Compare commits

...

79 Commits

Author SHA1 Message Date
Kevin Wan
744c18b7cb simplify cgroup controller separation (#384) 2021-01-13 20:58:33 +08:00
miaogaolin
8d6f6f933e fix cgroup bug (#380) 2021-01-13 20:39:57 +08:00
Kevin Wan
37c3b9f5c1 make sure unlock safe even if listeners panic (#383)
* make sure unlock safe even if listeners panic

* fix #378

* fix #378
2021-01-13 18:43:42 +08:00
卢永杰
1f1dcd16e6 fix server.start return nil points (#379)
Co-authored-by: luyongjie <luyongjie@37.com>
2021-01-13 18:40:39 +08:00
文杰
3285436f75 f-fix spell (#381)
Co-authored-by: chenwenjie <chenwenjie@zzstc.cn>
2021-01-13 18:07:31 +08:00
kingxt
7f49bd8a31 code optimized (#382) 2021-01-13 16:37:33 +08:00
kingxt
9cd2015661 fix inner type generate error (#377)
* fix point type bug

* optimized

* fix inner type error
2021-01-13 11:54:53 +08:00
kingxt
cf3a1020b0 Java optimized (#376)
* optiimzed java gen

* optiimzed java gen

* fix
2021-01-12 14:14:49 +08:00
kingxt
ee19fb736b feature: refactor api parse to g4 (#365)
* feature: refactor api parse to g4

* new g4 parser

* add CHANGE_LOG.MD

* refactor

* fix byte bug

* refactor

* optimized

* optimized

* revert

* update readme.md

* update readme.md

* update readme.md

* update readme.md

* remove no need

* fix java gen

* add upgrade

* resolve confilits

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
2021-01-11 15:10:51 +08:00
Kevin Wan
b0ccfb8eb4 add more tests for conf (#371) 2021-01-10 21:53:16 +08:00
Kevin Wan
444e5a711f update doc to use table to render plugins (#370) 2021-01-09 19:54:34 +08:00
Kevin Wan
8774d72ddb remove duplicated code in goctl (#369) 2021-01-09 00:17:23 +08:00
HarryWang29
e3fcdbf040 fix return in for (#367)
Co-authored-by: HarryWang29 <wrz890829@gmail.com>
2021-01-08 22:47:27 +08:00
Kevin Wan
2854ca03b4 update goctl version to 1.1.3 (#364) 2021-01-08 14:02:59 +08:00
anqiansong
6c624a6ed0 Feature model fix (#362)
* fix sql builderx adding raw string quotation marks incompatibility bug

* add unit test

* remove comments

* fix sql builderx adding raw string quotation marks incompatibility bug
2021-01-08 12:01:21 +08:00
Kevin Wan
57b73d8b49 make sure offset less than size even it's checked inside (#354) 2021-01-05 16:06:36 +08:00
Kevin Wan
a79cee12ee add godoc for RollingWindow (#351) 2021-01-04 22:43:55 +08:00
zjbztianya
7a921f66e6 simple rolling windows code (#346) 2021-01-04 22:11:18 +08:00
kingxt
12e235efb0 optimized goctl format (#336)
* fix format

* refactor

* refactor

* optimized

* refactor

* refactor

* refactor

* add js path prefix
2021-01-04 18:59:48 +08:00
Kevin Wan
01060cf16d close issue of #337 (#347) 2021-01-04 16:36:27 +08:00
Kevin Wan
0786862a35 align bucket boundary to interval in rolling window (#345) 2021-01-04 11:17:59 +08:00
Kevin Wan
efa43483b2 fix potential data race in PeriodicalExecutor (#344)
* fix potential data race in PeriodicalExecutor

* add comment
2021-01-03 20:56:17 +08:00
Kevin Wan
771371e051 simplify rolling window code, and make tests run faster (#343) 2021-01-03 20:47:29 +08:00
zjbztianya
2ee95f8981 fix rolling window bug (#340) 2021-01-03 20:27:47 +08:00
Kevin Wan
5bc01e4bfd set guarded to false only on quitting background flush (#342)
* set guarded to false only on quitting background flush

* set guarded to false only on quitting background flush, cont.
2021-01-03 19:54:11 +08:00
Kevin Wan
510e966982 simplify periodical executor background routine (#339) 2021-01-03 14:02:51 +08:00
Kevin Wan
10e3b8ac80 optimize code that fixes issue #317 (#338) 2021-01-02 19:01:37 +08:00
Kevin Wan
04059bbf5a add discord chat group in readme 2021-01-02 18:35:33 +08:00
weibobo
d643007c79 fix bug #317 (#335)
* fix bug #317.
* add counter for current task. If it's bigger then zero, do not quit background thread

* Revert "fix issue #317 (#331)"

This reverts commit fc43876cc5.
2021-01-02 18:04:04 +08:00
Kevin Wan
fc43876cc5 fix issue #317 (#331) 2021-01-01 13:24:28 +08:00
FengZhang
a926cb514f modify the goctl gensvc template (#323) 2020-12-30 10:05:26 +08:00
kingxt
25cab2f273 Java (#327)
* add g4 file

* new define api by g4

* reactor parser to g4gen

* add syntax parser & test

* add syntax parser & test

* add syntax parser & test

* update g4 file

* add import parse & test

* ractor AT lexer

* panic with error

* revert AT

* update g4 file

* update g4 file

* update g4 file

* optimize parser

* update g4 file

* parse info

* optimized java generator

* revert

* optimize java generator

* update java generator

* update java generator

* update java generator

* update java generator

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
2020-12-29 17:50:41 +08:00
Kevin Wan
8d2e2753a2 simplify http.Flusher implementation (#326)
* simplify code with http.Flusher type conversion

* simplify code with http.Flusher type conversion, better version
2020-12-29 15:02:36 +08:00
Kevin Wan
cc4c50e3eb fix broken link. 2020-12-29 11:54:32 +08:00
Kevin Wan
751072bdb0 fix broken doc link 2020-12-29 11:52:55 +08:00
Kevin Wan
e97e1f10db simplify code with http.Flusher type conversion (#325)
* simplify code with http.Flusher type conversion

* simplify code with http.Flusher type conversion, better version
2020-12-29 10:25:55 +08:00
jichangyun
0bd2a0656c The ResponseWriters defined in rest.handler add Flush interface. (#318) 2020-12-28 21:30:24 +08:00
Kevin Wan
71a2b20301 add more tests for prof (#322) 2020-12-27 14:45:14 +08:00
Kevin Wan
8df7de94e3 add more tests for zrpc (#321) 2020-12-27 14:08:24 +08:00
Kevin Wan
bf21203297 add more tests (#320) 2020-12-27 12:26:31 +08:00
Kevin Wan
ae98375194 add more tests (#319) 2020-12-26 20:30:02 +08:00
Kevin Wan
82d1ccf376 fixes #286 (#315) 2020-12-25 19:47:27 +08:00
Kevin Wan
bb6d49c17e add go report card back (#313)
* add go report card back

* avoid test failure, run tests sequentially
2020-12-25 12:09:59 +08:00
Kevin Wan
ed735ec47c Update codeql-analysis.yml
disable python code analysis, python code is in examples.
2020-12-25 12:09:43 +08:00
Kevin Wan
ba4bac3a03 format code (#312) 2020-12-25 11:53:37 +08:00
FengZhang
08433d7e04 add config load support env var (#309) 2020-12-25 11:42:19 +08:00
anqiansong
a3b525b50d feature model fix (#296)
* add raw stirng quote for sql field

* remove unused code
2020-12-21 09:43:32 +08:00
Kevin Wan
097f6886f2 Update readme.md 2020-12-15 23:47:41 +08:00
Kevin Wan
07a1549634 add wechat micro practice qrcode image (#289) 2020-12-14 17:49:58 +08:00
Kevin Wan
befca26c58 Update readme.md
add goproxy.cn download badge
2020-12-13 00:02:32 +08:00
Kevin Wan
3556a2eef4 Update readme-en.md
goreportcard is not working, submitted an issue to them.
2020-12-12 23:40:26 +08:00
Kevin Wan
807765f77e Update readme.md
goreportcard is not working, submitted a issue to them.
2020-12-12 23:39:28 +08:00
Kevin Wan
e44584e549 Create codeql-analysis.yml 2020-12-12 23:01:15 +08:00
Kevin Wan
acd48f0abb optimize dockerfile generation (#284) 2020-12-12 16:53:06 +08:00
kingxt
f919bc6713 refactor (#283) 2020-12-12 11:18:22 +08:00
Kevin Wan
a0030b8f45 format dockerfile on non-chinese mode (#282) 2020-12-12 10:13:33 +08:00
Kevin Wan
a5f0cce1b1 Update readme-en.md 2020-12-12 09:06:09 +08:00
Kevin Wan
4d13dda605 add EXPOSE in dockerfile generation (#281) 2020-12-12 08:18:01 +08:00
songmeizi
b56cc8e459 optimize test case of TestRpcGenerate (#279)
Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
2020-12-11 21:57:04 +08:00
Kevin Wan
c435811479 fix gocyclo warnings (#278) 2020-12-11 20:57:48 +08:00
Kevin Wan
c686c93fb5 fix dockerfile generation bug (#277) 2020-12-11 20:31:31 +08:00
Kevin Wan
da8f76e6bd add category docker & kube (#276) 2020-12-11 18:53:40 +08:00
Kevin Wan
99596a4149 fix issue #266 (#275)
* optimize dockerfile

* fix issue #266
2020-12-11 16:12:33 +08:00
wayne
ec2a9f2c57 fix tracelogger_test TestTraceLog (#271) 2020-12-10 17:04:57 +08:00
Kevin Wan
fd73ced6dc optimize dockerfile (#272) 2020-12-10 16:21:06 +08:00
Kevin Wan
5071736ab4 fmt code (#270) 2020-12-10 15:16:13 +08:00
Kevin Wan
0d7f1d23b4 require go 1.14 (#263)
* refactor & format code

* optimized parse tag (#256)

* feature plugin custom flag (#251)

* support plugin custom flags

* add short name

* remove log

* remove log

* require go 1.14

Co-authored-by: kingxt <kingxt4job@gmail.com>
Co-authored-by: songmeizi <anqiansong@xiaoheiban.cn>
2020-12-09 22:43:42 +08:00
songmeizi
84ab11ac09 feature plugin custom flag (#251)
* support plugin custom flags

* add short name

* remove log

* remove log
2020-12-09 18:08:17 +08:00
kingxt
67804a6bb2 optimized parse tag (#256) 2020-12-09 11:16:38 +08:00
Kevin Wan
65ee877236 refactor & format code (#255) 2020-12-08 23:01:25 +08:00
songmeizi
b060867009 Feature bookstore update (#253)
* update bookstore

* update bookstore
2020-12-08 22:36:48 +08:00
songmeizi
4d53045c6b improve data type conversion (#236)
* improve data type conversion

* update doc
2020-12-08 18:06:15 +08:00
kingxt
cecd4b1b75 goctl add plugin support (#243)
* add plugin support

* add plugin support

* add plugin support

* add plugin support

* add plugin support

* add plugin support

* add plugin support

* add plugin support

* add plugin support

* add plugin support

* add plugin support

* remove no need

* add plugin support

* rename

* rename

* add plugin support

* refactor

* update plugin

* refactor

* refactor

* refactor

* update plugin

* newline

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
2020-12-07 14:55:10 +08:00
Kevin Wan
7cd0463953 fix lint errors (#249)
* simplify code, format makefile

* simplify code

* some optimize by kevwan and benying (#240)

Co-authored-by: 杨志泉 <zhiquan.yang@yiducloud.cn>

* optimization (#241)

* optimize docker file generation, make docker build faster

* support k8s deployment yaml generation

* fix lint errors

Co-authored-by: benying <31179034+benyingY@users.noreply.github.com>
Co-authored-by: 杨志泉 <zhiquan.yang@yiducloud.cn>
Co-authored-by: bittoy <bittoy@qq.com>
2020-12-07 11:12:02 +08:00
Kevin Wan
7a82cf80ce support k8s deployment yaml generation (#247)
* simplify code, format makefile

* simplify code

* some optimize by kevwan and benying (#240)

Co-authored-by: 杨志泉 <zhiquan.yang@yiducloud.cn>

* optimization (#241)

* optimize docker file generation, make docker build faster

* support k8s deployment yaml generation

Co-authored-by: benying <31179034+benyingY@users.noreply.github.com>
Co-authored-by: 杨志泉 <zhiquan.yang@yiducloud.cn>
Co-authored-by: bittoy <bittoy@qq.com>
2020-12-07 00:07:50 +08:00
Kevin Wan
f997aee3ba optimize docker file generation, make docker build faster (#244)
* simplify code, format makefile

* simplify code

* some optimize by kevwan and benying (#240)

Co-authored-by: 杨志泉 <zhiquan.yang@yiducloud.cn>

* optimization (#241)

* optimize docker file generation, make docker build faster

Co-authored-by: benying <31179034+benyingY@users.noreply.github.com>
Co-authored-by: 杨志泉 <zhiquan.yang@yiducloud.cn>
Co-authored-by: bittoy <bittoy@qq.com>
2020-12-05 21:48:09 +08:00
bittoy
88ec89bdbd optimization (#241) 2020-12-02 15:00:07 +08:00
benying
7d1b43780a some optimize by kevwan and benying (#240)
Co-authored-by: 杨志泉 <zhiquan.yang@yiducloud.cn>
2020-12-01 06:44:32 +08:00
Kevin Wan
4b5c2de376 simplify code (#234)
* simplify code, format makefile

* simplify code
2020-11-29 12:41:42 +08:00
215 changed files with 17948 additions and 4029 deletions

67
.github/workflows/codeql-analysis.yml vendored Normal file
View File

@@ -0,0 +1,67 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
on:
push:
branches: [ master ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ master ]
schedule:
- cron: '18 19 * * 6'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
language: [ 'go' ]
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
# Learn more:
# https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
steps:
- name: Checkout repository
uses: actions/checkout@v2
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v1
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# queries: ./path/to/local/query, your-org/your-repo/queries@main
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v1
# Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl
# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
# and modify them (or add more) to build your code if your project
# uses a compiled language
#- run: |
# make bootstrap
# make release
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v1

2
.gitignore vendored
View File

@@ -4,6 +4,7 @@
# Unignore all with extensions # Unignore all with extensions
!*.* !*.*
!**/Dockerfile !**/Dockerfile
!**/Makefile
# Unignore all dirs # Unignore all dirs
!*/ !*/
@@ -12,7 +13,6 @@
.idea .idea
**/.DS_Store **/.DS_Store
**/logs **/logs
!Makefile
# gitlab ci # gitlab ci
.cache .cache

View File

@@ -8,8 +8,10 @@ import (
) )
type ( type (
// RollingWindowOption let callers customize the RollingWindow.
RollingWindowOption func(rollingWindow *RollingWindow) RollingWindowOption func(rollingWindow *RollingWindow)
// RollingWindow defines a rolling window to calculate the events in buckets with time interval.
RollingWindow struct { RollingWindow struct {
lock sync.RWMutex lock sync.RWMutex
size int size int
@@ -17,10 +19,12 @@ type (
interval time.Duration interval time.Duration
offset int offset int
ignoreCurrent bool ignoreCurrent bool
lastTime time.Duration lastTime time.Duration // start time of the last bucket
} }
) )
// NewRollingWindow returns a RollingWindow that with size buckets and time interval,
// use opts to customize the RollingWindow.
func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow { func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
if size < 1 { if size < 1 {
panic("size must be greater than 0") panic("size must be greater than 0")
@@ -38,6 +42,7 @@ func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOpt
return w return w
} }
// Add adds value to current bucket.
func (rw *RollingWindow) Add(v float64) { func (rw *RollingWindow) Add(v float64) {
rw.lock.Lock() rw.lock.Lock()
defer rw.lock.Unlock() defer rw.lock.Unlock()
@@ -45,6 +50,7 @@ func (rw *RollingWindow) Add(v float64) {
rw.win.add(rw.offset, v) rw.win.add(rw.offset, v)
} }
// Reduce runs fn on all buckets, ignore current bucket if ignoreCurrent was set.
func (rw *RollingWindow) Reduce(fn func(b *Bucket)) { func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
rw.lock.RLock() rw.lock.RLock()
defer rw.lock.RUnlock() defer rw.lock.RUnlock()
@@ -74,29 +80,23 @@ func (rw *RollingWindow) span() int {
func (rw *RollingWindow) updateOffset() { func (rw *RollingWindow) updateOffset() {
span := rw.span() span := rw.span()
if span > 0 { if span <= 0 {
offset := rw.offset return
start := offset + 1
steps := start + span
var remainder int
if steps > rw.size {
remainder = steps - rw.size
steps = rw.size
}
// reset expired buckets
for i := start; i < steps; i++ {
rw.win.resetBucket(i)
}
for i := 0; i < remainder; i++ {
rw.win.resetBucket(i)
}
rw.offset = (offset + span) % rw.size
rw.lastTime = timex.Now()
} }
offset := rw.offset
// reset expired buckets
for i := 0; i < span; i++ {
rw.win.resetBucket((offset + i + 1) % rw.size)
}
rw.offset = (offset + span) % rw.size
now := timex.Now()
// align to interval time boundary
rw.lastTime = now - (now-rw.lastTime)%rw.interval
} }
// Bucket defines the bucket that holds sum and num of additions.
type Bucket struct { type Bucket struct {
Sum float64 Sum float64
Count int64 Count int64
@@ -118,9 +118,9 @@ type window struct {
} }
func newWindow(size int) *window { func newWindow(size int) *window {
var buckets []*Bucket buckets := make([]*Bucket, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
buckets = append(buckets, new(Bucket)) buckets[i] = new(Bucket)
} }
return &window{ return &window{
buckets: buckets, buckets: buckets,
@@ -134,14 +134,15 @@ func (w *window) add(offset int, v float64) {
func (w *window) reduce(start, count int, fn func(b *Bucket)) { func (w *window) reduce(start, count int, fn func(b *Bucket)) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
fn(w.buckets[(start+i)%len(w.buckets)]) fn(w.buckets[(start+i)%w.size])
} }
} }
func (w *window) resetBucket(offset int) { func (w *window) resetBucket(offset int) {
w.buckets[offset].reset() w.buckets[offset%w.size].reset()
} }
// IgnoreCurrentBucket lets the Reduce call ignore current bucket.
func IgnoreCurrentBucket() RollingWindowOption { func IgnoreCurrentBucket() RollingWindowOption {
return func(w *RollingWindow) { return func(w *RollingWindow) {
w.ignoreCurrent = true w.ignoreCurrent = true

View File

@@ -105,6 +105,37 @@ func TestRollingWindowReduce(t *testing.T) {
} }
} }
func TestRollingWindowBucketTimeBoundary(t *testing.T) {
const size = 3
interval := time.Millisecond * 30
r := NewRollingWindow(size, interval)
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
}
assert.Equal(t, []float64{0, 0, 0}, listBuckets())
r.Add(1)
assert.Equal(t, []float64{0, 0, 1}, listBuckets())
time.Sleep(time.Millisecond * 45)
r.Add(2)
r.Add(3)
assert.Equal(t, []float64{0, 1, 5}, listBuckets())
// sleep time should be less than interval, and make the bucket change happen
time.Sleep(time.Millisecond * 20)
r.Add(4)
r.Add(5)
r.Add(6)
assert.Equal(t, []float64{1, 5, 15}, listBuckets())
time.Sleep(time.Millisecond * 100)
r.Add(7)
r.Add(8)
r.Add(9)
assert.Equal(t, []float64{0, 0, 24}, listBuckets())
}
func TestRollingWindowDataRace(t *testing.T) { func TestRollingWindowDataRace(t *testing.T) {
const size = 3 const size = 3
r := NewRollingWindow(size, duration) r := NewRollingWindow(size, duration)

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
"os"
"path" "path"
"github.com/tal-tech/go-zero/core/mapping" "github.com/tal-tech/go-zero/core/mapping"
@@ -19,7 +20,7 @@ func LoadConfig(file string, v interface{}) error {
if content, err := ioutil.ReadFile(file); err != nil { if content, err := ioutil.ReadFile(file); err != nil {
return err return err
} else if loader, ok := loaders[path.Ext(file)]; ok { } else if loader, ok := loaders[path.Ext(file)]; ok {
return loader(content, v) return loader([]byte(os.ExpandEnv(string(content))), v)
} else { } else {
return fmt.Errorf("unrecoginized file type: %s", file) return fmt.Errorf("unrecoginized file type: %s", file)
} }

View File

@@ -6,9 +6,21 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/fs"
"github.com/tal-tech/go-zero/core/hash" "github.com/tal-tech/go-zero/core/hash"
) )
func TestLoadConfig_notExists(t *testing.T) {
assert.NotNil(t, LoadConfig("not_a_file", nil))
}
func TestLoadConfig_notRecogFile(t *testing.T) {
filename, err := fs.TempFilenameWithText("hello")
assert.Nil(t, err)
defer os.Remove(filename)
assert.NotNil(t, LoadConfig(filename, nil))
}
func TestConfigJson(t *testing.T) { func TestConfigJson(t *testing.T) {
tests := []string{ tests := []string{
".json", ".json",
@@ -17,13 +29,14 @@ func TestConfigJson(t *testing.T) {
} }
text := `{ text := `{
"a": "foo", "a": "foo",
"b": 1 "b": 1,
"c": "${FOO}"
}` }`
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
t.Parallel() os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text) tmpfile, err := createTempFile(test, text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile) defer os.Remove(tmpfile)
@@ -31,10 +44,12 @@ func TestConfigJson(t *testing.T) {
var val struct { var val struct {
A string `json:"a"` A string `json:"a"`
B int `json:"b"` B int `json:"b"`
C string `json:"c"`
} }
MustLoad(tmpfile, &val) MustLoad(tmpfile, &val)
assert.Equal(t, "foo", val.A) assert.Equal(t, "foo", val.A)
assert.Equal(t, 1, val.B) assert.Equal(t, 1, val.B)
assert.Equal(t, "2", val.C)
}) })
} }
} }

View File

@@ -24,6 +24,20 @@ func TestProperties(t *testing.T) {
assert.Equal(t, "test", props.GetString("app.name")) assert.Equal(t, "test", props.GetString("app.name"))
assert.Equal(t, "app", props.GetString("app.program")) assert.Equal(t, "app", props.GetString("app.program"))
assert.Equal(t, 5, props.GetInt("app.threads")) assert.Equal(t, 5, props.GetInt("app.threads"))
val := props.ToString()
assert.Contains(t, val, "app.name")
assert.Contains(t, val, "app.program")
assert.Contains(t, val, "app.threads")
}
func TestLoadProperties_badContent(t *testing.T) {
filename, err := fs.TempFilenameWithText("hello")
assert.Nil(t, err)
defer os.Remove(filename)
_, err = LoadProperties(filename)
assert.NotNil(t, err)
assert.True(t, len(err.Error()) > 0)
} }
func TestSetString(t *testing.T) { func TestSetString(t *testing.T) {

View File

@@ -18,7 +18,8 @@ type (
disconnected bool disconnected bool
currentState connectivity.State currentState connectivity.State
listeners []func() listeners []func()
lock sync.Mutex // lock only guards listeners, because only listens can be accessed by other goroutines.
lock sync.Mutex
} }
) )
@@ -32,27 +33,33 @@ func (sw *stateWatcher) addListener(l func()) {
sw.lock.Unlock() sw.lock.Unlock()
} }
func (sw *stateWatcher) notifyListeners() {
sw.lock.Lock()
defer sw.lock.Unlock()
for _, l := range sw.listeners {
l()
}
}
func (sw *stateWatcher) updateState(conn etcdConn) {
sw.currentState = conn.GetState()
switch sw.currentState {
case connectivity.TransientFailure, connectivity.Shutdown:
sw.disconnected = true
case connectivity.Ready:
if sw.disconnected {
sw.disconnected = false
sw.notifyListeners()
}
}
}
func (sw *stateWatcher) watch(conn etcdConn) { func (sw *stateWatcher) watch(conn etcdConn) {
sw.currentState = conn.GetState() sw.currentState = conn.GetState()
for { for {
if conn.WaitForStateChange(context.Background(), sw.currentState) { if conn.WaitForStateChange(context.Background(), sw.currentState) {
newState := conn.GetState() sw.updateState(conn)
sw.lock.Lock()
sw.currentState = newState
switch newState {
case connectivity.TransientFailure, connectivity.Shutdown:
sw.disconnected = true
case connectivity.Ready:
if sw.disconnected {
sw.disconnected = false
for _, l := range sw.listeners {
l()
}
}
}
sw.lock.Unlock()
} }
} }
} }

View File

@@ -3,6 +3,7 @@ package executors
import ( import (
"reflect" "reflect"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/tal-tech/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
@@ -35,6 +36,7 @@ type (
// avoid race condition on waitGroup when calling wg.Add/Done/Wait(...) // avoid race condition on waitGroup when calling wg.Add/Done/Wait(...)
wgBarrier syncx.Barrier wgBarrier syncx.Barrier
confirmChan chan lang.PlaceholderType confirmChan chan lang.PlaceholderType
inflight int32
guarded bool guarded bool
newTicker func(duration time.Duration) timex.Ticker newTicker func(duration time.Duration) timex.Ticker
lock sync.Mutex lock sync.Mutex
@@ -91,18 +93,16 @@ func (pe *PeriodicalExecutor) Wait() {
func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) { func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
pe.lock.Lock() pe.lock.Lock()
defer func() { defer func() {
var start bool
if !pe.guarded { if !pe.guarded {
pe.guarded = true pe.guarded = true
start = true // defer to unlock quickly
defer pe.backgroundFlush()
} }
pe.lock.Unlock() pe.lock.Unlock()
if start {
pe.backgroundFlush()
}
}() }()
if pe.container.AddTask(task) { if pe.container.AddTask(task) {
atomic.AddInt32(&pe.inflight, 1)
return pe.container.RemoveAll(), true return pe.container.RemoveAll(), true
} }
@@ -111,6 +111,9 @@ func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool)
func (pe *PeriodicalExecutor) backgroundFlush() { func (pe *PeriodicalExecutor) backgroundFlush() {
threading.GoSafe(func() { threading.GoSafe(func() {
// flush before quit goroutine to avoid missing tasks
defer pe.Flush()
ticker := pe.newTicker(pe.interval) ticker := pe.newTicker(pe.interval)
defer ticker.Stop() defer ticker.Stop()
@@ -120,6 +123,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
select { select {
case vals := <-pe.commander: case vals := <-pe.commander:
commanded = true commanded = true
atomic.AddInt32(&pe.inflight, -1)
pe.enterExecution() pe.enterExecution()
pe.confirmChan <- lang.Placeholder pe.confirmChan <- lang.Placeholder
pe.executeTasks(vals) pe.executeTasks(vals)
@@ -129,13 +133,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
commanded = false commanded = false
} else if pe.Flush() { } else if pe.Flush() {
last = timex.Now() last = timex.Now()
} else if timex.Since(last) > pe.interval*idleRound { } else if pe.shallQuit(last) {
pe.lock.Lock()
pe.guarded = false
pe.lock.Unlock()
// flush again to avoid missing tasks
pe.Flush()
return return
} }
} }
@@ -178,3 +176,19 @@ func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
return true return true
} }
} }
func (pe *PeriodicalExecutor) shallQuit(last time.Duration) (stop bool) {
if timex.Since(last) <= pe.interval*idleRound {
return
}
// checking pe.inflight and setting pe.guarded should be locked together
pe.lock.Lock()
if atomic.LoadInt32(&pe.inflight) == 0 {
pe.guarded = false
stop = true
}
pe.lock.Unlock()
return
}

View File

@@ -140,6 +140,26 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
assert.Equal(t, total, cnt) assert.Equal(t, total, cnt)
} }
func TestPeriodicalExecutor_Deadlock(t *testing.T) {
executor := NewBulkExecutor(func(tasks []interface{}) {
}, WithBulkTasks(1), WithBulkInterval(time.Millisecond))
for i := 0; i < 1e5; i++ {
executor.Add(1)
}
}
func TestPeriodicalExecutor_hasTasks(t *testing.T) {
ticker := timex.NewFakeTicker()
defer ticker.Stop()
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
exec.newTicker = func(d time.Duration) timex.Ticker {
return ticker
}
assert.False(t, exec.hasTasks(nil))
assert.True(t, exec.hasTasks(1))
}
// go test -benchtime 10s -bench . // go test -benchtime 10s -bench .
func BenchmarkExecutor(b *testing.B) { func BenchmarkExecutor(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()

View File

@@ -21,6 +21,7 @@ var mock tracespec.Trace = new(mockTrace)
func TestTraceLog(t *testing.T) { func TestTraceLog(t *testing.T) {
var buf mockWriter var buf mockWriter
atomic.StoreUint32(&initialized, 1)
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock) ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog) WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId)) assert.True(t, strings.Contains(buf.String(), mockTraceId))

View File

@@ -153,58 +153,57 @@ func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fie
key := strings.TrimSpace(segments[0]) key := strings.TrimSpace(segments[0])
options := segments[1:] options := segments[1:]
if len(options) > 0 { if len(options) == 0 {
var fieldOpts fieldOptions return key, nil, nil
for _, segment := range options {
option := strings.TrimSpace(segment)
switch {
case option == stringOption:
fieldOpts.FromString = true
case strings.HasPrefix(option, optionalOption):
segs := strings.Split(option, equalToken)
switch len(segs) {
case 1:
fieldOpts.Optional = true
case 2:
fieldOpts.Optional = true
fieldOpts.OptionalDep = segs[1]
default:
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
}
case option == optionalOption:
fieldOpts.Optional = true
case strings.HasPrefix(option, optionsOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
} else {
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
}
case strings.HasPrefix(option, defaultOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
} else {
fieldOpts.Default = strings.TrimSpace(segs[1])
}
case strings.HasPrefix(option, rangeOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
}
if nr, err := parseNumberRange(segs[1]); err != nil {
return "", nil, err
} else {
fieldOpts.Range = nr
}
}
}
return key, &fieldOpts, nil
} }
return key, nil, nil var fieldOpts fieldOptions
for _, segment := range options {
option := strings.TrimSpace(segment)
switch {
case option == stringOption:
fieldOpts.FromString = true
case strings.HasPrefix(option, optionalOption):
segs := strings.Split(option, equalToken)
switch len(segs) {
case 1:
fieldOpts.Optional = true
case 2:
fieldOpts.Optional = true
fieldOpts.OptionalDep = segs[1]
default:
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
}
case option == optionalOption:
fieldOpts.Optional = true
case strings.HasPrefix(option, optionsOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
} else {
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
}
case strings.HasPrefix(option, defaultOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
} else {
fieldOpts.Default = strings.TrimSpace(segs[1])
}
case strings.HasPrefix(option, rangeOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
}
if nr, err := parseNumberRange(segs[1]); err != nil {
return "", nil, err
} else {
fieldOpts.Range = nr
}
}
}
return key, &fieldOpts, nil
} }
func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) { func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {

View File

@@ -4,12 +4,14 @@ package proc
import "time" import "time"
// AddShutdownListener returns fn itself on windows, lets callers call fn on their own.
func AddShutdownListener(fn func()) func() { func AddShutdownListener(fn func()) func() {
return nil return fn
} }
// AddWrapUpListener returns fn itself on windows, lets callers call fn on their own.
func AddWrapUpListener(fn func()) func() { func AddWrapUpListener(fn func()) func() {
return nil return fn
} }
func SetTimeoutToForceQuit(duration time.Duration) { func SetTimeoutToForceQuit(duration time.Duration) {

View File

@@ -0,0 +1,16 @@
package prof
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestReport(t *testing.T) {
once.Do(func() {})
assert.NotContains(t, generateReport(), "foo")
report("foo", time.Second)
assert.Contains(t, generateReport(), "foo")
report("foo", time.Second)
}

View File

@@ -0,0 +1,23 @@
package prof
import (
"testing"
"github.com/tal-tech/go-zero/core/utils"
)
func TestProfiler(t *testing.T) {
EnableProfiling()
Start()
Report("foo", ProfilePoint{
ElapsedTimer: utils.NewElapsedTimer(),
})
}
func TestNullProfiler(t *testing.T) {
p := newNullProfiler()
p.Start()
p.Report("foo", ProfilePoint{
ElapsedTimer: utils.NewElapsedTimer(),
})
}

View File

@@ -92,11 +92,11 @@ func currentCgroup() (*cgroup, error) {
continue continue
} }
cgroups[subsys] = path.Join(cgroupDir, subsys) // https://man7.org/linux/man-pages/man7/cgroups.7.html
if strings.Contains(subsys, ",") { // comma-separated list of controllers for cgroup version 1
for _, k := range strings.Split(subsys, ",") { fields := strings.Split(subsys, ",")
cgroups[k] = path.Join(cgroupDir, k) for _, val := range fields {
} cgroups[val] = path.Join(cgroupDir, val)
} }
} }

View File

@@ -556,7 +556,7 @@ func TestRedis_SortedSet(t *testing.T) {
val, err = client.Zscore("key", "value1") val, err = client.Zscore("key", "value1")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(5), val) assert.Equal(t, int64(5), val)
val, err = NewRedis(client.Addr, "").Zadds("key") _, err = NewRedis(client.Addr, "").Zadds("key")
assert.NotNil(t, err) assert.NotNil(t, err)
val, err = client.Zadds("key", Pair{ val, err = client.Zadds("key", Pair{
Key: "value2", Key: "value2",
@@ -567,9 +567,9 @@ func TestRedis_SortedSet(t *testing.T) {
}) })
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(2), val) assert.Equal(t, int64(2), val)
pairs, err := NewRedis(client.Addr, "").ZRevRangeWithScores("key", 1, 3) _, err = NewRedis(client.Addr, "").ZRevRangeWithScores("key", 1, 3)
assert.NotNil(t, err) assert.NotNil(t, err)
pairs, err = client.ZRevRangeWithScores("key", 1, 3) pairs, err := client.ZRevRangeWithScores("key", 1, 3)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []Pair{ assert.EqualValues(t, []Pair{
{ {

View File

@@ -70,8 +70,6 @@ func (g *sharedGroup) createCall(key string) (c *call, done bool) {
func (g *sharedGroup) makeCall(c *call, key string, fn func() (interface{}, error)) { func (g *sharedGroup) makeCall(c *call, key string, fn func() (interface{}, error)) {
defer func() { defer func() {
// delete key first, done later. can't reverse the order, because if reverse,
// another Do call might wg.Wait() without get notified with wg.Done()
g.lock.Lock() g.lock.Lock()
delete(g.calls, key) delete(g.calls, key)
g.lock.Unlock() g.lock.Unlock()

View File

@@ -1,33 +1,29 @@
type ( type (
addReq struct { addReq {
book string `form:"book"` book string `form:"book"`
price int64 `form:"price"` price int64 `form:"price"`
} }
addResp struct { addResp {
ok bool `json:"ok"` ok bool `json:"ok"`
} }
) )
type ( type (
checkReq struct { checkReq {
book string `form:"book"` book string `form:"book"`
} }
checkResp struct { checkResp {
found bool `json:"found"` found bool `json:"found"`
price int64 `json:"price"` price int64 `json:"price"`
} }
) )
service bookstore-api { service bookstore-api {
@server( @handler AddHandler
handler: AddHandler get /add (addReq) returns (addResp)
)
get /add (addReq) returns (addResp)
@server( @handler CheckHandler
handler: CheckHandler get /check (checkReq) returns (checkResp)
)
get /check (checkReq) returns (checkResp)
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
) )
func addHandler(ctx *svc.ServiceContext) http.HandlerFunc { func AddHandler(ctx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req types.AddReq var req types.AddReq
if err := httpx.Parse(r, &req); err != nil { if err := httpx.Parse(r, &req); err != nil {

View File

@@ -10,7 +10,7 @@ import (
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
) )
func checkHandler(ctx *svc.ServiceContext) http.HandlerFunc { func CheckHandler(ctx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req types.CheckReq var req types.CheckReq
if err := httpx.Parse(r, &req); err != nil { if err := httpx.Parse(r, &req); err != nil {

View File

@@ -10,16 +10,18 @@ import (
) )
func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) { func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) {
engine.AddRoutes([]rest.Route{ engine.AddRoutes(
{ []rest.Route{
Method: http.MethodGet, {
Path: "/add", Method: http.MethodGet,
Handler: addHandler(serverCtx), Path: "/add",
Handler: AddHandler(serverCtx),
},
{
Method: http.MethodGet,
Path: "/check",
Handler: CheckHandler(serverCtx),
},
}, },
{ )
Method: http.MethodGet,
Path: "/check",
Handler: checkHandler(serverCtx),
},
})
} }

View File

@@ -29,7 +29,8 @@ func (l *CheckLogic) Check(req types.CheckReq) (*types.CheckResp, error) {
Book: req.Book, Book: req.Book,
}) })
if err != nil { if err != nil {
return nil, err logx.Error(err)
return &types.CheckResp{}, err
} }
return &types.CheckResp{ return &types.CheckResp{

View File

@@ -8,4 +8,5 @@ require (
github.com/tal-tech/go-zero v1.0.27 github.com/tal-tech/go-zero v1.0.27
golang.org/x/net v0.0.0-20200707034311-ab3426394381 golang.org/x/net v0.0.0-20200707034311-ab3426394381
google.golang.org/grpc v1.29.1 google.golang.org/grpc v1.29.1
google.golang.org/protobuf v1.25.0
) )

View File

@@ -7,8 +7,8 @@ import (
"flag" "flag"
"fmt" "fmt"
"bookstore/rpc/add/add"
"bookstore/rpc/add/internal/config" "bookstore/rpc/add/internal/config"
add "bookstore/rpc/add/internal/pb"
"bookstore/rpc/add/internal/server" "bookstore/rpc/add/internal/server"
"bookstore/rpc/add/internal/svc" "bookstore/rpc/add/internal/svc"

View File

@@ -0,0 +1,305 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.25.0
// protoc v3.14.0
// source: add.proto
package add
import (
context "context"
proto "github.com/golang/protobuf/proto"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type AddReq struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Book string `protobuf:"bytes,1,opt,name=book,proto3" json:"book,omitempty"`
Price int64 `protobuf:"varint,2,opt,name=price,proto3" json:"price,omitempty"`
}
func (x *AddReq) Reset() {
*x = AddReq{}
if protoimpl.UnsafeEnabled {
mi := &file_add_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *AddReq) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AddReq) ProtoMessage() {}
func (x *AddReq) ProtoReflect() protoreflect.Message {
mi := &file_add_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AddReq.ProtoReflect.Descriptor instead.
func (*AddReq) Descriptor() ([]byte, []int) {
return file_add_proto_rawDescGZIP(), []int{0}
}
func (x *AddReq) GetBook() string {
if x != nil {
return x.Book
}
return ""
}
func (x *AddReq) GetPrice() int64 {
if x != nil {
return x.Price
}
return 0
}
type AddResp struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Ok bool `protobuf:"varint,1,opt,name=ok,proto3" json:"ok,omitempty"`
}
func (x *AddResp) Reset() {
*x = AddResp{}
if protoimpl.UnsafeEnabled {
mi := &file_add_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *AddResp) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AddResp) ProtoMessage() {}
func (x *AddResp) ProtoReflect() protoreflect.Message {
mi := &file_add_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AddResp.ProtoReflect.Descriptor instead.
func (*AddResp) Descriptor() ([]byte, []int) {
return file_add_proto_rawDescGZIP(), []int{1}
}
func (x *AddResp) GetOk() bool {
if x != nil {
return x.Ok
}
return false
}
var File_add_proto protoreflect.FileDescriptor
var file_add_proto_rawDesc = []byte{
0x0a, 0x09, 0x61, 0x64, 0x64, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x03, 0x61, 0x64, 0x64,
0x22, 0x32, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x52, 0x65, 0x71, 0x12, 0x12, 0x0a, 0x04, 0x62, 0x6f,
0x6f, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x62, 0x6f, 0x6f, 0x6b, 0x12, 0x14,
0x0a, 0x05, 0x70, 0x72, 0x69, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x70,
0x72, 0x69, 0x63, 0x65, 0x22, 0x19, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x52, 0x65, 0x73, 0x70, 0x12,
0x0e, 0x0a, 0x02, 0x6f, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x02, 0x6f, 0x6b, 0x32,
0x29, 0x0a, 0x05, 0x61, 0x64, 0x64, 0x65, 0x72, 0x12, 0x20, 0x0a, 0x03, 0x61, 0x64, 0x64, 0x12,
0x0b, 0x2e, 0x61, 0x64, 0x64, 0x2e, 0x61, 0x64, 0x64, 0x52, 0x65, 0x71, 0x1a, 0x0c, 0x2e, 0x61,
0x64, 0x64, 0x2e, 0x61, 0x64, 0x64, 0x52, 0x65, 0x73, 0x70, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
}
var (
file_add_proto_rawDescOnce sync.Once
file_add_proto_rawDescData = file_add_proto_rawDesc
)
func file_add_proto_rawDescGZIP() []byte {
file_add_proto_rawDescOnce.Do(func() {
file_add_proto_rawDescData = protoimpl.X.CompressGZIP(file_add_proto_rawDescData)
})
return file_add_proto_rawDescData
}
var file_add_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_add_proto_goTypes = []interface{}{
(*AddReq)(nil), // 0: add.addReq
(*AddResp)(nil), // 1: add.addResp
}
var file_add_proto_depIdxs = []int32{
0, // 0: add.adder.add:input_type -> add.addReq
1, // 1: add.adder.add:output_type -> add.addResp
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_add_proto_init() }
func file_add_proto_init() {
if File_add_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_add_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*AddReq); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_add_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*AddResp); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_add_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_add_proto_goTypes,
DependencyIndexes: file_add_proto_depIdxs,
MessageInfos: file_add_proto_msgTypes,
}.Build()
File_add_proto = out.File
file_add_proto_rawDesc = nil
file_add_proto_goTypes = nil
file_add_proto_depIdxs = nil
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConnInterface
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion6
// AdderClient is the client API for Adder service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type AdderClient interface {
Add(ctx context.Context, in *AddReq, opts ...grpc.CallOption) (*AddResp, error)
}
type adderClient struct {
cc grpc.ClientConnInterface
}
func NewAdderClient(cc grpc.ClientConnInterface) AdderClient {
return &adderClient{cc}
}
func (c *adderClient) Add(ctx context.Context, in *AddReq, opts ...grpc.CallOption) (*AddResp, error) {
out := new(AddResp)
err := c.cc.Invoke(ctx, "/add.adder/add", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// AdderServer is the server API for Adder service.
type AdderServer interface {
Add(context.Context, *AddReq) (*AddResp, error)
}
// UnimplementedAdderServer can be embedded to have forward compatible implementations.
type UnimplementedAdderServer struct {
}
func (*UnimplementedAdderServer) Add(context.Context, *AddReq) (*AddResp, error) {
return nil, status.Errorf(codes.Unimplemented, "method Add not implemented")
}
func RegisterAdderServer(s *grpc.Server, srv AdderServer) {
s.RegisterService(&_Adder_serviceDesc, srv)
}
func _Adder_Add_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AddReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AdderServer).Add(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/add.adder/Add",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AdderServer).Add(ctx, req.(*AddReq))
}
return interceptor(ctx, in, info, handler)
}
var _Adder_serviceDesc = grpc.ServiceDesc{
ServiceName: "add.adder",
HandlerType: (*AdderServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "add",
Handler: _Adder_Add_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "add.proto",
}

View File

@@ -8,7 +8,7 @@ package adder
import ( import (
"context" "context"
add "bookstore/rpc/add/internal/pb" "bookstore/rpc/add/add"
"github.com/tal-tech/go-zero/zrpc" "github.com/tal-tech/go-zero/zrpc"
) )
@@ -33,6 +33,6 @@ func NewAdder(cli zrpc.Client) Adder {
} }
func (m *defaultAdder) Add(ctx context.Context, in *AddReq) (*AddResp, error) { func (m *defaultAdder) Add(ctx context.Context, in *AddReq) (*AddResp, error) {
adder := add.NewAdderClient(m.cli.Conn()) client := add.NewAdderClient(m.cli.Conn())
return adder.Add(ctx, in) return client.Add(ctx, in)
} }

View File

@@ -1,49 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: adder.go
// Package adder is a generated GoMock package.
package adder
import (
context "context"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockAdder is a mock of Adder interface
type MockAdder struct {
ctrl *gomock.Controller
recorder *MockAdderMockRecorder
}
// MockAdderMockRecorder is the mock recorder for MockAdder
type MockAdderMockRecorder struct {
mock *MockAdder
}
// NewMockAdder creates a new mock instance
func NewMockAdder(ctrl *gomock.Controller) *MockAdder {
mock := &MockAdder{ctrl: ctrl}
mock.recorder = &MockAdderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockAdder) EXPECT() *MockAdderMockRecorder {
return m.recorder
}
// Add mocks base method
func (m *MockAdder) Add(ctx context.Context, in *AddReq) (*AddResp, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Add", ctx, in)
ret0, _ := ret[0].(*AddResp)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Add indicates an expected call of Add
func (mr *MockAdderMockRecorder) Add(ctx, in interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockAdder)(nil).Add), ctx, in)
}

View File

@@ -8,6 +8,5 @@ import (
type Config struct { type Config struct {
zrpc.RpcServerConf zrpc.RpcServerConf
DataSource string DataSource string
Table string
Cache cache.CacheConf Cache cache.CacheConf
} }

View File

@@ -3,7 +3,7 @@ package logic
import ( import (
"context" "context"
add "bookstore/rpc/add/internal/pb" add "bookstore/rpc/add/adder"
"bookstore/rpc/add/internal/svc" "bookstore/rpc/add/internal/svc"
"bookstore/rpc/model" "bookstore/rpc/model"

View File

@@ -1,167 +0,0 @@
// Code generated by protoc-gen-go.
// source: add.proto
// DO NOT EDIT!
/*
Package add is a generated protocol buffer package.
It is generated from these files:
add.proto
It has these top-level messages:
AddReq
AddResp
*/
package add
import (
"fmt"
"math"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type AddReq struct {
Book string `protobuf:"bytes,1,opt,name=book" json:"book,omitempty"`
Price int64 `protobuf:"varint,2,opt,name=price" json:"price,omitempty"`
}
func (m *AddReq) Reset() { *m = AddReq{} }
func (m *AddReq) String() string { return proto.CompactTextString(m) }
func (*AddReq) ProtoMessage() {}
func (*AddReq) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *AddReq) GetBook() string {
if m != nil {
return m.Book
}
return ""
}
func (m *AddReq) GetPrice() int64 {
if m != nil {
return m.Price
}
return 0
}
type AddResp struct {
Ok bool `protobuf:"varint,1,opt,name=ok" json:"ok,omitempty"`
}
func (m *AddResp) Reset() { *m = AddResp{} }
func (m *AddResp) String() string { return proto.CompactTextString(m) }
func (*AddResp) ProtoMessage() {}
func (*AddResp) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *AddResp) GetOk() bool {
if m != nil {
return m.Ok
}
return false
}
func init() {
proto.RegisterType((*AddReq)(nil), "add.addReq")
proto.RegisterType((*AddResp)(nil), "add.addResp")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for Adder service
type AdderClient interface {
Add(ctx context.Context, in *AddReq, opts ...grpc.CallOption) (*AddResp, error)
}
type adderClient struct {
cc *grpc.ClientConn
}
func NewAdderClient(cc *grpc.ClientConn) AdderClient {
return &adderClient{cc}
}
func (c *adderClient) Add(ctx context.Context, in *AddReq, opts ...grpc.CallOption) (*AddResp, error) {
out := new(AddResp)
err := grpc.Invoke(ctx, "/add.adder/add", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for Adder service
type AdderServer interface {
Add(context.Context, *AddReq) (*AddResp, error)
}
func RegisterAdderServer(s *grpc.Server, srv AdderServer) {
s.RegisterService(&_Adder_serviceDesc, srv)
}
func _Adder_Add_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AddReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AdderServer).Add(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/add.adder/Add",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AdderServer).Add(ctx, req.(*AddReq))
}
return interceptor(ctx, in, info, handler)
}
var _Adder_serviceDesc = grpc.ServiceDesc{
ServiceName: "add.adder",
HandlerType: (*AdderServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "add",
Handler: _Adder_Add_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "add.proto",
}
func init() { proto.RegisterFile("add.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 136 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4c, 0x4c, 0x49, 0xd1,
0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x4e, 0x4c, 0x49, 0x51, 0x32, 0xe2, 0x62, 0x4b, 0x4c,
0x49, 0x09, 0x4a, 0x2d, 0x14, 0x12, 0xe2, 0x62, 0x49, 0xca, 0xcf, 0xcf, 0x96, 0x60, 0x54, 0x60,
0xd4, 0xe0, 0x0c, 0x02, 0xb3, 0x85, 0x44, 0xb8, 0x58, 0x0b, 0x8a, 0x32, 0x93, 0x53, 0x25, 0x98,
0x14, 0x18, 0x35, 0x98, 0x83, 0x20, 0x1c, 0x25, 0x49, 0x2e, 0x76, 0xb0, 0x9e, 0xe2, 0x02, 0x21,
0x3e, 0x2e, 0x26, 0xa8, 0x16, 0x8e, 0x20, 0xa6, 0xfc, 0x6c, 0x23, 0x4d, 0x2e, 0xd6, 0xc4, 0x94,
0x94, 0xd4, 0x22, 0x21, 0x05, 0x2e, 0x90, 0xf1, 0x42, 0xdc, 0x7a, 0x20, 0xfb, 0x20, 0x36, 0x48,
0xf1, 0x20, 0x38, 0xc5, 0x05, 0x49, 0x6c, 0x60, 0x57, 0x18, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff,
0xe2, 0x6d, 0xb5, 0x91, 0x92, 0x00, 0x00, 0x00,
}

View File

@@ -6,8 +6,8 @@ package server
import ( import (
"context" "context"
"bookstore/rpc/add/add"
"bookstore/rpc/add/internal/logic" "bookstore/rpc/add/internal/logic"
add "bookstore/rpc/add/internal/pb"
"bookstore/rpc/add/internal/svc" "bookstore/rpc/add/internal/svc"
) )

View File

@@ -9,12 +9,12 @@ import (
type ServiceContext struct { type ServiceContext struct {
c config.Config c config.Config
Model *model.BookModel Model model.BookModel
} }
func NewServiceContext(c config.Config) *ServiceContext { func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{ return &ServiceContext{
c: c, c: c,
Model: model.NewBookModel(sqlx.NewMysql(c.DataSource), c.Cache, c.Table), Model: model.NewBookModel(sqlx.NewMysql(c.DataSource), c.Cache),
} }
} }

View File

@@ -7,8 +7,8 @@ import (
"flag" "flag"
"fmt" "fmt"
"bookstore/rpc/check/check"
"bookstore/rpc/check/internal/config" "bookstore/rpc/check/internal/config"
check "bookstore/rpc/check/internal/pb"
"bookstore/rpc/check/internal/server" "bookstore/rpc/check/internal/server"
"bookstore/rpc/check/internal/svc" "bookstore/rpc/check/internal/svc"

View File

@@ -0,0 +1,306 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.25.0
// protoc v3.14.0
// source: check.proto
package check
import (
context "context"
proto "github.com/golang/protobuf/proto"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type CheckReq struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Book string `protobuf:"bytes,1,opt,name=book,proto3" json:"book,omitempty"`
}
func (x *CheckReq) Reset() {
*x = CheckReq{}
if protoimpl.UnsafeEnabled {
mi := &file_check_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CheckReq) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CheckReq) ProtoMessage() {}
func (x *CheckReq) ProtoReflect() protoreflect.Message {
mi := &file_check_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CheckReq.ProtoReflect.Descriptor instead.
func (*CheckReq) Descriptor() ([]byte, []int) {
return file_check_proto_rawDescGZIP(), []int{0}
}
func (x *CheckReq) GetBook() string {
if x != nil {
return x.Book
}
return ""
}
type CheckResp struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Found bool `protobuf:"varint,1,opt,name=found,proto3" json:"found,omitempty"`
Price int64 `protobuf:"varint,2,opt,name=price,proto3" json:"price,omitempty"`
}
func (x *CheckResp) Reset() {
*x = CheckResp{}
if protoimpl.UnsafeEnabled {
mi := &file_check_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CheckResp) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CheckResp) ProtoMessage() {}
func (x *CheckResp) ProtoReflect() protoreflect.Message {
mi := &file_check_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CheckResp.ProtoReflect.Descriptor instead.
func (*CheckResp) Descriptor() ([]byte, []int) {
return file_check_proto_rawDescGZIP(), []int{1}
}
func (x *CheckResp) GetFound() bool {
if x != nil {
return x.Found
}
return false
}
func (x *CheckResp) GetPrice() int64 {
if x != nil {
return x.Price
}
return 0
}
var File_check_proto protoreflect.FileDescriptor
var file_check_proto_rawDesc = []byte{
0x0a, 0x0b, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x63,
0x68, 0x65, 0x63, 0x6b, 0x22, 0x1e, 0x0a, 0x08, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x71,
0x12, 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x6f, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04,
0x62, 0x6f, 0x6f, 0x6b, 0x22, 0x37, 0x0a, 0x09, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x73,
0x70, 0x12, 0x14, 0x0a, 0x05, 0x66, 0x6f, 0x75, 0x6e, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
0x52, 0x05, 0x66, 0x6f, 0x75, 0x6e, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x70, 0x72, 0x69, 0x63, 0x65,
0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x70, 0x72, 0x69, 0x63, 0x65, 0x32, 0x35, 0x0a,
0x07, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x65, 0x72, 0x12, 0x2a, 0x0a, 0x05, 0x63, 0x68, 0x65, 0x63,
0x6b, 0x12, 0x0f, 0x2e, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x2e, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x52,
0x65, 0x71, 0x1a, 0x10, 0x2e, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x2e, 0x63, 0x68, 0x65, 0x63, 0x6b,
0x52, 0x65, 0x73, 0x70, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_check_proto_rawDescOnce sync.Once
file_check_proto_rawDescData = file_check_proto_rawDesc
)
func file_check_proto_rawDescGZIP() []byte {
file_check_proto_rawDescOnce.Do(func() {
file_check_proto_rawDescData = protoimpl.X.CompressGZIP(file_check_proto_rawDescData)
})
return file_check_proto_rawDescData
}
var file_check_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_check_proto_goTypes = []interface{}{
(*CheckReq)(nil), // 0: check.checkReq
(*CheckResp)(nil), // 1: check.checkResp
}
var file_check_proto_depIdxs = []int32{
0, // 0: check.checker.check:input_type -> check.checkReq
1, // 1: check.checker.check:output_type -> check.checkResp
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_check_proto_init() }
func file_check_proto_init() {
if File_check_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_check_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CheckReq); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_check_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CheckResp); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_check_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_check_proto_goTypes,
DependencyIndexes: file_check_proto_depIdxs,
MessageInfos: file_check_proto_msgTypes,
}.Build()
File_check_proto = out.File
file_check_proto_rawDesc = nil
file_check_proto_goTypes = nil
file_check_proto_depIdxs = nil
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConnInterface
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion6
// CheckerClient is the client API for Checker service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type CheckerClient interface {
Check(ctx context.Context, in *CheckReq, opts ...grpc.CallOption) (*CheckResp, error)
}
type checkerClient struct {
cc grpc.ClientConnInterface
}
func NewCheckerClient(cc grpc.ClientConnInterface) CheckerClient {
return &checkerClient{cc}
}
func (c *checkerClient) Check(ctx context.Context, in *CheckReq, opts ...grpc.CallOption) (*CheckResp, error) {
out := new(CheckResp)
err := c.cc.Invoke(ctx, "/check.checker/check", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// CheckerServer is the server API for Checker service.
type CheckerServer interface {
Check(context.Context, *CheckReq) (*CheckResp, error)
}
// UnimplementedCheckerServer can be embedded to have forward compatible implementations.
type UnimplementedCheckerServer struct {
}
func (*UnimplementedCheckerServer) Check(context.Context, *CheckReq) (*CheckResp, error) {
return nil, status.Errorf(codes.Unimplemented, "method Check not implemented")
}
func RegisterCheckerServer(s *grpc.Server, srv CheckerServer) {
s.RegisterService(&_Checker_serviceDesc, srv)
}
func _Checker_Check_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CheckReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(CheckerServer).Check(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/check.checker/Check",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(CheckerServer).Check(ctx, req.(*CheckReq))
}
return interceptor(ctx, in, info, handler)
}
var _Checker_serviceDesc = grpc.ServiceDesc{
ServiceName: "check.checker",
HandlerType: (*CheckerServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "check",
Handler: _Checker_Check_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "check.proto",
}

View File

@@ -8,7 +8,7 @@ package checker
import ( import (
"context" "context"
check "bookstore/rpc/check/internal/pb" "bookstore/rpc/check/check"
"github.com/tal-tech/go-zero/zrpc" "github.com/tal-tech/go-zero/zrpc"
) )
@@ -33,6 +33,6 @@ func NewChecker(cli zrpc.Client) Checker {
} }
func (m *defaultChecker) Check(ctx context.Context, in *CheckReq) (*CheckResp, error) { func (m *defaultChecker) Check(ctx context.Context, in *CheckReq) (*CheckResp, error) {
checker := check.NewCheckerClient(m.cli.Conn()) client := check.NewCheckerClient(m.cli.Conn())
return checker.Check(ctx, in) return client.Check(ctx, in)
} }

View File

@@ -1,49 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: checker.go
// Package checker is a generated GoMock package.
package checker
import (
context "context"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockChecker is a mock of Checker interface
type MockChecker struct {
ctrl *gomock.Controller
recorder *MockCheckerMockRecorder
}
// MockCheckerMockRecorder is the mock recorder for MockChecker
type MockCheckerMockRecorder struct {
mock *MockChecker
}
// NewMockChecker creates a new mock instance
func NewMockChecker(ctrl *gomock.Controller) *MockChecker {
mock := &MockChecker{ctrl: ctrl}
mock.recorder = &MockCheckerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockChecker) EXPECT() *MockCheckerMockRecorder {
return m.recorder
}
// Check mocks base method
func (m *MockChecker) Check(ctx context.Context, in *CheckReq) (*CheckResp, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Check", ctx, in)
ret0, _ := ret[0].(*CheckResp)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Check indicates an expected call of Check
func (mr *MockCheckerMockRecorder) Check(ctx, in interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Check", reflect.TypeOf((*MockChecker)(nil).Check), ctx, in)
}

View File

@@ -8,6 +8,5 @@ import (
type Config struct { type Config struct {
zrpc.RpcServerConf zrpc.RpcServerConf
DataSource string DataSource string
Table string
Cache cache.CacheConf Cache cache.CacheConf
} }

View File

@@ -3,7 +3,7 @@ package logic
import ( import (
"context" "context"
check "bookstore/rpc/check/internal/pb" check "bookstore/rpc/check/checker"
"bookstore/rpc/check/internal/svc" "bookstore/rpc/check/internal/svc"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"

View File

@@ -1,167 +0,0 @@
// Code generated by protoc-gen-go.
// source: check.proto
// DO NOT EDIT!
/*
Package check is a generated protocol buffer package.
It is generated from these files:
check.proto
It has these top-level messages:
CheckReq
CheckResp
*/
package check
import (
"fmt"
"math"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type CheckReq struct {
Book string `protobuf:"bytes,1,opt,name=book" json:"book,omitempty"`
}
func (m *CheckReq) Reset() { *m = CheckReq{} }
func (m *CheckReq) String() string { return proto.CompactTextString(m) }
func (*CheckReq) ProtoMessage() {}
func (*CheckReq) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *CheckReq) GetBook() string {
if m != nil {
return m.Book
}
return ""
}
type CheckResp struct {
Found bool `protobuf:"varint,1,opt,name=found" json:"found,omitempty"`
Price int64 `protobuf:"varint,2,opt,name=price" json:"price,omitempty"`
}
func (m *CheckResp) Reset() { *m = CheckResp{} }
func (m *CheckResp) String() string { return proto.CompactTextString(m) }
func (*CheckResp) ProtoMessage() {}
func (*CheckResp) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *CheckResp) GetFound() bool {
if m != nil {
return m.Found
}
return false
}
func (m *CheckResp) GetPrice() int64 {
if m != nil {
return m.Price
}
return 0
}
func init() {
proto.RegisterType((*CheckReq)(nil), "check.checkReq")
proto.RegisterType((*CheckResp)(nil), "check.checkResp")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for Checker service
type CheckerClient interface {
Check(ctx context.Context, in *CheckReq, opts ...grpc.CallOption) (*CheckResp, error)
}
type checkerClient struct {
cc *grpc.ClientConn
}
func NewCheckerClient(cc *grpc.ClientConn) CheckerClient {
return &checkerClient{cc}
}
func (c *checkerClient) Check(ctx context.Context, in *CheckReq, opts ...grpc.CallOption) (*CheckResp, error) {
out := new(CheckResp)
err := grpc.Invoke(ctx, "/check.checker/check", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for Checker service
type CheckerServer interface {
Check(context.Context, *CheckReq) (*CheckResp, error)
}
func RegisterCheckerServer(s *grpc.Server, srv CheckerServer) {
s.RegisterService(&_Checker_serviceDesc, srv)
}
func _Checker_Check_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CheckReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(CheckerServer).Check(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/check.checker/Check",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(CheckerServer).Check(ctx, req.(*CheckReq))
}
return interceptor(ctx, in, info, handler)
}
var _Checker_serviceDesc = grpc.ServiceDesc{
ServiceName: "check.checker",
HandlerType: (*CheckerServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "check",
Handler: _Checker_Check_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "check.proto",
}
func init() { proto.RegisterFile("check.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 136 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4e, 0xce, 0x48, 0x4d,
0xce, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x73, 0x94, 0xe4, 0xb8, 0x38, 0xc0,
0x8c, 0xa0, 0xd4, 0x42, 0x21, 0x21, 0x2e, 0x96, 0xa4, 0xfc, 0xfc, 0x6c, 0x09, 0x46, 0x05, 0x46,
0x0d, 0xce, 0x20, 0x30, 0x5b, 0xc9, 0x9c, 0x8b, 0x13, 0x2a, 0x5f, 0x5c, 0x20, 0x24, 0xc2, 0xc5,
0x9a, 0x96, 0x5f, 0x9a, 0x97, 0x02, 0x56, 0xc1, 0x11, 0x04, 0xe1, 0x80, 0x44, 0x0b, 0x8a, 0x32,
0x93, 0x53, 0x25, 0x98, 0x14, 0x18, 0x35, 0x98, 0x83, 0x20, 0x1c, 0x23, 0x53, 0x2e, 0x76, 0xb0,
0xc6, 0xd4, 0x22, 0x21, 0x2d, 0x2e, 0x88, 0x65, 0x42, 0xfc, 0x7a, 0x10, 0x17, 0xc0, 0x6c, 0x94,
0x12, 0x40, 0x15, 0x28, 0x2e, 0x48, 0x62, 0x03, 0xbb, 0xce, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff,
0x6e, 0x6f, 0xa7, 0x1d, 0xac, 0x00, 0x00, 0x00,
}

View File

@@ -6,8 +6,8 @@ package server
import ( import (
"context" "context"
"bookstore/rpc/check/check"
"bookstore/rpc/check/internal/logic" "bookstore/rpc/check/internal/logic"
check "bookstore/rpc/check/internal/pb"
"bookstore/rpc/check/internal/svc" "bookstore/rpc/check/internal/svc"
) )

View File

@@ -9,12 +9,12 @@ import (
type ServiceContext struct { type ServiceContext struct {
c config.Config c config.Config
Model *model.BookModel Model model.BookModel
} }
func NewServiceContext(c config.Config) *ServiceContext { func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{ return &ServiceContext{
c: c, c: c,
Model: model.NewBookModel(sqlx.NewMysql(c.DataSource), c.Cache, c.Table), Model: model.NewBookModel(sqlx.NewMysql(c.DataSource), c.Cache),
} }
} }

View File

@@ -18,11 +18,18 @@ var (
bookRowsExpectAutoSet = strings.Join(stringx.Remove(bookFieldNames, "create_time", "update_time"), ",") bookRowsExpectAutoSet = strings.Join(stringx.Remove(bookFieldNames, "create_time", "update_time"), ",")
bookRowsWithPlaceHolder = strings.Join(stringx.Remove(bookFieldNames, "book", "create_time", "update_time"), "=?,") + "=?" bookRowsWithPlaceHolder = strings.Join(stringx.Remove(bookFieldNames, "book", "create_time", "update_time"), "=?,") + "=?"
bookPrefix = "cache#Book#book#" cacheBookPrefix = "cache#Book#book#"
) )
type ( type (
BookModel struct { BookModel interface {
Insert(data Book) (sql.Result, error)
FindOne(book string) (*Book, error)
Update(data Book) error
Delete(book string) error
}
defaultBookModel struct {
sqlc.CachedConn sqlc.CachedConn
table string table string
} }
@@ -33,23 +40,25 @@ type (
} }
) )
func NewBookModel(conn sqlx.SqlConn, c cache.CacheConf, table string) *BookModel { func NewBookModel(conn sqlx.SqlConn, c cache.CacheConf) BookModel {
return &BookModel{ return &defaultBookModel{
CachedConn: sqlc.NewConn(conn, c), CachedConn: sqlc.NewConn(conn, c),
table: table, table: "book",
} }
} }
func (m *BookModel) Insert(data Book) (sql.Result, error) { func (m *defaultBookModel) Insert(data Book) (sql.Result, error) {
query := `insert into ` + m.table + ` (` + bookRowsExpectAutoSet + `) values (?, ?)` query := fmt.Sprintf("insert into %s (%s) values (?, ?)", m.table, bookRowsExpectAutoSet)
return m.ExecNoCache(query, data.Book, data.Price) ret, err := m.ExecNoCache(query, data.Book, data.Price)
return ret, err
} }
func (m *BookModel) FindOne(book string) (*Book, error) { func (m *defaultBookModel) FindOne(book string) (*Book, error) {
bookKey := fmt.Sprintf("%s%v", bookPrefix, book) bookKey := fmt.Sprintf("%s%v", cacheBookPrefix, book)
var resp Book var resp Book
err := m.QueryRow(&resp, bookKey, func(conn sqlx.SqlConn, v interface{}) error { err := m.QueryRow(&resp, bookKey, func(conn sqlx.SqlConn, v interface{}) error {
query := `select ` + bookRows + ` from ` + m.table + ` where book = ? limit 1` query := fmt.Sprintf("select %s from %s where book = ? limit 1", bookRows, m.table)
return conn.QueryRow(v, query, book) return conn.QueryRow(v, query, book)
}) })
switch err { switch err {
@@ -62,20 +71,30 @@ func (m *BookModel) FindOne(book string) (*Book, error) {
} }
} }
func (m *BookModel) Update(data Book) error { func (m *defaultBookModel) Update(data Book) error {
bookKey := fmt.Sprintf("%s%v", bookPrefix, data.Book) bookKey := fmt.Sprintf("%s%v", cacheBookPrefix, data.Book)
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := `update ` + m.table + ` set ` + bookRowsWithPlaceHolder + ` where book = ?` query := fmt.Sprintf("update %s set %s where book = ?", m.table, bookRowsWithPlaceHolder)
return conn.Exec(query, data.Price, data.Book) return conn.Exec(query, data.Price, data.Book)
}, bookKey) }, bookKey)
return err return err
} }
func (m *BookModel) Delete(book string) error { func (m *defaultBookModel) Delete(book string) error {
bookKey := fmt.Sprintf("%s%v", bookPrefix, book)
bookKey := fmt.Sprintf("%s%v", cacheBookPrefix, book)
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := `delete from ` + m.table + ` where book = ?` query := fmt.Sprintf("delete from %s where book = ?", m.table)
return conn.Exec(query, book) return conn.Exec(query, book)
}, bookKey) }, bookKey)
return err return err
} }
func (m *defaultBookModel) formatPrimary(primary interface{}) string {
return fmt.Sprintf("%s%v", cacheBookPrefix, primary)
}
func (m *defaultBookModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
query := fmt.Sprintf("select %s from %s where book = ? limit 1", bookRows, m.table)
return conn.QueryRow(v, query, primary)
}

0
example/bookstore/rpc/model/vars.go Executable file → Normal file
View File

View File

@@ -37,6 +37,7 @@ github.com/coreos/go-systemd/v22 v22.0.0 h1:XJIw/+VlJ+87J+doOxznsAWIdmWuViOVhkQa
github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM=
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -49,6 +50,7 @@ github.com/dsymonds/gotoc v0.0.0-20160928043926-5aebcfc91819 h1:9778zj477h/VauD8
github.com/dsymonds/gotoc v0.0.0-20160928043926-5aebcfc91819/go.mod h1:MvzMVHq8BH2Ji/o8TGDocVA70byvLrAgFTxkEnmjO4Y= github.com/dsymonds/gotoc v0.0.0-20160928043926-5aebcfc91819/go.mod h1:MvzMVHq8BH2Ji/o8TGDocVA70byvLrAgFTxkEnmjO4Y=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4 h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4 h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/emicklei/proto v1.9.0 h1:l0QiNT6Qs7Yj0Mb4X6dnWBQer4ebei2BFcgQLbGqUDc=
github.com/emicklei/proto v1.9.0/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A= github.com/emicklei/proto v1.9.0/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@@ -131,6 +133,7 @@ github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334 h1:VHgatEHNcBFEB7inlalqfNqw65aNkM1lGX2yt3NmbS8= github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334 h1:VHgatEHNcBFEB7inlalqfNqw65aNkM1lGX2yt3NmbS8=
github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE= github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE=
github.com/iancoleman/strcase v0.1.2 h1:gnomlvw9tnV3ITTAxzKSgTF+8kFWcU/f+TgttpXGz1U=
github.com/iancoleman/strcase v0.1.2/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE= github.com/iancoleman/strcase v0.1.2/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
@@ -220,6 +223,7 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/shirou/gopsutil v0.0.0-20180427012116-c95755e4bcd7/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v0.0.0-20180427012116-c95755e4bcd7/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc= github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc=
@@ -252,6 +256,7 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6/go.mod h1
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA= github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA=
github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/urfave/cli v1.22.5 h1:lNq9sAHXK2qfdI8W+GRItjCEkI+2oR4d+MEHy1CKXoU=
github.com/urfave/cli v1.22.5/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.5/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=

8
go.mod
View File

@@ -1,16 +1,18 @@
module github.com/tal-tech/go-zero module github.com/tal-tech/go-zero
go 1.13 go 1.14
require ( require (
github.com/ClickHouse/clickhouse-go v1.4.3 github.com/ClickHouse/clickhouse-go v1.4.3
github.com/DATA-DOG/go-sqlmock v1.4.1 github.com/DATA-DOG/go-sqlmock v1.4.1
github.com/alicebob/miniredis/v2 v2.14.1 github.com/alicebob/miniredis/v2 v2.14.1
github.com/antlr/antlr4 v0.0.0-20210105212045-464bcbc32de2
github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect
github.com/dchest/siphash v1.2.1 github.com/dchest/siphash v1.2.1
github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/emicklei/proto v1.9.0 github.com/emicklei/proto v1.9.0
github.com/fatih/color v1.9.0 // indirect github.com/fatih/color v1.9.0 // indirect
github.com/fatih/structtag v1.2.0
github.com/frankban/quicktest v1.7.2 // indirect github.com/frankban/quicktest v1.7.2 // indirect
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8
github.com/go-redis/redis v6.15.7+incompatible github.com/go-redis/redis v6.15.7+incompatible
@@ -55,10 +57,10 @@ require (
golang.org/x/tools v0.0.0-20200410132612-ae9902aceb98 // indirect golang.org/x/tools v0.0.0-20200410132612-ae9902aceb98 // indirect
google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f // indirect google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f // indirect
google.golang.org/grpc v1.29.1 google.golang.org/grpc v1.29.1
google.golang.org/protobuf v1.25.0 google.golang.org/protobuf v1.25.0 // indirect
gopkg.in/cheggaaa/pb.v1 v1.0.28 gopkg.in/cheggaaa/pb.v1 v1.0.28
gopkg.in/h2non/gock.v1 v1.0.15 gopkg.in/h2non/gock.v1 v1.0.15
gopkg.in/yaml.v2 v2.3.0 gopkg.in/yaml.v2 v2.4.0
honnef.co/go/tools v0.0.1-2020.1.4 // indirect honnef.co/go/tools v0.0.1-2020.1.4 // indirect
sigs.k8s.io/yaml v1.2.0 // indirect sigs.k8s.io/yaml v1.2.0 // indirect
) )

8
go.sum
View File

@@ -16,6 +16,8 @@ github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGn
github.com/alicebob/miniredis/v2 v2.14.1 h1:GjlbSeoJ24bzdLRs13HoMEeaRZx9kg5nHoRW7QV/nCs= github.com/alicebob/miniredis/v2 v2.14.1 h1:GjlbSeoJ24bzdLRs13HoMEeaRZx9kg5nHoRW7QV/nCs=
github.com/alicebob/miniredis/v2 v2.14.1/go.mod h1:uS970Sw5Gs9/iK3yBg0l9Uj9s25wXxSpQUE9EaJ/Blg= github.com/alicebob/miniredis/v2 v2.14.1/go.mod h1:uS970Sw5Gs9/iK3yBg0l9Uj9s25wXxSpQUE9EaJ/Blg=
github.com/antihax/optional v0.0.0-20180407024304-ca021399b1a6/go.mod h1:V8iCPQYkqmusNa815XgQio277wI47sdRh1dUOLdyC6Q= github.com/antihax/optional v0.0.0-20180407024304-ca021399b1a6/go.mod h1:V8iCPQYkqmusNa815XgQio277wI47sdRh1dUOLdyC6Q=
github.com/antlr/antlr4 v0.0.0-20210105212045-464bcbc32de2 h1:rL2miklL5rhxUaZO7hntBcy/VHaiyuPQ4EJoy/NMwaM=
github.com/antlr/antlr4 v0.0.0-20210105212045-464bcbc32de2/go.mod h1:T7PbCXFs94rrTttyxjbyT5+/1V8T2TYDejxUfHJjw1Y=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@@ -63,6 +65,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/frankban/quicktest v1.7.2 h1:2QxQoC1TS09S7fhCPsrvqYdvP1H5M1P1ih5ABm3BTYk= github.com/frankban/quicktest v1.7.2 h1:2QxQoC1TS09S7fhCPsrvqYdvP1H5M1P1ih5ABm3BTYk=
github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o= github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
@@ -450,8 +454,8 @@ gopkg.in/yaml.v2 v2.2.5 h1:ymVxjfMaHvXD8RqPRmzHHsB3VvucivSkIAvJFDI5O3c=
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=

View File

@@ -129,7 +129,7 @@ go get -u github.com/tal-tech/go-zero
the .api files also can be generate by goctl, like below: the .api files also can be generate by goctl, like below:
```shell ```shell
goctl api -o greet.api goctl api -o greet.api
``` ```
3. generate the go server side code 3. generate the go server side code
@@ -208,3 +208,7 @@ goctl api -o greet.api
* [Rapid development of microservice systems](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl-en.md) * [Rapid development of microservice systems](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl-en.md)
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/doc/bookstore-en.md) * [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/doc/bookstore-en.md)
## 9. Chat group
Join the chat via https://discord.gg/4JQvC5A4Fe

View File

@@ -5,8 +5,9 @@
[English](readme-en.md) | 简体中文 [English](readme-en.md) | 简体中文
[![Go](https://github.com/tal-tech/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/tal-tech/go-zero/actions) [![Go](https://github.com/tal-tech/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/tal-tech/go-zero/actions)
[![codecov](https://codecov.io/gh/tal-tech/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/tal-tech/go-zero)
[![Go Report Card](https://goreportcard.com/badge/github.com/tal-tech/go-zero)](https://goreportcard.com/report/github.com/tal-tech/go-zero) [![Go Report Card](https://goreportcard.com/badge/github.com/tal-tech/go-zero)](https://goreportcard.com/report/github.com/tal-tech/go-zero)
[![goproxy](https://goproxy.cn/stats/github.com/tal-tech/go-zero/badges/download-count.svg)](https://goproxy.cn/stats/github.com/tal-tech/go-zero/badges/download-count.svg)
[![codecov](https://codecov.io/gh/tal-tech/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/tal-tech/go-zero)
[![Release](https://img.shields.io/github/v/release/tal-tech/go-zero.svg?style=flat-square)](https://github.com/tal-tech/go-zero) [![Release](https://img.shields.io/github/v/release/tal-tech/go-zero.svg?style=flat-square)](https://github.com/tal-tech/go-zero)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
@@ -95,7 +96,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
[快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md) [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
[快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md) [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore.md)
1. 安装 goctl 工具 1. 安装 goctl 工具
@@ -156,13 +157,13 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
## 7. 文档 ## 7. 文档
* API 文档 (逐步完善中) * API 文档
[https://www.yuque.com/tal-tech/go-zero](https://www.yuque.com/tal-tech/go-zero) [https://www.yuque.com/tal-tech/go-zero](https://www.yuque.com/tal-tech/go-zero)
* awesome 系列 * awesome 系列(全部收录于『微服务实践』公众号)
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md) * [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md) * [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore.md)
* [goctl 使用帮助](https://github.com/tal-tech/zero-doc/blob/main/doc/goctl.md) * [goctl 使用帮助](https://github.com/tal-tech/zero-doc/blob/main/doc/goctl.md)
* [通过 MapReduce 降低服务响应时间](https://github.com/tal-tech/zero-doc/blob/main/doc/mapreduce.md) * [通过 MapReduce 降低服务响应时间](https://github.com/tal-tech/zero-doc/blob/main/doc/mapreduce.md)
* [关键字替换和敏感词过滤工具](https://github.com/tal-tech/zero-doc/blob/main/doc/keywords.md) * [关键字替换和敏感词过滤工具](https://github.com/tal-tech/zero-doc/blob/main/doc/keywords.md)
@@ -172,7 +173,21 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
* [文本序列化和反序列化](https://github.com/tal-tech/zero-doc/blob/main/doc/mapping.md) * [文本序列化和反序列化](https://github.com/tal-tech/zero-doc/blob/main/doc/mapping.md)
* [快速构建 jwt 鉴权认证](https://github.com/tal-tech/zero-doc/blob/main/doc/jwt.md) * [快速构建 jwt 鉴权认证](https://github.com/tal-tech/zero-doc/blob/main/doc/jwt.md)
## 8. 微信交流群 * 精选 `goctl` 插件
| 插件 | 用途 |
| ------------- |:-------------|
| [goctl-swagger](https://github.com/zeromicro/goctl-swagger) | 一键生成 `api` 的 `swagger` 文档 |
| [goctl-android](https://github.com/zeromicro/goctl-android) | 生成 `java (android)` 端 `http client` 请求代码 |
| [goctl-go-compact](https://github.com/zeromicro/goctl-go-compact) | 合并 `api` 里同一个 `group` 里的 `handler` 到一个 `go` 文件 |
## 8. 微信公众号
`go-zero` 相关文章都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注,也可以通过公众号私信我 👏
<img src="https://gitee.com/kevwan/static/raw/master/images/wechat-micro.jpg" alt="wechat" width="300" />
## 9. 微信交流群
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。 如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。

171
rest/engine_test.go Normal file
View File

@@ -0,0 +1,171 @@
package rest
import (
"errors"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/conf"
)
func TestNewEngine(t *testing.T) {
yamls := []string{
`Name: foo
Port: 54321
`,
`Name: foo
Port: 54321
CpuThreshold: 500
`,
`Name: foo
Port: 54321
CpuThreshold: 500
Verbose: true
`,
}
routes := []featuredRoutes{
{
jwt: jwtSetting{},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
prevSecret: "thesecret",
},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
SignatureConf: SignatureConf{
Strict: true,
},
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
SignatureConf: SignatureConf{
Strict: true,
PrivateKeys: []PrivateKeyConf{
{
Fingerprint: "a",
KeyFile: "b",
},
},
},
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
},
}
for _, yaml := range yamls {
for _, route := range routes {
var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf))
ng := newEngine(cnf)
ng.AddRoutes(route)
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
}
})
assert.NotNil(t, ng.StartWithRouter(mockedRouter{}))
}
}
}
type mockedRouter struct {
}
func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
}
func (m mockedRouter) Handle(method string, path string, handler http.Handler) error {
return errors.New("foo")
}
func (m mockedRouter) SetNotFoundHandler(handler http.Handler) {
}
func (m mockedRouter) SetNotAllowedHandler(handler http.Handler) {
}

View File

@@ -46,18 +46,18 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
parser := token.NewTokenParser() parser := token.NewTokenParser()
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret) tok, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
if err != nil { if err != nil {
unauthorized(w, r, err, authOpts.Callback) unauthorized(w, r, err, authOpts.Callback)
return return
} }
if !token.Valid { if !tok.Valid {
unauthorized(w, r, errInvalidToken, authOpts.Callback) unauthorized(w, r, errInvalidToken, authOpts.Callback)
return return
} }
claims, ok := token.Claims.(jwt.MapClaims) claims, ok := tok.Claims.(jwt.MapClaims)
if !ok { if !ok {
unauthorized(w, r, errNoClaims, authOpts.Callback) unauthorized(w, r, errNoClaims, authOpts.Callback)
return return
@@ -122,6 +122,12 @@ func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
} }
} }
func (grw *guardedResponseWriter) Flush() {
if flusher, ok := grw.writer.(http.Flusher); ok {
flusher.Flush()
}
}
func (grw *guardedResponseWriter) Header() http.Header { func (grw *guardedResponseWriter) Header() http.Header {
return grw.writer.Header() return grw.writer.Header()
} }

View File

@@ -41,6 +41,10 @@ func TestAuthHandler(t *testing.T) {
w.Header().Set("X-Test", "test") w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content")) _, err := w.Write([]byte("content"))
assert.Nil(t, err) assert.Nil(t, err)
flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
})) }))
resp := httptest.NewRecorder() resp := httptest.NewRecorder()

View File

@@ -83,6 +83,12 @@ func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter {
} }
} }
func (w *cryptionResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *cryptionResponseWriter) Header() http.Header { func (w *cryptionResponseWriter) Header() http.Header {
return w.ResponseWriter.Header() return w.ResponseWriter.Header()
} }

View File

@@ -87,3 +87,19 @@ func TestCryptionHandlerWriteHeader(t *testing.T) {
handler.ServeHTTP(recorder, req) handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code) assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
} }
func TestCryptionHandlerFlush(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/any", nil)
handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(respText))
flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
assert.Nil(t, err)
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}

View File

@@ -38,6 +38,12 @@ func (w *LoggedResponseWriter) WriteHeader(code int) {
w.code = code w.code = code
} }
func (w *LoggedResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}
func LogHandler(next http.Handler) http.Handler { func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer() timer := utils.NewElapsedTimer()
@@ -68,6 +74,10 @@ func newDetailLoggedResponseWriter(writer *LoggedResponseWriter, buf *bytes.Buff
} }
} }
func (w *DetailLoggedResponseWriter) Flush() {
w.writer.Flush()
}
func (w *DetailLoggedResponseWriter) Header() http.Header { func (w *DetailLoggedResponseWriter) Header() http.Header {
return w.writer.Header() return w.writer.Header()
} }

View File

@@ -30,6 +30,10 @@ func TestLogHandler(t *testing.T) {
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
_, err := w.Write([]byte("content")) _, err := w.Write([]byte("content"))
assert.Nil(t, err) assert.Nil(t, err)
flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
})) }))
resp := httptest.NewRecorder() resp := httptest.NewRecorder()

View File

@@ -10,7 +10,6 @@ import (
) )
const ( const (
multipartFormData = "multipart/form-data"
formKey = "form" formKey = "form"
pathKey = "path" pathKey = "path"
emptyJson = "{}" emptyJson = "{}"
@@ -39,12 +38,12 @@ func Parse(r *http.Request, v interface{}) error {
// Parses the form request. // Parses the form request.
func ParseForm(r *http.Request, v interface{}) error { func ParseForm(r *http.Request, v interface{}) error {
if strings.Contains(r.Header.Get(ContentType), multipartFormData) { if err := r.ParseForm(); err != nil {
if err := r.ParseMultipartForm(maxMemory); err != nil { return err
return err }
}
} else { if err := r.ParseMultipartForm(maxMemory); err != nil {
if err := r.ParseForm(); err != nil { if err != http.ErrNotMultipart {
return err return err
} }
} }

View File

@@ -7,6 +7,12 @@ type WithCodeResponseWriter struct {
Code int Code int
} }
func (w *WithCodeResponseWriter) Flush() {
if flusher, ok := w.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *WithCodeResponseWriter) Header() http.Header { func (w *WithCodeResponseWriter) Header() http.Header {
return w.Writer.Header() return w.Writer.Header()
} }

View File

@@ -0,0 +1,33 @@
package security
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWithCodeResponseWriter(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cw := &WithCodeResponseWriter{Writer: w}
cw.Header().Set("X-Test", "test")
cw.WriteHeader(http.StatusServiceUnavailable)
assert.Equal(t, cw.Code, http.StatusServiceUnavailable)
_, err := cw.Write([]byte("content"))
assert.Nil(t, err)
flusher, ok := http.ResponseWriter(cw).(http.Flusher)
assert.True(t, ok)
flusher.Flush()
})
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}

View File

@@ -64,7 +64,7 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
allow, ok := pr.methodNotAllowed(r.Method, reqPath) allows, ok := pr.methodsAllowed(r.Method, reqPath)
if !ok { if !ok {
pr.handleNotFound(w, r) pr.handleNotFound(w, r)
return return
@@ -73,7 +73,7 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if pr.notAllowed != nil { if pr.notAllowed != nil {
pr.notAllowed.ServeHTTP(w, r) pr.notAllowed.ServeHTTP(w, r)
} else { } else {
w.Header().Set(allowHeader, allow) w.Header().Set(allowHeader, allows)
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
} }
} }
@@ -94,7 +94,7 @@ func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
} }
} }
func (pr *patRouter) methodNotAllowed(method, path string) (string, bool) { func (pr *patRouter) methodsAllowed(method, path string) (string, bool) {
var allows []string var allows []string
for treeMethod, tree := range pr.trees { for treeMethod, tree := range pr.trees {

View File

@@ -1,7 +1,6 @@
package rest package rest
import ( import (
"errors"
"log" "log"
"net/http" "net/http"
@@ -24,6 +23,9 @@ type (
} }
) )
// MustNewServer returns a server with given config of c and options defined in opts.
// Be aware that later RunOption might overwrite previous one that write the same option.
// The process will exit if error occurs.
func MustNewServer(c RestConf, opts ...RunOption) *Server { func MustNewServer(c RestConf, opts ...RunOption) *Server {
engine, err := NewServer(c, opts...) engine, err := NewServer(c, opts...)
if err != nil { if err != nil {
@@ -33,11 +35,9 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server {
return engine return engine
} }
// NewServer returns a server with given config of c and options defined in opts.
// Be aware that later RunOption might overwrite previous one that write the same option.
func NewServer(c RestConf, opts ...RunOption) (*Server, error) { func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
if len(opts) > 1 {
return nil, errors.New("only one RunOption is allowed")
}
if err := c.SetUp(); err != nil { if err := c.SetUp(); err != nil {
return nil, err return nil, err
} }

View File

@@ -8,18 +8,84 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/conf"
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/router" "github.com/tal-tech/go-zero/rest/router"
) )
func TestNewServer(t *testing.T) { func TestNewServer(t *testing.T) {
_, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil)) const configYaml = `
assert.NotNil(t, err) Name: foo
Port: 54321
`
var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
failStart := func(server *Server) {
server.opts.start = func(e *engine) error {
return http.ErrServerClosed
}
}
tests := []struct {
c RestConf
opts []RunOption
fail bool
}{
{
c: RestConf{},
opts: []RunOption{failStart},
fail: true,
},
{
c: cnf,
opts: []RunOption{failStart},
},
{
c: cnf,
opts: []RunOption{WithNotAllowedHandler(nil), failStart},
},
{
c: cnf,
opts: []RunOption{WithNotFoundHandler(nil), failStart},
},
{
c: cnf,
opts: []RunOption{WithUnauthorizedCallback(nil), failStart},
},
{
c: cnf,
opts: []RunOption{WithUnsignedCallback(nil), failStart},
},
}
for _, test := range tests {
srv, err := NewServer(test.c, test.opts...)
if test.fail {
assert.NotNil(t, err)
}
if err != nil {
continue
}
srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
}))
srv.AddRoute(Route{
Method: http.MethodGet,
Path: "/",
Handler: nil,
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
WithJwtTransition("preivous", "thenewone"))
srv.Start()
srv.Stop()
}
} }
func TestWithMiddleware(t *testing.T) { func TestWithMiddleware(t *testing.T) {
m := make(map[string]string) m := make(map[string]string)
router := router.NewRouter() rt := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(w http.ResponseWriter, r *http.Request) {
var v struct { var v struct {
Nickname string `form:"nickname"` Nickname string `form:"nickname"`
@@ -56,14 +122,14 @@ func TestWithMiddleware(t *testing.T) {
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000", "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
} }
for _, route := range rs { for _, route := range rs {
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler)) assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
} }
for _, url := range urls { for _, url := range urls {
r, err := http.NewRequest(http.MethodGet, url, nil) r, err := http.NewRequest(http.MethodGet, url, nil)
assert.Nil(t, err) assert.Nil(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
router.ServeHTTP(rr, r) rt.ServeHTTP(rr, r)
assert.Equal(t, "whatever:200000", rr.Body.String()) assert.Equal(t, "whatever:200000", rr.Body.String())
} }
@@ -76,7 +142,7 @@ func TestWithMiddleware(t *testing.T) {
func TestMultiMiddlewares(t *testing.T) { func TestMultiMiddlewares(t *testing.T) {
m := make(map[string]string) m := make(map[string]string)
router := router.NewRouter() rt := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(w http.ResponseWriter, r *http.Request) {
var v struct { var v struct {
Nickname string `form:"nickname"` Nickname string `form:"nickname"`
@@ -127,14 +193,14 @@ func TestMultiMiddlewares(t *testing.T) {
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000", "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
} }
for _, route := range rs { for _, route := range rs {
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler)) assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler))
} }
for _, url := range urls { for _, url := range urls {
r, err := http.NewRequest(http.MethodGet, url, nil) r, err := http.NewRequest(http.MethodGet, url, nil)
assert.Nil(t, err) assert.Nil(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
router.ServeHTTP(rr, r) rt.ServeHTTP(rr, r)
assert.Equal(t, "whatever:200000200000", rr.Body.String()) assert.Equal(t, "whatever:200000200000", rr.Body.String())
} }

15
tools/goctl/CHANGE_LOG.MD Normal file
View File

@@ -0,0 +1,15 @@
# 2020-01-08
## features ![](https://img.shields.io/static/v1?label=&message=new&color=red)
* reactor api parse by g4
* add syntax lexer for api
* support java-style documentation comments
* support parsing doc and comment
* support import group
> original: [api grammar document](./api/parser/readme.md)
# 2020-01-08
* add change log

View File

@@ -12,18 +12,21 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
) )
const apiTemplate = `info( const apiTemplate = `
syntax = "v1"
info(
title: // TODO: add title title: // TODO: add title
desc: // TODO: add description desc: // TODO: add description
author: "{{.gitUser}}" author: "{{.gitUser}}"
email: "{{.gitEmail}}" email: "{{.gitEmail}}"
) )
type request struct { type request {
// TODO: add members here and delete this comment // TODO: add members here and delete this comment
} }
type response struct { type response {
// TODO: add members here and delete this comment // TODO: add members here and delete this comment
} }

View File

@@ -19,11 +19,7 @@ func DartCommand(c *cli.Context) error {
return errors.New("missing -dir") return errors.New("missing -dir")
} }
p, err := parser.NewParser(apiFile) api, err := parser.Parse(apiFile)
if err != nil {
return err
}
api, err := p.Parse()
if err != nil { if err != nil {
return err return err
} }

View File

@@ -59,7 +59,6 @@ func genData(dir string, api *spec.ApiSpec) error {
return e return e
} }
convertMemberType(api)
return t.Execute(file, api) return t.Execute(file, api)
} }

View File

@@ -1,12 +1,10 @@
package dartgen package dartgen
import ( import (
"log"
"os" "os"
"reflect" "reflect"
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
) )
@@ -39,47 +37,6 @@ func tagGet(tag, k string) (reflect.Value, error) {
return reflect.ValueOf(out), nil return reflect.ValueOf(out), nil
} }
func convertMemberType(api *spec.ApiSpec) {
for i, t := range api.Types {
for j, mem := range t.Members {
api.Types[i].Members[j].Type = goTypeToDart(mem.Type)
}
}
}
func goTypeToDart(t string) string {
t = strings.Replace(t, "*", "", -1)
if strings.HasPrefix(t, "[]") {
return "List<" + goTypeToDart(t[2:]) + ">"
}
if strings.HasPrefix(t, "map") {
tys, e := util.DecomposeType(t)
if e != nil {
log.Fatal(e)
}
if len(tys) != 2 {
log.Fatal("Map type number !=2")
}
return "Map<String," + goTypeToDart(tys[1]) + ">"
}
switch t {
case "string":
return "String"
case "int", "int32", "int64":
return "int"
case "float", "float32", "float64":
return "double"
case "bool":
return "bool"
default:
return t
}
}
func isDirectType(s string) bool { func isDirectType(s string) bool {
return isAtomicType(s) || isListType(s) && isAtomicType(getCoreType(s)) return isAtomicType(s) || isListType(s) && isAtomicType(getCoreType(s))
} }

View File

@@ -41,7 +41,7 @@ func genDoc(api *spec.ApiSpec, dir string, filename string) error {
var builder strings.Builder var builder strings.Builder
for index, route := range api.Service.Routes() { for index, route := range api.Service.Routes() {
routeComment, _ := util.GetAnnotationValue(route.Annotations, "doc", "summary") routeComment := route.JoinedDoc()
if len(routeComment) == 0 { if len(routeComment) == 0 {
routeComment = "N/A" routeComment = "N/A"
} }
@@ -58,8 +58,8 @@ func genDoc(api *spec.ApiSpec, dir string, filename string) error {
"routeComment": routeComment, "routeComment": routeComment,
"method": strings.ToUpper(route.Method), "method": strings.ToUpper(route.Method),
"uri": route.Path, "uri": route.Path,
"requestType": "`" + stringx.TakeOne(route.RequestType.Name, "-") + "`", "requestType": "`" + stringx.TakeOne(route.RequestTypeName(), "-") + "`",
"responseType": "`" + stringx.TakeOne(route.ResponseType.Name, "-") + "`", "responseType": "`" + stringx.TakeOne(route.ResponseTypeName(), "-") + "`",
"responseContent": responseContent, "responseContent": responseContent,
}) })
if err != nil { if err != nil {
@@ -73,10 +73,28 @@ func genDoc(api *spec.ApiSpec, dir string, filename string) error {
} }
func responseBody(api *spec.ApiSpec, route spec.Route) (string, error) { func responseBody(api *spec.ApiSpec, route spec.Route) (string, error) {
tps := util.GetLocalTypes(api, route) if len(route.ResponseTypeName()) == 0 {
return "", nil
}
var tps = make([]spec.Type, 0)
tps = append(tps, route.ResponseType)
if definedType, ok := route.ResponseType.(spec.DefineStruct); ok {
associatedTypes(definedType, &tps)
}
value, err := gogen.BuildTypes(tps) value, err := gogen.BuildTypes(tps)
if err != nil { if err != nil {
return "", err return "", err
} }
return fmt.Sprintf("\n\n```golang\n%s\n```\n", value), nil return fmt.Sprintf("\n\n```golang\n%s\n```\n", value), nil
} }
func associatedTypes(tp spec.DefineStruct, tps *[]spec.Type) {
*tps = append(*tps, tp)
for _, item := range tp.Members {
if definedType, ok := item.Type.(spec.DefineStruct); ok {
associatedTypes(definedType, tps)
}
}
}

View File

@@ -29,14 +29,11 @@ func DocCommand(c *cli.Context) error {
return err return err
} }
for _, f := range files { for _, f := range files {
p, err := parser.NewParser(f) api, err := parser.Parse(f)
if err != nil { if err != nil {
return errors.New(fmt.Sprintf("parse file: %s, err: %s", f, err.Error())) return errors.New(fmt.Sprintf("parse file: %s, err: %s", f, err.Error()))
} }
api, err := p.Parse()
if err != nil {
return err
}
index := strings.Index(f, dir) index := strings.Index(f, dir)
if index < 0 { if index < 0 {
continue continue

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"errors" "errors"
"fmt" "fmt"
"go/format"
"go/scanner" "go/scanner"
"io/ioutil" "io/ioutil"
"os" "os"
@@ -13,6 +14,7 @@ import (
"github.com/tal-tech/go-zero/core/errorx" "github.com/tal-tech/go-zero/core/errorx"
"github.com/tal-tech/go-zero/tools/goctl/api/parser" "github.com/tal-tech/go-zero/tools/goctl/api/parser"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@@ -52,10 +54,12 @@ func GoFormatApi(c *cli.Context) error {
}) })
be.Add(err) be.Add(err)
} }
if be.NotNil() { if be.NotNil() {
scanner.PrintError(os.Stderr, be.Err()) scanner.PrintError(os.Stderr, be.Err())
os.Exit(1) os.Exit(1)
} }
return be.Err() return be.Err()
} }
@@ -71,10 +75,7 @@ func ApiFormatByStdin() error {
} }
_, err = fmt.Print(result) _, err = fmt.Print(result)
if err != nil { return err
return err
}
return nil
} }
func ApiFormatByPath(apiFilePath string) error { func ApiFormatByPath(apiFilePath string) error {
@@ -88,14 +89,16 @@ func ApiFormatByPath(apiFilePath string) error {
return err return err
} }
if err := ioutil.WriteFile(apiFilePath, []byte(result), os.ModePerm); err != nil { _, err = parser.ParseContent(result)
if err != nil {
return err return err
} }
return nil
return ioutil.WriteFile(apiFilePath, []byte(result), os.ModePerm)
} }
func apiFormat(data string) (string, error) { func apiFormat(data string) (string, error) {
_, err := parser.ParseApi(data) _, err := parser.ParseContent(data)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -103,24 +106,114 @@ func apiFormat(data string) (string, error) {
var builder strings.Builder var builder strings.Builder
s := bufio.NewScanner(strings.NewReader(data)) s := bufio.NewScanner(strings.NewReader(data))
var tapCount = 0 var tapCount = 0
var newLineCount = 0
var preLine string
for s.Scan() { for s.Scan() {
line := strings.TrimSpace(s.Text()) line := strings.TrimSpace(s.Text())
if len(line) == 0 {
if newLineCount > 0 {
continue
}
newLineCount++
} else {
if preLine == rightBrace {
builder.WriteString(ctlutil.NL)
}
newLineCount = 0
}
if tapCount == 0 {
format, err := formatGoTypeDef(line, s, &builder)
if err != nil {
return "", err
}
if format {
continue
}
}
noCommentLine := util.RemoveComment(line) noCommentLine := util.RemoveComment(line)
if noCommentLine == rightParenthesis || noCommentLine == rightBrace { if noCommentLine == rightParenthesis || noCommentLine == rightBrace {
tapCount -= 1 tapCount -= 1
} }
if tapCount < 0 { if tapCount < 0 {
line = strings.TrimSuffix(line, rightBrace) line := strings.TrimSuffix(noCommentLine, rightBrace)
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if strings.HasSuffix(line, leftBrace) { if strings.HasSuffix(line, leftBrace) {
tapCount += 1 tapCount += 1
} }
} }
util.WriteIndent(&builder, tapCount) util.WriteIndent(&builder, tapCount)
builder.WriteString(line + "\n") builder.WriteString(line + ctlutil.NL)
if strings.HasSuffix(noCommentLine, leftParenthesis) || strings.HasSuffix(noCommentLine, leftBrace) { if strings.HasSuffix(noCommentLine, leftParenthesis) || strings.HasSuffix(noCommentLine, leftBrace) {
tapCount += 1 tapCount += 1
} }
preLine = line
} }
return strings.TrimSpace(builder.String()), nil return strings.TrimSpace(builder.String()), nil
} }
func formatGoTypeDef(line string, scanner *bufio.Scanner, builder *strings.Builder) (bool, error) {
noCommentLine := util.RemoveComment(line)
tokenCount := 0
if strings.HasPrefix(noCommentLine, "type") && (strings.HasSuffix(noCommentLine, leftParenthesis) ||
strings.HasSuffix(noCommentLine, leftBrace)) {
var typeBuilder strings.Builder
typeBuilder.WriteString(mayInsertStructKeyword(line, &tokenCount) + ctlutil.NL)
for scanner.Scan() {
noCommentLine := util.RemoveComment(scanner.Text())
typeBuilder.WriteString(mayInsertStructKeyword(scanner.Text(), &tokenCount) + ctlutil.NL)
if noCommentLine == rightBrace || noCommentLine == rightParenthesis {
tokenCount--
}
if tokenCount == 0 {
ts, err := format.Source([]byte(typeBuilder.String()))
if err != nil {
return false, errors.New("error format \n" + typeBuilder.String())
}
result := strings.ReplaceAll(string(ts), " struct ", " ")
result = strings.ReplaceAll(result, "type ()", "")
builder.WriteString(result)
break
}
}
return true, nil
}
return false, nil
}
func mayInsertStructKeyword(line string, token *int) string {
insertStruct := func() string {
if strings.Contains(line, " struct") {
return line
}
index := strings.Index(line, leftBrace)
return line[:index] + " struct " + line[index:]
}
noCommentLine := util.RemoveComment(line)
if strings.HasSuffix(noCommentLine, leftBrace) {
*token++
return insertStruct()
}
if strings.HasSuffix(noCommentLine, rightBrace) {
noCommentLine = strings.TrimSuffix(noCommentLine, rightBrace)
noCommentLine = util.RemoveComment(noCommentLine)
if strings.HasSuffix(noCommentLine, leftBrace) {
return insertStruct()
}
}
if strings.HasSuffix(noCommentLine, leftParenthesis) {
*token++
}
if strings.Contains(noCommentLine, "`") {
return util.UpperFirst(strings.TrimSpace(line))
}
return line
}

View File

@@ -9,13 +9,11 @@ import (
const ( const (
notFormattedStr = ` notFormattedStr = `
type Request struct { type Request struct {
Name string Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
} }
type Response struct { type Response struct {
Message string Message string ` + "`" + `json:"message"` + "`" + `
} }
service A-api { service A-api {
@server( @server(
handler: GreetHandler handler: GreetHandler
@@ -24,14 +22,12 @@ handler: GreetHandler
} }
` `
formattedStr = `type Request struct { formattedStr = `type Request {
Name string Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
} }
type Response {
type Response struct { Message string ` + "`" + `json:"message"` + "`" + `
Message string
} }
service A-api { service A-api {
@server( @server(
handler: GreetHandler handler: GreetHandler
@@ -40,7 +36,7 @@ service A-api {
}` }`
) )
func TestInlineTypeNotExist(t *testing.T) { func TestFormat(t *testing.T) {
r, err := apiFormat(notFormattedStr) r, err := apiFormat(notFormattedStr)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, r, formattedStr) assert.Equal(t, r, formattedStr)

View File

@@ -41,11 +41,7 @@ func GoCommand(c *cli.Context) error {
} }
func DoGenProject(apiFile, dir, style string) error { func DoGenProject(apiFile, dir, style string) error {
p, err := parser.NewParser(apiFile) api, err := parser.Parse(apiFile)
if err != nil {
return err
}
api, err := p.Parse()
if err != nil { if err != nil {
return err return err
} }

View File

@@ -16,16 +16,16 @@ import (
const testApiTemplate = ` const testApiTemplate = `
info( info(
title: doc title title: doc title
desc: > desc: ">
doc description first part, doc description first part,
doc description second part< doc description second part<"
version: 1.0 version: 1.0
) )
// TODO: test // TODO: test
// { // {
type Request struct { // TODO: test type Request struct { // TODO: test
// TOOD // TODO
Name string ` + "`" + `path:"name,options=you|me"` + "`" + ` // } Name string ` + "`" + `path:"name,options=you|me"` + "`" + ` // }
} // TODO: test } // TODO: test
@@ -55,9 +55,7 @@ service A-api {
const testMultiServiceTemplate = ` const testMultiServiceTemplate = `
info( info(
title: doc title title: doc title
desc: > desc: doc description first part
doc description first part,
doc description second part<
version: 1.0 version: 1.0
) )
@@ -229,7 +227,7 @@ type Response struct {
} }
service A-api { service A-api {
@doc(helloworld) @doc ("helloworld")
@server( @server(
handler: GreetHandler handler: GreetHandler
) )
@@ -249,7 +247,7 @@ type Response struct {
} }
service A-api { service A-api {
@doc(helloworld) @doc ("helloworld")
@server( @server(
handler: GreetHandler handler: GreetHandler
) )
@@ -325,10 +323,7 @@ func TestParser(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) api, err := parser.Parse(filename)
assert.Nil(t, err)
api, err := parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, len(api.Types), 2) assert.Equal(t, len(api.Types), 2)
@@ -337,8 +332,8 @@ func TestParser(t *testing.T) {
assert.Equal(t, api.Service.Routes()[0].Path, "/greet/from/:name") assert.Equal(t, api.Service.Routes()[0].Path, "/greet/from/:name")
assert.Equal(t, api.Service.Routes()[1].Path, "/greet/get") assert.Equal(t, api.Service.Routes()[1].Path, "/greet/get")
assert.Equal(t, api.Service.Routes()[1].RequestType.Name, "Request") assert.Equal(t, api.Service.Routes()[1].RequestTypeName(), "Request")
assert.Equal(t, api.Service.Routes()[1].ResponseType.Name, "") assert.Equal(t, api.Service.Routes()[1].ResponseType, nil)
validate(t, filename) validate(t, filename)
} }
@@ -349,10 +344,7 @@ func TestMultiService(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) api, err := parser.Parse(filename)
assert.Nil(t, err)
api, err := parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, len(api.Service.Routes()), 2) assert.Equal(t, len(api.Service.Routes()), 2)
@@ -367,10 +359,7 @@ func TestApiNoInfo(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) _, err = parser.Parse(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
validate(t, filename) validate(t, filename)
@@ -382,7 +371,7 @@ func TestInvalidApiFile(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
_, err = parser.NewParser(filename) _, err = parser.Parse(filename)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@@ -392,14 +381,11 @@ func TestAnonymousAnnotation(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) api, err := parser.Parse(filename)
assert.Nil(t, err)
api, err := parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, len(api.Service.Routes()), 1) assert.Equal(t, len(api.Service.Routes()), 1)
assert.Equal(t, api.Service.Routes()[0].Annotations[0].Value, "GreetHandler") assert.Equal(t, api.Service.Routes()[0].Handler, "GreetHandler")
validate(t, filename) validate(t, filename)
} }
@@ -410,10 +396,7 @@ func TestApiHasMiddleware(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) _, err = parser.Parse(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
validate(t, filename) validate(t, filename)
@@ -425,10 +408,7 @@ func TestApiHasJwt(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) _, err = parser.Parse(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
validate(t, filename) validate(t, filename)
@@ -440,10 +420,7 @@ func TestApiHasJwtAndMiddleware(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) _, err = parser.Parse(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
validate(t, filename) validate(t, filename)
@@ -455,13 +432,8 @@ func TestApiHasNoRequestBody(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) _, err = parser.Parse(filename)
assert.Nil(t, err) assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err)
validate(t, filename)
} }
func TestApiRoutes(t *testing.T) { func TestApiRoutes(t *testing.T) {
@@ -470,10 +442,7 @@ func TestApiRoutes(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) _, err = parser.Parse(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
validate(t, filename) validate(t, filename)
@@ -485,10 +454,7 @@ func TestHasCommentRoutes(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) _, err = parser.Parse(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
validate(t, filename) validate(t, filename)
@@ -500,13 +466,8 @@ func TestInlineTypeNotExist(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) _, err = parser.Parse(filename)
assert.Nil(t, err) assert.NotNil(t, err)
_, err = parser.Parse()
assert.Nil(t, err)
validate(t, filename)
} }
func TestHasImportApi(t *testing.T) { func TestHasImportApi(t *testing.T) {
@@ -520,15 +481,12 @@ func TestHasImportApi(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(importApiName) defer os.Remove(importApiName)
parser, err := parser.NewParser(filename) api, err := parser.Parse(filename)
assert.Nil(t, err)
api, err := parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
var hasInline bool var hasInline bool
for _, ty := range api.Types { for _, ty := range api.Types {
if ty.Name == "ImportData" { if ty.Name() == "ImportData" {
hasInline = true hasInline = true
break break
} }
@@ -544,10 +502,7 @@ func TestNoStructApi(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := parser.NewParser(filename) spec, err := parser.Parse(filename)
assert.Nil(t, err)
spec, err := parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, len(spec.Types), 5) assert.Equal(t, len(spec.Types), 5)
@@ -559,8 +514,8 @@ func TestNestTypeApi(t *testing.T) {
err := ioutil.WriteFile(filename, []byte(nestTypeApi), os.ModePerm) err := ioutil.WriteFile(filename, []byte(nestTypeApi), os.ModePerm)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
_, err = parser.NewParser(filename)
_, err = parser.Parse(filename)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@@ -569,7 +524,8 @@ func TestCamelStyle(t *testing.T) {
err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm) err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
_, err = parser.NewParser(filename)
_, err = parser.Parse(filename)
assert.Nil(t, err) assert.Nil(t, err)
validateWithCamel(t, filename, "GoZero") validateWithCamel(t, filename, "GoZero")

View File

@@ -1,15 +1,11 @@
package gogen package gogen
import ( import (
"bytes"
"fmt" "fmt"
"strings" "strings"
"text/template"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format" "github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
@@ -39,38 +35,24 @@ func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error {
return err return err
} }
fp, created, err := util.MaybeCreateFile(dir, configDir, filename+".go")
if err != nil {
return err
}
if !created {
return nil
}
defer fp.Close()
var authNames = getAuths(api) var authNames = getAuths(api)
var auths []string var auths []string
for _, item := range authNames { for _, item := range authNames {
auths = append(auths, fmt.Sprintf("%s %s", item, jwtTemplate)) auths = append(auths, fmt.Sprintf("%s %s", item, jwtTemplate))
} }
var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl) var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl)
text, err := ctlutil.LoadTemplate(category, configTemplateFile, configTemplate)
if err != nil {
return err
}
t := template.Must(template.New("configTemplate").Parse(text)) return genFile(fileGenConfig{
buffer := new(bytes.Buffer) dir: dir,
err = t.Execute(buffer, map[string]string{ subdir: configDir,
"authImport": authImportStr, filename: filename + ".go",
"auth": strings.Join(auths, "\n"), templateName: "configTemplate",
category: category,
templateFile: configTemplateFile,
builtinTemplate: configTemplate,
data: map[string]string{
"authImport": authImportStr,
"auth": strings.Join(auths, "\n"),
},
}) })
if err != nil {
return err
}
formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode)
return err
} }

View File

@@ -1,15 +1,11 @@
package gogen package gogen
import ( import (
"bytes"
"fmt" "fmt"
"strconv" "strconv"
"text/template"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format" "github.com/tal-tech/go-zero/tools/goctl/util/format"
) )
@@ -28,42 +24,22 @@ func genEtc(dir string, cfg *config.Config, api *spec.ApiSpec) error {
return err return err
} }
fp, created, err := util.MaybeCreateFile(dir, etcDir, fmt.Sprintf("%s.yaml", filename))
if err != nil {
return err
}
if !created {
return nil
}
defer fp.Close()
service := api.Service service := api.Service
host, ok := util.GetAnnotationValue(service.Groups[0].Annotations, "server", "host") host := "0.0.0.0"
if !ok { port := strconv.Itoa(defaultPort)
host = "0.0.0.0"
}
port, ok := util.GetAnnotationValue(service.Groups[0].Annotations, "server", "port")
if !ok {
port = strconv.Itoa(defaultPort)
}
text, err := ctlutil.LoadTemplate(category, etcTemplateFile, etcTemplate) return genFile(fileGenConfig{
if err != nil { dir: dir,
return err subdir: etcDir,
} filename: fmt.Sprintf("%s.yaml", filename),
templateName: "etcTemplate",
t := template.Must(template.New("etcTemplate").Parse(text)) category: category,
buffer := new(bytes.Buffer) templateFile: etcTemplateFile,
err = t.Execute(buffer, map[string]string{ builtinTemplate: etcTemplate,
"serviceName": service.Name, data: map[string]string{
"host": host, "serviceName": service.Name,
"port": port, "host": host,
"port": port,
},
}) })
if err != nil {
return err
}
formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode)
return err
} }

View File

@@ -1,16 +1,11 @@
package gogen package gogen
import ( import (
"bytes"
"errors"
"fmt" "fmt"
"path" "path"
"strings" "strings"
"text/template"
"unicode"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format" "github.com/tal-tech/go-zero/tools/goctl/util/format"
@@ -67,44 +62,31 @@ func genHandler(dir string, cfg *config.Config, group spec.Group, route spec.Rou
return doGenToFile(dir, handler, cfg, group, route, Handler{ return doGenToFile(dir, handler, cfg, group, route, Handler{
ImportPackages: genHandlerImports(group, route, parentPkg), ImportPackages: genHandlerImports(group, route, parentPkg),
HandlerName: handler, HandlerName: handler,
RequestType: util.Title(route.RequestType.Name), RequestType: util.Title(route.RequestTypeName()),
LogicType: strings.Title(getLogicName(route)), LogicType: strings.Title(getLogicName(route)),
Call: strings.Title(strings.TrimSuffix(handler, "Handler")), Call: strings.Title(strings.TrimSuffix(handler, "Handler")),
HasResp: len(route.ResponseType.Name) > 0, HasResp: len(route.ResponseTypeName()) > 0,
HasRequest: len(route.RequestType.Name) > 0, HasRequest: len(route.RequestTypeName()) > 0,
}) })
} }
func doGenToFile(dir, handler string, cfg *config.Config, group spec.Group, route spec.Route, handleObj Handler) error { func doGenToFile(dir, handler string, cfg *config.Config, group spec.Group,
route spec.Route, handleObj Handler) error {
filename, err := format.FileNamingFormat(cfg.NamingFormat, handler) filename, err := format.FileNamingFormat(cfg.NamingFormat, handler)
if err != nil { if err != nil {
return err return err
} }
filename = filename + ".go" return genFile(fileGenConfig{
fp, created, err := apiutil.MaybeCreateFile(dir, getHandlerFolderPath(group, route), filename) dir: dir,
if err != nil { subdir: getHandlerFolderPath(group, route),
return err filename: filename + ".go",
} templateName: "handlerTemplate",
if !created { category: category,
return nil templateFile: handlerTemplateFile,
} builtinTemplate: handlerTemplate,
defer fp.Close() data: handleObj,
})
text, err := util.LoadTemplate(category, handlerTemplateFile, handlerTemplate)
if err != nil {
return err
}
buffer := new(bytes.Buffer)
err = template.Must(template.New("handlerTemplate").Parse(text)).Execute(buffer, handleObj)
if err != nil {
return err
}
formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode)
return err
} }
func genHandlers(dir string, cfg *config.Config, api *spec.ApiSpec) error { func genHandlers(dir string, cfg *config.Config, api *spec.ApiSpec) error {
@@ -124,7 +106,7 @@ func genHandlerImports(group spec.Group, route spec.Route, parentPkg string) str
imports = append(imports, fmt.Sprintf("\"%s\"", imports = append(imports, fmt.Sprintf("\"%s\"",
util.JoinPackages(parentPkg, getLogicFolderPath(group, route)))) util.JoinPackages(parentPkg, getLogicFolderPath(group, route))))
imports = append(imports, fmt.Sprintf("\"%s\"", util.JoinPackages(parentPkg, contextDir))) imports = append(imports, fmt.Sprintf("\"%s\"", util.JoinPackages(parentPkg, contextDir)))
if len(route.RequestType.Name) > 0 { if len(route.RequestTypeName()) > 0 {
imports = append(imports, fmt.Sprintf("\"%s\"\n", util.JoinPackages(parentPkg, typesDir))) imports = append(imports, fmt.Sprintf("\"%s\"\n", util.JoinPackages(parentPkg, typesDir)))
} }
imports = append(imports, fmt.Sprintf("\"%s/rest/httpx\"", vars.ProjectOpenSourceUrl)) imports = append(imports, fmt.Sprintf("\"%s/rest/httpx\"", vars.ProjectOpenSourceUrl))
@@ -133,18 +115,7 @@ func genHandlerImports(group spec.Group, route spec.Route, parentPkg string) str
} }
func getHandlerBaseName(route spec.Route) (string, error) { func getHandlerBaseName(route spec.Route) (string, error) {
handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler") handler := route.Handler
if !ok {
return "", fmt.Errorf("missing handler annotation for %q", route.Path)
}
for _, char := range handler {
if !unicode.IsDigit(char) && !unicode.IsLetter(char) {
return "", errors.New(fmt.Sprintf("route [%s] handler [%s] invalid, handler name should only contains letter or digit",
route.Path, handler))
}
}
handler = strings.TrimSpace(handler) handler = strings.TrimSpace(handler)
handler = strings.TrimSuffix(handler, "handler") handler = strings.TrimSuffix(handler, "handler")
handler = strings.TrimSuffix(handler, "Handler") handler = strings.TrimSuffix(handler, "Handler")
@@ -152,10 +123,10 @@ func getHandlerBaseName(route spec.Route) (string, error) {
} }
func getHandlerFolderPath(group spec.Group, route spec.Route) string { func getHandlerFolderPath(group spec.Group, route spec.Route) string {
folder, ok := apiutil.GetAnnotationValue(route.Annotations, "server", groupProperty) folder := route.GetAnnotation(groupProperty)
if !ok { if len(folder) == 0 {
folder, ok = apiutil.GetAnnotationValue(group.Annotations, "server", groupProperty) folder = group.GetAnnotation(groupProperty)
if !ok { if len(folder) == 0 {
return handlerDir return handlerDir
} }
} }

View File

@@ -1,14 +1,11 @@
package gogen package gogen
import ( import (
"bytes"
"fmt" "fmt"
"path" "path"
"strings" "strings"
"text/template"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format" "github.com/tal-tech/go-zero/tools/goctl/util/format"
@@ -61,17 +58,6 @@ func genLogicByRoute(dir string, cfg *config.Config, group spec.Group, route spe
return err return err
} }
goFile = goFile + ".go"
fp, created, err := util.MaybeCreateFile(dir, getLogicFolderPath(group, route), goFile)
if err != nil {
return err
}
if !created {
return nil
}
defer fp.Close()
parentPkg, err := getParentPackage(dir) parentPkg, err := getParentPackage(dir)
if err != nil { if err != nil {
return err return err
@@ -81,47 +67,46 @@ func genLogicByRoute(dir string, cfg *config.Config, group spec.Group, route spe
var responseString string var responseString string
var returnString string var returnString string
var requestString string var requestString string
if len(route.ResponseType.Name) > 0 { if len(route.ResponseTypeName()) > 0 {
resp := strings.Title(route.ResponseType.Name) resp := responseGoTypeName(route, typesPacket)
responseString = "(*types." + resp + ", error)" responseString = "(" + resp + ", error)"
returnString = fmt.Sprintf("return &types.%s{}, nil", resp) if strings.HasPrefix(resp, "*") {
returnString = fmt.Sprintf("return &%s{}, nil", strings.TrimPrefix(resp, "*"))
} else {
returnString = fmt.Sprintf("return %s{}, nil", resp)
}
} else { } else {
responseString = "error" responseString = "error"
returnString = "return nil" returnString = "return nil"
} }
if len(route.RequestType.Name) > 0 { if len(route.RequestTypeName()) > 0 {
requestString = "req " + "types." + strings.Title(route.RequestType.Name) requestString = "req " + requestGoTypeName(route, typesPacket)
} }
text, err := ctlutil.LoadTemplate(category, logicTemplateFile, logicTemplate) return genFile(fileGenConfig{
if err != nil { dir: dir,
return err subdir: getLogicFolderPath(group, route),
} filename: goFile + ".go",
templateName: "logicTemplate",
t := template.Must(template.New("logicTemplate").Parse(text)) category: category,
buffer := new(bytes.Buffer) templateFile: logicTemplateFile,
err = t.Execute(fp, map[string]string{ builtinTemplate: logicTemplate,
"imports": imports, data: map[string]string{
"logic": strings.Title(logic), "imports": imports,
"function": strings.Title(strings.TrimSuffix(logic, "Logic")), "logic": strings.Title(logic),
"responseType": responseString, "function": strings.Title(strings.TrimSuffix(logic, "Logic")),
"returnString": returnString, "responseType": responseString,
"request": requestString, "returnString": returnString,
"request": requestString,
},
}) })
if err != nil {
return err
}
formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode)
return err
} }
func getLogicFolderPath(group spec.Group, route spec.Route) string { func getLogicFolderPath(group spec.Group, route spec.Route) string {
folder, ok := util.GetAnnotationValue(route.Annotations, "server", groupProperty) folder := route.GetAnnotation(groupProperty)
if !ok { if len(folder) == 0 {
folder, ok = util.GetAnnotationValue(group.Annotations, "server", groupProperty) folder = group.GetAnnotation(groupProperty)
if !ok { if len(folder) == 0 {
return logicDir return logicDir
} }
} }
@@ -134,7 +119,7 @@ func genLogicImports(route spec.Route, parentPkg string) string {
var imports []string var imports []string
imports = append(imports, `"context"`+"\n") imports = append(imports, `"context"`+"\n")
imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, contextDir))) imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, contextDir)))
if len(route.ResponseType.Name) > 0 || len(route.RequestType.Name) > 0 { if len(route.ResponseTypeName()) > 0 || len(route.RequestTypeName()) > 0 {
imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir))) imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir)))
} }
imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceUrl)) imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceUrl))

View File

@@ -1,13 +1,10 @@
package gogen package gogen
import ( import (
"bytes"
"fmt" "fmt"
"strings" "strings"
"text/template"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format" "github.com/tal-tech/go-zero/tools/goctl/util/format"
@@ -52,39 +49,24 @@ func genMain(dir string, cfg *config.Config, api *spec.ApiSpec) error {
return err return err
} }
goFile := filename + ".go"
fp, created, err := util.MaybeCreateFile(dir, "", goFile)
if err != nil {
return err
}
if !created {
return nil
}
defer fp.Close()
parentPkg, err := getParentPackage(dir) parentPkg, err := getParentPackage(dir)
if err != nil { if err != nil {
return err return err
} }
text, err := ctlutil.LoadTemplate(category, mainTemplateFile, mainTemplate) return genFile(fileGenConfig{
if err != nil { dir: dir,
return err subdir: "",
} filename: filename + ".go",
templateName: "mainTemplate",
t := template.Must(template.New("mainTemplate").Parse(text)) category: category,
buffer := new(bytes.Buffer) templateFile: mainTemplateFile,
err = t.Execute(buffer, map[string]string{ builtinTemplate: mainTemplate,
"importPackages": genMainImports(parentPkg), data: map[string]string{
"serviceName": api.Service.Name, "importPackages": genMainImports(parentPkg),
"serviceName": api.Service.Name,
},
}) })
if err != nil {
return err
}
formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode)
return err
} }
func genMainImports(parentPkg string) string { func genMainImports(parentPkg string) string {

View File

@@ -1,12 +1,9 @@
package gogen package gogen
import ( import (
"bytes"
"strings" "strings"
"text/template"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util/format" "github.com/tal-tech/go-zero/tools/goctl/util/format"
) )
@@ -37,34 +34,26 @@ func genMiddleware(dir string, cfg *config.Config, api *spec.ApiSpec) error {
var middlewares = getMiddleware(api) var middlewares = getMiddleware(api)
for _, item := range middlewares { for _, item := range middlewares {
middlewareFilename := strings.TrimSuffix(strings.ToLower(item), "middleware") + "_middleware" middlewareFilename := strings.TrimSuffix(strings.ToLower(item), "middleware") + "_middleware"
formatName, err := format.FileNamingFormat(cfg.NamingFormat, middlewareFilename) filename, err := format.FileNamingFormat(cfg.NamingFormat, middlewareFilename)
if err != nil { if err != nil {
return err return err
} }
filename := formatName + ".go"
fp, created, err := util.MaybeCreateFile(dir, middlewareDir, filename)
if err != nil {
return err
}
if !created {
return nil
}
defer fp.Close()
name := strings.TrimSuffix(item, "Middleware") + "Middleware" name := strings.TrimSuffix(item, "Middleware") + "Middleware"
t := template.Must(template.New("contextTemplate").Parse(middlewareImplementCode)) err = genFile(fileGenConfig{
buffer := new(bytes.Buffer) dir: dir,
err = t.Execute(buffer, map[string]string{ subdir: middlewareDir,
"name": strings.Title(name), filename: filename + ".go",
templateName: "contextTemplate",
builtinTemplate: middlewareImplementCode,
data: map[string]string{
"name": strings.Title(name),
},
}) })
if err != nil { if err != nil {
return err return err
} }
formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode)
return err
} }
return nil return nil
} }

View File

@@ -1,7 +1,6 @@
package gogen package gogen
import ( import (
"bytes"
"fmt" "fmt"
"os" "os"
"path" "path"
@@ -11,7 +10,6 @@ import (
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format" "github.com/tal-tech/go-zero/tools/goctl/util/format"
@@ -132,28 +130,19 @@ func genRoutes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
filename := path.Join(dir, handlerDir, routeFilename) filename := path.Join(dir, handlerDir, routeFilename)
os.Remove(filename) os.Remove(filename)
fp, created, err := apiutil.MaybeCreateFile(dir, handlerDir, routeFilename) return genFile(fileGenConfig{
if err != nil { dir: dir,
return err subdir: handlerDir,
} filename: routeFilename,
if !created { templateName: "routesTemplate",
return nil category: "",
} templateFile: "",
defer fp.Close() builtinTemplate: routesTemplate,
data: map[string]string{
t := template.Must(template.New("routesTemplate").Parse(routesTemplate)) "importPackages": genRouteImports(parentPkg, api),
buffer := new(bytes.Buffer) "routesAdditions": strings.TrimSpace(builder.String()),
err = t.Execute(buffer, map[string]string{ },
"importPackages": genRouteImports(parentPkg, api),
"routesAdditions": strings.TrimSpace(builder.String()),
}) })
if err != nil {
return err
}
formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode)
return err
} }
func genRouteImports(parentPkg string, api *spec.ApiSpec) string { func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
@@ -161,10 +150,10 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
importSet.AddStr(fmt.Sprintf("\"%s\"", util.JoinPackages(parentPkg, contextDir))) importSet.AddStr(fmt.Sprintf("\"%s\"", util.JoinPackages(parentPkg, contextDir)))
for _, group := range api.Service.Groups { for _, group := range api.Service.Groups {
for _, route := range group.Routes { for _, route := range group.Routes {
folder, ok := apiutil.GetAnnotationValue(route.Annotations, "server", groupProperty) folder := route.GetAnnotation(groupProperty)
if !ok { if len(folder) == 0 {
folder, ok = apiutil.GetAnnotationValue(group.Annotations, "server", groupProperty) folder = group.GetAnnotation(groupProperty)
if !ok { if len(folder) == 0 {
continue continue
} }
} }
@@ -186,12 +175,12 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
for _, r := range g.Routes { for _, r := range g.Routes {
handler := getHandlerName(r) handler := getHandlerName(r)
handler = handler + "(serverCtx)" handler = handler + "(serverCtx)"
folder, ok := apiutil.GetAnnotationValue(r.Annotations, "server", groupProperty) folder := r.GetAnnotation(groupProperty)
if ok { if len(folder) > 0 {
handler = toPrefix(folder) + "." + strings.ToUpper(handler[:1]) + handler[1:] handler = toPrefix(folder) + "." + strings.ToUpper(handler[:1]) + handler[1:]
} else { } else {
folder, ok = apiutil.GetAnnotationValue(g.Annotations, "server", groupProperty) folder = g.GetAnnotation(groupProperty)
if ok { if len(folder) > 0 {
handler = toPrefix(folder) + "." + strings.ToUpper(handler[:1]) + handler[1:] handler = toPrefix(folder) + "." + strings.ToUpper(handler[:1]) + handler[1:]
} }
} }
@@ -202,12 +191,14 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
}) })
} }
if value, ok := apiutil.GetAnnotationValue(g.Annotations, "server", "jwt"); ok { jwt := g.GetAnnotation("jwt")
groupedRoutes.authName = value if len(jwt) > 0 {
groupedRoutes.authName = jwt
groupedRoutes.jwtEnabled = true groupedRoutes.jwtEnabled = true
} }
if value, ok := apiutil.GetAnnotationValue(g.Annotations, "server", "middleware"); ok { middleware := g.GetAnnotation("middleware")
for _, item := range strings.Split(value, ",") { if len(middleware) > 0 {
for _, item := range strings.Split(middleware, ",") {
groupedRoutes.middlewares = append(groupedRoutes.middlewares, item) groupedRoutes.middlewares = append(groupedRoutes.middlewares, item)
} }
} }

View File

@@ -1,13 +1,10 @@
package gogen package gogen
import ( import (
"bytes"
"fmt" "fmt"
"strings" "strings"
"text/template"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format" "github.com/tal-tech/go-zero/tools/goctl/util/format"
@@ -33,7 +30,6 @@ func NewServiceContext(c {{.config}}) *ServiceContext {
{{.middlewareAssignment}} {{.middlewareAssignment}}
} }
} }
` `
) )
@@ -43,15 +39,6 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
return err return err
} }
fp, created, err := util.MaybeCreateFile(dir, contextDir, filename+".go")
if err != nil {
return err
}
if !created {
return nil
}
defer fp.Close()
var authNames = getAuths(api) var authNames = getAuths(api)
var auths []string var auths []string
for _, item := range authNames { for _, item := range authNames {
@@ -63,11 +50,6 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
return err return err
} }
text, err := ctlutil.LoadTemplate(category, contextTemplateFile, contextTemplate)
if err != nil {
return err
}
var middlewareStr string var middlewareStr string
var middlewareAssignment string var middlewareAssignment string
var middlewares = getMiddleware(api) var middlewares = getMiddleware(api)
@@ -75,7 +57,8 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
for _, item := range middlewares { for _, item := range middlewares {
middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item) middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
name := strings.TrimSuffix(item, "Middleware") + "Middleware" name := strings.TrimSuffix(item, "Middleware") + "Middleware"
middlewareAssignment += fmt.Sprintf("%s: %s,\n", item, fmt.Sprintf("middleware.New%s().%s", strings.Title(name), "Handle")) middlewareAssignment += fmt.Sprintf("%s: %s,\n", item,
fmt.Sprintf("middleware.New%s().%s", strings.Title(name), "Handle"))
} }
var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\"" var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
@@ -84,19 +67,19 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl) configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceUrl)
} }
t := template.Must(template.New("contextTemplate").Parse(text)) return genFile(fileGenConfig{
buffer := new(bytes.Buffer) dir: dir,
err = t.Execute(buffer, map[string]string{ subdir: contextDir,
"configImport": configImport, filename: filename + ".go",
"config": "config.Config", templateName: "contextTemplate",
"middleware": middlewareStr, category: category,
"middlewareAssignment": middlewareAssignment, templateFile: contextTemplateFile,
builtinTemplate: contextTemplate,
data: map[string]string{
"configImport": configImport,
"config": "config.Config",
"middleware": middlewareStr,
"middlewareAssignment": middlewareAssignment,
},
}) })
if err != nil {
return err
}
formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode)
return err
} }

View File

@@ -1,14 +1,12 @@
package gogen package gogen
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os" "os"
"path" "path"
"strings" "strings"
"text/template"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util" apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
@@ -37,8 +35,8 @@ func BuildTypes(types []spec.Type) (string, error) {
} else { } else {
builder.WriteString("\n\n") builder.WriteString("\n\n")
} }
if err := writeType(&builder, tp, types); err != nil { if err := writeType(&builder, tp); err != nil {
return "", apiutil.WrapErr(err, "Type "+tp.Name+" generate error") return "", apiutil.WrapErr(err, "Type "+tp.Name()+" generate error")
} }
} }
@@ -55,91 +53,43 @@ func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
if err != nil { if err != nil {
return err return err
} }
typeFilename = typeFilename + ".go" typeFilename = typeFilename + ".go"
filename := path.Join(dir, typesDir, typeFilename) filename := path.Join(dir, typesDir, typeFilename)
os.Remove(filename) os.Remove(filename)
fp, created, err := apiutil.MaybeCreateFile(dir, typesDir, typeFilename) return genFile(fileGenConfig{
if err != nil { dir: dir,
return err subdir: typesDir,
} filename: typeFilename,
templateName: "typesTemplate",
if !created { category: "",
return nil templateFile: "",
} builtinTemplate: typesTemplate,
defer fp.Close() data: map[string]interface{}{
"types": val,
t := template.Must(template.New("typesTemplate").Parse(typesTemplate)) "containsTime": false,
buffer := new(bytes.Buffer) },
err = t.Execute(buffer, map[string]interface{}{
"types": val,
"containsTime": api.ContainsTime(),
}) })
if err != nil {
return err
}
formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode)
return err
} }
func convertTypeCase(types []spec.Type, t string) (string, error) { func writeType(writer io.Writer, tp spec.Type) error {
ts, err := apiutil.DecomposeType(t) structType, ok := tp.(spec.DefineStruct)
if err != nil { if !ok {
return "", err return errors.New(fmt.Sprintf("unspport struct type: %s", tp.Name()))
} }
var defTypes []string fmt.Fprintf(writer, "type %s struct {\n", util.Title(tp.Name()))
for _, tp := range ts { for _, member := range structType.Members {
for _, typ := range types {
if typ.Name == tp {
defTypes = append(defTypes, tp)
}
}
}
for _, tp := range defTypes {
t = strings.ReplaceAll(t, tp, util.Title(tp))
}
return t, nil
}
func writeType(writer io.Writer, tp spec.Type, types []spec.Type) error {
fmt.Fprintf(writer, "type %s struct {\n", util.Title(tp.Name))
for _, member := range tp.Members {
if member.IsInline { if member.IsInline {
var found = false if _, err := fmt.Fprintf(writer, "%s\n", strings.Title(member.Type.Name())); err != nil {
for _, ty := range types {
if strings.ToLower(ty.Name) == strings.ToLower(member.Name) {
found = true
}
}
if !found {
return errors.New("inline type " + member.Name + " not exist, please correct api file")
}
if _, err := fmt.Fprintf(writer, "%s\n", strings.Title(member.Type)); err != nil {
return err return err
} else { } else {
continue continue
} }
} }
tpString, err := convertTypeCase(types, member.Type)
if err != nil { if err := writeProperty(writer, member.Name, member.Tag, member.GetComment(), member.Type, 1); err != nil {
return err
}
pm, err := member.GetPropertyName()
if err != nil {
return err
}
if !strings.Contains(pm, "_") {
if strings.Title(member.Name) != strings.Title(pm) {
fmt.Printf("type: %s, property name %s json tag illegal, "+
"should set json tag as `json:\"%s\"` \n", tp.Name, member.Name, util.Untitle(member.Name))
}
}
if err := writeProperty(writer, member.Name, tpString, member.Tag, member.GetComment(), 1); err != nil {
return err return err
} }
} }

View File

@@ -38,11 +38,12 @@ func RevertTemplate(name string) error {
return util.CreateTemplate(category, name, content) return util.CreateTemplate(category, name, content)
} }
func Update(category string) error { func Update() error {
err := Clean() err := Clean()
if err != nil { if err != nil {
return err return err
} }
return util.InitTemplates(category, templates) return util.InitTemplates(category, templates)
} }
@@ -50,6 +51,6 @@ func Clean() error {
return util.Clean(category) return util.Clean(category)
} }
func GetCategory() string { func Category() string {
return category return category
} }

View File

@@ -84,7 +84,7 @@ func TestUpdate(t *testing.T) {
assert.Equal(t, string(data), modifyData) assert.Equal(t, string(data), modifyData)
assert.Nil(t, Update(category)) assert.Nil(t, Update())
data, err = ioutil.ReadFile(file) data, err = ioutil.ReadFile(file)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -1,18 +1,64 @@
package gogen package gogen
import ( import (
"bytes"
"fmt" "fmt"
goformat "go/format" goformat "go/format"
"io" "io"
"path/filepath" "path/filepath"
"strings" "strings"
"text/template"
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx" "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
) )
type fileGenConfig struct {
dir string
subdir string
filename string
templateName string
category string
templateFile string
builtinTemplate string
data interface{}
}
func genFile(c fileGenConfig) error {
fp, created, err := util.MaybeCreateFile(c.dir, c.subdir, c.filename)
if err != nil {
return err
}
if !created {
return nil
}
defer fp.Close()
var text string
if len(c.category) == 0 || len(c.templateFile) == 0 {
text = c.builtinTemplate
} else {
text, err = ctlutil.LoadTemplate(c.category, c.templateFile, c.builtinTemplate)
if err != nil {
return err
}
}
t := template.Must(template.New(c.templateName).Parse(text))
buffer := new(bytes.Buffer)
err = t.Execute(buffer, c.data)
if err != nil {
return err
}
code := formatCode(buffer.String())
_, err = fp.WriteString(code)
return err
}
func getParentPackage(dir string) (string, error) { func getParentPackage(dir string) (string, error) {
abs, err := filepath.Abs(dir) abs, err := filepath.Abs(dir)
if err != nil { if err != nil {
@@ -26,15 +72,15 @@ func getParentPackage(dir string) (string, error) {
return filepath.ToSlash(filepath.Join(projectCtx.Path, strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir))), nil return filepath.ToSlash(filepath.Join(projectCtx.Path, strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir))), nil
} }
func writeProperty(writer io.Writer, name, tp, tag, comment string, indent int) error { func writeProperty(writer io.Writer, name, tag, comment string, tp spec.Type, indent int) error {
util.WriteIndent(writer, indent) util.WriteIndent(writer, indent)
var err error var err error
if len(comment) > 0 { if len(comment) > 0 {
comment = strings.TrimPrefix(comment, "//") comment = strings.TrimPrefix(comment, "//")
comment = "//" + comment comment = "//" + comment
_, err = fmt.Fprintf(writer, "%s %s %s %s\n", strings.Title(name), tp, tag, comment) _, err = fmt.Fprintf(writer, "%s %s %s %s\n", strings.Title(name), tp.Name(), tag, comment)
} else { } else {
_, err = fmt.Fprintf(writer, "%s %s %s\n", strings.Title(name), tp, tag) _, err = fmt.Fprintf(writer, "%s %s %s\n", strings.Title(name), tp.Name(), tag)
} }
return err return err
} }
@@ -42,11 +88,13 @@ func writeProperty(writer io.Writer, name, tp, tag, comment string, indent int)
func getAuths(api *spec.ApiSpec) []string { func getAuths(api *spec.ApiSpec) []string {
authNames := collection.NewSet() authNames := collection.NewSet()
for _, g := range api.Service.Groups { for _, g := range api.Service.Groups {
if value, ok := util.GetAnnotationValue(g.Annotations, "server", "jwt"); ok { jwt := g.GetAnnotation("jwt")
authNames.Add(value) if len(jwt) > 0 {
authNames.Add(jwt)
} }
if value, ok := util.GetAnnotationValue(g.Annotations, "server", "signature"); ok { signature := g.GetAnnotation("signature")
authNames.Add(value) if len(signature) > 0 {
authNames.Add(signature)
} }
} }
return authNames.KeysStr() return authNames.KeysStr()
@@ -55,8 +103,9 @@ func getAuths(api *spec.ApiSpec) []string {
func getMiddleware(api *spec.ApiSpec) []string { func getMiddleware(api *spec.ApiSpec) []string {
result := collection.NewSet() result := collection.NewSet()
for _, g := range api.Service.Groups { for _, g := range api.Service.Groups {
if value, ok := util.GetAnnotationValue(g.Annotations, "server", "middleware"); ok { middleware := g.GetAnnotation("middleware")
for _, item := range strings.Split(value, ",") { if len(middleware) > 0 {
for _, item := range strings.Split(middleware, ",") {
result.Add(strings.TrimSpace(item)) result.Add(strings.TrimSpace(item))
} }
} }
@@ -72,3 +121,70 @@ func formatCode(code string) string {
return string(ret) return string(ret)
} }
func responseGoTypeName(r spec.Route, pkg ...string) string {
if r.ResponseType == nil {
return ""
}
return golangExpr(r.ResponseType, pkg...)
}
func requestGoTypeName(r spec.Route, pkg ...string) string {
if r.RequestType == nil {
return ""
}
return golangExpr(r.RequestType, pkg...)
}
func golangExpr(ty spec.Type, pkg ...string) string {
switch v := ty.(type) {
case spec.PrimitiveType:
return v.RawName
case spec.DefineStruct:
if len(pkg) > 1 {
panic("package cannot be more than 1")
}
if len(pkg) == 0 {
return v.RawName
}
return fmt.Sprintf("%s.%s", pkg[0], strings.Title(v.RawName))
case spec.ArrayType:
if len(pkg) > 1 {
panic("package cannot be more than 1")
}
if len(pkg) == 0 {
return v.RawName
}
return fmt.Sprintf("[]%s", golangExpr(v.Value, pkg...))
case spec.MapType:
if len(pkg) > 1 {
panic("package cannot be more than 1")
}
if len(pkg) == 0 {
return v.RawName
}
return fmt.Sprintf("map[%s]%s", v.Key, golangExpr(v.Value, pkg...))
case spec.PointerType:
if len(pkg) > 1 {
panic("package cannot be more than 1")
}
if len(pkg) == 0 {
return v.RawName
}
return fmt.Sprintf("*%s", golangExpr(v.Type, pkg...))
case spec.InterfaceType:
return v.RawName
}
return ""
}

View File

@@ -22,11 +22,7 @@ func JavaCommand(c *cli.Context) error {
return errors.New("missing -dir") return errors.New("missing -dir")
} }
p, err := parser.NewParser(apiFile) api, err := parser.Parse(apiFile)
if err != nil {
return err
}
api, err := p.Parse()
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,12 +1,16 @@
package javagen package javagen
import ( import (
"bufio"
"bytes"
"errors"
"fmt" "fmt"
"io" "io"
"path" "path"
"strings" "strings"
"text/template" "text/template"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util" apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
@@ -16,19 +20,81 @@ const (
componentTemplate = `// Code generated by goctl. DO NOT EDIT. componentTemplate = `// Code generated by goctl. DO NOT EDIT.
package com.xhb.logic.http.packet.{{.packet}}.model; package com.xhb.logic.http.packet.{{.packet}}.model;
import com.xhb.logic.http.DeProguardable; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
{{.imports}}
{{.componentType}} public class {{.className}} extends {{.superClassName}} {
{{.properties}}
{{if .HasProperty}}
public {{.className}}() {
}
public {{.className}}({{.params}}) {
{{.constructorSetter}}
}
{{end}}
{{.getSet}}
}
` `
getSetTemplate = `
{{.indent}}{{.decorator}}
{{.indent}}public {{.returnType}} get{{.property}}() {
{{.indent}} return this.{{.tagValue}};
{{.indent}}}
{{.indent}}public void set{{.property}}({{.type}} {{.propertyValue}}) {
{{.indent}} this.{{.tagValue}} = {{.propertyValue}};
{{.indent}}}
`
boolTemplate = `
{{.indent}}{{.decorator}}
{{.indent}}public {{.returnType}} is{{.property}}() {
{{.indent}} return this.{{.tagValue}};
{{.indent}}}
{{.indent}}public void set{{.property}}({{.type}} {{.propertyValue}}) {
{{.indent}} this.{{.tagValue}} = {{.propertyValue}};
{{.indent}}}
`
httpResponseData = "import com.xhb.core.response.HttpResponseData;"
httpData = "import com.xhb.core.packet.HttpData;"
) )
type componentsContext struct {
api *spec.ApiSpec
requestTypes []spec.Type
responseTypes []spec.Type
imports []string
members []spec.Member
}
func genComponents(dir, packetName string, api *spec.ApiSpec) error { func genComponents(dir, packetName string, api *spec.ApiSpec) error {
types := apiutil.GetSharedTypes(api) types := api.Types
if len(types) == 0 { if len(types) == 0 {
return nil return nil
} }
var requestTypes []spec.Type
var responseTypes []spec.Type
for _, group := range api.Service.Groups {
for _, route := range group.Routes {
if route.RequestType != nil {
requestTypes = append(requestTypes, route.RequestType)
}
if route.ResponseType != nil {
responseTypes = append(responseTypes, route.ResponseType)
}
}
}
context := componentsContext{api: api, requestTypes: requestTypes, responseTypes: responseTypes}
for _, ty := range types { for _, ty := range types {
if err := createComponent(dir, packetName, ty); err != nil { if err := context.createComponent(dir, packetName, ty); err != nil {
return err return err
} }
} }
@@ -36,13 +102,55 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
return nil return nil
} }
func createComponent(dir, packetName string, ty spec.Type) error { func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
modelFile := util.Title(ty.Name) + ".java" defineStruct, ok := ty.(spec.DefineStruct)
if !ok {
return errors.New("unsupported type %s" + ty.Name())
}
for _, item := range c.requestTypes {
if item.Name() == defineStruct.Name() {
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
return nil
}
}
}
modelFile := util.Title(ty.Name()) + ".java"
filename := path.Join(dir, modelDir, modelFile) filename := path.Join(dir, modelDir, modelFile)
if err := util.RemoveOrQuit(filename); err != nil { if err := util.RemoveOrQuit(filename); err != nil {
return err return err
} }
propertiesString, err := c.buildProperties(defineStruct)
if err != nil {
return err
}
getSetString, err := c.buildGetterSetter(defineStruct)
if err != nil {
return err
}
superClassName := "HttpData"
for _, item := range c.responseTypes {
if item.Name() == defineStruct.Name() {
superClassName = "HttpResponseData"
if !stringx.Contains(c.imports, httpResponseData) {
c.imports = append(c.imports, httpResponseData)
}
break
}
}
if superClassName == "HttpData" && !stringx.Contains(c.imports, httpData) {
c.imports = append(c.imports, httpData)
}
params, constructorSetter, err := c.buildConstructor()
if err != nil {
return err
}
fp, created, err := apiutil.MaybeCreateFile(dir, modelDir, modelFile) fp, created, err := apiutil.MaybeCreateFile(dir, modelDir, modelFile)
if err != nil { if err != nil {
return err return err
@@ -52,34 +160,184 @@ func createComponent(dir, packetName string, ty spec.Type) error {
} }
defer fp.Close() defer fp.Close()
tys, err := buildType(ty) buffer := new(bytes.Buffer)
t := template.Must(template.New("componentType").Parse(componentTemplate))
err = t.Execute(buffer, map[string]interface{}{
"properties": propertiesString,
"params": params,
"constructorSetter": constructorSetter,
"getSet": getSetString,
"packet": packetName,
"imports": strings.Join(c.imports, "\n"),
"className": util.Title(defineStruct.Name()),
"superClassName": superClassName,
"HasProperty": len(strings.TrimSpace(propertiesString)) > 0,
})
if err != nil { if err != nil {
return err return err
} }
t := template.Must(template.New("componentType").Parse(componentTemplate)) _, err = fp.WriteString(formatSource(buffer.String()))
return t.Execute(fp, map[string]string{ return err
"componentType": tys,
"packet": packetName,
})
} }
func buildType(ty spec.Type) (string, error) { func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
var builder strings.Builder var builder strings.Builder
if err := writeType(&builder, ty); err != nil { if err := c.writeType(&builder, defineStruct); err != nil {
return "", apiutil.WrapErr(err, "Type "+ty.Name+" generate error") return "", apiutil.WrapErr(err, "Type "+defineStruct.Name()+" generate error")
} }
return builder.String(), nil return builder.String(), nil
} }
func writeType(writer io.Writer, tp spec.Type) error { func (c *componentsContext) buildGetterSetter(defineStruct spec.DefineStruct) (string, error) {
fmt.Fprintf(writer, "public class %s implements DeProguardable {\n", util.Title(tp.Name)) var builder strings.Builder
for _, member := range tp.Members { if err := c.genGetSet(&builder, 1); err != nil {
if err := writeProperty(writer, member, 1); err != nil { return "", apiutil.WrapErr(err, "Type "+defineStruct.Name()+" get or set generate error")
return err
}
} }
genGetSet(writer, tp, 1)
fmt.Fprintf(writer, "}\n") return builder.String(), nil
}
func (c *componentsContext) writeType(writer io.Writer, defineStruct spec.DefineStruct) error {
c.members = make([]spec.Member, 0)
err := c.writeMembers(writer, defineStruct, 1)
if err != nil {
return err
}
return nil return nil
} }
func (c *componentsContext) writeMembers(writer io.Writer, tp spec.Type, indent int) error {
definedType, ok := tp.(spec.DefineStruct)
if !ok {
pointType, ok := tp.(spec.PointerType)
if ok {
return c.writeMembers(writer, pointType.Type, indent)
}
return errors.New(fmt.Sprintf("type %s not supported", tp.Name()))
}
for _, member := range definedType.Members {
if member.IsInline {
err := c.writeMembers(writer, member.Type, indent)
if err != nil {
return err
}
continue
}
if member.IsBodyMember() || member.IsFormMember() {
if err := writeProperty(writer, member, indent); err != nil {
return err
}
c.members = append(c.members, member)
}
}
return nil
}
func (c *componentsContext) buildConstructor() (string, string, error) {
var params strings.Builder
var constructorSetter strings.Builder
for index, member := range c.members {
tp, err := specTypeToJava(member.Type)
if err != nil {
return "", "", err
}
params.WriteString(fmt.Sprintf("%s %s", tp, util.Untitle(member.Name)))
pn, err := member.GetPropertyName()
if err != nil {
return "", "", err
}
if index != len(c.members)-1 {
params.WriteString(", ")
}
writeIndent(&constructorSetter, 2)
constructorSetter.WriteString(fmt.Sprintf("this.%s = %s;", pn, util.Untitle(member.Name)))
if index != len(c.members)-1 {
constructorSetter.WriteString(util.NL)
}
}
return params.String(), constructorSetter.String(), nil
}
func (c *componentsContext) genGetSet(writer io.Writer, indent int) error {
var members = c.members
for _, member := range members {
javaType, err := specTypeToJava(member.Type)
if err != nil {
return nil
}
var property = util.Title(member.Name)
var templateStr = getSetTemplate
if javaType == "boolean" {
templateStr = boolTemplate
property = strings.TrimPrefix(property, "Is")
property = strings.TrimPrefix(property, "is")
}
t := template.Must(template.New(templateStr).Parse(getSetTemplate))
var tmplBytes bytes.Buffer
tyString := javaType
decorator := ""
javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"}
if !stringx.Contains(javaPrimitiveType, javaType) {
if member.IsOptional() || member.IsOmitEmpty() {
decorator = "@Nullable "
} else {
decorator = "@NotNull "
}
tyString = decorator + tyString
}
tagName, err := member.GetPropertyName()
if err != nil {
return err
}
err = t.Execute(&tmplBytes, map[string]string{
"property": property,
"propertyValue": util.Untitle(member.Name),
"tagValue": tagName,
"type": tyString,
"decorator": decorator,
"returnType": javaType,
"indent": indentString(indent),
})
if err != nil {
return err
}
r := tmplBytes.String()
r = strings.Replace(r, " boolean get", " boolean is", 1)
writer.Write([]byte(r))
}
return nil
}
func formatSource(source string) string {
var builder strings.Builder
scanner := bufio.NewScanner(strings.NewReader(source))
preIsBreakLine := false
for scanner.Scan() {
text := strings.TrimSpace(scanner.Text())
if text == "" && preIsBreakLine {
continue
}
preIsBreakLine = text == ""
builder.WriteString(scanner.Text() + "\n")
}
if err := scanner.Err(); err != nil {
fmt.Println(err)
}
return builder.String()
}

View File

@@ -1,11 +1,8 @@
package javagen package javagen
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io"
"os"
"strings" "strings"
"text/template" "text/template"
@@ -17,25 +14,17 @@ import (
const packetTemplate = `package com.xhb.logic.http.packet.{{.packet}}; const packetTemplate = `package com.xhb.logic.http.packet.{{.packet}};
import com.google.gson.Gson; import com.xhb.core.packet.HttpPacket;
import com.xhb.commons.JSON;
import com.xhb.commons.JsonParser;
import com.xhb.core.network.HttpRequestClient; import com.xhb.core.network.HttpRequestClient;
import com.xhb.core.packet.HttpRequestPacket; {{.imports}}
import com.xhb.core.response.HttpResponseData;
import com.xhb.logic.http.DeProguardable;
{{.import}}
import org.jetbrains.annotations.NotNull;
import org.json.JSONObject;
public class {{.packetName}} extends HttpRequestPacket<{{.packetName}}.{{.packetName}}Response> {
{{.doc}}
public class {{.packetName}} extends HttpPacket<{{.responseType}}> {
{{.paramsDeclaration}} {{.paramsDeclaration}}
public {{.packetName}}({{.params}}{{.requestType}} request) { public {{.packetName}}({{.params}}{{if .HasRequestBody}}{{.requestType}} request{{end}}) {
super(request); {{if .HasRequestBody}}super(request);{{else}}super(EmptyRequest.instance);{{end}}
this.request = request;{{.paramsSet}} {{if .HasRequestBody}}this.request = request;{{end}}{{.paramsSetter}}
} }
@Override @Override
@@ -47,32 +36,6 @@ public class {{.packetName}} extends HttpRequestPacket<{{.packetName}}.{{.packet
public String requestUri() { public String requestUri() {
return {{.uri}}; return {{.uri}};
} }
@Override
public {{.packetName}}Response newInstanceFrom(JSON json) {
return new {{.packetName}}Response(json);
}
public static class {{.packetName}}Response extends HttpResponseData {
private {{.responseType}} responseData;
{{.packetName}}Response(@NotNull JSON json) {
super(json);
JSONObject jsonObject = json.asObject();
if (JsonParser.hasKey(jsonObject, "data")) {
Gson gson = new Gson();
JSONObject dataJson = JsonParser.getJSONObject(jsonObject, "data");
responseData = gson.fromJson(dataJson.toString(), {{.responseType}}.class);
}
}
public {{.responseType}} get{{.responseType}} () {
return responseData;
}
}
{{.types}}
} }
` `
@@ -87,10 +50,11 @@ func genPacket(dir, packetName string, api *spec.ApiSpec) error {
} }
func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName string) error { func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName string) error {
packet, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler") packet := route.Handler
packet = strings.Replace(packet, "Handler", "Packet", 1) packet = strings.Replace(packet, "Handler", "Packet", 1)
if !ok { packet = strings.Title(packet)
return fmt.Errorf("missing packet annotation for %q", route.Path) if !strings.HasSuffix(packet, "Packet") {
packet += "Packet"
} }
javaFile := packet + ".java" javaFile := packet + ".java"
@@ -103,80 +67,68 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
} }
defer fp.Close() defer fp.Close()
var builder strings.Builder var hasRequestBody = false
var first bool if route.RequestType != nil {
tps := apiutil.GetLocalTypes(api, route) if defineStruct, ok := route.RequestType.(spec.DefineStruct); ok {
hasRequestBody = len(defineStruct.GetBodyMembers()) > 0 || len(defineStruct.GetFormMembers()) > 0
for _, tp := range tps {
if first {
first = false
} else {
fmt.Fprintln(&builder)
}
if err := genType(&builder, tp); err != nil {
return err
} }
} }
types := builder.String()
writeIndent(&builder, 1)
params := paramsForRoute(route) params := strings.TrimSpace(paramsForRoute(route))
if len(params) > 0 && hasRequestBody {
params += ", "
}
paramsDeclaration := declarationForRoute(route) paramsDeclaration := declarationForRoute(route)
paramsSet := paramsSet(route) paramsSetter := paramsSet(route)
imports := getImports(api, packetName)
if len(route.ResponseTypeName()) == 0 {
imports += fmt.Sprintf("\v%s", "import com.xhb.core.response.EmptyResponse;")
}
t := template.Must(template.New("packetTemplate").Parse(packetTemplate)) t := template.Must(template.New("packetTemplate").Parse(packetTemplate))
var tmplBytes bytes.Buffer var tmplBytes bytes.Buffer
err = t.Execute(&tmplBytes, map[string]string{ err = t.Execute(&tmplBytes, map[string]interface{}{
"packetName": packet, "packetName": packet,
"method": strings.ToUpper(route.Method), "method": strings.ToUpper(route.Method),
"uri": processUri(route), "uri": processUri(route),
"types": strings.TrimSpace(types), "responseType": stringx.TakeOne(util.Title(route.ResponseTypeName()), "EmptyResponse"),
"responseType": stringx.TakeOne(util.Title(route.ResponseType.Name), "Object"),
"params": params, "params": params,
"paramsDeclaration": strings.TrimSpace(paramsDeclaration), "paramsDeclaration": strings.TrimSpace(paramsDeclaration),
"paramsSet": paramsSet, "paramsSetter": paramsSetter,
"packet": packetName, "packet": packetName,
"requestType": util.Title(route.RequestType.Name), "requestType": util.Title(route.RequestTypeName()),
"import": getImports(api, route, packetName), "HasRequestBody": hasRequestBody,
"imports": imports,
"doc": doc(route),
}) })
if err != nil { if err != nil {
return err return err
} }
formatFile(&tmplBytes, fp)
_, err = fp.WriteString(formatSource(tmplBytes.String()))
return nil return nil
} }
func getImports(api *spec.ApiSpec, route spec.Route, packetName string) string { func doc(route spec.Route) string {
var builder strings.Builder comment := route.JoinedDoc()
allTypes := apiutil.GetAllTypes(api, route) if len(comment) > 0 {
sharedTypes := apiutil.GetSharedTypes(api) formatter := `
for _, at := range allTypes { /*
for _, item := range sharedTypes { %s
if item.Name == at.Name { */`
fmt.Fprintf(&builder, "import com.xhb.logic.http.packet.%s.model.%s;\n", packetName, item.Name) return fmt.Sprintf(formatter, comment)
break
}
}
} }
return builder.String() return ""
} }
func formatFile(tmplBytes *bytes.Buffer, file *os.File) { func getImports(api *spec.ApiSpec, packetName string) string {
scanner := bufio.NewScanner(tmplBytes) var builder strings.Builder
builder := bufio.NewWriter(file) allTypes := api.Types
defer builder.Flush() if len(allTypes) > 0 {
preIsBreakLine := false fmt.Fprintf(&builder, "import com.xhb.logic.http.packet.%s.model.*;\n", packetName)
for scanner.Scan() {
text := strings.TrimSpace(scanner.Text())
if text == "" && preIsBreakLine {
continue
}
preIsBreakLine = text == ""
builder.WriteString(scanner.Text() + "\n")
}
if err := scanner.Err(); err != nil {
fmt.Println(err)
} }
return builder.String()
} }
func paramsSet(route spec.Route) string { func paramsSet(route spec.Route) string {
@@ -209,7 +161,7 @@ func paramsForRoute(route spec.Route) string {
builder.WriteString(fmt.Sprintf("String %s, ", cop[1:])) builder.WriteString(fmt.Sprintf("String %s, ", cop[1:]))
} }
} }
return builder.String() return strings.TrimSuffix(builder.String(), ", ")
} }
func declarationForRoute(route spec.Route) string { func declarationForRoute(route spec.Route) string {
@@ -235,6 +187,7 @@ func declarationForRoute(route spec.Route) string {
func processUri(route spec.Route) string { func processUri(route spec.Route) string {
path := route.Path path := route.Path
var builder strings.Builder var builder strings.Builder
cops := strings.Split(path, "/") cops := strings.Split(path, "/")
for index, cop := range cops { for index, cop := range cops {
@@ -255,25 +208,37 @@ func processUri(route spec.Route) string {
result = result[:len(result)-4] result = result[:len(result)-4]
} }
if strings.HasPrefix(result, "/") { if strings.HasPrefix(result, "/") {
result = strings.TrimPrefix(result, "/")
result = "\"" + result result = "\"" + result
} }
return result return result + formString(route)
} }
func genType(writer io.Writer, tp spec.Type) error { func formString(route spec.Route) string {
writeIndent(writer, 1) var keyValues []string
fmt.Fprintf(writer, "static class %s implements DeProguardable {\n", util.Title(tp.Name)) if defineStruct, ok := route.RequestType.(spec.DefineStruct); ok {
for _, member := range tp.Members { forms := defineStruct.GetFormMembers()
if err := writeProperty(writer, member, 2); err != nil { for _, item := range forms {
return err name, err := item.GetPropertyName()
if err != nil {
panic(err)
}
strcat := "?"
if len(keyValues) > 0 {
strcat = "&"
}
if item.Type.Name() == "bool" {
name = strings.TrimPrefix(name, "Is")
name = strings.TrimPrefix(name, "is")
keyValues = append(keyValues, fmt.Sprintf(`"%s%s=" + request.is%s()`, strcat, name, strings.Title(name)))
} else {
keyValues = append(keyValues, fmt.Sprintf(`"%s%s=" + request.get%s()`, strcat, name, strings.Title(name)))
}
}
if len(keyValues) > 0 {
return " + " + strings.Join(keyValues, " + ")
} }
} }
return ""
writeBreakline(writer)
writeIndent(writer, 1)
genGetSet(writer, tp, 2)
writeIndent(writer, 1)
fmt.Fprintln(writer, "}")
return nil
} }

View File

@@ -1,52 +1,53 @@
package javagen package javagen
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"strings" "strings"
"text/template"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
const getSetTemplate = `
{{.indent}}{{.decorator}}
{{.indent}}public {{.returnType}} get{{.property}}() {
{{.indent}} return this.{{.propertyValue}};
{{.indent}}}
{{.indent}}public void set{{.property}}({{.type}} {{.propertyValue}}) {
{{.indent}} this.{{.propertyValue}} = {{.propertyValue}};
{{.indent}}}
`
func writeProperty(writer io.Writer, member spec.Member, indent int) error { func writeProperty(writer io.Writer, member spec.Member, indent int) error {
if len(member.Comment) > 0 {
writeIndent(writer, indent)
fmt.Fprint(writer, member.Comment+util.NL)
}
writeIndent(writer, indent) writeIndent(writer, indent)
ty, err := goTypeToJava(member.Type) ty, err := specTypeToJava(member.Type)
ty = strings.Replace(ty, "*", "", 1) ty = strings.Replace(ty, "*", "", 1)
if err != nil { if err != nil {
return err return err
} }
name, err := member.GetPropertyName() name, err := member.GetPropertyName()
if err != nil { if err != nil {
return err return err
} }
_, err = fmt.Fprintf(writer, "private %s %s", ty, name) _, err = fmt.Fprintf(writer, "private %s %s", ty, name)
if err != nil { if err != nil {
return err return err
} }
writeDefaultValue(writer, member)
err = writeDefaultValue(writer, member)
if err != nil {
return err
}
fmt.Fprint(writer, ";\n") fmt.Fprint(writer, ";\n")
return err return err
} }
func writeDefaultValue(writer io.Writer, member spec.Member) error { func writeDefaultValue(writer io.Writer, member spec.Member) error {
switch member.Type { javaType, err := specTypeToJava(member.Type)
case "string": if err != nil {
return err
}
if javaType == "String" {
_, err := fmt.Fprintf(writer, " = \"\"") _, err := fmt.Fprintf(writer, " = \"\"")
return err return err
} }
@@ -67,97 +68,71 @@ func indentString(indent int) string {
return result return result
} }
func writeBreakline(writer io.Writer) { func specTypeToJava(tp spec.Type) (string, error) {
fmt.Fprint(writer, "\n") switch v := tp.(type) {
case spec.DefineStruct:
return util.Title(tp.Name()), nil
case spec.PrimitiveType:
r, ok := primitiveType(tp.Name())
if !ok {
return "", errors.New("unsupported primitive type " + tp.Name())
}
return r, nil
case spec.MapType:
valueType, err := specTypeToJava(v.Value)
if err != nil {
return "", err
}
return fmt.Sprintf("java.util.HashMap<String, %s>", util.Title(valueType)), nil
case spec.ArrayType:
if tp.Name() == "[]byte" {
return "byte[]", nil
}
valueType, err := specTypeToJava(v.Value)
if err != nil {
return "", err
}
switch valueType {
case "int":
return "Integer[]", nil
case "long":
return "Long[]", nil
case "float":
return "Float[]", nil
case "double":
return "Double[]", nil
case "boolean":
return "Boolean[]", nil
}
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
case spec.InterfaceType:
return "Object", nil
case spec.PointerType:
return specTypeToJava(v.Type)
}
return "", errors.New("unsupported primitive type " + tp.Name())
} }
func isPrimitiveType(tp string) bool { func primitiveType(tp string) (string, bool) {
switch tp {
case "int", "int32", "int64":
return true
case "float", "float32", "float64":
return true
case "bool":
return true
}
return false
}
func goTypeToJava(tp string) (string, error) {
if len(tp) == 0 {
return "", errors.New("property type empty")
}
if strings.HasPrefix(tp, "*") {
tp = tp[1:]
}
switch tp { switch tp {
case "string": case "string":
return "String", nil return "String", true
case "int64": case "int64", "uint64":
return "long", nil return "long", true
case "int", "int8", "int32": case "int", "int8", "int32", "uint", "uint8", "uint16", "uint32":
return "int", nil return "int", true
case "float", "float32", "float64": case "float", "float32":
return "double", nil return "float", true
case "float64":
return "double", true
case "bool": case "bool":
return "boolean", nil return "boolean", true
} }
if strings.HasPrefix(tp, "[]") {
tys, err := apiutil.DecomposeType(tp) return "", false
if err != nil {
return "", err
}
if len(tys) == 0 {
return "", fmt.Errorf("%s tp parse error", tp)
}
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(tys[0])), nil
} else if strings.HasPrefix(tp, "map") {
tys, err := apiutil.DecomposeType(tp)
if err != nil {
return "", err
}
if len(tys) == 2 {
return "", fmt.Errorf("%s tp parse error", tp)
}
return fmt.Sprintf("java.util.HashMap<String, %s>", util.Title(tys[1])), nil
}
return util.Title(tp), nil
}
func genGetSet(writer io.Writer, tp spec.Type, indent int) error {
t := template.Must(template.New("getSetTemplate").Parse(getSetTemplate))
for _, member := range tp.Members {
var tmplBytes bytes.Buffer
oty, err := goTypeToJava(member.Type)
if err != nil {
return err
}
tyString := oty
decorator := ""
if !isPrimitiveType(member.Type) {
if member.IsOptional() {
decorator = "@org.jetbrains.annotations.Nullable "
} else {
decorator = "@org.jetbrains.annotations.NotNull "
}
tyString = decorator + tyString
}
err = t.Execute(&tmplBytes, map[string]string{
"property": util.Title(member.Name),
"propertyValue": util.Untitle(member.Name),
"type": tyString,
"decorator": decorator,
"returnType": oty,
"indent": indentString(indent),
})
if err != nil {
return err
}
r := tmplBytes.String()
r = strings.Replace(r, " boolean get", " boolean is", 1)
writer.Write([]byte(r))
}
return nil
} }

View File

@@ -21,11 +21,7 @@ func KtCommand(c *cli.Context) error {
return errors.New("missing -pkg") return errors.New("missing -pkg")
} }
p, e := parser.NewParser(apiFile) api, e := parser.Parse(apiFile)
if e != nil {
return e
}
api, e := p.Parse()
if e != nil { if e != nil {
return e return e
} }

View File

@@ -1,6 +1,7 @@
package ktgen package ktgen
import ( import (
"fmt"
"log" "log"
"strings" "strings"
"text/template" "text/template"
@@ -44,7 +45,7 @@ func parseType(t string) string {
} }
if strings.HasPrefix(t, "map") { if strings.HasPrefix(t, "map") {
tys, e := util.DecomposeType(t) tys, e := decomposeType(t)
if e != nil { if e != nil {
log.Fatal(e) log.Fatal(e)
} }
@@ -68,6 +69,47 @@ func parseType(t string) string {
} }
} }
func decomposeType(t string) (result []string, err error) {
add := func(tp string) error {
ret, err := decomposeType(tp)
if err != nil {
return err
}
result = append(result, ret...)
return nil
}
if strings.HasPrefix(t, "map") {
t = strings.ReplaceAll(t, "map", "")
if t[0] == '[' {
pos := strings.Index(t, "]")
if pos > 1 {
if err = add(t[1:pos]); err != nil {
return
}
if len(t) > pos+1 {
err = add(t[pos+1:])
return
}
}
}
} else if strings.HasPrefix(t, "[]") {
if len(t) > 2 {
err = add(t[2:])
return
}
} else if strings.HasPrefix(t, "*") {
err = add(t[1:])
return
} else {
result = append(result, t)
return
}
err = fmt.Errorf("bad type %q", t)
return
}
func add(a, i int) int { func add(a, i int) int {
return a + i return a + i
} }

View File

@@ -126,10 +126,25 @@ func genBase(dir, pkg string, api *spec.ApiSpec) error {
} }
func genApi(dir, pkg string, api *spec.ApiSpec) error { func genApi(dir, pkg string, api *spec.ApiSpec) error {
name := strcase.ToCamel(api.Info.Title + "Api") properties := api.Info.Properties
if properties == nil {
return fmt.Errorf("none properties")
}
title := properties["Title"]
if len(title) == 0 {
return fmt.Errorf("none title")
}
desc := properties["Desc"]
if len(desc) == 0 {
return fmt.Errorf("none desc")
}
name := strcase.ToCamel(title + "Api")
path := filepath.Join(dir, name+".kt") path := filepath.Join(dir, name+".kt")
api.Info.Title = name api.Info.Title = name
api.Info.Desc = pkg api.Info.Desc = desc
e := os.MkdirAll(dir, 0755) e := os.MkdirAll(dir, 0755)
if e != nil { if e != nil {

View File

@@ -1,268 +0,0 @@
package parser
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"strings"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/api/util"
)
const (
tokenInfo = "info"
tokenImport = "import"
tokenType = "type"
tokenService = "service"
tokenServiceAnnotation = "@server"
tokenStruct = "struct"
)
type (
ApiStruct struct {
Info string
Type string
Service string
Imports string
serviceBeginLine int
}
apiFileState interface {
process(api *ApiStruct, token string) (apiFileState, error)
}
apiRootState struct {
*baseState
}
apiInfoState struct {
*baseState
}
apiImportState struct {
*baseState
}
apiTypeState struct {
*baseState
}
apiServiceState struct {
*baseState
}
)
func ParseApi(src string) (*ApiStruct, error) {
var buffer = new(bytes.Buffer)
buffer.WriteString(src)
api := new(ApiStruct)
var lineNumber = api.serviceBeginLine
apiFile := baseState{r: bufio.NewReader(buffer), lineNumber: &lineNumber}
st := apiRootState{&apiFile}
for {
st, err := st.process(api, "")
if err == io.EOF {
return api, nil
}
if err != nil {
return nil, fmt.Errorf("near line: %d, %s", lineNumber, err.Error())
}
if st == nil {
return api, nil
}
}
}
func (s *apiRootState) process(api *ApiStruct, _ string) (apiFileState, error) {
var builder strings.Builder
for {
ch, err := s.readSkipComment()
if err != nil {
return nil, err
}
switch {
case isSpace(ch) || isNewline(ch) || ch == leftParenthesis:
token := builder.String()
token = strings.TrimSpace(token)
if len(token) == 0 {
continue
}
builder.Reset()
switch token {
case tokenInfo:
info := apiInfoState{s.baseState}
return info.process(api, token+string(ch))
case tokenImport:
tp := apiImportState{s.baseState}
return tp.process(api, token+string(ch))
case tokenType:
ty := apiTypeState{s.baseState}
return ty.process(api, token+string(ch))
case tokenService:
server := apiServiceState{s.baseState}
return server.process(api, token+string(ch))
case tokenServiceAnnotation:
server := apiServiceState{s.baseState}
return server.process(api, token+string(ch))
default:
if strings.HasPrefix(token, "//") {
continue
}
return nil, errors.New(fmt.Sprintf("invalid token %s at line %d", token, *s.lineNumber))
}
default:
builder.WriteRune(ch)
}
}
}
func (s *apiInfoState) process(api *ApiStruct, token string) (apiFileState, error) {
for {
line, err := s.readLine()
if err != nil {
return nil, err
}
api.Info += newline + token + line
token = ""
if strings.TrimSpace(line) == string(rightParenthesis) {
return &apiRootState{s.baseState}, nil
}
}
}
func (s *apiImportState) process(api *ApiStruct, token string) (apiFileState, error) {
line, err := s.readLine()
if err != nil {
return nil, err
}
line = token + line
line = util.RemoveComment(line)
if len(strings.Fields(line)) != 2 {
return nil, errors.New("import syntax error: " + line)
}
api.Imports += newline + line
return &apiRootState{s.baseState}, nil
}
func (s *apiTypeState) process(api *ApiStruct, token string) (apiFileState, error) {
var blockCount = 0
var braceCount = 0
for {
line, err := s.readLine()
if err != nil {
return nil, err
}
line = token + line
if braceCount == 0 {
line = mayInsertStructKeyword(line)
}
api.Type += newline + newline + line
line = strings.TrimSpace(line)
line = util.RemoveComment(line)
token = ""
if strings.HasSuffix(line, leftBrace) {
blockCount++
braceCount++
}
if strings.HasSuffix(line, string(leftParenthesis)) {
blockCount++
}
if strings.HasSuffix(line, string(rightBrace)) {
blockCount--
braceCount--
}
if strings.HasSuffix(line, string(rightParenthesis)) {
blockCount--
}
if braceCount >= 2 {
return nil, errors.New("nested type not supported: " + line)
}
if braceCount < 0 {
line = strings.TrimSuffix(line, string(rightBrace))
line = strings.TrimSpace(line)
if strings.HasSuffix(line, leftBrace) {
blockCount++
braceCount++
}
}
if blockCount == 0 {
return &apiRootState{s.baseState}, nil
}
}
}
func (s *apiServiceState) process(api *ApiStruct, token string) (apiFileState, error) {
var blockCount = 0
for {
line, err := s.readLineSkipComment()
if err != nil {
return nil, err
}
line = token + line
token = ""
api.Service += newline + line
line = strings.TrimSpace(line)
line = util.RemoveComment(line)
if strings.HasSuffix(line, leftBrace) {
blockCount++
}
if strings.HasSuffix(line, string(leftParenthesis)) {
blockCount++
}
if line == string(rightBrace) {
blockCount--
}
if line == string(rightParenthesis) {
blockCount--
}
if blockCount == 0 {
return &apiRootState{s.baseState}, nil
}
}
}
func mayInsertStructKeyword(line string) string {
line = util.RemoveComment(line)
if !strings.HasSuffix(line, leftBrace) && !strings.HasSuffix(line, string(rightBrace)) {
return line
}
fields := strings.Fields(line)
if stringx.Contains(fields, tokenStruct) ||
stringx.Contains(fields, tokenStruct+leftBrace) ||
stringx.Contains(fields, tokenStruct+leftBrace+string(rightBrace)) ||
len(fields) <= 1 {
return line
}
var insertIndex int
if fields[0] == tokenType {
insertIndex = 2
} else {
insertIndex = 1
}
if insertIndex >= len(fields) {
return line
}
var result []string
result = append(result, fields[:insertIndex]...)
result = append(result, tokenStruct)
result = append(result, fields[insertIndex:]...)
return strings.Join(result, " ")
}

View File

@@ -1,236 +0,0 @@
package parser
import (
"bufio"
"fmt"
"strings"
)
const (
startState = iota
attrNameState
attrValueState
attrColonState
multilineState
)
type baseState struct {
r *bufio.Reader
lineNumber *int
}
func newBaseState(r *bufio.Reader, lineNumber *int) *baseState {
return &baseState{
r: r,
lineNumber: lineNumber,
}
}
func (s *baseState) parseProperties() (map[string]string, error) {
var r = s.r
var attributes = make(map[string]string)
var builder strings.Builder
var key string
var st = startState
for {
ch, err := s.readSkipComment()
if err != nil {
return nil, err
}
switch st {
case startState:
switch {
case isNewline(ch):
return nil, fmt.Errorf("%q should be on the same line with %q", leftParenthesis, infoDirective)
case isSpace(ch):
continue
case ch == leftParenthesis:
st = attrNameState
default:
return nil, fmt.Errorf("unexpected char %q after %q", ch, infoDirective)
}
case attrNameState:
switch {
case isNewline(ch):
if builder.Len() > 0 {
return nil, fmt.Errorf("unexpected newline after %q", builder.String())
}
case isLetterDigit(ch):
builder.WriteRune(ch)
case isSpace(ch):
if builder.Len() > 0 {
key = builder.String()
builder.Reset()
st = attrColonState
}
case ch == colon:
if builder.Len() == 0 {
return nil, fmt.Errorf("unexpected leading %q", ch)
}
key = builder.String()
builder.Reset()
st = attrValueState
case ch == rightParenthesis:
return attributes, nil
}
case attrColonState:
switch {
case isSpace(ch):
continue
case ch == colon:
st = attrValueState
default:
return nil, fmt.Errorf("bad char %q after %q in %q", ch, key, infoDirective)
}
case attrValueState:
switch {
case ch == multilineBeginTag:
if builder.Len() > 0 {
return nil, fmt.Errorf("%q before %q", builder.String(), multilineBeginTag)
} else {
st = multilineState
}
case isSpace(ch):
if builder.Len() > 0 {
builder.WriteRune(ch)
}
case isNewline(ch):
attributes[key] = builder.String()
builder.Reset()
st = attrNameState
case ch == rightParenthesis:
attributes[key] = builder.String()
builder.Reset()
return attributes, nil
default:
builder.WriteRune(ch)
}
case multilineState:
switch {
case ch == multilineEndTag:
attributes[key] = builder.String()
builder.Reset()
st = attrNameState
case isNewline(ch):
var multipleNewlines bool
loopAfterNewline:
for {
next, err := read(r)
if err != nil {
return nil, err
}
switch {
case isSpace(next):
continue
case isNewline(next):
multipleNewlines = true
default:
if err := unread(r); err != nil {
return nil, err
}
break loopAfterNewline
}
}
if multipleNewlines {
fmt.Fprintln(&builder)
} else {
builder.WriteByte(' ')
}
case ch == rightParenthesis:
if builder.Len() > 0 {
attributes[key] = builder.String()
builder.Reset()
}
return attributes, nil
default:
builder.WriteRune(ch)
}
}
}
}
func (s *baseState) read() (rune, error) {
value, err := read(s.r)
if err != nil {
return 0, err
}
if isNewline(value) {
*s.lineNumber++
}
return value, nil
}
func (s *baseState) readSkipComment() (rune, error) {
ch, err := s.read()
if err != nil {
return 0, err
}
if isSlash(ch) {
value, err := s.mayReadToEndOfLine()
if err != nil {
return 0, err
}
if value > 0 {
ch = value
}
}
return ch, nil
}
func (s *baseState) mayReadToEndOfLine() (rune, error) {
ch, err := s.read()
if err != nil {
return 0, err
}
if isSlash(ch) {
for {
value, err := s.read()
if err != nil {
return 0, err
}
if isNewline(value) {
return value, nil
}
}
}
err = s.unread()
return 0, err
}
func (s *baseState) readLineSkipComment() (string, error) {
line, err := s.readLine()
if err != nil {
return "", err
}
var commentIdx = strings.Index(line, "//")
if commentIdx >= 0 {
return line[:commentIdx], nil
}
return line, nil
}
func (s *baseState) readLine() (string, error) {
line, _, err := s.r.ReadLine()
if err != nil {
return "", err
}
*s.lineNumber++
return string(line), nil
}
func (s *baseState) skipSpaces() error {
return skipSpaces(s.r)
}
func (s *baseState) unread() error {
return unread(s.r)
}

View File

@@ -1,20 +0,0 @@
package parser
import (
"bufio"
"bytes"
"testing"
"github.com/stretchr/testify/assert"
)
func TestProperties(t *testing.T) {
const text = `(summary: hello world)`
var builder bytes.Buffer
builder.WriteString(text)
var lineNumber = 1
var state = newBaseState(bufio.NewReader(&builder), &lineNumber)
m, err := state.parseProperties()
assert.Nil(t, err)
assert.Equal(t, "hello world", m["summary"])
}

View File

@@ -1,146 +0,0 @@
package parser
import (
"errors"
"fmt"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
)
type (
entity struct {
state *baseState
api *spec.ApiSpec
parser entityParser
}
entityParser interface {
parseLine(line string, api *spec.ApiSpec, annos []spec.Annotation) error
setEntityName(name string)
}
)
func newEntity(state *baseState, api *spec.ApiSpec, parser entityParser) entity {
return entity{
state: state,
api: api,
parser: parser,
}
}
func (s *entity) process() error {
line, err := s.state.readLineSkipComment()
if err != nil {
return err
}
fields := strings.Fields(line)
if len(fields) < 2 {
return fmt.Errorf("invalid type definition for %q",
strings.TrimSpace(strings.Trim(string(line), "{")))
}
if len(fields) == 2 {
if fields[1] != leftBrace {
return fmt.Errorf("bad string %q after type", fields[1])
}
} else if len(fields) == 3 {
if fields[1] != typeStruct {
return fmt.Errorf("bad string %q after type", fields[1])
}
if fields[2] != leftBrace {
return fmt.Errorf("bad string %q after type", fields[2])
}
}
s.parser.setEntityName(fields[0])
var annos []spec.Annotation
memberLoop:
for {
ch, err := s.state.readSkipComment()
if err != nil {
return err
}
var annoName string
var builder strings.Builder
switch {
case ch == at:
annotationLoop:
for {
next, err := s.state.readSkipComment()
if err != nil {
return err
}
switch {
case isSpace(next):
if builder.Len() > 0 && annoName == "" {
annoName = builder.String()
builder.Reset()
}
case isNewline(next):
if builder.Len() == 0 {
return errors.New("invalid annotation format")
}
if len(annoName) > 0 {
value := builder.String()
if value != string(leftParenthesis) {
builder.Reset()
annos = append(annos, spec.Annotation{
Name: annoName,
Value: value,
})
annoName = ""
break annotationLoop
}
}
case next == leftParenthesis:
if builder.Len() == 0 {
return errors.New("invalid annotation format")
}
annoName = builder.String()
builder.Reset()
if err := s.state.unread(); err != nil {
return err
}
attrs, err := s.state.parseProperties()
if err != nil {
return err
}
annos = append(annos, spec.Annotation{
Name: annoName,
Properties: attrs,
})
annoName = ""
break annotationLoop
default:
builder.WriteRune(next)
}
}
case ch == rightBrace:
break memberLoop
case isLetterDigit(ch):
if err := s.state.unread(); err != nil {
return err
}
var line string
line, err = s.state.readLineSkipComment()
if err != nil {
return err
}
line = strings.TrimSpace(line)
if err := s.parser.parseLine(line, s.api, annos); err != nil {
return err
}
annos = nil
}
}
return nil
}

View File

@@ -0,0 +1,46 @@
lexer grammar ApiLexer;
// Keywords
ATDOC: '@doc';
ATHANDLER: '@handler';
INTERFACE: 'interface{}';
ATSERVER: '@server';
// Whitespace and comments
WS: [ \t\r\n\u000C]+ -> channel(HIDDEN);
COMMENT: '/*' .*? '*/' -> channel(88);
LINE_COMMENT: '//' ~[\r\n]* -> channel(88);
STRING: '"' (~["\\] | EscapeSequence)* '"';
RAW_STRING: '`' (~[`\\\r\n] | EscapeSequence)+ '`';
LINE_VALUE: ':' [ \t]* (STRING|(~[\r\n"`]*));
ID: Letter LetterOrDigit*;
fragment ExponentPart
: [eE] [+-]? Digits
;
fragment EscapeSequence
: '\\' [btnfr"'\\]
| '\\' ([0-3]? [0-7])? [0-7]
| '\\' 'u'+ HexDigit HexDigit HexDigit HexDigit
;
fragment HexDigits
: HexDigit ((HexDigit | '_')* HexDigit)?
;
fragment HexDigit
: [0-9a-fA-F]
;
fragment Digits
: [0-9] ([0-9_]* [0-9])?
;
fragment LetterOrDigit
: Letter
| [0-9]
;
fragment Letter
: [a-zA-Z$_] // these are the "java letters" below 0x7F
| ~[\u0000-\u007F\uD800-\uDBFF] // covers all characters above 0x7F which are not a surrogate
| [\uD800-\uDBFF] [\uDC00-\uDFFF] // covers UTF-16 surrogate pairs encodings for U+10000 to U+10FFFF
;

View File

@@ -0,0 +1,73 @@
grammar ApiParser;
import ApiLexer;
@lexer::members{
const COMEMNTS = 88
}
api: spec*;
spec: syntaxLit
|importSpec
|infoSpec
|typeSpec
|serviceSpec
;
// syntax
syntaxLit: {match(p,"syntax")}syntaxToken=ID assign='=' {checkVersion(p)}version=STRING;
// import
importSpec: importLit|importBlock;
importLit: {match(p,"import")}importToken=ID importValue ;
importBlock: {match(p,"import")}importToken=ID '(' importBlockValue+ ')';
importBlockValue: importValue;
importValue: {checkImportValue(p)}STRING;
// info
infoSpec: {match(p,"info")}infoToken=ID lp='(' kvLit+ rp=')';
// type
typeSpec: typeLit
|typeBlock;
// eg: type Foo int
typeLit: {match(p,"type")}typeToken=ID typeLitBody;
// eg: type (...)
typeBlock: {match(p,"type")}typeToken=ID lp='(' typeBlockBody* rp=')';
typeLitBody: typeStruct|typeAlias;
typeBlockBody: typeBlockStruct|typeBlockAlias;
typeStruct: {checkKeyword(p)}structName=ID structToken=ID? lbrace='{' field* rbrace='}';
typeAlias: {checkKeyword(p)}alias=ID assign='='? dataType;
typeBlockStruct: {checkKeyword(p)}structName=ID structToken=ID? lbrace='{' field* rbrace='}';
typeBlockAlias: {checkKeyword(p)}alias=ID assign='='? dataType;
field: {isNormal(p)}? normalField|anonymousFiled ;
normalField: {checkKeyword(p)}fieldName=ID dataType tag=RAW_STRING?;
anonymousFiled: star='*'? ID;
dataType: {isInterface(p)}ID
|mapType
|arrayType
|inter='interface{}'
|time='time.Time'
|pointerType
|typeStruct
;
pointerType: star='*' {checkKeyword(p)}ID;
mapType: {match(p,"map")}mapToken=ID lbrack='[' {checkKey(p)}key=ID rbrack=']' value=dataType;
arrayType: lbrack='[' rbrack=']' dataType;
// service
serviceSpec: atServer? serviceApi;
atServer: ATSERVER lp='(' kvLit+ rp=')';
serviceApi: {match(p,"service")}serviceToken=ID serviceName lbrace='{' serviceRoute* rbrace='}';
serviceRoute: atDoc? (atServer|atHandler) route;
atDoc: ATDOC lp='('? ((kvLit+)|STRING) rp=')'?;
atHandler: ATHANDLER ID;
route: {checkHttpMethod(p)}httpMethod=ID path request=body? returnToken=ID? response=replybody?;
body: lp='(' (ID)? rp=')';
replybody: lp='(' dataType? rp=')';
// kv
kvLit: key=ID {checkKeyValue(p)}value=LINE_VALUE;
serviceName: (ID '-'?)+;
path: (('/' (ID ('-' ID)*))|('/:' (ID ('-' ID)?)))+;

View File

@@ -0,0 +1,251 @@
package ast
import (
"fmt"
"sort"
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
)
type Api struct {
LinePrefix string
Syntax *SyntaxExpr
Import []*ImportExpr
importM map[string]PlaceHolder
Info *InfoExpr
Type []TypeExpr
typeM map[string]PlaceHolder
Service []*Service
serviceM map[string]PlaceHolder
handlerM map[string]PlaceHolder
routeM map[string]PlaceHolder
}
func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
defer func() {
if p := recover(); p != nil {
panic(fmt.Errorf("%+v", p))
}
}()
var final Api
final.importM = map[string]PlaceHolder{}
final.typeM = map[string]PlaceHolder{}
final.serviceM = map[string]PlaceHolder{}
final.handlerM = map[string]PlaceHolder{}
final.routeM = map[string]PlaceHolder{}
for _, each := range ctx.AllSpec() {
root := each.Accept(v).(*Api)
if root.Syntax != nil {
if final.Syntax != nil {
v.panic(root.Syntax.Syntax, fmt.Sprintf("mutiple syntax declaration"))
}
final.Syntax = root.Syntax
}
for _, imp := range root.Import {
if _, ok := final.importM[imp.Value.Text()]; ok {
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
}
final.importM[imp.Value.Text()] = Holder
final.Import = append(final.Import, imp)
}
if root.Info != nil {
infoM := map[string]PlaceHolder{}
if final.Info != nil {
v.panic(root.Info.Info, fmt.Sprintf("mutiple info declaration"))
}
for _, value := range root.Info.Kvs {
if _, ok := infoM[value.Key.Text()]; ok {
v.panic(value.Key, fmt.Sprintf("duplicate key '%s'", value.Key.Text()))
}
infoM[value.Key.Text()] = Holder
}
final.Info = root.Info
}
for _, tp := range root.Type {
if _, ok := final.typeM[tp.NameExpr().Text()]; ok {
v.panic(tp.NameExpr(), fmt.Sprintf("duplicate type '%s'", tp.NameExpr().Text()))
}
final.typeM[tp.NameExpr().Text()] = Holder
final.Type = append(final.Type, tp)
}
for _, service := range root.Service {
if _, ok := final.serviceM[service.ServiceApi.Name.Text()]; !ok && len(final.serviceM) > 0 {
v.panic(service.ServiceApi.Name, fmt.Sprintf("mutiple service declaration"))
}
if service.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range service.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
atServerM[kv.Key.Text()] = Holder
}
}
for _, route := range service.ServiceApi.ServiceRoute {
uniqueRoute := fmt.Sprintf("%s %s", route.Route.Method.Text(), route.Route.Path.Text())
if _, ok := final.routeM[uniqueRoute]; ok {
v.panic(route.Route.Method, fmt.Sprintf("duplicate route '%s'", uniqueRoute))
}
final.routeM[uniqueRoute] = Holder
var handlerExpr Expr
if route.AtServer != nil {
atServerM := map[string]PlaceHolder{}
for _, kv := range route.AtServer.Kv {
if _, ok := atServerM[kv.Key.Text()]; ok {
v.panic(kv.Key, fmt.Sprintf("duplicate key '%s'", kv.Key.Text()))
}
atServerM[kv.Key.Text()] = Holder
if kv.Key.Text() == "handler" {
handlerExpr = kv.Value
}
}
}
if route.AtHandler != nil {
handlerExpr = route.AtHandler.Name
}
if handlerExpr == nil {
v.panic(route.Route.Method, fmt.Sprintf("mismtached handler"))
}
if handlerExpr.Text() == "" {
v.panic(handlerExpr, fmt.Sprintf("mismtached handler"))
}
if _, ok := final.handlerM[handlerExpr.Text()]; ok {
v.panic(handlerExpr, fmt.Sprintf("duplicate handler '%s'", handlerExpr.Text()))
}
final.handlerM[handlerExpr.Text()] = Holder
}
final.Service = append(final.Service, service)
}
}
return &final
}
func (v *ApiVisitor) VisitSpec(ctx *api.SpecContext) interface{} {
var root Api
if ctx.SyntaxLit() != nil {
root.Syntax = ctx.SyntaxLit().Accept(v).(*SyntaxExpr)
}
if ctx.ImportSpec() != nil {
root.Import = ctx.ImportSpec().Accept(v).([]*ImportExpr)
}
if ctx.InfoSpec() != nil {
root.Info = ctx.InfoSpec().Accept(v).(*InfoExpr)
}
if ctx.TypeSpec() != nil {
tp := ctx.TypeSpec().Accept(v)
root.Type = tp.([]TypeExpr)
}
if ctx.ServiceSpec() != nil {
root.Service = []*Service{ctx.ServiceSpec().Accept(v).(*Service)}
}
return &root
}
func (a *Api) Format() error {
// todo
return nil
}
func (a *Api) Equal(v interface{}) bool {
if v == nil {
return false
}
root, ok := v.(*Api)
if !ok {
return false
}
if !a.Syntax.Equal(root.Syntax) {
return false
}
if len(a.Import) != len(root.Import) {
return false
}
var expectingImport, actualImport []*ImportExpr
expectingImport = append(expectingImport, a.Import...)
actualImport = append(actualImport, root.Import...)
sort.Slice(expectingImport, func(i, j int) bool {
return expectingImport[i].Value.Text() < expectingImport[j].Value.Text()
})
sort.Slice(actualImport, func(i, j int) bool {
return actualImport[i].Value.Text() < actualImport[j].Value.Text()
})
for index, each := range expectingImport {
ac := actualImport[index]
if !each.Equal(ac) {
return false
}
}
if !a.Info.Equal(root.Info) {
return false
}
if len(a.Type) != len(root.Type) {
return false
}
var expectingType, actualType []TypeExpr
expectingType = append(expectingType, a.Type...)
actualType = append(actualType, root.Type...)
sort.Slice(expectingType, func(i, j int) bool {
return expectingType[i].NameExpr().Text() < expectingType[j].NameExpr().Text()
})
sort.Slice(actualType, func(i, j int) bool {
return actualType[i].NameExpr().Text() < actualType[j].NameExpr().Text()
})
for index, each := range expectingType {
ac := actualType[index]
if !each.Equal(ac) {
return false
}
}
if len(a.Service) != len(root.Service) {
return false
}
var expectingService, actualService []*Service
expectingService = append(expectingService, a.Service...)
actualService = append(actualService, root.Service...)
for index, each := range expectingService {
ac := actualService[index]
if !each.Equal(ac) {
return false
}
}
return true
}

View File

@@ -0,0 +1,405 @@
package ast
import (
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"github.com/antlr/antlr4/runtime/Go/antlr"
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
)
type (
Parser struct {
linePrefix string
debug bool
log console.Console
antlr.DefaultErrorListener
}
ParserOption func(p *Parser)
)
func NewParser(options ...ParserOption) *Parser {
p := &Parser{
log: console.NewColorConsole(),
}
for _, opt := range options {
opt(p)
}
return p
}
// Accept can parse any terminalNode of api tree by fn.
// -- for debug
func (p *Parser) Accept(fn func(p *api.ApiParserParser, visitor *ApiVisitor) interface{}, content string) (v interface{}, err error) {
defer func() {
p := recover()
if p != nil {
switch e := p.(type) {
case error:
err = e
default:
err = fmt.Errorf("%+v", p)
}
}
}()
inputStream := antlr.NewInputStream(content)
lexer := api.NewApiParserLexer(inputStream)
lexer.RemoveErrorListeners()
tokens := antlr.NewCommonTokenStream(lexer, antlr.LexerDefaultTokenChannel)
apiParser := api.NewApiParserParser(tokens)
apiParser.RemoveErrorListeners()
apiParser.AddErrorListener(p)
var visitorOptions []VisitorOption
visitorOptions = append(visitorOptions, WithVisitorPrefix(p.linePrefix))
if p.debug {
visitorOptions = append(visitorOptions, WithVisitorDebug())
}
visitor := NewApiVisitor(visitorOptions...)
v = fn(apiParser, visitor)
return
}
// Parse is used to parse the api from the specified file name
func (p *Parser) Parse(filename string) (*Api, error) {
data, err := p.readContent(filename)
if err != nil {
return nil, err
}
return p.parse(filename, data)
}
// ParseContent is used to parse the api from the specified content
func (p *Parser) ParseContent(content string) (*Api, error) {
return p.parse("", content)
}
// parse is used to parse api from the content
// filename is only used to mark the file where the error is located
func (p *Parser) parse(filename, content string) (*Api, error) {
root, err := p.invoke(filename, content)
if err != nil {
return nil, err
}
var apiAstList []*Api
apiAstList = append(apiAstList, root)
for _, imp := range root.Import {
path := imp.Value.Text()
data, err := p.readContent(path)
if err != nil {
return nil, err
}
nestedApi, err := p.invoke(path, data)
if err != nil {
return nil, err
}
err = p.valid(root, nestedApi)
if err != nil {
return nil, err
}
apiAstList = append(apiAstList, nestedApi)
}
err = p.checkTypeDeclaration(apiAstList)
if err != nil {
return nil, err
}
allApi := p.memberFill(apiAstList)
return allApi, nil
}
func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) {
defer func() {
p := recover()
if p != nil {
switch e := p.(type) {
case error:
err = e
default:
err = fmt.Errorf("%+v", p)
}
}
}()
if linePrefix != "" {
p.linePrefix = linePrefix
}
inputStream := antlr.NewInputStream(content)
lexer := api.NewApiParserLexer(inputStream)
lexer.RemoveErrorListeners()
tokens := antlr.NewCommonTokenStream(lexer, antlr.LexerDefaultTokenChannel)
apiParser := api.NewApiParserParser(tokens)
apiParser.RemoveErrorListeners()
apiParser.AddErrorListener(p)
var visitorOptions []VisitorOption
visitorOptions = append(visitorOptions, WithVisitorPrefix(p.linePrefix))
if p.debug {
visitorOptions = append(visitorOptions, WithVisitorDebug())
}
visitor := NewApiVisitor(visitorOptions...)
v = apiParser.Api().Accept(visitor).(*Api)
v.LinePrefix = p.linePrefix
return
}
func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
if len(nestedApi.Import) > 0 {
importToken := nestedApi.Import[0].Import
return fmt.Errorf("%s line %d:%d the nested api does not support import",
nestedApi.LinePrefix, importToken.Line(), importToken.Column())
}
if mainApi.Syntax != nil && nestedApi.Syntax != nil {
if mainApi.Syntax.Version.Text() != nestedApi.Syntax.Version.Text() {
syntaxToken := nestedApi.Syntax.Syntax
return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'",
nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), mainApi.Syntax.Version.Text(), nestedApi.Syntax.Version.Text())
}
}
if len(mainApi.Service) > 0 {
mainService := mainApi.Service[0]
for _, service := range nestedApi.Service {
if mainService.ServiceApi.Name.Text() != service.ServiceApi.Name.Text() {
return fmt.Errorf("%s multiple service name declaration, expecting service name '%s', but found '%s'",
nestedApi.LinePrefix, mainService.ServiceApi.Name.Text(), service.ServiceApi.Name.Text())
}
}
}
mainHandlerMap := make(map[string]PlaceHolder)
mainRouteMap := make(map[string]PlaceHolder)
mainTypeMap := make(map[string]PlaceHolder)
routeMap := func(list []*ServiceRoute) (map[string]PlaceHolder, map[string]PlaceHolder) {
handlerMap := make(map[string]PlaceHolder)
routeMap := make(map[string]PlaceHolder)
for _, g := range list {
handler := g.GetHandler()
if handler.IsNotNil() {
var handlerName = handler.Text()
handlerMap[handlerName] = Holder
path := fmt.Sprintf("%s://%s", g.Route.Method.Text(), g.Route.Path.Text())
routeMap[path] = Holder
}
}
return handlerMap, routeMap
}
for _, each := range mainApi.Service {
h, r := routeMap(each.ServiceApi.ServiceRoute)
for k, v := range h {
mainHandlerMap[k] = v
}
for k, v := range r {
mainRouteMap[k] = v
}
}
for _, each := range mainApi.Type {
mainTypeMap[each.NameExpr().Text()] = Holder
}
// duplicate route check
for _, each := range nestedApi.Service {
for _, r := range each.ServiceApi.ServiceRoute {
handler := r.GetHandler()
if !handler.IsNotNil() {
return fmt.Errorf("%s handler not exist near line %d", nestedApi.LinePrefix, r.Route.Method.Line())
}
if _, ok := mainHandlerMap[handler.Text()]; ok {
return fmt.Errorf("%s line %d:%d duplicate handler '%s'",
nestedApi.LinePrefix, handler.Line(), handler.Column(), handler.Text())
}
path := fmt.Sprintf("%s://%s", r.Route.Method.Text(), r.Route.Path.Text())
if _, ok := mainRouteMap[path]; ok {
return fmt.Errorf("%s line %d:%d duplicate route '%s'",
nestedApi.LinePrefix, r.Route.Method.Line(), r.Route.Method.Column(), r.Route.Method.Text()+" "+r.Route.Path.Text())
}
}
}
// duplicate type check
for _, each := range nestedApi.Type {
if _, ok := mainTypeMap[each.NameExpr().Text()]; ok {
return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'",
nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text())
}
}
return nil
}
func (p *Parser) memberFill(apiList []*Api) *Api {
var root Api
for index, each := range apiList {
if index == 0 {
root.Syntax = each.Syntax
root.Info = each.Info
root.Import = each.Import
}
root.Type = append(root.Type, each.Type...)
root.Service = append(root.Service, each.Service...)
}
return &root
}
// checkTypeDeclaration checks whether a struct type has been declared in context
func (p *Parser) checkTypeDeclaration(apiList []*Api) error {
types := make(map[string]TypeExpr)
for _, root := range apiList {
for _, each := range root.Type {
types[each.NameExpr().Text()] = each
}
}
for _, apiItem := range apiList {
linePrefix := apiItem.LinePrefix
for _, each := range apiItem.Type {
tp, ok := each.(*TypeStruct)
if !ok {
continue
}
for _, member := range tp.Fields {
err := p.checkType(linePrefix, types, member.DataType)
if err != nil {
return err
}
}
}
for _, service := range apiItem.Service {
for _, each := range service.ServiceApi.ServiceRoute {
route := each.Route
if route.Req != nil && route.Req.Name.IsNotNil() && route.Req.Name.Expr().IsNotNil() {
_, ok := types[route.Req.Name.Expr().Text()]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Req.Name.Expr().Line(), route.Req.Name.Expr().Column(), route.Req.Name.Expr().Text())
}
}
if route.Reply != nil && route.Reply.Name.IsNotNil() && route.Reply.Name.Expr().IsNotNil() {
reply := route.Reply.Name
var structName string
switch tp := reply.(type) {
case *Literal:
structName = tp.Literal.Text()
case *Array:
switch innerTp := tp.Literal.(type) {
case *Literal:
structName = innerTp.Literal.Text()
case *Pointer:
structName = innerTp.Name.Text()
}
}
if api.IsBasicType(structName) {
continue
}
_, ok := types[structName]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Reply.Name.Expr().Line(), route.Reply.Name.Expr().Column(), structName)
}
}
}
}
}
return nil
}
func (p *Parser) checkType(linePrefix string, types map[string]TypeExpr, expr DataType) error {
if expr == nil {
return nil
}
switch v := expr.(type) {
case *Literal:
name := v.Literal.Text()
if api.IsBasicType(name) {
return nil
}
_, ok := types[name]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, v.Literal.Line(), v.Literal.Column(), name)
}
case *Pointer:
name := v.Name.Text()
if api.IsBasicType(name) {
return nil
}
_, ok := types[name]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, v.Name.Line(), v.Name.Column(), name)
}
case *Map:
return p.checkType(linePrefix, types, v.Value)
case *Array:
return p.checkType(linePrefix, types, v.Literal)
default:
return nil
}
return nil
}
func (p *Parser) readContent(filename string) (string, error) {
filename = strings.ReplaceAll(filename, `"`, "")
abs, err := filepath.Abs(filename)
if err != nil {
return "", err
}
data, err := ioutil.ReadFile(abs)
if err != nil {
return "", err
}
return string(data), nil
}
func (p *Parser) SyntaxError(_ antlr.Recognizer, _ interface{}, line, column int, msg string, _ antlr.RecognitionException) {
str := fmt.Sprintf(`%s line %d:%d %s`, p.linePrefix, line, column, msg)
if p.debug {
p.log.Error(str)
}
panic(str)
}
func WithParserDebug() ParserOption {
return func(p *Parser) {
p.debug = true
}
}
func WithParserPrefix(prefix string) ParserOption {
return func(p *Parser) {
p.linePrefix = prefix
}
}

Some files were not shown because too many files have changed in this diff Show More