Skip to content

Commit

Permalink
Adding schema param to from_records (#248)
Browse files Browse the repository at this point in the history
* adding schema param to from_records

* fixing the logic

* removed print

* fix test

* updated docstring

* updated docstring and error message

* changed behavior of empty records and no explicit schema

---------

Co-authored-by: ivan <ilongin@iterative.ai>
  • Loading branch information
ilongin and ivan authored Aug 20, 2024
1 parent 6fdc261 commit 70fe5a1
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
27 changes: 22 additions & 5 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,7 @@ def from_records(
to_insert: Optional[Union[dict, list[dict]]],
session: Optional[Session] = None,
in_memory: bool = False,
schema: Optional[dict[str, DataType]] = None,
) -> "DataChain":
"""Create a DataChain from the provided records. This method can be used for
programmatically generating a chain in contrast of reading data from storages
Expand All @@ -1532,22 +1533,38 @@ def from_records(
Parameters:
to_insert : records (or a single record) to insert. Each record is
a dictionary of signals and theirs values.
schema : describes chain signals and their corresponding types
Example:
```py
empty = DataChain.from_records()
single_record = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD)
```
"""
session = Session.get(session, in_memory=in_memory)
catalog = session.catalog

name = session.generate_temp_dataset_name()
columns: tuple[sqlalchemy.Column[Any], ...] = tuple(
sqlalchemy.Column(name, typ)
for name, typ in File._datachain_column_types.items()
signal_schema = None
columns: list[sqlalchemy.Column] = []

if schema:
signal_schema = SignalSchema(schema)
columns = signal_schema.db_signals(as_columns=True) # type: ignore[assignment]
else:
columns = [
sqlalchemy.Column(name, typ)
for name, typ in File._datachain_column_types.items()
]

dsr = catalog.create_dataset(
name,
columns=columns,
feature_schema=(
signal_schema.clone_without_sys_signals().serialize()
if signal_schema
else None
),
)
dsr = catalog.create_dataset(name, columns=columns)

if isinstance(to_insert, dict):
to_insert = [to_insert]
Expand Down
3 changes: 2 additions & 1 deletion src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_origin,
)

import sqlalchemy as sa
from pydantic import BaseModel, create_model
from typing_extensions import Literal as LiteralEx

Expand Down Expand Up @@ -232,7 +233,7 @@ def db_signals(
signals = [
DEFAULT_DELIMITER.join(path)
if not as_columns
else Column(DEFAULT_DELIMITER.join(path), python_to_sql(_type))
else sa.Column(DEFAULT_DELIMITER.join(path), python_to_sql(_type))
for path, _type, has_subtree, _ in self.get_flat_tree()
if not has_subtree
]
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,52 @@ def test_from_features(test_session):
assert t1 == features[i]


def test_from_records_empty_chain_with_schema(test_session):
schema = {"my_file": File, "my_col": int}
ds = DataChain.from_records([], schema=schema, session=test_session)
ds_sys = ds.settings(sys=True)

ds_name = "my_ds"
ds.save(ds_name)
ds = DataChain(name=ds_name)

assert isinstance(ds.feature_schema, dict)
assert isinstance(ds.signals_schema, SignalSchema)
assert ds.schema.keys() == {"my_file", "my_col"}
assert set(ds.schema.values()) == {File, int}
assert ds.count() == 0

# check that columns have actually been created from schema
dr = ds_sys.catalog.warehouse.dataset_rows(ds_sys.catalog.get_dataset(ds_name))
assert sorted([c.name for c in dr.c]) == sorted(ds.signals_schema.db_signals())


def test_from_records_empty_chain_without_schema(test_session):
ds = DataChain.from_records([], schema=None, session=test_session)
ds_sys = ds.settings(sys=True)

ds_name = "my_ds"
ds.save(ds_name)
ds = DataChain(name=ds_name)

assert ds.schema.keys() == {
"source",
"path",
"size",
"version",
"etag",
"is_latest",
"last_modified",
"location",
"vtype",
}
assert ds.count() == 0

# check that columns have actually been created from schema
dr = ds_sys.catalog.warehouse.dataset_rows(ds_sys.catalog.get_dataset(ds_name))
assert sorted([c.name for c in dr.c]) == sorted(ds.signals_schema.db_signals())


def test_datasets(test_session):
ds = DataChain.datasets(session=test_session)
datasets = [d for d in ds.collect("dataset") if d.name == "fibonacci"]
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/lib/test_signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Optional, Union

import pytest
from sqlalchemy import Column

from datachain import Column, DataModel
from datachain import DataModel
from datachain.lib.convert.flatten import flatten
from datachain.lib.file import File
from datachain.lib.signal_schema import (
Expand Down

0 comments on commit 70fe5a1

Please sign in to comment.