Skip to content

Commit 5a158c3

Browse files
authored
feat: add support for statement-scoped connection state (#599)
Add support for statement-scoped connection state. Statement-scoped values are only valid for the duration of the execution of a single statement. These values can be set in ExecOptions and used as an argument for QueryContext and ExecContext. A separate pull request will be added later that enables the use of statement-scoped connection variables in statement hints. That is, the following statement will apply the hints that correspond with valid connection properties as statement-scoped properties for the statement: ```sql @{rpc_priority=high, statement_tag='my_tag'} select * from my_table ```
1 parent b09720f commit 5a158c3

File tree

10 files changed

+441
-33
lines changed

10 files changed

+441
-33
lines changed

conn.go

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,16 +332,22 @@ func (c *conn) showConnectionVariable(identifier parser.Identifier) (any, bool,
332332
return c.state.GetValue(extension, name)
333333
}
334334

335-
func (c *conn) setConnectionVariable(identifier parser.Identifier, value string, local bool, transaction bool) error {
335+
func (c *conn) setConnectionVariable(identifier parser.Identifier, value string, local bool, transaction bool, statementScoped bool) error {
336336
if transaction && !local {
337337
// When transaction == true, then local must also be true.
338338
// We should never hit this condition, as this is an indication of a bug in the driver code.
339339
return status.Errorf(codes.FailedPrecondition, "transaction properties must be set as a local value")
340340
}
341+
if statementScoped && local {
342+
return status.Errorf(codes.FailedPrecondition, "cannot specify both statementScoped and local")
343+
}
341344
extension, name, err := toExtensionAndName(identifier)
342345
if err != nil {
343346
return err
344347
}
348+
if statementScoped {
349+
return c.state.SetStatementScopedValue(extension, name, value)
350+
}
345351
if local {
346352
return c.state.SetLocalValue(extension, name, value, transaction)
347353
}
@@ -568,7 +574,10 @@ func (c *conn) startBatchDDL() (driver.Result, error) {
568574
}
569575

570576
func (c *conn) startBatchDML(automatic bool) (driver.Result, error) {
571-
execOptions := c.options( /*reset = */ true)
577+
execOptions, err := c.options( /*reset = */ true)
578+
if err != nil {
579+
return nil, err
580+
}
572581

573582
if c.inTransaction() {
574583
return c.tx.StartBatchDML(execOptions.QueryOptions, automatic)
@@ -779,6 +788,10 @@ func (c *conn) CheckNamedValue(value *driver.NamedValue) error {
779788
c.tempExecOptions = &execOptions
780789
return driver.ErrRemoveArgument
781790
}
791+
if execOptions, ok := value.Value.(*ExecOptions); ok {
792+
c.tempExecOptions = execOptions
793+
return driver.ErrRemoveArgument
794+
}
782795

783796
if checkIsValidType(value.Value) {
784797
return nil
@@ -823,7 +836,10 @@ func (c *conn) Prepare(query string) (driver.Stmt, error) {
823836
}
824837

825838
func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
826-
execOptions := c.options( /* reset = */ true)
839+
execOptions, err := c.options( /* reset = */ true)
840+
if err != nil {
841+
return nil, err
842+
}
827843
parsedSQL, args, err := c.parser.ParseParameters(query)
828844
if err != nil {
829845
return nil, err
@@ -894,13 +910,37 @@ func (c *conn) transactionDeadline() (time.Time, bool, error) {
894910
return deadline, hasDeadline, nil
895911
}
896912

913+
func (c *conn) applyStatementScopedValues(execOptions *ExecOptions) (cleanup func(), returnedErr error) {
914+
if execOptions == nil {
915+
return func() {}, nil
916+
}
917+
918+
defer func() {
919+
if returnedErr != nil {
920+
// Clear any statement values that might have been set if we return an error, as that also means that we
921+
// are not returning a cleanup function.
922+
c.state.ClearStatementScopedValues()
923+
}
924+
}()
925+
values := execOptions.PropertyValues
926+
for _, value := range values {
927+
if err := c.setConnectionVariable(value.Identifier, value.Value /*local=*/, false /*transaction=*/, false /*statementScoped=*/, true); err != nil {
928+
return func() {}, err
929+
}
930+
}
931+
return c.state.ClearStatementScopedValues, nil
932+
}
933+
897934
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
898935
// Execute client side statement if it is one.
899936
clientStmt, err := c.parser.ParseClientSideStatement(query)
900937
if err != nil {
901938
return nil, err
902939
}
903-
execOptions := c.options( /* reset = */ clientStmt == nil)
940+
execOptions, err := c.options( /* reset = */ clientStmt == nil)
941+
if err != nil {
942+
return nil, err
943+
}
904944
if clientStmt != nil {
905945
execStmt, err := createExecutableStatement(clientStmt)
906946
if err != nil {
@@ -924,6 +964,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
924964
}()
925965
// Clear the commit timestamp of this connection before we execute the query.
926966
c.clearCommitResponse()
967+
927968
// Check if the execution options contains an instruction to execute
928969
// a specific partition of a PartitionedQuery.
929970
if pq := execOptions.PartitionedQueryOptions.ExecutePartition.PartitionedQuery; pq != nil {
@@ -1040,7 +1081,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
10401081
if err != nil {
10411082
return nil, err
10421083
}
1043-
execOptions := c.options( /*reset = */ stmt == nil)
1084+
execOptions, err := c.options( /*reset = */ stmt == nil)
1085+
if err != nil {
1086+
return nil, err
1087+
}
10441088
if stmt != nil {
10451089
execStmt, err := createExecutableStatement(stmt)
10461090
if err != nil {
@@ -1119,7 +1163,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecO
11191163
}
11201164

11211165
// options returns and optionally resets the ExecOptions for the next statement.
1122-
func (c *conn) options(reset bool) *ExecOptions {
1166+
func (c *conn) options(reset bool) (*ExecOptions, error) {
11231167
if reset {
11241168
defer func() {
11251169
// Only reset the transaction tag if there is no active transaction on the connection.
@@ -1130,11 +1174,24 @@ func (c *conn) options(reset bool) *ExecOptions {
11301174
c.tempExecOptions = nil
11311175
}()
11321176
}
1177+
// TODO: Refactor this to only use connection state as the (temporary) storage, and remove the tempExecOptions field
1178+
if c.tempExecOptions != nil {
1179+
cleanup, err := c.applyStatementScopedValues(c.tempExecOptions)
1180+
if err != nil {
1181+
return nil, err
1182+
}
1183+
defer cleanup()
1184+
}
1185+
1186+
// TODO: Refactor this to work 'the other way around'. That is:
1187+
// The ExecOptions that are given for a statement should update the connection state.
1188+
// The statement execution should read the state from the connection state.
11331189
effectiveOptions := &ExecOptions{
11341190
AutocommitDMLMode: c.AutocommitDMLMode(),
11351191
DecodeToNativeArrays: c.DecodeToNativeArrays(),
11361192
QueryOptions: spanner.QueryOptions{
11371193
RequestTag: c.StatementTag(),
1194+
Priority: propertyRpcPriority.GetValueOrDefault(c.state),
11381195
},
11391196
TransactionOptions: spanner.TransactionOptions{
11401197
ExcludeTxnFromChangeStreams: c.ExcludeTxnFromChangeStreams(),
@@ -1152,7 +1209,7 @@ func (c *conn) options(reset bool) *ExecOptions {
11521209
if c.tempExecOptions != nil {
11531210
effectiveOptions.merge(c.tempExecOptions)
11541211
}
1155-
return effectiveOptions
1212+
return effectiveOptions, nil
11561213
}
11571214

11581215
func (c *conn) Close() error {
@@ -1442,7 +1499,11 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
14421499
// Reset the transaction_tag after starting the transaction.
14431500
_ = propertyTransactionTag.ResetValue(c.state, connectionstate.ContextUser)
14441501
}()
1445-
return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, c.options( /*reset=*/ true))
1502+
execOptions, err := c.options( /*reset=*/ true)
1503+
if err != nil {
1504+
execOptions = &ExecOptions{}
1505+
}
1506+
return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, execOptions)
14461507
})
14471508
if err != nil {
14481509
cancel()

conn_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package spannerdriver
16+
17+
import (
18+
"database/sql"
19+
"testing"
20+
21+
"github.com/googleapis/go-sql-spanner/connectionstate"
22+
"github.com/googleapis/go-sql-spanner/parser"
23+
"google.golang.org/grpc/codes"
24+
"google.golang.org/grpc/status"
25+
)
26+
27+
func TestApplyStatementScopedValues(t *testing.T) {
28+
t.Parallel()
29+
30+
c := &conn{
31+
logger: noopLogger,
32+
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
33+
}
34+
if g, w := propertyIsolationLevel.GetValueOrDefault(c.state), sql.LevelDefault; g != w {
35+
t.Fatalf("default isolation level mismatch\n Got: %v\nWant: %v", g, w)
36+
}
37+
38+
// Add a statement-scoped connection property value for isolation_level.
39+
cleanup, err := c.applyStatementScopedValues(&ExecOptions{
40+
PropertyValues: []PropertyValue{
41+
CreatePropertyValue("isolation_level", "repeatable read"),
42+
},
43+
})
44+
if err != nil {
45+
t.Fatal(err)
46+
}
47+
if g, w := propertyIsolationLevel.GetValueOrDefault(c.state), sql.LevelRepeatableRead; g != w {
48+
t.Fatalf("statement isolation level mismatch\n Got: %v\nWant: %v", g, w)
49+
}
50+
// Clean up the statement-scoped connection properties.
51+
cleanup()
52+
53+
// The isolation level should now be back to default.
54+
if g, w := propertyIsolationLevel.GetValueOrDefault(c.state), sql.LevelDefault; g != w {
55+
t.Fatalf("default isolation level mismatch\n Got: %v\nWant: %v", g, w)
56+
}
57+
}
58+
59+
func TestApplyStatementScopedValuesWithInvalidValue(t *testing.T) {
60+
t.Parallel()
61+
62+
c := &conn{
63+
logger: noopLogger,
64+
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
65+
}
66+
67+
// Add a statement-scoped connection property value for isolation_level with an invalid value.
68+
_, err := c.applyStatementScopedValues(&ExecOptions{
69+
PropertyValues: []PropertyValue{
70+
CreatePropertyValue("isolation_level", "not an isolation level"),
71+
},
72+
})
73+
if g, w := status.Code(err), codes.InvalidArgument; g != w {
74+
t.Fatalf("error mismatch\n Got: %v\nWant: %v", g, w)
75+
}
76+
if g, w := propertyIsolationLevel.GetValueOrDefault(c.state), sql.LevelDefault; g != w {
77+
t.Fatalf("statement isolation level mismatch\n Got: %v\nWant: %v", g, w)
78+
}
79+
}
80+
81+
func TestApplyStatementScopedValuesWithUnknownProperty(t *testing.T) {
82+
t.Parallel()
83+
84+
c := &conn{
85+
logger: noopLogger,
86+
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
87+
}
88+
89+
// Add a statement-scoped connection property value for an unknown property without an extension.
90+
_, err := c.applyStatementScopedValues(&ExecOptions{
91+
PropertyValues: []PropertyValue{
92+
CreatePropertyValue("non_existing_property", "some-value"),
93+
},
94+
})
95+
if g, w := status.Code(err), codes.InvalidArgument; g != w {
96+
t.Fatalf("error mismatch\n Got: %v\nWant: %v", g, w)
97+
}
98+
if g, w := propertyIsolationLevel.GetValueOrDefault(c.state), sql.LevelDefault; g != w {
99+
t.Fatalf("statement isolation level mismatch\n Got: %v\nWant: %v", g, w)
100+
}
101+
}
102+
103+
func TestApplyStatementScopedValuesWithExtension(t *testing.T) {
104+
t.Parallel()
105+
106+
c := &conn{
107+
logger: noopLogger,
108+
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
109+
}
110+
111+
// Add a statement-scoped connection property value for an unknown property with an extension.
112+
// Property values for unknown properties with an extension are allowed.
113+
propValue := PropertyValue{
114+
Identifier: parser.Identifier{Parts: []string{"my_extension", "my_property"}},
115+
Value: "some-value",
116+
}
117+
cleanup, err := c.applyStatementScopedValues(&ExecOptions{
118+
PropertyValues: []PropertyValue{propValue},
119+
})
120+
if g, w := status.Code(err), codes.OK; g != w {
121+
t.Fatalf("error mismatch\n Got: %v\nWant: %v", g, w)
122+
}
123+
val, ok, err := c.state.GetValue("my_extension", "my_property")
124+
if err != nil {
125+
t.Fatal(err)
126+
}
127+
if !ok {
128+
t.Fatal("missing my_extension.my_property")
129+
}
130+
if g, w := val, "some-value"; g != w {
131+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
132+
}
133+
cleanup()
134+
135+
// The value should now be gone.
136+
_, ok, err = c.state.GetValue("my_extension", "my_property")
137+
if g, w := status.Code(err), codes.InvalidArgument; g != w {
138+
t.Fatalf("error mismatch\n Got: %v\nWant: %v", g, w)
139+
}
140+
if ok {
141+
t.Fatal("got unexpected value for my_extension.my_property")
142+
}
143+
}

connection_properties.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,27 @@ import (
2323
"cloud.google.com/go/spanner"
2424
"cloud.google.com/go/spanner/apiv1/spannerpb"
2525
"github.com/googleapis/go-sql-spanner/connectionstate"
26+
"github.com/googleapis/go-sql-spanner/parser"
2627
"google.golang.org/grpc/codes"
2728
"google.golang.org/grpc/status"
2829
)
2930

31+
// PropertyValue is an untyped property value for a connection property.
32+
// These can be set on an ExecOptions instance to set statement-scoped connection property values for a
33+
// single statement execution.
34+
type PropertyValue struct {
35+
Identifier parser.Identifier
36+
Value string
37+
}
38+
39+
// CreatePropertyValue creates an untyped property value for a connection variable.
40+
func CreatePropertyValue(name, value string) PropertyValue {
41+
return PropertyValue{
42+
Identifier: parser.Identifier{Parts: []string{name}},
43+
Value: value,
44+
}
45+
}
46+
3047
// connectionProperties contains all supported connection properties for Spanner.
3148
// These properties are added to all connectionstate.ConnectionState instances that are created for Spanner connections.
3249
var connectionProperties = map[string]connectionstate.ConnectionProperty{}

0 commit comments

Comments
 (0)