mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-16 19:28:18 +08:00
Compare commits
11 Commits
tools/goct
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b5e7b1c26 | ||
|
|
4ad4fd43b7 | ||
|
|
3f3dca6f72 | ||
|
|
4ea0d9e9ea | ||
|
|
52d2bdadcd | ||
|
|
3738be1945 | ||
|
|
5b74b9ab7b | ||
|
|
4a67261b7b | ||
|
|
22bdae0787 | ||
|
|
e8675d6a9a | ||
|
|
e441c44975 |
2
.github/workflows/go.yml
vendored
2
.github/workflows/go.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
|
||||
- name: Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
uses: codecov/codecov-action@v6
|
||||
with:
|
||||
files: ./coverage.txt
|
||||
flags: unittests
|
||||
|
||||
@@ -263,14 +263,24 @@ func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []
|
||||
for _, ev := range events {
|
||||
switch ev.Type {
|
||||
case clientv3.EventTypePut:
|
||||
evKey := string(ev.Kv.Key)
|
||||
evVal := string(ev.Kv.Value)
|
||||
c.lock.Lock()
|
||||
watcher.values[string(ev.Kv.Key)] = string(ev.Kv.Value)
|
||||
oldVal, exists := watcher.values[evKey]
|
||||
watcher.values[evKey] = evVal
|
||||
c.lock.Unlock()
|
||||
if exists && oldVal == evVal {
|
||||
// duplicate PUT with same value, skip to prevent unbounded growth
|
||||
continue
|
||||
}
|
||||
if exists {
|
||||
// key moved to a new value, notify delete of old entry first
|
||||
for _, l := range listeners {
|
||||
l.OnDelete(KV{Key: evKey, Val: oldVal})
|
||||
}
|
||||
}
|
||||
for _, l := range listeners {
|
||||
l.OnAdd(KV{
|
||||
Key: string(ev.Kv.Key),
|
||||
Val: string(ev.Kv.Value),
|
||||
})
|
||||
l.OnAdd(KV{Key: evKey, Val: evVal})
|
||||
}
|
||||
case clientv3.EventTypeDelete:
|
||||
c.lock.Lock()
|
||||
@@ -433,7 +443,7 @@ func (c *cluster) setupWatch(cli EtcdClient, key watchKey, rev int64) (context.C
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cli.Ctx())
|
||||
|
||||
|
||||
c.lock.Lock()
|
||||
if watcher, ok := c.watchers[key]; ok {
|
||||
watcher.cancel = cancel
|
||||
|
||||
@@ -517,7 +517,7 @@ func TestCluster_ConcurrentMonitor(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
key := keys[idx%len(keys)]
|
||||
|
||||
|
||||
if idx%2 == 0 {
|
||||
// Half the goroutines add listeners (write operation)
|
||||
c.addListener(key, &mockListener{})
|
||||
@@ -543,6 +543,50 @@ func TestCluster_ConcurrentMonitor(t *testing.T) {
|
||||
close(c.done)
|
||||
}
|
||||
|
||||
func TestCluster_handleWatchEvents_DuplicatePut(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
listener := NewMockUpdateListener(ctrl)
|
||||
// OnAdd must be called exactly once despite two PUT events with the same key+value.
|
||||
listener.EXPECT().OnAdd(KV{Key: "hello", Val: "world"}).Times(1)
|
||||
|
||||
c := newCluster([]string{"any"})
|
||||
key := watchKey{key: "any"}
|
||||
c.watchers[key] = &watchValue{
|
||||
listeners: []UpdateListener{listener},
|
||||
values: make(map[string]string),
|
||||
}
|
||||
events := []*clientv3.Event{
|
||||
{Type: clientv3.EventTypePut, Kv: &mvccpb.KeyValue{Key: []byte("hello"), Value: []byte("world")}},
|
||||
{Type: clientv3.EventTypePut, Kv: &mvccpb.KeyValue{Key: []byte("hello"), Value: []byte("world")}},
|
||||
}
|
||||
c.handleWatchEvents(context.Background(), key, events)
|
||||
}
|
||||
|
||||
func TestCluster_handleWatchEvents_ValueChange(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
listener := NewMockUpdateListener(ctrl)
|
||||
gomock.InOrder(
|
||||
listener.EXPECT().OnAdd(KV{Key: "hello", Val: "world1"}),
|
||||
listener.EXPECT().OnDelete(KV{Key: "hello", Val: "world1"}),
|
||||
listener.EXPECT().OnAdd(KV{Key: "hello", Val: "world2"}),
|
||||
)
|
||||
|
||||
c := newCluster([]string{"any"})
|
||||
key := watchKey{key: "any"}
|
||||
c.watchers[key] = &watchValue{
|
||||
listeners: []UpdateListener{listener},
|
||||
values: make(map[string]string),
|
||||
}
|
||||
c.handleWatchEvents(context.Background(), key, []*clientv3.Event{
|
||||
{Type: clientv3.EventTypePut, Kv: &mvccpb.KeyValue{Key: []byte("hello"), Value: []byte("world1")}},
|
||||
})
|
||||
c.handleWatchEvents(context.Background(), key, []*clientv3.Event{
|
||||
{Type: clientv3.EventTypePut, Kv: &mvccpb.KeyValue{Key: []byte("hello"), Value: []byte("world2")}},
|
||||
})
|
||||
}
|
||||
|
||||
type mockListener struct {
|
||||
}
|
||||
|
||||
|
||||
@@ -141,12 +141,23 @@ func (c *container) addKv(key, value string) ([]string, bool) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
oldVal, alreadyMapped := c.mapping[key]
|
||||
if alreadyMapped && oldVal == value {
|
||||
// duplicate PUT with same key+value, nothing to do
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c.dirty.Set(true)
|
||||
if alreadyMapped {
|
||||
// key moved to a new value; remove stale entry to prevent leak
|
||||
c.doRemoveKey(key)
|
||||
}
|
||||
|
||||
keys := c.values[value]
|
||||
previous := append([]string(nil), keys...)
|
||||
early := len(keys) > 0
|
||||
if c.exclusive && early {
|
||||
for _, each := range keys {
|
||||
for _, each := range previous {
|
||||
c.doRemoveKey(each)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,6 +201,81 @@ func TestContainer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainer_DuplicateAdd(t *testing.T) {
|
||||
c := newContainer(false)
|
||||
// Simulate 100 duplicate PUT events for the same key+value.
|
||||
for i := 0; i < 100; i++ {
|
||||
c.OnAdd(internal.KV{Key: "etcd-key", Val: "host:1234"})
|
||||
}
|
||||
assert.ElementsMatch(t, []string{"host:1234"}, c.GetValues())
|
||||
// Internal slice must not have grown beyond one entry.
|
||||
c.lock.Lock()
|
||||
assert.Len(t, c.values["host:1234"], 1)
|
||||
c.lock.Unlock()
|
||||
}
|
||||
|
||||
func TestContainer_KeyValueChange(t *testing.T) {
|
||||
c := newContainer(false)
|
||||
c.OnAdd(internal.KV{Key: "etcd-key", Val: "host:1234"})
|
||||
assert.ElementsMatch(t, []string{"host:1234"}, c.GetValues())
|
||||
|
||||
// Key moves to a different server value.
|
||||
c.OnAdd(internal.KV{Key: "etcd-key", Val: "host:5678"})
|
||||
assert.ElementsMatch(t, []string{"host:5678"}, c.GetValues())
|
||||
|
||||
// Old server must be fully removed; a subsequent delete must leave nothing.
|
||||
c.OnDelete(internal.KV{Key: "etcd-key", Val: "host:5678"})
|
||||
assert.Empty(t, c.GetValues())
|
||||
}
|
||||
|
||||
// TestContainer_ExclusiveMode verifies that adding successive keys for the same
|
||||
// value in exclusive mode leaves only the latest key and evicts all prior ones.
|
||||
func TestContainer_ExclusiveMode(t *testing.T) {
|
||||
c := newContainer(true)
|
||||
c.OnAdd(internal.KV{Key: "key1", Val: "server1"})
|
||||
c.OnAdd(internal.KV{Key: "key2", Val: "server1"})
|
||||
c.OnAdd(internal.KV{Key: "key3", Val: "server1"})
|
||||
|
||||
assert.ElementsMatch(t, []string{"server1"}, c.GetValues())
|
||||
c.lock.Lock()
|
||||
assert.Equal(t, []string{"key3"}, c.values["server1"], "only the latest key must remain")
|
||||
assert.NotContains(t, c.mapping, "key1", "key1 must have been evicted")
|
||||
assert.NotContains(t, c.mapping, "key2", "key2 must have been evicted")
|
||||
assert.Equal(t, "server1", c.mapping["key3"])
|
||||
c.lock.Unlock()
|
||||
}
|
||||
|
||||
// TestContainer_ExclusiveMode_MultipleEvictions injects 3 keys for the same
|
||||
// value directly into internal state and then triggers the exclusive eviction
|
||||
// loop via OnAdd. This exercises the range-over-previous fix: iterating over
|
||||
// the live slice (range keys) would corrupt iteration when doRemoveKey
|
||||
// compacts the shared underlying array in-place, causing the second and third
|
||||
// keys to be skipped; ranging over the deep copy (range previous) is safe.
|
||||
func TestContainer_ExclusiveMode_MultipleEvictions(t *testing.T) {
|
||||
c := newContainer(true)
|
||||
|
||||
// Bypass the exclusive invariant to simulate 3 pre-existing keys for the
|
||||
// same value — the state that would expose the in-place aliasing bug.
|
||||
c.lock.Lock()
|
||||
c.values["server1"] = []string{"key1", "key2", "key3"}
|
||||
c.mapping["key1"] = "server1"
|
||||
c.mapping["key2"] = "server1"
|
||||
c.mapping["key3"] = "server1"
|
||||
c.lock.Unlock()
|
||||
|
||||
// Adding key4 must evict all three existing keys via the exclusive loop.
|
||||
c.OnAdd(internal.KV{Key: "key4", Val: "server1"})
|
||||
|
||||
assert.ElementsMatch(t, []string{"server1"}, c.GetValues())
|
||||
c.lock.Lock()
|
||||
assert.Equal(t, []string{"key4"}, c.values["server1"], "all prior keys must be evicted")
|
||||
assert.NotContains(t, c.mapping, "key1", "key1 must be evicted")
|
||||
assert.NotContains(t, c.mapping, "key2", "key2 must be evicted")
|
||||
assert.NotContains(t, c.mapping, "key3", "key3 must be evicted")
|
||||
assert.Equal(t, "server1", c.mapping["key4"])
|
||||
c.lock.Unlock()
|
||||
}
|
||||
|
||||
func TestSubscriber(t *testing.T) {
|
||||
sub := new(Subscriber)
|
||||
Exclusive()(sub)
|
||||
|
||||
13
go.mod
13
go.mod
@@ -11,19 +11,19 @@ require (
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||
github.com/golang/protobuf v1.5.4
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/grafana/pyroscope-go v1.2.8
|
||||
github.com/grafana/pyroscope-go v1.3.0
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/jhump/protoreflect v1.18.0
|
||||
github.com/modelcontextprotocol/go-sdk v1.4.0
|
||||
github.com/pelletier/go-toml/v2 v2.3.0
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/redis/go-redis/v9 v9.18.0
|
||||
github.com/redis/go-redis/v9 v9.19.0
|
||||
github.com/spaolacci/murmur3 v1.1.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/titanous/json5 v1.0.0
|
||||
go.etcd.io/etcd/api/v3 v3.5.21
|
||||
go.etcd.io/etcd/client/v3 v3.5.21
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0
|
||||
go.mongodb.org/mongo-driver/v2 v2.6.0
|
||||
go.opentelemetry.io/otel v1.40.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.40.0
|
||||
@@ -38,7 +38,7 @@ require (
|
||||
golang.org/x/sys v0.41.0
|
||||
golang.org/x/time v0.14.0
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409
|
||||
google.golang.org/grpc v1.79.3
|
||||
google.golang.org/grpc v1.80.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gopkg.in/cheggaaa/pb.v1 v1.0.28
|
||||
gopkg.in/h2non/gock.v1 v1.1.2
|
||||
@@ -58,7 +58,6 @@ require (
|
||||
github.com/coreos/go-semver v0.3.1 // indirect
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/emicklei/go-restful/v3 v3.12.2 // indirect
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect
|
||||
@@ -73,7 +72,7 @@ require (
|
||||
github.com/google/gnostic-models v0.7.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.10 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 // indirect
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
@@ -82,7 +81,7 @@ require (
|
||||
github.com/jhump/protoreflect/v2 v2.0.0-beta.1 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/compress v1.18.6 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
|
||||
38
go.sum
38
go.sum
@@ -28,8 +28,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/emicklei/go-restful/v3 v3.12.2 h1:DhwDP0vY3k8ZzE0RunuJy8GhNpPL6zqLkDf9B/a0/xU=
|
||||
github.com/emicklei/go-restful/v3 v3.12.2/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g=
|
||||
@@ -82,10 +80,10 @@ github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgY
|
||||
github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grafana/pyroscope-go v1.2.8 h1:UvCwIhlx9DeV7F6TW/z8q1Mi4PIm3vuUJ2ZlCEvmA4M=
|
||||
github.com/grafana/pyroscope-go v1.2.8/go.mod h1:SSi59eQ1/zmKoY/BKwa5rSFsJaq+242Bcrr4wPix1g8=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 h1:c1Us8i6eSmkW+Ez05d3co8kasnuOY813tbMN8i/a3Og=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.9/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||
github.com/grafana/pyroscope-go v1.3.0 h1:t3Jehad8vvqN4oRAB0LdmfQ5ZSUXQw3asoft+K4GAT8=
|
||||
github.com/grafana/pyroscope-go v1.3.0/go.mod h1:XA7I3usNx+UdjOZfQnl1WV8y924vsJo9KIVrKB+9jx4=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.10 h1:dvhndEbyavTb59vFCd6PsrAG5qi69/qZZtegh/TJKSY=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.10/go.mod h1:XnWRGg2XO5uxZdiz1rfeJH6w1eZ+YICCBVXNWOfH86g=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 h1:X+2YciYSxvMQK0UZ7sg45ZVabVZBeBuvMkmuI2V3Fak=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7/go.mod h1:lW34nIZuQ8UDPdkon5fmfp2l3+ZkQ2me/+oecHYLOII=
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||
@@ -109,10 +107,10 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao=
|
||||
github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
|
||||
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
@@ -169,8 +167,8 @@ github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9Z
|
||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
|
||||
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
|
||||
github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k=
|
||||
github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robertkrimen/otto v0.2.1 h1:FVP0PJ0AHIjC+N4pKCG9yCDz6LHNPCwi/GKID5pGGF0=
|
||||
@@ -218,16 +216,16 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs=
|
||||
github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s=
|
||||
go.etcd.io/etcd/api/v3 v3.5.21 h1:A6O2/JDb3tvHhiIz3xf9nJ7REHvtEFJJ3veW3FbCnS8=
|
||||
go.etcd.io/etcd/api/v3 v3.5.21/go.mod h1:c3aH5wcvXv/9dqIw2Y810LDXJfhSYdHQ0vxmP3CCHVY=
|
||||
go.etcd.io/etcd/client/pkg/v3 v3.5.21 h1:lPBu71Y7osQmzlflM9OfeIV2JlmpBjqBNlLtcoBqUTc=
|
||||
go.etcd.io/etcd/client/pkg/v3 v3.5.21/go.mod h1:BgqT/IXPjK9NkeSDjbzwsHySX3yIle2+ndz28nVsjUs=
|
||||
go.etcd.io/etcd/client/v3 v3.5.21 h1:T6b1Ow6fNjOLOtM0xSoKNQt1ASPCLWrF9XMHcH9pEyY=
|
||||
go.etcd.io/etcd/client/v3 v3.5.21/go.mod h1:mFYy67IOqmbRf/kRUvsHixzo3iG+1OF2W2+jVIQRAnU=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
go.mongodb.org/mongo-driver/v2 v2.6.0 h1:b9sJOYrkmt4l8bY43ZenFBcPlhYIjaOfYHLtbB/5qi8=
|
||||
go.mongodb.org/mongo-driver/v2 v2.6.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
|
||||
@@ -327,14 +325,14 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 h1:H86B94AW+VfJWDqFeEbBPhEtHzJwJfTbgE2lZa54ZAQ=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
|
||||
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
33
mcp/options.go
Normal file
33
mcp/options.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package mcp
|
||||
|
||||
import "net/http"
|
||||
|
||||
// RequestMetadataExtractor extracts request metadata for downstream handlers.
|
||||
type RequestMetadataExtractor func(*http.Request) RequestMetadata
|
||||
|
||||
// McpOption customizes MCP server construction.
|
||||
type McpOption interface {
|
||||
apply(*serverOptions)
|
||||
}
|
||||
|
||||
type mcpOptionFunc func(*serverOptions)
|
||||
|
||||
func (f mcpOptionFunc) apply(opts *serverOptions) {
|
||||
f(opts)
|
||||
}
|
||||
|
||||
type serverOptions struct {
|
||||
requestMetadataExtractor RequestMetadataExtractor
|
||||
}
|
||||
|
||||
func defaultServerOptions() serverOptions {
|
||||
return serverOptions{}
|
||||
}
|
||||
|
||||
// WithRequestMetadataExtractor installs an extractor that runs for each incoming
|
||||
// MCP HTTP request, and stores the extracted metadata into handler context.
|
||||
func WithRequestMetadataExtractor(extractor RequestMetadataExtractor) McpOption {
|
||||
return mcpOptionFunc(func(opts *serverOptions) {
|
||||
opts.requestMetadataExtractor = extractor
|
||||
})
|
||||
}
|
||||
@@ -15,6 +15,7 @@ This package provides a go-zero integration for the [Model Context Protocol (MCP
|
||||
- **CORS Support**: Configurable CORS settings for cross-origin requests
|
||||
- **Type-Safe Tool Handlers**: Generic tool handlers with automatic JSON schema generation
|
||||
- **Prompts and Resources**: Full support for MCP prompts and resources
|
||||
- **Request Metadata Bridge**: Optional request metadata extraction into handler context
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -220,6 +221,35 @@ mcp:
|
||||
messageEndpoint: /message
|
||||
```
|
||||
|
||||
## Request Metadata Bridge
|
||||
|
||||
For multi-tenant or request-context-aware tools, you can extract selected HTTP request metadata once at the transport boundary and read it from `context.Context` in handlers.
|
||||
|
||||
```go
|
||||
server := mcp.NewMcpServerWithOptions(c,
|
||||
mcp.WithRequestMetadataExtractor(mcp.DefaultRequestMetadataExtractor),
|
||||
)
|
||||
|
||||
handler := func(ctx context.Context, req *mcp.CallToolRequest, args SomeArgs) (*mcp.CallToolResult, any, error) {
|
||||
tenant, _ := mcp.HeaderFromContext(ctx, "X-Tenant-Id")
|
||||
traceID, _ := mcp.QueryFromContext(ctx, "trace")
|
||||
scope, _ := mcp.PathFromContext(ctx, "scope")
|
||||
|
||||
_ = tenant
|
||||
_ = traceID
|
||||
_ = scope
|
||||
|
||||
return &mcp.CallToolResult{}, nil, nil
|
||||
}
|
||||
```
|
||||
|
||||
Available helpers:
|
||||
|
||||
- `RequestMetadataFromContext(ctx)`
|
||||
- `HeaderFromContext(ctx, key)`
|
||||
- `QueryFromContext(ctx, key)`
|
||||
- `PathFromContext(ctx, key)`
|
||||
|
||||
## Configuration Options
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|
||||
150
mcp/request_metadata.go
Normal file
150
mcp/request_metadata.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest/pathvar"
|
||||
)
|
||||
|
||||
// RequestMetadata carries selected request-scoped values into MCP handlers.
|
||||
type RequestMetadata struct {
|
||||
Headers map[string][]string
|
||||
Query map[string][]string
|
||||
Path map[string]string
|
||||
}
|
||||
|
||||
type requestMetadataCtxKey struct{}
|
||||
|
||||
// RequestMetadataFromContext returns metadata extracted at the transport boundary.
|
||||
func RequestMetadataFromContext(ctx context.Context) (RequestMetadata, bool) {
|
||||
metadata, ok := requestMetadataFromContext(ctx)
|
||||
if !ok {
|
||||
return RequestMetadata{}, false
|
||||
}
|
||||
|
||||
return normalizeRequestMetadata(metadata), true
|
||||
}
|
||||
|
||||
// HeaderFromContext returns the first header value for key.
|
||||
func HeaderFromContext(ctx context.Context, key string) (string, bool) {
|
||||
metadata, ok := requestMetadataFromContext(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
vals := metadata.Headers[http.CanonicalHeaderKey(key)]
|
||||
if len(vals) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return vals[0], true
|
||||
}
|
||||
|
||||
// QueryFromContext returns the first query value for key.
|
||||
func QueryFromContext(ctx context.Context, key string) (string, bool) {
|
||||
metadata, ok := requestMetadataFromContext(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
vals := metadata.Query[key]
|
||||
if len(vals) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return vals[0], true
|
||||
}
|
||||
|
||||
// PathFromContext returns the path variable value for key.
|
||||
func PathFromContext(ctx context.Context, key string) (string, bool) {
|
||||
metadata, ok := requestMetadataFromContext(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
val, ok := metadata.Path[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return val, true
|
||||
}
|
||||
|
||||
func requestMetadataFromContext(ctx context.Context) (RequestMetadata, bool) {
|
||||
metadata, ok := ctx.Value(requestMetadataCtxKey{}).(RequestMetadata)
|
||||
if !ok {
|
||||
return RequestMetadata{}, false
|
||||
}
|
||||
|
||||
return metadata, true
|
||||
}
|
||||
|
||||
// DefaultRequestMetadataExtractor extracts headers, query values, and path variables.
|
||||
func DefaultRequestMetadataExtractor(r *http.Request) RequestMetadata {
|
||||
metadata := RequestMetadata{
|
||||
Headers: make(map[string][]string, len(r.Header)),
|
||||
Query: make(map[string][]string),
|
||||
Path: clonePathVars(pathvar.Vars(r)),
|
||||
}
|
||||
|
||||
for key, vals := range r.Header {
|
||||
metadata.Headers[http.CanonicalHeaderKey(key)] = append([]string(nil), vals...)
|
||||
}
|
||||
|
||||
if r.URL != nil {
|
||||
for key, vals := range r.URL.Query() {
|
||||
metadata.Query[key] = append([]string(nil), vals...)
|
||||
}
|
||||
}
|
||||
|
||||
return metadata
|
||||
}
|
||||
|
||||
func normalizeRequestMetadata(metadata RequestMetadata) RequestMetadata {
|
||||
return RequestMetadata{
|
||||
Headers: cloneCanonicalHeaderValues(metadata.Headers),
|
||||
Query: cloneHeaderValues(metadata.Query),
|
||||
Path: clonePathVars(metadata.Path),
|
||||
}
|
||||
}
|
||||
|
||||
func cloneHeaderValues(values map[string][]string) map[string][]string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make(map[string][]string, len(values))
|
||||
for key, vals := range values {
|
||||
cloned[key] = append([]string(nil), vals...)
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneCanonicalHeaderValues(values map[string][]string) map[string][]string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make(map[string][]string, len(values))
|
||||
for key, vals := range values {
|
||||
canonical := http.CanonicalHeaderKey(key)
|
||||
cloned[canonical] = append(cloned[canonical], vals...)
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
func clonePathVars(values map[string]string) map[string]string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make(map[string]string, len(values))
|
||||
for key, val := range values {
|
||||
cloned[key] = val
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
185
mcp/request_metadata_test.go
Normal file
185
mcp/request_metadata_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/rest/pathvar"
|
||||
)
|
||||
|
||||
func TestDefaultRequestMetadataExtractor(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/sse?tenant=t1&trace=abc", nil)
|
||||
req.Header.Add("X-Tenant-Id", "tenant-from-header")
|
||||
req = pathvar.WithVars(req, map[string]string{"tool": "sum"})
|
||||
|
||||
metadata := DefaultRequestMetadataExtractor(req)
|
||||
header, ok := metadata.Headers["X-Tenant-Id"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"tenant-from-header"}, header)
|
||||
assert.Equal(t, []string{"t1"}, metadata.Query["tenant"])
|
||||
assert.Equal(t, "sum", metadata.Path["tool"])
|
||||
}
|
||||
|
||||
func TestRequestMetadataContextHelpers(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||||
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
|
||||
Query: map[string][]string{"tenant": {"foo"}},
|
||||
Path: map[string]string{"scope": "prod"},
|
||||
})
|
||||
|
||||
metadata, ok := RequestMetadataFromContext(ctx)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"trace-1"}, metadata.Headers["X-Trace-Id"])
|
||||
|
||||
header, ok := HeaderFromContext(ctx, "x-trace-id")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "trace-1", header)
|
||||
|
||||
query, ok := QueryFromContext(ctx, "tenant")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "foo", query)
|
||||
|
||||
path, ok := PathFromContext(ctx, "scope")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "prod", path)
|
||||
}
|
||||
|
||||
func TestRequestMetadataContextHelpersMissingKeys(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||||
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
|
||||
Query: map[string][]string{"tenant": {"foo"}},
|
||||
Path: map[string]string{"scope": "prod"},
|
||||
})
|
||||
|
||||
_, ok := HeaderFromContext(ctx, "x-missing")
|
||||
assert.False(t, ok)
|
||||
|
||||
_, ok = QueryFromContext(ctx, "missing")
|
||||
assert.False(t, ok)
|
||||
|
||||
_, ok = PathFromContext(ctx, "missing")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestRequestMetadataFromContextNotFound(t *testing.T) {
|
||||
_, ok := RequestMetadataFromContext(context.Background())
|
||||
assert.False(t, ok)
|
||||
|
||||
_, ok = HeaderFromContext(context.Background(), "x-test")
|
||||
assert.False(t, ok)
|
||||
|
||||
_, ok = QueryFromContext(context.Background(), "tenant")
|
||||
assert.False(t, ok)
|
||||
|
||||
_, ok = PathFromContext(context.Background(), "tenant")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestWrapRequestMetadata(t *testing.T) {
|
||||
s := &mcpServerImpl{
|
||||
options: serverOptions{
|
||||
requestMetadataExtractor: DefaultRequestMetadataExtractor,
|
||||
},
|
||||
}
|
||||
|
||||
called := false
|
||||
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
header, ok := HeaderFromContext(r.Context(), "x-tenant-id")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "tenant-1", header)
|
||||
|
||||
query, ok := QueryFromContext(r.Context(), "tenant")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "q-tenant", query)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/sse?tenant=q-tenant", nil)
|
||||
req.Header.Set("X-Tenant-Id", "tenant-1")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestWrapRequestMetadataNoExtractor(t *testing.T) {
|
||||
s := &mcpServerImpl{}
|
||||
|
||||
called := false
|
||||
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
_, ok := RequestMetadataFromContext(r.Context())
|
||||
assert.False(t, ok)
|
||||
}))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/sse", nil))
|
||||
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestWrapRequestMetadataCanonicalizesCustomHeaders(t *testing.T) {
|
||||
s := &mcpServerImpl{
|
||||
options: serverOptions{
|
||||
requestMetadataExtractor: func(*http.Request) RequestMetadata {
|
||||
return RequestMetadata{
|
||||
Headers: map[string][]string{
|
||||
"x-tenant-id": {"tenant-lower"},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
called := false
|
||||
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
header, ok := HeaderFromContext(r.Context(), "X-Tenant-Id")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "tenant-lower", header)
|
||||
}))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/sse", nil))
|
||||
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestRequestMetadataFromContextReturnsCopy(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||||
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
|
||||
})
|
||||
|
||||
metadata, ok := RequestMetadataFromContext(ctx)
|
||||
assert.True(t, ok)
|
||||
metadata.Headers["X-Trace-Id"][0] = "mutated"
|
||||
metadata.Headers["X-New"] = []string{"new"}
|
||||
|
||||
fresh, ok := RequestMetadataFromContext(ctx)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"trace-1"}, fresh.Headers["X-Trace-Id"])
|
||||
assert.Nil(t, fresh.Headers["X-New"])
|
||||
}
|
||||
|
||||
func TestRequestMetadataFromContextWithEmptyAndCanonicalizedHeaders(t *testing.T) {
|
||||
emptyCtx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{})
|
||||
empty, ok := RequestMetadataFromContext(emptyCtx)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, empty.Headers)
|
||||
assert.Nil(t, empty.Query)
|
||||
assert.Nil(t, empty.Path)
|
||||
|
||||
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||||
Headers: map[string][]string{
|
||||
"x-tenant-id": {"a"},
|
||||
"X-Tenant-Id": {"b"},
|
||||
},
|
||||
})
|
||||
|
||||
metadata, ok := RequestMetadataFromContext(ctx)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"a", "b"}, metadata.Headers["X-Tenant-Id"])
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
@@ -20,10 +21,23 @@ type mcpServerImpl struct {
|
||||
conf McpConf
|
||||
httpServer *rest.Server
|
||||
mcpServer *sdkmcp.Server
|
||||
options serverOptions
|
||||
}
|
||||
|
||||
// NewMcpServer creates a new MCP server using the official SDK
|
||||
func NewMcpServer(c McpConf) McpServer {
|
||||
return NewMcpServerWithOptions(c)
|
||||
}
|
||||
|
||||
// NewMcpServerWithOptions creates a new MCP server with optional customizations.
|
||||
func NewMcpServerWithOptions(c McpConf, opts ...McpOption) McpServer {
|
||||
serverOpts := defaultServerOptions()
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt.apply(&serverOpts)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the underlying rest HTTP server
|
||||
var httpServer *rest.Server
|
||||
if len(c.Mcp.Cors) == 0 {
|
||||
@@ -52,6 +66,7 @@ func NewMcpServer(c McpConf) McpServer {
|
||||
conf: c,
|
||||
httpServer: httpServer,
|
||||
mcpServer: mcpServer,
|
||||
options: serverOpts,
|
||||
}
|
||||
|
||||
// Choose transport based on configuration
|
||||
@@ -85,7 +100,7 @@ func (s *mcpServerImpl) setupSSETransport() {
|
||||
return s.mcpServer
|
||||
}, nil)
|
||||
|
||||
s.registerRoutes(handler, s.conf.Mcp.SseEndpoint)
|
||||
s.registerRoutes(s.wrapRequestMetadata(handler), s.conf.Mcp.SseEndpoint)
|
||||
}
|
||||
|
||||
// setupStreamableTransport configures the server to use Streamable HTTP transport (2025-03-26 spec)
|
||||
@@ -96,7 +111,7 @@ func (s *mcpServerImpl) setupStreamableTransport() {
|
||||
return s.mcpServer
|
||||
}, nil)
|
||||
|
||||
s.registerRoutes(handler, s.conf.Mcp.MessageEndpoint)
|
||||
s.registerRoutes(s.wrapRequestMetadata(handler), s.conf.Mcp.MessageEndpoint)
|
||||
}
|
||||
|
||||
func (s *mcpServerImpl) registerRoutes(handler http.Handler, endpoint string) {
|
||||
@@ -113,3 +128,16 @@ func (s *mcpServerImpl) registerRoutes(handler http.Handler, endpoint string) {
|
||||
Handler: handler.ServeHTTP,
|
||||
}, rest.WithTimeout(s.conf.Mcp.MessageTimeout))
|
||||
}
|
||||
|
||||
func (s *mcpServerImpl) wrapRequestMetadata(next http.Handler) http.Handler {
|
||||
extractor := s.options.requestMetadataExtractor
|
||||
if extractor == nil {
|
||||
return next
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
metadata := normalizeRequestMetadata(extractor(r))
|
||||
ctx := context.WithValue(r.Context(), requestMetadataCtxKey{}, metadata)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,11 +3,14 @@ package mcp
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
)
|
||||
@@ -391,3 +394,148 @@ func TestAddToolWithCustomServer(t *testing.T) {
|
||||
return nil, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestMetadataIntegrationSSEToolCall(t *testing.T) {
|
||||
port := getFreePort(t)
|
||||
|
||||
c := McpConf{}
|
||||
c.Host = "127.0.0.1"
|
||||
c.Port = port
|
||||
c.Mcp.Name = "metadata-integration-test"
|
||||
c.Mcp.UseStreamable = false
|
||||
c.Mcp.SseEndpoint = "/sse/:scope"
|
||||
c.Mcp.MessageTimeout = 2 * time.Second
|
||||
c.Mcp.SseTimeout = 2 * time.Second
|
||||
|
||||
server := NewMcpServerWithOptions(c, WithRequestMetadataExtractor(DefaultRequestMetadataExtractor))
|
||||
|
||||
tool := &Tool{
|
||||
Name: "inspect_metadata",
|
||||
Description: "Inspect metadata in handler context",
|
||||
}
|
||||
|
||||
type Args struct{}
|
||||
|
||||
AddTool(server, tool, func(ctx context.Context, req *CallToolRequest, args Args) (*CallToolResult, any, error) {
|
||||
header, ok := HeaderFromContext(ctx, "x-tenant-id")
|
||||
if !ok || header != "tenant-header" {
|
||||
return nil, nil, fmt.Errorf("unexpected header from context: %q", header)
|
||||
}
|
||||
|
||||
query, ok := QueryFromContext(ctx, "tenant")
|
||||
if !ok || query != "tenant-query" {
|
||||
return nil, nil, fmt.Errorf("unexpected query from context: %q", query)
|
||||
}
|
||||
|
||||
scope, ok := PathFromContext(ctx, "scope")
|
||||
if !ok || scope != "prod" {
|
||||
return nil, nil, fmt.Errorf("unexpected path from context: %q", scope)
|
||||
}
|
||||
|
||||
return &CallToolResult{
|
||||
Content: []Content{&TextContent{Text: "metadata-ok"}},
|
||||
}, nil, nil
|
||||
})
|
||||
|
||||
go server.Start()
|
||||
t.Cleanup(server.Stop)
|
||||
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d/sse/prod?tenant=tenant-query", port)
|
||||
waitForServerReady(t, baseURL, 2*time.Second)
|
||||
|
||||
client := sdkmcp.NewClient(&sdkmcp.Implementation{
|
||||
Name: "metadata-client",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 2 * time.Second,
|
||||
Transport: metadataHeaderRoundTripper{
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
}
|
||||
|
||||
transport := &sdkmcp.SSEClientTransport{
|
||||
Endpoint: baseURL,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
|
||||
session, err := client.Connect(context.Background(), transport, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = session.Close()
|
||||
})
|
||||
|
||||
res, err := session.CallTool(context.Background(), &sdkmcp.CallToolParams{
|
||||
Name: "inspect_metadata",
|
||||
Arguments: map[string]any{},
|
||||
})
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
if !assert.NotNil(t, res) {
|
||||
return
|
||||
}
|
||||
assert.False(t, res.IsError)
|
||||
}
|
||||
|
||||
type metadataHeaderRoundTripper struct {
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (r metadataHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
next := r.next
|
||||
if next == nil {
|
||||
next = http.DefaultTransport
|
||||
}
|
||||
|
||||
clone := req.Clone(req.Context())
|
||||
clone.Header.Set("X-Tenant-Id", "tenant-header")
|
||||
return next.RoundTrip(clone)
|
||||
}
|
||||
|
||||
func getFreePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if !assert.NoError(t, err) {
|
||||
return 0
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
addr, ok := listener.Addr().(*net.TCPAddr)
|
||||
if !assert.True(t, ok) {
|
||||
return 0
|
||||
}
|
||||
|
||||
return addr.Port
|
||||
}
|
||||
|
||||
func waitForServerReady(t *testing.T, endpoint string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
client := &http.Client{Timeout: 200 * time.Millisecond}
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build readiness request: %v", err)
|
||||
}
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode > 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("server did not become ready for %s within %s", endpoint, timeout)
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ require (
|
||||
github.com/emicklei/proto v1.14.3
|
||||
github.com/fatih/structtag v1.2.0
|
||||
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/gookit/color v1.6.0
|
||||
github.com/go-sql-driver/mysql v1.10.0
|
||||
github.com/gookit/color v1.6.1
|
||||
github.com/iancoleman/strcase v0.3.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/pflag v1.0.10
|
||||
@@ -18,13 +18,13 @@ require (
|
||||
github.com/zeromicro/ddl-parser v1.0.5
|
||||
github.com/zeromicro/go-zero v1.10.1
|
||||
golang.org/x/text v0.34.0
|
||||
google.golang.org/grpc v1.79.3
|
||||
google.golang.org/grpc v1.80.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
filippo.io/edwards25519 v1.2.0 // indirect
|
||||
github.com/alicebob/miniredis/v2 v2.37.0 // indirect
|
||||
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
|
||||
filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68=
|
||||
@@ -51,8 +51,8 @@ github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e h1:auobAirzhPsL
|
||||
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e/go.mod h1:NAKTe9SplQBxIUlHlsuId1jk1I7bWTVV/2q/GtdRi6g=
|
||||
github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU=
|
||||
github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0=
|
||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||
github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw=
|
||||
github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
|
||||
@@ -72,8 +72,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gookit/assert v0.1.1 h1:lh3GcawXe/p+cU7ESTZ5Ui3Sm/x8JWpIis4/1aF0mY0=
|
||||
github.com/gookit/assert v0.1.1/go.mod h1:jS5bmIVQZTIwk42uXl4lyj4iaaxx32tqH16CFj0VX2E=
|
||||
github.com/gookit/color v1.6.0 h1:JjJXBTk1ETNyqyilJhkTXJYYigHG24TM9Xa2M1xAhRA=
|
||||
github.com/gookit/color v1.6.0/go.mod h1:9ACFc7/1IpHGBW8RwuDm/0YEnhg3dwwXpoMsmtyHfjs=
|
||||
github.com/gookit/color v1.6.1 h1:KoTnDxJPRgrL0SoX0f8rCFg2zI0t4E3GZZBMo2nN8LU=
|
||||
github.com/gookit/color v1.6.1/go.mod h1:9ACFc7/1IpHGBW8RwuDm/0YEnhg3dwwXpoMsmtyHfjs=
|
||||
github.com/grafana/pyroscope-go v1.2.8 h1:UvCwIhlx9DeV7F6TW/z8q1Mi4PIm3vuUJ2ZlCEvmA4M=
|
||||
github.com/grafana/pyroscope-go v1.2.8/go.mod h1:SSi59eQ1/zmKoY/BKwa5rSFsJaq+242Bcrr4wPix1g8=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 h1:c1Us8i6eSmkW+Ez05d3co8kasnuOY813tbMN8i/a3Og=
|
||||
@@ -282,14 +282,14 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 h1:H86B94AW+VfJWDqFeEbBPhEtHzJwJfTbgE2lZa54ZAQ=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
|
||||
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -13,10 +13,10 @@ type discovBuilder struct{}
|
||||
|
||||
func (b *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (
|
||||
resolver.Resolver, error) {
|
||||
hosts := strings.FieldsFunc(targets.GetAuthority(target), func(r rune) bool {
|
||||
hosts := strings.FieldsFunc(targets.GetHosts(target), func(r rune) bool {
|
||||
return r == EndpointSepChar
|
||||
})
|
||||
sub, err := discov.NewSubscriber(hosts, targets.GetEndpoints(target))
|
||||
sub, err := discov.NewSubscriber(hosts, targets.GetKey(target))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func TestDiscovBuilder_Build(t *testing.T) {
|
||||
for _, server := range servers.Servers {
|
||||
addrs = append(addrs, server.Address)
|
||||
}
|
||||
u, err := url.Parse(fmt.Sprintf("%s://%s", DiscovScheme, strings.Join(addrs, ",")))
|
||||
u, err := url.Parse(fmt.Sprintf("%s:///%s?key=test", DiscovScheme, strings.Join(addrs, ",")))
|
||||
assert.NoError(t, err)
|
||||
|
||||
var b discovBuilder
|
||||
|
||||
@@ -17,3 +17,29 @@ func GetAuthority(target resolver.Target) string {
|
||||
func GetEndpoints(target resolver.Target) string {
|
||||
return strings.Trim(target.URL.Path, slashSeparator)
|
||||
}
|
||||
|
||||
// GetHosts returns the comma-separated etcd hosts from the target URL.
|
||||
// It supports two formats:
|
||||
// - New format (etcd:///h1:port,h2:port?key=k): hosts are in the URL path (empty authority)
|
||||
// - Legacy format (etcd://h1:port/key): host is in the URL authority
|
||||
func GetHosts(target resolver.Target) string {
|
||||
if target.URL.Host == "" {
|
||||
// New format: hosts encoded in URL path to avoid RFC 3986 authority issues
|
||||
return GetEndpoints(target)
|
||||
}
|
||||
// Legacy format: single host in authority
|
||||
return target.URL.Host
|
||||
}
|
||||
|
||||
// GetKey returns the etcd key from the target URL.
|
||||
// It supports two formats:
|
||||
// - New format (etcd:///h1:port,h2:port?key=k): key is in the "key" query parameter
|
||||
// - Legacy format (etcd://h1:port/key): key is in the URL path
|
||||
func GetKey(target resolver.Target) string {
|
||||
if target.URL.Host == "" {
|
||||
// New format: key is in the query parameter
|
||||
return target.URL.Query().Get("key")
|
||||
}
|
||||
// Legacy format: key is in the path
|
||||
return strings.Trim(target.URL.Path, slashSeparator)
|
||||
}
|
||||
|
||||
@@ -87,3 +87,83 @@ func TestGetEndpoints(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHosts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "single host",
|
||||
url: "etcd:///localhost:2379?key=foo",
|
||||
want: "localhost:2379",
|
||||
},
|
||||
{
|
||||
name: "multiple hosts",
|
||||
url: "etcd:///host1:2379,host2:2379,host3:2379?key=foo",
|
||||
want: "host1:2379,host2:2379,host3:2379",
|
||||
},
|
||||
{
|
||||
name: "legacy single host in authority",
|
||||
url: "etcd://localhost:2379/my-service",
|
||||
want: "localhost:2379",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
uri, err := url.Parse(test.url)
|
||||
assert.Nil(t, err)
|
||||
target := resolver.Target{
|
||||
URL: *uri,
|
||||
}
|
||||
assert.Equal(t, test.want, GetHosts(target))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "simple key",
|
||||
url: "etcd:///localhost:2379?key=my-service",
|
||||
want: "my-service",
|
||||
},
|
||||
{
|
||||
name: "key with slashes",
|
||||
url: "etcd:///localhost:2379?key=%2Fgrpc%2Fmy-service",
|
||||
want: "/grpc/my-service",
|
||||
},
|
||||
{
|
||||
name: "no key",
|
||||
url: "etcd:///localhost:2379",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "legacy key in path",
|
||||
url: "etcd://localhost:2379/my-service",
|
||||
want: "my-service",
|
||||
},
|
||||
{
|
||||
name: "legacy key with leading slash",
|
||||
url: "etcd://localhost:2379/grpc/my-service",
|
||||
want: "grpc/my-service",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
uri, err := url.Parse(test.url)
|
||||
assert.Nil(t, err)
|
||||
target := resolver.Target{
|
||||
URL: *uri,
|
||||
}
|
||||
assert.Equal(t, test.want, GetKey(target))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/zrpc/resolver/internal"
|
||||
@@ -14,7 +15,9 @@ func BuildDirectTarget(endpoints []string) string {
|
||||
}
|
||||
|
||||
// BuildDiscovTarget returns a string that represents the given endpoints with discov schema.
|
||||
// The format is etcd:///host1:port,host2:port?key=<etcd-key> to avoid placing comma-separated
|
||||
// hosts in the URI authority, which Go 1.26+ rejects per RFC 3986.
|
||||
func BuildDiscovTarget(endpoints []string, key string) string {
|
||||
return fmt.Sprintf("%s://%s/%s", internal.EtcdScheme,
|
||||
strings.Join(endpoints, internal.EndpointSep), key)
|
||||
return fmt.Sprintf("%s:///%s?key=%s", internal.EtcdScheme,
|
||||
strings.Join(endpoints, internal.EndpointSep), url.QueryEscape(key))
|
||||
}
|
||||
|
||||
@@ -13,5 +13,10 @@ func TestBuildDirectTarget(t *testing.T) {
|
||||
|
||||
func TestBuildDiscovTarget(t *testing.T) {
|
||||
target := BuildDiscovTarget([]string{"localhost:123", "localhost:456"}, "foo")
|
||||
assert.Equal(t, "etcd://localhost:123,localhost:456/foo", target)
|
||||
assert.Equal(t, "etcd:///localhost:123,localhost:456?key=foo", target)
|
||||
}
|
||||
|
||||
func TestBuildDiscovTargetWithSlashKey(t *testing.T) {
|
||||
target := BuildDiscovTarget([]string{"localhost:2379"}, "/grpc/my-service")
|
||||
assert.Equal(t, "etcd:///localhost:2379?key=%2Fgrpc%2Fmy-service", target)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user