diff --git a/tap_snowflake/client.py b/tap_snowflake/client.py index 977c8d6..3c22601 100644 --- a/tap_snowflake/client.py +++ b/tap_snowflake/client.py @@ -11,14 +11,36 @@ from pathlib import Path from typing import Any, Iterable, List, Tuple from uuid import uuid4 +import datetime +import re import sqlalchemy from singer_sdk import SQLConnector, SQLStream, metrics from singer_sdk.helpers._batch import BaseBatchFileEncoding, BatchConfig from singer_sdk.streams.core import REPLICATION_FULL_TABLE, REPLICATION_INCREMENTAL +import singer_sdk.helpers._typing from snowflake.sqlalchemy import URL from sqlalchemy.sql import text +unpatched_conform = singer_sdk.helpers._typing._conform_primitive_property + + +def patched_conform( + elem: Any, + property_schema: dict, +) -> Any: + """Overrides Singer SDK type conformance to prevent dates turning into datetimes. + Converts a primitive (i.e. not object or array) to a json compatible type. + Returns: + The appropriate json compatible type. + """ + if isinstance(elem, datetime.date): + return elem.isoformat() + return unpatched_conform(elem=elem, property_schema=property_schema) + + +singer_sdk.helpers._typing._conform_primitive_property = patched_conform + class ProfileStats(Enum): """Profile Statistics Enum.""" @@ -78,7 +100,7 @@ def create_engine(self) -> sqlalchemy.engine.Engine: self.sqlalchemy_url, echo=False, pool_timeout=10, - ) + ) # overridden to filter out the information_schema from catalog discovery def discover_catalog_entries(self) -> list[dict]: @@ -91,13 +113,22 @@ def discover_catalog_entries(self) -> list[dict]: tables = [t.lower() for t in self.config.get("tables", [])] engine = self.create_sqlalchemy_engine() inspected = sqlalchemy.inspect(engine) - schema_names = [ - schema_name - for schema_name in self.get_schema_names(engine, inspected) - if schema_name.lower() != "information_schema" - ] + + if self.config.get("schema"): + schema_names = [] + schema_names.append(self.config.get("schema")) + + else: + schema_names = [ + schema_name + for schema_name in self.get_schema_names(engine, inspected) + if schema_name.lower() != "information_schema" + ] + for schema_name in schema_names: # Iterate through each table and view + # We shouldn't have to iterate through every table in a schema if tables are provided + # However, the only way to get is_view for tables is self.get_object_names with schema for table_name, is_view in self.get_object_names( engine, inspected, schema_name ): diff --git a/tests/catalog.json b/tests/catalog.json index c3f4c5f..62e5e69 100644 --- a/tests/catalog.json +++ b/tests/catalog.json @@ -378,7 +378,7 @@ "type": ["number"] }, "o_orderdate": { - "format": "date-time", + "format": "date", "type": ["string"] }, "o_orderpriority": {