From 39dccf3f1e3830c7718b36ab2f85c02875cabfa7 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 19 Jul 2023 20:31:56 -0400 Subject: [PATCH] save --- src/data/dataset.py | 13 +++++----- src/data/dataset_duckdb.py | 6 ++++- src/data/dataset_select_groups_test.py | 30 +++++++++++++++++++++++ src/data/dataset_stats_test.py | 34 ++++++++++++++++++++++++++ src/data/dataset_test_utils.py | 3 +++ 5 files changed, 79 insertions(+), 7 deletions(-) diff --git a/src/data/dataset.py b/src/data/dataset.py index 46a011102..b900b09cd 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -1,8 +1,8 @@ """The interface for the database.""" import abc -import datetime import enum from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from typing import Any, Iterator, Literal, Optional, Sequence, Union import pandas as pd @@ -11,7 +11,7 @@ from pydantic import StrictBool, StrictBytes, StrictFloat, StrictInt, StrictStr, validator from ..embeddings.vector_store import VectorStore -from ..schema import VALUE_KEY, Bin, Path, PathTuple, Schema, normalize_path +from ..schema import VALUE_KEY, Bin, DataType, Path, PathTuple, Schema, normalize_path from ..signals.signal import Signal, resolve_signal from ..tasks import TaskStepId @@ -44,8 +44,8 @@ class StatsResult(BaseModel): approx_count_distinct: int # Defined for ordinal features. - min_val: Optional[Union[float, datetime.date, datetime.datetime]] - max_val: Optional[Union[float, datetime.date, datetime.datetime]] + min_val: Optional[Union[float, datetime]] + max_val: Optional[Union[float, datetime]] # Defined for text features. avg_text_length: Optional[float] @@ -187,7 +187,7 @@ def column_from_identifier(column: ColumnId) -> Column: return Column(path=column) -FeatureValue = Union[StrictInt, StrictFloat, StrictBool, StrictStr, StrictBytes] +FeatureValue = Union[StrictInt, StrictFloat, StrictBool, StrictStr, StrictBytes, datetime] FeatureListValue = list[StrictStr] BinaryFilterTuple = tuple[Path, BinaryOp, FeatureValue] ListFilterTuple = tuple[Path, ListOp, FeatureListValue] @@ -409,7 +409,8 @@ def media(self, item_id: str, leaf_path: Path) -> MediaResult: def default_settings(dataset: Dataset) -> DatasetSettings: """Gets the default settings for a dataset.""" - leaf_paths = dataset.manifest().data_schema.leafs.keys() + schema = dataset.manifest().data_schema + leaf_paths = [path for path, field in schema.leafs.items() if field.dtype == DataType.STRING] pool = ThreadPoolExecutor() stats: list[StatsResult] = list(pool.map(lambda leaf: dataset.stats(leaf), leaf_paths)) sorted_stats = sorted([stat for stat in stats if stat.avg_text_length], diff --git a/src/data/dataset_duckdb.py b/src/data/dataset_duckdb.py index 953816fcc..79b9b5433 100644 --- a/src/data/dataset_duckdb.py +++ b/src/data/dataset_duckdb.py @@ -41,6 +41,7 @@ is_float, is_integer, is_ordinal, + is_temporal, normalize_path, signal_compute_type_supports_dtype, ) @@ -602,7 +603,7 @@ def stats(self, leaf_path: Path) -> StatsResult: min_max_query = f""" SELECT MIN(val) AS minVal, MAX(val) AS maxVal FROM (SELECT {inner_select} as val FROM t) - WHERE NOT isnan(val) + {'WHERE NOT isnan(val)' if is_float(leaf.dtype) else ''} """ row = self._query(min_max_query)[0] result.min_val, result.max_val = row @@ -692,6 +693,9 @@ def select_groups( """ df = self._query_df(query) counts = list(df.itertuples(index=False, name=None)) + if is_temporal(leaf.dtype): + # Replace any NaT with None and pd.Timestamp to native datetime objects. + counts = [(None if pd.isnull(val) else val.to_pydatetime(), count) for val, count in counts] return SelectGroupsResult(too_many_distinct=False, counts=counts, bins=named_bins) def _topk_udf_to_sort_by( diff --git a/src/data/dataset_select_groups_test.py b/src/data/dataset_select_groups_test.py index 8ddbbed15..9962965bb 100644 --- a/src/data/dataset_select_groups_test.py +++ b/src/data/dataset_select_groups_test.py @@ -1,6 +1,7 @@ """Tests for dataset.select_groups().""" import re +from datetime import datetime import pytest from pytest_mock import MockerFixture @@ -255,6 +256,35 @@ def test_filters(make_test_data: TestDataMaker) -> None: assert result.counts == [(None, 1)] +def test_datetime(make_test_data: TestDataMaker) -> None: + items: list[Item] = [ + { + UUID_COLUMN: '1', + 'date': datetime(2023, 1, 1) + }, + { + UUID_COLUMN: '2', + 'date': datetime(2023, 1, 15) + }, + { + UUID_COLUMN: '2', + 'date': datetime(2023, 2, 1) + }, + { + UUID_COLUMN: '4', + 'date': datetime(2023, 3, 1) + }, + { + UUID_COLUMN: '5', + # Missing datetime. + } + ] + dataset = make_test_data(items) + result = dataset.select_groups('date') + assert result.counts == [(datetime(2023, 1, 1), 1), (datetime(2023, 1, 15), 1), + (datetime(2023, 2, 1), 1), (datetime(2023, 3, 1), 1), (None, 1)] + + def test_invalid_leaf(make_test_data: TestDataMaker) -> None: items: list[Item] = [ { diff --git a/src/data/dataset_stats_test.py b/src/data/dataset_stats_test.py index 14290f578..57cb3e236 100644 --- a/src/data/dataset_stats_test.py +++ b/src/data/dataset_stats_test.py @@ -1,5 +1,6 @@ """Tests for dataset.stats().""" +from datetime import datetime from typing import Any, cast import pytest @@ -123,3 +124,36 @@ def test_error_handling(make_test_data: TestDataMaker) -> None: with pytest.raises(ValueError, match="Path \\('unknown',\\) not found in schema"): dataset.stats(leaf_path='unknown') + + +def test_datetime(make_test_data: TestDataMaker) -> None: + items: list[Item] = [ + { + UUID_COLUMN: '1', + 'date': datetime(2023, 1, 1) + }, + { + UUID_COLUMN: '2', + 'date': datetime(2023, 1, 15) + }, + { + UUID_COLUMN: '2', + 'date': datetime(2023, 2, 1) + }, + { + UUID_COLUMN: '4', + 'date': datetime(2023, 3, 1) + }, + { + UUID_COLUMN: '5', + # Missing datetime. + } + ] + dataset = make_test_data(items) + result = dataset.stats('date') + assert result == StatsResult( + path=('date',), + total_count=4, + approx_count_distinct=4, + min_val=datetime(2023, 1, 1), + max_val=datetime(2023, 3, 1)) diff --git a/src/data/dataset_test_utils.py b/src/data/dataset_test_utils.py index 3e55f1c6a..9b1958952 100644 --- a/src/data/dataset_test_utils.py +++ b/src/data/dataset_test_utils.py @@ -1,6 +1,7 @@ """Tests utils of for dataset_test.""" import os import pathlib +from datetime import datetime from typing import Optional, Type, cast from typing_extensions import Protocol @@ -36,6 +37,8 @@ def _infer_dtype(value: Item) -> DataType: return DataType.FLOAT32 elif isinstance(value, int): return DataType.INT32 + elif isinstance(value, datetime): + return DataType.TIMESTAMP else: raise ValueError(f'Cannot infer dtype of primitive value: {value}')