Skip to content

Commit

Permalink
Properly unescape keyspace name in FindAllShardsInKeyspace (#15765)
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Lord <mattalord@gmail.com>
  • Loading branch information
mattlord authored Apr 24, 2024
1 parent 204bc50 commit 5f47800
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
9 changes: 9 additions & 0 deletions go/vt/topo/keyspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/constants/sidecar"
"vitess.io/vitess/go/sqlescape"
"vitess.io/vitess/go/vt/key"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -213,6 +214,14 @@ func (ts *Server) FindAllShardsInKeyspace(ctx context.Context, keyspace string,
opt.Concurrency = DefaultConcurrency
}

// Unescape the keyspace name as this can e.g. come from the VSchema where
// a keyspace/database name will need to be SQL escaped if it has special
// characters such as a dash.
keyspace, err := sqlescape.UnescapeID(keyspace)
if err != nil {
return nil, vterrors.Wrapf(err, "FindAllShardsInKeyspace(%s) invalid keyspace name", keyspace)
}

// First try to get all shards using List if we can.
buildResultFromList := func(kvpairs []KVInfo) (map[string]*ShardInfo, error) {
result := make(map[string]*ShardInfo, len(kvpairs))
Expand Down
28 changes: 23 additions & 5 deletions go/vt/topo/keyspace_external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqlescape"
"vitess.io/vitess/go/vt/key"
"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/topo/memorytopo"
Expand All @@ -32,10 +33,12 @@ import (
)

func TestServerFindAllShardsInKeyspace(t *testing.T) {
const defaultKeyspace = "keyspace"
tests := []struct {
name string
shards int
opt *topo.FindAllShardsInKeyspaceOptions
name string
shards int
keyspace string // If you want to override the default
opt *topo.FindAllShardsInKeyspaceOptions
}{
{
name: "negative concurrency",
Expand All @@ -54,9 +57,25 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) {
shards: 32,
opt: &topo.FindAllShardsInKeyspaceOptions{Concurrency: 8},
},
{
name: "SQL escaped keyspace",
shards: 32,
keyspace: "`my-keyspace`",
opt: &topo.FindAllShardsInKeyspaceOptions{Concurrency: 8},
},
}

for _, tt := range tests {
keyspace := defaultKeyspace
if tt.keyspace != "" {
// Most calls such as CreateKeyspace will not accept invalid characters
// in the value so we'll only use the original test case value in
// FindAllShardsInKeyspace. This allows us to test and confirm that
// FindAllShardsInKeyspace can handle SQL escaped or backtick'd names.
keyspace, _ = sqlescape.UnescapeID(tt.keyspace)
} else {
tt.keyspace = defaultKeyspace
}
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -66,7 +85,6 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) {

// Create an ephemeral keyspace and generate shard records within
// the keyspace to fetch later.
const keyspace = "keyspace"
require.NoError(t, ts.CreateKeyspace(ctx, keyspace, &topodatapb.Keyspace{}))

shards, err := key.GenerateShardRanges(tt.shards)
Expand All @@ -78,7 +96,7 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) {

// Verify that we return a complete list of shards and that each
// key range is present in the output.
out, err := ts.FindAllShardsInKeyspace(ctx, keyspace, tt.opt)
out, err := ts.FindAllShardsInKeyspace(ctx, tt.keyspace, tt.opt)
require.NoError(t, err)
require.Len(t, out, tt.shards)

Expand Down
7 changes: 4 additions & 3 deletions go/vt/vtctl/workflow/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ func TestVDiffCreate(t *testing.T) {
wantErr string
}{
{
name: "no values",
req: &vtctldatapb.VDiffCreateRequest{},
wantErr: "FindAllShardsInKeyspace(): List: node doesn't exist: keyspaces/shards", // We did not provide any keyspace or shard
name: "no values",
req: &vtctldatapb.VDiffCreateRequest{},
// We did not provide any keyspace or shard.
wantErr: "FindAllShardsInKeyspace() invalid keyspace name: UnescapeID err: invalid input identifier ''",
},
}
for _, tt := range tests {
Expand Down

0 comments on commit 5f47800

Please sign in to comment.