Skip to content

Commit

Permalink
solved merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
majoma7 committed Jan 28, 2025
2 parents 172d295 + 5a5e299 commit 75510db
Show file tree
Hide file tree
Showing 7 changed files with 751 additions and 255 deletions.
33 changes: 29 additions & 4 deletions inventory_foundation_sdk/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,23 @@
'inventory_foundation_sdk/custom_datasets.py'),
'inventory_foundation_sdk.custom_datasets.DynamicPathJSONDataset._save': ( 'custom_datasets.html#dynamicpathjsondataset._save',
'inventory_foundation_sdk/custom_datasets.py')},
<<<<<<< HEAD
'inventory_foundation_sdk.db_mgmt': { 'inventory_foundation_sdk.db_mgmt.check_in_scope_entries': ( 'db_mgmt.html#check_in_scope_entries',
'inventory_foundation_sdk/db_mgmt.py'),
=======
'inventory_foundation_sdk.db_mgmt': { 'inventory_foundation_sdk.db_mgmt.SQLDatabase': ( 'db_mgmt.html#sqldatabase',
'inventory_foundation_sdk/db_mgmt.py'),
'inventory_foundation_sdk.db_mgmt.SQLDatabase.__init__': ( 'db_mgmt.html#sqldatabase.__init__',
'inventory_foundation_sdk/db_mgmt.py'),
'inventory_foundation_sdk.db_mgmt.SQLDatabase.close': ( 'db_mgmt.html#sqldatabase.close',
'inventory_foundation_sdk/db_mgmt.py'),
'inventory_foundation_sdk.db_mgmt.SQLDatabase.connect': ( 'db_mgmt.html#sqldatabase.connect',
'inventory_foundation_sdk/db_mgmt.py'),
'inventory_foundation_sdk.db_mgmt.SQLDatabase.execute_multiple_queries': ( 'db_mgmt.html#sqldatabase.execute_multiple_queries',
'inventory_foundation_sdk/db_mgmt.py'),
'inventory_foundation_sdk.db_mgmt.SQLDatabase.execute_query': ( 'db_mgmt.html#sqldatabase.execute_query',
'inventory_foundation_sdk/db_mgmt.py'),
>>>>>>> origin/refactor-if-sdk
'inventory_foundation_sdk.db_mgmt.get_db_credentials': ( 'db_mgmt.html#get_db_credentials',
'inventory_foundation_sdk/db_mgmt.py'),
'inventory_foundation_sdk.db_mgmt.insert_multi_rows': ( 'db_mgmt.html#insert_multi_rows',
Expand Down Expand Up @@ -84,7 +99,17 @@
'inventory_foundation_sdk/etl_nodes.py'),
'inventory_foundation_sdk.etl_nodes.input_output_node': ( 'etl_nodes.html#input_output_node',
'inventory_foundation_sdk/etl_nodes.py')},
'inventory_foundation_sdk.kedro_orchestration': { 'inventory_foundation_sdk.kedro_orchestration.verify_db_write_status': ( 'kedro_orchestration.html#verify_db_write_status',
'inventory_foundation_sdk/kedro_orchestration.py')},
'inventory_foundation_sdk.test': { 'inventory_foundation_sdk.test.write_company_name': ( 'core copy.html#write_company_name',
'inventory_foundation_sdk/test.py')}}}
'inventory_foundation_sdk.state_mgmnt': { 'inventory_foundation_sdk.state_mgmnt.Flag': ( 'state_mgmt.html#flag',
'inventory_foundation_sdk/state_mgmnt.py'),
'inventory_foundation_sdk.state_mgmnt.Flag.__init__': ( 'state_mgmt.html#flag.__init__',
'inventory_foundation_sdk/state_mgmnt.py'),
'inventory_foundation_sdk.state_mgmnt.Flag.check': ( 'state_mgmt.html#flag.check',
'inventory_foundation_sdk/state_mgmnt.py'),
'inventory_foundation_sdk.state_mgmnt.Flag.get': ( 'state_mgmt.html#flag.get',
'inventory_foundation_sdk/state_mgmnt.py'),
'inventory_foundation_sdk.state_mgmnt.Flag.set': ( 'state_mgmt.html#flag.set',
'inventory_foundation_sdk/state_mgmnt.py'),
'inventory_foundation_sdk.state_mgmnt.States': ( 'state_mgmt.html#states',
'inventory_foundation_sdk/state_mgmnt.py'),
'inventory_foundation_sdk.state_mgmnt.States.are_verified': ( 'state_mgmt.html#states.are_verified',
'inventory_foundation_sdk/state_mgmnt.py')}}}
174 changes: 174 additions & 0 deletions inventory_foundation_sdk/db_mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_db_mgmt.ipynb.

# %% auto 0
<<<<<<< HEAD
__all__ = ['get_db_credentials', 'insert_multi_rows', 'check_in_scope_entries']
=======
__all__ = ['get_db_credentials', 'insert_multi_rows', 'SQLDatabase']
>>>>>>> origin/refactor-if-sdk

# %% ../nbs/10_db_mgmt.ipynb 3
from kedro.config import OmegaConfigLoader
Expand All @@ -14,10 +18,15 @@
import numpy as np
from tqdm import tqdm

import psycopg2

<<<<<<< HEAD
import psycopg2

# %% ../nbs/10_db_mgmt.ipynb 5
=======
# %% ../nbs/10_db_mgmt.ipynb 4
>>>>>>> origin/refactor-if-sdk
def get_db_credentials():
"""
Fetch PostgreSQL database credentials from the configuration file of the kedro project.
Expand All @@ -34,7 +43,11 @@ def get_db_credentials():

return db_credentials

<<<<<<< HEAD
# %% ../nbs/10_db_mgmt.ipynb 6
=======
# %% ../nbs/10_db_mgmt.ipynb 5
>>>>>>> origin/refactor-if-sdk
def insert_multi_rows(
data_to_insert: pd.DataFrame,
table_name: str,
Expand Down Expand Up @@ -76,15 +89,23 @@ def insert_multi_rows(
"Number of types does not match the number of columns in the DataFrame."
)

<<<<<<< HEAD
# logger.info("-- in insert multi rows -- converting data to list of tuples")
=======
logger.info("-- in insert multi rows -- converting data to list of tuples")
>>>>>>> origin/refactor-if-sdk
# Convert to list of tuples and apply type casting

data_values = data_to_insert.values.tolist()
data_values = [
tuple(typ(val) for typ, val in zip(types, row)) for row in data_values
]

<<<<<<< HEAD
# logger.info("-- in insert multi rows -- preparing SQL")
=======
logger.info("-- in insert multi rows -- preparing SQL")
>>>>>>> origin/refactor-if-sdk
# Create SQL placeholders and query
placeholders = ", ".join(["%s"] * len(column_names))
column_names_str = ", ".join(f'"{col}"' for col in column_names)
Expand Down Expand Up @@ -150,6 +171,7 @@ def insert_multi_rows(

return None

<<<<<<< HEAD
# %% ../nbs/10_db_mgmt.ipynb 7
def check_in_scope_entries(
target_table,
Expand Down Expand Up @@ -264,3 +286,155 @@ def check_in_scope_entries(
except Exception as e:
logger.error(f"Error checking in-scope entries for {target_table}: {e}")
raise e
=======
# %% ../nbs/10_db_mgmt.ipynb 6
class SQLDatabase:
"""
A class to represent a SQL database.
Attributes:
----------
credentials : dict
A dictionary containing the database connection credentials.
connection : psycopg2.connection
The database connection object.
Methods:
-------
connect():
Connects to the database using the provided credentials.
close():
Closes the database connection.
execute_query(query: str):
Executes the given SQL query.
execute_multiple_queries(queries: list, params: list = None, fetchrows: bool = False):
Executes a list or iterable of SQL queries.
If fetchrows is True, iterates over queries and fetches rows.
If fetchrows is False, uses executemany for batch execution.
"""

def __init__(self):
"""
Initializes the SQLDatabase object with the provided database connection credentials.
"""
self._credentials = get_db_credentials()["con"]
self.connection = None

def connect(self):
"""
Connects to the database using the provided credentials.
"""
if not self.connection:
self.connection = psycopg2.connect(self._credentials)

def close(self):
"""
Closes the database connection.
"""
if self.connection:
self.connection.close()
self.connection = None

def execute_query(
self, query: str, params: tuple = None, fetchall: bool = False, fetchone=False
):
"""
Executes the given SQL query with optional parameters.
Parameters:
-----------
query : str
The SQL query to be executed.
params : tuple, optional
The parameters to be passed to the query. Defaults to None.
fetchall : bool, optional
Whether to fetch all rows from the query result. Defaults to False.
fetchone : bool, optional
Whether to fetch only one row from the query result. Defaults to False.
Returns:
--------
result : list
A list of tuples representing the query result rows.
"""

# Check if not both fetchall and fetchone are True
if fetchall and fetchone:
raise ValueError("Both fetchall and fetchone cannot be True")

if not self.connection:
self.connect()

with self.connection.cursor() as cur:
cur.execute(query, params)
if fetchall:
result = cur.fetchall()
elif fetchone:
result = cur.fetchone()
else:
result = None

self.connection.commit()
return result

def execute_multiple_queries(
self, queries: list | str, params: list = None, fetchrows: bool = False
):
"""
Executes a list or iterable of SQL queries.
If fetchrows is True, iterates over queries and fetches rows.
If fetchrows is False, tries to use executemany for batch execution.
Parameters:
-----------
queries : list or str
A list of SQL queries to be executed, or a single query as a string.
params : list, optional
A list of tuples containing parameters for each query. Defaults to None.
fetchrows : bool, optional
Whether to fetch rows from the queries. Defaults to False (use executemany).
Returns:
--------
results : list
A list of results for each executed query, or None if using executemany.
"""
if not self.connection:
self.connect()

results = []
with self.connection.cursor() as cur:
if fetchrows:

if isinstance(queries, str):
# Convert single query to a list of same length as params
queries = [queries] * len(params)

# Iterate over queries and fetch rows
for idx, query in enumerate(queries):
query_params = params[idx] if params else None
cur.execute(query, query_params)
result = cur.fetchone()
results.append(result) # Collect the result of each query
else:

if not isinstance(queries, str):
# In this case only one query with multiple params can be executed (raise error)
raise ValueError(
"Multiple queries with multiple params are not supported when using executemany. Set fetchrows=True"
)

cur.executemany(queries, params)

self.connection.commit()

return results if fetchrows else None
>>>>>>> origin/refactor-if-sdk
85 changes: 85 additions & 0 deletions inventory_foundation_sdk/state_mgmnt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Classes to manage the state of nodes within kedro pipelines"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/20_state_mgmt.ipynb.

# %% auto 0
__all__ = ['Flag', 'States']

# %% ../nbs/20_state_mgmt.ipynb 3
import logging


class Flag:
"""
A class to represent a flag with a state that can be checked, set, or retrieved.
Attributes:
name (str): The name of the flag.
_state (bool or None): The state of the flag. Defaults to None.
Methods:
check(): Verifies the state of the flag. Raises an error if not set or if False.
set(state): Sets the state of the flag.
get(): Returns the current state of the flag.
"""

def __init__(self, name, state=None):
"""
Initializes the flag with a given name and optional initial state.
Args:
name (str): The name of the flag.
state (bool or None): The initial state of the flag. Defaults to None.
"""
self.name = name
self._state = state

def check(self):
"""
Checks the state of the flag.
Raises:
ValueError: If the flag's state is None.
AssertionError: If the flag's state is False.
Logs:
An info message if the flag's state is True.
"""
if self._state is None:
raise ValueError(f"Flag '{self.name}' is not set.")
elif not self._state: # This checks if the state is False (False, 0, '', etc.)
raise AssertionError(f"Flag '{self.name}' is False.")
else:
logging.info(f"Flag '{self.name}' is True. Everything is good.")

def set(self, state):
"""
Sets the state of the flag.
Args:
state (bool): The new state to set for the flag.
"""
self._state = state

def get(self):
"""
Retrieves the current state of the flag.
Returns:
bool or None: The current state of the flag.
"""
return self._state

# %% ../nbs/20_state_mgmt.ipynb 4
class States:
"""
A class to manage the states of nodes within the ETL pipeline.
The are verified function checks if all the states have been verified and the corresponding nodes has run successfully.
"""

@staticmethod
def are_verified(*states: Flag):
"""
Checks if all the states have been verified.
"""
return all(state.get() for state in states)
Loading

0 comments on commit 75510db

Please sign in to comment.