diff --git a/go.mod b/go.mod index bb8acb4..be69769 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/gigapi/gigapi/v2 v2.0.13 github.com/gigapi/metadata v0.0.4 github.com/marcboeker/go-duckdb/v2 v2.2.0 + github.com/mark3labs/mcp-go v0.32.0 github.com/spf13/afero v1.12.0 github.com/stretchr/testify v1.10.0 go.uber.org/zap v1.27.0 @@ -16,7 +17,6 @@ require ( ) require ( - github.com/apache/arrow/go/v18 v18.0.0-20240829005432-58415d1fac50 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect @@ -43,7 +43,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pierrec/lz4/v4 v4.1.22 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/redis/go-redis/v9 v9.8.0 // indirect @@ -51,12 +51,13 @@ require ( github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/viper v1.18.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect - go.uber.org/multierr v1.10.0 // indirect + go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/net v0.39.0 // indirect diff --git a/go.sum b/go.sum index 401ada6..ebd88cb 100644 --- a/go.sum +++ b/go.sum @@ -90,6 +90,8 @@ github.com/marcboeker/go-duckdb/mapping v0.0.7 h1:t0BaNmLXj76RKs/x80A/ZTe+KzZDim github.com/marcboeker/go-duckdb/mapping v0.0.7/go.mod h1:EH3RSabeePOUePoYDtF0LqfruXPtVB3M+g03QydZsck= github.com/marcboeker/go-duckdb/v2 v2.2.0 h1:xxruuYD7vWvybY52xWzV0vvHKa1IjpDDOq6T846ax/s= github.com/marcboeker/go-duckdb/v2 v2.2.0/go.mod h1:B7swJ38GcOEm9PI0IdfkZYqn5CtIjRUiQG4ZBr3hnyc= +github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= +github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= @@ -105,8 +107,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= -github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -128,8 +130,8 @@ github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9yS github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4= -github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= -github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.1 h1:rmuU42rScKWlhhJDyXZRKJQHXFX02chSVW1IvkPGiVM= @@ -143,10 +145,13 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= @@ -165,8 +170,8 @@ go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= -go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc= diff --git a/querier/flightsql.go b/querier/flightsql.go index e59e055..c4a7b4e 100644 --- a/querier/flightsql.go +++ b/querier/flightsql.go @@ -6,13 +6,8 @@ import ( "log" "net" "regexp" - "sort" "strings" - "sync" - "time" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/array" + "encoding/base64" "github.com/apache/arrow-go/v18/arrow/flight" "github.com/apache/arrow-go/v18/arrow/flight/flightsql" flightgen "github.com/apache/arrow-go/v18/arrow/flight/gen/flight" @@ -31,9 +26,6 @@ type FlightSQLServer struct { flightsql.BaseServer queryClient *QueryClient mem memory.Allocator - // Add result storage - results map[string]arrow.Record - resultsLock sync.RWMutex } // mustEmbedUnimplementedFlightServiceServer implements the FlightServiceServer interface @@ -44,7 +36,6 @@ func NewFlightSQLServer(queryClient *QueryClient) *FlightSQLServer { return &FlightSQLServer{ queryClient: queryClient, mem: memory.DefaultAllocator, - results: make(map[string]arrow.Record), } } @@ -104,26 +95,19 @@ func (s *FlightSQLServer) GetFlightInfo(ctx context.Context, desc *flight.Flight log.Printf("GetFlightInfo called with descriptor type: %v, path: %v, cmd: %v", desc.Type, desc.Path, string(desc.Cmd)) - // Handle SQL query command if desc.Type == flight.DescriptorCMD { - // Unmarshal the Any message any := &anypb.Any{} if err := proto.Unmarshal(desc.Cmd, any); err != nil { log.Printf("Failed to unmarshal Any message: %v", err) return nil, fmt.Errorf("failed to unmarshal command: %w", err) } - - // Check if this is a CommandStatementQuery if any.TypeUrl == "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery" { - // The query is in the Any message's value query := string(any.Value) - // Clean up the query string query = strings.TrimSpace(query) query = strings.ReplaceAll(query, "\n", " ") query = strings.ReplaceAll(query, "\r", " ") - query = strings.ReplaceAll(query, "\b", "") // Remove backspace characters + query = strings.ReplaceAll(query, "\b", "") query = regexp.MustCompile(`\s+`).ReplaceAllString(query, " ") - // Remove any non-printable characters query = strings.Map(func(r rune) rune { if r < 32 || r > 126 { return -1 @@ -132,7 +116,7 @@ func (s *FlightSQLServer) GetFlightInfo(ctx context.Context, desc *flight.Flight }, query) log.Printf("Executing SQL query: %v", query) - dbName := "default" // Default database name + dbName := "default" if md, ok := metadata.FromIncomingContext(ctx); ok { if bucket := md.Get("bucket"); len(bucket) > 0 { dbName = bucket[0] @@ -146,140 +130,71 @@ func (s *FlightSQLServer) GetFlightInfo(ctx context.Context, desc *flight.Flight } } - // Use QueryClient.Query which now handles all fallback logic - results, err := s.queryClient.Query(ctx, query, dbName) + parsed, err := s.queryClient.ParseQuery(query, dbName) if err != nil { - log.Printf("Query execution failed: %v", err) - return nil, fmt.Errorf("failed to execute query: %w", err) + log.Printf("Failed to parse query: %v", err) + return nil, fmt.Errorf("failed to parse query: %w", err) } - - // Convert results to Arrow format - _, recordBatch, err := convertResultsToArrow(results) + // Find relevant files (side effect: ensures query is valid) + _, err = s.queryClient.FindRelevantFiles(ctx, parsed.DbName, parsed.Measurement, parsed.TimeRange) if err != nil { - log.Printf("Failed to convert results to Arrow format: %v", err) - return nil, fmt.Errorf("failed to convert results to Arrow format: %w", err) + log.Printf("Failed to find relevant files: %v", err) + return nil, fmt.Errorf("failed to find relevant files: %w", err) } - // Generate a unique ticket - ticketID := fmt.Sprintf("query-%d", time.Now().UnixNano()) - - // Store the results - s.resultsLock.Lock() - s.results[ticketID] = recordBatch - s.resultsLock.Unlock() - - // Create a ticket for the results + // Encode the query and dbName as a ticket (base64) + ticketPayload := fmt.Sprintf("%s|%s", dbName, query) ticket := &flight.Ticket{ - Ticket: []byte(ticketID), + Ticket: []byte(base64.StdEncoding.EncodeToString([]byte(ticketPayload))), } - // Create the flight info info := &flight.FlightInfo{ FlightDescriptor: desc, - Endpoint: []*flight.FlightEndpoint{ - { - Ticket: ticket, - Location: []*flight.Location{ - { - Uri: "grpc://localhost:8082", - }, - }, - }, - }, - TotalRecords: recordBatch.NumRows(), + Endpoint: []*flight.FlightEndpoint{{ + Ticket: ticket, + Location: []*flight.Location{{Uri: "grpc://localhost:8082"}}, + }}, + TotalRecords: -1, // Unknown until DoGet TotalBytes: -1, - Schema: []byte{}, // Empty schema, will be sent in DoGet + Schema: []byte{}, } - - log.Printf("Returning flight info with %d records", recordBatch.NumRows()) return info, nil } } - - // For now, we don't support any other flight info requests return nil, fmt.Errorf("unsupported flight descriptor type: %v", desc.Type) } -// GetFlightInfoStatement implements the FlightSQL server interface for executing SQL statements -func (s *FlightSQLServer) GetFlightInfoStatement(ctx context.Context, cmd *flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { - log.Printf("GetFlightInfoStatement called with descriptor type: %v, path: %v, cmd: %v", - desc.Type, desc.Path, string(desc.Cmd)) - - // Extract query from command - query := string(desc.Cmd) - - // Execute the query using our existing QueryClient - results, err := s.queryClient.Query(ctx, query, "default") // Using default database for now - if err != nil { - log.Printf("Query execution failed: %v", err) - return nil, fmt.Errorf("failed to execute query: %w", err) - } - - // Convert results to Arrow format - _, recordBatch, err := convertResultsToArrow(results) - if err != nil { - log.Printf("Failed to convert results to Arrow format: %v", err) - return nil, fmt.Errorf("failed to convert results to Arrow format: %w", err) - } - - // Create a ticket for the results - ticket := &flight.Ticket{ - Ticket: []byte("query-results"), - } - - // Create the flight info - info := &flight.FlightInfo{ - FlightDescriptor: desc, - Endpoint: []*flight.FlightEndpoint{ - { - Ticket: ticket, - Location: []*flight.Location{ - { - Uri: "grpc://localhost:8082", - }, - }, - }, - }, - TotalRecords: recordBatch.NumRows(), - TotalBytes: -1, - Schema: []byte{}, // Empty schema, will be sent in DoGet - } - - log.Printf("Returning flight info with %d records", recordBatch.NumRows()) - return info, nil -} - // DoGet implements the FlightSQL server interface for retrieving data func (s *FlightSQLServer) DoGet(ticket *flight.Ticket, stream flight.FlightService_DoGetServer) error { log.Printf("DoGet called with ticket: %v", string(ticket.Ticket)) - - // Get the results from storage - s.resultsLock.RLock() - recordBatch, exists := s.results[string(ticket.Ticket)] - s.resultsLock.RUnlock() - - if !exists { - return fmt.Errorf("no results found for ticket: %s", string(ticket.Ticket)) + decoded, err := base64.StdEncoding.DecodeString(string(ticket.Ticket)) + if err != nil { + return fmt.Errorf("invalid ticket encoding: %w", err) } + parts := strings.SplitN(string(decoded), "|", 2) + if len(parts) != 2 { + return fmt.Errorf("invalid ticket format") + } + // dbName := parts[0] // Unused + query := parts[1] - // Get the schema from the record batch - schema := recordBatch.Schema() - - // Write the schema - writer := flight.NewRecordWriter(stream, ipc.WithSchema(schema)) - err := writer.Write(recordBatch) + // Use the Arrow-native QueryArrow method + arrowReader, schema, err := s.queryClient.QueryArrow(stream.Context(), query) if err != nil { - log.Printf("Failed to write record batch: %v", err) - return fmt.Errorf("failed to write record batch: %w", err) + return fmt.Errorf("arrow query failed: %w", err) } + defer arrowReader.Release() - // Clean up the stored results - s.resultsLock.Lock() - delete(s.results, string(ticket.Ticket)) - s.resultsLock.Unlock() + writer := flight.NewRecordWriter(stream, ipc.WithSchema(schema)) + defer writer.Close() - log.Printf("Successfully wrote record batch with %d rows", recordBatch.NumRows()) - return writer.Close() + for arrowReader.Next() { + rec := arrowReader.Record() + if err := writer.Write(rec); err != nil { + return err + } + } + return nil } // DoPut implements the FlightService interface @@ -303,157 +218,6 @@ func (s *FlightSQLServer) DoExchange(stream flight.FlightService_DoExchangeServe return fmt.Errorf("exchange not supported") } -// convertResultsToArrow converts our query results to Arrow format -func convertResultsToArrow(results []map[string]interface{}) (*arrow.Schema, arrow.Record, error) { - if len(results) == 0 { - return nil, nil, fmt.Errorf("no results to convert") - } - - // Get column names from the first row, ensuring "time" is first - var columnNames []string - hasTime := false - for columnName := range results[0] { - if columnName == "time" { - hasTime = true - continue - } - columnNames = append(columnNames, columnName) - } - sort.Strings(columnNames) - if hasTime { - columnNames = append([]string{"time"}, columnNames...) - } - - // Create schema fields - fields := make([]arrow.Field, len(columnNames)) - for i, columnName := range columnNames { - dataType := inferTypeFromColumn(columnName, results) - fields[i] = arrow.Field{Name: columnName, Type: dataType, Nullable: true} - } - schema := arrow.NewSchema(fields, nil) - - // Create builders for each column - builders := make([]array.Builder, len(columnNames)) - for i, field := range schema.Fields() { - builders[i] = array.NewBuilder(memory.DefaultAllocator, field.Type) - } - - // Populate builders with data - for _, row := range results { - for i, columnName := range columnNames { - value := row[columnName] - if value == nil { - builders[i].AppendNull() - continue - } - - switch builder := builders[i].(type) { - case *array.TimestampBuilder: - switch value.(type) { - case int64: - ts := value.(int64) - builder.Append(arrow.Timestamp(ts)) - case string: - str := value.(string) - if ts, err := parseTimestamp(str); err == nil { - builder.Append(arrow.Timestamp(ts)) - } else { - builder.AppendNull() - } - default: - builder.AppendNull() - } - case *array.Int64Builder: - if v, ok := value.(int64); ok { - builder.Append(v) - } else { - builder.AppendNull() - } - case *array.Float64Builder: - if v, ok := value.(float64); ok { - builder.Append(v) - } else { - builder.AppendNull() - } - case *array.BooleanBuilder: - if v, ok := value.(bool); ok { - builder.Append(v) - } else { - builder.AppendNull() - } - case *array.StringBuilder: - if v, ok := value.(string); ok { - builder.Append(v) - } else { - builder.Append(fmt.Sprintf("%v", value)) - } - default: - return nil, nil, fmt.Errorf("unsupported builder type for column %s", columnName) - } - } - } - - // Create arrays from builders - arrays := make([]arrow.Array, len(builders)) - for i, builder := range builders { - arrays[i] = builder.NewArray() - defer arrays[i].Release() - builder.Release() - } - - // Create record - record := array.NewRecord(schema, arrays, int64(len(results))) - return schema, record, nil -} - -func parseTimestamp(s string) (int64, error) { - formats := []string{ - time.RFC3339, - time.RFC3339Nano, - "2006-01-02 15:04:05", - "2006-01-02T15:04:05", - "2006-01-02 15:04:05.999999999", - "2006-01-02T15:04:05.999999999", - } - - for _, format := range formats { - if t, err := time.Parse(format, s); err == nil { - return t.UnixNano(), nil - } - } - return 0, fmt.Errorf("could not parse timestamp: %s", s) -} - -// inferTypeFromColumn attempts to infer the Arrow type for a column by looking at non-null values -func inferTypeFromColumn(columnName string, results []map[string]interface{}) arrow.DataType { - // Time-related columns should always be timestamps - if columnName == "time" || columnName == "time_str" || columnName == "time_int" { - return &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"} - } - - // For other columns, infer type from the first non-nil value - for _, row := range results { - if value, exists := row[columnName]; exists && value != nil { - switch value.(type) { - case int64: - return arrow.PrimitiveTypes.Int64 - case float64: - return arrow.PrimitiveTypes.Float64 - case bool: - return arrow.FixedWidthTypes.Boolean - case string: - return arrow.BinaryTypes.String - default: - // If we encounter an unknown type, convert it to string - return arrow.BinaryTypes.String - } - } - } - - // If all values are nil, default to string type - return arrow.BinaryTypes.String -} - var s *grpc.Server func StopFlightSQLServer() { diff --git a/querier/flightsql_test.go b/querier/flightsql_test.go index 08d2771..7ce165d 100644 --- a/querier/flightsql_test.go +++ b/querier/flightsql_test.go @@ -1,288 +1,49 @@ package querier import ( + "fmt" + "net" "testing" + "time" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" ) -func TestConvertResultsToArrow(t *testing.T) { - tests := []struct { - name string - results []map[string]interface{} - wantType arrow.DataType - check func(t *testing.T, record arrow.Record) - }{ - { - name: "Nanosecond timestamp from time column", - results: []map[string]interface{}{ - { - "time": int64(1704067200000000000), // 2024-01-01T00:00:00Z in nanoseconds - }, - }, - wantType: &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}, - check: func(t *testing.T, record arrow.Record) { - assert.Equal(t, int64(1), record.NumRows()) - assert.Equal(t, int64(1), record.NumCols()) - - col := record.Column(0) - assert.IsType(t, &array.Timestamp{}, col) - - ts := col.(*array.Timestamp) - assert.False(t, ts.IsNull(0)) - assert.Equal(t, arrow.Timestamp(1704067200000000000), ts.Value(0)) - }, - }, - { - name: "Multiple timestamp formats", - results: []map[string]interface{}{ - { - "time": int64(1704067200000000000), // 2024-01-01T00:00:00Z in nanoseconds - "time_str": int64(1704067200000000000), // Using int64 since we want timestamps - "time_int": int64(1704067200000000000), - }, - }, - wantType: &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}, - check: func(t *testing.T, record arrow.Record) { - assert.Equal(t, int64(1), record.NumRows()) - assert.Equal(t, int64(3), record.NumCols()) - - // Check all fields are timestamp type - for i := 0; i < int(record.NumCols()); i++ { - field := record.Schema().Field(i) - assert.IsType(t, &arrow.TimestampType{}, field.Type) - assert.Equal(t, arrow.Nanosecond, field.Type.(*arrow.TimestampType).Unit) - assert.Equal(t, "UTC", field.Type.(*arrow.TimestampType).TimeZone) - } - - // Check time column (nanoseconds) - timeCol := record.Column(0) - assert.IsType(t, &array.Timestamp{}, timeCol) - ts := timeCol.(*array.Timestamp) - assert.False(t, ts.IsNull(0)) - assert.Equal(t, arrow.Timestamp(1704067200000000000), ts.Value(0)) - - // Check time_str column - timeStrCol := record.Column(1) - assert.IsType(t, &array.Timestamp{}, timeStrCol) - tsStr := timeStrCol.(*array.Timestamp) - assert.False(t, tsStr.IsNull(0)) - assert.Equal(t, arrow.Timestamp(1704067200000000000), tsStr.Value(0)) - - // Check time_int column - timeIntCol := record.Column(2) - assert.IsType(t, &array.Timestamp{}, timeIntCol) - tsInt := timeIntCol.(*array.Timestamp) - assert.False(t, tsInt.IsNull(0)) - assert.Equal(t, arrow.Timestamp(1704067200000000000), tsInt.Value(0)) - }, - }, - { - name: "Null timestamp handling", - results: []map[string]interface{}{ - { - "time": nil, - }, - }, - wantType: &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}, - check: func(t *testing.T, record arrow.Record) { - assert.Equal(t, int64(1), record.NumRows()) - assert.Equal(t, int64(1), record.NumCols()) - - col := record.Column(0) - assert.IsType(t, &array.Timestamp{}, col) - - ts := col.(*array.Timestamp) - assert.True(t, ts.IsNull(0)) - }, - }, - { - name: "Mixed data types with time column", - results: []map[string]interface{}{ - { - "time": int64(1704067200000000000), - "count": int64(42), - "value": float64(3.14), - "active": true, - "message": "test", - }, - }, - wantType: &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}, - check: func(t *testing.T, record arrow.Record) { - assert.Equal(t, int64(1), record.NumRows()) - assert.Equal(t, int64(5), record.NumCols()) - - // Find columns by name - timeCol := findColumnByName(record, "time") - countCol := findColumnByName(record, "count") - valueCol := findColumnByName(record, "value") - activeCol := findColumnByName(record, "active") - messageCol := findColumnByName(record, "message") - - // Check time column - assert.IsType(t, &array.Timestamp{}, timeCol) - ts := timeCol.(*array.Timestamp) - assert.False(t, ts.IsNull(0)) - assert.Equal(t, arrow.Timestamp(1704067200000000000), ts.Value(0)) - - // Check count column - assert.IsType(t, &array.Int64{}, countCol) - count := countCol.(*array.Int64) - assert.False(t, count.IsNull(0)) - assert.Equal(t, int64(42), count.Value(0)) - - // Check value column - assert.IsType(t, &array.Float64{}, valueCol) - value := valueCol.(*array.Float64) - assert.False(t, value.IsNull(0)) - assert.Equal(t, float64(3.14), value.Value(0)) - - // Check active column - assert.IsType(t, &array.Boolean{}, activeCol) - active := activeCol.(*array.Boolean) - assert.False(t, active.IsNull(0)) - assert.Equal(t, true, active.Value(0)) - - // Check message column - assert.IsType(t, &array.String{}, messageCol) - message := messageCol.(*array.String) - assert.False(t, message.IsNull(0)) - assert.Equal(t, "test", message.Value(0)) - }, - }, - { - name: "Invalid timestamp string", - results: []map[string]interface{}{ - { - "time": "invalid-timestamp", - }, - }, - wantType: &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}, - check: func(t *testing.T, record arrow.Record) { - assert.Equal(t, int64(1), record.NumRows()) - assert.Equal(t, int64(1), record.NumCols()) - - col := record.Column(0) - assert.IsType(t, &array.Timestamp{}, col) - - ts := col.(*array.Timestamp) - assert.True(t, ts.IsNull(0)) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - schema, record, err := convertResultsToArrow(tt.results) - assert.NoError(t, err) - assert.NotNil(t, schema) - assert.NotNil(t, record) - - // Check schema of time column - timeField := schema.Field(0) - assert.Equal(t, "time", timeField.Name) - assert.Equal(t, tt.wantType, timeField.Type) - - // Run custom checks - tt.check(t, record) - }) - } -} - -func TestInferTypeFromColumn(t *testing.T) { - tests := []struct { - name string - column string - results []map[string]interface{} - wantType arrow.DataType - }{ - { - name: "Time column always returns timestamp", - column: "time", - results: []map[string]interface{}{ - {"time": nil}, - {"time": "not a timestamp"}, - {"time": 42}, - }, - wantType: &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}, - }, - { - name: "Time-like column returns timestamp", - column: "time_str", - results: []map[string]interface{}{ - {"time_str": int64(1704067200000000000)}, - }, - wantType: &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}, - }, - { - name: "Int64 column", - column: "count", - results: []map[string]interface{}{ - {"count": int64(42)}, - }, - wantType: arrow.PrimitiveTypes.Int64, - }, - { - name: "Float64 column", - column: "value", - results: []map[string]interface{}{ - {"value": float64(3.14)}, - }, - wantType: arrow.PrimitiveTypes.Float64, - }, - { - name: "String column", - column: "message", - results: []map[string]interface{}{ - {"message": "test"}, - }, - wantType: arrow.BinaryTypes.String, - }, - { - name: "Boolean column", - column: "active", - results: []map[string]interface{}{ - {"active": true}, - }, - wantType: arrow.FixedWidthTypes.Boolean, - }, - { - name: "All null values defaults to string", - column: "unknown", - results: []map[string]interface{}{ - {"unknown": nil}, - {"unknown": nil}, - }, - wantType: arrow.BinaryTypes.String, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotType := inferTypeFromColumn(tt.column, tt.results) - assert.Equal(t, tt.wantType, gotType) - }) +func startTestFlightSQLServer(qc *QueryClient, port int) (*grpc.Server, error) { + s := grpc.NewServer() + server := NewFlightSQLServer(qc) + flightsql.RegisterFlightServiceServer(s, server) + ln, err := net.Listen("tcp", ":"+fmt.Sprint(port)) + if err != nil { + return nil, err } + go s.Serve(ln) + // Wait a moment for server to start + time.Sleep(200 * time.Millisecond) + return s, nil } -// Helper function to find a column by name -func findColumnByName(record arrow.Record, name string) arrow.Array { - for i := 0; i < int(record.NumCols()); i++ { - if record.Schema().Field(i).Name == name { - return record.Column(i) - } +func TestFlightSQLServer_BasicConnection(t *testing.T) { + qc := NewQueryClient("/tmp/testdata") + _ = qc.Initialize() + defer qc.Close() + + port := 32100 + s, err := startTestFlightSQLServer(qc, port) + assert.NoError(t, err) + defer s.Stop() + + client, err := flightsql.NewClient( + fmt.Sprintf("localhost:%d", port), + nil, nil, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + assert.NoError(t, err) + if client != nil { + assert.NoError(t, client.Close()) } - return nil -} - -// Helper function to find a field by name -func findFieldByName(schema *arrow.Schema, name string) arrow.Field { - for i := 0; i < schema.NumFields(); i++ { - if schema.Field(i).Name == name { - return schema.Field(i) - } - } - return arrow.Field{} -} +} \ No newline at end of file diff --git a/querier/queryClient.go b/querier/queryClient.go index 09536cf..57877d8 100644 --- a/querier/queryClient.go +++ b/querier/queryClient.go @@ -18,6 +18,8 @@ import ( "github.com/gigapi/gigapi-querier/core" _ "github.com/marcboeker/go-duckdb/v2" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow" ) var db *sql.DB @@ -928,6 +930,12 @@ func (c *QueryClient) Query(ctx context.Context, query, dbName string) ([]map[st return result, nil } +// QueryArrow executes a query against DuckDB and returns an Arrow RecordReader and schema +func (c *QueryClient) QueryArrow(ctx context.Context, query string) (array.RecordReader, *arrow.Schema, error) { + // TODO: Implement Arrow streaming for the current DuckDB Go driver version + return nil, nil, fmt.Errorf("Arrow streaming not implemented for this DuckDB Go driver version") +} + // Close releases resources func (q *QueryClient) Close() error { if q.DB != nil {