Skip to content

Commit

Permalink
Updated some deltalake schema conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Mar 25, 2024
1 parent 1001fb4 commit e82df1c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 54 deletions.
18 changes: 16 additions & 2 deletions aligned/schemas/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def is_datetime(self) -> bool:
def is_array(self) -> bool:
return self.name.startswith('array')

def array_subtype(self) -> FeatureType | None:
if not self.is_array or '-' not in self.name:
return None

sub = str(self.name[len('array-') :])
return FeatureType(sub)

@property
def datetime_timezone(self) -> str | None:
if not self.is_datetime:
Expand Down Expand Up @@ -115,10 +122,17 @@ def pandas_type(self) -> str | type:

@property
def polars_type(self) -> type:
if self.name.startswith('datetime-'):
time_zone = self.name.split('-')[1]
if self.is_datetime:
time_zone = self.datetime_timezone
return pl.Datetime(time_zone=time_zone) # type: ignore

if self.is_array:
sub_type = self.array_subtype()
if sub_type:
return pl.List(sub_type.polars_type) # type: ignore
else:
return pl.List(pl.Utf8) # type: ignore

for name, dtype in NAME_POLARS_MAPPING:
if name == self.name:
return dtype
Expand Down
88 changes: 37 additions & 51 deletions aligned/sources/azure_blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,17 +442,14 @@ async def write_polars(self, df: pl.LazyFrame) -> None:
mode='append',
)

async def insert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
def df_to_deltalake_compatible(
self, df: pl.DataFrame, requests: list[RetrivalRequest]
) -> tuple[pl.DataFrame, dict]:
import pyarrow as pa
from aligned.schemas.constraints import Optional
from aligned.schemas.feature import Feature

df = await job.to_polars()
url = f"az://{self.path}"

def pa_field(feature: Feature) -> pa.Field:
is_nullable = Optional() in (feature.constraints or set())

def pa_dtype(dtype: FeatureType) -> pa.DataType:
pa_types = {
'int8': pa.int8(),
'int16': pa.int16(),
Expand All @@ -462,35 +459,60 @@ def pa_field(feature: Feature) -> pa.Field:
'double': pa.float64(),
'string': pa.large_string(),
'date': pa.date64(),
'embedding': pa.large_list(pa.float32()),
'datetime': pa.float64(),
'list': pa.large_list(pa.int32()),
'array': pa.large_list(pa.int32()),
'bool': pa.bool_(),
}

if feature.dtype.name in pa_types:
return pa.field(feature.name, pa_types[feature.dtype.name], nullable=is_nullable)
if dtype.name in pa_types:
return pa_types[dtype.name]

if dtype.is_datetime:
return pa.float64()

if dtype.is_array:
array_sub_dtype = dtype.array_subtype()
if array_sub_dtype:
return pa.large_list(pa_dtype(array_sub_dtype))

return pa.large_list(pa.string())

raise ValueError(f"Unsupported dtype: {dtype}")

def pa_field(feature: Feature) -> pa.Field:
is_nullable = Optional() in (feature.constraints or set())

raise ValueError(f"Unsupported dtype: {feature.dtype}")
pa_type = pa_dtype(feature.dtype)
return pa.field(feature.name, pa_type, nullable=is_nullable)

dtypes = dict(zip(df.columns, df.dtypes, strict=False))
schemas = {}

for request in requests:
for feature in request.all_features.union(request.entities):
schemas[feature.name] = pa_field(feature)

if dtypes[feature.name] == pl.Null:
df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type))
elif feature.dtype.name == 'array':
df = df.with_columns(pl.col(feature.name).cast(pl.List(pl.Int32())))
elif feature.dtype.name == 'datetime':
elif feature.dtype.is_datetime:
df = df.with_columns(pl.col(feature.name).dt.timestamp('ms').cast(pl.Float64()))
else:
df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type))

return df, schemas

async def insert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
import pyarrow as pa

df = await job.to_polars()
url = f"az://{self.path}"

df, schemas = self.df_to_deltalake_compatible(df, requests)

orderd_schema = OrderedDict(sorted(schemas.items()))
schema = list(orderd_schema.values())

df.select(list(orderd_schema.keys())).write_delta(
url,
storage_options=self.config.read_creds(),
Expand All @@ -500,58 +522,22 @@ def pa_field(feature: Feature) -> pa.Field:

async def upsert(self, job: RetrivalJob, requests: list[RetrivalRequest]) -> None:
import pyarrow as pa
from aligned.schemas.constraints import Optional
from aligned.schemas.feature import Feature
from deltalake.exceptions import TableNotFoundError

df = await job.to_polars()

url = f"az://{self.path}"
merge_on = set()

def pa_field(feature: Feature) -> pa.Field:
is_nullable = Optional() in (feature.constraints or set())

pa_types = {
'int8': pa.int8(),
'int16': pa.int16(),
'int32': pa.int32(),
'int64': pa.int64(),
'float': pa.float64(),
'double': pa.float64(),
'string': pa.large_string(),
'date': pa.date64(),
'datetime': pa.float64(),
'list': pa.large_list(pa.int32()),
'array': pa.large_list(pa.int32()),
'bool': pa.bool_(),
}

if feature.dtype.name in pa_types:
return pa.field(feature.name, pa_types[feature.dtype.name], nullable=is_nullable)

raise ValueError(f"Unsupported dtype: {feature.dtype}")

dtypes = dict(zip(df.columns, df.dtypes, strict=False))
schemas = {}

for request in requests:
merge_on.update(request.entity_names)

for feature in request.all_features.union(request.entities):
schemas[feature.name] = pa_field(feature)
if dtypes[feature.name] == pl.Null:
df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type))
elif feature.dtype.name == 'array':
df = df.with_columns(pl.col(feature.name).cast(pl.List(pl.Int32())))
elif feature.dtype.name == 'datetime':
df = df.with_columns(pl.col(feature.name).dt.timestamp('ms').cast(pl.Float64()))
else:
df = df.with_columns(pl.col(feature.name).cast(feature.dtype.polars_type))
df, schemas = self.df_to_deltalake_compatible(df, requests)

orderd_schema = OrderedDict(sorted(schemas.items()))
schema = list(orderd_schema.values())

predicate = ' AND '.join([f"s.{key} = t.{key}" for key in merge_on])

try:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.84"
version = "0.0.85"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <mats@mollestad.no>"]
license = "Apache-2.0"
Expand Down

0 comments on commit e82df1c

Please sign in to comment.