diff --git a/CHANGELOG.md b/CHANGELOG.md index bae4f98b..11981d33 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # WIP: 0.15.0 +## Important new stuff: + +### MetaPlane + +See [design-docs/2025-12-meta-plane.md](meta-plane.md) for motivation + * Introduced `MetaPlane`/`TableMeta`/`TransformMeta` interfaces to decouple metadata management from the compute plane * Added SQL reference implementation (`SQLMetaPlane`, `SQLTableMeta`, @@ -7,6 +13,17 @@ steps to consume the new meta plane API * Added meta-plane design doc and removed legacy `MetaTable` plumbing in lints, migrations, and tests + +### InputSpec and key mapping + +See [design-docs/2025-12-key-mapping.md](key-mapping.md) for motivation + +* Renamed `JoinSpec` to `InputSpec` +* Added `keys` parameter to `InputSpec` and `ComputeInput` to support + joining tables with different key names +* Added `DataField` accessor for `InputSpec.keys` + +## CLI improvements: * Make CLI accept multiple `--name` values # 0.14.6 diff --git a/datapipe/compute.py b/datapipe/compute.py index d5e8c314..27b6d057 100644 --- a/datapipe/compute.py +++ b/datapipe/compute.py @@ -5,13 +5,14 @@ from typing import Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union from opentelemetry import trace +from sqlalchemy import Column from datapipe.datatable import DataStore, DataTable from datapipe.executor import Executor, ExecutorConfig from datapipe.run_config import RunConfig from datapipe.store.database import TableStoreDB from datapipe.store.table_store import TableStore -from datapipe.types import ChangeList, IndexDF, Labels, TableOrName +from datapipe.types import ChangeList, DataField, FieldAccessor, IndexDF, Labels, MetaSchema, TableOrName logger = logging.getLogger("datapipe.compute") tracer = trace.get_tracer("datapipe.compute") @@ -86,6 +87,51 @@ class ComputeInput: dt: DataTable join_type: Literal["inner", "full"] = "full" + # If provided, this dict tells how to get key columns from meta and data tables + # + # Example: {"idx_col": DataField("data_col")} means that to get idx_col value + # we should read data_col from data table + # + # Example: {"idx_col": "meta_col"} means that to get idx_col value + # we should read meta_col from meta table + keys: Optional[Dict[str, FieldAccessor]] = None + + @property + def primary_keys(self) -> List[str]: + if self.keys: + return list(self.keys.keys()) + else: + return self.dt.primary_keys + + @property + def primary_schema(self) -> MetaSchema: + if self.keys: + primary_schema_dict = {col.name: col for col in self.dt.primary_schema} + data_schema_dict = {col.name: col for col in self.dt.table_store.get_schema()} + + schema = [] + for k, accessor in self.keys.items(): + if isinstance(accessor, str): + source_column = primary_schema_dict[accessor] + column_alias = k + elif isinstance(accessor, DataField): + source_column = data_schema_dict[accessor.field_name] + column_alias = k + schema.append(data_schema_dict[accessor.field_name]) + else: + raise ValueError(f"Unknown accessor type: {type(accessor)}") + + schema.append( + Column( + column_alias, + source_column.type, + primary_key=source_column.primary_key, + ) + ) + return schema + else: + return self.dt.primary_schema + class ComputeStep: """ @@ -114,8 +160,7 @@ def __init__( self._name = name # Нормализация input_dts: автоматически оборачиваем DataTable в ComputeInput self.input_dts = [ - inp if isinstance(inp, ComputeInput) else ComputeInput(dt=inp, join_type="full") - for inp in input_dts + inp if isinstance(inp, ComputeInput) else ComputeInput(dt=inp, join_type="full") for inp in input_dts ] self.output_dts = output_dts self._labels = labels diff --git a/datapipe/meta/base.py b/datapipe/meta/base.py index 5283cc4e..72a084ad 100644 --- a/datapipe/meta/base.py +++ b/datapipe/meta/base.py @@ -1,8 +1,11 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Iterable, Iterator, List, Literal, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Literal, Optional, Sequence, Tuple + +import pandas as pd +from sqlalchemy import Column from datapipe.run_config import RunConfig -from datapipe.types import ChangeList, DataSchema, HashDF, IndexDF, MetadataDF, MetaSchema +from datapipe.types import ChangeList, DataSchema, FieldAccessor, HashDF, IndexDF, MetadataDF, MetaSchema if TYPE_CHECKING: from datapipe.compute import ComputeInput @@ -166,10 +169,72 @@ def reset_metadata( # with self.meta_dbconn.con.begin() as con: # con.execute(self.meta.sql_table.update().values(process_ts=0, update_ts=0)) + def transform_idx_to_table_idx( + self, + transform_idx: IndexDF, + keys: Optional[Dict[str, FieldAccessor]] = None, + ) -> IndexDF: + """ + Given an index dataframe with transform keys, return an index dataframe + with table keys, applying `keys` aliasing if provided. + + * `keys` is a mapping from table key to transform key + """ + + if keys is None: + return transform_idx + + table_key_cols: Dict[str, pd.Series] = {} + for transform_col in transform_idx.columns: + accessor = keys.get(transform_col) if keys is not None else transform_col + if isinstance(accessor, str): + table_key_cols[accessor] = transform_idx[transform_col] + else: + pass # skip non-meta fields + + return IndexDF(pd.DataFrame(table_key_cols)) + class TransformMeta: - primary_schema: DataSchema - primary_keys: List[str] + transform_keys_schema: DataSchema + transform_keys: List[str] + + @classmethod + def compute_transform_schema( + cls, + input_cis: Sequence["ComputeInput"], + output_dts: Sequence["DataTable"], + transform_keys: Optional[List[str]], + ) -> Tuple[List[str], MetaSchema]: + # Hacky way to collect all the primary keys into a single set. Possible + # problem that is not handled here is that theres a possibility that the + # same key is defined differently in different input tables. + all_keys: Dict[str, Column] = {} + + for ci in input_cis: + all_keys.update({col.name: col for col in ci.primary_schema}) + + for dt in output_dts: + all_keys.update({col.name: col for col in dt.primary_schema}) + + if transform_keys is not None: + return (transform_keys, [all_keys[k] for k in transform_keys]) + + assert len(input_cis) > 0, "At least one input table is required to infer transform keys" + + inp_p_keys = set.intersection(*[set(inp.primary_keys) for inp in input_cis]) + assert len(inp_p_keys) > 0 + + if len(output_dts) == 0: + return (list(inp_p_keys), [all_keys[k] for k in inp_p_keys]) + + out_p_keys = set.intersection(*[set(out.primary_keys) for out in output_dts]) + assert len(out_p_keys) > 0 + + inp_out_p_keys = set.intersection(inp_p_keys, out_p_keys) + assert len(inp_out_p_keys) > 0 + + return (list(inp_out_p_keys), [all_keys[k] for k in inp_out_p_keys]) def get_changed_idx_count( self, diff --git a/datapipe/meta/sql_meta.py b/datapipe/meta/sql_meta.py index 4e8f20b3..a1fa0787 100644 --- a/datapipe/meta/sql_meta.py +++ b/datapipe/meta/sql_meta.py @@ -26,10 +26,13 @@ from datapipe.run_config import LabelDict, RunConfig from datapipe.sql_util import sql_apply_idx_filter_to_table, sql_apply_runconfig_filter from datapipe.store.database import DBConn, MetaKey +from datapipe.store.table_store import TableStore from datapipe.types import ( ChangeList, DataDF, + DataField, DataSchema, + FieldAccessor, HashDF, IndexDF, MetadataDF, @@ -357,35 +360,71 @@ def get_stale_idx( def get_agg_cte( self, transform_keys: List[str], + table_store: TableStore, + keys: Dict[str, FieldAccessor], filters_idx: Optional[IndexDF] = None, run_config: Optional[RunConfig] = None, ) -> Tuple[List[str], Any]: """ - Create a CTE that aggregates the table by transform keys and returns the - maximum update_ts for each group. + Create a CTE that aggregates the table by transform keys, applies keys + aliasing and returns the maximum update_ts for each group. + + * `keys` is a mapping from transform key to table key accessor + (can be string for meta table column or DataField for data table + column) + * `transform_keys` is a list of keys used in the transformation CTE has the following columns: * transform keys which are present in primary keys * update_ts - Returns a tuple of (keys, CTE). + Returns a tuple of (keys, CTE) where * keys is a list of transform keys + present in primary keys of this CTE """ - tbl = self.sql_table + from datapipe.store.database import TableStoreDB + + meta_table = self.sql_table + data_table = None + + key_cols: list[Any] = [] + cte_transform_keys: List[str] = [] + should_join_data_table = False + + for transform_key in transform_keys: + # TODO convert to match when we deprecate Python 3.9 + accessor = keys.get(transform_key, transform_key) + if isinstance(accessor, str): + if accessor in self.primary_keys: + key_cols.append(meta_table.c[accessor].label(transform_key)) + cte_transform_keys.append(transform_key) + elif isinstance(accessor, DataField): + should_join_data_table = True + assert isinstance(table_store, TableStoreDB) + data_table = table_store.data_table - keys = [k for k in transform_keys if k in self.primary_keys] - key_cols: List[Any] = [sa.column(k) for k in keys] + key_cols.append(data_table.c[accessor.field_name].label(transform_key)) + cte_transform_keys.append(transform_key) - sql: Any = sa.select(*key_cols + [sa.func.max(tbl.c["update_ts"]).label("update_ts")]).select_from(tbl) + sql: Any = sa.select(*key_cols + [sa.func.max(meta_table.c["update_ts"]).label("update_ts")]).select_from( + meta_table + ) + + if should_join_data_table: + assert data_table is not None + sql = sql.join( + data_table, + sa.and_(*[meta_table.c[pk] == data_table.c[pk] for pk in self.primary_keys]), + ) if len(key_cols) > 0: sql = sql.group_by(*key_cols) - sql = sql_apply_filters_idx_to_subquery(sql, keys, filters_idx) - sql = sql_apply_runconfig_filter(sql, tbl, self.primary_keys, run_config) + sql = sql_apply_filters_idx_to_subquery(sql, cte_transform_keys, filters_idx) # ??? + sql = sql_apply_runconfig_filter(sql, meta_table, self.primary_keys, run_config) # ??? - return (keys, sql.cte(name=f"{tbl.name}__update")) + return (cte_transform_keys, sql.cte(name=f"{meta_table.name}__update")) TRANSFORM_META_SCHEMA: DataSchema = [ @@ -396,19 +435,13 @@ def get_agg_cte( ] -@dataclass -class SQLMetaComputeInput: - table: "SQLTableMeta" - join_type: Literal["inner", "full"] = "full" - - class SQLTransformMeta(TransformMeta): def __init__( self, dbconn: DBConn, name: str, - input_mts: Sequence[SQLMetaComputeInput], - output_mts: Sequence[SQLTableMeta], + input_cis: Sequence[ComputeInput], + output_dts: Sequence[DataTable], transform_keys: Optional[List[str]], order_by: Optional[List[str]] = None, order: Literal["asc", "desc"] = "asc", @@ -417,16 +450,16 @@ def __init__( self.dbconn = dbconn self.name = name - self.input_mts = input_mts - self.output_mts = output_mts + self.input_cis = input_cis + self.output_dts = output_dts - self.primary_keys, self.primary_schema = compute_transform_schema( - input_mts=self.input_mts, - output_mts=self.output_mts, + self.transform_keys, self.transform_keys_schema = self.compute_transform_schema( + input_cis=self.input_cis, + output_dts=self.output_dts, transform_keys=transform_keys, ) - self.sql_schema = [i._copy() for i in self.primary_schema + TRANSFORM_META_SCHEMA] + self.sql_schema = [i._copy() for i in self.transform_keys_schema + TRANSFORM_META_SCHEMA] self.sql_table = sa.Table( name, @@ -447,9 +480,9 @@ def __reduce__(self) -> Tuple[Any, ...]: return self.__class__, ( self.dbconn, self.name, - self.input_mts, - self.output_mts, - self.primary_keys, + self.input_cis, + self.output_dts, + self.transform_keys, self.order_by, self.order, ) @@ -459,7 +492,7 @@ def get_changed_idx_count( ds: "DataStore", run_config: Optional[RunConfig] = None, ) -> int: - _, sql = self._build_changed_idx_sql(ds=ds, run_config=run_config) + sql = self._build_changed_idx_sql(ds=ds, run_config=run_config) with ds.meta_dbconn.con.begin() as con: idx_count = con.execute( @@ -475,7 +508,7 @@ def get_full_process_ids( run_config: Optional[RunConfig] = None, ) -> Tuple[int, Iterable[IndexDF]]: with tracer.start_as_current_span("compute ids to process"): - if len(self.input_mts) == 0: + if len(self.input_cis) == 0: return (0, iter([])) idx_count = self.get_changed_idx_count( @@ -483,7 +516,7 @@ def get_full_process_ids( run_config=run_config, ) - join_keys, u1 = self._build_changed_idx_sql( + u1 = self._build_changed_idx_sql( ds=ds, run_config=run_config, order_by=self.order_by, @@ -493,7 +526,7 @@ def get_full_process_ids( # Список ключей из фильтров, которые нужно добавить в результат extra_filters: LabelDict if run_config is not None: - extra_filters = {k: v for k, v in run_config.filters.items() if k not in join_keys} + extra_filters = {k: v for k, v in run_config.filters.items() if k not in self.transform_keys} else: extra_filters = {} @@ -501,7 +534,7 @@ def alter_res_df(): with ds.meta_dbconn.con.begin() as con: for df in pd.read_sql_query(u1, con=con, chunksize=chunk_size): assert isinstance(df, pd.DataFrame) - df = df[self.primary_keys] + df = df[self.transform_keys] for k, v in extra_filters.items(): df[k] = v @@ -518,16 +551,16 @@ def get_change_list_process_ids( run_config: Optional[RunConfig] = None, ) -> Tuple[int, Iterable[IndexDF]]: with tracer.start_as_current_span("compute ids to process"): - changes = [pd.DataFrame(columns=self.primary_keys)] + changes = [pd.DataFrame(columns=self.transform_keys)] - for inp in self.input_mts: - if inp.table.name in change_list.changes: - idx = change_list.changes[inp.table.name] - if any([key not in idx.columns for key in self.primary_keys]): + for inp in self.input_cis: + if inp.dt.name in change_list.changes: + idx = change_list.changes[inp.dt.name] + if any([key not in idx.columns for key in self.transform_keys]): # TODO пересмотреть эту логику, выглядит избыточной # (возможно, достаточно посчитать один раз для всех # input таблиц) - _, sql = self._build_changed_idx_sql( + sql = self._build_changed_idx_sql( ds=ds, filters_idx=idx, run_config=run_config, @@ -537,14 +570,14 @@ def get_change_list_process_ids( sql, con=con, ) - table_changes_df = table_changes_df[self.primary_keys] + table_changes_df = table_changes_df[self.transform_keys] changes.append(table_changes_df) else: - changes.append(data_to_index(idx, self.primary_keys)) + changes.append(data_to_index(idx, self.transform_keys)) - idx_df = pd.concat(changes).drop_duplicates(subset=self.primary_keys) - idx = IndexDF(idx_df[self.primary_keys]) + idx_df = pd.concat(changes).drop_duplicates(subset=self.transform_keys) + idx = IndexDF(idx_df[self.transform_keys]) chunk_count = math.ceil(len(idx) / chunk_size) @@ -558,7 +591,7 @@ def insert_rows( self, idx: IndexDF, ) -> None: - idx = cast(IndexDF, idx[self.primary_keys]) + idx = cast(IndexDF, idx[self.transform_keys]) insert_sql = self.dbconn.insert(self.sql_table).values( [ @@ -573,7 +606,7 @@ def insert_rows( ] ) - sql = insert_sql.on_conflict_do_nothing(index_elements=self.primary_keys) + sql = insert_sql.on_conflict_do_nothing(index_elements=self.transform_keys) with self.dbconn.con.begin() as con: con.execute(sql) @@ -585,7 +618,7 @@ def mark_rows_processed_success( run_config: Optional[RunConfig] = None, ) -> None: idx = cast( - IndexDF, idx[self.primary_keys].drop_duplicates().dropna() + IndexDF, idx[self.transform_keys].drop_duplicates().dropna() ) # FIXME: сделать в основном запросе distinct if len(idx) == 0: return @@ -628,7 +661,7 @@ def mark_rows_processed_success( ) sql = insert_sql.on_conflict_do_update( - index_elements=self.primary_keys, + index_elements=self.transform_keys, set_={ "process_ts": process_ts, "is_success": True, @@ -648,7 +681,7 @@ def mark_rows_processed_error( run_config: Optional[RunConfig] = None, ) -> None: idx = cast( - IndexDF, idx[self.primary_keys].drop_duplicates().dropna() + IndexDF, idx[self.transform_keys].drop_duplicates().dropna() ) # FIXME: сделать в основном запросе distinct if len(idx) == 0: return @@ -667,7 +700,7 @@ def mark_rows_processed_error( ) sql = insert_sql.on_conflict_do_update( - index_elements=self.primary_keys, + index_elements=self.transform_keys, set_={ "process_ts": process_ts, "is_success": False, @@ -703,7 +736,7 @@ def mark_all_rows_unprocessed( .where(self.sql_table.c.is_success == True) # noqa: E712 ) - sql = sql_apply_runconfig_filter(update_sql, self.sql_table, self.primary_keys, run_config) + sql = sql_apply_runconfig_filter(update_sql, self.sql_table, self.transform_keys, run_config) # execute with self.dbconn.con.begin() as con: @@ -716,15 +749,20 @@ def _build_changed_idx_sql( order_by: Optional[List[str]] = None, order: Literal["asc", "desc"] = "asc", run_config: Optional[RunConfig] = None, # TODO remove - ) -> Tuple[Iterable[str], Any]: + ) -> Any: all_input_keys_counts: Dict[str, int] = {} - for col in itertools.chain(*[inp.table.primary_schema for inp in self.input_mts]): + for col in itertools.chain(*[inp.dt.primary_schema for inp in self.input_cis]): all_input_keys_counts[col.name] = all_input_keys_counts.get(col.name, 0) + 1 inp_ctes = [] - for inp in self.input_mts: - keys, cte = inp.table.get_agg_cte( - transform_keys=self.primary_keys, + for inp in self.input_cis: + inp_meta = inp.dt.meta + assert isinstance(inp_meta, SQLTableMeta) + + keys, cte = inp_meta.get_agg_cte( + transform_keys=self.transform_keys, + table_store=inp.dt.table_store, + keys=inp.keys or {}, filters_idx=filters_idx, run_config=run_config, ) @@ -732,7 +770,7 @@ def _build_changed_idx_sql( agg_of_aggs = _make_agg_of_agg( ds=ds, - transform_keys=self.primary_keys, + transform_keys=self.transform_keys, ctes=inp_ctes, agg_col="update_ts", ) @@ -740,30 +778,30 @@ def _build_changed_idx_sql( tr_tbl = self.sql_table out: Any = ( sa.select( - *[sa.column(k) for k in self.primary_keys] + *[sa.column(k) for k in self.transform_keys] + [tr_tbl.c.process_ts, tr_tbl.c.priority, tr_tbl.c.is_success] ) .select_from(tr_tbl) - .group_by(*[sa.column(k) for k in self.primary_keys]) + .group_by(*[sa.column(k) for k in self.transform_keys]) ) - out = sql_apply_filters_idx_to_subquery(out, self.primary_keys, filters_idx) + out = sql_apply_filters_idx_to_subquery(out, self.transform_keys, filters_idx) out = out.cte(name="transform") - if len(self.primary_keys) == 0: + if len(self.transform_keys) == 0: join_onclause_sql: Any = sa.literal(True) - elif len(self.primary_keys) == 1: - join_onclause_sql = agg_of_aggs.c[self.primary_keys[0]] == out.c[self.primary_keys[0]] + elif len(self.transform_keys) == 1: + join_onclause_sql = agg_of_aggs.c[self.transform_keys[0]] == out.c[self.transform_keys[0]] else: # len(transform_keys) > 1: - join_onclause_sql = sa.and_(*[agg_of_aggs.c[key] == out.c[key] for key in self.primary_keys]) + join_onclause_sql = sa.and_(*[agg_of_aggs.c[key] == out.c[key] for key in self.transform_keys]) sql = ( sa.select( # Нам нужно выбирать хотя бы что-то, чтобы не было ошибки при # пустом transform_keys sa.literal(1).label("_datapipe_dummy"), - *[sa.func.coalesce(agg_of_aggs.c[key], out.c[key]).label(key) for key in self.primary_keys], + *[sa.func.coalesce(agg_of_aggs.c[key], out.c[key]).label(key) for key in self.transform_keys], ) .select_from(agg_of_aggs) .outerjoin( @@ -785,7 +823,7 @@ def _build_changed_idx_sql( if order_by is None: sql = sql.order_by( out.c.priority.desc().nullslast(), - *[sa.column(k) for k in self.primary_keys], + *[sa.column(k) for k in self.transform_keys], ) else: if order == "desc": @@ -798,7 +836,7 @@ def _build_changed_idx_sql( *[sa.asc(sa.column(k)) for k in order_by], out.c.priority.desc().nullslast(), ) - return (self.primary_keys, sql) + return sql def sql_apply_filters_idx_to_subquery( @@ -890,41 +928,6 @@ def _make_agg_of_agg( return sql.cte(name=f"all__{agg_col}") -def compute_transform_schema( - input_mts: Sequence[SQLMetaComputeInput], - output_mts: Sequence[SQLTableMeta], - transform_keys: Optional[List[str]], -) -> Tuple[List[str], MetaSchema]: - # Hacky way to collect all the primary keys into a single set. Possible - # problem that is not handled here is that theres a possibility that the - # same key is defined differently in different input tables. - all_keys = { - col.name: col - for col in itertools.chain( - *([inp.table.primary_schema for inp in input_mts] + [dt.primary_schema for dt in output_mts]) - ) - } - - if transform_keys is not None: - return (transform_keys, [all_keys[k] for k in transform_keys]) - - assert len(input_mts) > 0 - - inp_p_keys = set.intersection(*[set(inp.table.primary_keys) for inp in input_mts]) - assert len(inp_p_keys) > 0 - - if len(output_mts) == 0: - return (list(inp_p_keys), [all_keys[k] for k in inp_p_keys]) - - out_p_keys = set.intersection(*[set(out.primary_keys) for out in output_mts]) - assert len(out_p_keys) > 0 - - inp_out_p_keys = set.intersection(inp_p_keys, out_p_keys) - assert len(inp_out_p_keys) > 0 - - return (list(inp_out_p_keys), [all_keys[k] for k in inp_out_p_keys]) - - class SQLMetaPlane(MetaPlane): def __init__(self, dbconn: DBConn, create_meta_table: bool = False) -> None: self.dbconn = dbconn @@ -953,26 +956,11 @@ def create_transform_meta( order_by: Optional[List[str]] = None, order: Literal["asc", "desc"] = "asc", ) -> TransformMeta: - input_mts = [] - for inp in input_dts: - assert isinstance(inp.dt.meta, SQLTableMeta) - input_mts.append( - SQLMetaComputeInput( - table=inp.dt.meta, - join_type=inp.join_type, - ) - ) - - output_mts = [] - for out in output_dts: - assert isinstance(out.meta, SQLTableMeta) - output_mts.append(out.meta) - return SQLTransformMeta( dbconn=self.dbconn, name=name, - input_mts=input_mts, - output_mts=output_mts, + input_cis=input_dts, + output_dts=output_dts, transform_keys=transform_keys, order_by=order_by, order=order, diff --git a/datapipe/step/batch_transform.py b/datapipe/step/batch_transform.py index c35ca08f..6d550b38 100644 --- a/datapipe/step/batch_transform.py +++ b/datapipe/step/batch_transform.py @@ -34,7 +34,7 @@ ChangeList, DataDF, IndexDF, - JoinSpec, + InputSpec, Labels, PipelineInput, Required, @@ -117,7 +117,7 @@ def __init__( order=order, ) - self.transform_keys, self.transform_schema = self.meta.primary_keys, self.meta.primary_schema + self.transform_keys, self.transform_schema = self.meta.transform_keys, self.meta.transform_keys_schema self.filters = filters self.order_by = order_by @@ -271,7 +271,8 @@ def get_batch_input_dfs( idx: IndexDF, run_config: Optional[RunConfig] = None, ) -> List[DataDF]: - return [inp.dt.get_data(idx) for inp in self.input_dts] + # TODO consider parallel fetch through executor + return [inp.dt.get_data(inp.dt.meta.transform_idx_to_table_idx(idx, inp.keys)) for inp in self.input_dts] def process_batch_dfs( self, @@ -498,12 +499,13 @@ def pipeline_input_to_compute_input(self, ds: DataStore, catalog: Catalog, input return ComputeInput( dt=catalog.get_datatable(ds, input.table), join_type="inner", + keys=input.keys, ) - elif isinstance(input, JoinSpec): - # This should not happen, but just in case + elif isinstance(input, InputSpec): return ComputeInput( dt=catalog.get_datatable(ds, input.table), join_type="full", + keys=input.keys, ) else: return ComputeInput(dt=catalog.get_datatable(ds, input), join_type="full") diff --git a/datapipe/types.py b/datapipe/types.py index f0af4044..5fb418a2 100644 --- a/datapipe/types.py +++ b/datapipe/types.py @@ -9,6 +9,7 @@ Dict, List, NewType, + Optional, Set, Tuple, Type, @@ -61,16 +62,33 @@ @dataclass -class JoinSpec: +class DataField: + field_name: str + + +FieldAccessor = Union[str, DataField] + + +@dataclass +class InputSpec: table: TableOrName + # If provided, this dict tells how to get key columns from meta and data tables + # + # Example: {"idx_col": DataField("data_col")} means that to get idx_col value + # we should read data_col from data table + # + # Example: {"idx_col": "meta_col"} means that to get idx_col value + # we should read meta_col from meta table + keys: Optional[Dict[str, FieldAccessor]] = None + @dataclass -class Required(JoinSpec): +class Required(InputSpec): pass -PipelineInput = Union[TableOrName, JoinSpec] +PipelineInput = Union[TableOrName, InputSpec] @dataclass diff --git a/design-docs/2025-12-key-mapping.md b/design-docs/2025-12-key-mapping.md new file mode 100644 index 00000000..3bfc1a54 --- /dev/null +++ b/design-docs/2025-12-key-mapping.md @@ -0,0 +1,105 @@ +# Open questions + +* How to deal with ChangeList filter_idx? Before this change we assume that idx + mean the same everywhere and just filter by given idx name and values, now we + should understand how table keys are propagated to join +* How to deal with RunConfig filters? This feature in general is not very + transparent to me, we have to understand it and do something, probably similar + to filter_idx + +# Goal + +Make it possible to join tables in transformation where key in one and another +table do not match by name. + +# Use case + +You have tables `User (id: PK)` and `Subscription (id: PK, user_id: DATA, sub_user_id: DATA)` +You need to enrich both sides of `Subscription` with information + +You might write: + +``` +BatchTransform( + process_func, + + # defines ["user_id", "sub_user_id"] as a keys that identify each transform task + # every table should have a way to join to these keys + transform_keys=["user_id", "sub_user_id"], + inputs=[ + # Subscription has needed columns in data table, we fetch them from there + InputSpec(Subscription, keys={"user_id": DataField("user_id"), "sub_user_id": DataField("sub_user_id")}), + + # matches tr.user_id = User.id + InputSpec(User, keys={"user_id": "id"}), + + # matches tr.sub_user_id = User.id + InputSpec(User, keys={"sub_user_id": "id"}), + ], + outputs=[...], +) +``` + +And `process_func` at each execution will receive three dataframes: + +* `subscription_df` - chunk of `Subscription` +* `user_df` - chunk of `User` matched by `user_id` +* `sub_user_df` - chunk of `User` matched by `sub_user_id` + +Both `user_df` and `sub_user_df` have columns aligned with `User` table, i.e. +without renamings, it is up to end user to interpret the data. + +# InputSpec + +We introduce `InputSpec` qualifier for `BatchTransform` inputs. + +`keys` parameter defines which columns to use for this input table and where to +get them from. `keys` is a dict in a form `{"{transform_key}": key_accessor}`, +where `key_accessor` might be: +* a string, then a column from meta-table is used with possible renaming +* `DataField("data_col")` then a `data_col` from data-table is used instead of + meta-table + +If table is provided as is without `InputSpec` wrapper, then it is equivalent to +`InputSpec(Table, join_type="outer", keys={"id1": "id1", ...})`, join type is +outer join and all keys are mapped to themselves. + +## DataField limitations + +`DataField` accessor serves as an ad-hoc solution for a situation when for some +reason a data field can not be promoted to a meta-field. + +Data fields are not used when retreiving a chunk of data, so it is possible to +over-fetch data. + +Data fields are not enforced to have indices in DB, so their usage might be very +heavy for database. + + +# Implementation + +## DX + +* [x] `datapipe.types.JoinSpec` is renamed to `InputSpec` and receives `keys` + parameter + +## Compute + +* [x] `datapipe.compute.ComputeInput` receives `keys` parameter + +`datapipe.meta.sql_meta.SQLTableMeta`: +* [x] new method `transform_idx_to_table_idx` which should be used to convert + transform keys to table keys +* [x] `get_agg_cte` receives `keys` parameter and starts producing subquery with + renamed keys +* [ ] `get_agg_cte` correctly applies `keys` to `filter_idx` parameter +* [ ] `get_agg_cte` correctly applies `keys` to `RunConfig` filters + +`BatchTransform`: +* [x] correctly converts transform idx to table idx in `get_batch_input_dfs` +* [x] inputs and outputs are stored as `ComputeInput` lists, because we need + data table for `DataField` + +`DataTable`: +* [x] `DataTable.get_data` accepts `table_idx` which is acquired by applying + `tranform_idx_to_table_idx` diff --git a/design-docs/2025-12-meta-plane.md b/design-docs/2025-12-meta-plane.md index 0a6ca4de..deaea61d 100644 --- a/design-docs/2025-12-meta-plane.md +++ b/design-docs/2025-12-meta-plane.md @@ -1,4 +1,6 @@ -Main idea: separate all metadata management to separate package so, that +# Goal + +Separate all metadata management to separate package so, that inteface between compute/execution-plane and meta would be relatively narrow and we could create alternative meta-plane implementation diff --git a/tests/test_meta_transform_keys.py b/tests/test_meta_transform_keys.py new file mode 100644 index 00000000..6b25ca10 --- /dev/null +++ b/tests/test_meta_transform_keys.py @@ -0,0 +1,164 @@ +import time + +import pandas as pd +from sqlalchemy import Column, String + +from datapipe.compute import ComputeInput +from datapipe.datatable import DataStore +from datapipe.step.batch_transform import BatchTransformStep +from datapipe.store.database import DBConn, TableStoreDB +from datapipe.tests.util import assert_datatable_equal +from datapipe.types import DataField + + +def test_transform_keys(dbconn: DBConn): + """ + Проверяет что трансформация с keys (InputSpec) корректно отрабатывает. + + Сценарий: + 1. Создаём posts и profiles (profiles с keys={'user_id': 'id'}) + """ + ds = DataStore(dbconn, create_meta_table=True) + + # 1. Создать posts таблицу (используем String для id чтобы совпадать с мета-таблицей) + posts_store = TableStoreDB( + dbconn, + "posts", + [ + Column("id", String, primary_key=True), + Column("user_id", String), + Column("content", String), + ], + create_table=True, + ) + posts = ds.create_table("posts", posts_store) + + # 2. Создать profiles таблицу (справочник) + profiles_store = TableStoreDB( + dbconn, + "profiles", + [ + Column("id", String, primary_key=True), + Column("username", String), + ], + create_table=True, + ) + profiles = ds.create_table("profiles", profiles_store) + + # 3. Создать output таблицу (id - primary key, остальное - данные) + output_store = TableStoreDB( + dbconn, + "posts_with_username", + [ + Column("id", String, primary_key=True), + Column("user_id", String), # Обычная колонка, не primary key + Column("content", String), + Column("username", String), + ], + create_table=True, + ) + output_dt = ds.create_table("posts_with_username", output_store) + + # 4. Добавить данные + process_ts = time.time() + + # 3 поста от 2 пользователей + posts_df = pd.DataFrame( + [ + {"id": "1", "user_id": "1", "content": "Post 1"}, + {"id": "2", "user_id": "1", "content": "Post 2"}, + {"id": "3", "user_id": "2", "content": "Post 3"}, + ] + ) + posts.store_chunk(posts_df, now=process_ts) + + # 2 профиля + profiles_df = pd.DataFrame( + [ + {"id": "1", "username": "alice"}, + {"id": "2", "username": "bob"}, + ] + ) + profiles.store_chunk(profiles_df, now=process_ts) + + # 5. Создать трансформацию с keys + def transform_func(posts_df, profiles_df): + # JOIN posts + profiles + result = posts_df.merge(profiles_df, left_on="user_id", right_on="id", suffixes=("", "_profile")) + return result[["id", "user_id", "content", "username"]] + + step = BatchTransformStep( + ds=ds, + name="test_transform", + func=transform_func, + input_dts=[ + ComputeInput( + dt=posts, + join_type="full", + keys={ + "post_id": "id", + "user_id": DataField("user_id"), + }, + ), + ComputeInput( + dt=profiles, + join_type="inner", + keys={ + "user_id": "id", + }, + ), + ], + output_dts=[output_dt], + transform_keys=["post_id", "user_id"], + ) + + # 6. Запустить трансформацию + print("\n🚀 Running initial transformation...") + step.run_full(ds) + + # Проверяем результаты трансформации + assert_datatable_equal( + output_dt, + pd.DataFrame( + [ + {"id": "1", "user_id": "1", "content": "Post 1", "username": "alice"}, + {"id": "2", "user_id": "1", "content": "Post 2", "username": "alice"}, + {"id": "3", "user_id": "2", "content": "Post 3", "username": "bob"}, + ] + ), + ) + + # 8. Добавим новые данные и проверим инкрементальную обработку + time.sleep(0.01) # Небольшая задержка для различения timestamp'ов + process_ts2 = time.time() + + # Добавляем 1 новый пост + new_posts_df = pd.DataFrame( + [ + {"id": "4", "user_id": "1", "content": "New Post 4"}, + ] + ) + posts.store_chunk(new_posts_df, now=process_ts2) + + # Добавляем 1 новый профиль + new_profiles_df = pd.DataFrame( + [ + {"id": "3", "username": "charlie"}, + ] + ) + profiles.store_chunk(new_profiles_df, now=process_ts2) + + # 9. Запускаем инкрементальную обработку + step.run_full(ds) + + assert_datatable_equal( + output_dt, + pd.DataFrame( + [ + {"id": "1", "user_id": "1", "content": "Post 1", "username": "alice"}, + {"id": "2", "user_id": "1", "content": "Post 2", "username": "alice"}, + {"id": "3", "user_id": "2", "content": "Post 3", "username": "bob"}, + {"id": "4", "user_id": "1", "content": "New Post 4", "username": "alice"}, + ] + ), + ) diff --git a/tests/test_transform_meta.py b/tests/test_transform_meta.py index 8f2652d1..d9a9f297 100644 --- a/tests/test_transform_meta.py +++ b/tests/test_transform_meta.py @@ -4,16 +4,31 @@ from pytest_cases import parametrize from sqlalchemy import Column, Integer -from datapipe.meta.sql_meta import SQLMetaComputeInput, SQLTableMeta, compute_transform_schema -from datapipe.store.database import DBConn +from datapipe.compute import ComputeInput +from datapipe.datatable import DataTable +from datapipe.event_logger import EventLogger +from datapipe.meta.base import TransformMeta +from datapipe.meta.sql_meta import SQLTableMeta +from datapipe.store.database import DBConn, TableStoreDB from datapipe.types import MetaSchema -def make_mt(name, dbconn, schema_keys) -> SQLTableMeta: - return SQLTableMeta( - dbconn=dbconn, +def make_dt(name, dbconn, schema_keys) -> DataTable: + schema = [Column(key, Integer(), primary_key=True) for key in schema_keys] + + return DataTable( name=name, - primary_schema=[Column(key, Integer(), primary_key=True) for key in schema_keys], + meta=SQLTableMeta( + dbconn=dbconn, + name=f"{name}_meta", + primary_schema=schema, + ), + table_store=TableStoreDB( + dbconn=dbconn, + name=name, + data_sql_schema=schema, + ), + event_logger=EventLogger(), ) @@ -66,10 +81,13 @@ def test_compute_transform_schema_success( transform_keys, expected_keys, ): - inp_mts = [SQLMetaComputeInput(make_mt(f"inp_{i}", dbconn, keys)) for (i, keys) in enumerate(input_keys_list)] - out_mts = [make_mt(f"out_{i}", dbconn, keys) for (i, keys) in enumerate(output_keys_list)] + inp_cis = [ + ComputeInput(make_dt(f"inp_{i}", dbconn, keys), join_type="full", keys=None) + for (i, keys) in enumerate(input_keys_list) + ] + out_dts = [make_dt(f"out_{i}", dbconn, keys) for (i, keys) in enumerate(output_keys_list)] - _, sch = compute_transform_schema(inp_mts, out_mts, transform_keys=transform_keys) + _, sch = TransformMeta.compute_transform_schema(inp_cis, out_dts, transform_keys=transform_keys) assert_schema_equals(sch, expected_keys) @@ -81,8 +99,10 @@ def test_compute_transform_schema_fail( output_keys_list, transform_keys, ): - inp_mts = [SQLMetaComputeInput(make_mt(f"inp_{i}", dbconn, keys)) for (i, keys) in enumerate(input_keys_list)] - out_mts = [make_mt(f"out_{i}", dbconn, keys) for (i, keys) in enumerate(output_keys_list)] - + inp_cis = [ + ComputeInput(make_dt(f"inp_{i}", dbconn, keys), join_type="full", keys=None) + for (i, keys) in enumerate(input_keys_list) + ] + out_dts = [make_dt(f"out_{i}", dbconn, keys) for (i, keys) in enumerate(output_keys_list)] with pytest.raises(AssertionError): - compute_transform_schema(inp_mts, out_mts, transform_keys=transform_keys) + TransformMeta.compute_transform_schema(inp_cis, out_dts, transform_keys=transform_keys)