diff --git a/pyproject.toml b/pyproject.toml index 3b6f569..f2dc660 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "moto", "polars", "s3fs", + "setuptools", "tenacity", ] name = "dri-utils" diff --git a/src/driutils/benchmarking/__init__.py b/src/driutils/benchmarking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/driutils/test_data/create_test_cosmos_data.py b/src/driutils/benchmarking/create_test_cosmos_data.py similarity index 83% rename from src/driutils/test_data/create_test_cosmos_data.py rename to src/driutils/benchmarking/create_test_cosmos_data.py index 0699eba..fe4275a 100644 --- a/src/driutils/test_data/create_test_cosmos_data.py +++ b/src/driutils/benchmarking/create_test_cosmos_data.py @@ -20,45 +20,14 @@ You (might) need extended permissions to write the test data to s3. """ -import datetime import random -from datetime import date, datetime, timedelta -from typing import Optional, Tuple, Union +from datetime import date, timedelta import duckdb import polars as pl import s3fs - -def steralize_dates( - start_date: Union[date, datetime], end_date: Optional[Union[date, datetime]] -) -> Tuple[Union[date, datetime], datetime]: - """ - Configures and validates start and end dates. - - Args: - start_date: The start date. - end_date: The end date. - - Returns: - A tuple containing the start date and the end date. - - Raises: - UserWarning: If the start date is after the end date. - """ - # Ensure the start_date is not after the end_date - if start_date > end_date: - raise UserWarning(f"Start date must come before end date: {start_date} > {end_date}") - - # If start_date is of type date, convert it to datetime with time at start of the day - if isinstance(start_date, date): - start_date = datetime.combine(start_date, datetime.min.time()) - - # If end_date is of type date, convert it to datetime to include the entire day - if isinstance(end_date, date): - end_date = datetime.combine(end_date, datetime.max.time()) - - return start_date, end_date +from driutils.datetime import steralize_date_range def write_parquet_s3(bucket: str, key: str, data: pl.DataFrame) -> None: @@ -94,7 +63,7 @@ def build_test_precip_data( test_data = pl.DataFrame(schema=schema) # Format dates - start_date, end_date = steralize_dates(start_date, end_date) + start_date, end_date = steralize_date_range(start_date, end_date) # Build datetime range series datetime_range = pl.datetime_range(start_date, end_date, interval, eager=True).alias("time") @@ -124,6 +93,7 @@ def build_test_precip_data( if isinstance(dtype, pl.Int64): col_values = pl.Series(column, [random.randrange(1, 255, 1) for i in range(required_rows)]) + col_values.round(3) test_data.replace_column(test_data.get_column_index(column), col_values)