Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix company name #9

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 52 additions & 31 deletions inventory_foundation_sdk/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@

# %% ../nbs/50_custom_datasets.ipynb 5
import logging

logger = logging.getLogger(__name__)


class AddRowDataset(AbstractDataset):

"""
Adds or update one row to a SQL table, if it does not exist.

"""

def __init__(
Expand All @@ -40,8 +41,8 @@ def __init__(
column_names: t.List,
credentials: str,
unique_columns: t.List,
load_args = None,
save_args = None
load_args=None,
save_args=None,
):

self.unique_columns = unique_columns
Expand All @@ -50,7 +51,7 @@ def __init__(
self.db_credentials = credentials
self.save_args = save_args or {}
self.load_args = load_args or {}

def _describe(self) -> t.Dict[str, t.Any]:
"""Returns a dict that describes the attributes of the dataset."""
return dict(
Expand All @@ -66,15 +67,13 @@ def _load(self) -> pd.DataFrame:
return_all_columns = self.load_args.get("return_all_columns", False)

try:
with psycopg2.connect(self.db_credentials['con']) as conn:
with psycopg2.connect(self.db_credentials["con"]) as conn:
with conn.cursor() as cursor:

if return_all_columns:

# Fetch all rows
cursor.execute(
f"SELECT * FROM {self.table}"
)
cursor.execute(f"SELECT * FROM {self.table}")
data = cursor.fetchall()

# Fetch column names in the correct order from the database
Expand All @@ -87,7 +86,7 @@ def _load(self) -> pd.DataFrame:
WHERE c.relname = %s AND a.attnum > 0 AND NOT a.attisdropped
ORDER BY a.attnum
""",
(self.table,)
(self.table,),
)
columns = [row[0] for row in cursor.fetchall()]

Expand Down Expand Up @@ -116,22 +115,28 @@ def _save(self, data: pd.DataFrame) -> None:
"""

verbose = self.save_args.get("verbose", 1)

try:
# Connect to the database
with psycopg2.connect(self.db_credentials['con']) as conn:
with psycopg2.connect(self.db_credentials["con"]) as conn:
with conn.cursor() as cursor:
# Prepare data insertion
for _, row in data.iterrows():
# Ensure all data is properly converted to standard Python types
row_data = tuple(
row[col].item() if isinstance(row[col], (np.generic, np.ndarray)) else row[col]
(
row[col].item()
if isinstance(row[col], (np.generic, np.ndarray))
else row[col]
)
for col in self.column_names
)

# Determine the update clause (exclude unique columns)
updatable_columns = [
col for col in self.column_names if col not in self.unique_columns
col
for col in self.column_names
if col not in self.unique_columns
]

# Only create an update clause if there are columns to update
Expand All @@ -149,30 +154,42 @@ def _save(self, data: pd.DataFrame) -> None:
)

# Build the SQL query dynamically
query = sql.SQL("""
query = sql.SQL(
"""
INSERT INTO {table} ({columns})
VALUES ({values})
ON CONFLICT ({conflict_clause}) DO UPDATE SET
{update_clause}
RETURNING xmax = 0 AS is_inserted
""").format(
"""
).format(
table=sql.Identifier(self.table),
columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.column_names),
values=sql.SQL(", ").join(sql.Placeholder() for _ in self.column_names),
columns=sql.SQL(", ").join(
sql.Identifier(col) for col in self.column_names
),
values=sql.SQL(", ").join(
sql.Placeholder() for _ in self.column_names
),
conflict_clause=conflict_clause,
update_clause=update_clause
update_clause=update_clause,
)
else:
# Build the SQL query for insertion without an update clause
query = sql.SQL("""
query = sql.SQL(
"""
INSERT INTO {table} ({columns})
VALUES ({values})
ON CONFLICT DO NOTHING
RETURNING xmax = 0 AS is_inserted
""").format(
"""
).format(
table=sql.Identifier(self.table),
columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.column_names),
values=sql.SQL(", ").join(sql.Placeholder() for _ in self.column_names)
columns=sql.SQL(", ").join(
sql.Identifier(col) for col in self.column_names
),
values=sql.SQL(", ").join(
sql.Placeholder() for _ in self.column_names
),
)

# Execute the query with properly cast values
Expand All @@ -183,9 +200,13 @@ def _save(self, data: pd.DataFrame) -> None:

if verbose > 0:
if is_inserted:
logger.info(f"Inserted new row: {dict(zip(self.column_names, row_data))}")
logger.info(
f"Inserted new row: {dict(zip(self.column_names, row_data))}"
)
else:
logger.info(f"Updated row (or skipped due to conflict): {dict(zip(self.column_names, row_data))}")
logger.info(
f"Updated row (or skipped due to conflict): {dict(zip(self.column_names, row_data))}"
)

# Commit the transaction
conn.commit()
Expand All @@ -196,16 +217,14 @@ def _save(self, data: pd.DataFrame) -> None:

# %% ../nbs/50_custom_datasets.ipynb 6
class DynamicPathJSONDataset(AbstractDataset):

"""
Custom dataset to dynamically resolve a JSON file path from parameters.
"""

def __init__(self, path_param: str):

"""
Initializes the ConditionedJSONDataset.

Args:
path_param (str): The parameter key that contains the file path.
"""
Expand All @@ -220,10 +239,12 @@ def _load(self) -> dict:
"""
# Load parameters
params_path = self.config_loader["parameters"][self.path_param]

# Resolve the file path from parameters
if not params_path:
raise ValueError(f"Path parameter '{self.path_param}' not found in parameters.")
raise ValueError(
f"Path parameter '{self.path_param}' not found in parameters."
)

# Load and return JSON data
full_path = Path(params_path)
Expand Down
44 changes: 25 additions & 19 deletions inventory_foundation_sdk/db_mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
import numpy as np
from tqdm import tqdm


# %% ../nbs/10_db_mgmt.ipynb 5
def get_db_credentials():

"""
Fetch PostgreSQL database credentials from the configuration file of the kedro project.

Expand All @@ -35,8 +33,10 @@ def get_db_credentials():

# %% ../nbs/10_db_mgmt.ipynb 6
import logging

logger = logging.getLogger(__name__)


def insert_multi_rows(
data_to_insert: pd.DataFrame,
table_name: str,
Expand All @@ -47,7 +47,6 @@ def insert_multi_rows(
return_with_ids: bool = False,
unique_columns: list = None, # mandatory if return_with_ids is True
) -> pd.DataFrame | None:

"""
Inserts data into the specified database table, with an optional return of database-assigned IDs.

Expand All @@ -68,31 +67,40 @@ def insert_multi_rows(
# Check for NaN values and log a warning if any are found
if data_to_insert.isnull().values.any():
logger.warning("There are NaNs in the data")

# Ensure the DataFrame has the correct number of columns
if len(column_names) != data_to_insert.shape[1]:
raise ValueError("Number of column names does not match the number of columns in the DataFrame.")
raise ValueError(
"Number of column names does not match the number of columns in the DataFrame."
)
if len(types) != data_to_insert.shape[1]:
raise ValueError("Number of types does not match the number of columns in the DataFrame.")

raise ValueError(
"Number of types does not match the number of columns in the DataFrame."
)

logger.info("-- in insert multi rows -- converting data to list of tuples")
# Convert to list of tuples and apply type casting

data_values = data_to_insert.values.tolist()
data_values = [tuple(typ(val) for typ, val in zip(types, row)) for row in data_values]

data_values = [
tuple(typ(val) for typ, val in zip(types, row)) for row in data_values
]

logger.info("-- in insert multi rows -- preparing SQL")
# Create SQL placeholders and query
placeholders = ", ".join(["%s"] * len(column_names))
column_names_str = ", ".join(f'"{col}"' for col in column_names)


batch_size_for_commit = 1_000_000 # Adjust this based on your dataset size and transaction tolerance
batch_size_for_commit = (
1_000_000 # Adjust this based on your dataset size and transaction tolerance
)
row_count = 0

if return_with_ids:
if not unique_columns:
raise ValueError("unique_columns must be provided when return_with_ids is True")
raise ValueError(
"unique_columns must be provided when return_with_ids is True"
)

unique_columns_str = ", ".join(f'"{col}"' for col in unique_columns)
insert_query = f"""
Expand All @@ -104,8 +112,6 @@ def insert_multi_rows(
"""
ids = []



# Insert row by row and collect IDs
with tqdm(total=len(data_values), desc="Inserting rows") as pbar:
for row in data_values:
Expand All @@ -115,12 +121,12 @@ def insert_multi_rows(
ids.append(row_id[0])
row_count += 1
pbar.update(1) # Update progress bar for each row

# Commit every batch_size_for_commit rows
if row_count % batch_size_for_commit == 0:
conn.commit() # Commit the transaction
conn.commit()
conn.commit()

# Add IDs back to the original DataFrame
data_with_ids = data_to_insert.copy()
data_with_ids["ID"] = ids
Expand All @@ -132,7 +138,7 @@ def insert_multi_rows(
VALUES ({placeholders})
ON CONFLICT DO NOTHING;
"""

# Insert row by row without returning IDs
with tqdm(total=len(data_values), desc="Inserting rows") as pbar:
for row in data_values:
Expand All @@ -141,7 +147,7 @@ def insert_multi_rows(
pbar.update(1) # Update progress bar for each row
if row_count % batch_size_for_commit == 0:
conn.commit() # Commit the transaction

conn.commit() # Commit all changes after processing

return None
Loading
Loading