Compare commits

..

51 Commits

Author SHA1 Message Date
Kevin Wan
57b73d8b49 make sure offset less than size even it's checked inside (#354) 2021-01-05 16:06:36 +08:00
Kevin Wan
a79cee12ee add godoc for RollingWindow (#351) 2021-01-04 22:43:55 +08:00
zjbztianya
7a921f66e6 simple rolling windows code (#346) 2021-01-04 22:11:18 +08:00
kingxt
12e235efb0 optimized goctl format (#336)
* fix format

* refactor

* refactor

* optimized

* refactor

* refactor

* refactor

* add js path prefix
2021-01-04 18:59:48 +08:00
Kevin Wan
01060cf16d close issue of #337 (#347) 2021-01-04 16:36:27 +08:00
Kevin Wan
0786862a35 align bucket boundary to interval in rolling window (#345) 2021-01-04 11:17:59 +08:00
Kevin Wan
efa43483b2 fix potential data race in PeriodicalExecutor (#344)
* fix potential data race in PeriodicalExecutor

* add comment
2021-01-03 20:56:17 +08:00
Kevin Wan
771371e051 simplify rolling window code, and make tests run faster (#343) 2021-01-03 20:47:29 +08:00
zjbztianya
2ee95f8981 fix rolling window bug (#340) 2021-01-03 20:27:47 +08:00
Kevin Wan
5bc01e4bfd set guarded to false only on quitting background flush (#342)
* set guarded to false only on quitting background flush

* set guarded to false only on quitting background flush, cont.
2021-01-03 19:54:11 +08:00
Kevin Wan
510e966982 simplify periodical executor background routine (#339) 2021-01-03 14:02:51 +08:00
Kevin Wan
10e3b8ac80 optimize code that fixes issue #317 (#338) 2021-01-02 19:01:37 +08:00
Kevin Wan
04059bbf5a add discord chat group in readme 2021-01-02 18:35:33 +08:00
weibobo
d643007c79 fix bug #317 (#335)
* fix bug #317.
* add counter for current task. If it's bigger then zero, do not quit background thread

* Revert "fix issue #317 (#331)"

This reverts commit fc43876cc5.
2021-01-02 18:04:04 +08:00
Kevin Wan
fc43876cc5 fix issue #317 (#331) 2021-01-01 13:24:28 +08:00
FengZhang
a926cb514f modify the goctl gensvc template (#323) 2020-12-30 10:05:26 +08:00
kingxt
25cab2f273 Java (#327)
* add g4 file

* new define api by g4

* reactor parser to g4gen

* add syntax parser & test

* add syntax parser & test

* add syntax parser & test

* update g4 file

* add import parse & test

* ractor AT lexer

* panic with error

* revert AT

* update g4 file

* update g4 file

* update g4 file

* optimize parser

* update g4 file

* parse info

* optimized java generator

* revert

* optimize java generator

* update java generator

* update java generator

* update java generator

* update java generator

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
2020-12-29 17:50:41 +08:00
Kevin Wan
8d2e2753a2 simplify http.Flusher implementation (#326)
* simplify code with http.Flusher type conversion

* simplify code with http.Flusher type conversion, better version
2020-12-29 15:02:36 +08:00
Kevin Wan
cc4c50e3eb fix broken link. 2020-12-29 11:54:32 +08:00
Kevin Wan
751072bdb0 fix broken doc link 2020-12-29 11:52:55 +08:00
Kevin Wan
e97e1f10db simplify code with http.Flusher type conversion (#325)
* simplify code with http.Flusher type conversion

* simplify code with http.Flusher type conversion, better version
2020-12-29 10:25:55 +08:00
jichangyun
0bd2a0656c The ResponseWriters defined in rest.handler add Flush interface. (#318) 2020-12-28 21:30:24 +08:00
Kevin Wan
71a2b20301 add more tests for prof (#322) 2020-12-27 14:45:14 +08:00
Kevin Wan
8df7de94e3 add more tests for zrpc (#321) 2020-12-27 14:08:24 +08:00
Kevin Wan
bf21203297 add more tests (#320) 2020-12-27 12:26:31 +08:00
Kevin Wan
ae98375194 add more tests (#319) 2020-12-26 20:30:02 +08:00
Kevin Wan
82d1ccf376 fixes #286 (#315) 2020-12-25 19:47:27 +08:00
Kevin Wan
bb6d49c17e add go report card back (#313)
* add go report card back

* avoid test failure, run tests sequentially
2020-12-25 12:09:59 +08:00
Kevin Wan
ed735ec47c Update codeql-analysis.yml
disable python code analysis, python code is in examples.
2020-12-25 12:09:43 +08:00
Kevin Wan
ba4bac3a03 format code (#312) 2020-12-25 11:53:37 +08:00
FengZhang
08433d7e04 add config load support env var (#309) 2020-12-25 11:42:19 +08:00
anqiansong
a3b525b50d feature model fix (#296)
* add raw stirng quote for sql field

* remove unused code
2020-12-21 09:43:32 +08:00
Kevin Wan
097f6886f2 Update readme.md 2020-12-15 23:47:41 +08:00
Kevin Wan
07a1549634 add wechat micro practice qrcode image (#289) 2020-12-14 17:49:58 +08:00
Kevin Wan
befca26c58 Update readme.md
add goproxy.cn download badge
2020-12-13 00:02:32 +08:00
Kevin Wan
3556a2eef4 Update readme-en.md
goreportcard is not working, submitted an issue to them.
2020-12-12 23:40:26 +08:00
Kevin Wan
807765f77e Update readme.md
goreportcard is not working, submitted a issue to them.
2020-12-12 23:39:28 +08:00
Kevin Wan
e44584e549 Create codeql-analysis.yml 2020-12-12 23:01:15 +08:00
Kevin Wan
acd48f0abb optimize dockerfile generation (#284) 2020-12-12 16:53:06 +08:00
kingxt
f919bc6713 refactor (#283) 2020-12-12 11:18:22 +08:00
Kevin Wan
a0030b8f45 format dockerfile on non-chinese mode (#282) 2020-12-12 10:13:33 +08:00
Kevin Wan
a5f0cce1b1 Update readme-en.md 2020-12-12 09:06:09 +08:00
Kevin Wan
4d13dda605 add EXPOSE in dockerfile generation (#281) 2020-12-12 08:18:01 +08:00
songmeizi
b56cc8e459 optimize test case of TestRpcGenerate (#279)
Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
2020-12-11 21:57:04 +08:00
Kevin Wan
c435811479 fix gocyclo warnings (#278) 2020-12-11 20:57:48 +08:00
Kevin Wan
c686c93fb5 fix dockerfile generation bug (#277) 2020-12-11 20:31:31 +08:00
Kevin Wan
da8f76e6bd add category docker & kube (#276) 2020-12-11 18:53:40 +08:00
Kevin Wan
99596a4149 fix issue #266 (#275)
* optimize dockerfile

* fix issue #266
2020-12-11 16:12:33 +08:00
wayne
ec2a9f2c57 fix tracelogger_test TestTraceLog (#271) 2020-12-10 17:04:57 +08:00
Kevin Wan
fd73ced6dc optimize dockerfile (#272) 2020-12-10 16:21:06 +08:00
Kevin Wan
5071736ab4 fmt code (#270) 2020-12-10 15:16:13 +08:00
75 changed files with 2265 additions and 387 deletions

67
.github/workflows/codeql-analysis.yml vendored Normal file
View File

@@ -0,0 +1,67 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
on:
push:
branches: [ master ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ master ]
schedule:
- cron: '18 19 * * 6'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
language: [ 'go' ]
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
# Learn more:
# https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
steps:
- name: Checkout repository
uses: actions/checkout@v2
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v1
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# queries: ./path/to/local/query, your-org/your-repo/queries@main
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v1
# Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl
# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
# and modify them (or add more) to build your code if your project
# uses a compiled language
#- run: |
# make bootstrap
# make release
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v1

2
.gitignore vendored
View File

@@ -4,6 +4,7 @@
# Unignore all with extensions # Unignore all with extensions
!*.* !*.*
!**/Dockerfile !**/Dockerfile
!**/Makefile
# Unignore all dirs # Unignore all dirs
!*/ !*/
@@ -12,7 +13,6 @@
.idea .idea
**/.DS_Store **/.DS_Store
**/logs **/logs
!Makefile
# gitlab ci # gitlab ci
.cache .cache

View File

@@ -8,8 +8,10 @@ import (
) )
type ( type (
// RollingWindowOption let callers customize the RollingWindow.
RollingWindowOption func(rollingWindow *RollingWindow) RollingWindowOption func(rollingWindow *RollingWindow)
// RollingWindow defines a rolling window to calculate the events in buckets with time interval.
RollingWindow struct { RollingWindow struct {
lock sync.RWMutex lock sync.RWMutex
size int size int
@@ -17,10 +19,12 @@ type (
interval time.Duration interval time.Duration
offset int offset int
ignoreCurrent bool ignoreCurrent bool
lastTime time.Duration lastTime time.Duration // start time of the last bucket
} }
) )
// NewRollingWindow returns a RollingWindow that with size buckets and time interval,
// use opts to customize the RollingWindow.
func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow { func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
if size < 1 { if size < 1 {
panic("size must be greater than 0") panic("size must be greater than 0")
@@ -38,6 +42,7 @@ func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOpt
return w return w
} }
// Add adds value to current bucket.
func (rw *RollingWindow) Add(v float64) { func (rw *RollingWindow) Add(v float64) {
rw.lock.Lock() rw.lock.Lock()
defer rw.lock.Unlock() defer rw.lock.Unlock()
@@ -45,6 +50,7 @@ func (rw *RollingWindow) Add(v float64) {
rw.win.add(rw.offset, v) rw.win.add(rw.offset, v)
} }
// Reduce runs fn on all buckets, ignore current bucket if ignoreCurrent was set.
func (rw *RollingWindow) Reduce(fn func(b *Bucket)) { func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
rw.lock.RLock() rw.lock.RLock()
defer rw.lock.RUnlock() defer rw.lock.RUnlock()
@@ -79,26 +85,18 @@ func (rw *RollingWindow) updateOffset() {
} }
offset := rw.offset offset := rw.offset
start := offset + 1
steps := start + span
var remainder int
if steps > rw.size {
remainder = steps - rw.size
steps = rw.size
}
// reset expired buckets // reset expired buckets
for i := start; i < steps; i++ { for i := 0; i < span; i++ {
rw.win.resetBucket(i) rw.win.resetBucket((offset + i + 1) % rw.size)
}
for i := 0; i < remainder; i++ {
rw.win.resetBucket(i)
} }
rw.offset = (offset + span) % rw.size rw.offset = (offset + span) % rw.size
rw.lastTime = timex.Now() now := timex.Now()
// align to interval time boundary
rw.lastTime = now - (now-rw.lastTime)%rw.interval
} }
// Bucket defines the bucket that holds sum and num of additions.
type Bucket struct { type Bucket struct {
Sum float64 Sum float64
Count int64 Count int64
@@ -144,6 +142,7 @@ func (w *window) resetBucket(offset int) {
w.buckets[offset%w.size].reset() w.buckets[offset%w.size].reset()
} }
// IgnoreCurrentBucket lets the Reduce call ignore current bucket.
func IgnoreCurrentBucket() RollingWindowOption { func IgnoreCurrentBucket() RollingWindowOption {
return func(w *RollingWindow) { return func(w *RollingWindow) {
w.ignoreCurrent = true w.ignoreCurrent = true

View File

@@ -105,6 +105,37 @@ func TestRollingWindowReduce(t *testing.T) {
} }
} }
func TestRollingWindowBucketTimeBoundary(t *testing.T) {
const size = 3
interval := time.Millisecond * 30
r := NewRollingWindow(size, interval)
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
}
assert.Equal(t, []float64{0, 0, 0}, listBuckets())
r.Add(1)
assert.Equal(t, []float64{0, 0, 1}, listBuckets())
time.Sleep(time.Millisecond * 45)
r.Add(2)
r.Add(3)
assert.Equal(t, []float64{0, 1, 5}, listBuckets())
// sleep time should be less than interval, and make the bucket change happen
time.Sleep(time.Millisecond * 20)
r.Add(4)
r.Add(5)
r.Add(6)
assert.Equal(t, []float64{1, 5, 15}, listBuckets())
time.Sleep(time.Millisecond * 100)
r.Add(7)
r.Add(8)
r.Add(9)
assert.Equal(t, []float64{0, 0, 24}, listBuckets())
}
func TestRollingWindowDataRace(t *testing.T) { func TestRollingWindowDataRace(t *testing.T) {
const size = 3 const size = 3
r := NewRollingWindow(size, duration) r := NewRollingWindow(size, duration)

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
"os"
"path" "path"
"github.com/tal-tech/go-zero/core/mapping" "github.com/tal-tech/go-zero/core/mapping"
@@ -19,7 +20,7 @@ func LoadConfig(file string, v interface{}) error {
if content, err := ioutil.ReadFile(file); err != nil { if content, err := ioutil.ReadFile(file); err != nil {
return err return err
} else if loader, ok := loaders[path.Ext(file)]; ok { } else if loader, ok := loaders[path.Ext(file)]; ok {
return loader(content, v) return loader([]byte(os.ExpandEnv(string(content))), v)
} else { } else {
return fmt.Errorf("unrecoginized file type: %s", file) return fmt.Errorf("unrecoginized file type: %s", file)
} }

View File

@@ -17,13 +17,14 @@ func TestConfigJson(t *testing.T) {
} }
text := `{ text := `{
"a": "foo", "a": "foo",
"b": 1 "b": 1,
"c": "${FOO}"
}` }`
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
t.Parallel() os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text) tmpfile, err := createTempFile(test, text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile) defer os.Remove(tmpfile)
@@ -31,10 +32,12 @@ func TestConfigJson(t *testing.T) {
var val struct { var val struct {
A string `json:"a"` A string `json:"a"`
B int `json:"b"` B int `json:"b"`
C string `json:"c"`
} }
MustLoad(tmpfile, &val) MustLoad(tmpfile, &val)
assert.Equal(t, "foo", val.A) assert.Equal(t, "foo", val.A)
assert.Equal(t, 1, val.B) assert.Equal(t, 1, val.B)
assert.Equal(t, "2", val.C)
}) })
} }
} }

View File

@@ -3,6 +3,7 @@ package executors
import ( import (
"reflect" "reflect"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/tal-tech/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
@@ -35,6 +36,7 @@ type (
// avoid race condition on waitGroup when calling wg.Add/Done/Wait(...) // avoid race condition on waitGroup when calling wg.Add/Done/Wait(...)
wgBarrier syncx.Barrier wgBarrier syncx.Barrier
confirmChan chan lang.PlaceholderType confirmChan chan lang.PlaceholderType
inflight int32
guarded bool guarded bool
newTicker func(duration time.Duration) timex.Ticker newTicker func(duration time.Duration) timex.Ticker
lock sync.Mutex lock sync.Mutex
@@ -91,18 +93,16 @@ func (pe *PeriodicalExecutor) Wait() {
func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) { func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
pe.lock.Lock() pe.lock.Lock()
defer func() { defer func() {
var start bool
if !pe.guarded { if !pe.guarded {
pe.guarded = true pe.guarded = true
start = true // defer to unlock quickly
defer pe.backgroundFlush()
} }
pe.lock.Unlock() pe.lock.Unlock()
if start {
pe.backgroundFlush()
}
}() }()
if pe.container.AddTask(task) { if pe.container.AddTask(task) {
atomic.AddInt32(&pe.inflight, 1)
return pe.container.RemoveAll(), true return pe.container.RemoveAll(), true
} }
@@ -111,6 +111,9 @@ func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool)
func (pe *PeriodicalExecutor) backgroundFlush() { func (pe *PeriodicalExecutor) backgroundFlush() {
threading.GoSafe(func() { threading.GoSafe(func() {
// flush before quit goroutine to avoid missing tasks
defer pe.Flush()
ticker := pe.newTicker(pe.interval) ticker := pe.newTicker(pe.interval)
defer ticker.Stop() defer ticker.Stop()
@@ -120,6 +123,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
select { select {
case vals := <-pe.commander: case vals := <-pe.commander:
commanded = true commanded = true
atomic.AddInt32(&pe.inflight, -1)
pe.enterExecution() pe.enterExecution()
pe.confirmChan <- lang.Placeholder pe.confirmChan <- lang.Placeholder
pe.executeTasks(vals) pe.executeTasks(vals)
@@ -129,13 +133,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
commanded = false commanded = false
} else if pe.Flush() { } else if pe.Flush() {
last = timex.Now() last = timex.Now()
} else if timex.Since(last) > pe.interval*idleRound { } else if pe.shallQuit(last) {
pe.lock.Lock()
pe.guarded = false
pe.lock.Unlock()
// flush again to avoid missing tasks
pe.Flush()
return return
} }
} }
@@ -178,3 +176,19 @@ func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
return true return true
} }
} }
func (pe *PeriodicalExecutor) shallQuit(last time.Duration) (stop bool) {
if timex.Since(last) <= pe.interval*idleRound {
return
}
// checking pe.inflight and setting pe.guarded should be locked together
pe.lock.Lock()
if atomic.LoadInt32(&pe.inflight) == 0 {
pe.guarded = false
stop = true
}
pe.lock.Unlock()
return
}

View File

@@ -140,6 +140,26 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
assert.Equal(t, total, cnt) assert.Equal(t, total, cnt)
} }
func TestPeriodicalExecutor_Deadlock(t *testing.T) {
executor := NewBulkExecutor(func(tasks []interface{}) {
}, WithBulkTasks(1), WithBulkInterval(time.Millisecond))
for i := 0; i < 1e5; i++ {
executor.Add(1)
}
}
func TestPeriodicalExecutor_hasTasks(t *testing.T) {
ticker := timex.NewFakeTicker()
defer ticker.Stop()
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
exec.newTicker = func(d time.Duration) timex.Ticker {
return ticker
}
assert.False(t, exec.hasTasks(nil))
assert.True(t, exec.hasTasks(1))
}
// go test -benchtime 10s -bench . // go test -benchtime 10s -bench .
func BenchmarkExecutor(b *testing.B) { func BenchmarkExecutor(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()

View File

@@ -21,6 +21,7 @@ var mock tracespec.Trace = new(mockTrace)
func TestTraceLog(t *testing.T) { func TestTraceLog(t *testing.T) {
var buf mockWriter var buf mockWriter
atomic.StoreUint32(&initialized, 1)
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock) ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog) WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId)) assert.True(t, strings.Contains(buf.String(), mockTraceId))

View File

@@ -153,58 +153,57 @@ func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fie
key := strings.TrimSpace(segments[0]) key := strings.TrimSpace(segments[0])
options := segments[1:] options := segments[1:]
if len(options) > 0 { if len(options) == 0 {
var fieldOpts fieldOptions return key, nil, nil
for _, segment := range options {
option := strings.TrimSpace(segment)
switch {
case option == stringOption:
fieldOpts.FromString = true
case strings.HasPrefix(option, optionalOption):
segs := strings.Split(option, equalToken)
switch len(segs) {
case 1:
fieldOpts.Optional = true
case 2:
fieldOpts.Optional = true
fieldOpts.OptionalDep = segs[1]
default:
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
}
case option == optionalOption:
fieldOpts.Optional = true
case strings.HasPrefix(option, optionsOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
} else {
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
}
case strings.HasPrefix(option, defaultOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
} else {
fieldOpts.Default = strings.TrimSpace(segs[1])
}
case strings.HasPrefix(option, rangeOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
}
if nr, err := parseNumberRange(segs[1]); err != nil {
return "", nil, err
} else {
fieldOpts.Range = nr
}
}
}
return key, &fieldOpts, nil
} }
return key, nil, nil var fieldOpts fieldOptions
for _, segment := range options {
option := strings.TrimSpace(segment)
switch {
case option == stringOption:
fieldOpts.FromString = true
case strings.HasPrefix(option, optionalOption):
segs := strings.Split(option, equalToken)
switch len(segs) {
case 1:
fieldOpts.Optional = true
case 2:
fieldOpts.Optional = true
fieldOpts.OptionalDep = segs[1]
default:
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
}
case option == optionalOption:
fieldOpts.Optional = true
case strings.HasPrefix(option, optionsOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
} else {
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
}
case strings.HasPrefix(option, defaultOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
} else {
fieldOpts.Default = strings.TrimSpace(segs[1])
}
case strings.HasPrefix(option, rangeOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
}
if nr, err := parseNumberRange(segs[1]); err != nil {
return "", nil, err
} else {
fieldOpts.Range = nr
}
}
}
return key, &fieldOpts, nil
} }
func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) { func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {

View File

@@ -0,0 +1,16 @@
package prof
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestReport(t *testing.T) {
once.Do(func() {})
assert.NotContains(t, generateReport(), "foo")
report("foo", time.Second)
assert.Contains(t, generateReport(), "foo")
report("foo", time.Second)
}

View File

@@ -0,0 +1,23 @@
package prof
import (
"testing"
"github.com/tal-tech/go-zero/core/utils"
)
func TestProfiler(t *testing.T) {
EnableProfiling()
Start()
Report("foo", ProfilePoint{
ElapsedTimer: utils.NewElapsedTimer(),
})
}
func TestNullProfiler(t *testing.T) {
p := newNullProfiler()
p.Start()
p.Report("foo", ProfilePoint{
ElapsedTimer: utils.NewElapsedTimer(),
})
}

View File

@@ -70,8 +70,6 @@ func (g *sharedGroup) createCall(key string) (c *call, done bool) {
func (g *sharedGroup) makeCall(c *call, key string, fn func() (interface{}, error)) { func (g *sharedGroup) makeCall(c *call, key string, fn func() (interface{}, error)) {
defer func() { defer func() {
// delete key first, done later. can't reverse the order, because if reverse,
// another Do call might wg.Wait() without get notified with wg.Done()
g.lock.Lock() g.lock.Lock()
delete(g.calls, key) delete(g.calls, key)
g.lock.Unlock() g.lock.Unlock()

View File

@@ -129,7 +129,7 @@ go get -u github.com/tal-tech/go-zero
the .api files also can be generate by goctl, like below: the .api files also can be generate by goctl, like below:
```shell ```shell
goctl api -o greet.api goctl api -o greet.api
``` ```
3. generate the go server side code 3. generate the go server side code
@@ -208,3 +208,7 @@ goctl api -o greet.api
* [Rapid development of microservice systems](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl-en.md) * [Rapid development of microservice systems](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl-en.md)
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/doc/bookstore-en.md) * [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/doc/bookstore-en.md)
## 9. Chat group
Join the chat via https://discord.gg/4JQvC5A4Fe

View File

@@ -5,8 +5,9 @@
[English](readme-en.md) | 简体中文 [English](readme-en.md) | 简体中文
[![Go](https://github.com/tal-tech/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/tal-tech/go-zero/actions) [![Go](https://github.com/tal-tech/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/tal-tech/go-zero/actions)
[![codecov](https://codecov.io/gh/tal-tech/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/tal-tech/go-zero)
[![Go Report Card](https://goreportcard.com/badge/github.com/tal-tech/go-zero)](https://goreportcard.com/report/github.com/tal-tech/go-zero) [![Go Report Card](https://goreportcard.com/badge/github.com/tal-tech/go-zero)](https://goreportcard.com/report/github.com/tal-tech/go-zero)
[![goproxy](https://goproxy.cn/stats/github.com/tal-tech/go-zero/badges/download-count.svg)](https://goproxy.cn/stats/github.com/tal-tech/go-zero/badges/download-count.svg)
[![codecov](https://codecov.io/gh/tal-tech/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/tal-tech/go-zero)
[![Release](https://img.shields.io/github/v/release/tal-tech/go-zero.svg?style=flat-square)](https://github.com/tal-tech/go-zero) [![Release](https://img.shields.io/github/v/release/tal-tech/go-zero.svg?style=flat-square)](https://github.com/tal-tech/go-zero)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
@@ -95,7 +96,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
[快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md) [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
[快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md) [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore.md)
1. 安装 goctl 工具 1. 安装 goctl 工具
@@ -162,7 +163,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
* awesome 系列 * awesome 系列
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md) * [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md) * [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore.md)
* [goctl 使用帮助](https://github.com/tal-tech/zero-doc/blob/main/doc/goctl.md) * [goctl 使用帮助](https://github.com/tal-tech/zero-doc/blob/main/doc/goctl.md)
* [通过 MapReduce 降低服务响应时间](https://github.com/tal-tech/zero-doc/blob/main/doc/mapreduce.md) * [通过 MapReduce 降低服务响应时间](https://github.com/tal-tech/zero-doc/blob/main/doc/mapreduce.md)
* [关键字替换和敏感词过滤工具](https://github.com/tal-tech/zero-doc/blob/main/doc/keywords.md) * [关键字替换和敏感词过滤工具](https://github.com/tal-tech/zero-doc/blob/main/doc/keywords.md)
@@ -172,7 +173,13 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
* [文本序列化和反序列化](https://github.com/tal-tech/zero-doc/blob/main/doc/mapping.md) * [文本序列化和反序列化](https://github.com/tal-tech/zero-doc/blob/main/doc/mapping.md)
* [快速构建 jwt 鉴权认证](https://github.com/tal-tech/zero-doc/blob/main/doc/jwt.md) * [快速构建 jwt 鉴权认证](https://github.com/tal-tech/zero-doc/blob/main/doc/jwt.md)
## 8. 微信交流群 ## 8. 微信公众号
`go-zero` 相关文章都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注,也可以通过公众号私信我 👏
<img src="https://gitee.com/kevwan/static/raw/master/images/wechat-micro.jpg" alt="wechat" width="300" />
## 9. 微信交流群
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。 如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。

171
rest/engine_test.go Normal file
View File

@@ -0,0 +1,171 @@
package rest
import (
"errors"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/conf"
)
func TestNewEngine(t *testing.T) {
yamls := []string{
`Name: foo
Port: 54321
`,
`Name: foo
Port: 54321
CpuThreshold: 500
`,
`Name: foo
Port: 54321
CpuThreshold: 500
Verbose: true
`,
}
routes := []featuredRoutes{
{
jwt: jwtSetting{},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
prevSecret: "thesecret",
},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
SignatureConf: SignatureConf{
Strict: true,
},
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
SignatureConf: SignatureConf{
Strict: true,
PrivateKeys: []PrivateKeyConf{
{
Fingerprint: "a",
KeyFile: "b",
},
},
},
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
}
for _, yaml := range yamls {
for _, route := range routes {
var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf))
ng := newEngine(cnf)
ng.AddRoutes(route)
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
}
})
assert.NotNil(t, ng.StartWithRouter(mockedRouter{}))
}
}
}
type mockedRouter struct {
}
func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
}
func (m mockedRouter) Handle(method string, path string, handler http.Handler) error {
return errors.New("foo")
}
func (m mockedRouter) SetNotFoundHandler(handler http.Handler) {
}
func (m mockedRouter) SetNotAllowedHandler(handler http.Handler) {
}

View File

@@ -46,18 +46,18 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
parser := token.NewTokenParser() parser := token.NewTokenParser()
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret) tok, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
if err != nil { if err != nil {
unauthorized(w, r, err, authOpts.Callback) unauthorized(w, r, err, authOpts.Callback)
return return
} }
if !token.Valid { if !tok.Valid {
unauthorized(w, r, errInvalidToken, authOpts.Callback) unauthorized(w, r, errInvalidToken, authOpts.Callback)
return return
} }
claims, ok := token.Claims.(jwt.MapClaims) claims, ok := tok.Claims.(jwt.MapClaims)
if !ok { if !ok {
unauthorized(w, r, errNoClaims, authOpts.Callback) unauthorized(w, r, errNoClaims, authOpts.Callback)
return return
@@ -122,6 +122,12 @@ func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
} }
} }
func (grw *guardedResponseWriter) Flush() {
if flusher, ok := grw.writer.(http.Flusher); ok {
flusher.Flush()
}
}
func (grw *guardedResponseWriter) Header() http.Header { func (grw *guardedResponseWriter) Header() http.Header {
return grw.writer.Header() return grw.writer.Header()
} }

View File

@@ -41,6 +41,10 @@ func TestAuthHandler(t *testing.T) {
w.Header().Set("X-Test", "test") w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content")) _, err := w.Write([]byte("content"))
assert.Nil(t, err) assert.Nil(t, err)
flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
})) }))
resp := httptest.NewRecorder() resp := httptest.NewRecorder()

View File

@@ -83,6 +83,12 @@ func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter {
} }
} }
func (w *cryptionResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *cryptionResponseWriter) Header() http.Header { func (w *cryptionResponseWriter) Header() http.Header {
return w.ResponseWriter.Header() return w.ResponseWriter.Header()
} }

View File

@@ -87,3 +87,19 @@ func TestCryptionHandlerWriteHeader(t *testing.T) {
handler.ServeHTTP(recorder, req) handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code) assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
} }
func TestCryptionHandlerFlush(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", nil)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(respText))
flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
assert.Nil(t, err)
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}

View File

@@ -38,6 +38,12 @@ func (w *LoggedResponseWriter) WriteHeader(code int) {
w.code = code w.code = code
} }
func (w *LoggedResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}
func LogHandler(next http.Handler) http.Handler { func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer() timer := utils.NewElapsedTimer()
@@ -68,6 +74,10 @@ func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buff
} }
} }
func (w *DetailLoggedResponseWriter) Flush() {
w.writer.Flush()
}
func (w *DetailLoggedResponseWriter) Header() http.Header { func (w *DetailLoggedResponseWriter) Header() http.Header {
return w.writer.Header() return w.writer.Header()
} }

View File

@@ -30,6 +30,10 @@ func TestLogHandler(t *testing.T) {
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
_, err := w.Write([]byte("content")) _, err := w.Write([]byte("content"))
assert.Nil(t, err) assert.Nil(t, err)
flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
})) }))
resp := httptest.NewRecorder() resp := httptest.NewRecorder()

View File

@@ -7,6 +7,12 @@ type WithCodeResponseWriter struct {
Code int Code int
} }
func (w *WithCodeResponseWriter) Flush() {
if flusher, ok := w.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *WithCodeResponseWriter) Header() http.Header { func (w *WithCodeResponseWriter) Header() http.Header {
return w.Writer.Header() return w.Writer.Header()
} }

View File

@@ -0,0 +1,33 @@
package security
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWithCodeResponseWriter(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cw := &WithCodeResponseWriter{Writer: w}
cw.Header().Set("X-Test", "test")
cw.WriteHeader(http.StatusServiceUnavailable)
assert.Equal(t, cw.Code, http.StatusServiceUnavailable)
_, err := cw.Write([]byte("content"))
assert.Nil(t, err)
flusher, ok := http.ResponseWriter(cw).(http.Flusher)
assert.True(t, ok)
flusher.Flush()
})
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}

View File

@@ -64,7 +64,7 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
allow, ok := pr.methodNotAllowed(r.Method, reqPath) allows, ok := pr.methodsAllowed(r.Method, reqPath)
if !ok { if !ok {
pr.handleNotFound(w, r) pr.handleNotFound(w, r)
return return
@@ -73,7 +73,7 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if pr.notAllowed != nil { if pr.notAllowed != nil {
pr.notAllowed.ServeHTTP(w, r) pr.notAllowed.ServeHTTP(w, r)
} else { } else {
w.Header().Set(allowHeader, allow) w.Header().Set(allowHeader, allows)
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
} }
} }
@@ -94,7 +94,7 @@ func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
} }
} }
func (pr *patRouter) methodNotAllowed(method, path string) (string, bool) { func (pr *patRouter) methodsAllowed(method, path string) (string, bool) {
var allows []string var allows []string
for treeMethod, tree := range pr.trees { for treeMethod, tree := range pr.trees {

View File

@@ -1,7 +1,6 @@
package rest package rest
import ( import (
"errors"
"log" "log"
"net/http" "net/http"
@@ -24,6 +23,9 @@ type (
} }
) )
// MustNewServer returns a server with given config of c and options defined in opts.
// Be aware that later RunOption might overwrite previous one that write the same option.
// The process will exit if error occurs.
func MustNewServer(c RestConf, opts ...RunOption) *Server { func MustNewServer(c RestConf, opts ...RunOption) *Server {
engine, err := NewServer(c, opts...) engine, err := NewServer(c, opts...)
if err != nil { if err != nil {
@@ -33,11 +35,9 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server {
return engine return engine
} }
// NewServer returns a server with given config of c and options defined in opts.
// Be aware that later RunOption might overwrite previous one that write the same option.
func NewServer(c RestConf, opts ...RunOption) (*Server, error) { func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
if len(opts) > 1 {
return nil, errors.New("only one RunOption is allowed")
}
if err := c.SetUp(); err != nil { if err := c.SetUp(); err != nil {
return nil, err return nil, err
} }

View File

@@ -8,18 +8,84 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/conf"
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/router" "github.com/tal-tech/go-zero/rest/router"
) )
func TestNewServer(t *testing.T) { func TestNewServer(t *testing.T) {
_, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil)) const configYaml = `
assert.NotNil(t, err) Name: foo
Port: 54321
`
var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
failStart := func(server *Server) {
server.opts.start = func(e *engine) error {
return http.ErrServerClosed
}
}
tests := []struct {
c RestConf
opts []RunOption
fail bool
}{
{
c: RestConf{},
opts: []RunOption{failStart},
fail: true,
},
{
c: cnf,
opts: []RunOption{failStart},
},
{
c: cnf,
opts: []RunOption{WithNotAllowedHandler(nil), failStart},
},
{
c: cnf,
opts: []RunOption{WithNotFoundHandler(nil), failStart},
},
{
c: cnf,
opts: []RunOption{WithUnauthorizedCallback(nil), failStart},
},
{
c: cnf,
opts: []RunOption{WithUnsignedCallback(nil), failStart},
},
}
for _, test := range tests {
srv, err := NewServer(test.c, test.opts...)
if test.fail {
assert.NotNil(t, err)
}
if err != nil {
continue
}
srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
}))
srv.AddRoute(Route{
Method: http.MethodGet,
Path: "/",
Handler: nil,
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
WithJwtTransition("preivous", "thenewone"))
srv.Start()
srv.Stop()
}
} }
func TestWithMiddleware(t *testing.T) { func TestWithMiddleware(t *testing.T) {
m := make(map[string]string) m := make(map[string]string)
router := router.NewRouter() rt := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(w http.ResponseWriter, r *http.Request) {
var v struct { var v struct {
Nickname string `form:"nickname"` Nickname string `form:"nickname"`
@@ -56,14 +122,14 @@ func TestWithMiddleware(t *testing.T) {
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000", "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
} }
for _, route := range rs { for _, route := range rs {
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler)) assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
} }
for _, url := range urls { for _, url := range urls {
r, err := http.NewRequest(http.MethodGet, url, nil) r, err := http.NewRequest(http.MethodGet, url, nil)
assert.Nil(t, err) assert.Nil(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
router.ServeHTTP(rr, r) rt.ServeHTTP(rr, r)
assert.Equal(t, "whatever:200000", rr.Body.String()) assert.Equal(t, "whatever:200000", rr.Body.String())
} }
@@ -76,7 +142,7 @@ func TestWithMiddleware(t *testing.T) {
func TestMultiMiddlewares(t *testing.T) { func TestMultiMiddlewares(t *testing.T) {
m := make(map[string]string) m := make(map[string]string)
router := router.NewRouter() rt := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(w http.ResponseWriter, r *http.Request) {
var v struct { var v struct {
Nickname string `form:"nickname"` Nickname string `form:"nickname"`
@@ -127,14 +193,14 @@ func TestMultiMiddlewares(t *testing.T) {
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000", "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
} }
for _, route := range rs { for _, route := range rs {
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler)) assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
} }
for _, url := range urls { for _, url := range urls {
r, err := http.NewRequest(http.MethodGet, url, nil) r, err := http.NewRequest(http.MethodGet, url, nil)
assert.Nil(t, err) assert.Nil(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
router.ServeHTTP(rr, r) rt.ServeHTTP(rr, r)
assert.Equal(t, "whatever:200000200000", rr.Body.String()) assert.Equal(t, "whatever:200000200000", rr.Body.String())
} }

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"errors" "errors"
"fmt" "fmt"
"go/format"
"go/scanner" "go/scanner"
"io/ioutil" "io/ioutil"
"os" "os"
@@ -13,6 +14,7 @@ import (
"github.com/tal-tech/go-zero/core/errorx" "github.com/tal-tech/go-zero/core/errorx"
"github.com/tal-tech/go-zero/tools/goctl/api/parser" "github.com/tal-tech/go-zero/tools/goctl/api/parser"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@@ -103,24 +105,108 @@ func apiFormat(data string) (string, error) {
var builder strings.Builder var builder strings.Builder
s := bufio.NewScanner(strings.NewReader(data)) s := bufio.NewScanner(strings.NewReader(data))
var tapCount = 0 var tapCount = 0
var newLineCount = 0
var preLine string
for s.Scan() { for s.Scan() {
line := strings.TrimSpace(s.Text()) line := strings.TrimSpace(s.Text())
if len(line) == 0 {
if newLineCount > 0 {
continue
}
newLineCount++
} else {
if preLine == rightBrace {
builder.WriteString(ctlutil.NL)
}
newLineCount = 0
}
if tapCount == 0 {
format, err := formatGoTypeDef(line, s, &builder)
if err != nil {
return "", err
}
if format {
continue
}
}
noCommentLine := util.RemoveComment(line) noCommentLine := util.RemoveComment(line)
if noCommentLine == rightParenthesis || noCommentLine == rightBrace { if noCommentLine == rightParenthesis || noCommentLine == rightBrace {
tapCount -= 1 tapCount -= 1
} }
if tapCount < 0 { if tapCount < 0 {
line = strings.TrimSuffix(line, rightBrace) line := strings.TrimSuffix(noCommentLine, rightBrace)
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if strings.HasSuffix(line, leftBrace) { if strings.HasSuffix(line, leftBrace) {
tapCount += 1 tapCount += 1
} }
} }
util.WriteIndent(&builder, tapCount) util.WriteIndent(&builder, tapCount)
builder.WriteString(line + "\n") builder.WriteString(line + ctlutil.NL)
if strings.HasSuffix(noCommentLine, leftParenthesis) || strings.HasSuffix(noCommentLine, leftBrace) { if strings.HasSuffix(noCommentLine, leftParenthesis) || strings.HasSuffix(noCommentLine, leftBrace) {
tapCount += 1 tapCount += 1
} }
preLine = line
} }
return strings.TrimSpace(builder.String()), nil return strings.TrimSpace(builder.String()), nil
} }
func formatGoTypeDef(line string, scanner *bufio.Scanner, builder *strings.Builder) (bool, error) {
noCommentLine := util.RemoveComment(line)
tokenCount := 0
if strings.HasPrefix(noCommentLine, "type") && (strings.HasSuffix(noCommentLine, leftParenthesis) ||
strings.HasSuffix(noCommentLine, leftBrace)) {
var typeBuilder strings.Builder
typeBuilder.WriteString(mayInsertStructKeyword(line, &tokenCount) + ctlutil.NL)
for scanner.Scan() {
noCommentLine := util.RemoveComment(scanner.Text())
typeBuilder.WriteString(mayInsertStructKeyword(scanner.Text(), &tokenCount) + ctlutil.NL)
if noCommentLine == rightBrace || noCommentLine == rightParenthesis {
tokenCount--
}
if tokenCount == 0 {
ts, err := format.Source([]byte(typeBuilder.String()))
if err != nil {
return false, errors.New("error format \n" + typeBuilder.String())
}
result := strings.ReplaceAll(string(ts), " struct ", " ")
result = strings.ReplaceAll(result, "type ()", "")
builder.WriteString(result)
break
}
}
return true, nil
}
return false, nil
}
func mayInsertStructKeyword(line string, token *int) string {
insertStruct := func() string {
if strings.Contains(line, " struct") {
return line
}
index := strings.Index(line, leftBrace)
return line[:index] + " struct " + line[index:]
}
noCommentLine := util.RemoveComment(line)
if strings.HasSuffix(noCommentLine, leftBrace) {
*token++
return insertStruct()
}
if strings.HasSuffix(noCommentLine, rightBrace) {
noCommentLine = strings.TrimSuffix(noCommentLine, rightBrace)
noCommentLine = util.RemoveComment(noCommentLine)
if strings.HasSuffix(noCommentLine, leftBrace) {
return insertStruct()
}
}
if strings.HasSuffix(noCommentLine, leftParenthesis) {
*token++
}
return line
}

View File

@@ -24,11 +24,11 @@ handler: GreetHandler
} }
` `
formattedStr = `type Request struct { formattedStr = `type Request {
Name string Name string
} }
type Response struct { type Response {
Message string Message string
} }
@@ -40,7 +40,7 @@ service A-api {
}` }`
) )
func TestInlineTypeNotExist(t *testing.T) { func TestFormat(t *testing.T) {
r, err := apiFormat(notFormattedStr) r, err := apiFormat(notFormattedStr)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, r, formattedStr) assert.Equal(t, r, formattedStr)

View File

@@ -38,11 +38,12 @@ func RevertTemplate(name string) error {
return util.CreateTemplate(category, name, content) return util.CreateTemplate(category, name, content)
} }
func Update(category string) error { func Update() error {
err := Clean() err := Clean()
if err != nil { if err != nil {
return err return err
} }
return util.InitTemplates(category, templates) return util.InitTemplates(category, templates)
} }
@@ -50,6 +51,6 @@ func Clean() error {
return util.Clean(category) return util.Clean(category)
} }
func GetCategory() string { func Category() string {
return category return category
} }

View File

@@ -84,7 +84,7 @@ func TestUpdate(t *testing.T) {
assert.Equal(t, string(data), modifyData) assert.Equal(t, string(data), modifyData)
assert.Nil(t, Update(category)) assert.Nil(t, Update())
data, err = ioutil.ReadFile(file) data, err = ioutil.ReadFile(file)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -1,6 +1,7 @@
package javagen package javagen
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"path" "path"
@@ -17,6 +18,8 @@ const (
package com.xhb.logic.http.packet.{{.packet}}.model; package com.xhb.logic.http.packet.{{.packet}}.model;
import com.xhb.logic.http.DeProguardable; import com.xhb.logic.http.DeProguardable;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
{{.componentType}} {{.componentType}}
` `
@@ -28,7 +31,7 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
return nil return nil
} }
for _, ty := range types { for _, ty := range types {
if err := createComponent(dir, packetName, ty); err != nil { if err := createComponent(dir, packetName, ty, api.Types); err != nil {
return err return err
} }
} }
@@ -36,7 +39,7 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
return nil return nil
} }
func createComponent(dir, packetName string, ty spec.Type) error { func createComponent(dir, packetName string, ty spec.Type, types []spec.Type) error {
modelFile := util.Title(ty.Name) + ".java" modelFile := util.Title(ty.Name) + ".java"
filename := path.Join(dir, modelDir, modelFile) filename := path.Join(dir, modelDir, modelFile)
if err := util.RemoveOrQuit(filename); err != nil { if err := util.RemoveOrQuit(filename); err != nil {
@@ -52,7 +55,7 @@ func createComponent(dir, packetName string, ty spec.Type) error {
} }
defer fp.Close() defer fp.Close()
tys, err := buildType(ty) tys, err := buildType(ty, types)
if err != nil { if err != nil {
return err return err
} }
@@ -64,22 +67,66 @@ func createComponent(dir, packetName string, ty spec.Type) error {
}) })
} }
func buildType(ty spec.Type) (string, error) { func buildType(ty spec.Type, types []spec.Type) (string, error) {
var builder strings.Builder var builder strings.Builder
if err := writeType(&builder, ty); err != nil { if err := writeType(&builder, ty, types); err != nil {
return "", apiutil.WrapErr(err, "Type "+ty.Name+" generate error") return "", apiutil.WrapErr(err, "Type "+ty.Name+" generate error")
} }
return builder.String(), nil return builder.String(), nil
} }
func writeType(writer io.Writer, tp spec.Type) error { func writeType(writer io.Writer, tp spec.Type, types []spec.Type) error {
fmt.Fprintf(writer, "public class %s implements DeProguardable {\n", util.Title(tp.Name)) fmt.Fprintf(writer, "public class %s implements DeProguardable {\n", util.Title(tp.Name))
for _, member := range tp.Members { var members []spec.Member
if err := writeProperty(writer, member, 1); err != nil { err := writeMembers(writer, types, tp.Members, &members, 1)
return err if err != nil {
} return err
}
genGetSet(writer, members, 1)
fmt.Fprintf(writer, "}")
return nil
}
func writeMembers(writer io.Writer, types []spec.Type, members []spec.Member, allMembers *[]spec.Member, indent int) error {
for _, member := range members {
if !member.IsInline {
_, err := member.GetPropertyName()
if err != nil {
return err
}
}
if !member.IsBodyMember() {
continue
}
for _, item := range *allMembers {
if item.Name == member.Name {
continue
}
}
if member.IsInline {
hasInline := false
for _, ty := range types {
if strings.ToLower(ty.Name) == strings.ToLower(member.Name) {
err := writeMembers(writer, types, ty.Members, allMembers, indent)
if err != nil {
return err
}
hasInline = true
break
}
}
if !hasInline {
return errors.New("inline type " + member.Name + " not exist, please correct api file")
}
} else {
if err := writeProperty(writer, member, indent); err != nil {
return err
}
*allMembers = append(*allMembers, member)
}
} }
genGetSet(writer, tp, 1)
fmt.Fprintf(writer, "}\n")
return nil return nil
} }

View File

@@ -19,23 +19,27 @@ const packetTemplate = `package com.xhb.logic.http.packet.{{.packet}};
import com.google.gson.Gson; import com.google.gson.Gson;
import com.xhb.commons.JSON; import com.xhb.commons.JSON;
import com.xhb.commons.JsonParser; import com.xhb.commons.JsonMarshal;
import com.xhb.core.network.HttpRequestClient; import com.xhb.core.network.HttpRequestClient;
import com.xhb.core.packet.HttpRequestPacket; import com.xhb.core.packet.HttpRequestPacket;
import com.xhb.core.response.HttpResponseData; import com.xhb.core.response.HttpResponseData;
import com.xhb.logic.http.DeProguardable; import com.xhb.logic.http.DeProguardable;
{{if not .HasRequestBody}}
import com.xhb.logic.http.request.EmptyRequest;
{{end}}
{{.import}} {{.import}}
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.json.JSONObject; import org.json.JSONObject;
public class {{.packetName}} extends HttpRequestPacket<{{.packetName}}.{{.packetName}}Response> { public class {{.packetName}} extends HttpRequestPacket<{{.packetName}}.{{.packetName}}Response> {
{{.paramsDeclaration}} {{.paramsDeclaration}}
public {{.packetName}}({{.params}}{{.requestType}} request) { public {{.packetName}}({{.params}}{{if .HasRequestBody}}, {{.requestType}} request{{end}}) {
super(request); {{if .HasRequestBody}}super(request);{{else}}super(EmptyRequest.instance);{{end}}
this.request = request;{{.paramsSet}} {{if .HasRequestBody}}this.request = request;{{end}}{{.paramsSet}}
} }
@Override @Override
@@ -113,7 +117,8 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
} else { } else {
fmt.Fprintln(&builder) fmt.Fprintln(&builder)
} }
if err := genType(&builder, tp); err != nil {
if err := genType(&builder, tp, api.Types); err != nil {
return err return err
} }
} }
@@ -126,7 +131,7 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
t := template.Must(template.New("packetTemplate").Parse(packetTemplate)) t := template.Must(template.New("packetTemplate").Parse(packetTemplate))
var tmplBytes bytes.Buffer var tmplBytes bytes.Buffer
err = t.Execute(&tmplBytes, map[string]string{ err = t.Execute(&tmplBytes, map[string]interface{}{
"packetName": packet, "packetName": packet,
"method": strings.ToUpper(route.Method), "method": strings.ToUpper(route.Method),
"uri": processUri(route), "uri": processUri(route),
@@ -137,6 +142,7 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
"paramsSet": paramsSet, "paramsSet": paramsSet,
"packet": packetName, "packet": packetName,
"requestType": util.Title(route.RequestType.Name), "requestType": util.Title(route.RequestType.Name),
"HasRequestBody": len(route.RequestType.GetBodyMembers()) > 0,
"import": getImports(api, route, packetName), "import": getImports(api, route, packetName),
}) })
if err != nil { if err != nil {
@@ -209,7 +215,7 @@ func paramsForRoute(route spec.Route) string {
builder.WriteString(fmt.Sprintf("String %s, ", cop[1:])) builder.WriteString(fmt.Sprintf("String %s, ", cop[1:]))
} }
} }
return builder.String() return strings.TrimSuffix(builder.String(), ", ")
} }
func declarationForRoute(route spec.Route) string { func declarationForRoute(route spec.Route) string {
@@ -260,18 +266,22 @@ func processUri(route spec.Route) string {
return result return result
} }
func genType(writer io.Writer, tp spec.Type) error { func genType(writer io.Writer, tp spec.Type, types []spec.Type) error {
writeIndent(writer, 1) if len(tp.GetBodyMembers()) == 0 {
fmt.Fprintf(writer, "static class %s implements DeProguardable {\n", util.Title(tp.Name)) return nil
for _, member := range tp.Members {
if err := writeProperty(writer, member, 2); err != nil {
return err
}
} }
writeBreakline(writer)
writeIndent(writer, 1) writeIndent(writer, 1)
genGetSet(writer, tp, 2) fmt.Fprintf(writer, "static class %s implements DeProguardable {\n", util.Title(tp.Name))
var members []spec.Member
err := writeMembers(writer, types, tp.Members, &members, 2)
if err != nil {
return err
}
writeNewline(writer)
writeIndent(writer, 1)
genGetSet(writer, members, 2)
writeIndent(writer, 1) writeIndent(writer, 1)
fmt.Fprintln(writer, "}") fmt.Fprintln(writer, "}")

View File

@@ -67,8 +67,8 @@ func indentString(indent int) string {
return result return result
} }
func writeBreakline(writer io.Writer) { func writeNewline(writer io.Writer) {
fmt.Fprint(writer, "\n") fmt.Fprint(writer, util.NL)
} }
func isPrimitiveType(tp string) bool { func isPrimitiveType(tp string) bool {
@@ -87,6 +87,7 @@ func goTypeToJava(tp string) (string, error) {
if len(tp) == 0 { if len(tp) == 0 {
return "", errors.New("property type empty") return "", errors.New("property type empty")
} }
if strings.HasPrefix(tp, "*") { if strings.HasPrefix(tp, "*") {
tp = tp[1:] tp = tp[1:]
} }
@@ -107,39 +108,44 @@ func goTypeToJava(tp string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
if len(tys) == 0 { if len(tys) == 0 {
return "", fmt.Errorf("%s tp parse error", tp) return "", fmt.Errorf("%s tp parse error", tp)
} }
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(tys[0])), nil return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(tys[0])), nil
} else if strings.HasPrefix(tp, "map") { } else if strings.HasPrefix(tp, "map") {
tys, err := apiutil.DecomposeType(tp) tys, err := apiutil.DecomposeType(tp)
if err != nil { if err != nil {
return "", err return "", err
} }
if len(tys) == 2 { if len(tys) == 2 {
return "", fmt.Errorf("%s tp parse error", tp) return "", fmt.Errorf("%s tp parse error", tp)
} }
return fmt.Sprintf("java.util.HashMap<String, %s>", util.Title(tys[1])), nil return fmt.Sprintf("java.util.HashMap<String, %s>", util.Title(tys[1])), nil
} }
return util.Title(tp), nil return util.Title(tp), nil
} }
func genGetSet(writer io.Writer, tp spec.Type, indent int) error { func genGetSet(writer io.Writer, members []spec.Member, indent int) error {
t := template.Must(template.New("getSetTemplate").Parse(getSetTemplate)) t := template.Must(template.New("getSetTemplate").Parse(getSetTemplate))
for _, member := range tp.Members { for _, member := range members {
var tmplBytes bytes.Buffer var tmplBytes bytes.Buffer
oty, err := goTypeToJava(member.Type) oty, err := goTypeToJava(member.Type)
if err != nil { if err != nil {
return err return err
} }
tyString := oty tyString := oty
decorator := "" decorator := ""
if !isPrimitiveType(member.Type) { if !isPrimitiveType(member.Type) {
if member.IsOptional() { if member.IsOptional() {
decorator = "@org.jetbrains.annotations.Nullable " decorator = "@Nullable "
} else { } else {
decorator = "@org.jetbrains.annotations.NotNull " decorator = "@NotNull "
} }
tyString = decorator + tyString tyString = decorator + tyString
} }
@@ -155,6 +161,7 @@ func genGetSet(writer io.Writer, tp spec.Type, indent int) error {
if err != nil { if err != nil {
return err return err
} }
r := tmplBytes.String() r := tmplBytes.String()
r = strings.Replace(r, " boolean get", " boolean is", 1) r = strings.Replace(r, " boolean get", " boolean is", 1)
writer.Write([]byte(r)) writer.Write([]byte(r))

View File

@@ -63,10 +63,6 @@ func (m Member) IsOmitempty() bool {
func (m Member) GetPropertyName() (string, error) { func (m Member) GetPropertyName() (string, error) {
tags := m.Tags() tags := m.Tags()
if len(tags) == 0 {
return "", errors.New("json property name not exist, member: " + m.Name)
}
for _, tag := range tags { for _, tag := range tags {
if stringx.Contains(definedKeys, tag.Key) { if stringx.Contains(definedKeys, tag.Key) {
if tag.Name == "-" { if tag.Name == "-" {

View File

@@ -85,7 +85,7 @@ func genHandler(dir, webApi, caller string, api *spec.ApiSpec, unwrapApi bool) e
imports += fmt.Sprintf(`import * as components from "%s"`, "./"+outputFile) imports += fmt.Sprintf(`import * as components from "%s"`, "./"+outputFile)
} }
apis, err := genApi(api, localTypes, caller, prefixForType) apis, err := genApi(api, caller, prefixForType)
if err != nil { if err != nil {
return err return err
} }
@@ -119,32 +119,34 @@ func genTypes(localTypes []spec.Type, inlineType func(string) (*spec.Type, error
return types, nil return types, nil
} }
func genApi(api *spec.ApiSpec, localTypes []spec.Type, caller string, prefixForType func(string) string) (string, error) { func genApi(api *spec.ApiSpec, caller string, prefixForType func(string) string) (string, error) {
var builder strings.Builder var builder strings.Builder
for _, route := range api.Service.Routes() { for _, group := range api.Service.Groups {
handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler") for _, route := range group.Routes {
if !ok { handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler")
return "", fmt.Errorf("missing handler annotation for route %q", route.Path) if !ok {
} return "", fmt.Errorf("missing handler annotation for route %q", route.Path)
handler = util.Untitle(handler)
handler = strings.Replace(handler, "Handler", "", 1)
comment := commentForRoute(route)
if len(comment) > 0 {
fmt.Fprintf(&builder, "%s\n", comment)
}
fmt.Fprintf(&builder, "export function %s(%s) {\n", handler, paramsForRoute(route, prefixForType))
writeIndent(&builder, 1)
responseGeneric := "<null>"
if len(route.ResponseType.Name) > 0 {
val, err := goTypeToTs(route.ResponseType.Name, prefixForType)
if err != nil {
return "", err
} }
responseGeneric = fmt.Sprintf("<%s>", val) handler = util.Untitle(handler)
handler = strings.Replace(handler, "Handler", "", 1)
comment := commentForRoute(route)
if len(comment) > 0 {
fmt.Fprintf(&builder, "%s\n", comment)
}
fmt.Fprintf(&builder, "export function %s(%s) {\n", handler, paramsForRoute(route, prefixForType))
writeIndent(&builder, 1)
responseGeneric := "<null>"
if len(route.ResponseType.Name) > 0 {
val, err := goTypeToTs(route.ResponseType.Name, prefixForType)
if err != nil {
return "", err
}
responseGeneric = fmt.Sprintf("<%s>", val)
}
fmt.Fprintf(&builder, `return %s.%s%s(%s)`, caller, strings.ToLower(route.Method),
util.Title(responseGeneric), callParamsForRoute(route, group))
builder.WriteString("\n}\n\n")
} }
fmt.Fprintf(&builder, `return %s.%s%s(%s)`, caller, strings.ToLower(route.Method),
util.Title(responseGeneric), callParamsForRoute(route))
builder.WriteString("\n}\n\n")
} }
apis := builder.String() apis := builder.String()
@@ -188,21 +190,28 @@ func commentForRoute(route spec.Route) string {
return builder.String() return builder.String()
} }
func callParamsForRoute(route spec.Route) string { func callParamsForRoute(route spec.Route, group spec.Group) string {
hasParams := pathHasParams(route) hasParams := pathHasParams(route)
hasBody := hasRequestBody(route) hasBody := hasRequestBody(route)
if hasParams && hasBody { if hasParams && hasBody {
return fmt.Sprintf("%s, %s, %s", pathForRoute(route), "params", "req") return fmt.Sprintf("%s, %s, %s", pathForRoute(route, group), "params", "req")
} else if hasParams { } else if hasParams {
return fmt.Sprintf("%s, %s", pathForRoute(route), "params") return fmt.Sprintf("%s, %s", pathForRoute(route, group), "params")
} else if hasBody { } else if hasBody {
return fmt.Sprintf("%s, %s", pathForRoute(route), "req") return fmt.Sprintf("%s, %s", pathForRoute(route, group), "req")
} }
return pathForRoute(route) return pathForRoute(route, group)
} }
func pathForRoute(route spec.Route) string { func pathForRoute(route spec.Route, group spec.Group) string {
return "\"" + route.Path + "\"" value, ok := apiutil.GetAnnotationValue(group.Annotations, "server", pathPrefix)
if !ok {
return "\"" + route.Path + "\""
} else {
value = strings.TrimPrefix(value, `"`)
value = strings.TrimSuffix(value, `"`)
return fmt.Sprintf(`"%s/%s"`, value, strings.TrimPrefix(route.Path, "/"))
}
} }
func pathHasParams(route spec.Route) bool { func pathHasParams(route spec.Route) bool {

View File

@@ -2,4 +2,5 @@ package tsgen
const ( const (
packagePrefix = "components." packagePrefix = "components."
pathPrefix = "pathPrefix"
) )

View File

@@ -9,15 +9,17 @@ import (
"text/template" "text/template"
"time" "time"
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
const ( const (
etcDir = "etc" dockerfileName = "Dockerfile"
yamlEtx = ".yaml" etcDir = "etc"
cstOffset = 60 * 60 * 8 // 8 hours offset for Chinese Standard Time yamlEtx = ".yaml"
cstOffset = 60 * 60 * 8 // 8 hours offset for Chinese Standard Time
) )
type Docker struct { type Docker struct {
@@ -25,10 +27,18 @@ type Docker struct {
GoRelPath string GoRelPath string
GoFile string GoFile string
ExeFile string ExeFile string
HasPort bool
Port int
Argument string Argument string
} }
func DockerCommand(c *cli.Context) error { func DockerCommand(c *cli.Context) (err error) {
defer func() {
if err == nil {
fmt.Println(aurora.Green("Done."))
}
}()
goFile := c.String("go") goFile := c.String("go")
if len(goFile) == 0 { if len(goFile) == 0 {
return errors.New("-go can't be empty") return errors.New("-go can't be empty")
@@ -38,8 +48,9 @@ func DockerCommand(c *cli.Context) error {
return fmt.Errorf("file %q not found", goFile) return fmt.Errorf("file %q not found", goFile)
} }
port := c.Int("port")
if _, err := os.Stat(etcDir); os.IsNotExist(err) { if _, err := os.Stat(etcDir); os.IsNotExist(err) {
return generateDockerfile(goFile) return generateDockerfile(goFile, port)
} }
cfg, err := findConfig(goFile, etcDir) cfg, err := findConfig(goFile, etcDir)
@@ -47,13 +58,13 @@ func DockerCommand(c *cli.Context) error {
return err return err
} }
if err := generateDockerfile(goFile, "-f", "etc/"+cfg); err != nil { if err := generateDockerfile(goFile, port, "-f", "etc/"+cfg); err != nil {
return err return err
} }
projDir, ok := util.FindProjectPath(goFile) projDir, ok := util.FindProjectPath(goFile)
if ok { if ok {
fmt.Printf("Run \"docker build ...\" command in dir %q\n", projDir) fmt.Printf("Hint: run \"docker build ...\" command in dir %q\n", projDir)
} }
return nil return nil
@@ -88,18 +99,22 @@ func findConfig(file, dir string) (string, error) {
return files[0], nil return files[0], nil
} }
func generateDockerfile(goFile string, args ...string) error { func generateDockerfile(goFile string, port int, args ...string) error {
projPath, err := getFilePath(filepath.Dir(goFile)) projPath, err := getFilePath(filepath.Dir(goFile))
if err != nil { if err != nil {
return err return err
} }
pos := strings.IndexByte(projPath, '/') if len(projPath) == 0 {
if pos >= 0 { projPath = "."
projPath = projPath[pos+1:] } else {
pos := strings.IndexByte(projPath, os.PathSeparator)
if pos >= 0 {
projPath = projPath[pos+1:]
}
} }
out, err := util.CreateIfNotExist("Dockerfile") out, err := util.CreateIfNotExist(dockerfileName)
if err != nil { if err != nil {
return err return err
} }
@@ -122,6 +137,8 @@ func generateDockerfile(goFile string, args ...string) error {
GoRelPath: projPath, GoRelPath: projPath,
GoFile: goFile, GoFile: goFile,
ExeFile: util.FileNameWithoutExt(filepath.Base(goFile)), ExeFile: util.FileNameWithoutExt(filepath.Base(goFile)),
HasPort: port > 0,
Port: port,
Argument: builder.String(), Argument: builder.String(),
}) })
} }

View File

@@ -14,34 +14,59 @@ LABEL stage=gobuilder
ENV CGO_ENABLED 0 ENV CGO_ENABLED 0
ENV GOOS linux ENV GOOS linux
{{if .Chinese}}ENV GOPROXY https://goproxy.cn,direct{{end}} {{if .Chinese}}ENV GOPROXY https://goproxy.cn,direct
{{end}}
WORKDIR /build/zero WORKDIR /build/zero
ADD go.mod . ADD go.mod .
ADD go.sum . ADD go.sum .
RUN go mod download RUN go mod download
COPY . . COPY . .
COPY {{.GoRelPath}}/etc /app/etc {{if .Argument}}COPY {{.GoRelPath}}/etc /app/etc
RUN go build -ldflags="-s -w" -o /app/{{.ExeFile}} {{.GoRelPath}}/{{.GoFile}} {{end}}RUN go build -ldflags="-s -w" -o /app/{{.ExeFile}} {{.GoRelPath}}/{{.GoFile}}
FROM alpine FROM alpine
RUN apk update --no-cache RUN apk update --no-cache && apk add --no-cache ca-certificates tzdata
RUN apk add --no-cache ca-certificates
RUN apk add --no-cache tzdata
ENV TZ Asia/Shanghai ENV TZ Asia/Shanghai
WORKDIR /app WORKDIR /app
COPY --from=builder /app/{{.ExeFile}} /app/{{.ExeFile}} COPY --from=builder /app/{{.ExeFile}} /app/{{.ExeFile}}{{if .Argument}}
COPY --from=builder /app/etc /app/etc COPY --from=builder /app/etc /app/etc{{end}}
{{if .HasPort}}
EXPOSE {{.Port}}
{{end}}
CMD ["./{{.ExeFile}}"{{.Argument}}] CMD ["./{{.ExeFile}}"{{.Argument}}]
` `
) )
func Clean() error {
return util.Clean(category)
}
func GenTemplates(_ *cli.Context) error { func GenTemplates(_ *cli.Context) error {
return initTemplate()
}
func Category() string {
return category
}
func RevertTemplate(name string) error {
return util.CreateTemplate(category, name, dockerTemplate)
}
func Update() error {
err := Clean()
if err != nil {
return err
}
return initTemplate()
}
func initTemplate() error {
return util.InitTemplates(category, map[string]string{ return util.InitTemplates(category, map[string]string{
dockerTemplateFile: dockerTemplate, dockerTemplateFile: dockerTemplate,
}) })

View File

@@ -27,7 +27,7 @@ import (
) )
var ( var (
BuildVersion = "20201125" BuildVersion = "1.1.1"
commands = []cli.Command{ commands = []cli.Command{
{ {
Name: "api", Name: "api",
@@ -54,14 +54,12 @@ var (
Usage: "the format target dir", Usage: "the format target dir",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "iu", Name: "iu",
Usage: "ignore update", Usage: "ignore update",
Required: false,
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "stdin", Name: "stdin",
Usage: "use stdin to input api doc content, press \"ctrl + d\" to send EOF", Usage: "use stdin to input api doc content, press \"ctrl + d\" to send EOF",
Required: false,
}, },
}, },
Action: format.GoFormatApi, Action: format.GoFormatApi,
@@ -101,9 +99,8 @@ var (
Usage: "the api file", Usage: "the api file",
}, },
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Required: false, Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
}, },
}, },
Action: gogen.GoCommand, Action: gogen.GoCommand,
@@ -136,19 +133,16 @@ var (
Usage: "the api file", Usage: "the api file",
}, },
cli.StringFlag{ cli.StringFlag{
Name: "webapi", Name: "webapi",
Usage: "the web api file path", Usage: "the web api file path",
Required: false,
}, },
cli.StringFlag{ cli.StringFlag{
Name: "caller", Name: "caller",
Usage: "the web api caller", Usage: "the web api caller",
Required: false,
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "unwrap", Name: "unwrap",
Usage: "unwrap the webapi caller for import", Usage: "unwrap the webapi caller for import",
Required: false,
}, },
}, },
Action: tsgen.TsCommand, Action: tsgen.TsCommand,
@@ -204,9 +198,8 @@ var (
Usage: "the api file", Usage: "the api file",
}, },
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Required: false, Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
}, },
}, },
Action: plugin.PluginCommand, Action: plugin.PluginCommand,
@@ -221,6 +214,11 @@ var (
Name: "go", Name: "go",
Usage: "the file that contains main function", Usage: "the file that contains main function",
}, },
cli.IntFlag{
Name: "port",
Usage: "the port to expose, default none",
Value: 0,
},
}, },
Action: docker.DockerCommand, Action: docker.DockerCommand,
}, },
@@ -248,9 +246,8 @@ var (
Required: true, Required: true,
}, },
cli.StringFlag{ cli.StringFlag{
Name: "secret", Name: "secret",
Usage: "the image pull secret", Usage: "the secret to image pull from registry",
Required: true,
}, },
cli.IntFlag{ cli.IntFlag{
Name: "requestCpu", Name: "requestCpu",
@@ -321,9 +318,8 @@ var (
Usage: `generate rpc demo service`, Usage: `generate rpc demo service`,
Flags: []cli.Flag{ Flags: []cli.Flag{
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Required: false, Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",
@@ -360,9 +356,8 @@ var (
Usage: `the target path of the code`, Usage: `the target path of the code`,
}, },
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Required: false, Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",
@@ -394,9 +389,8 @@ var (
Usage: "the target dir", Usage: "the target dir",
}, },
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Required: false, Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "cache, c", Name: "cache, c",
@@ -430,9 +424,8 @@ var (
Usage: "the target dir", Usage: "the target dir",
}, },
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Required: false, Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",
@@ -476,7 +469,7 @@ var (
Flags: []cli.Flag{ Flags: []cli.Flag{
cli.StringFlag{ cli.StringFlag{
Name: "category,c", Name: "category,c",
Usage: "the category of template, enum [api,rpc,model]", Usage: "the category of template, enum [api,rpc,model,docker,kube]",
}, },
}, },
Action: tpl.UpdateTemplates, Action: tpl.UpdateTemplates,
@@ -487,7 +480,7 @@ var (
Flags: []cli.Flag{ Flags: []cli.Flag{
cli.StringFlag{ cli.StringFlag{
Name: "category,c", Name: "category,c",
Usage: "the category of template, enum [api,rpc,model]", Usage: "the category of template, enum [api,rpc,model,docker,kube]",
}, },
cli.StringFlag{ cli.StringFlag{
Name: "name,n", Name: "name,n",

View File

@@ -47,9 +47,9 @@ spec:
volumeMounts: volumeMounts:
- name: timezone - name: timezone
mountPath: /etc/localtime mountPath: /etc/localtime
imagePullSecrets: {{if .Secret}}imagePullSecrets:
- name: {{.Secret}} - name: {{.Secret}}
volumes: {{end}}volumes:
- name: timezone - name: timezone
hostPath: hostPath:
path: /usr/share/zoneinfo/Asia/Shanghai path: /usr/share/zoneinfo/Asia/Shanghai

View File

@@ -2,8 +2,10 @@ package kube
import ( import (
"errors" "errors"
"fmt"
"text/template" "text/template"
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@@ -16,47 +18,23 @@ const (
portLimit = 32767 portLimit = 32767
) )
var errUnknownServiceType = errors.New("unknown service type") type Deployment struct {
Name string
type ( Namespace string
ServiceType string Image string
Secret string
KubeRequest struct { Replicas int
Env string Revisions int
ServiceName string Port int
ServiceType ServiceType NodePort int
Namespace string UseNodePort bool
Schedule string RequestCpu int
Replicas int RequestMem int
RevisionHistoryLimit int LimitCpu int
Port int LimitMem int
LimitCpu int MinReplicas int
LimitMem int MaxReplicas int
RequestCpu int }
RequestMem int
SuccessfulJobsHistoryLimit int
HpaMinReplicas int
HpaMaxReplicas int
}
Deployment struct {
Name string
Namespace string
Image string
Secret string
Replicas int
Revisions int
Port int
NodePort int
UseNodePort bool
RequestCpu int
RequestMem int
LimitCpu int
LimitMem int
MinReplicas int
MaxReplicas int
}
)
func DeploymentCommand(c *cli.Context) error { func DeploymentCommand(c *cli.Context) error {
nodePort := c.Int("nodePort") nodePort := c.Int("nodePort")
@@ -77,7 +55,7 @@ func DeploymentCommand(c *cli.Context) error {
defer out.Close() defer out.Close()
t := template.Must(template.New("deploymentTemplate").Parse(text)) t := template.Must(template.New("deploymentTemplate").Parse(text))
return t.Execute(out, Deployment{ err = t.Execute(out, Deployment{
Name: c.String("name"), Name: c.String("name"),
Namespace: c.String("namespace"), Namespace: c.String("namespace"),
Image: c.String("image"), Image: c.String("image"),
@@ -94,6 +72,20 @@ func DeploymentCommand(c *cli.Context) error {
MinReplicas: c.Int("minReplicas"), MinReplicas: c.Int("minReplicas"),
MaxReplicas: c.Int("maxReplicas"), MaxReplicas: c.Int("maxReplicas"),
}) })
if err != nil {
return err
}
fmt.Println(aurora.Green("Done."))
return nil
}
func Category() string {
return category
}
func Clean() error {
return util.Clean(category)
} }
func GenTemplates(_ *cli.Context) error { func GenTemplates(_ *cli.Context) error {
@@ -102,3 +94,19 @@ func GenTemplates(_ *cli.Context) error {
jobTemplateFile: jobTmeplate, jobTemplateFile: jobTmeplate,
}) })
} }
func RevertTemplate(name string) error {
return util.CreateTemplate(category, name, deploymentTemplate)
}
func Update() error {
err := Clean()
if err != nil {
return err
}
return util.InitTemplates(category, map[string]string{
deployTemplateFile: deploymentTemplate,
jobTemplateFile: jobTmeplate,
})
}

View File

@@ -61,9 +61,9 @@ func FieldNames(in interface{}) []string {
// gets us a StructField // gets us a StructField
fi := typ.Field(i) fi := typ.Field(i)
if tagv := fi.Tag.Get(dbTag); tagv != "" { if tagv := fi.Tag.Get(dbTag); tagv != "" {
out = append(out, tagv) out = append(out, fmt.Sprintf("`%v`", tagv))
} else { } else {
out = append(out, fi.Name) out = append(out, fmt.Sprintf("`%v`", fi.Name))
} }
} }
return out return out

View File

@@ -28,8 +28,7 @@ var userFields = FieldNames(User{})
func TestFieldNames(t *testing.T) { func TestFieldNames(t *testing.T) {
var u User var u User
out := FieldNames(&u) out := FieldNames(&u)
fmt.Println(out) actual := []string{"`id`", "`user_name`", "`sex`", "`uuid`", "`age`"}
actual := []string{"id", "user_name", "sex", "uuid", "age"}
assert.Equal(t, out, actual) assert.Equal(t, out, actual)
} }
@@ -54,7 +53,7 @@ func TestBuilderSql(t *testing.T) {
sql, args, err := builder.Select(fields...).From("user").Where(eq).ToSQL() sql, args, err := builder.Select(fields...).From("user").Where(eq).ToSQL()
fmt.Println(sql, args, err) fmt.Println(sql, args, err)
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id=?" actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE id=?"
actualArgs := []interface{}{"123123"} actualArgs := []interface{}{"123123"}
assert.Equal(t, sql, actualSql) assert.Equal(t, sql, actualSql)
assert.Equal(t, args, actualArgs) assert.Equal(t, args, actualArgs)
@@ -68,7 +67,7 @@ func TestBuildSqlDefaultValue(t *testing.T) {
sql, args, err := builder.Select(userFields...).From("user").Where(eq).ToSQL() sql, args, err := builder.Select(userFields...).From("user").Where(eq).ToSQL()
fmt.Println(sql, args, err) fmt.Println(sql, args, err)
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE age=? AND user_name=?" actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE age=? AND user_name=?"
actualArgs := []interface{}{0, ""} actualArgs := []interface{}{0, ""}
assert.Equal(t, sql, actualSql) assert.Equal(t, sql, actualSql)
assert.Equal(t, args, actualArgs) assert.Equal(t, args, actualArgs)
@@ -83,7 +82,7 @@ func TestBuilderSqlIn(t *testing.T) {
sql, args, err := builder.Select(userFields...).From("user").Where(in).And(gtU).ToSQL() sql, args, err := builder.Select(userFields...).From("user").Where(in).And(gtU).ToSQL()
fmt.Println(sql, args, err) fmt.Println(sql, args, err)
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id IN (?,?,?) AND age>?" actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE id IN (?,?,?) AND age>?"
actualArgs := []interface{}{"1", "2", "3", 18} actualArgs := []interface{}{"1", "2", "3", 18}
assert.Equal(t, sql, actualSql) assert.Equal(t, sql, actualSql)
assert.Equal(t, args, actualArgs) assert.Equal(t, args, actualArgs)
@@ -94,7 +93,7 @@ func TestBuildSqlLike(t *testing.T) {
sql, args, err := builder.Select(userFields...).From("user").Where(like).ToSQL() sql, args, err := builder.Select(userFields...).From("user").Where(like).ToSQL()
fmt.Println(sql, args, err) fmt.Println(sql, args, err)
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE name LIKE ?" actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE name LIKE ?"
actualArgs := []interface{}{"%wang%"} actualArgs := []interface{}{"%wang%"}
assert.Equal(t, sql, actualSql) assert.Equal(t, sql, actualSql)
assert.Equal(t, args, actualArgs) assert.Equal(t, args, actualArgs)

View File

@@ -1,8 +1,13 @@
#!/bin/bash #!/bin/bash
# generate model with cache from ddl # generate model with cache from ddl
fromDDL: fromDDLWithCache:
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -cache goctl template clean;
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/cache/user" -cache;
fromDDLWithoutCache:
goctl template clean;
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/nocache/user";
# generate model with cache from data source # generate model with cache from data source
@@ -12,4 +17,5 @@ datasource=127.0.0.1:3306
database=gozero database=gozero
fromDataSource: fromDataSource:
goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style gozero goctl template clean;
goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style gozero;

View File

@@ -36,7 +36,7 @@ func genDelete(table Table, withCache bool) (string, string, error) {
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(), "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
"dataType": table.PrimaryKey.DataType, "dataType": table.PrimaryKey.DataType,
"keys": strings.Join(keySet.KeysStr(), "\n"), "keys": strings.Join(keySet.KeysStr(), "\n"),
"originalPrimaryKey": table.PrimaryKey.Name.Source(), "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
"keyValues": strings.Join(keyVariableSet.KeysStr(), ", "), "keyValues": strings.Join(keyVariableSet.KeysStr(), ", "),
}) })
if err != nil { if err != nil {

View File

@@ -19,7 +19,7 @@ func genFindOne(table Table, withCache bool) (string, string, error) {
"withCache": withCache, "withCache": withCache,
"upperStartCamelObject": camel, "upperStartCamelObject": camel,
"lowerStartCamelObject": stringx.From(camel).Untitle(), "lowerStartCamelObject": stringx.From(camel).Untitle(),
"originalPrimaryKey": table.PrimaryKey.Name.Source(), "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(), "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
"dataType": table.PrimaryKey.DataType, "dataType": table.PrimaryKey.DataType,
"cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression, "cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression,

View File

@@ -39,7 +39,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(), "lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
"lowerStartCamelField": stringx.From(camelFieldName).Untitle(), "lowerStartCamelField": stringx.From(camelFieldName).Untitle(),
"upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(), "upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(),
"originalField": field.Name.Source(), "originalField": wrapWithRawString(field.Name.Source()),
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -82,7 +82,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
"upperStartCamelObject": camelTableName, "upperStartCamelObject": camelTableName,
"primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left, "primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left,
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(), "lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
"originalPrimaryField": table.PrimaryKey.Name.Source(), "originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source()),
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -21,9 +21,6 @@ import (
const ( const (
pwd = "." pwd = "."
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
NamingLower = "lower"
NamingCamel = "camel"
NamingSnake = "snake"
) )
type ( type (
@@ -280,3 +277,20 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return output.String(), nil return output.String(), nil
} }
func wrapWithRawString(v string) string {
if v == "`" {
return v
}
if !strings.HasPrefix(v, "`") {
v = "`" + v
}
if !strings.HasSuffix(v, "`") {
v = v + "`"
} else if len(v) == 1 {
v = v + "`"
}
return v
}

View File

@@ -1,13 +1,18 @@
package gen package gen
import ( import (
"database/sql"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
) )
var ( var (
@@ -79,3 +84,32 @@ func TestNamingModel(t *testing.T) {
return err == nil return err == nil
}()) }())
} }
func TestWrapWithRawString(t *testing.T) {
assert.Equal(t, "``", wrapWithRawString(""))
assert.Equal(t, "``", wrapWithRawString("``"))
assert.Equal(t, "`a`", wrapWithRawString("a"))
assert.Equal(t, "` `", wrapWithRawString(" "))
}
func TestFields(t *testing.T) {
type Student struct {
Id int64 `db:"id"`
Name string `db:"name"`
Age sql.NullInt64 `db:"age"`
Score sql.NullFloat64 `db:"score"`
CreateTime time.Time `db:"create_time"`
UpdateTime sql.NullTime `db:"update_time"`
}
var (
studentFieldNames = builderx.FieldNames(&Student{})
studentRows = strings.Join(studentFieldNames, ",")
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
)
assert.Equal(t, []string{"`id`", "`name`", "`age`", "`score`", "`create_time`", "`update_time`"}, studentFieldNames)
assert.Equal(t, "`id`,`name`,`age`,`score`,`create_time`,`update_time`", studentRows)
assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet)
assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder)
}

View File

@@ -14,7 +14,7 @@ func genNew(table Table, withCache bool) (string, error) {
output, err := util.With("new"). output, err := util.With("new").
Parse(text). Parse(text).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"table": table.Name.Source(), "table": wrapWithRawString(table.Name.Source()),
"withCache": withCache, "withCache": withCache,
"upperStartCamelObject": table.Name.ToCamel(), "upperStartCamelObject": table.Name.ToCamel(),
}) })

View File

@@ -54,6 +54,14 @@ var templates = map[string]string{
errTemplateFile: template.Error, errTemplateFile: template.Error,
} }
func Category() string {
return category
}
func Clean() error {
return util.Clean(category)
}
func GenTemplates(_ *cli.Context) error { func GenTemplates(_ *cli.Context) error {
return util.InitTemplates(category, templates) return util.InitTemplates(category, templates)
} }
@@ -66,18 +74,10 @@ func RevertTemplate(name string) error {
return util.CreateTemplate(category, name, content) return util.CreateTemplate(category, name, content)
} }
func Clean() error { func Update() error {
return util.Clean(category)
}
func Update(category string) error {
err := Clean() err := Clean()
if err != nil { if err != nil {
return err return err
} }
return util.InitTemplates(category, templates) return util.InitTemplates(category, templates)
} }
func GetCategory() string {
return category
}

View File

@@ -85,7 +85,7 @@ func TestUpdate(t *testing.T) {
assert.Equal(t, string(data), modifyData) assert.Equal(t, string(data), modifyData)
assert.Nil(t, Update(category)) assert.Nil(t, Update())
data, err = ioutil.ReadFile(file) data, err = ioutil.ReadFile(file)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -35,7 +35,7 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
"primaryCacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].DataKeyExpression, "primaryCacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].DataKeyExpression,
"primaryKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable, "primaryKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable,
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(), "lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
"originalPrimaryKey": table.PrimaryKey.Name.Source(), "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
"expressionValues": strings.Join(expressionValues, ", "), "expressionValues": strings.Join(expressionValues, ", "),
}) })
if err != nil { if err != nil {

View File

@@ -27,7 +27,7 @@ func genVars(table Table, withCache bool) (string, error) {
"upperStartCamelObject": camel, "upperStartCamelObject": camel,
"cacheKeys": strings.Join(keys, "\n"), "cacheKeys": strings.Join(keys, "\n"),
"autoIncrement": table.PrimaryKey.AutoIncrement, "autoIncrement": table.PrimaryKey.AutoIncrement,
"originalPrimaryKey": table.PrimaryKey.Name.Source(), "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
"withCache": withCache, "withCache": withCache,
}) })
if err != nil { if err != nil {

View File

@@ -1,34 +0,0 @@
package model
import (
"github.com/tal-tech/go-zero/core/stores/sqlx"
)
type (
DDLModel struct {
conn sqlx.SqlConn
}
DDL struct {
Table string `db:"Table"`
DDL string `db:"Create Table"`
}
)
func NewDDLModel(conn sqlx.SqlConn) *DDLModel {
return &DDLModel{conn: conn}
}
func (m *DDLModel) ShowDDL(table ...string) ([]string, error) {
var ddl []string
for _, t := range table {
query := `show create table ` + t
var resp DDL
err := m.conn.QueryRow(&resp, query)
if err != nil {
return nil, err
}
ddl = append(ddl, resp.DDL)
}
return ddl, nil
}

View File

@@ -1,12 +1,14 @@
package template package template
var Vars = ` import "fmt"
var Vars = fmt.Sprintf(`
var ( var (
{{.lowerStartCamelObject}}FieldNames = builderx.FieldNames(&{{.upperStartCamelObject}}{}) {{.lowerStartCamelObject}}FieldNames = builderx.FieldNames(&{{.upperStartCamelObject}}{})
{{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",") {{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",")
{{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "create_time", "update_time"), ",") {{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ",")
{{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "create_time", "update_time"), "=?,") + "=?" {{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s"), "=?,") + "=?"
{{if .withCache}}{{.cacheKeys}}{{end}} {{if .withCache}}{{.cacheKeys}}{{end}}
) )
` `, "`", "`", "`", "`", "`", "`", "`", "`")

View File

@@ -0,0 +1,235 @@
package model
import (
"database/sql"
"fmt"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/stores/redis/redistest"
mocksql "github.com/tal-tech/go-zero/tools/goctl/model/sql/test"
)
func TestStudentModel(t *testing.T) {
var (
testTimeValue = time.Now()
testTable = "`student`"
testUpdateName = "gozero1"
testRowsAffected int64 = 1
testInsertId int64 = 1
)
var data Student
data.Id = testInsertId
data.Name = "gozero"
data.Age = sql.NullInt64{
Int64: 1,
Valid: true,
}
data.Score = sql.NullFloat64{
Float64: 100,
Valid: true,
}
data.CreateTime = testTimeValue
data.UpdateTime = sql.NullTime{
Time: testTimeValue,
Valid: true,
}
err := mockStudent(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
WithArgs(data.Name, data.Age, data.Score).
WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m StudentModel) {
r, err := m.Insert(data)
assert.Nil(t, err)
lastInsertId, err := r.LastInsertId()
assert.Nil(t, err)
assert.Equal(t, testInsertId, lastInsertId)
rowsAffected, err := r.RowsAffected()
assert.Nil(t, err)
assert.Equal(t, testRowsAffected, rowsAffected)
})
assert.Nil(t, err)
err = mockStudent(func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
WithArgs(testInsertId).
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
}, func(m StudentModel) {
result, err := m.FindOne(testInsertId)
assert.Nil(t, err)
assert.Equal(t, *result, data)
})
assert.Nil(t, err)
err = mockStudent(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(testUpdateName, data.Age, data.Score, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m StudentModel) {
data.Name = testUpdateName
err := m.Update(data)
assert.Nil(t, err)
})
assert.Nil(t, err)
err = mockStudent(func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
WithArgs(testInsertId).
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
}, func(m StudentModel) {
result, err := m.FindOne(testInsertId)
assert.Nil(t, err)
assert.Equal(t, *result, data)
})
assert.Nil(t, err)
err = mockStudent(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m StudentModel) {
err := m.Delete(testInsertId)
assert.Nil(t, err)
})
assert.Nil(t, err)
}
func TestUserModel(t *testing.T) {
var (
testTimeValue = time.Now()
testTable = "`user`"
testUpdateName = "gozero1"
testUser = "gozero"
testPassword = "test"
testMobile = "test_mobile"
testGender = "男"
testNickname = "test_nickname"
testRowsAffected int64 = 1
testInsertId int64 = 1
)
var data User
data.Id = testInsertId
data.User = testUser
data.Name = "gozero"
data.Password = testPassword
data.Mobile = testMobile
data.Gender = testGender
data.Nickname = testNickname
data.CreateTime = testTimeValue
data.UpdateTime = testTimeValue
err := mockUser(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
WithArgs(data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname).
WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m UserModel) {
r, err := m.Insert(data)
assert.Nil(t, err)
lastInsertId, err := r.LastInsertId()
assert.Nil(t, err)
assert.Equal(t, testInsertId, lastInsertId)
rowsAffected, err := r.RowsAffected()
assert.Nil(t, err)
assert.Equal(t, testRowsAffected, rowsAffected)
})
assert.Nil(t, err)
err = mockUser(func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
WithArgs(testInsertId).
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
}, func(m UserModel) {
result, err := m.FindOne(testInsertId)
assert.Nil(t, err)
assert.Equal(t, *result, data)
})
assert.Nil(t, err)
err = mockUser(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.User, testUpdateName, data.Password, data.Mobile, data.Gender, data.Nickname, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m UserModel) {
data.Name = testUpdateName
err := m.Update(data)
assert.Nil(t, err)
})
assert.Nil(t, err)
err = mockUser(func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
WithArgs(testInsertId).
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
}, func(m UserModel) {
result, err := m.FindOne(testInsertId)
assert.Nil(t, err)
assert.Equal(t, *result, data)
})
assert.Nil(t, err)
err = mockUser(func(mock sqlmock.Sqlmock) {
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
}, func(m UserModel) {
err := m.Delete(testInsertId)
assert.Nil(t, err)
})
assert.Nil(t, err)
}
// with cache
func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel)) error {
db, mock, err := sqlmock.New()
if err != nil {
return err
}
defer db.Close()
mock.ExpectBegin()
mockFn(mock)
mock.ExpectCommit()
conn := mocksql.NewMockConn(db)
r, clean, err := redistest.CreateRedis()
if err != nil {
return err
}
defer clean()
m := NewStudentModel(conn, cache.CacheConf{
{
RedisConf: redis.RedisConf{
Host: r.Addr,
Type: "node",
},
Weight: 100,
},
})
fn(m)
return nil
}
// without cache
func mockUser(mockFn func(mock sqlmock.Sqlmock), fn func(m UserModel)) error {
db, mock, err := sqlmock.New()
if err != nil {
return err
}
defer db.Close()
mock.ExpectBegin()
mockFn(mock)
mock.ExpectCommit()
conn := mocksql.NewMockConn(db)
m := NewUserModel(conn)
fn(m)
return nil
}

View File

@@ -0,0 +1,105 @@
package model
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/sqlc"
"github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
)
var (
studentFieldNames = builderx.FieldNames(&Student{})
studentRows = strings.Join(studentFieldNames, ",")
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
cacheStudentIdPrefix = "cache#Student#id#"
)
type (
StudentModel interface {
Insert(data Student) (sql.Result, error)
FindOne(id int64) (*Student, error)
Update(data Student) error
Delete(id int64) error
}
defaultStudentModel struct {
sqlc.CachedConn
table string
}
Student struct {
Id int64 `db:"id"`
Name string `db:"name"`
Age sql.NullInt64 `db:"age"`
Score sql.NullFloat64 `db:"score"`
CreateTime time.Time `db:"create_time"`
UpdateTime sql.NullTime `db:"update_time"`
}
)
func NewStudentModel(conn sqlx.SqlConn, c cache.CacheConf) StudentModel {
return &defaultStudentModel{
CachedConn: sqlc.NewConn(conn, c),
table: "`student`",
}
}
func (m *defaultStudentModel) Insert(data Student) (sql.Result, error) {
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?)", m.table, studentRowsExpectAutoSet)
ret, err := m.ExecNoCache(query, data.Name, data.Age, data.Score)
return ret, err
}
func (m *defaultStudentModel) FindOne(id int64) (*Student, error) {
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id)
var resp Student
err := m.QueryRow(&resp, studentIdKey, func(conn sqlx.SqlConn, v interface{}) error {
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table)
return conn.QueryRow(v, query, id)
})
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultStudentModel) Update(data Student) error {
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, data.Id)
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, studentRowsWithPlaceHolder)
return conn.Exec(query, data.Name, data.Age, data.Score, data.Id)
}, studentIdKey)
return err
}
func (m *defaultStudentModel) Delete(id int64) error {
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id)
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
return conn.Exec(query, id)
}, studentIdKey)
return err
}
func (m *defaultStudentModel) formatPrimary(primary interface{}) string {
return fmt.Sprintf("%s%v", cacheStudentIdPrefix, primary)
}
func (m *defaultStudentModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table)
return conn.QueryRow(v, query, primary)
}

View File

@@ -0,0 +1,130 @@
package model
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/tal-tech/go-zero/core/stores/sqlc"
"github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
)
var (
userFieldNames = builderx.FieldNames(&User{})
userRows = strings.Join(userFieldNames, ",")
userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
)
type (
UserModel interface {
Insert(data User) (sql.Result, error)
FindOne(id int64) (*User, error)
FindOneByUser(user string) (*User, error)
FindOneByName(name string) (*User, error)
FindOneByMobile(mobile string) (*User, error)
Update(data User) error
Delete(id int64) error
}
defaultUserModel struct {
conn sqlx.SqlConn
table string
}
User struct {
Id int64 `db:"id"`
User string `db:"user"` // 用户
Name string `db:"name"` // 用户名称
Password string `db:"password"` // 用户密码
Mobile string `db:"mobile"` // 手机号
Gender string `db:"gender"` // 男|女|未公开
Nickname string `db:"nickname"` // 用户昵称
CreateTime time.Time `db:"create_time"`
UpdateTime time.Time `db:"update_time"`
}
)
func NewUserModel(conn sqlx.SqlConn) UserModel {
return &defaultUserModel{
conn: conn,
table: "`user`",
}
}
func (m *defaultUserModel) Insert(data User) (sql.Result, error) {
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?, ?)", m.table, userRowsExpectAutoSet)
ret, err := m.conn.Exec(query, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname)
return ret, err
}
func (m *defaultUserModel) FindOne(id int64) (*User, error) {
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", userRows, m.table)
var resp User
err := m.conn.QueryRow(&resp, query, id)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultUserModel) FindOneByUser(user string) (*User, error) {
var resp User
query := fmt.Sprintf("select %s from %s where `user` = ? limit 1", userRows, m.table)
err := m.conn.QueryRow(&resp, query, user)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultUserModel) FindOneByName(name string) (*User, error) {
var resp User
query := fmt.Sprintf("select %s from %s where `name` = ? limit 1", userRows, m.table)
err := m.conn.QueryRow(&resp, query, name)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultUserModel) FindOneByMobile(mobile string) (*User, error) {
var resp User
query := fmt.Sprintf("select %s from %s where `mobile` = ? limit 1", userRows, m.table)
err := m.conn.QueryRow(&resp, query, mobile)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultUserModel) Update(data User) error {
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, userRowsWithPlaceHolder)
_, err := m.conn.Exec(query, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, data.Id)
return err
}
func (m *defaultUserModel) Delete(id int64) error {
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
_, err := m.conn.Exec(query, id)
return err
}

View File

@@ -0,0 +1,5 @@
package model
import "github.com/tal-tech/go-zero/core/stores/sqlx"
var ErrNotFound = sqlx.ErrNotFound

View File

@@ -0,0 +1,255 @@
// copy from core/stores/sqlx/orm.go
package mocksql
import (
"errors"
"reflect"
"strings"
"github.com/tal-tech/go-zero/core/mapping"
)
const tagName = "db"
var (
ErrNotMatchDestination = errors.New("not matching destination to scan")
ErrNotReadableValue = errors.New("value not addressable or interfaceable")
ErrNotSettable = errors.New("passed in variable is not settable")
ErrUnsupportedValueType = errors.New("unsupported unmarshal type")
)
type rowsScanner interface {
Columns() ([]string, error)
Err() error
Next() bool
Scan(v ...interface{}) error
}
func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) {
rt := mapping.Deref(v.Type())
size := rt.NumField()
result := make(map[string]interface{}, size)
for i := 0; i < size; i++ {
key := parseTagName(rt.Field(i))
if len(key) == 0 {
return nil, nil
}
valueField := reflect.Indirect(v).Field(i)
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
result[key] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
result[key] = valueField.Addr().Interface()
}
}
return result, nil
}
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) {
fields := unwrapFields(v)
if strict && len(columns) < len(fields) {
return nil, ErrNotMatchDestination
}
taggedMap, err := getTaggedFieldValueMap(v)
if err != nil {
return nil, err
}
values := make([]interface{}, len(columns))
if len(taggedMap) == 0 {
for i := 0; i < len(values); i++ {
valueField := fields[i]
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
values[i] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
values[i] = valueField.Addr().Interface()
}
}
} else {
for i, column := range columns {
if tagged, ok := taggedMap[column]; ok {
values[i] = tagged
} else {
var anonymous interface{}
values[i] = &anonymous
}
}
}
return values, nil
}
func parseTagName(field reflect.StructField) string {
key := field.Tag.Get(tagName)
if len(key) == 0 {
return ""
} else {
options := strings.Split(key, ",")
return options[0]
}
}
func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error {
if !scanner.Next() {
if err := scanner.Err(); err != nil {
return err
}
return ErrNotFound
}
rv := reflect.ValueOf(v)
if err := mapping.ValidatePtr(&rv); err != nil {
return err
}
rte := reflect.TypeOf(v).Elem()
rve := rv.Elem()
switch rte.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
if rve.CanSet() {
return scanner.Scan(v)
} else {
return ErrNotSettable
}
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
if values, err := mapStructFieldsIntoSlice(rve, columns, strict); err != nil {
return err
} else {
return scanner.Scan(values...)
}
default:
return ErrUnsupportedValueType
}
}
func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error {
rv := reflect.ValueOf(v)
if err := mapping.ValidatePtr(&rv); err != nil {
return err
}
rt := reflect.TypeOf(v)
rte := rt.Elem()
rve := rv.Elem()
switch rte.Kind() {
case reflect.Slice:
if rve.CanSet() {
ptr := rte.Elem().Kind() == reflect.Ptr
appendFn := func(item reflect.Value) {
if ptr {
rve.Set(reflect.Append(rve, item))
} else {
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
}
}
fillFn := func(value interface{}) error {
if rve.CanSet() {
if err := scanner.Scan(value); err != nil {
return err
} else {
appendFn(reflect.ValueOf(value))
return nil
}
}
return ErrNotSettable
}
base := mapping.Deref(rte.Elem())
switch base.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
for scanner.Next() {
value := reflect.New(base)
if err := fillFn(value.Interface()); err != nil {
return err
}
}
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
for scanner.Next() {
value := reflect.New(base)
if values, err := mapStructFieldsIntoSlice(value, columns, strict); err != nil {
return err
} else {
if err := scanner.Scan(values...); err != nil {
return err
} else {
appendFn(value)
}
}
}
default:
return ErrUnsupportedValueType
}
return nil
} else {
return ErrNotSettable
}
default:
return ErrUnsupportedValueType
}
}
func unwrapFields(v reflect.Value) []reflect.Value {
var fields []reflect.Value
indirect := reflect.Indirect(v)
for i := 0; i < indirect.NumField(); i++ {
child := indirect.Field(i)
if child.Kind() == reflect.Ptr && child.IsNil() {
baseValueType := mapping.Deref(child.Type())
child.Set(reflect.New(baseValueType))
}
child = reflect.Indirect(child)
childType := indirect.Type().Field(i)
if child.Kind() == reflect.Struct && childType.Anonymous {
fields = append(fields, unwrapFields(child)...)
} else {
fields = append(fields, child)
}
}
return fields
}

View File

@@ -0,0 +1,90 @@
// copy from core/stores/sqlx/sqlconn.go
package mocksql
import (
"database/sql"
"github.com/tal-tech/go-zero/core/stores/sqlx"
)
type (
MockConn struct {
db *sql.DB
}
statement struct {
stmt *sql.Stmt
}
)
func NewMockConn(db *sql.DB) *MockConn {
return &MockConn{db: db}
}
func (conn *MockConn) Exec(query string, args ...interface{}) (sql.Result, error) {
return exec(conn.db, query, args...)
}
func (conn *MockConn) Prepare(query string) (sqlx.StmtSession, error) {
st, err := conn.db.Prepare(query)
return statement{stmt: st}, err
}
func (conn *MockConn) QueryRow(v interface{}, q string, args ...interface{}) error {
return query(conn.db, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (conn *MockConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return query(conn.db, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (conn *MockConn) QueryRows(v interface{}, q string, args ...interface{}) error {
return query(conn.db, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (conn *MockConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return query(conn.db, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, q, args...)
}
func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
return nil
}
func (s statement) Close() error {
return s.stmt.Close()
}
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
return execStmt(s.stmt, args...)
}
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, args...)
}
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, args...)
}
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, args...)
}
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, args...)
}

View File

@@ -0,0 +1,122 @@
// copy from core/stores/sqlx/stmt.go
package mocksql
import (
"database/sql"
"fmt"
"time"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/timex"
)
const slowThreshold = time.Millisecond * 500
func exec(db *sql.DB, q string, args ...interface{}) (sql.Result, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
}
defer func() {
switch err {
case nil:
err = tx.Commit()
default:
tx.Rollback()
}
}()
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now()
result, err := tx.Exec(q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
}
return result, err
}
func execStmt(conn *sql.Stmt, args ...interface{}) (sql.Result, error) {
stmt := fmt.Sprint(args...)
startTime := timex.Now()
result, err := conn.Exec(args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
}
return result, err
}
func query(db *sql.DB, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
switch err {
case nil:
err = tx.Commit()
default:
tx.Rollback()
}
}()
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now()
rows, err := tx.Query(q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql query: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
return err
}
defer rows.Close()
return scanner(rows)
}
func queryStmt(conn *sql.Stmt, scanner func(*sql.Rows) error, args ...interface{}) error {
stmt := fmt.Sprint(args...)
startTime := timex.Now()
rows, err := conn.Query(args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
return err
}
defer rows.Close()
return scanner(rows)
}

View File

@@ -0,0 +1,105 @@
// copy from core/stores/sqlx/utils.go
package mocksql
import (
"database/sql"
"fmt"
"strings"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/mapping"
)
var ErrNotFound = sql.ErrNoRows
func desensitize(datasource string) string {
// remove account
pos := strings.LastIndex(datasource, "@")
if 0 <= pos && pos+1 < len(datasource) {
datasource = datasource[pos+1:]
}
return datasource
}
func escape(input string) string {
var b strings.Builder
for _, ch := range input {
switch ch {
case '\x00':
b.WriteString(`\x00`)
case '\r':
b.WriteString(`\r`)
case '\n':
b.WriteString(`\n`)
case '\\':
b.WriteString(`\\`)
case '\'':
b.WriteString(`\'`)
case '"':
b.WriteString(`\"`)
case '\x1a':
b.WriteString(`\x1a`)
default:
b.WriteRune(ch)
}
}
return b.String()
}
func format(query string, args ...interface{}) (string, error) {
numArgs := len(args)
if numArgs == 0 {
return query, nil
}
var b strings.Builder
argIndex := 0
for _, ch := range query {
if ch == '?' {
if argIndex >= numArgs {
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
}
arg := args[argIndex]
argIndex++
switch v := arg.(type) {
case bool:
if v {
b.WriteByte('1')
} else {
b.WriteByte('0')
}
case string:
b.WriteByte('\'')
b.WriteString(escape(v))
b.WriteByte('\'')
default:
b.WriteString(mapping.Repr(v))
}
} else {
b.WriteRune(ch)
}
}
if argIndex < numArgs {
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
}
return b.String(), nil
}
func logInstanceError(datasource string, err error) {
datasource = desensitize(datasource)
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
}
func logSqlError(stmt string, err error) {
if err != nil && err != ErrNotFound {
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
}
}

View File

@@ -25,9 +25,10 @@ const (
) )
type Plugin struct { type Plugin struct {
Api *spec.ApiSpec Api *spec.ApiSpec
Style string ApiFilePath string
Dir string Style string
Dir string
} }
func PluginCommand(c *cli.Context) error { func PluginCommand(c *cli.Context) error {
@@ -86,6 +87,12 @@ func prepareArgs(c *cli.Context) ([]byte, error) {
transferData.Api = api transferData.Api = api
} }
absApiFilePath, err := filepath.Abs(apiPath)
if err != nil {
return nil, err
}
transferData.ApiFilePath = absApiFilePath
dirAbs, err := filepath.Abs(c.String("dir")) dirAbs, err := filepath.Abs(c.String("dir"))
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -4,6 +4,7 @@ import (
"go/build" "go/build"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -44,7 +45,9 @@ func TestRpcGenerate(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir) _, err = execx.Run("go test "+projectName, projectDir)
if err != nil { if err != nil {
assert.Contains(t, err.Error(), "not in GOROOT") assert.True(t, func() bool {
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
}())
} }
// case go mod // case go mod
@@ -61,7 +64,9 @@ func TestRpcGenerate(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir) _, err = execx.Run("go test "+projectName, projectDir)
if err != nil { if err != nil {
assert.Contains(t, err.Error(), "not in GOROOT") assert.True(t, func() bool {
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
}())
} }
// case not in go mod and go path // case not in go mod and go path
@@ -69,7 +74,9 @@ func TestRpcGenerate(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir) _, err = execx.Run("go test "+projectName, projectDir)
if err != nil { if err != nil {
assert.Contains(t, err.Error(), "not in GOROOT") assert.True(t, func() bool {
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
}())
} }
// invalid directory // invalid directory

View File

@@ -15,12 +15,12 @@ const svcTemplate = `package svc
import {{.imports}} import {{.imports}}
type ServiceContext struct { type ServiceContext struct {
c config.Config Config config.Config
} }
func NewServiceContext(c config.Config) *ServiceContext { func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{ return &ServiceContext{
c:c, Config:c,
} }
} }
` `

View File

@@ -54,14 +54,15 @@ func Clean() error {
return util.Clean(category) return util.Clean(category)
} }
func Update(category string) error { func Update() error {
err := Clean() err := Clean()
if err != nil { if err != nil {
return err return err
} }
return util.InitTemplates(category, templates) return util.InitTemplates(category, templates)
} }
func GetCategory() string { func Category() string {
return category return category
} }

View File

@@ -97,8 +97,7 @@ func TestUpdate(t *testing.T) {
} }
assert.Equal(t, "modify", string(data)) assert.Equal(t, "modify", string(data))
err = Update(category) assert.Nil(t, Update())
assert.Nil(t, err)
data, err = ioutil.ReadFile(mainTpl) data, err = ioutil.ReadFile(mainTpl)
if err != nil { if err != nil {
@@ -109,6 +108,6 @@ func TestUpdate(t *testing.T) {
func TestGetCategory(t *testing.T) { func TestGetCategory(t *testing.T) {
_ = Clean() _ = Clean()
result := GetCategory() result := Category()
assert.Equal(t, category, result) assert.Equal(t, category, result)
} }

View File

@@ -76,12 +76,16 @@ func UpdateTemplates(ctx *cli.Context) (err error) {
} }
}() }()
switch category { switch category {
case gogen.GetCategory(): case docker.Category():
return gogen.Update(category) return docker.Update()
case rpcgen.GetCategory(): case gogen.Category():
return rpcgen.Update(category) return gogen.Update()
case modelgen.GetCategory(): case kube.Category():
return modelgen.Update(category) return kube.Update()
case rpcgen.Category():
return rpcgen.Update()
case modelgen.Category():
return modelgen.Update()
default: default:
err = fmt.Errorf("unexpected category: %s", category) err = fmt.Errorf("unexpected category: %s", category)
return return
@@ -97,11 +101,15 @@ func RevertTemplates(ctx *cli.Context) (err error) {
} }
}() }()
switch category { switch category {
case gogen.GetCategory(): case docker.Category():
return docker.RevertTemplate(filename)
case kube.Category():
return kube.RevertTemplate(filename)
case gogen.Category():
return gogen.RevertTemplate(filename) return gogen.RevertTemplate(filename)
case rpcgen.GetCategory(): case rpcgen.Category():
return rpcgen.RevertTemplate(filename) return rpcgen.RevertTemplate(filename)
case modelgen.GetCategory(): case modelgen.Category():
return modelgen.RevertTemplate(filename) return modelgen.RevertTemplate(filename)
default: default:
err = fmt.Errorf("unexpected category: %s", category) err = fmt.Errorf("unexpected category: %s", category)

View File

@@ -60,7 +60,6 @@ func FindGoModPath(dir string) (string, bool) {
var hasGoMod = false var hasGoMod = false
for { for {
if FileExists(filepath.Join(tempPath, goModeIdentifier)) { if FileExists(filepath.Join(tempPath, goModeIdentifier)) {
tempPath = filepath.Dir(tempPath)
rootPath = strings.TrimPrefix(absDir[len(tempPath):], "/") rootPath = strings.TrimPrefix(absDir[len(tempPath):], "/")
hasGoMod = true hasGoMod = true
break break

View File

@@ -19,10 +19,11 @@ func Untitle(s string) string {
} }
func Index(slice []string, item string) int { func Index(slice []string, item string) int {
for i, _ := range slice { for i := range slice {
if slice[i] == item { if slice[i] == item {
return i return i
} }
} }
return -1 return -1
} }

View File

@@ -64,10 +64,10 @@ func (l *Logger) Warning(args ...interface{}) {
// ignore builtin grpc warning // ignore builtin grpc warning
} }
func (l *Logger) Warningln(args ...interface{}) {
// ignore builtin grpc warning
}
func (l *Logger) Warningf(format string, args ...interface{}) { func (l *Logger) Warningf(format string, args ...interface{}) {
// ignore builtin grpc warning // ignore builtin grpc warning
} }
func (l *Logger) Warningln(args ...interface{}) {
// ignore builtin grpc warning
}

View File

@@ -0,0 +1,83 @@
package internal
import (
"log"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
const content = "foo"
func TestLoggerError(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Error(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerErrorf(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Errorf(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerErrorln(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Errorln(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerFatal(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Fatal(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerFatalf(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Fatalf(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerFatalln(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Fatalln(content)
assert.Contains(t, builder.String(), content)
}
func TestLoggerWarning(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Warning(content)
assert.Empty(t, builder.String())
}
func TestLoggerWarningf(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Warningf(content)
assert.Empty(t, builder.String())
}
func TestLoggerWarningln(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
logger := new(Logger)
logger.Warningln(content)
assert.Empty(t, builder.String())
}