diff --git a/client.go b/client.go index 05af68f..ba3b994 100644 --- a/client.go +++ b/client.go @@ -127,9 +127,26 @@ func (c *APIClient) GetQueryID() string { return fmt.Sprintf("%s.%d", c.SessionID, c.QuerySeq) } +func (c *APIClient) NeedSticky() bool { + if c.sessionState != nil { + return c.sessionState.NeedSticky + } + return false +} + +func (c *APIClient) NeedKeepAlive() bool { + if c.sessionState != nil { + return c.sessionState.NeedKeepAlive + } + return false +} + func NewAPIHttpClientFromConfig(cfg *Config) *http.Client { + jar := NewIgnoreDomainCookieJar() + jar.SetCookies(nil, []*http.Cookie{{Name: "cookie_enabled", Value: "true"}}) cli := &http.Client{ Timeout: cfg.Timeout, + Jar: jar, } if cfg.EnableOpenTelemetry { cli.Transport = otelhttp.NewTransport(http.DefaultTransport) @@ -148,7 +165,7 @@ func NewAPIClientFromConfig(cfg *Config) *APIClient { // if role is set in config, we'd prefer to limit it as the only effective role, // so you could limit the privileges by setting a role with limited privileges. - // however this can be overridden by executing `SET SECONDARY ROLES ALL` in the + // however, this can be overridden by executing `SET SECONDARY ROLES ALL` in the // query. // secondaryRoles now have two viable values: // - nil: means enabling ALL the granted roles of the user @@ -202,7 +219,7 @@ func initAccessTokenLoader(cfg *Config) AccessTokenLoader { return nil } -func (c *APIClient) doRequest(ctx context.Context, method, path string, req interface{}, resp interface{}, respHeaders *http.Header) error { +func (c *APIClient) doRequest(ctx context.Context, method, path string, req interface{}, needSticky bool, resp interface{}, respHeaders *http.Header) error { if c.doRequestFunc != nil { return c.doRequestFunc(method, path, req, resp) } @@ -226,6 +243,9 @@ func (c *APIClient) doRequest(ctx context.Context, method, path string, req inte maxRetries := 2 for i := 1; i <= maxRetries; i++ { headers, err := c.makeHeaders(ctx) + if needSticky { + headers.Set(DatabendQueryStickyNode, c.NodeID) + } if err != nil { return errors.Wrap(err, "failed to make request headers") } @@ -484,7 +504,7 @@ func (c *APIClient) startQueryRequest(ctx context.Context, request *QueryRequest respHeaders http.Header ) err := c.doRetry(func() error { - return c.doRequest(ctx, "POST", path, request, &resp, &respHeaders) + return c.doRequest(ctx, "POST", path, request, c.NeedSticky(), &resp, &respHeaders) }, Query, ) if err != nil { @@ -520,7 +540,7 @@ func (c *APIClient) PollQuery(ctx context.Context, nextURI string) (*QueryRespon var result QueryResponse err := c.doRetry( func() error { - return c.doRequest(ctx, "GET", nextURI, nil, &result, nil) + return c.doRequest(ctx, "GET", nextURI, nil, true, &result, nil) }, Page, ) @@ -539,7 +559,7 @@ func (c *APIClient) KillQuery(ctx context.Context, response *QueryResponse) erro ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() _ = c.doRetry(func() error { - return c.doRequest(ctx, "GET", response.KillURI, nil, nil, nil) + return c.doRequest(ctx, "GET", response.KillURI, nil, true, nil, nil) }, Kill, ) } @@ -551,7 +571,7 @@ func (c *APIClient) CloseQuery(ctx context.Context, response *QueryResponse) err ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() _ = c.doRetry(func() error { - return c.doRequest(ctx, "GET", response.FinalURI, nil, nil, nil) + return c.doRequest(ctx, "GET", response.FinalURI, nil, true, nil, nil) }, Final, ) } @@ -723,6 +743,14 @@ func (c *APIClient) UploadToStageByAPI(ctx context.Context, stage *StageLocation return nil } +func (c *APIClient) Logout(ctx context.Context) error { + if c.NeedKeepAlive() { + req := &struct{}{} + return c.doRequest(ctx, "POST", "/v1/session/logout/", req, c.NeedSticky(), nil, nil) + } + return nil +} + func randRouteHint() string { charset := "abcdef0123456789" b := make([]byte, 16) diff --git a/connection.go b/connection.go index 69a667e..2075fd6 100644 --- a/connection.go +++ b/connection.go @@ -78,7 +78,7 @@ func (dc *DatabendConn) BeginTx( } func (dc *DatabendConn) cleanup() { - // must flush log buffer while the process is running. + dc.rest.Logout(dc.ctx) dc.rest = nil dc.cfg = nil } diff --git a/const.go b/const.go index 5d840d2..6c69874 100644 --- a/const.go +++ b/const.go @@ -6,6 +6,7 @@ const ( DatabendQueryIDHeader = "X-DATABEND-QUERY-ID" DatabendRouteHintHeader = "X-DATABEND-ROUTE-HINT" DatabendQueryIDNode = "X-DATABEND-NODE-ID" + DatabendQueryStickyNode = "X-DATABEND-STICKY-NODE" Authorization = "Authorization" WarehouseRoute = "X-DATABEND-ROUTE" UserAgent = "User-Agent" diff --git a/cookie_jar.go b/cookie_jar.go new file mode 100644 index 0000000..e6cee0c --- /dev/null +++ b/cookie_jar.go @@ -0,0 +1,36 @@ +package godatabend + +import ( + "net/http" + "net/url" + "sync" +) + +type IgnoreDomainCookieJar struct { + mu sync.Mutex + cookies map[string]*http.Cookie +} + +func NewIgnoreDomainCookieJar() *IgnoreDomainCookieJar { + return &IgnoreDomainCookieJar{ + cookies: make(map[string]*http.Cookie), + } +} + +func (jar *IgnoreDomainCookieJar) SetCookies(_u *url.URL, cookies []*http.Cookie) { + jar.mu.Lock() + defer jar.mu.Unlock() + for _, cookie := range cookies { + jar.cookies[cookie.Name] = cookie + } +} + +func (jar *IgnoreDomainCookieJar) Cookies(u *url.URL) []*http.Cookie { + jar.mu.Lock() + defer jar.mu.Unlock() + result := make([]*http.Cookie, 0, len(jar.cookies)) + for _, cookie := range jar.cookies { + result = append(result, cookie) + } + return result +} diff --git a/query.go b/query.go index 7433abd..8fb9b5a 100644 --- a/query.go +++ b/query.go @@ -104,7 +104,9 @@ type SessionState struct { Settings map[string]string `json:"settings,omitempty"` // txn - TxnState TxnState `json:"txn_state,omitempty"` // "Active", "AutoCommit" + TxnState TxnState `json:"txn_state,omitempty"` // "Active", "AutoCommit" + NeedSticky bool `json:"need_sticky,omitempty"` + NeedKeepAlive bool `json:"need_keep_alive,omitempty"` } type StageAttachmentConfig struct { diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 1493207..d3ede21 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -6,7 +6,7 @@ services: volumes: - ./data:/data databend: - image: datafuselabs/databend + image: datafuselabs/databend:nightly environment: - QUERY_DEFAULT_USER=databend - QUERY_DEFAULT_PASSWORD=databend diff --git a/tests/session_test.go b/tests/session_test.go index c6537dd..55eb57a 100644 --- a/tests/session_test.go +++ b/tests/session_test.go @@ -1,6 +1,7 @@ package tests import ( + "context" "database/sql" "fmt" "github.com/stretchr/testify/require" @@ -95,3 +96,32 @@ func (s *DatabendTestSuite) TestSessionVariable() { r.Nil(err) r.Equal(int64(100), result) } + +func (s *DatabendTestSuite) TestTempTable() { + r := require.New(s.T()) + + var result int64 + ctx := context.Background() + conn, err := s.db.Conn(ctx) + defer func() { + err = conn.Close() + r.Nil(err) + }() + _, err = conn.ExecContext(ctx, "create temp table t_temp (a int64)") + r.Nil(err) + _, err = conn.ExecContext(ctx, "insert into t_temp values (1), (2)") + r.Nil(err) + rows, err := conn.QueryContext(ctx, "select * from t_temp") + r.Nil(err) + defer rows.Close() + + r.True(rows.Next()) + err = rows.Scan(&result) + r.Equal(int64(1), result) + + r.True(rows.Next()) + err = rows.Scan(&result) + r.Equal(int64(2), result) + + r.False(rows.Next()) +}