mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-26 16:15:30 +08:00
Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
57b73d8b49 | ||
|
|
a79cee12ee | ||
|
|
7a921f66e6 | ||
|
|
12e235efb0 | ||
|
|
01060cf16d | ||
|
|
0786862a35 | ||
|
|
efa43483b2 | ||
|
|
771371e051 | ||
|
|
2ee95f8981 | ||
|
|
5bc01e4bfd | ||
|
|
510e966982 | ||
|
|
10e3b8ac80 | ||
|
|
04059bbf5a | ||
|
|
d643007c79 | ||
|
|
fc43876cc5 | ||
|
|
a926cb514f | ||
|
|
25cab2f273 | ||
|
|
8d2e2753a2 | ||
|
|
cc4c50e3eb | ||
|
|
751072bdb0 | ||
|
|
e97e1f10db | ||
|
|
0bd2a0656c | ||
|
|
71a2b20301 | ||
|
|
8df7de94e3 | ||
|
|
bf21203297 | ||
|
|
ae98375194 | ||
|
|
82d1ccf376 | ||
|
|
bb6d49c17e | ||
|
|
ed735ec47c | ||
|
|
ba4bac3a03 | ||
|
|
08433d7e04 | ||
|
|
a3b525b50d | ||
|
|
097f6886f2 | ||
|
|
07a1549634 | ||
|
|
befca26c58 | ||
|
|
3556a2eef4 | ||
|
|
807765f77e | ||
|
|
e44584e549 | ||
|
|
acd48f0abb | ||
|
|
f919bc6713 | ||
|
|
a0030b8f45 | ||
|
|
a5f0cce1b1 | ||
|
|
4d13dda605 | ||
|
|
b56cc8e459 | ||
|
|
c435811479 | ||
|
|
c686c93fb5 | ||
|
|
da8f76e6bd | ||
|
|
99596a4149 | ||
|
|
ec2a9f2c57 | ||
|
|
fd73ced6dc | ||
|
|
5071736ab4 |
67
.github/workflows/codeql-analysis.yml
vendored
Normal file
67
.github/workflows/codeql-analysis.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# For most projects, this workflow file will not need changing; you simply need
|
||||||
|
# to commit it to your repository.
|
||||||
|
#
|
||||||
|
# You may wish to alter this file to override the set of languages analyzed,
|
||||||
|
# or to provide custom queries or build logic.
|
||||||
|
#
|
||||||
|
# ******** NOTE ********
|
||||||
|
# We have attempted to detect the languages in your repository. Please check
|
||||||
|
# the `language` matrix defined below to confirm you have the correct set of
|
||||||
|
# supported CodeQL languages.
|
||||||
|
#
|
||||||
|
name: "CodeQL"
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ master ]
|
||||||
|
pull_request:
|
||||||
|
# The branches below must be a subset of the branches above
|
||||||
|
branches: [ master ]
|
||||||
|
schedule:
|
||||||
|
- cron: '18 19 * * 6'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
analyze:
|
||||||
|
name: Analyze
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
language: [ 'go' ]
|
||||||
|
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
|
||||||
|
# Learn more:
|
||||||
|
# https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
# Initializes the CodeQL tools for scanning.
|
||||||
|
- name: Initialize CodeQL
|
||||||
|
uses: github/codeql-action/init@v1
|
||||||
|
with:
|
||||||
|
languages: ${{ matrix.language }}
|
||||||
|
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||||
|
# By default, queries listed here will override any specified in a config file.
|
||||||
|
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||||
|
# queries: ./path/to/local/query, your-org/your-repo/queries@main
|
||||||
|
|
||||||
|
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
||||||
|
# If this step fails, then you should remove it and run the build manually (see below)
|
||||||
|
- name: Autobuild
|
||||||
|
uses: github/codeql-action/autobuild@v1
|
||||||
|
|
||||||
|
# ℹ️ Command-line programs to run using the OS shell.
|
||||||
|
# 📚 https://git.io/JvXDl
|
||||||
|
|
||||||
|
# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
|
||||||
|
# and modify them (or add more) to build your code if your project
|
||||||
|
# uses a compiled language
|
||||||
|
|
||||||
|
#- run: |
|
||||||
|
# make bootstrap
|
||||||
|
# make release
|
||||||
|
|
||||||
|
- name: Perform CodeQL Analysis
|
||||||
|
uses: github/codeql-action/analyze@v1
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,6 +4,7 @@
|
|||||||
# Unignore all with extensions
|
# Unignore all with extensions
|
||||||
!*.*
|
!*.*
|
||||||
!**/Dockerfile
|
!**/Dockerfile
|
||||||
|
!**/Makefile
|
||||||
|
|
||||||
# Unignore all dirs
|
# Unignore all dirs
|
||||||
!*/
|
!*/
|
||||||
@@ -12,7 +13,6 @@
|
|||||||
.idea
|
.idea
|
||||||
**/.DS_Store
|
**/.DS_Store
|
||||||
**/logs
|
**/logs
|
||||||
!Makefile
|
|
||||||
|
|
||||||
# gitlab ci
|
# gitlab ci
|
||||||
.cache
|
.cache
|
||||||
|
|||||||
@@ -8,8 +8,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
// RollingWindowOption let callers customize the RollingWindow.
|
||||||
RollingWindowOption func(rollingWindow *RollingWindow)
|
RollingWindowOption func(rollingWindow *RollingWindow)
|
||||||
|
|
||||||
|
// RollingWindow defines a rolling window to calculate the events in buckets with time interval.
|
||||||
RollingWindow struct {
|
RollingWindow struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
size int
|
size int
|
||||||
@@ -17,10 +19,12 @@ type (
|
|||||||
interval time.Duration
|
interval time.Duration
|
||||||
offset int
|
offset int
|
||||||
ignoreCurrent bool
|
ignoreCurrent bool
|
||||||
lastTime time.Duration
|
lastTime time.Duration // start time of the last bucket
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NewRollingWindow returns a RollingWindow that with size buckets and time interval,
|
||||||
|
// use opts to customize the RollingWindow.
|
||||||
func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
|
func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
|
||||||
if size < 1 {
|
if size < 1 {
|
||||||
panic("size must be greater than 0")
|
panic("size must be greater than 0")
|
||||||
@@ -38,6 +42,7 @@ func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOpt
|
|||||||
return w
|
return w
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add adds value to current bucket.
|
||||||
func (rw *RollingWindow) Add(v float64) {
|
func (rw *RollingWindow) Add(v float64) {
|
||||||
rw.lock.Lock()
|
rw.lock.Lock()
|
||||||
defer rw.lock.Unlock()
|
defer rw.lock.Unlock()
|
||||||
@@ -45,6 +50,7 @@ func (rw *RollingWindow) Add(v float64) {
|
|||||||
rw.win.add(rw.offset, v)
|
rw.win.add(rw.offset, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reduce runs fn on all buckets, ignore current bucket if ignoreCurrent was set.
|
||||||
func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
|
func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
|
||||||
rw.lock.RLock()
|
rw.lock.RLock()
|
||||||
defer rw.lock.RUnlock()
|
defer rw.lock.RUnlock()
|
||||||
@@ -79,26 +85,18 @@ func (rw *RollingWindow) updateOffset() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
offset := rw.offset
|
offset := rw.offset
|
||||||
start := offset + 1
|
|
||||||
steps := start + span
|
|
||||||
var remainder int
|
|
||||||
if steps > rw.size {
|
|
||||||
remainder = steps - rw.size
|
|
||||||
steps = rw.size
|
|
||||||
}
|
|
||||||
|
|
||||||
// reset expired buckets
|
// reset expired buckets
|
||||||
for i := start; i < steps; i++ {
|
for i := 0; i < span; i++ {
|
||||||
rw.win.resetBucket(i)
|
rw.win.resetBucket((offset + i + 1) % rw.size)
|
||||||
}
|
|
||||||
for i := 0; i < remainder; i++ {
|
|
||||||
rw.win.resetBucket(i)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rw.offset = (offset + span) % rw.size
|
rw.offset = (offset + span) % rw.size
|
||||||
rw.lastTime = timex.Now()
|
now := timex.Now()
|
||||||
|
// align to interval time boundary
|
||||||
|
rw.lastTime = now - (now-rw.lastTime)%rw.interval
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bucket defines the bucket that holds sum and num of additions.
|
||||||
type Bucket struct {
|
type Bucket struct {
|
||||||
Sum float64
|
Sum float64
|
||||||
Count int64
|
Count int64
|
||||||
@@ -144,6 +142,7 @@ func (w *window) resetBucket(offset int) {
|
|||||||
w.buckets[offset%w.size].reset()
|
w.buckets[offset%w.size].reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IgnoreCurrentBucket lets the Reduce call ignore current bucket.
|
||||||
func IgnoreCurrentBucket() RollingWindowOption {
|
func IgnoreCurrentBucket() RollingWindowOption {
|
||||||
return func(w *RollingWindow) {
|
return func(w *RollingWindow) {
|
||||||
w.ignoreCurrent = true
|
w.ignoreCurrent = true
|
||||||
|
|||||||
@@ -105,6 +105,37 @@ func TestRollingWindowReduce(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRollingWindowBucketTimeBoundary(t *testing.T) {
|
||||||
|
const size = 3
|
||||||
|
interval := time.Millisecond * 30
|
||||||
|
r := NewRollingWindow(size, interval)
|
||||||
|
listBuckets := func() []float64 {
|
||||||
|
var buckets []float64
|
||||||
|
r.Reduce(func(b *Bucket) {
|
||||||
|
buckets = append(buckets, b.Sum)
|
||||||
|
})
|
||||||
|
return buckets
|
||||||
|
}
|
||||||
|
assert.Equal(t, []float64{0, 0, 0}, listBuckets())
|
||||||
|
r.Add(1)
|
||||||
|
assert.Equal(t, []float64{0, 0, 1}, listBuckets())
|
||||||
|
time.Sleep(time.Millisecond * 45)
|
||||||
|
r.Add(2)
|
||||||
|
r.Add(3)
|
||||||
|
assert.Equal(t, []float64{0, 1, 5}, listBuckets())
|
||||||
|
// sleep time should be less than interval, and make the bucket change happen
|
||||||
|
time.Sleep(time.Millisecond * 20)
|
||||||
|
r.Add(4)
|
||||||
|
r.Add(5)
|
||||||
|
r.Add(6)
|
||||||
|
assert.Equal(t, []float64{1, 5, 15}, listBuckets())
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
r.Add(7)
|
||||||
|
r.Add(8)
|
||||||
|
r.Add(9)
|
||||||
|
assert.Equal(t, []float64{0, 0, 24}, listBuckets())
|
||||||
|
}
|
||||||
|
|
||||||
func TestRollingWindowDataRace(t *testing.T) {
|
func TestRollingWindowDataRace(t *testing.T) {
|
||||||
const size = 3
|
const size = 3
|
||||||
r := NewRollingWindow(size, duration)
|
r := NewRollingWindow(size, duration)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
"github.com/tal-tech/go-zero/core/mapping"
|
"github.com/tal-tech/go-zero/core/mapping"
|
||||||
@@ -19,7 +20,7 @@ func LoadConfig(file string, v interface{}) error {
|
|||||||
if content, err := ioutil.ReadFile(file); err != nil {
|
if content, err := ioutil.ReadFile(file); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if loader, ok := loaders[path.Ext(file)]; ok {
|
} else if loader, ok := loaders[path.Ext(file)]; ok {
|
||||||
return loader(content, v)
|
return loader([]byte(os.ExpandEnv(string(content))), v)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("unrecoginized file type: %s", file)
|
return fmt.Errorf("unrecoginized file type: %s", file)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,13 +17,14 @@ func TestConfigJson(t *testing.T) {
|
|||||||
}
|
}
|
||||||
text := `{
|
text := `{
|
||||||
"a": "foo",
|
"a": "foo",
|
||||||
"b": 1
|
"b": 1,
|
||||||
|
"c": "${FOO}"
|
||||||
}`
|
}`
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
test := test
|
test := test
|
||||||
t.Run(test, func(t *testing.T) {
|
t.Run(test, func(t *testing.T) {
|
||||||
t.Parallel()
|
os.Setenv("FOO", "2")
|
||||||
|
defer os.Unsetenv("FOO")
|
||||||
tmpfile, err := createTempFile(test, text)
|
tmpfile, err := createTempFile(test, text)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(tmpfile)
|
defer os.Remove(tmpfile)
|
||||||
@@ -31,10 +32,12 @@ func TestConfigJson(t *testing.T) {
|
|||||||
var val struct {
|
var val struct {
|
||||||
A string `json:"a"`
|
A string `json:"a"`
|
||||||
B int `json:"b"`
|
B int `json:"b"`
|
||||||
|
C string `json:"c"`
|
||||||
}
|
}
|
||||||
MustLoad(tmpfile, &val)
|
MustLoad(tmpfile, &val)
|
||||||
assert.Equal(t, "foo", val.A)
|
assert.Equal(t, "foo", val.A)
|
||||||
assert.Equal(t, 1, val.B)
|
assert.Equal(t, 1, val.B)
|
||||||
|
assert.Equal(t, "2", val.C)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package executors
|
|||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tal-tech/go-zero/core/lang"
|
"github.com/tal-tech/go-zero/core/lang"
|
||||||
@@ -35,6 +36,7 @@ type (
|
|||||||
// avoid race condition on waitGroup when calling wg.Add/Done/Wait(...)
|
// avoid race condition on waitGroup when calling wg.Add/Done/Wait(...)
|
||||||
wgBarrier syncx.Barrier
|
wgBarrier syncx.Barrier
|
||||||
confirmChan chan lang.PlaceholderType
|
confirmChan chan lang.PlaceholderType
|
||||||
|
inflight int32
|
||||||
guarded bool
|
guarded bool
|
||||||
newTicker func(duration time.Duration) timex.Ticker
|
newTicker func(duration time.Duration) timex.Ticker
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
@@ -91,18 +93,16 @@ func (pe *PeriodicalExecutor) Wait() {
|
|||||||
func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
|
func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
|
||||||
pe.lock.Lock()
|
pe.lock.Lock()
|
||||||
defer func() {
|
defer func() {
|
||||||
var start bool
|
|
||||||
if !pe.guarded {
|
if !pe.guarded {
|
||||||
pe.guarded = true
|
pe.guarded = true
|
||||||
start = true
|
// defer to unlock quickly
|
||||||
|
defer pe.backgroundFlush()
|
||||||
}
|
}
|
||||||
pe.lock.Unlock()
|
pe.lock.Unlock()
|
||||||
if start {
|
|
||||||
pe.backgroundFlush()
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if pe.container.AddTask(task) {
|
if pe.container.AddTask(task) {
|
||||||
|
atomic.AddInt32(&pe.inflight, 1)
|
||||||
return pe.container.RemoveAll(), true
|
return pe.container.RemoveAll(), true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,6 +111,9 @@ func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool)
|
|||||||
|
|
||||||
func (pe *PeriodicalExecutor) backgroundFlush() {
|
func (pe *PeriodicalExecutor) backgroundFlush() {
|
||||||
threading.GoSafe(func() {
|
threading.GoSafe(func() {
|
||||||
|
// flush before quit goroutine to avoid missing tasks
|
||||||
|
defer pe.Flush()
|
||||||
|
|
||||||
ticker := pe.newTicker(pe.interval)
|
ticker := pe.newTicker(pe.interval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -120,6 +123,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
|
|||||||
select {
|
select {
|
||||||
case vals := <-pe.commander:
|
case vals := <-pe.commander:
|
||||||
commanded = true
|
commanded = true
|
||||||
|
atomic.AddInt32(&pe.inflight, -1)
|
||||||
pe.enterExecution()
|
pe.enterExecution()
|
||||||
pe.confirmChan <- lang.Placeholder
|
pe.confirmChan <- lang.Placeholder
|
||||||
pe.executeTasks(vals)
|
pe.executeTasks(vals)
|
||||||
@@ -129,13 +133,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
|
|||||||
commanded = false
|
commanded = false
|
||||||
} else if pe.Flush() {
|
} else if pe.Flush() {
|
||||||
last = timex.Now()
|
last = timex.Now()
|
||||||
} else if timex.Since(last) > pe.interval*idleRound {
|
} else if pe.shallQuit(last) {
|
||||||
pe.lock.Lock()
|
|
||||||
pe.guarded = false
|
|
||||||
pe.lock.Unlock()
|
|
||||||
|
|
||||||
// flush again to avoid missing tasks
|
|
||||||
pe.Flush()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -178,3 +176,19 @@ func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pe *PeriodicalExecutor) shallQuit(last time.Duration) (stop bool) {
|
||||||
|
if timex.Since(last) <= pe.interval*idleRound {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// checking pe.inflight and setting pe.guarded should be locked together
|
||||||
|
pe.lock.Lock()
|
||||||
|
if atomic.LoadInt32(&pe.inflight) == 0 {
|
||||||
|
pe.guarded = false
|
||||||
|
stop = true
|
||||||
|
}
|
||||||
|
pe.lock.Unlock()
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -140,6 +140,26 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
|
|||||||
assert.Equal(t, total, cnt)
|
assert.Equal(t, total, cnt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPeriodicalExecutor_Deadlock(t *testing.T) {
|
||||||
|
executor := NewBulkExecutor(func(tasks []interface{}) {
|
||||||
|
}, WithBulkTasks(1), WithBulkInterval(time.Millisecond))
|
||||||
|
for i := 0; i < 1e5; i++ {
|
||||||
|
executor.Add(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeriodicalExecutor_hasTasks(t *testing.T) {
|
||||||
|
ticker := timex.NewFakeTicker()
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
|
||||||
|
exec.newTicker = func(d time.Duration) timex.Ticker {
|
||||||
|
return ticker
|
||||||
|
}
|
||||||
|
assert.False(t, exec.hasTasks(nil))
|
||||||
|
assert.True(t, exec.hasTasks(1))
|
||||||
|
}
|
||||||
|
|
||||||
// go test -benchtime 10s -bench .
|
// go test -benchtime 10s -bench .
|
||||||
func BenchmarkExecutor(b *testing.B) {
|
func BenchmarkExecutor(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ var mock tracespec.Trace = new(mockTrace)
|
|||||||
|
|
||||||
func TestTraceLog(t *testing.T) {
|
func TestTraceLog(t *testing.T) {
|
||||||
var buf mockWriter
|
var buf mockWriter
|
||||||
|
atomic.StoreUint32(&initialized, 1)
|
||||||
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
|
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
|
||||||
WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog)
|
WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog)
|
||||||
assert.True(t, strings.Contains(buf.String(), mockTraceId))
|
assert.True(t, strings.Contains(buf.String(), mockTraceId))
|
||||||
|
|||||||
@@ -153,58 +153,57 @@ func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fie
|
|||||||
key := strings.TrimSpace(segments[0])
|
key := strings.TrimSpace(segments[0])
|
||||||
options := segments[1:]
|
options := segments[1:]
|
||||||
|
|
||||||
if len(options) > 0 {
|
if len(options) == 0 {
|
||||||
var fieldOpts fieldOptions
|
return key, nil, nil
|
||||||
|
|
||||||
for _, segment := range options {
|
|
||||||
option := strings.TrimSpace(segment)
|
|
||||||
switch {
|
|
||||||
case option == stringOption:
|
|
||||||
fieldOpts.FromString = true
|
|
||||||
case strings.HasPrefix(option, optionalOption):
|
|
||||||
segs := strings.Split(option, equalToken)
|
|
||||||
switch len(segs) {
|
|
||||||
case 1:
|
|
||||||
fieldOpts.Optional = true
|
|
||||||
case 2:
|
|
||||||
fieldOpts.Optional = true
|
|
||||||
fieldOpts.OptionalDep = segs[1]
|
|
||||||
default:
|
|
||||||
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
|
|
||||||
}
|
|
||||||
case option == optionalOption:
|
|
||||||
fieldOpts.Optional = true
|
|
||||||
case strings.HasPrefix(option, optionsOption):
|
|
||||||
segs := strings.Split(option, equalToken)
|
|
||||||
if len(segs) != 2 {
|
|
||||||
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
|
|
||||||
} else {
|
|
||||||
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
|
|
||||||
}
|
|
||||||
case strings.HasPrefix(option, defaultOption):
|
|
||||||
segs := strings.Split(option, equalToken)
|
|
||||||
if len(segs) != 2 {
|
|
||||||
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
|
|
||||||
} else {
|
|
||||||
fieldOpts.Default = strings.TrimSpace(segs[1])
|
|
||||||
}
|
|
||||||
case strings.HasPrefix(option, rangeOption):
|
|
||||||
segs := strings.Split(option, equalToken)
|
|
||||||
if len(segs) != 2 {
|
|
||||||
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
|
|
||||||
}
|
|
||||||
if nr, err := parseNumberRange(segs[1]); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
} else {
|
|
||||||
fieldOpts.Range = nr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return key, &fieldOpts, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return key, nil, nil
|
var fieldOpts fieldOptions
|
||||||
|
for _, segment := range options {
|
||||||
|
option := strings.TrimSpace(segment)
|
||||||
|
switch {
|
||||||
|
case option == stringOption:
|
||||||
|
fieldOpts.FromString = true
|
||||||
|
case strings.HasPrefix(option, optionalOption):
|
||||||
|
segs := strings.Split(option, equalToken)
|
||||||
|
switch len(segs) {
|
||||||
|
case 1:
|
||||||
|
fieldOpts.Optional = true
|
||||||
|
case 2:
|
||||||
|
fieldOpts.Optional = true
|
||||||
|
fieldOpts.OptionalDep = segs[1]
|
||||||
|
default:
|
||||||
|
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
|
||||||
|
}
|
||||||
|
case option == optionalOption:
|
||||||
|
fieldOpts.Optional = true
|
||||||
|
case strings.HasPrefix(option, optionsOption):
|
||||||
|
segs := strings.Split(option, equalToken)
|
||||||
|
if len(segs) != 2 {
|
||||||
|
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
|
||||||
|
} else {
|
||||||
|
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(option, defaultOption):
|
||||||
|
segs := strings.Split(option, equalToken)
|
||||||
|
if len(segs) != 2 {
|
||||||
|
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
|
||||||
|
} else {
|
||||||
|
fieldOpts.Default = strings.TrimSpace(segs[1])
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(option, rangeOption):
|
||||||
|
segs := strings.Split(option, equalToken)
|
||||||
|
if len(segs) != 2 {
|
||||||
|
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
|
||||||
|
}
|
||||||
|
if nr, err := parseNumberRange(segs[1]); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
} else {
|
||||||
|
fieldOpts.Range = nr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, &fieldOpts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
|
func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
|
||||||
|
|||||||
16
core/prof/profilecenter_test.go
Normal file
16
core/prof/profilecenter_test.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package prof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReport(t *testing.T) {
|
||||||
|
once.Do(func() {})
|
||||||
|
assert.NotContains(t, generateReport(), "foo")
|
||||||
|
report("foo", time.Second)
|
||||||
|
assert.Contains(t, generateReport(), "foo")
|
||||||
|
report("foo", time.Second)
|
||||||
|
}
|
||||||
23
core/prof/profiler_test.go
Normal file
23
core/prof/profiler_test.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package prof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/core/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProfiler(t *testing.T) {
|
||||||
|
EnableProfiling()
|
||||||
|
Start()
|
||||||
|
Report("foo", ProfilePoint{
|
||||||
|
ElapsedTimer: utils.NewElapsedTimer(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNullProfiler(t *testing.T) {
|
||||||
|
p := newNullProfiler()
|
||||||
|
p.Start()
|
||||||
|
p.Report("foo", ProfilePoint{
|
||||||
|
ElapsedTimer: utils.NewElapsedTimer(),
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -70,8 +70,6 @@ func (g *sharedGroup) createCall(key string) (c *call, done bool) {
|
|||||||
|
|
||||||
func (g *sharedGroup) makeCall(c *call, key string, fn func() (interface{}, error)) {
|
func (g *sharedGroup) makeCall(c *call, key string, fn func() (interface{}, error)) {
|
||||||
defer func() {
|
defer func() {
|
||||||
// delete key first, done later. can't reverse the order, because if reverse,
|
|
||||||
// another Do call might wg.Wait() without get notified with wg.Done()
|
|
||||||
g.lock.Lock()
|
g.lock.Lock()
|
||||||
delete(g.calls, key)
|
delete(g.calls, key)
|
||||||
g.lock.Unlock()
|
g.lock.Unlock()
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ go get -u github.com/tal-tech/go-zero
|
|||||||
the .api files also can be generate by goctl, like below:
|
the .api files also can be generate by goctl, like below:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
goctl api -o greet.api
|
goctl api -o greet.api
|
||||||
```
|
```
|
||||||
|
|
||||||
3. generate the go server side code
|
3. generate the go server side code
|
||||||
@@ -208,3 +208,7 @@ goctl api -o greet.api
|
|||||||
|
|
||||||
* [Rapid development of microservice systems](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl-en.md)
|
* [Rapid development of microservice systems](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl-en.md)
|
||||||
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/doc/bookstore-en.md)
|
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/doc/bookstore-en.md)
|
||||||
|
|
||||||
|
## 9. Chat group
|
||||||
|
|
||||||
|
Join the chat via https://discord.gg/4JQvC5A4Fe
|
||||||
|
|||||||
15
readme.md
15
readme.md
@@ -5,8 +5,9 @@
|
|||||||
[English](readme-en.md) | 简体中文
|
[English](readme-en.md) | 简体中文
|
||||||
|
|
||||||
[](https://github.com/tal-tech/go-zero/actions)
|
[](https://github.com/tal-tech/go-zero/actions)
|
||||||
[](https://codecov.io/gh/tal-tech/go-zero)
|
|
||||||
[](https://goreportcard.com/report/github.com/tal-tech/go-zero)
|
[](https://goreportcard.com/report/github.com/tal-tech/go-zero)
|
||||||
|
[](https://goproxy.cn/stats/github.com/tal-tech/go-zero/badges/download-count.svg)
|
||||||
|
[](https://codecov.io/gh/tal-tech/go-zero)
|
||||||
[](https://github.com/tal-tech/go-zero)
|
[](https://github.com/tal-tech/go-zero)
|
||||||
[](https://opensource.org/licenses/MIT)
|
[](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
@@ -95,7 +96,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
|||||||
|
|
||||||
[快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
|
[快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
|
||||||
|
|
||||||
[快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md)
|
[快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore.md)
|
||||||
|
|
||||||
1. 安装 goctl 工具
|
1. 安装 goctl 工具
|
||||||
|
|
||||||
@@ -162,7 +163,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
|||||||
|
|
||||||
* awesome 系列
|
* awesome 系列
|
||||||
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
|
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
|
||||||
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md)
|
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore.md)
|
||||||
* [goctl 使用帮助](https://github.com/tal-tech/zero-doc/blob/main/doc/goctl.md)
|
* [goctl 使用帮助](https://github.com/tal-tech/zero-doc/blob/main/doc/goctl.md)
|
||||||
* [通过 MapReduce 降低服务响应时间](https://github.com/tal-tech/zero-doc/blob/main/doc/mapreduce.md)
|
* [通过 MapReduce 降低服务响应时间](https://github.com/tal-tech/zero-doc/blob/main/doc/mapreduce.md)
|
||||||
* [关键字替换和敏感词过滤工具](https://github.com/tal-tech/zero-doc/blob/main/doc/keywords.md)
|
* [关键字替换和敏感词过滤工具](https://github.com/tal-tech/zero-doc/blob/main/doc/keywords.md)
|
||||||
@@ -172,7 +173,13 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
|||||||
* [文本序列化和反序列化](https://github.com/tal-tech/zero-doc/blob/main/doc/mapping.md)
|
* [文本序列化和反序列化](https://github.com/tal-tech/zero-doc/blob/main/doc/mapping.md)
|
||||||
* [快速构建 jwt 鉴权认证](https://github.com/tal-tech/zero-doc/blob/main/doc/jwt.md)
|
* [快速构建 jwt 鉴权认证](https://github.com/tal-tech/zero-doc/blob/main/doc/jwt.md)
|
||||||
|
|
||||||
## 8. 微信交流群
|
## 8. 微信公众号
|
||||||
|
|
||||||
|
`go-zero` 相关文章都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注,也可以通过公众号私信我 👏
|
||||||
|
|
||||||
|
<img src="https://gitee.com/kevwan/static/raw/master/images/wechat-micro.jpg" alt="wechat" width="300" />
|
||||||
|
|
||||||
|
## 9. 微信交流群
|
||||||
|
|
||||||
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。
|
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。
|
||||||
|
|
||||||
|
|||||||
171
rest/engine_test.go
Normal file
171
rest/engine_test.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tal-tech/go-zero/core/conf"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewEngine(t *testing.T) {
|
||||||
|
yamls := []string{
|
||||||
|
`Name: foo
|
||||||
|
Port: 54321
|
||||||
|
`,
|
||||||
|
`Name: foo
|
||||||
|
Port: 54321
|
||||||
|
CpuThreshold: 500
|
||||||
|
`,
|
||||||
|
`Name: foo
|
||||||
|
Port: 54321
|
||||||
|
CpuThreshold: 500
|
||||||
|
Verbose: true
|
||||||
|
`,
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := []featuredRoutes{
|
||||||
|
{
|
||||||
|
jwt: jwtSetting{},
|
||||||
|
signature: signatureSetting{},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
priority: true,
|
||||||
|
jwt: jwtSetting{},
|
||||||
|
signature: signatureSetting{},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
priority: true,
|
||||||
|
jwt: jwtSetting{
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
signature: signatureSetting{},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
priority: true,
|
||||||
|
jwt: jwtSetting{
|
||||||
|
enabled: true,
|
||||||
|
prevSecret: "thesecret",
|
||||||
|
},
|
||||||
|
signature: signatureSetting{},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
priority: true,
|
||||||
|
jwt: jwtSetting{
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
signature: signatureSetting{},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
priority: true,
|
||||||
|
jwt: jwtSetting{
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
signature: signatureSetting{
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
priority: true,
|
||||||
|
jwt: jwtSetting{
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
signature: signatureSetting{
|
||||||
|
enabled: true,
|
||||||
|
SignatureConf: SignatureConf{
|
||||||
|
Strict: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
priority: true,
|
||||||
|
jwt: jwtSetting{
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
signature: signatureSetting{
|
||||||
|
enabled: true,
|
||||||
|
SignatureConf: SignatureConf{
|
||||||
|
Strict: true,
|
||||||
|
PrivateKeys: []PrivateKeyConf{
|
||||||
|
{
|
||||||
|
Fingerprint: "a",
|
||||||
|
KeyFile: "b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, yaml := range yamls {
|
||||||
|
for _, route := range routes {
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf))
|
||||||
|
ng := newEngine(cnf)
|
||||||
|
ng.AddRoutes(route)
|
||||||
|
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
assert.NotNil(t, ng.StartWithRouter(mockedRouter{}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockedRouter struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockedRouter) Handle(method string, path string, handler http.Handler) error {
|
||||||
|
return errors.New("foo")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockedRouter) SetNotFoundHandler(handler http.Handler) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockedRouter) SetNotAllowedHandler(handler http.Handler) {
|
||||||
|
}
|
||||||
@@ -46,18 +46,18 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
|
|||||||
parser := token.NewTokenParser()
|
parser := token.NewTokenParser()
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
|
tok, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
unauthorized(w, r, err, authOpts.Callback)
|
unauthorized(w, r, err, authOpts.Callback)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !token.Valid {
|
if !tok.Valid {
|
||||||
unauthorized(w, r, errInvalidToken, authOpts.Callback)
|
unauthorized(w, r, errInvalidToken, authOpts.Callback)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
claims, ok := tok.Claims.(jwt.MapClaims)
|
||||||
if !ok {
|
if !ok {
|
||||||
unauthorized(w, r, errNoClaims, authOpts.Callback)
|
unauthorized(w, r, errNoClaims, authOpts.Callback)
|
||||||
return
|
return
|
||||||
@@ -122,6 +122,12 @@ func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (grw *guardedResponseWriter) Flush() {
|
||||||
|
if flusher, ok := grw.writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (grw *guardedResponseWriter) Header() http.Header {
|
func (grw *guardedResponseWriter) Header() http.Header {
|
||||||
return grw.writer.Header()
|
return grw.writer.Header()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ func TestAuthHandler(t *testing.T) {
|
|||||||
w.Header().Set("X-Test", "test")
|
w.Header().Set("X-Test", "test")
|
||||||
_, err := w.Write([]byte("content"))
|
_, err := w.Write([]byte("content"))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
assert.True(t, ok)
|
||||||
|
flusher.Flush()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -83,6 +83,12 @@ func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *cryptionResponseWriter) Flush() {
|
||||||
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (w *cryptionResponseWriter) Header() http.Header {
|
func (w *cryptionResponseWriter) Header() http.Header {
|
||||||
return w.ResponseWriter.Header()
|
return w.ResponseWriter.Header()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -87,3 +87,19 @@ func TestCryptionHandlerWriteHeader(t *testing.T) {
|
|||||||
handler.ServeHTTP(recorder, req)
|
handler.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCryptionHandlerFlush(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/any", nil)
|
||||||
|
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(respText))
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
assert.True(t, ok)
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,6 +38,12 @@ func (w *LoggedResponseWriter) WriteHeader(code int) {
|
|||||||
w.code = code
|
w.code = code
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *LoggedResponseWriter) Flush() {
|
||||||
|
if flusher, ok := w.w.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func LogHandler(next http.Handler) http.Handler {
|
func LogHandler(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
timer := utils.NewElapsedTimer()
|
timer := utils.NewElapsedTimer()
|
||||||
@@ -68,6 +74,10 @@ func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buff
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *DetailLoggedResponseWriter) Flush() {
|
||||||
|
w.writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
func (w *DetailLoggedResponseWriter) Header() http.Header {
|
func (w *DetailLoggedResponseWriter) Header() http.Header {
|
||||||
return w.writer.Header()
|
return w.writer.Header()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,10 @@ func TestLogHandler(t *testing.T) {
|
|||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
_, err := w.Write([]byte("content"))
|
_, err := w.Write([]byte("content"))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
assert.True(t, ok)
|
||||||
|
flusher.Flush()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -7,6 +7,12 @@ type WithCodeResponseWriter struct {
|
|||||||
Code int
|
Code int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WithCodeResponseWriter) Flush() {
|
||||||
|
if flusher, ok := w.Writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (w *WithCodeResponseWriter) Header() http.Header {
|
func (w *WithCodeResponseWriter) Header() http.Header {
|
||||||
return w.Writer.Header()
|
return w.Writer.Header()
|
||||||
}
|
}
|
||||||
|
|||||||
33
rest/internal/security/withcoderesponsewriter_test.go
Normal file
33
rest/internal/security/withcoderesponsewriter_test.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithCodeResponseWriter(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cw := &WithCodeResponseWriter{Writer: w}
|
||||||
|
|
||||||
|
cw.Header().Set("X-Test", "test")
|
||||||
|
cw.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
assert.Equal(t, cw.Code, http.StatusServiceUnavailable)
|
||||||
|
|
||||||
|
_, err := cw.Write([]byte("content"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
flusher, ok := http.ResponseWriter(cw).(http.Flusher)
|
||||||
|
assert.True(t, ok)
|
||||||
|
flusher.Flush()
|
||||||
|
})
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(resp, req)
|
||||||
|
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||||
|
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||||
|
assert.Equal(t, "content", resp.Body.String())
|
||||||
|
}
|
||||||
@@ -64,7 +64,7 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
allow, ok := pr.methodNotAllowed(r.Method, reqPath)
|
allows, ok := pr.methodsAllowed(r.Method, reqPath)
|
||||||
if !ok {
|
if !ok {
|
||||||
pr.handleNotFound(w, r)
|
pr.handleNotFound(w, r)
|
||||||
return
|
return
|
||||||
@@ -73,7 +73,7 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if pr.notAllowed != nil {
|
if pr.notAllowed != nil {
|
||||||
pr.notAllowed.ServeHTTP(w, r)
|
pr.notAllowed.ServeHTTP(w, r)
|
||||||
} else {
|
} else {
|
||||||
w.Header().Set(allowHeader, allow)
|
w.Header().Set(allowHeader, allows)
|
||||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -94,7 +94,7 @@ func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pr *patRouter) methodNotAllowed(method, path string) (string, bool) {
|
func (pr *patRouter) methodsAllowed(method, path string) (string, bool) {
|
||||||
var allows []string
|
var allows []string
|
||||||
|
|
||||||
for treeMethod, tree := range pr.trees {
|
for treeMethod, tree := range pr.trees {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package rest
|
package rest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@@ -24,6 +23,9 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MustNewServer returns a server with given config of c and options defined in opts.
|
||||||
|
// Be aware that later RunOption might overwrite previous one that write the same option.
|
||||||
|
// The process will exit if error occurs.
|
||||||
func MustNewServer(c RestConf, opts ...RunOption) *Server {
|
func MustNewServer(c RestConf, opts ...RunOption) *Server {
|
||||||
engine, err := NewServer(c, opts...)
|
engine, err := NewServer(c, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -33,11 +35,9 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server {
|
|||||||
return engine
|
return engine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewServer returns a server with given config of c and options defined in opts.
|
||||||
|
// Be aware that later RunOption might overwrite previous one that write the same option.
|
||||||
func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
||||||
if len(opts) > 1 {
|
|
||||||
return nil, errors.New("only one RunOption is allowed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.SetUp(); err != nil {
|
if err := c.SetUp(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,18 +8,84 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tal-tech/go-zero/core/conf"
|
||||||
"github.com/tal-tech/go-zero/rest/httpx"
|
"github.com/tal-tech/go-zero/rest/httpx"
|
||||||
"github.com/tal-tech/go-zero/rest/router"
|
"github.com/tal-tech/go-zero/rest/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewServer(t *testing.T) {
|
func TestNewServer(t *testing.T) {
|
||||||
_, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil))
|
const configYaml = `
|
||||||
assert.NotNil(t, err)
|
Name: foo
|
||||||
|
Port: 54321
|
||||||
|
`
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
|
||||||
|
failStart := func(server *Server) {
|
||||||
|
server.opts.start = func(e *engine) error {
|
||||||
|
return http.ErrServerClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
c RestConf
|
||||||
|
opts []RunOption
|
||||||
|
fail bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
c: RestConf{},
|
||||||
|
opts: []RunOption{failStart},
|
||||||
|
fail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
c: cnf,
|
||||||
|
opts: []RunOption{failStart},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
c: cnf,
|
||||||
|
opts: []RunOption{WithNotAllowedHandler(nil), failStart},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
c: cnf,
|
||||||
|
opts: []RunOption{WithNotFoundHandler(nil), failStart},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
c: cnf,
|
||||||
|
opts: []RunOption{WithUnauthorizedCallback(nil), failStart},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
c: cnf,
|
||||||
|
opts: []RunOption{WithUnsignedCallback(nil), failStart},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
srv, err := NewServer(test.c, test.opts...)
|
||||||
|
if test.fail {
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
srv.AddRoute(Route{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: nil,
|
||||||
|
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
|
||||||
|
WithJwtTransition("preivous", "thenewone"))
|
||||||
|
srv.Start()
|
||||||
|
srv.Stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWithMiddleware(t *testing.T) {
|
func TestWithMiddleware(t *testing.T) {
|
||||||
m := make(map[string]string)
|
m := make(map[string]string)
|
||||||
router := router.NewRouter()
|
rt := router.NewRouter()
|
||||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
var v struct {
|
var v struct {
|
||||||
Nickname string `form:"nickname"`
|
Nickname string `form:"nickname"`
|
||||||
@@ -56,14 +122,14 @@ func TestWithMiddleware(t *testing.T) {
|
|||||||
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
|
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
|
||||||
}
|
}
|
||||||
for _, route := range rs {
|
for _, route := range rs {
|
||||||
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
|
assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
|
||||||
}
|
}
|
||||||
for _, url := range urls {
|
for _, url := range urls {
|
||||||
r, err := http.NewRequest(http.MethodGet, url, nil)
|
r, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
router.ServeHTTP(rr, r)
|
rt.ServeHTTP(rr, r)
|
||||||
|
|
||||||
assert.Equal(t, "whatever:200000", rr.Body.String())
|
assert.Equal(t, "whatever:200000", rr.Body.String())
|
||||||
}
|
}
|
||||||
@@ -76,7 +142,7 @@ func TestWithMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
func TestMultiMiddlewares(t *testing.T) {
|
func TestMultiMiddlewares(t *testing.T) {
|
||||||
m := make(map[string]string)
|
m := make(map[string]string)
|
||||||
router := router.NewRouter()
|
rt := router.NewRouter()
|
||||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
var v struct {
|
var v struct {
|
||||||
Nickname string `form:"nickname"`
|
Nickname string `form:"nickname"`
|
||||||
@@ -127,14 +193,14 @@ func TestMultiMiddlewares(t *testing.T) {
|
|||||||
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
|
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
|
||||||
}
|
}
|
||||||
for _, route := range rs {
|
for _, route := range rs {
|
||||||
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
|
assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
|
||||||
}
|
}
|
||||||
for _, url := range urls {
|
for _, url := range urls {
|
||||||
r, err := http.NewRequest(http.MethodGet, url, nil)
|
r, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
router.ServeHTTP(rr, r)
|
rt.ServeHTTP(rr, r)
|
||||||
|
|
||||||
assert.Equal(t, "whatever:200000200000", rr.Body.String())
|
assert.Equal(t, "whatever:200000200000", rr.Body.String())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"go/format"
|
||||||
"go/scanner"
|
"go/scanner"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
"github.com/tal-tech/go-zero/core/errorx"
|
"github.com/tal-tech/go-zero/core/errorx"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/api/parser"
|
"github.com/tal-tech/go-zero/tools/goctl/api/parser"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/api/util"
|
"github.com/tal-tech/go-zero/tools/goctl/api/util"
|
||||||
|
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
"github.com/urfave/cli"
|
"github.com/urfave/cli"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -103,24 +105,108 @@ func apiFormat(data string) (string, error) {
|
|||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
s := bufio.NewScanner(strings.NewReader(data))
|
s := bufio.NewScanner(strings.NewReader(data))
|
||||||
var tapCount = 0
|
var tapCount = 0
|
||||||
|
var newLineCount = 0
|
||||||
|
var preLine string
|
||||||
for s.Scan() {
|
for s.Scan() {
|
||||||
line := strings.TrimSpace(s.Text())
|
line := strings.TrimSpace(s.Text())
|
||||||
|
if len(line) == 0 {
|
||||||
|
if newLineCount > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newLineCount++
|
||||||
|
} else {
|
||||||
|
if preLine == rightBrace {
|
||||||
|
builder.WriteString(ctlutil.NL)
|
||||||
|
}
|
||||||
|
newLineCount = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if tapCount == 0 {
|
||||||
|
format, err := formatGoTypeDef(line, s, &builder)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if format {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
noCommentLine := util.RemoveComment(line)
|
noCommentLine := util.RemoveComment(line)
|
||||||
if noCommentLine == rightParenthesis || noCommentLine == rightBrace {
|
if noCommentLine == rightParenthesis || noCommentLine == rightBrace {
|
||||||
tapCount -= 1
|
tapCount -= 1
|
||||||
}
|
}
|
||||||
if tapCount < 0 {
|
if tapCount < 0 {
|
||||||
line = strings.TrimSuffix(line, rightBrace)
|
line := strings.TrimSuffix(noCommentLine, rightBrace)
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
if strings.HasSuffix(line, leftBrace) {
|
if strings.HasSuffix(line, leftBrace) {
|
||||||
tapCount += 1
|
tapCount += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
util.WriteIndent(&builder, tapCount)
|
util.WriteIndent(&builder, tapCount)
|
||||||
builder.WriteString(line + "\n")
|
builder.WriteString(line + ctlutil.NL)
|
||||||
if strings.HasSuffix(noCommentLine, leftParenthesis) || strings.HasSuffix(noCommentLine, leftBrace) {
|
if strings.HasSuffix(noCommentLine, leftParenthesis) || strings.HasSuffix(noCommentLine, leftBrace) {
|
||||||
tapCount += 1
|
tapCount += 1
|
||||||
}
|
}
|
||||||
|
preLine = line
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(builder.String()), nil
|
return strings.TrimSpace(builder.String()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatGoTypeDef(line string, scanner *bufio.Scanner, builder *strings.Builder) (bool, error) {
|
||||||
|
noCommentLine := util.RemoveComment(line)
|
||||||
|
tokenCount := 0
|
||||||
|
if strings.HasPrefix(noCommentLine, "type") && (strings.HasSuffix(noCommentLine, leftParenthesis) ||
|
||||||
|
strings.HasSuffix(noCommentLine, leftBrace)) {
|
||||||
|
var typeBuilder strings.Builder
|
||||||
|
typeBuilder.WriteString(mayInsertStructKeyword(line, &tokenCount) + ctlutil.NL)
|
||||||
|
for scanner.Scan() {
|
||||||
|
noCommentLine := util.RemoveComment(scanner.Text())
|
||||||
|
typeBuilder.WriteString(mayInsertStructKeyword(scanner.Text(), &tokenCount) + ctlutil.NL)
|
||||||
|
if noCommentLine == rightBrace || noCommentLine == rightParenthesis {
|
||||||
|
tokenCount--
|
||||||
|
}
|
||||||
|
if tokenCount == 0 {
|
||||||
|
ts, err := format.Source([]byte(typeBuilder.String()))
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.New("error format \n" + typeBuilder.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
result := strings.ReplaceAll(string(ts), " struct ", " ")
|
||||||
|
result = strings.ReplaceAll(result, "type ()", "")
|
||||||
|
builder.WriteString(result)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mayInsertStructKeyword(line string, token *int) string {
|
||||||
|
insertStruct := func() string {
|
||||||
|
if strings.Contains(line, " struct") {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
index := strings.Index(line, leftBrace)
|
||||||
|
return line[:index] + " struct " + line[index:]
|
||||||
|
}
|
||||||
|
|
||||||
|
noCommentLine := util.RemoveComment(line)
|
||||||
|
if strings.HasSuffix(noCommentLine, leftBrace) {
|
||||||
|
*token++
|
||||||
|
return insertStruct()
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(noCommentLine, rightBrace) {
|
||||||
|
noCommentLine = strings.TrimSuffix(noCommentLine, rightBrace)
|
||||||
|
noCommentLine = util.RemoveComment(noCommentLine)
|
||||||
|
if strings.HasSuffix(noCommentLine, leftBrace) {
|
||||||
|
return insertStruct()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(noCommentLine, leftParenthesis) {
|
||||||
|
*token++
|
||||||
|
}
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,11 +24,11 @@ handler: GreetHandler
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
formattedStr = `type Request struct {
|
formattedStr = `type Request {
|
||||||
Name string
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Response struct {
|
type Response {
|
||||||
Message string
|
Message string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ service A-api {
|
|||||||
}`
|
}`
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestInlineTypeNotExist(t *testing.T) {
|
func TestFormat(t *testing.T) {
|
||||||
r, err := apiFormat(notFormattedStr)
|
r, err := apiFormat(notFormattedStr)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, r, formattedStr)
|
assert.Equal(t, r, formattedStr)
|
||||||
|
|||||||
@@ -38,11 +38,12 @@ func RevertTemplate(name string) error {
|
|||||||
return util.CreateTemplate(category, name, content)
|
return util.CreateTemplate(category, name, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Update(category string) error {
|
func Update() error {
|
||||||
err := Clean()
|
err := Clean()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return util.InitTemplates(category, templates)
|
return util.InitTemplates(category, templates)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,6 +51,6 @@ func Clean() error {
|
|||||||
return util.Clean(category)
|
return util.Clean(category)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetCategory() string {
|
func Category() string {
|
||||||
return category
|
return category
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ func TestUpdate(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, string(data), modifyData)
|
assert.Equal(t, string(data), modifyData)
|
||||||
|
|
||||||
assert.Nil(t, Update(category))
|
assert.Nil(t, Update())
|
||||||
|
|
||||||
data, err = ioutil.ReadFile(file)
|
data, err = ioutil.ReadFile(file)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package javagen
|
package javagen
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"path"
|
"path"
|
||||||
@@ -17,6 +18,8 @@ const (
|
|||||||
package com.xhb.logic.http.packet.{{.packet}}.model;
|
package com.xhb.logic.http.packet.{{.packet}}.model;
|
||||||
|
|
||||||
import com.xhb.logic.http.DeProguardable;
|
import com.xhb.logic.http.DeProguardable;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.jetbrains.annotations.Nullable;
|
||||||
|
|
||||||
{{.componentType}}
|
{{.componentType}}
|
||||||
`
|
`
|
||||||
@@ -28,7 +31,7 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for _, ty := range types {
|
for _, ty := range types {
|
||||||
if err := createComponent(dir, packetName, ty); err != nil {
|
if err := createComponent(dir, packetName, ty, api.Types); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -36,7 +39,7 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createComponent(dir, packetName string, ty spec.Type) error {
|
func createComponent(dir, packetName string, ty spec.Type, types []spec.Type) error {
|
||||||
modelFile := util.Title(ty.Name) + ".java"
|
modelFile := util.Title(ty.Name) + ".java"
|
||||||
filename := path.Join(dir, modelDir, modelFile)
|
filename := path.Join(dir, modelDir, modelFile)
|
||||||
if err := util.RemoveOrQuit(filename); err != nil {
|
if err := util.RemoveOrQuit(filename); err != nil {
|
||||||
@@ -52,7 +55,7 @@ func createComponent(dir, packetName string, ty spec.Type) error {
|
|||||||
}
|
}
|
||||||
defer fp.Close()
|
defer fp.Close()
|
||||||
|
|
||||||
tys, err := buildType(ty)
|
tys, err := buildType(ty, types)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -64,22 +67,66 @@ func createComponent(dir, packetName string, ty spec.Type) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildType(ty spec.Type) (string, error) {
|
func buildType(ty spec.Type, types []spec.Type) (string, error) {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
if err := writeType(&builder, ty); err != nil {
|
if err := writeType(&builder, ty, types); err != nil {
|
||||||
return "", apiutil.WrapErr(err, "Type "+ty.Name+" generate error")
|
return "", apiutil.WrapErr(err, "Type "+ty.Name+" generate error")
|
||||||
}
|
}
|
||||||
return builder.String(), nil
|
return builder.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeType(writer io.Writer, tp spec.Type) error {
|
func writeType(writer io.Writer, tp spec.Type, types []spec.Type) error {
|
||||||
fmt.Fprintf(writer, "public class %s implements DeProguardable {\n", util.Title(tp.Name))
|
fmt.Fprintf(writer, "public class %s implements DeProguardable {\n", util.Title(tp.Name))
|
||||||
for _, member := range tp.Members {
|
var members []spec.Member
|
||||||
if err := writeProperty(writer, member, 1); err != nil {
|
err := writeMembers(writer, types, tp.Members, &members, 1)
|
||||||
return err
|
if err != nil {
|
||||||
}
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
genGetSet(writer, members, 1)
|
||||||
|
fmt.Fprintf(writer, "}")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeMembers(writer io.Writer, types []spec.Type, members []spec.Member, allMembers *[]spec.Member, indent int) error {
|
||||||
|
for _, member := range members {
|
||||||
|
if !member.IsInline {
|
||||||
|
_, err := member.GetPropertyName()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !member.IsBodyMember() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range *allMembers {
|
||||||
|
if item.Name == member.Name {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if member.IsInline {
|
||||||
|
hasInline := false
|
||||||
|
for _, ty := range types {
|
||||||
|
if strings.ToLower(ty.Name) == strings.ToLower(member.Name) {
|
||||||
|
err := writeMembers(writer, types, ty.Members, allMembers, indent)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hasInline = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasInline {
|
||||||
|
return errors.New("inline type " + member.Name + " not exist, please correct api file")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := writeProperty(writer, member, indent); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*allMembers = append(*allMembers, member)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
genGetSet(writer, tp, 1)
|
|
||||||
fmt.Fprintf(writer, "}\n")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,23 +19,27 @@ const packetTemplate = `package com.xhb.logic.http.packet.{{.packet}};
|
|||||||
|
|
||||||
import com.google.gson.Gson;
|
import com.google.gson.Gson;
|
||||||
import com.xhb.commons.JSON;
|
import com.xhb.commons.JSON;
|
||||||
import com.xhb.commons.JsonParser;
|
import com.xhb.commons.JsonMarshal;
|
||||||
import com.xhb.core.network.HttpRequestClient;
|
import com.xhb.core.network.HttpRequestClient;
|
||||||
import com.xhb.core.packet.HttpRequestPacket;
|
import com.xhb.core.packet.HttpRequestPacket;
|
||||||
import com.xhb.core.response.HttpResponseData;
|
import com.xhb.core.response.HttpResponseData;
|
||||||
import com.xhb.logic.http.DeProguardable;
|
import com.xhb.logic.http.DeProguardable;
|
||||||
|
{{if not .HasRequestBody}}
|
||||||
|
import com.xhb.logic.http.request.EmptyRequest;
|
||||||
|
{{end}}
|
||||||
{{.import}}
|
{{.import}}
|
||||||
|
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.jetbrains.annotations.Nullable;
|
||||||
import org.json.JSONObject;
|
import org.json.JSONObject;
|
||||||
|
|
||||||
public class {{.packetName}} extends HttpRequestPacket<{{.packetName}}.{{.packetName}}Response> {
|
public class {{.packetName}} extends HttpRequestPacket<{{.packetName}}.{{.packetName}}Response> {
|
||||||
|
|
||||||
{{.paramsDeclaration}}
|
{{.paramsDeclaration}}
|
||||||
|
|
||||||
public {{.packetName}}({{.params}}{{.requestType}} request) {
|
public {{.packetName}}({{.params}}{{if .HasRequestBody}}, {{.requestType}} request{{end}}) {
|
||||||
super(request);
|
{{if .HasRequestBody}}super(request);{{else}}super(EmptyRequest.instance);{{end}}
|
||||||
this.request = request;{{.paramsSet}}
|
{{if .HasRequestBody}}this.request = request;{{end}}{{.paramsSet}}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -113,7 +117,8 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
|
|||||||
} else {
|
} else {
|
||||||
fmt.Fprintln(&builder)
|
fmt.Fprintln(&builder)
|
||||||
}
|
}
|
||||||
if err := genType(&builder, tp); err != nil {
|
|
||||||
|
if err := genType(&builder, tp, api.Types); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -126,7 +131,7 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
|
|||||||
|
|
||||||
t := template.Must(template.New("packetTemplate").Parse(packetTemplate))
|
t := template.Must(template.New("packetTemplate").Parse(packetTemplate))
|
||||||
var tmplBytes bytes.Buffer
|
var tmplBytes bytes.Buffer
|
||||||
err = t.Execute(&tmplBytes, map[string]string{
|
err = t.Execute(&tmplBytes, map[string]interface{}{
|
||||||
"packetName": packet,
|
"packetName": packet,
|
||||||
"method": strings.ToUpper(route.Method),
|
"method": strings.ToUpper(route.Method),
|
||||||
"uri": processUri(route),
|
"uri": processUri(route),
|
||||||
@@ -137,6 +142,7 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
|
|||||||
"paramsSet": paramsSet,
|
"paramsSet": paramsSet,
|
||||||
"packet": packetName,
|
"packet": packetName,
|
||||||
"requestType": util.Title(route.RequestType.Name),
|
"requestType": util.Title(route.RequestType.Name),
|
||||||
|
"HasRequestBody": len(route.RequestType.GetBodyMembers()) > 0,
|
||||||
"import": getImports(api, route, packetName),
|
"import": getImports(api, route, packetName),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -209,7 +215,7 @@ func paramsForRoute(route spec.Route) string {
|
|||||||
builder.WriteString(fmt.Sprintf("String %s, ", cop[1:]))
|
builder.WriteString(fmt.Sprintf("String %s, ", cop[1:]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return builder.String()
|
return strings.TrimSuffix(builder.String(), ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func declarationForRoute(route spec.Route) string {
|
func declarationForRoute(route spec.Route) string {
|
||||||
@@ -260,18 +266,22 @@ func processUri(route spec.Route) string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func genType(writer io.Writer, tp spec.Type) error {
|
func genType(writer io.Writer, tp spec.Type, types []spec.Type) error {
|
||||||
writeIndent(writer, 1)
|
if len(tp.GetBodyMembers()) == 0 {
|
||||||
fmt.Fprintf(writer, "static class %s implements DeProguardable {\n", util.Title(tp.Name))
|
return nil
|
||||||
for _, member := range tp.Members {
|
|
||||||
if err := writeProperty(writer, member, 2); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
writeBreakline(writer)
|
|
||||||
writeIndent(writer, 1)
|
writeIndent(writer, 1)
|
||||||
genGetSet(writer, tp, 2)
|
fmt.Fprintf(writer, "static class %s implements DeProguardable {\n", util.Title(tp.Name))
|
||||||
|
var members []spec.Member
|
||||||
|
err := writeMembers(writer, types, tp.Members, &members, 2)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
writeNewline(writer)
|
||||||
|
writeIndent(writer, 1)
|
||||||
|
genGetSet(writer, members, 2)
|
||||||
writeIndent(writer, 1)
|
writeIndent(writer, 1)
|
||||||
fmt.Fprintln(writer, "}")
|
fmt.Fprintln(writer, "}")
|
||||||
|
|
||||||
|
|||||||
@@ -67,8 +67,8 @@ func indentString(indent int) string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeBreakline(writer io.Writer) {
|
func writeNewline(writer io.Writer) {
|
||||||
fmt.Fprint(writer, "\n")
|
fmt.Fprint(writer, util.NL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isPrimitiveType(tp string) bool {
|
func isPrimitiveType(tp string) bool {
|
||||||
@@ -87,6 +87,7 @@ func goTypeToJava(tp string) (string, error) {
|
|||||||
if len(tp) == 0 {
|
if len(tp) == 0 {
|
||||||
return "", errors.New("property type empty")
|
return "", errors.New("property type empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(tp, "*") {
|
if strings.HasPrefix(tp, "*") {
|
||||||
tp = tp[1:]
|
tp = tp[1:]
|
||||||
}
|
}
|
||||||
@@ -107,39 +108,44 @@ func goTypeToJava(tp string) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tys) == 0 {
|
if len(tys) == 0 {
|
||||||
return "", fmt.Errorf("%s tp parse error", tp)
|
return "", fmt.Errorf("%s tp parse error", tp)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(tys[0])), nil
|
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(tys[0])), nil
|
||||||
} else if strings.HasPrefix(tp, "map") {
|
} else if strings.HasPrefix(tp, "map") {
|
||||||
tys, err := apiutil.DecomposeType(tp)
|
tys, err := apiutil.DecomposeType(tp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tys) == 2 {
|
if len(tys) == 2 {
|
||||||
return "", fmt.Errorf("%s tp parse error", tp)
|
return "", fmt.Errorf("%s tp parse error", tp)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("java.util.HashMap<String, %s>", util.Title(tys[1])), nil
|
return fmt.Sprintf("java.util.HashMap<String, %s>", util.Title(tys[1])), nil
|
||||||
}
|
}
|
||||||
return util.Title(tp), nil
|
return util.Title(tp), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func genGetSet(writer io.Writer, tp spec.Type, indent int) error {
|
func genGetSet(writer io.Writer, members []spec.Member, indent int) error {
|
||||||
t := template.Must(template.New("getSetTemplate").Parse(getSetTemplate))
|
t := template.Must(template.New("getSetTemplate").Parse(getSetTemplate))
|
||||||
for _, member := range tp.Members {
|
for _, member := range members {
|
||||||
var tmplBytes bytes.Buffer
|
var tmplBytes bytes.Buffer
|
||||||
|
|
||||||
oty, err := goTypeToJava(member.Type)
|
oty, err := goTypeToJava(member.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tyString := oty
|
tyString := oty
|
||||||
decorator := ""
|
decorator := ""
|
||||||
if !isPrimitiveType(member.Type) {
|
if !isPrimitiveType(member.Type) {
|
||||||
if member.IsOptional() {
|
if member.IsOptional() {
|
||||||
decorator = "@org.jetbrains.annotations.Nullable "
|
decorator = "@Nullable "
|
||||||
} else {
|
} else {
|
||||||
decorator = "@org.jetbrains.annotations.NotNull "
|
decorator = "@NotNull "
|
||||||
}
|
}
|
||||||
tyString = decorator + tyString
|
tyString = decorator + tyString
|
||||||
}
|
}
|
||||||
@@ -155,6 +161,7 @@ func genGetSet(writer io.Writer, tp spec.Type, indent int) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
r := tmplBytes.String()
|
r := tmplBytes.String()
|
||||||
r = strings.Replace(r, " boolean get", " boolean is", 1)
|
r = strings.Replace(r, " boolean get", " boolean is", 1)
|
||||||
writer.Write([]byte(r))
|
writer.Write([]byte(r))
|
||||||
|
|||||||
@@ -63,10 +63,6 @@ func (m Member) IsOmitempty() bool {
|
|||||||
|
|
||||||
func (m Member) GetPropertyName() (string, error) {
|
func (m Member) GetPropertyName() (string, error) {
|
||||||
tags := m.Tags()
|
tags := m.Tags()
|
||||||
if len(tags) == 0 {
|
|
||||||
return "", errors.New("json property name not exist, member: " + m.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
if stringx.Contains(definedKeys, tag.Key) {
|
if stringx.Contains(definedKeys, tag.Key) {
|
||||||
if tag.Name == "-" {
|
if tag.Name == "-" {
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func genHandler(dir, webApi, caller string, api *spec.ApiSpec, unwrapApi bool) e
|
|||||||
imports += fmt.Sprintf(`import * as components from "%s"`, "./"+outputFile)
|
imports += fmt.Sprintf(`import * as components from "%s"`, "./"+outputFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
apis, err := genApi(api, localTypes, caller, prefixForType)
|
apis, err := genApi(api, caller, prefixForType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -119,32 +119,34 @@ func genTypes(localTypes []spec.Type, inlineType func(string) (*spec.Type, error
|
|||||||
return types, nil
|
return types, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func genApi(api *spec.ApiSpec, localTypes []spec.Type, caller string, prefixForType func(string) string) (string, error) {
|
func genApi(api *spec.ApiSpec, caller string, prefixForType func(string) string) (string, error) {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
for _, route := range api.Service.Routes() {
|
for _, group := range api.Service.Groups {
|
||||||
handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler")
|
for _, route := range group.Routes {
|
||||||
if !ok {
|
handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler")
|
||||||
return "", fmt.Errorf("missing handler annotation for route %q", route.Path)
|
if !ok {
|
||||||
}
|
return "", fmt.Errorf("missing handler annotation for route %q", route.Path)
|
||||||
handler = util.Untitle(handler)
|
|
||||||
handler = strings.Replace(handler, "Handler", "", 1)
|
|
||||||
comment := commentForRoute(route)
|
|
||||||
if len(comment) > 0 {
|
|
||||||
fmt.Fprintf(&builder, "%s\n", comment)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(&builder, "export function %s(%s) {\n", handler, paramsForRoute(route, prefixForType))
|
|
||||||
writeIndent(&builder, 1)
|
|
||||||
responseGeneric := "<null>"
|
|
||||||
if len(route.ResponseType.Name) > 0 {
|
|
||||||
val, err := goTypeToTs(route.ResponseType.Name, prefixForType)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
responseGeneric = fmt.Sprintf("<%s>", val)
|
handler = util.Untitle(handler)
|
||||||
|
handler = strings.Replace(handler, "Handler", "", 1)
|
||||||
|
comment := commentForRoute(route)
|
||||||
|
if len(comment) > 0 {
|
||||||
|
fmt.Fprintf(&builder, "%s\n", comment)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&builder, "export function %s(%s) {\n", handler, paramsForRoute(route, prefixForType))
|
||||||
|
writeIndent(&builder, 1)
|
||||||
|
responseGeneric := "<null>"
|
||||||
|
if len(route.ResponseType.Name) > 0 {
|
||||||
|
val, err := goTypeToTs(route.ResponseType.Name, prefixForType)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
responseGeneric = fmt.Sprintf("<%s>", val)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&builder, `return %s.%s%s(%s)`, caller, strings.ToLower(route.Method),
|
||||||
|
util.Title(responseGeneric), callParamsForRoute(route, group))
|
||||||
|
builder.WriteString("\n}\n\n")
|
||||||
}
|
}
|
||||||
fmt.Fprintf(&builder, `return %s.%s%s(%s)`, caller, strings.ToLower(route.Method),
|
|
||||||
util.Title(responseGeneric), callParamsForRoute(route))
|
|
||||||
builder.WriteString("\n}\n\n")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
apis := builder.String()
|
apis := builder.String()
|
||||||
@@ -188,21 +190,28 @@ func commentForRoute(route spec.Route) string {
|
|||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func callParamsForRoute(route spec.Route) string {
|
func callParamsForRoute(route spec.Route, group spec.Group) string {
|
||||||
hasParams := pathHasParams(route)
|
hasParams := pathHasParams(route)
|
||||||
hasBody := hasRequestBody(route)
|
hasBody := hasRequestBody(route)
|
||||||
if hasParams && hasBody {
|
if hasParams && hasBody {
|
||||||
return fmt.Sprintf("%s, %s, %s", pathForRoute(route), "params", "req")
|
return fmt.Sprintf("%s, %s, %s", pathForRoute(route, group), "params", "req")
|
||||||
} else if hasParams {
|
} else if hasParams {
|
||||||
return fmt.Sprintf("%s, %s", pathForRoute(route), "params")
|
return fmt.Sprintf("%s, %s", pathForRoute(route, group), "params")
|
||||||
} else if hasBody {
|
} else if hasBody {
|
||||||
return fmt.Sprintf("%s, %s", pathForRoute(route), "req")
|
return fmt.Sprintf("%s, %s", pathForRoute(route, group), "req")
|
||||||
}
|
}
|
||||||
return pathForRoute(route)
|
return pathForRoute(route, group)
|
||||||
}
|
}
|
||||||
|
|
||||||
func pathForRoute(route spec.Route) string {
|
func pathForRoute(route spec.Route, group spec.Group) string {
|
||||||
return "\"" + route.Path + "\""
|
value, ok := apiutil.GetAnnotationValue(group.Annotations, "server", pathPrefix)
|
||||||
|
if !ok {
|
||||||
|
return "\"" + route.Path + "\""
|
||||||
|
} else {
|
||||||
|
value = strings.TrimPrefix(value, `"`)
|
||||||
|
value = strings.TrimSuffix(value, `"`)
|
||||||
|
return fmt.Sprintf(`"%s/%s"`, value, strings.TrimPrefix(route.Path, "/"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func pathHasParams(route spec.Route) bool {
|
func pathHasParams(route spec.Route) bool {
|
||||||
|
|||||||
@@ -2,4 +2,5 @@ package tsgen
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
packagePrefix = "components."
|
packagePrefix = "components."
|
||||||
|
pathPrefix = "pathPrefix"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,15 +9,17 @@ import (
|
|||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/logrusorgru/aurora"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
|
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
"github.com/urfave/cli"
|
"github.com/urfave/cli"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
etcDir = "etc"
|
dockerfileName = "Dockerfile"
|
||||||
yamlEtx = ".yaml"
|
etcDir = "etc"
|
||||||
cstOffset = 60 * 60 * 8 // 8 hours offset for Chinese Standard Time
|
yamlEtx = ".yaml"
|
||||||
|
cstOffset = 60 * 60 * 8 // 8 hours offset for Chinese Standard Time
|
||||||
)
|
)
|
||||||
|
|
||||||
type Docker struct {
|
type Docker struct {
|
||||||
@@ -25,10 +27,18 @@ type Docker struct {
|
|||||||
GoRelPath string
|
GoRelPath string
|
||||||
GoFile string
|
GoFile string
|
||||||
ExeFile string
|
ExeFile string
|
||||||
|
HasPort bool
|
||||||
|
Port int
|
||||||
Argument string
|
Argument string
|
||||||
}
|
}
|
||||||
|
|
||||||
func DockerCommand(c *cli.Context) error {
|
func DockerCommand(c *cli.Context) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
|
fmt.Println(aurora.Green("Done."))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
goFile := c.String("go")
|
goFile := c.String("go")
|
||||||
if len(goFile) == 0 {
|
if len(goFile) == 0 {
|
||||||
return errors.New("-go can't be empty")
|
return errors.New("-go can't be empty")
|
||||||
@@ -38,8 +48,9 @@ func DockerCommand(c *cli.Context) error {
|
|||||||
return fmt.Errorf("file %q not found", goFile)
|
return fmt.Errorf("file %q not found", goFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
port := c.Int("port")
|
||||||
if _, err := os.Stat(etcDir); os.IsNotExist(err) {
|
if _, err := os.Stat(etcDir); os.IsNotExist(err) {
|
||||||
return generateDockerfile(goFile)
|
return generateDockerfile(goFile, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg, err := findConfig(goFile, etcDir)
|
cfg, err := findConfig(goFile, etcDir)
|
||||||
@@ -47,13 +58,13 @@ func DockerCommand(c *cli.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := generateDockerfile(goFile, "-f", "etc/"+cfg); err != nil {
|
if err := generateDockerfile(goFile, port, "-f", "etc/"+cfg); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
projDir, ok := util.FindProjectPath(goFile)
|
projDir, ok := util.FindProjectPath(goFile)
|
||||||
if ok {
|
if ok {
|
||||||
fmt.Printf("Run \"docker build ...\" command in dir %q\n", projDir)
|
fmt.Printf("Hint: run \"docker build ...\" command in dir %q\n", projDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -88,18 +99,22 @@ func findConfig(file, dir string) (string, error) {
|
|||||||
return files[0], nil
|
return files[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateDockerfile(goFile string, args ...string) error {
|
func generateDockerfile(goFile string, port int, args ...string) error {
|
||||||
projPath, err := getFilePath(filepath.Dir(goFile))
|
projPath, err := getFilePath(filepath.Dir(goFile))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pos := strings.IndexByte(projPath, '/')
|
if len(projPath) == 0 {
|
||||||
if pos >= 0 {
|
projPath = "."
|
||||||
projPath = projPath[pos+1:]
|
} else {
|
||||||
|
pos := strings.IndexByte(projPath, os.PathSeparator)
|
||||||
|
if pos >= 0 {
|
||||||
|
projPath = projPath[pos+1:]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out, err := util.CreateIfNotExist("Dockerfile")
|
out, err := util.CreateIfNotExist(dockerfileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -122,6 +137,8 @@ func generateDockerfile(goFile string, args ...string) error {
|
|||||||
GoRelPath: projPath,
|
GoRelPath: projPath,
|
||||||
GoFile: goFile,
|
GoFile: goFile,
|
||||||
ExeFile: util.FileNameWithoutExt(filepath.Base(goFile)),
|
ExeFile: util.FileNameWithoutExt(filepath.Base(goFile)),
|
||||||
|
HasPort: port > 0,
|
||||||
|
Port: port,
|
||||||
Argument: builder.String(),
|
Argument: builder.String(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,34 +14,59 @@ LABEL stage=gobuilder
|
|||||||
|
|
||||||
ENV CGO_ENABLED 0
|
ENV CGO_ENABLED 0
|
||||||
ENV GOOS linux
|
ENV GOOS linux
|
||||||
{{if .Chinese}}ENV GOPROXY https://goproxy.cn,direct{{end}}
|
{{if .Chinese}}ENV GOPROXY https://goproxy.cn,direct
|
||||||
|
{{end}}
|
||||||
WORKDIR /build/zero
|
WORKDIR /build/zero
|
||||||
|
|
||||||
ADD go.mod .
|
ADD go.mod .
|
||||||
ADD go.sum .
|
ADD go.sum .
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
COPY . .
|
COPY . .
|
||||||
COPY {{.GoRelPath}}/etc /app/etc
|
{{if .Argument}}COPY {{.GoRelPath}}/etc /app/etc
|
||||||
RUN go build -ldflags="-s -w" -o /app/{{.ExeFile}} {{.GoRelPath}}/{{.GoFile}}
|
{{end}}RUN go build -ldflags="-s -w" -o /app/{{.ExeFile}} {{.GoRelPath}}/{{.GoFile}}
|
||||||
|
|
||||||
|
|
||||||
FROM alpine
|
FROM alpine
|
||||||
|
|
||||||
RUN apk update --no-cache
|
RUN apk update --no-cache && apk add --no-cache ca-certificates tzdata
|
||||||
RUN apk add --no-cache ca-certificates
|
|
||||||
RUN apk add --no-cache tzdata
|
|
||||||
ENV TZ Asia/Shanghai
|
ENV TZ Asia/Shanghai
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY --from=builder /app/{{.ExeFile}} /app/{{.ExeFile}}
|
COPY --from=builder /app/{{.ExeFile}} /app/{{.ExeFile}}{{if .Argument}}
|
||||||
COPY --from=builder /app/etc /app/etc
|
COPY --from=builder /app/etc /app/etc{{end}}
|
||||||
|
{{if .HasPort}}
|
||||||
|
EXPOSE {{.Port}}
|
||||||
|
{{end}}
|
||||||
CMD ["./{{.ExeFile}}"{{.Argument}}]
|
CMD ["./{{.ExeFile}}"{{.Argument}}]
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func Clean() error {
|
||||||
|
return util.Clean(category)
|
||||||
|
}
|
||||||
|
|
||||||
func GenTemplates(_ *cli.Context) error {
|
func GenTemplates(_ *cli.Context) error {
|
||||||
|
return initTemplate()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Category() string {
|
||||||
|
return category
|
||||||
|
}
|
||||||
|
|
||||||
|
func RevertTemplate(name string) error {
|
||||||
|
return util.CreateTemplate(category, name, dockerTemplate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Update() error {
|
||||||
|
err := Clean()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return initTemplate()
|
||||||
|
}
|
||||||
|
|
||||||
|
func initTemplate() error {
|
||||||
return util.InitTemplates(category, map[string]string{
|
return util.InitTemplates(category, map[string]string{
|
||||||
dockerTemplateFile: dockerTemplate,
|
dockerTemplateFile: dockerTemplate,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
BuildVersion = "20201125"
|
BuildVersion = "1.1.1"
|
||||||
commands = []cli.Command{
|
commands = []cli.Command{
|
||||||
{
|
{
|
||||||
Name: "api",
|
Name: "api",
|
||||||
@@ -54,14 +54,12 @@ var (
|
|||||||
Usage: "the format target dir",
|
Usage: "the format target dir",
|
||||||
},
|
},
|
||||||
cli.BoolFlag{
|
cli.BoolFlag{
|
||||||
Name: "iu",
|
Name: "iu",
|
||||||
Usage: "ignore update",
|
Usage: "ignore update",
|
||||||
Required: false,
|
|
||||||
},
|
},
|
||||||
cli.BoolFlag{
|
cli.BoolFlag{
|
||||||
Name: "stdin",
|
Name: "stdin",
|
||||||
Usage: "use stdin to input api doc content, press \"ctrl + d\" to send EOF",
|
Usage: "use stdin to input api doc content, press \"ctrl + d\" to send EOF",
|
||||||
Required: false,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Action: format.GoFormatApi,
|
Action: format.GoFormatApi,
|
||||||
@@ -101,9 +99,8 @@ var (
|
|||||||
Usage: "the api file",
|
Usage: "the api file",
|
||||||
},
|
},
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "style",
|
Name: "style",
|
||||||
Required: false,
|
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Action: gogen.GoCommand,
|
Action: gogen.GoCommand,
|
||||||
@@ -136,19 +133,16 @@ var (
|
|||||||
Usage: "the api file",
|
Usage: "the api file",
|
||||||
},
|
},
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "webapi",
|
Name: "webapi",
|
||||||
Usage: "the web api file path",
|
Usage: "the web api file path",
|
||||||
Required: false,
|
|
||||||
},
|
},
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "caller",
|
Name: "caller",
|
||||||
Usage: "the web api caller",
|
Usage: "the web api caller",
|
||||||
Required: false,
|
|
||||||
},
|
},
|
||||||
cli.BoolFlag{
|
cli.BoolFlag{
|
||||||
Name: "unwrap",
|
Name: "unwrap",
|
||||||
Usage: "unwrap the webapi caller for import",
|
Usage: "unwrap the webapi caller for import",
|
||||||
Required: false,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Action: tsgen.TsCommand,
|
Action: tsgen.TsCommand,
|
||||||
@@ -204,9 +198,8 @@ var (
|
|||||||
Usage: "the api file",
|
Usage: "the api file",
|
||||||
},
|
},
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "style",
|
Name: "style",
|
||||||
Required: false,
|
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Action: plugin.PluginCommand,
|
Action: plugin.PluginCommand,
|
||||||
@@ -221,6 +214,11 @@ var (
|
|||||||
Name: "go",
|
Name: "go",
|
||||||
Usage: "the file that contains main function",
|
Usage: "the file that contains main function",
|
||||||
},
|
},
|
||||||
|
cli.IntFlag{
|
||||||
|
Name: "port",
|
||||||
|
Usage: "the port to expose, default none",
|
||||||
|
Value: 0,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Action: docker.DockerCommand,
|
Action: docker.DockerCommand,
|
||||||
},
|
},
|
||||||
@@ -248,9 +246,8 @@ var (
|
|||||||
Required: true,
|
Required: true,
|
||||||
},
|
},
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "secret",
|
Name: "secret",
|
||||||
Usage: "the image pull secret",
|
Usage: "the secret to image pull from registry",
|
||||||
Required: true,
|
|
||||||
},
|
},
|
||||||
cli.IntFlag{
|
cli.IntFlag{
|
||||||
Name: "requestCpu",
|
Name: "requestCpu",
|
||||||
@@ -321,9 +318,8 @@ var (
|
|||||||
Usage: `generate rpc demo service`,
|
Usage: `generate rpc demo service`,
|
||||||
Flags: []cli.Flag{
|
Flags: []cli.Flag{
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "style",
|
Name: "style",
|
||||||
Required: false,
|
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
|
||||||
},
|
},
|
||||||
cli.BoolFlag{
|
cli.BoolFlag{
|
||||||
Name: "idea",
|
Name: "idea",
|
||||||
@@ -360,9 +356,8 @@ var (
|
|||||||
Usage: `the target path of the code`,
|
Usage: `the target path of the code`,
|
||||||
},
|
},
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "style",
|
Name: "style",
|
||||||
Required: false,
|
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
|
||||||
},
|
},
|
||||||
cli.BoolFlag{
|
cli.BoolFlag{
|
||||||
Name: "idea",
|
Name: "idea",
|
||||||
@@ -394,9 +389,8 @@ var (
|
|||||||
Usage: "the target dir",
|
Usage: "the target dir",
|
||||||
},
|
},
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "style",
|
Name: "style",
|
||||||
Required: false,
|
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
|
||||||
},
|
},
|
||||||
cli.BoolFlag{
|
cli.BoolFlag{
|
||||||
Name: "cache, c",
|
Name: "cache, c",
|
||||||
@@ -430,9 +424,8 @@ var (
|
|||||||
Usage: "the target dir",
|
Usage: "the target dir",
|
||||||
},
|
},
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "style",
|
Name: "style",
|
||||||
Required: false,
|
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
|
||||||
},
|
},
|
||||||
cli.BoolFlag{
|
cli.BoolFlag{
|
||||||
Name: "idea",
|
Name: "idea",
|
||||||
@@ -476,7 +469,7 @@ var (
|
|||||||
Flags: []cli.Flag{
|
Flags: []cli.Flag{
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "category,c",
|
Name: "category,c",
|
||||||
Usage: "the category of template, enum [api,rpc,model]",
|
Usage: "the category of template, enum [api,rpc,model,docker,kube]",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Action: tpl.UpdateTemplates,
|
Action: tpl.UpdateTemplates,
|
||||||
@@ -487,7 +480,7 @@ var (
|
|||||||
Flags: []cli.Flag{
|
Flags: []cli.Flag{
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "category,c",
|
Name: "category,c",
|
||||||
Usage: "the category of template, enum [api,rpc,model]",
|
Usage: "the category of template, enum [api,rpc,model,docker,kube]",
|
||||||
},
|
},
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "name,n",
|
Name: "name,n",
|
||||||
|
|||||||
@@ -47,9 +47,9 @@ spec:
|
|||||||
volumeMounts:
|
volumeMounts:
|
||||||
- name: timezone
|
- name: timezone
|
||||||
mountPath: /etc/localtime
|
mountPath: /etc/localtime
|
||||||
imagePullSecrets:
|
{{if .Secret}}imagePullSecrets:
|
||||||
- name: {{.Secret}}
|
- name: {{.Secret}}
|
||||||
volumes:
|
{{end}}volumes:
|
||||||
- name: timezone
|
- name: timezone
|
||||||
hostPath:
|
hostPath:
|
||||||
path: /usr/share/zoneinfo/Asia/Shanghai
|
path: /usr/share/zoneinfo/Asia/Shanghai
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package kube
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/logrusorgru/aurora"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
"github.com/urfave/cli"
|
"github.com/urfave/cli"
|
||||||
)
|
)
|
||||||
@@ -16,47 +18,23 @@ const (
|
|||||||
portLimit = 32767
|
portLimit = 32767
|
||||||
)
|
)
|
||||||
|
|
||||||
var errUnknownServiceType = errors.New("unknown service type")
|
type Deployment struct {
|
||||||
|
Name string
|
||||||
type (
|
Namespace string
|
||||||
ServiceType string
|
Image string
|
||||||
|
Secret string
|
||||||
KubeRequest struct {
|
Replicas int
|
||||||
Env string
|
Revisions int
|
||||||
ServiceName string
|
Port int
|
||||||
ServiceType ServiceType
|
NodePort int
|
||||||
Namespace string
|
UseNodePort bool
|
||||||
Schedule string
|
RequestCpu int
|
||||||
Replicas int
|
RequestMem int
|
||||||
RevisionHistoryLimit int
|
LimitCpu int
|
||||||
Port int
|
LimitMem int
|
||||||
LimitCpu int
|
MinReplicas int
|
||||||
LimitMem int
|
MaxReplicas int
|
||||||
RequestCpu int
|
}
|
||||||
RequestMem int
|
|
||||||
SuccessfulJobsHistoryLimit int
|
|
||||||
HpaMinReplicas int
|
|
||||||
HpaMaxReplicas int
|
|
||||||
}
|
|
||||||
|
|
||||||
Deployment struct {
|
|
||||||
Name string
|
|
||||||
Namespace string
|
|
||||||
Image string
|
|
||||||
Secret string
|
|
||||||
Replicas int
|
|
||||||
Revisions int
|
|
||||||
Port int
|
|
||||||
NodePort int
|
|
||||||
UseNodePort bool
|
|
||||||
RequestCpu int
|
|
||||||
RequestMem int
|
|
||||||
LimitCpu int
|
|
||||||
LimitMem int
|
|
||||||
MinReplicas int
|
|
||||||
MaxReplicas int
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func DeploymentCommand(c *cli.Context) error {
|
func DeploymentCommand(c *cli.Context) error {
|
||||||
nodePort := c.Int("nodePort")
|
nodePort := c.Int("nodePort")
|
||||||
@@ -77,7 +55,7 @@ func DeploymentCommand(c *cli.Context) error {
|
|||||||
defer out.Close()
|
defer out.Close()
|
||||||
|
|
||||||
t := template.Must(template.New("deploymentTemplate").Parse(text))
|
t := template.Must(template.New("deploymentTemplate").Parse(text))
|
||||||
return t.Execute(out, Deployment{
|
err = t.Execute(out, Deployment{
|
||||||
Name: c.String("name"),
|
Name: c.String("name"),
|
||||||
Namespace: c.String("namespace"),
|
Namespace: c.String("namespace"),
|
||||||
Image: c.String("image"),
|
Image: c.String("image"),
|
||||||
@@ -94,6 +72,20 @@ func DeploymentCommand(c *cli.Context) error {
|
|||||||
MinReplicas: c.Int("minReplicas"),
|
MinReplicas: c.Int("minReplicas"),
|
||||||
MaxReplicas: c.Int("maxReplicas"),
|
MaxReplicas: c.Int("maxReplicas"),
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(aurora.Green("Done."))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Category() string {
|
||||||
|
return category
|
||||||
|
}
|
||||||
|
|
||||||
|
func Clean() error {
|
||||||
|
return util.Clean(category)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenTemplates(_ *cli.Context) error {
|
func GenTemplates(_ *cli.Context) error {
|
||||||
@@ -102,3 +94,19 @@ func GenTemplates(_ *cli.Context) error {
|
|||||||
jobTemplateFile: jobTmeplate,
|
jobTemplateFile: jobTmeplate,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RevertTemplate(name string) error {
|
||||||
|
return util.CreateTemplate(category, name, deploymentTemplate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Update() error {
|
||||||
|
err := Clean()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.InitTemplates(category, map[string]string{
|
||||||
|
deployTemplateFile: deploymentTemplate,
|
||||||
|
jobTemplateFile: jobTmeplate,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -61,9 +61,9 @@ func FieldNames(in interface{}) []string {
|
|||||||
// gets us a StructField
|
// gets us a StructField
|
||||||
fi := typ.Field(i)
|
fi := typ.Field(i)
|
||||||
if tagv := fi.Tag.Get(dbTag); tagv != "" {
|
if tagv := fi.Tag.Get(dbTag); tagv != "" {
|
||||||
out = append(out, tagv)
|
out = append(out, fmt.Sprintf("`%v`", tagv))
|
||||||
} else {
|
} else {
|
||||||
out = append(out, fi.Name)
|
out = append(out, fmt.Sprintf("`%v`", fi.Name))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -28,8 +28,7 @@ var userFields = FieldNames(User{})
|
|||||||
func TestFieldNames(t *testing.T) {
|
func TestFieldNames(t *testing.T) {
|
||||||
var u User
|
var u User
|
||||||
out := FieldNames(&u)
|
out := FieldNames(&u)
|
||||||
fmt.Println(out)
|
actual := []string{"`id`", "`user_name`", "`sex`", "`uuid`", "`age`"}
|
||||||
actual := []string{"id", "user_name", "sex", "uuid", "age"}
|
|
||||||
assert.Equal(t, out, actual)
|
assert.Equal(t, out, actual)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +53,7 @@ func TestBuilderSql(t *testing.T) {
|
|||||||
sql, args, err := builder.Select(fields...).From("user").Where(eq).ToSQL()
|
sql, args, err := builder.Select(fields...).From("user").Where(eq).ToSQL()
|
||||||
fmt.Println(sql, args, err)
|
fmt.Println(sql, args, err)
|
||||||
|
|
||||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id=?"
|
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE id=?"
|
||||||
actualArgs := []interface{}{"123123"}
|
actualArgs := []interface{}{"123123"}
|
||||||
assert.Equal(t, sql, actualSql)
|
assert.Equal(t, sql, actualSql)
|
||||||
assert.Equal(t, args, actualArgs)
|
assert.Equal(t, args, actualArgs)
|
||||||
@@ -68,7 +67,7 @@ func TestBuildSqlDefaultValue(t *testing.T) {
|
|||||||
sql, args, err := builder.Select(userFields...).From("user").Where(eq).ToSQL()
|
sql, args, err := builder.Select(userFields...).From("user").Where(eq).ToSQL()
|
||||||
fmt.Println(sql, args, err)
|
fmt.Println(sql, args, err)
|
||||||
|
|
||||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE age=? AND user_name=?"
|
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE age=? AND user_name=?"
|
||||||
actualArgs := []interface{}{0, ""}
|
actualArgs := []interface{}{0, ""}
|
||||||
assert.Equal(t, sql, actualSql)
|
assert.Equal(t, sql, actualSql)
|
||||||
assert.Equal(t, args, actualArgs)
|
assert.Equal(t, args, actualArgs)
|
||||||
@@ -83,7 +82,7 @@ func TestBuilderSqlIn(t *testing.T) {
|
|||||||
sql, args, err := builder.Select(userFields...).From("user").Where(in).And(gtU).ToSQL()
|
sql, args, err := builder.Select(userFields...).From("user").Where(in).And(gtU).ToSQL()
|
||||||
fmt.Println(sql, args, err)
|
fmt.Println(sql, args, err)
|
||||||
|
|
||||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id IN (?,?,?) AND age>?"
|
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE id IN (?,?,?) AND age>?"
|
||||||
actualArgs := []interface{}{"1", "2", "3", 18}
|
actualArgs := []interface{}{"1", "2", "3", 18}
|
||||||
assert.Equal(t, sql, actualSql)
|
assert.Equal(t, sql, actualSql)
|
||||||
assert.Equal(t, args, actualArgs)
|
assert.Equal(t, args, actualArgs)
|
||||||
@@ -94,7 +93,7 @@ func TestBuildSqlLike(t *testing.T) {
|
|||||||
sql, args, err := builder.Select(userFields...).From("user").Where(like).ToSQL()
|
sql, args, err := builder.Select(userFields...).From("user").Where(like).ToSQL()
|
||||||
fmt.Println(sql, args, err)
|
fmt.Println(sql, args, err)
|
||||||
|
|
||||||
actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE name LIKE ?"
|
actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE name LIKE ?"
|
||||||
actualArgs := []interface{}{"%wang%"}
|
actualArgs := []interface{}{"%wang%"}
|
||||||
assert.Equal(t, sql, actualSql)
|
assert.Equal(t, sql, actualSql)
|
||||||
assert.Equal(t, args, actualArgs)
|
assert.Equal(t, args, actualArgs)
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# generate model with cache from ddl
|
# generate model with cache from ddl
|
||||||
fromDDL:
|
fromDDLWithCache:
|
||||||
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -cache
|
goctl template clean;
|
||||||
|
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/cache/user" -cache;
|
||||||
|
|
||||||
|
fromDDLWithoutCache:
|
||||||
|
goctl template clean;
|
||||||
|
goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/nocache/user";
|
||||||
|
|
||||||
|
|
||||||
# generate model with cache from data source
|
# generate model with cache from data source
|
||||||
@@ -12,4 +17,5 @@ datasource=127.0.0.1:3306
|
|||||||
database=gozero
|
database=gozero
|
||||||
|
|
||||||
fromDataSource:
|
fromDataSource:
|
||||||
goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style gozero
|
goctl template clean;
|
||||||
|
goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style gozero;
|
||||||
@@ -36,7 +36,7 @@ func genDelete(table Table, withCache bool) (string, string, error) {
|
|||||||
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
|
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
|
||||||
"dataType": table.PrimaryKey.DataType,
|
"dataType": table.PrimaryKey.DataType,
|
||||||
"keys": strings.Join(keySet.KeysStr(), "\n"),
|
"keys": strings.Join(keySet.KeysStr(), "\n"),
|
||||||
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
|
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||||
"keyValues": strings.Join(keyVariableSet.KeysStr(), ", "),
|
"keyValues": strings.Join(keyVariableSet.KeysStr(), ", "),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func genFindOne(table Table, withCache bool) (string, string, error) {
|
|||||||
"withCache": withCache,
|
"withCache": withCache,
|
||||||
"upperStartCamelObject": camel,
|
"upperStartCamelObject": camel,
|
||||||
"lowerStartCamelObject": stringx.From(camel).Untitle(),
|
"lowerStartCamelObject": stringx.From(camel).Untitle(),
|
||||||
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
|
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||||
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
|
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
|
||||||
"dataType": table.PrimaryKey.DataType,
|
"dataType": table.PrimaryKey.DataType,
|
||||||
"cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression,
|
"cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression,
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
|||||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||||
"lowerStartCamelField": stringx.From(camelFieldName).Untitle(),
|
"lowerStartCamelField": stringx.From(camelFieldName).Untitle(),
|
||||||
"upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(),
|
"upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(),
|
||||||
"originalField": field.Name.Source(),
|
"originalField": wrapWithRawString(field.Name.Source()),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -82,7 +82,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
|||||||
"upperStartCamelObject": camelTableName,
|
"upperStartCamelObject": camelTableName,
|
||||||
"primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left,
|
"primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left,
|
||||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||||
"originalPrimaryField": table.PrimaryKey.Name.Source(),
|
"originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -21,9 +21,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
pwd = "."
|
pwd = "."
|
||||||
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
|
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
|
||||||
NamingLower = "lower"
|
|
||||||
NamingCamel = "camel"
|
|
||||||
NamingSnake = "snake"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@@ -280,3 +277,20 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
|
|||||||
|
|
||||||
return output.String(), nil
|
return output.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func wrapWithRawString(v string) string {
|
||||||
|
if v == "`" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(v, "`") {
|
||||||
|
v = "`" + v
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(v, "`") {
|
||||||
|
v = v + "`"
|
||||||
|
} else if len(v) == 1 {
|
||||||
|
v = v + "`"
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
package gen
|
package gen
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tal-tech/go-zero/core/logx"
|
"github.com/tal-tech/go-zero/core/logx"
|
||||||
|
"github.com/tal-tech/go-zero/core/stringx"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/config"
|
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -79,3 +84,32 @@ func TestNamingModel(t *testing.T) {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}())
|
}())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWrapWithRawString(t *testing.T) {
|
||||||
|
assert.Equal(t, "``", wrapWithRawString(""))
|
||||||
|
assert.Equal(t, "``", wrapWithRawString("``"))
|
||||||
|
assert.Equal(t, "`a`", wrapWithRawString("a"))
|
||||||
|
assert.Equal(t, "` `", wrapWithRawString(" "))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFields(t *testing.T) {
|
||||||
|
type Student struct {
|
||||||
|
Id int64 `db:"id"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
Age sql.NullInt64 `db:"age"`
|
||||||
|
Score sql.NullFloat64 `db:"score"`
|
||||||
|
CreateTime time.Time `db:"create_time"`
|
||||||
|
UpdateTime sql.NullTime `db:"update_time"`
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
studentFieldNames = builderx.FieldNames(&Student{})
|
||||||
|
studentRows = strings.Join(studentFieldNames, ",")
|
||||||
|
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
|
||||||
|
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"`id`", "`name`", "`age`", "`score`", "`create_time`", "`update_time`"}, studentFieldNames)
|
||||||
|
assert.Equal(t, "`id`,`name`,`age`,`score`,`create_time`,`update_time`", studentRows)
|
||||||
|
assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet)
|
||||||
|
assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder)
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ func genNew(table Table, withCache bool) (string, error) {
|
|||||||
output, err := util.With("new").
|
output, err := util.With("new").
|
||||||
Parse(text).
|
Parse(text).
|
||||||
Execute(map[string]interface{}{
|
Execute(map[string]interface{}{
|
||||||
"table": table.Name.Source(),
|
"table": wrapWithRawString(table.Name.Source()),
|
||||||
"withCache": withCache,
|
"withCache": withCache,
|
||||||
"upperStartCamelObject": table.Name.ToCamel(),
|
"upperStartCamelObject": table.Name.ToCamel(),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -54,6 +54,14 @@ var templates = map[string]string{
|
|||||||
errTemplateFile: template.Error,
|
errTemplateFile: template.Error,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Category() string {
|
||||||
|
return category
|
||||||
|
}
|
||||||
|
|
||||||
|
func Clean() error {
|
||||||
|
return util.Clean(category)
|
||||||
|
}
|
||||||
|
|
||||||
func GenTemplates(_ *cli.Context) error {
|
func GenTemplates(_ *cli.Context) error {
|
||||||
return util.InitTemplates(category, templates)
|
return util.InitTemplates(category, templates)
|
||||||
}
|
}
|
||||||
@@ -66,18 +74,10 @@ func RevertTemplate(name string) error {
|
|||||||
return util.CreateTemplate(category, name, content)
|
return util.CreateTemplate(category, name, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Clean() error {
|
func Update() error {
|
||||||
return util.Clean(category)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Update(category string) error {
|
|
||||||
err := Clean()
|
err := Clean()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return util.InitTemplates(category, templates)
|
return util.InitTemplates(category, templates)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetCategory() string {
|
|
||||||
return category
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func TestUpdate(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, string(data), modifyData)
|
assert.Equal(t, string(data), modifyData)
|
||||||
|
|
||||||
assert.Nil(t, Update(category))
|
assert.Nil(t, Update())
|
||||||
|
|
||||||
data, err = ioutil.ReadFile(file)
|
data, err = ioutil.ReadFile(file)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
|
|||||||
"primaryCacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].DataKeyExpression,
|
"primaryCacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].DataKeyExpression,
|
||||||
"primaryKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable,
|
"primaryKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable,
|
||||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||||
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
|
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||||
"expressionValues": strings.Join(expressionValues, ", "),
|
"expressionValues": strings.Join(expressionValues, ", "),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func genVars(table Table, withCache bool) (string, error) {
|
|||||||
"upperStartCamelObject": camel,
|
"upperStartCamelObject": camel,
|
||||||
"cacheKeys": strings.Join(keys, "\n"),
|
"cacheKeys": strings.Join(keys, "\n"),
|
||||||
"autoIncrement": table.PrimaryKey.AutoIncrement,
|
"autoIncrement": table.PrimaryKey.AutoIncrement,
|
||||||
"originalPrimaryKey": table.PrimaryKey.Name.Source(),
|
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||||
"withCache": withCache,
|
"withCache": withCache,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
|
||||||
DDLModel struct {
|
|
||||||
conn sqlx.SqlConn
|
|
||||||
}
|
|
||||||
DDL struct {
|
|
||||||
Table string `db:"Table"`
|
|
||||||
DDL string `db:"Create Table"`
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewDDLModel(conn sqlx.SqlConn) *DDLModel {
|
|
||||||
return &DDLModel{conn: conn}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DDLModel) ShowDDL(table ...string) ([]string, error) {
|
|
||||||
var ddl []string
|
|
||||||
for _, t := range table {
|
|
||||||
query := `show create table ` + t
|
|
||||||
var resp DDL
|
|
||||||
err := m.conn.QueryRow(&resp, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ddl = append(ddl, resp.DDL)
|
|
||||||
}
|
|
||||||
return ddl, nil
|
|
||||||
}
|
|
||||||
@@ -1,12 +1,14 @@
|
|||||||
package template
|
package template
|
||||||
|
|
||||||
var Vars = `
|
import "fmt"
|
||||||
|
|
||||||
|
var Vars = fmt.Sprintf(`
|
||||||
var (
|
var (
|
||||||
{{.lowerStartCamelObject}}FieldNames = builderx.FieldNames(&{{.upperStartCamelObject}}{})
|
{{.lowerStartCamelObject}}FieldNames = builderx.FieldNames(&{{.upperStartCamelObject}}{})
|
||||||
{{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",")
|
{{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",")
|
||||||
{{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "create_time", "update_time"), ",")
|
{{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ",")
|
||||||
{{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "create_time", "update_time"), "=?,") + "=?"
|
{{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s"), "=?,") + "=?"
|
||||||
|
|
||||||
{{if .withCache}}{{.cacheKeys}}{{end}}
|
{{if .withCache}}{{.cacheKeys}}{{end}}
|
||||||
)
|
)
|
||||||
`
|
`, "`", "`", "`", "`", "`", "`", "`", "`")
|
||||||
|
|||||||
235
tools/goctl/model/sql/test/model/model_test.go
Normal file
235
tools/goctl/model/sql/test/model/model_test.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/cache"
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/redis"
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/redis/redistest"
|
||||||
|
mocksql "github.com/tal-tech/go-zero/tools/goctl/model/sql/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStudentModel(t *testing.T) {
|
||||||
|
var (
|
||||||
|
testTimeValue = time.Now()
|
||||||
|
testTable = "`student`"
|
||||||
|
testUpdateName = "gozero1"
|
||||||
|
testRowsAffected int64 = 1
|
||||||
|
testInsertId int64 = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
var data Student
|
||||||
|
data.Id = testInsertId
|
||||||
|
data.Name = "gozero"
|
||||||
|
data.Age = sql.NullInt64{
|
||||||
|
Int64: 1,
|
||||||
|
Valid: true,
|
||||||
|
}
|
||||||
|
data.Score = sql.NullFloat64{
|
||||||
|
Float64: 100,
|
||||||
|
Valid: true,
|
||||||
|
}
|
||||||
|
data.CreateTime = testTimeValue
|
||||||
|
data.UpdateTime = sql.NullTime{
|
||||||
|
Time: testTimeValue,
|
||||||
|
Valid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
|
||||||
|
WithArgs(data.Name, data.Age, data.Score).
|
||||||
|
WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
|
}, func(m StudentModel) {
|
||||||
|
r, err := m.Insert(data)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
lastInsertId, err := r.LastInsertId()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, testInsertId, lastInsertId)
|
||||||
|
|
||||||
|
rowsAffected, err := r.RowsAffected()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, testRowsAffected, rowsAffected)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
|
||||||
|
WithArgs(testInsertId).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
|
||||||
|
}, func(m StudentModel) {
|
||||||
|
result, err := m.FindOne(testInsertId)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, *result, data)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(testUpdateName, data.Age, data.Score, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
|
}, func(m StudentModel) {
|
||||||
|
data.Name = testUpdateName
|
||||||
|
err := m.Update(data)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
|
||||||
|
WithArgs(testInsertId).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
|
||||||
|
}, func(m StudentModel) {
|
||||||
|
result, err := m.FindOne(testInsertId)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, *result, data)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
|
}, func(m StudentModel) {
|
||||||
|
err := m.Delete(testInsertId)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserModel(t *testing.T) {
|
||||||
|
var (
|
||||||
|
testTimeValue = time.Now()
|
||||||
|
testTable = "`user`"
|
||||||
|
testUpdateName = "gozero1"
|
||||||
|
testUser = "gozero"
|
||||||
|
testPassword = "test"
|
||||||
|
testMobile = "test_mobile"
|
||||||
|
testGender = "男"
|
||||||
|
testNickname = "test_nickname"
|
||||||
|
testRowsAffected int64 = 1
|
||||||
|
testInsertId int64 = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
var data User
|
||||||
|
data.Id = testInsertId
|
||||||
|
data.User = testUser
|
||||||
|
data.Name = "gozero"
|
||||||
|
data.Password = testPassword
|
||||||
|
data.Mobile = testMobile
|
||||||
|
data.Gender = testGender
|
||||||
|
data.Nickname = testNickname
|
||||||
|
data.CreateTime = testTimeValue
|
||||||
|
data.UpdateTime = testTimeValue
|
||||||
|
|
||||||
|
err := mockUser(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
|
||||||
|
WithArgs(data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname).
|
||||||
|
WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
|
}, func(m UserModel) {
|
||||||
|
r, err := m.Insert(data)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
lastInsertId, err := r.LastInsertId()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, testInsertId, lastInsertId)
|
||||||
|
|
||||||
|
rowsAffected, err := r.RowsAffected()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, testRowsAffected, rowsAffected)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
|
||||||
|
WithArgs(testInsertId).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
|
||||||
|
}, func(m UserModel) {
|
||||||
|
result, err := m.FindOne(testInsertId)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, *result, data)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.User, testUpdateName, data.Password, data.Mobile, data.Gender, data.Nickname, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
|
}, func(m UserModel) {
|
||||||
|
data.Name = testUpdateName
|
||||||
|
err := m.Update(data)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
|
||||||
|
WithArgs(testInsertId).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
|
||||||
|
}, func(m UserModel) {
|
||||||
|
result, err := m.FindOne(testInsertId)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, *result, data)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
|
}, func(m UserModel) {
|
||||||
|
err := m.Delete(testInsertId)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// with cache
|
||||||
|
func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel)) error {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mockFn(mock)
|
||||||
|
mock.ExpectCommit()
|
||||||
|
|
||||||
|
conn := mocksql.NewMockConn(db)
|
||||||
|
r, clean, err := redistest.CreateRedis()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer clean()
|
||||||
|
|
||||||
|
m := NewStudentModel(conn, cache.CacheConf{
|
||||||
|
{
|
||||||
|
RedisConf: redis.RedisConf{
|
||||||
|
Host: r.Addr,
|
||||||
|
Type: "node",
|
||||||
|
},
|
||||||
|
Weight: 100,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
fn(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// without cache
|
||||||
|
func mockUser(mockFn func(mock sqlmock.Sqlmock), fn func(m UserModel)) error {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mockFn(mock)
|
||||||
|
mock.ExpectCommit()
|
||||||
|
|
||||||
|
conn := mocksql.NewMockConn(db)
|
||||||
|
m := NewUserModel(conn)
|
||||||
|
fn(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
105
tools/goctl/model/sql/test/model/studentmodel.go
Executable file
105
tools/goctl/model/sql/test/model/studentmodel.go
Executable file
@@ -0,0 +1,105 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/cache"
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/sqlc"
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||||
|
"github.com/tal-tech/go-zero/core/stringx"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
studentFieldNames = builderx.FieldNames(&Student{})
|
||||||
|
studentRows = strings.Join(studentFieldNames, ",")
|
||||||
|
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
|
||||||
|
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
|
||||||
|
|
||||||
|
cacheStudentIdPrefix = "cache#Student#id#"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
StudentModel interface {
|
||||||
|
Insert(data Student) (sql.Result, error)
|
||||||
|
FindOne(id int64) (*Student, error)
|
||||||
|
Update(data Student) error
|
||||||
|
Delete(id int64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultStudentModel struct {
|
||||||
|
sqlc.CachedConn
|
||||||
|
table string
|
||||||
|
}
|
||||||
|
|
||||||
|
Student struct {
|
||||||
|
Id int64 `db:"id"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
Age sql.NullInt64 `db:"age"`
|
||||||
|
Score sql.NullFloat64 `db:"score"`
|
||||||
|
CreateTime time.Time `db:"create_time"`
|
||||||
|
UpdateTime sql.NullTime `db:"update_time"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewStudentModel(conn sqlx.SqlConn, c cache.CacheConf) StudentModel {
|
||||||
|
return &defaultStudentModel{
|
||||||
|
CachedConn: sqlc.NewConn(conn, c),
|
||||||
|
table: "`student`",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultStudentModel) Insert(data Student) (sql.Result, error) {
|
||||||
|
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?)", m.table, studentRowsExpectAutoSet)
|
||||||
|
ret, err := m.ExecNoCache(query, data.Name, data.Age, data.Score)
|
||||||
|
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultStudentModel) FindOne(id int64) (*Student, error) {
|
||||||
|
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id)
|
||||||
|
var resp Student
|
||||||
|
err := m.QueryRow(&resp, studentIdKey, func(conn sqlx.SqlConn, v interface{}) error {
|
||||||
|
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table)
|
||||||
|
return conn.QueryRow(v, query, id)
|
||||||
|
})
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return &resp, nil
|
||||||
|
case sqlc.ErrNotFound:
|
||||||
|
return nil, ErrNotFound
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultStudentModel) Update(data Student) error {
|
||||||
|
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, data.Id)
|
||||||
|
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||||
|
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, studentRowsWithPlaceHolder)
|
||||||
|
return conn.Exec(query, data.Name, data.Age, data.Score, data.Id)
|
||||||
|
}, studentIdKey)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultStudentModel) Delete(id int64) error {
|
||||||
|
|
||||||
|
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id)
|
||||||
|
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||||
|
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
|
||||||
|
return conn.Exec(query, id)
|
||||||
|
}, studentIdKey)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultStudentModel) formatPrimary(primary interface{}) string {
|
||||||
|
return fmt.Sprintf("%s%v", cacheStudentIdPrefix, primary)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultStudentModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
|
||||||
|
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table)
|
||||||
|
return conn.QueryRow(v, query, primary)
|
||||||
|
}
|
||||||
130
tools/goctl/model/sql/test/model/usermodel.go
Executable file
130
tools/goctl/model/sql/test/model/usermodel.go
Executable file
@@ -0,0 +1,130 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/sqlc"
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||||
|
"github.com/tal-tech/go-zero/core/stringx"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
userFieldNames = builderx.FieldNames(&User{})
|
||||||
|
userRows = strings.Join(userFieldNames, ",")
|
||||||
|
userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
|
||||||
|
userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
UserModel interface {
|
||||||
|
Insert(data User) (sql.Result, error)
|
||||||
|
FindOne(id int64) (*User, error)
|
||||||
|
FindOneByUser(user string) (*User, error)
|
||||||
|
FindOneByName(name string) (*User, error)
|
||||||
|
FindOneByMobile(mobile string) (*User, error)
|
||||||
|
Update(data User) error
|
||||||
|
Delete(id int64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultUserModel struct {
|
||||||
|
conn sqlx.SqlConn
|
||||||
|
table string
|
||||||
|
}
|
||||||
|
|
||||||
|
User struct {
|
||||||
|
Id int64 `db:"id"`
|
||||||
|
User string `db:"user"` // 用户
|
||||||
|
Name string `db:"name"` // 用户名称
|
||||||
|
Password string `db:"password"` // 用户密码
|
||||||
|
Mobile string `db:"mobile"` // 手机号
|
||||||
|
Gender string `db:"gender"` // 男|女|未公开
|
||||||
|
Nickname string `db:"nickname"` // 用户昵称
|
||||||
|
CreateTime time.Time `db:"create_time"`
|
||||||
|
UpdateTime time.Time `db:"update_time"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewUserModel(conn sqlx.SqlConn) UserModel {
|
||||||
|
return &defaultUserModel{
|
||||||
|
conn: conn,
|
||||||
|
table: "`user`",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) Insert(data User) (sql.Result, error) {
|
||||||
|
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?, ?)", m.table, userRowsExpectAutoSet)
|
||||||
|
ret, err := m.conn.Exec(query, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname)
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) FindOne(id int64) (*User, error) {
|
||||||
|
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", userRows, m.table)
|
||||||
|
var resp User
|
||||||
|
err := m.conn.QueryRow(&resp, query, id)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return &resp, nil
|
||||||
|
case sqlc.ErrNotFound:
|
||||||
|
return nil, ErrNotFound
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) FindOneByUser(user string) (*User, error) {
|
||||||
|
var resp User
|
||||||
|
query := fmt.Sprintf("select %s from %s where `user` = ? limit 1", userRows, m.table)
|
||||||
|
err := m.conn.QueryRow(&resp, query, user)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return &resp, nil
|
||||||
|
case sqlc.ErrNotFound:
|
||||||
|
return nil, ErrNotFound
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) FindOneByName(name string) (*User, error) {
|
||||||
|
var resp User
|
||||||
|
query := fmt.Sprintf("select %s from %s where `name` = ? limit 1", userRows, m.table)
|
||||||
|
err := m.conn.QueryRow(&resp, query, name)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return &resp, nil
|
||||||
|
case sqlc.ErrNotFound:
|
||||||
|
return nil, ErrNotFound
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) FindOneByMobile(mobile string) (*User, error) {
|
||||||
|
var resp User
|
||||||
|
query := fmt.Sprintf("select %s from %s where `mobile` = ? limit 1", userRows, m.table)
|
||||||
|
err := m.conn.QueryRow(&resp, query, mobile)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return &resp, nil
|
||||||
|
case sqlc.ErrNotFound:
|
||||||
|
return nil, ErrNotFound
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) Update(data User) error {
|
||||||
|
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, userRowsWithPlaceHolder)
|
||||||
|
_, err := m.conn.Exec(query, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, data.Id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) Delete(id int64) error {
|
||||||
|
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
|
||||||
|
_, err := m.conn.Exec(query, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
5
tools/goctl/model/sql/test/model/vars.go
Normal file
5
tools/goctl/model/sql/test/model/vars.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||||
|
|
||||||
|
var ErrNotFound = sqlx.ErrNotFound
|
||||||
255
tools/goctl/model/sql/test/orm.go
Normal file
255
tools/goctl/model/sql/test/orm.go
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
// copy from core/stores/sqlx/orm.go
|
||||||
|
package mocksql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/core/mapping"
|
||||||
|
)
|
||||||
|
|
||||||
|
const tagName = "db"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotMatchDestination = errors.New("not matching destination to scan")
|
||||||
|
ErrNotReadableValue = errors.New("value not addressable or interfaceable")
|
||||||
|
ErrNotSettable = errors.New("passed in variable is not settable")
|
||||||
|
ErrUnsupportedValueType = errors.New("unsupported unmarshal type")
|
||||||
|
)
|
||||||
|
|
||||||
|
type rowsScanner interface {
|
||||||
|
Columns() ([]string, error)
|
||||||
|
Err() error
|
||||||
|
Next() bool
|
||||||
|
Scan(v ...interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) {
|
||||||
|
rt := mapping.Deref(v.Type())
|
||||||
|
size := rt.NumField()
|
||||||
|
result := make(map[string]interface{}, size)
|
||||||
|
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
key := parseTagName(rt.Field(i))
|
||||||
|
if len(key) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
valueField := reflect.Indirect(v).Field(i)
|
||||||
|
switch valueField.Kind() {
|
||||||
|
case reflect.Ptr:
|
||||||
|
if !valueField.CanInterface() {
|
||||||
|
return nil, ErrNotReadableValue
|
||||||
|
}
|
||||||
|
if valueField.IsNil() {
|
||||||
|
baseValueType := mapping.Deref(valueField.Type())
|
||||||
|
valueField.Set(reflect.New(baseValueType))
|
||||||
|
}
|
||||||
|
result[key] = valueField.Interface()
|
||||||
|
default:
|
||||||
|
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
||||||
|
return nil, ErrNotReadableValue
|
||||||
|
}
|
||||||
|
result[key] = valueField.Addr().Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) {
|
||||||
|
fields := unwrapFields(v)
|
||||||
|
if strict && len(columns) < len(fields) {
|
||||||
|
return nil, ErrNotMatchDestination
|
||||||
|
}
|
||||||
|
|
||||||
|
taggedMap, err := getTaggedFieldValueMap(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
values := make([]interface{}, len(columns))
|
||||||
|
if len(taggedMap) == 0 {
|
||||||
|
for i := 0; i < len(values); i++ {
|
||||||
|
valueField := fields[i]
|
||||||
|
switch valueField.Kind() {
|
||||||
|
case reflect.Ptr:
|
||||||
|
if !valueField.CanInterface() {
|
||||||
|
return nil, ErrNotReadableValue
|
||||||
|
}
|
||||||
|
if valueField.IsNil() {
|
||||||
|
baseValueType := mapping.Deref(valueField.Type())
|
||||||
|
valueField.Set(reflect.New(baseValueType))
|
||||||
|
}
|
||||||
|
values[i] = valueField.Interface()
|
||||||
|
default:
|
||||||
|
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
||||||
|
return nil, ErrNotReadableValue
|
||||||
|
}
|
||||||
|
values[i] = valueField.Addr().Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i, column := range columns {
|
||||||
|
if tagged, ok := taggedMap[column]; ok {
|
||||||
|
values[i] = tagged
|
||||||
|
} else {
|
||||||
|
var anonymous interface{}
|
||||||
|
values[i] = &anonymous
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return values, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTagName(field reflect.StructField) string {
|
||||||
|
key := field.Tag.Get(tagName)
|
||||||
|
if len(key) == 0 {
|
||||||
|
return ""
|
||||||
|
} else {
|
||||||
|
options := strings.Split(key, ",")
|
||||||
|
return options[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error {
|
||||||
|
if !scanner.Next() {
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
if err := mapping.ValidatePtr(&rv); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rte := reflect.TypeOf(v).Elem()
|
||||||
|
rve := rv.Elem()
|
||||||
|
switch rte.Kind() {
|
||||||
|
case reflect.Bool,
|
||||||
|
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||||
|
reflect.Float32, reflect.Float64,
|
||||||
|
reflect.String:
|
||||||
|
if rve.CanSet() {
|
||||||
|
return scanner.Scan(v)
|
||||||
|
} else {
|
||||||
|
return ErrNotSettable
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
columns, err := scanner.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if values, err := mapStructFieldsIntoSlice(rve, columns, strict); err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
return scanner.Scan(values...)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return ErrUnsupportedValueType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error {
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
if err := mapping.ValidatePtr(&rv); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := reflect.TypeOf(v)
|
||||||
|
rte := rt.Elem()
|
||||||
|
rve := rv.Elem()
|
||||||
|
switch rte.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
if rve.CanSet() {
|
||||||
|
ptr := rte.Elem().Kind() == reflect.Ptr
|
||||||
|
appendFn := func(item reflect.Value) {
|
||||||
|
if ptr {
|
||||||
|
rve.Set(reflect.Append(rve, item))
|
||||||
|
} else {
|
||||||
|
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fillFn := func(value interface{}) error {
|
||||||
|
if rve.CanSet() {
|
||||||
|
if err := scanner.Scan(value); err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
appendFn(reflect.ValueOf(value))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ErrNotSettable
|
||||||
|
}
|
||||||
|
|
||||||
|
base := mapping.Deref(rte.Elem())
|
||||||
|
switch base.Kind() {
|
||||||
|
case reflect.Bool,
|
||||||
|
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||||
|
reflect.Float32, reflect.Float64,
|
||||||
|
reflect.String:
|
||||||
|
for scanner.Next() {
|
||||||
|
value := reflect.New(base)
|
||||||
|
if err := fillFn(value.Interface()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
columns, err := scanner.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for scanner.Next() {
|
||||||
|
value := reflect.New(base)
|
||||||
|
if values, err := mapStructFieldsIntoSlice(value, columns, strict); err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
if err := scanner.Scan(values...); err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
appendFn(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return ErrUnsupportedValueType
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return ErrNotSettable
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return ErrUnsupportedValueType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unwrapFields(v reflect.Value) []reflect.Value {
|
||||||
|
var fields []reflect.Value
|
||||||
|
indirect := reflect.Indirect(v)
|
||||||
|
|
||||||
|
for i := 0; i < indirect.NumField(); i++ {
|
||||||
|
child := indirect.Field(i)
|
||||||
|
if child.Kind() == reflect.Ptr && child.IsNil() {
|
||||||
|
baseValueType := mapping.Deref(child.Type())
|
||||||
|
child.Set(reflect.New(baseValueType))
|
||||||
|
}
|
||||||
|
|
||||||
|
child = reflect.Indirect(child)
|
||||||
|
childType := indirect.Type().Field(i)
|
||||||
|
if child.Kind() == reflect.Struct && childType.Anonymous {
|
||||||
|
fields = append(fields, unwrapFields(child)...)
|
||||||
|
} else {
|
||||||
|
fields = append(fields, child)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields
|
||||||
|
}
|
||||||
90
tools/goctl/model/sql/test/sqlconn.go
Normal file
90
tools/goctl/model/sql/test/sqlconn.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
// copy from core/stores/sqlx/sqlconn.go
|
||||||
|
package mocksql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
MockConn struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
statement struct {
|
||||||
|
stmt *sql.Stmt
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewMockConn(db *sql.DB) *MockConn {
|
||||||
|
return &MockConn{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *MockConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
return exec(conn.db, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *MockConn) Prepare(query string) (sqlx.StmtSession, error) {
|
||||||
|
st, err := conn.db.Prepare(query)
|
||||||
|
return statement{stmt: st}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *MockConn) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||||
|
return query(conn.db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(v, rows, true)
|
||||||
|
}, q, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *MockConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||||
|
return query(conn.db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(v, rows, false)
|
||||||
|
}, q, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *MockConn) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||||
|
return query(conn.db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(v, rows, true)
|
||||||
|
}, q, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *MockConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||||
|
return query(conn.db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(v, rows, false)
|
||||||
|
}, q, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s statement) Close() error {
|
||||||
|
return s.stmt.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
|
||||||
|
return execStmt(s.stmt, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
|
||||||
|
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(v, rows, true)
|
||||||
|
}, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
|
||||||
|
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(v, rows, false)
|
||||||
|
}, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
|
||||||
|
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(v, rows, true)
|
||||||
|
}, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
|
||||||
|
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(v, rows, false)
|
||||||
|
}, args...)
|
||||||
|
}
|
||||||
122
tools/goctl/model/sql/test/stmt.go
Normal file
122
tools/goctl/model/sql/test/stmt.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
// copy from core/stores/sqlx/stmt.go
|
||||||
|
|
||||||
|
package mocksql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/core/logx"
|
||||||
|
"github.com/tal-tech/go-zero/core/timex"
|
||||||
|
)
|
||||||
|
|
||||||
|
const slowThreshold = time.Millisecond * 500
|
||||||
|
|
||||||
|
func exec(db *sql.DB, q string, args ...interface{}) (sql.Result, error) {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
err = tx.Commit()
|
||||||
|
default:
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
stmt, err := format(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := timex.Now()
|
||||||
|
result, err := tx.Exec(q, args...)
|
||||||
|
duration := timex.Since(startTime)
|
||||||
|
if duration > slowThreshold {
|
||||||
|
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||||
|
} else {
|
||||||
|
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logSqlError(stmt, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func execStmt(conn *sql.Stmt, args ...interface{}) (sql.Result, error) {
|
||||||
|
stmt := fmt.Sprint(args...)
|
||||||
|
startTime := timex.Now()
|
||||||
|
result, err := conn.Exec(args...)
|
||||||
|
duration := timex.Since(startTime)
|
||||||
|
if duration > slowThreshold {
|
||||||
|
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||||
|
} else {
|
||||||
|
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logSqlError(stmt, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func query(db *sql.DB, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
err = tx.Commit()
|
||||||
|
default:
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
stmt, err := format(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := timex.Now()
|
||||||
|
rows, err := tx.Query(q, args...)
|
||||||
|
duration := timex.Since(startTime)
|
||||||
|
if duration > slowThreshold {
|
||||||
|
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||||
|
} else {
|
||||||
|
logx.WithDuration(duration).Infof("sql query: %s", stmt)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logSqlError(stmt, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanner(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryStmt(conn *sql.Stmt, scanner func(*sql.Rows) error, args ...interface{}) error {
|
||||||
|
stmt := fmt.Sprint(args...)
|
||||||
|
startTime := timex.Now()
|
||||||
|
rows, err := conn.Query(args...)
|
||||||
|
duration := timex.Since(startTime)
|
||||||
|
if duration > slowThreshold {
|
||||||
|
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
|
||||||
|
} else {
|
||||||
|
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logSqlError(stmt, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return scanner(rows)
|
||||||
|
}
|
||||||
105
tools/goctl/model/sql/test/utils.go
Normal file
105
tools/goctl/model/sql/test/utils.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
// copy from core/stores/sqlx/utils.go
|
||||||
|
package mocksql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/core/logx"
|
||||||
|
"github.com/tal-tech/go-zero/core/mapping"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrNotFound = sql.ErrNoRows
|
||||||
|
|
||||||
|
func desensitize(datasource string) string {
|
||||||
|
// remove account
|
||||||
|
pos := strings.LastIndex(datasource, "@")
|
||||||
|
if 0 <= pos && pos+1 < len(datasource) {
|
||||||
|
datasource = datasource[pos+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return datasource
|
||||||
|
}
|
||||||
|
|
||||||
|
func escape(input string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
for _, ch := range input {
|
||||||
|
switch ch {
|
||||||
|
case '\x00':
|
||||||
|
b.WriteString(`\x00`)
|
||||||
|
case '\r':
|
||||||
|
b.WriteString(`\r`)
|
||||||
|
case '\n':
|
||||||
|
b.WriteString(`\n`)
|
||||||
|
case '\\':
|
||||||
|
b.WriteString(`\\`)
|
||||||
|
case '\'':
|
||||||
|
b.WriteString(`\'`)
|
||||||
|
case '"':
|
||||||
|
b.WriteString(`\"`)
|
||||||
|
case '\x1a':
|
||||||
|
b.WriteString(`\x1a`)
|
||||||
|
default:
|
||||||
|
b.WriteRune(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func format(query string, args ...interface{}) (string, error) {
|
||||||
|
numArgs := len(args)
|
||||||
|
if numArgs == 0 {
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
argIndex := 0
|
||||||
|
|
||||||
|
for _, ch := range query {
|
||||||
|
if ch == '?' {
|
||||||
|
if argIndex >= numArgs {
|
||||||
|
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
arg := args[argIndex]
|
||||||
|
argIndex++
|
||||||
|
|
||||||
|
switch v := arg.(type) {
|
||||||
|
case bool:
|
||||||
|
if v {
|
||||||
|
b.WriteByte('1')
|
||||||
|
} else {
|
||||||
|
b.WriteByte('0')
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
b.WriteByte('\'')
|
||||||
|
b.WriteString(escape(v))
|
||||||
|
b.WriteByte('\'')
|
||||||
|
default:
|
||||||
|
b.WriteString(mapping.Repr(v))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
b.WriteRune(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if argIndex < numArgs {
|
||||||
|
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func logInstanceError(datasource string, err error) {
|
||||||
|
datasource = desensitize(datasource)
|
||||||
|
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func logSqlError(stmt string, err error) {
|
||||||
|
if err != nil && err != ErrNotFound {
|
||||||
|
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,9 +25,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Plugin struct {
|
type Plugin struct {
|
||||||
Api *spec.ApiSpec
|
Api *spec.ApiSpec
|
||||||
Style string
|
ApiFilePath string
|
||||||
Dir string
|
Style string
|
||||||
|
Dir string
|
||||||
}
|
}
|
||||||
|
|
||||||
func PluginCommand(c *cli.Context) error {
|
func PluginCommand(c *cli.Context) error {
|
||||||
@@ -86,6 +87,12 @@ func prepareArgs(c *cli.Context) ([]byte, error) {
|
|||||||
transferData.Api = api
|
transferData.Api = api
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absApiFilePath, err := filepath.Abs(apiPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
transferData.ApiFilePath = absApiFilePath
|
||||||
dirAbs, err := filepath.Abs(c.String("dir"))
|
dirAbs, err := filepath.Abs(c.String("dir"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"go/build"
|
"go/build"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -44,7 +45,9 @@ func TestRpcGenerate(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = execx.Run("go test "+projectName, projectDir)
|
_, err = execx.Run("go test "+projectName, projectDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
assert.Contains(t, err.Error(), "not in GOROOT")
|
assert.True(t, func() bool {
|
||||||
|
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||||
|
}())
|
||||||
}
|
}
|
||||||
|
|
||||||
// case go mod
|
// case go mod
|
||||||
@@ -61,7 +64,9 @@ func TestRpcGenerate(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = execx.Run("go test "+projectName, projectDir)
|
_, err = execx.Run("go test "+projectName, projectDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
assert.Contains(t, err.Error(), "not in GOROOT")
|
assert.True(t, func() bool {
|
||||||
|
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||||
|
}())
|
||||||
}
|
}
|
||||||
|
|
||||||
// case not in go mod and go path
|
// case not in go mod and go path
|
||||||
@@ -69,7 +74,9 @@ func TestRpcGenerate(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = execx.Run("go test "+projectName, projectDir)
|
_, err = execx.Run("go test "+projectName, projectDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
assert.Contains(t, err.Error(), "not in GOROOT")
|
assert.True(t, func() bool {
|
||||||
|
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||||
|
}())
|
||||||
}
|
}
|
||||||
|
|
||||||
// invalid directory
|
// invalid directory
|
||||||
|
|||||||
@@ -15,12 +15,12 @@ const svcTemplate = `package svc
|
|||||||
import {{.imports}}
|
import {{.imports}}
|
||||||
|
|
||||||
type ServiceContext struct {
|
type ServiceContext struct {
|
||||||
c config.Config
|
Config config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServiceContext(c config.Config) *ServiceContext {
|
func NewServiceContext(c config.Config) *ServiceContext {
|
||||||
return &ServiceContext{
|
return &ServiceContext{
|
||||||
c:c,
|
Config:c,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|||||||
@@ -54,14 +54,15 @@ func Clean() error {
|
|||||||
return util.Clean(category)
|
return util.Clean(category)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Update(category string) error {
|
func Update() error {
|
||||||
err := Clean()
|
err := Clean()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return util.InitTemplates(category, templates)
|
return util.InitTemplates(category, templates)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetCategory() string {
|
func Category() string {
|
||||||
return category
|
return category
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -97,8 +97,7 @@ func TestUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Equal(t, "modify", string(data))
|
assert.Equal(t, "modify", string(data))
|
||||||
|
|
||||||
err = Update(category)
|
assert.Nil(t, Update())
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
data, err = ioutil.ReadFile(mainTpl)
|
data, err = ioutil.ReadFile(mainTpl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -109,6 +108,6 @@ func TestUpdate(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetCategory(t *testing.T) {
|
func TestGetCategory(t *testing.T) {
|
||||||
_ = Clean()
|
_ = Clean()
|
||||||
result := GetCategory()
|
result := Category()
|
||||||
assert.Equal(t, category, result)
|
assert.Equal(t, category, result)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,12 +76,16 @@ func UpdateTemplates(ctx *cli.Context) (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
switch category {
|
switch category {
|
||||||
case gogen.GetCategory():
|
case docker.Category():
|
||||||
return gogen.Update(category)
|
return docker.Update()
|
||||||
case rpcgen.GetCategory():
|
case gogen.Category():
|
||||||
return rpcgen.Update(category)
|
return gogen.Update()
|
||||||
case modelgen.GetCategory():
|
case kube.Category():
|
||||||
return modelgen.Update(category)
|
return kube.Update()
|
||||||
|
case rpcgen.Category():
|
||||||
|
return rpcgen.Update()
|
||||||
|
case modelgen.Category():
|
||||||
|
return modelgen.Update()
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("unexpected category: %s", category)
|
err = fmt.Errorf("unexpected category: %s", category)
|
||||||
return
|
return
|
||||||
@@ -97,11 +101,15 @@ func RevertTemplates(ctx *cli.Context) (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
switch category {
|
switch category {
|
||||||
case gogen.GetCategory():
|
case docker.Category():
|
||||||
|
return docker.RevertTemplate(filename)
|
||||||
|
case kube.Category():
|
||||||
|
return kube.RevertTemplate(filename)
|
||||||
|
case gogen.Category():
|
||||||
return gogen.RevertTemplate(filename)
|
return gogen.RevertTemplate(filename)
|
||||||
case rpcgen.GetCategory():
|
case rpcgen.Category():
|
||||||
return rpcgen.RevertTemplate(filename)
|
return rpcgen.RevertTemplate(filename)
|
||||||
case modelgen.GetCategory():
|
case modelgen.Category():
|
||||||
return modelgen.RevertTemplate(filename)
|
return modelgen.RevertTemplate(filename)
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("unexpected category: %s", category)
|
err = fmt.Errorf("unexpected category: %s", category)
|
||||||
|
|||||||
@@ -60,7 +60,6 @@ func FindGoModPath(dir string) (string, bool) {
|
|||||||
var hasGoMod = false
|
var hasGoMod = false
|
||||||
for {
|
for {
|
||||||
if FileExists(filepath.Join(tempPath, goModeIdentifier)) {
|
if FileExists(filepath.Join(tempPath, goModeIdentifier)) {
|
||||||
tempPath = filepath.Dir(tempPath)
|
|
||||||
rootPath = strings.TrimPrefix(absDir[len(tempPath):], "/")
|
rootPath = strings.TrimPrefix(absDir[len(tempPath):], "/")
|
||||||
hasGoMod = true
|
hasGoMod = true
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -19,10 +19,11 @@ func Untitle(s string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Index(slice []string, item string) int {
|
func Index(slice []string, item string) int {
|
||||||
for i, _ := range slice {
|
for i := range slice {
|
||||||
if slice[i] == item {
|
if slice[i] == item {
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,10 +64,10 @@ func (l *Logger) Warning(args ...interface{}) {
|
|||||||
// ignore builtin grpc warning
|
// ignore builtin grpc warning
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Warningln(args ...interface{}) {
|
|
||||||
// ignore builtin grpc warning
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Warningf(format string, args ...interface{}) {
|
func (l *Logger) Warningf(format string, args ...interface{}) {
|
||||||
// ignore builtin grpc warning
|
// ignore builtin grpc warning
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Warningln(args ...interface{}) {
|
||||||
|
// ignore builtin grpc warning
|
||||||
|
}
|
||||||
|
|||||||
83
zrpc/internal/rpclogger_test.go
Normal file
83
zrpc/internal/rpclogger_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
const content = "foo"
|
||||||
|
|
||||||
|
func TestLoggerError(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
log.SetOutput(&builder)
|
||||||
|
logger := new(Logger)
|
||||||
|
logger.Error(content)
|
||||||
|
assert.Contains(t, builder.String(), content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerErrorf(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
log.SetOutput(&builder)
|
||||||
|
logger := new(Logger)
|
||||||
|
logger.Errorf(content)
|
||||||
|
assert.Contains(t, builder.String(), content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerErrorln(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
log.SetOutput(&builder)
|
||||||
|
logger := new(Logger)
|
||||||
|
logger.Errorln(content)
|
||||||
|
assert.Contains(t, builder.String(), content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerFatal(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
log.SetOutput(&builder)
|
||||||
|
logger := new(Logger)
|
||||||
|
logger.Fatal(content)
|
||||||
|
assert.Contains(t, builder.String(), content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerFatalf(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
log.SetOutput(&builder)
|
||||||
|
logger := new(Logger)
|
||||||
|
logger.Fatalf(content)
|
||||||
|
assert.Contains(t, builder.String(), content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerFatalln(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
log.SetOutput(&builder)
|
||||||
|
logger := new(Logger)
|
||||||
|
logger.Fatalln(content)
|
||||||
|
assert.Contains(t, builder.String(), content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerWarning(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
log.SetOutput(&builder)
|
||||||
|
logger := new(Logger)
|
||||||
|
logger.Warning(content)
|
||||||
|
assert.Empty(t, builder.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerWarningf(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
log.SetOutput(&builder)
|
||||||
|
logger := new(Logger)
|
||||||
|
logger.Warningf(content)
|
||||||
|
assert.Empty(t, builder.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerWarningln(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
log.SetOutput(&builder)
|
||||||
|
logger := new(Logger)
|
||||||
|
logger.Warningln(content)
|
||||||
|
assert.Empty(t, builder.String())
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user