Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upserting capabilities added at creation #108

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
102 changes: 71 additions & 31 deletions src/api/create.py
Original file line number Diff line number Diff line change
@@ -9,12 +9,9 @@
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy_utils.functions import (
create_database,
database_exists,
drop_database,
)
from sqlalchemy import MetaData, create_engine, inspect, text
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy_utils.functions import create_database, database_exists
from sqlmodel import SQLModel
from tqdm import tqdm

@@ -70,17 +67,25 @@ def normalize_signal_name(name):
)
return signal_name

def get_primary_keys(table_name, engine):
inspector = inspect(engine)
pk_columns = inspector.get_pk_constraint(table_name).get("constrained_columns", [])
if not pk_columns:
raise ValueError(f"No primary key found for table {table_name}")
return [', '.join(pk_columns)]

class DBCreationClient:
def __init__(self, uri: str, db_name: str):
self.uri = uri
self.db_name = db_name

def create_database(self):
if database_exists(self.uri):
drop_database(self.uri)

create_database(self.uri)
print(database_exists(self.uri))
if not database_exists(self.uri):
logging.info("Database does not exist. Creating.")
create_database(self.uri)
else:
logging.info("Database exists. Skipping creation.")

self.metadata_obj, self.engine = connect(self.uri)

@@ -93,26 +98,61 @@ def create_database(self):
def create_user(self):
engine = create_engine(self.uri, echo=True)
name = password = "public_user"
drop_user = text(f"DROP USER IF EXISTS {name}")
create_user_query = text(f"CREATE USER {name} WITH PASSWORD :password;")
grant_privledges = text(f"GRANT CONNECT ON DATABASE {self.db_name} TO {name};")
grant_public_schema = text(f"GRANT USAGE ON SCHEMA public TO {name};")
grant_public_schema_tables = text(
f"GRANT SELECT ON ALL TABLES IN SCHEMA public TO {name};"
)

create_user_query = text(f"DO $$ BEGIN \
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{name}') THEN \
CREATE USER {name} WITH PASSWORD :password; \
END IF; \
END $$;")

grant_privileges = [
text(f"GRANT CONNECT ON DATABASE {self.db_name} TO {name};"),
text(f"GRANT USAGE ON SCHEMA public TO {name};"),
text(f"GRANT SELECT ON ALL TABLES IN SCHEMA public TO {name};"),
text(f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO {name};"),
]

with engine.connect() as conn:
conn.execute(drop_user)
# Ensure the user exists
conn.execute(create_user_query, {"password": password})
conn.execute(grant_privledges)
conn.execute(grant_public_schema)
conn.execute(grant_public_schema_tables)

# Grant necessary privileges
for grant_query in grant_privileges:
conn.execute(grant_query)

def create_or_upsert_table(self, table_name: str, df: pd.DataFrame):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also have unit testing for this method?


self.metadata_obj.reflect(bind=self.engine)

if table_name not in self.metadata_obj.tables:
logging.info("Creating table {table_name} from scratch.")
df.to_sql(table_name, self.engine, if_exists="append")

table = self.metadata_obj.tables[table_name]

insert_stmt = insert(table).values(df.to_dict(orient="records"))
update_column_stmt = {col.name: insert_stmt.excluded[col.name] for col in table.columns if col.name != "uuid"}

primary_key = get_primary_keys(table_name, self.engine)

stmt = insert_stmt.on_conflict_do_update(index_elements=primary_key, set_=update_column_stmt)

with self.engine.connect() as conn:
conn.execute(stmt)

def create_cpf_summary(self, data_path: Path):
"""Create the CPF summary table"""
paths = data_path.glob("cpf/*_cpf_columns.parquet")

columns = []
for path in paths:
df = pd.read_parquet(path)
df.to_sql("cpf_summary", self.uri, if_exists="replace")
columns.append(df)

cpf_sum = pd.concat(columns, axis=0)
cpf_sum = cpf_sum.drop_duplicates(subset=["name", "description"]).reset_index(drop=True)

self.create_or_upsert_table("cpf_summary", cpf_sum)

def create_scenarios(self, data_path: Path):
"""Create the scenarios metadata table"""
@@ -123,7 +163,7 @@ def create_scenarios(self, data_path: Path):

data = pd.DataFrame(dict(id=ids, name=scenarios)).set_index("id")
data = data.dropna()
data.to_sql("scenarios", self.uri, if_exists="append")
self.create_or_upsert_table( "scenarios", data)

def create_shots(self, data_path: Path):
"""Create the shot metadata table"""
@@ -157,7 +197,7 @@ def create_shots(self, data_path: Path):
cpfs = pd.concat(cpfs, axis=0)
cpfs = cpfs = cpfs.reset_index()
cpfs = cpfs.loc[cpfs.shot_id <= LAST_MAST_SHOT]
cpfs = cpfs.drop_duplicates(subset="shot_id")
cpfs = cpfs.drop_duplicates(subset="shot_id").sort_values(by="shot_id")
cpfs = cpfs.set_index("shot_id")

shot_metadata = pd.merge(
@@ -167,11 +207,12 @@ def create_shots(self, data_path: Path):
right_on="shot_id",
how="left",
)

shot_metadata.to_sql("shots", self.uri, if_exists="append")
shot_metadata = shot_metadata.reset_index()
shot_metadata = shot_metadata.replace(np.nan, None)
self.create_or_upsert_table("shots", shot_metadata)

def create_signals(self, data_path: Path):
logging.info(f"Loading signals from {data_path}")
logging.info("Loading signals from {data_path}")
file_name = data_path / "signals.parquet"

parquet_file = pq.ParquetFile(file_name)
@@ -201,9 +242,8 @@ def create_signals(self, data_path: Path):
uda_attributes = ["uda_name", "mds_name", "file_name", "format"]
df = df.drop(uda_attributes, axis=1)
df["shot_id"] = df.shot_id.astype(int)
df = df.set_index("shot_id", drop=True)
df["description"] = df.description.map(lambda x: "" if x is None else x)
df.to_sql("signals", self.uri, if_exists="append")
self.create_or_upsert_table("signals", df)

def create_sources(self, data_path: Path):
source_metadata = pd.read_parquet(data_path / "sources.parquet")
@@ -217,7 +257,7 @@ def create_sources(self, data_path: Path):
)
column_names = ["uuid", "shot_id", "name", "description", "quality", "url"]
source_metadata = source_metadata[column_names]
source_metadata.to_sql("sources", self.uri, if_exists="append", index=False)
self.create_or_upsert_table("sources", source_metadata)


def read_cpf_metadata(cpf_file_name: Path) -> pd.DataFrame:
@@ -243,7 +283,7 @@ def create_db_and_tables(data_path):

# populate the database tables
logging.info("Create CPF summary")
client.create_cpf_summary(data_path / "cpf")
client.create_cpf_summary(data_path)

logging.info("Create Scenarios")
client.create_scenarios(data_path)