Skip to content
Open
21 changes: 21 additions & 0 deletions runtime/drivers/athena/olap.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,27 @@ func (c *Connection) Query(ctx context.Context, stmt *drivers.Statement) (*drive
}, nil
}

func (c *Connection) Head(ctx context.Context, db, schema, table string, limit int64) (*drivers.Result, error) {
tbl, err := c.InformationSchema().Lookup(ctx, db, schema, table)
if err != nil {
return nil, err
}

var columns []string
for _, field := range tbl.Schema.Fields {
columns = append(columns, c.Dialect().EscapeIdentifier(field.Name))
}

limitClause := ""
if limit > 0 {
limitClause = fmt.Sprintf(" LIMIT %d", limit)
}

return c.Query(ctx, &drivers.Statement{
Query: fmt.Sprintf("SELECT %s FROM %s%s", strings.Join(columns, ", "), c.Dialect().EscapeTable(db, schema, table), limitClause),
})
}

// QuerySchema implements drivers.OLAPStore.
func (c *Connection) QuerySchema(ctx context.Context, query string, args []any) (*runtimev1.StructType, error) {
return nil, drivers.ErrNotImplemented
Expand Down
15 changes: 11 additions & 4 deletions runtime/drivers/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,19 @@ var spec = drivers.Spec{
},
},
ImplementsWarehouse: true,
ImplementsOLAP: true,
}

type driver struct{}

type configProperties struct {
SecretJSON string `mapstructure:"google_application_credentials"`
ProjectID string `mapstructure:"project_id"`
AllowHostAccess bool `mapstructure:"allow_host_access"`
SecretJSON string `mapstructure:"google_application_credentials"`
ProjectID string `mapstructure:"project_id"`
// MaxBytesBilled is the maximum number of bytes billed for a query. This is a safety mechanism to prevent accidentally running large queries.
// Set this to 0 for project defaults.
// Only applies to dashboard queries and does not apply when ingesting data from BigQuery into Rill.
MaxBytesBilled int64 `mapstructure:"max_bytes_billed"`
AllowHostAccess bool `mapstructure:"allow_host_access"`
// LogQueries controls whether to log the raw SQL passed to OLAP.
LogQueries bool `mapstructure:"log_queries"`
}
Expand All @@ -64,7 +69,9 @@ func (d driver) Open(_, instanceID string, config map[string]any, st *storage.Cl
return nil, errors.New("bigquery driver can't be shared")
}

conf := &configProperties{}
conf := &configProperties{
MaxBytesBilled: 0, // 0 defaults to project default set directly in BigQuery
}
err := mapstructure.WeakDecode(config, conf)
if err != nil {
return nil, err
Expand Down
152 changes: 115 additions & 37 deletions runtime/drivers/bigquery/olap.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"math/big"
"reflect"
"strings"
"time"

"cloud.google.com/go/bigquery"
Expand Down Expand Up @@ -67,8 +69,19 @@ func (c *Connection) Query(ctx context.Context, stmt *drivers.Statement) (res *d
return nil, err
}

wrapMaxBytesBilledError := func(err error, maxBytesBilled int64) error {
if err != nil && maxBytesBilled >= 0 && strings.Contains(strings.ToLower(err.Error()), "query exceeded limit for bytes billed") {
return fmt.Errorf("bigquery query exceeds configured max_bytes_billed limit (%d bytes). Increase `max_bytes_billed` in connector config or set it to -1 to disable the limit: %w", maxBytesBilled, err)
}
return err
}

q := client.Query(stmt.Query)
q.Parameters = make([]bigquery.QueryParameter, len(stmt.Args))
q.ConnectionProperties = []*bigquery.ConnectionProperty{
{Key: "time_zone", Value: "UTC"},
}
q.MaxBytesBilled = c.config.MaxBytesBilled
for i, arg := range stmt.Args {
q.Parameters[i] = bigquery.QueryParameter{
Value: arg,
Expand All @@ -79,43 +92,47 @@ func (c *Connection) Query(ctx context.Context, stmt *drivers.Statement) (res *d
// Can not use q.Read for dry run so must trigger the job and check status
j, err := q.Run(ctx)
if err != nil {
return nil, err
return nil, wrapMaxBytesBilledError(err, c.config.MaxBytesBilled)
}
// Dry run is not asynchronous so no need to call Wait
status := j.LastStatus()
return nil, status.Err()
}
it, err := q.Read(ctx)
if err != nil {
return nil, err
}
stats, ok := status.Statistics.Details.(*bigquery.QueryStatistics)
if !ok {
return nil, fmt.Errorf("unexpected statistics type")
}

// We use query.Read which can also use fast path when results are small.
// In fast path schema is only available after fetching the first row.
var firstRow []bigquery.Value
for i := 0; i < len(it.Schema); i++ {
firstRow = append(firstRow, new(bigquery.Value))
}
rowErr := it.Next(&firstRow)
if rowErr != nil && !errors.Is(rowErr, iterator.Done) {
return nil, err
// extract schema
schema, err := fromBQSchema(stats.Schema)
if err != nil {
return nil, err
}
res := &drivers.Result{
Schema: schema,
}
return res, wrapMaxBytesBilledError(status.Err(), c.config.MaxBytesBilled)
}
// schema is returned even if there are no rows
schema, err := fromBQSchema(it.Schema)
it, err := q.Read(ctx)
if err != nil {
return nil, err
return nil, wrapMaxBytesBilledError(err, c.config.MaxBytesBilled)
}
row := newRows(it, firstRow, errors.Is(rowErr, iterator.Done))
res = &drivers.Result{
Rows: row,
Schema: schema,
}
return res, nil
return toResult(it, math.MaxInt64)
}

// QuerySchema implements drivers.OLAPStore.
func (c *Connection) QuerySchema(ctx context.Context, query string, args []any) (*runtimev1.StructType, error) {
return nil, drivers.ErrNotImplemented
ctx, cancel := context.WithTimeout(ctx, drivers.DefaultQuerySchemaTimeout)
defer cancel()

res, err := c.Query(ctx, &drivers.Statement{
Query: query,
Args: args,
DryRun: true,
})
if err != nil {
return nil, err
}
defer res.Close()
return res.Schema, nil
}

// WithConnection implements drivers.OLAPStore.
Expand Down Expand Up @@ -169,7 +186,13 @@ func (c *Connection) Lookup(ctx context.Context, db, schema, name string) (*driv
return nil, fmt.Errorf("failed to get BigQuery client: %w", err)
}

table := client.Dataset(schema).Table(name)
var table *bigquery.Table
if db != "" {
table = client.DatasetInProject(db, schema).Table(name)
} else {
table = client.Dataset(schema).Table(name)
}

meta, err := table.Metadata(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get table metadata: %w", err)
Expand All @@ -178,29 +201,66 @@ func (c *Connection) Lookup(ctx context.Context, db, schema, name string) (*driv
if err != nil {
return nil, err
}
return &drivers.OlapTable{
tbl := &drivers.OlapTable{
Database: db,
DatabaseSchema: schema,
Name: name,
View: meta.Type == bigquery.ViewTable,
Schema: runtimeSchema,
UnsupportedCols: nil, // all columns are currently being mapped though may not be as specific as in BigQuery
PhysicalSizeBytes: 0,
}, nil
}
return tbl, nil
}

func (c *Connection) Head(ctx context.Context, db, schema, table string, limit int64) (*drivers.Result, error) {
client, err := c.getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery client: %w", err)
}
tbl := client.DatasetInProject(db, schema).Table(table)
it := tbl.Read(ctx)
return toResult(it, limit)
}

func toResult(it *bigquery.RowIterator, limit int64) (*drivers.Result, error) {
// We use query.Read which can also use fast path when results are small.
// In fast path schema is only available after fetching the first row.
var firstRow []bigquery.Value
for i := 0; i < len(it.Schema); i++ {
firstRow = append(firstRow, new(bigquery.Value))
}
err := it.Next(&firstRow)
if err != nil && !errors.Is(err, iterator.Done) {
return nil, err
}
row := newRows(it, firstRow, errors.Is(err, iterator.Done), limit)
// schema is returned even if there are no rows
schema, err := fromBQSchema(it.Schema)
if err != nil {
return nil, err
}
res := &drivers.Result{
Rows: row,
Schema: schema,
}
return res, nil
}

type rows struct {
ri *bigquery.RowIterator
ri *bigquery.RowIterator
limit int64

firstRow []bigquery.Value
canScanFirstRow bool

scanned int64
lastRow []bigquery.Value // last scanned row from ri in Next
lastErr error
canScanRow bool
}

func newRows(ri *bigquery.RowIterator, firstRow []bigquery.Value, noRows bool) *rows {
func newRows(ri *bigquery.RowIterator, firstRow []bigquery.Value, noRows bool, limit int64) *rows {
if noRows {
return &rows{
lastErr: iterator.Done,
Expand All @@ -210,9 +270,10 @@ func newRows(ri *bigquery.RowIterator, firstRow []bigquery.Value, noRows bool) *
ri: ri,
firstRow: firstRow,
canScanFirstRow: true,
limit: limit,
}
r.lastRow = make([]bigquery.Value, len(firstRow))
for i := range len(firstRow) {
for i := range firstRow {
r.lastRow[i] = new(bigquery.Value)
}
return r
Expand Down Expand Up @@ -245,10 +306,17 @@ func (r *rows) MapScan(dest map[string]any) error {
return err
}
for i, col := range r.ri.Schema {
dest[col.Name], err = convertValue(r.ri.Schema[i], row[i])
v, err := convertValue(r.ri.Schema[i], row[i])
if err != nil {
return err
}
if val, ok := v.(sqldriver.Valuer); ok {
v, err = val.Value()
if err != nil {
return err
}
}
dest[col.Name] = v
}
return nil
}
Expand All @@ -259,10 +327,15 @@ func (r *rows) Next() bool {
return false
}

if r.scanned >= r.limit {
return false
}

// first row was already fetched during query execution to get schema
if r.canScanFirstRow {
r.canScanRow = true
r.canScanFirstRow = false
r.scanned++
return true
}

Expand All @@ -275,6 +348,7 @@ func (r *rows) Next() bool {
return false
}
r.canScanRow = true
r.scanned++
return true
}

Expand Down Expand Up @@ -356,11 +430,11 @@ func toPB(field *bigquery.FieldSchema) (*runtimev1.Type, error) {
case bigquery.TimestampFieldType:
t.Code = runtimev1.Type_CODE_TIMESTAMP
case bigquery.DateTimeFieldType:
t.Code = runtimev1.Type_CODE_STRING
t.Code = runtimev1.Type_CODE_TIMESTAMP
case bigquery.TimeFieldType:
t.Code = runtimev1.Type_CODE_STRING
case bigquery.DateFieldType:
t.Code = runtimev1.Type_CODE_STRING
t.Code = runtimev1.Type_CODE_DATE
case bigquery.BooleanFieldType:
t.Code = runtimev1.Type_CODE_BOOL
case bigquery.IntegerFieldType:
Expand All @@ -387,6 +461,10 @@ func convertValue(field *bigquery.FieldSchema, value bigquery.Value) (any, error
return val, nil
}

if _, ok := value.(sqldriver.Valuer); ok {
return value, nil
}

// Marshal ARRAY and RECORD types to JSON, since arrays/maps aren't
// valid driver.Value types.
out, err := json.Marshal(val)
Expand Down Expand Up @@ -418,11 +496,11 @@ func convertUnitType(field *bigquery.FieldSchema, value bigquery.Value) (any, er
case bigquery.TimestampFieldType:
return convertBasicType[time.Time](field, value)
case bigquery.DateFieldType:
return convertStringerType[civil.Date](field, value)
return value, nil // no conversion for civil.Date type
case bigquery.TimeFieldType:
return convertStringerType[civil.Time](field, value)
case bigquery.DateTimeFieldType:
return convertStringerType[civil.DateTime](field, value)
return value, nil // no conversion for civil.DateTime type
case bigquery.NumericFieldType:
return convertRationalType(field, value, bigquery.NumericString)
case bigquery.BigNumericFieldType:
Expand Down
Loading
Loading