Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
sprivite committed Jan 23, 2025
1 parent c5eff2f commit 0912071
Show file tree
Hide file tree
Showing 9 changed files with 313 additions and 280 deletions.
9 changes: 8 additions & 1 deletion phenex/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from .relative_time_range_filter import RelativeTimeRangeFilter
from .value import GreaterThan, GreaterThanOrEqualTo, LessThan, LessThanOrEqualTo, EqualTo, Value
from .value import (
GreaterThan,
GreaterThanOrEqualTo,
LessThan,
LessThanOrEqualTo,
EqualTo,
Value,
)
10 changes: 6 additions & 4 deletions phenex/filters/categorical_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def __init__(
self.domain = domain
super(CategoricalFilter, self).__init__()

def _filter(self, table: 'PhenexTable'):
def _filter(self, table: "PhenexTable"):
return table.filter(table[self.column_name].isin(self.allowed_values))

def autojoin_filter(self, table: 'PhenexTable', tables:dict = None):
def autojoin_filter(self, table: "PhenexTable", tables: dict = None):
if self.column_name not in table.columns:
if self.domain not in tables.keys():
raise ValueError(f"Table required for categorical filter ({self.domain}) does not exist within domains dicitonary")
table = table.join(tables[self.domain], domains = tables)
raise ValueError(
f"Table required for categorical filter ({self.domain}) does not exist within domains dicitonary"
)
table = table.join(tables[self.domain], domains=tables)
# TODO downselect to original columns
return table.filter(table[self.column_name].isin(self.allowed_values))
135 changes: 74 additions & 61 deletions phenex/ibis_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,29 @@ class SnowflakeConnector:
Methods:
connect_dest() -> BaseBackend:
Establishes and returns an Ibis backend connection to the destination Snowflake database and schema.
connect_source() -> BaseBackend:
Establishes and returns an Ibis backend connection to the source Snowflake database and schema.
get_source_table(name_table: str) -> Table:
Retrieves a table from the source Snowflake database.
get_dest_table(name_table: str) -> Table:
Retrieves a table from the destination Snowflake database.
create_view(table: Table, name_table: Optional[str] = None, overwrite: bool = False) -> View:
Create a view of a table in the destination Snowflake database.
create_table(table: Table, name_table: Optional[str] = None, overwrite: bool = False) -> Table:
Materialize a table in the destination Snowflake database.
drop_table(name_table: str) -> None:
Drop a table from the destination Snowflake database.
drop_view(name_table: str) -> None:
Drop a view from the destination Snowflake database.
"""

def __init__(
self,
SNOWFLAKE_USER: Optional[str] = None,
Expand All @@ -73,22 +74,36 @@ def __init__(
SNOWFLAKE_DEST_DATABASE: Optional[str] = None,
):
self.SNOWFLAKE_USER = SNOWFLAKE_USER or os.environ.get("SNOWFLAKE_USER")
self.SNOWFLAKE_ACCOUNT = SNOWFLAKE_ACCOUNT or os.environ.get("SNOWFLAKE_ACCOUNT")
self.SNOWFLAKE_WAREHOUSE = SNOWFLAKE_WAREHOUSE or os.environ.get("SNOWFLAKE_WAREHOUSE")
self.SNOWFLAKE_ACCOUNT = SNOWFLAKE_ACCOUNT or os.environ.get(
"SNOWFLAKE_ACCOUNT"
)
self.SNOWFLAKE_WAREHOUSE = SNOWFLAKE_WAREHOUSE or os.environ.get(
"SNOWFLAKE_WAREHOUSE"
)
self.SNOWFLAKE_ROLE = SNOWFLAKE_ROLE or os.environ.get("SNOWFLAKE_ROLE")
self.SNOWFLAKE_PASSWORD = SNOWFLAKE_PASSWORD or os.environ.get("SNOWFLAKE_PASSWORD")
self.SNOWFLAKE_SOURCE_DATABASE = SNOWFLAKE_SOURCE_DATABASE or os.environ.get("SNOWFLAKE_SOURCE_DATABASE")
self.SNOWFLAKE_DEST_DATABASE = SNOWFLAKE_DEST_DATABASE or os.environ.get("SNOWFLAKE_DEST_DATABASE")
self.SNOWFLAKE_PASSWORD = SNOWFLAKE_PASSWORD or os.environ.get(
"SNOWFLAKE_PASSWORD"
)
self.SNOWFLAKE_SOURCE_DATABASE = SNOWFLAKE_SOURCE_DATABASE or os.environ.get(
"SNOWFLAKE_SOURCE_DATABASE"
)
self.SNOWFLAKE_DEST_DATABASE = SNOWFLAKE_DEST_DATABASE or os.environ.get(
"SNOWFLAKE_DEST_DATABASE"
)

try:
_, _ = self.SNOWFLAKE_SOURCE_DATABASE.split('.')
_, _ = self.SNOWFLAKE_SOURCE_DATABASE.split(".")
except:
raise ValueError('Use a fully qualified database name (e.g. CATALOG.DATABASE).')
raise ValueError(
"Use a fully qualified database name (e.g. CATALOG.DATABASE)."
)
try:
_, _ = self.SNOWFLAKE_SOURCE_DATABASE.split('.')
_, _ = self.SNOWFLAKE_SOURCE_DATABASE.split(".")
except:
raise ValueError('Use a fully qualified database name (e.g. CATALOG.DATABASE).')

raise ValueError(
"Use a fully qualified database name (e.g. CATALOG.DATABASE)."
)

required_vars = [
"SNOWFLAKE_USER",
"SNOWFLAKE_ACCOUNT",
Expand All @@ -102,22 +117,25 @@ def __init__(
self.source_connection = self.connect_source()
self.dest_connection = self.connect_dest()


def _check_env_vars(self, required_vars: List[str]):
for var in required_vars:
if not getattr(self, var):
raise ValueError(f"Missing required variable: {var}. Set in the environment or pass through __init__().")
raise ValueError(
f"Missing required variable: {var}. Set in the environment or pass through __init__()."
)

def _check_source_dest(self):
if (self.SNOWFLAKE_SOURCE_DATABASE == self.SNOWFLAKE_DEST_DATABASE and
self.SNOWFLAKE_SOURCE_SCHEMA == self.SNOWFLAKE_DEST_SCHEMA):
if (
self.SNOWFLAKE_SOURCE_DATABASE == self.SNOWFLAKE_DEST_DATABASE
and self.SNOWFLAKE_SOURCE_SCHEMA == self.SNOWFLAKE_DEST_SCHEMA
):
raise ValueError("Source and destination locations cannot be the same.")

def _connect(self, database) -> BaseBackend:
'''
"""
Private method to get a database connection. End users should use connect_source() and connect_dest() to get connections to source and destination databases.
'''
database, schema = database.split('.')
"""
database, schema = database.split(".")
#
# In Ibis speak: catalog = collection of databases
# database = collection of tables
Expand All @@ -127,7 +145,7 @@ def _connect(self, database) -> BaseBackend:
#
# In the below connect method, the arguments are the SNOWFLAKE terms.
#

if self.SNOWFLAKE_PASSWORD:
return ibis.snowflake.connect(
user=self.SNOWFLAKE_USER,
Expand All @@ -136,7 +154,7 @@ def _connect(self, database) -> BaseBackend:
warehouse=self.SNOWFLAKE_WAREHOUSE,
role=self.SNOWFLAKE_ROLE,
database=database,
schema=schema
schema=schema,
)
else:
return ibis.snowflake.connect(
Expand All @@ -146,9 +164,9 @@ def _connect(self, database) -> BaseBackend:
warehouse=self.SNOWFLAKE_WAREHOUSE,
role=self.SNOWFLAKE_ROLE,
database=database,
schema=schema
schema=schema,
)

def connect_dest(self) -> BaseBackend:
"""
Establishes and returns an Ibis backend connection to the destination Snowflake database.
Expand All @@ -157,8 +175,8 @@ def connect_dest(self) -> BaseBackend:
BaseBackend: Ibis backend connection to the destination Snowflake database.
"""
return self._connect(
database=self.SNOWFLAKE_DEST_DATABASE,
)
database=self.SNOWFLAKE_DEST_DATABASE,
)

def connect_source(self) -> BaseBackend:
"""
Expand All @@ -168,8 +186,8 @@ def connect_source(self) -> BaseBackend:
BaseBackend: Ibis backend connection to the source Snowflake database.
"""
return self._connect(
database=self.SNOWFLAKE_SOURCE_DATABASE,
)
database=self.SNOWFLAKE_SOURCE_DATABASE,
)

def get_source_table(self, name_table):
"""
Expand All @@ -182,8 +200,7 @@ def get_source_table(self, name_table):
Table: Ibis table object from the source Snowflake database.
"""
return self.dest_connection.table(
name_table,
database=self.SNOWFLAKE_SOURCE_DATABASE
name_table, database=self.SNOWFLAKE_SOURCE_DATABASE
)

def get_dest_table(self, name_table):
Expand All @@ -197,16 +214,16 @@ def get_dest_table(self, name_table):
Table: Ibis table object from the destination Snowflake database.
"""
return self.dest_connection.table(
name_table,
database=self.SNOWFLAKE_DEST_DATABASE
name_table, database=self.SNOWFLAKE_DEST_DATABASE
)

def _get_output_table_name(self, table):
if table.has_name:
name_table = table.get_name().split('.')[-1]
name_table = table.get_name().split(".")[-1]
else:
raise ValueError('Must specify name_table!')
raise ValueError("Must specify name_table!")
return name_table

def create_view(self, table, name_table=None, overwrite=False):
"""
Create a view of a table in the destination Snowflake database.
Expand All @@ -220,17 +237,14 @@ def create_view(self, table, name_table=None, overwrite=False):
View: Ibis view object created in the destination Snowflake database.
"""
name_table = name_table or self._get_output_table_name(table)

# Check if the destination database exists, if not, create it
catalog, database = self.SNOWFLAKE_DEST_DATABASE.split('.')
catalog, database = self.SNOWFLAKE_DEST_DATABASE.split(".")
if not database in self.dest_connection.list_databases():
self.dest_connection.create_database(name=database, catalog=catalog)

return self.dest_connection.create_view(
name=name_table,
database=database,
obj=table,
overwrite=overwrite
name=name_table, database=database, obj=table, overwrite=overwrite
)

def create_table(self, table, name_table=None, overwrite=False):
Expand All @@ -246,17 +260,14 @@ def create_table(self, table, name_table=None, overwrite=False):
Table: Ibis table object created in the destination Snowflake database.
"""
name_table = name_table or self._get_output_table_name(table)

# Check if the destination database exists, if not, create it
catalog, database = self.SNOWFLAKE_DEST_DATABASE.split('.')
catalog, database = self.SNOWFLAKE_DEST_DATABASE.split(".")
if not database in self.dest_connection.list_databases():
self.dest_connection.create_database(name=database, catalog=catalog)

return self.dest_connection.create_table(
name=name_table,
database=database,
obj=table,
overwrite=overwrite
name=name_table, database=database, obj=table, overwrite=overwrite
)

def drop_table(self, name_table):
Expand All @@ -270,10 +281,9 @@ def drop_table(self, name_table):
None
"""
return self.dest_connection.drop_table(
name=name_table,
database=self.SNOWFLAKE_DEST_DATABASE
)

name=name_table, database=self.SNOWFLAKE_DEST_DATABASE
)

def drop_view(self, name_table):
"""
Drop a view from the destination Snowflake database.
Expand All @@ -285,9 +295,9 @@ def drop_view(self, name_table):
None
"""
return self.dest_connection.drop_view(
name=name_table,
database=self.SNOWFLAKE_DEST_DATABASE
)
name=name_table, database=self.SNOWFLAKE_DEST_DATABASE
)


class DuckDBConnector:
"""
Expand All @@ -300,6 +310,7 @@ class DuckDBConnector:
connect() -> BaseBackend:
Establishes and returns an Ibis backend connection to the DuckDB database.
"""

def __init__(self, DUCKDB_PATH: Optional[str] = ":memory"):
"""
Initializes the DuckDBConnector with the specified path.
Expand All @@ -318,4 +329,6 @@ def connect(self) -> BaseBackend:
"""
required_vars = ["DUCKDB_PATH"]
_check_env_vars(*required_vars)
return ibis.connect(backend="duckdb", path=self.DUCKDB_PATH or os.getenv("DUCKDB_PATH"))
return ibis.connect(
backend="duckdb", path=self.DUCKDB_PATH or os.getenv("DUCKDB_PATH")
)
Loading

0 comments on commit 0912071

Please sign in to comment.