Skip to content

Commit

Permalink
use local functions; add rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
nkshaw23 committed Oct 25, 2024
1 parent b2f1cf2 commit 0d49be4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 34 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"moto",
"polars",
"s3fs",
"setuptools",
"tenacity",
]
name = "dri-utils"
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -20,45 +20,14 @@
You (might) need extended permissions to write the test data to s3.
"""

import datetime
import random
from datetime import date, datetime, timedelta
from typing import Optional, Tuple, Union
from datetime import date, timedelta

import duckdb
import polars as pl
import s3fs


def steralize_dates(
start_date: Union[date, datetime], end_date: Optional[Union[date, datetime]]
) -> Tuple[Union[date, datetime], datetime]:
"""
Configures and validates start and end dates.
Args:
start_date: The start date.
end_date: The end date.
Returns:
A tuple containing the start date and the end date.
Raises:
UserWarning: If the start date is after the end date.
"""
# Ensure the start_date is not after the end_date
if start_date > end_date:
raise UserWarning(f"Start date must come before end date: {start_date} > {end_date}")

# If start_date is of type date, convert it to datetime with time at start of the day
if isinstance(start_date, date):
start_date = datetime.combine(start_date, datetime.min.time())

# If end_date is of type date, convert it to datetime to include the entire day
if isinstance(end_date, date):
end_date = datetime.combine(end_date, datetime.max.time())

return start_date, end_date
from driutils.datetime import steralize_date_range


def write_parquet_s3(bucket: str, key: str, data: pl.DataFrame) -> None:
Expand Down Expand Up @@ -94,7 +63,7 @@ def build_test_precip_data(
test_data = pl.DataFrame(schema=schema)

# Format dates
start_date, end_date = steralize_dates(start_date, end_date)
start_date, end_date = steralize_date_range(start_date, end_date)

# Build datetime range series
datetime_range = pl.datetime_range(start_date, end_date, interval, eager=True).alias("time")
Expand Down Expand Up @@ -124,6 +93,7 @@ def build_test_precip_data(

if isinstance(dtype, pl.Int64):
col_values = pl.Series(column, [random.randrange(1, 255, 1) for i in range(required_rows)])
col_values.round(3)

test_data.replace_column(test_data.get_column_index(column), col_values)

Expand Down

0 comments on commit 0d49be4

Please sign in to comment.