Skip to content

Commit

Permalink
feat(catalog): Standardize Catalog create table function (#245)
Browse files Browse the repository at this point in the history
* standardize CreateTable

* update catalog impl

* add test for table.NewMetadata and AssignFresh* functions

* add docstrings for the new functions

* use proper type for return of With helpers

* fix lint, missing func

* Update catalog/catalog.go

Co-authored-by: Kevin Liu <kevinjqliu@users.noreply.github.com>

---------

Co-authored-by: Kevin Liu <kevinjqliu@users.noreply.github.com>
  • Loading branch information
zeroshade and kevinjqliu authored Jan 15, 2025
1 parent e7d5d8a commit 85238d2
Show file tree
Hide file tree
Showing 9 changed files with 511 additions and 40 deletions.
40 changes: 38 additions & 2 deletions catalog/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"fmt"
"maps"
"net/url"
"strings"

"github.com/apache/iceberg-go"
"github.com/apache/iceberg-go/table"
Expand Down Expand Up @@ -156,6 +157,10 @@ type Catalog interface {
// CatalogType returns the type of the catalog.
CatalogType() CatalogType

// CreateTable creates a new iceberg table in the catalog using the provided identifier
// and schema. Options can be used to optionally provide location, partition spec, sort order,
// and custom properties.
CreateTable(ctx context.Context, identifier table.Identifier, schema *iceberg.Schema, opts ...createTableOpt) (*table.Table, error)
// ListTables returns a list of table identifiers in the catalog, with the returned
// identifiers containing the information required to load the table via that catalog.
ListTables(ctx context.Context, namespace table.Identifier) ([]table.Identifier, error)
Expand Down Expand Up @@ -217,7 +222,6 @@ func getUpdatedPropsAndUpdateSummary(currentProps iceberg.Properties, removals [
if err := checkForOverlap(removals, updates); err != nil {
return nil, PropertiesUpdateSummary{}, err
}

var (
updatedProps = maps.Clone(currentProps)
removed = make([]string, 0, len(removals))
Expand All @@ -243,6 +247,38 @@ func getUpdatedPropsAndUpdateSummary(currentProps iceberg.Properties, removals [
Updated: updated,
Missing: iceberg.Difference(removals, removed),
}

return updatedProps, summary, nil
}

type createTableOpt func(*createTableCfg)

type createTableCfg struct {
location string
partitionSpec *iceberg.PartitionSpec
sortOrder table.SortOrder
properties iceberg.Properties
}

func WithLocation(location string) createTableOpt {
return func(cfg *createTableCfg) {
cfg.location = strings.TrimRight(location, "/")
}
}

func WithPartitionSpec(spec *iceberg.PartitionSpec) createTableOpt {
return func(cfg *createTableCfg) {
cfg.partitionSpec = spec
}
}

func WithSortOrder(order table.SortOrder) createTableOpt {
return func(cfg *createTableCfg) {
cfg.sortOrder = order
}
}

func WithProperties(props iceberg.Properties) createTableOpt {
return func(cfg *createTableCfg) {
cfg.properties = props
}
}
4 changes: 4 additions & 0 deletions catalog/glue.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ func (c *GlueCatalog) CatalogType() CatalogType {
return Glue
}

func (c *GlueCatalog) CreateTable(ctx context.Context, identifier table.Identifier, schema *iceberg.Schema, opts ...createTableOpt) (*table.Table, error) {
panic("create table not implemented for Glue Catalog")
}

// DropTable deletes an Iceberg table from the Glue catalog.
func (c *GlueCatalog) DropTable(ctx context.Context, identifier table.Identifier) error {
database, tableName, err := identifierToGlueTable(identifier)
Expand Down
66 changes: 28 additions & 38 deletions catalog/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,38 +134,6 @@ func (t *loadTableResponse) UnmarshalJSON(b []byte) (err error) {
return
}

type createTableOption func(*createTableRequest)

func WithLocation(loc string) createTableOption {
return func(req *createTableRequest) {
req.Location = strings.TrimRight(loc, "/")
}
}

func WithPartitionSpec(spec *iceberg.PartitionSpec) createTableOption {
return func(req *createTableRequest) {
req.PartitionSpec = spec
}
}

func WithWriteOrder(order *table.SortOrder) createTableOption {
return func(req *createTableRequest) {
req.WriteOrder = order
}
}

func WithStageCreate() createTableOption {
return func(req *createTableRequest) {
req.StageCreate = true
}
}

func WithProperties(props iceberg.Properties) createTableOption {
return func(req *createTableRequest) {
req.Props = props
}
}

type createTableRequest struct {
Name string `json:"name"`
Schema *iceberg.Schema `json:"schema"`
Expand Down Expand Up @@ -700,18 +668,40 @@ func splitIdentForPath(ident table.Identifier) (string, string, error) {
return strings.Join(NamespaceFromIdent(ident), namespaceSeparator), TableNameFromIdent(ident), nil
}

func (r *RestCatalog) CreateTable(ctx context.Context, identifier table.Identifier, schema *iceberg.Schema, opts ...createTableOption) (*table.Table, error) {
func (r *RestCatalog) CreateTable(ctx context.Context, identifier table.Identifier, schema *iceberg.Schema, opts ...createTableOpt) (*table.Table, error) {
ns, tbl, err := splitIdentForPath(identifier)
if err != nil {
return nil, err
}

payload := createTableRequest{
Name: tbl,
Schema: schema,
}
var cfg createTableCfg
for _, o := range opts {
o(&payload)
o(&cfg)
}

freshSchema, err := iceberg.AssignFreshSchemaIDs(schema, nil)
if err != nil {
return nil, err
}

freshPartitionSpec, err := iceberg.AssignFreshPartitionSpecIDs(cfg.partitionSpec, schema, freshSchema)
if err != nil {
return nil, err
}

freshSortOrder, err := table.AssignFreshSortOrderIDs(cfg.sortOrder, schema, freshSchema)
if err != nil {
return nil, err
}

payload := createTableRequest{
Name: tbl,
Schema: freshSchema,
Location: cfg.location,
PartitionSpec: &freshPartitionSpec,
WriteOrder: &freshSortOrder,
StageCreate: false,
Props: cfg.properties,
}

ret, err := doPost[createTableRequest, loadTableResponse](ctx, r.baseURI, []string{"namespaces", ns, "tables"}, payload,
Expand Down
31 changes: 31 additions & 0 deletions partitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,34 @@ func (ps *PartitionSpec) PartitionType(schema *Schema) *StructType {
}
return &StructType{FieldList: nestedFields}
}

// AssignFreshPartitionSpecIDs creates a new PartitionSpec by reassigning the field IDs
// from the old schema to the corresponding fields in the fresh schema, while re-assigning
// the actual Spec IDs to 1000 + the position of the field in the partition spec.
func AssignFreshPartitionSpecIDs(spec *PartitionSpec, old, fresh *Schema) (PartitionSpec, error) {
if spec == nil {
return PartitionSpec{}, nil
}

newFields := make([]PartitionField, 0, len(spec.fields))
for pos, field := range spec.fields {
origCol, ok := old.FindColumnName(field.SourceID)
if !ok {
return PartitionSpec{}, fmt.Errorf("could not find field in old schema: %s", field.Name)
}

freshField, ok := fresh.FindFieldByName(origCol)
if !ok {
return PartitionSpec{}, fmt.Errorf("could not find field in fresh schema: %s", field.Name)
}

newFields = append(newFields, PartitionField{
Name: field.Name,
SourceID: freshField.ID,
FieldID: partitionDataIDStart + pos,
Transform: field.Transform,
})
}

return NewPartitionSpec(newFields...), nil
}
160 changes: 160 additions & 0 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,79 @@ func visitField[T any](f NestedField, visitor SchemaVisitor[T]) T {
}
}

type PreOrderSchemaVisitor[T any] interface {
Schema(*Schema, func() T) T
Struct(StructType, []func() T) T
Field(NestedField, func() T) T
List(ListType, func() T) T
Map(MapType, func() T, func() T) T
Primitive(PrimitiveType) T
}

func PreOrderVisit[T any](sc *Schema, visitor PreOrderSchemaVisitor[T]) (res T, err error) {
if sc == nil {
err = fmt.Errorf("%w: cannot visit nil schema", ErrInvalidArgument)
return
}

defer func() {
if r := recover(); r != nil {
switch e := r.(type) {
case string:
err = fmt.Errorf("error encountered during schema visitor: %s", e)
case error:
err = fmt.Errorf("error encountered during schema visitor: %w", e)
}
}
}()

return visitor.Schema(sc, func() T {
return visitStructPreOrder(sc.AsStruct(), visitor)
}), nil
}

func visitStructPreOrder[T any](obj StructType, visitor PreOrderSchemaVisitor[T]) T {
results := make([]func() T, len(obj.FieldList))

for i, f := range obj.FieldList {
results[i] = func() T {
return visitFieldPreOrder(f, visitor)
}
}

return visitor.Struct(obj, results)
}

func visitListPreOrder[T any](obj ListType, visitor PreOrderSchemaVisitor[T]) T {
return visitor.List(obj, func() T {
return visitFieldPreOrder(obj.ElementField(), visitor)
})
}

func visitMapPreOrder[T any](obj MapType, visitor PreOrderSchemaVisitor[T]) T {
return visitor.Map(obj, func() T {
return visitFieldPreOrder(obj.KeyField(), visitor)
}, func() T {
return visitFieldPreOrder(obj.ValueField(), visitor)
})
}

func visitFieldPreOrder[T any](f NestedField, visitor PreOrderSchemaVisitor[T]) T {
var fn func() T
switch typ := f.Type.(type) {
case *StructType:
fn = func() T { return visitStructPreOrder(*typ, visitor) }
case *ListType:
fn = func() T { return visitListPreOrder(*typ, visitor) }
case *MapType:
fn = func() T { return visitMapPreOrder(*typ, visitor) }
default:
fn = func() T { return visitor.Primitive(typ.(PrimitiveType)) }
}

return visitor.Field(f, fn)
}

// IndexByID performs a post-order traversal of the given schema and
// returns a mapping from field ID to field.
func IndexByID(schema *Schema) (map[int]NestedField, error) {
Expand Down Expand Up @@ -1069,6 +1142,93 @@ func buildAccessors(schema *Schema) (map[int]accessor, error) {
return Visit(schema, buildPosAccessors{})
}

type setFreshIDs struct {
oldIdToNew map[int]int
nextIDFunc func() int
}

func (s *setFreshIDs) getAndInc(currentID int) int {
next := s.nextIDFunc()
s.oldIdToNew[currentID] = next
return next
}

func (s *setFreshIDs) Schema(_ *Schema, structResult func() Type) Type {
return structResult()
}

func (s *setFreshIDs) Struct(st StructType, fieldResults []func() Type) Type {
newFields := make([]NestedField, len(st.FieldList))
for idx, f := range st.FieldList {
newFields[idx] = NestedField{
ID: s.getAndInc(f.ID),
Name: f.Name,
Type: fieldResults[idx](),
Doc: f.Doc,
Required: f.Required,
}
}
return &StructType{FieldList: newFields}
}

func (s *setFreshIDs) Field(_ NestedField, fieldResult func() Type) Type {
return fieldResult()
}

func (s *setFreshIDs) List(list ListType, elemResult func() Type) Type {
elemID := s.getAndInc(list.ElementID)
return &ListType{
ElementID: elemID,
Element: elemResult(),
ElementRequired: list.ElementRequired,
}
}

func (s *setFreshIDs) Map(mapType MapType, keyResult, valueResult func() Type) Type {
keyID := s.getAndInc(mapType.KeyID)
valueID := s.getAndInc(mapType.ValueID)
return &MapType{
KeyID: keyID,
ValueID: valueID,
KeyType: keyResult(),
ValueType: valueResult(),
ValueRequired: mapType.ValueRequired,
}
}

func (s *setFreshIDs) Primitive(p PrimitiveType) Type {
return p
}

// AssignFreshSchemaIDs creates a new schema with fresh field IDs for all of the
// fields in it. The nextID function is used to iteratively generate the ids, if
// it is nil then a simple incrementing counter is used starting at 1.
func AssignFreshSchemaIDs(sc *Schema, nextID func() int) (*Schema, error) {
if nextID == nil {
var id int = 0
nextID = func() int {
id++
return id
}
}
visitor := &setFreshIDs{oldIdToNew: make(map[int]int), nextIDFunc: nextID}
outType, err := PreOrderVisit(sc, visitor)
if err != nil {
return nil, err
}

fields := outType.(*StructType).FieldList
var newIdentifierIDs []int
if len(sc.IdentifierFieldIDs) != 0 {
newIdentifierIDs = make([]int, len(sc.IdentifierFieldIDs))
for i, id := range sc.IdentifierFieldIDs {
newIdentifierIDs[i] = visitor.oldIdToNew[id]
}
}

return NewSchemaWithIdentifiers(0, newIdentifierIDs, fields...), nil
}

type SchemaWithPartnerVisitor[T, P any] interface {
Schema(sc *Schema, schemaPartner P, structResult T) T
Struct(st StructType, structPartner P, fieldResults []T) T
Expand Down
Loading

0 comments on commit 85238d2

Please sign in to comment.