Skip to content

Commit

Permalink
Added changes for dbt-oracle 1.8
Browse files Browse the repository at this point in the history
  • Loading branch information
aosingh committed Jun 5, 2024
1 parent 7cd6f7c commit 1a7706e
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 99 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Configuration variables
VERSION=1.7.5
VERSION=1.8.0
PROJ_DIR?=$(shell pwd)
VENV_DIR?=${PROJ_DIR}/.bldenv
BUILD_DIR=${PROJ_DIR}/build
Expand Down
10 changes: 5 additions & 5 deletions dbt/adapters/oracle/connection_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import enum
import os

import dbt.exceptions
from dbt.events import AdapterLogger
import dbt_common.exceptions
from dbt.adapters.events.logging import AdapterLogger

from dbt.ui import warning_tag, yellow, red
from dbt_common.ui import warning_tag, yellow, red

logger = AdapterLogger("oracle")

Expand Down Expand Up @@ -129,5 +129,5 @@ class OracleDriverType(str, enum.Enum):
SQLNET_ORA_CONFIG = OracleNetConfig.from_env()
logger.info("Running in thin mode")
else:
raise dbt.exceptions.DbtRuntimeError("Invalid value set for ORA_PYTHON_DRIVER_TYPE\n"
"Use any one of 'cx', 'thin', or 'thick'")
raise dbt_common.exceptions.DbtRuntimeError("Invalid value set for ORA_PYTHON_DRIVER_TYPE\n"
"Use any one of 'cx', 'thin', or 'thick'")
25 changes: 13 additions & 12 deletions dbt/adapters/oracle/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
import uuid
import platform

import dbt.exceptions
from dbt.adapters.base import Credentials
import dbt_common.exceptions
from dbt.adapters.contracts.connection import AdapterResponse, Credentials
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.sql import SQLConnectionManager
from dbt.contracts.connection import AdapterResponse
from dbt.events.functions import fire_event
from dbt.events.types import ConnectionUsed, SQLQuery, SQLCommit, SQLQueryStatus
from dbt.events import AdapterLogger
from dbt.events.contextvars import get_node_info
from dbt.utils import cast_to_str
from dbt.adapters.events.types import ConnectionUsed, SQLQuery, SQLCommit, SQLQueryStatus
from dbt.adapters.events.logging import AdapterLogger

from dbt_common.events.functions import fire_event
from dbt_common.events.contextvars import get_node_info
from dbt_common.utils import cast_to_str

from dbt.version import __version__ as dbt_version
from dbt.adapters.oracle.connection_helper import oracledb, SQLNET_ORA_CONFIG
Expand Down Expand Up @@ -256,7 +257,7 @@ def open(cls, connection):
connection.handle = None
connection.state = 'fail'

raise dbt.exceptions.FailedToConnectError(str(e))
raise FailedToConnectError(str(e))

return connection

Expand Down Expand Up @@ -302,18 +303,18 @@ def exception_handler(self, sql):
logger.info("Failed to release connection!")
pass

raise dbt.exceptions.DbtDatabaseError(str(e).strip()) from e
raise dbt_common.exceptions.DbtDatabaseError(str(e).strip()) from e

except Exception as e:
logger.info("Rolling back transaction.")
self.release()
if isinstance(e, dbt.exceptions.DbtRuntimeError):
if isinstance(e, dbt_common.exceptions.DbtRuntimeError):
# during a sql query, an internal to dbt exception was raised.
# this sounds a lot like a signal handler and probably has
# useful information, so raise it without modification.
raise e

raise dbt.exceptions.DbtRuntimeError(str(e)) from e
raise dbt_common.exceptions.DbtRuntimeError(str(e)) from e

@classmethod
def get_credentials(cls, credentials):
Expand Down
86 changes: 37 additions & 49 deletions dbt/adapters/oracle/impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (c) 2023, Oracle and/or its affiliates.
Copyright (c) 2024, Oracle and/or its affiliates.
Copyright (c) 2020, Vitor Avancini
Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -16,7 +16,7 @@
"""
import datetime
from typing import (
Optional, List, Set
Optional, List, Set, FrozenSet, Tuple, Iterable
)
from itertools import chain
from typing import (
Expand All @@ -27,24 +27,23 @@
import agate
import requests

import dbt.exceptions
import dbt_common.exceptions
from dbt_common.contracts.constraints import ConstraintType
from dbt_common.utils import filter_null_values

from dbt.adapters.base.relation import BaseRelation, InformationSchema
from dbt.adapters.base.impl import GET_CATALOG_MACRO_NAME, ConstraintSupport, GET_CATALOG_RELATIONS_MACRO_NAME, _expect_row_value
from dbt.adapters.contracts.relation import RelationConfig
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.base.meta import available
from dbt.adapters.capability import CapabilityDict, CapabilitySupport, Support, Capability

from dbt.adapters.oracle import OracleAdapterConnectionManager
from dbt.adapters.oracle.column import OracleColumn
from dbt.adapters.oracle.relation import OracleRelation
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ConstraintType
from dbt.events import AdapterLogger

from dbt.utils import filter_null_values

from dbt.adapters.oracle.keyword_catalog import KEYWORDS
from dbt.adapters.oracle.python_submissions import OracleADBSPythonJob
from dbt.adapters.oracle.connections import AdapterResponse

logger = AdapterLogger("oracle")

Expand Down Expand Up @@ -139,7 +138,7 @@ def verify_database(self, database):
database = database.strip('"')
expected = self.config.credentials.database
if expected and database.lower() != expected.lower():
raise dbt.exceptions.DbtRuntimeError(
raise dbt_common.exceptions.DbtRuntimeError(
'Cross-db references not allowed in {} ({} vs {})'
.format(self.type(), database, expected)
)
Expand Down Expand Up @@ -210,67 +209,56 @@ def _get_one_catalog(
self,
information_schema: InformationSchema,
schemas: Set[str],
manifest: Manifest,
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:

logger.info(f"GET ONE CATALOG =====> {schemas}")
logger.info(f"GET ONE CATALOG =====> {information_schema}")
logger.info(f"GET ONE CATALOG =====> {used_schemas}")
kwargs = {"information_schema": information_schema, "schemas": schemas}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
# pass in the full manifest so we get any local project
# overrides
manifest=manifest,
)
# In case database is not defined, we can use the the configured database which we set as part of credentials
for node in chain(manifest.nodes.values(), manifest.sources.values()):
if not node.database or node.database == 'None':
node.database = self.config.credentials.database

results = self._catalog_filter_table(table, manifest)
table = self.execute_macro(GET_CATALOG_MACRO_NAME, kwargs=kwargs)
results = self._catalog_filter_table(table, used_schemas=used_schemas)
logger.info(f"GET ONE CATALOG =====> {results}")
return results

def _get_one_catalog_by_relations(
self,
information_schema: InformationSchema,
relations: List[BaseRelation],
manifest: Manifest,
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:

logger.info(f"GET ONE _get_one_catalog_by_relations =====> {relations}")
logger.info(f"GET ONE _get_one_catalog_by_relations =====> {information_schema}")
logger.info(f"GET ONE _get_one_catalog_by_relations =====> {used_schemas}")
kwargs = {
"information_schema": information_schema,
"relations": relations,
}
table = self.execute_macro(
GET_CATALOG_RELATIONS_MACRO_NAME,
kwargs=kwargs,
# pass in the full manifest, so we get any local project
# overrides
manifest=manifest,
)

# In case database is not defined, we can use the the configured database which we set as part of credentials
for node in chain(manifest.nodes.values(), manifest.sources.values()):
if not node.database or node.database == 'None':
node.database = self.config.credentials.database

results = self._catalog_filter_table(table, manifest) # type: ignore[arg-type]
table = self.execute_macro(GET_CATALOG_RELATIONS_MACRO_NAME, kwargs=kwargs)
results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type]
logger.info(f"GET ONE _get_one_catalog_by_relations =====> {results}")
return results

def get_filtered_catalog(
self, manifest: Manifest, relations: Optional[Set[BaseRelation]] = None
self,
relation_configs: Iterable[RelationConfig],
used_schemas: FrozenSet[Tuple[str, str]],
relations: Optional[Set[BaseRelation]] = None
):
logger.info(f"GET ONE get_filtered_catalog =====> {relations}")
logger.info(f"GET ONE get_filtered_catalog =====> {relations}")
logger.info(f"GET ONE get_filtered_catalog =====> {used_schemas}")
catalogs: agate.Table
if (
relations is None
or len(relations) > 100
or not self.supports(Capability.SchemaMetadataByRelations)
):
# Do it the traditional way. We get the full catalog.
catalogs, exceptions = self.get_catalog(manifest)
catalogs, exceptions = self.get_catalog(relation_configs, used_schemas)
else:
# Do it the new way. We try to save time by selecting information
# only for the exact set of relations we are interested in.
catalogs, exceptions = self.get_catalog_by_relations(manifest, relations)
catalogs, exceptions = self.get_catalog_by_relations(used_schemas, relations)

if relations and catalogs:
relation_map = {
Expand Down Expand Up @@ -388,8 +376,8 @@ def quote_seed_column(
elif quote_config is None:
pass
else:
raise dbt.exceptions.CompilationError(f'The seed configuration value of "quote_columns" '
f'has an invalid type {type(quote_config)}')
raise dbt_common.exceptions.CompilationError(f'The seed configuration value of "quote_columns" '
f'has an invalid type {type(quote_config)}')

if quote_columns:
return self.quote(column)
Expand Down Expand Up @@ -417,7 +405,7 @@ def render_raw_columns_constraints(cls, raw_columns: Dict[str, Dict[str, Any]])

def get_oml_auth_token(self) -> str:
if self.config.credentials.oml_auth_token_uri is None:
raise dbt.exceptions.DbtRuntimeError("oml_auth_token_uri should be set to run dbt-py models")
raise dbt_common.exceptions.DbtRuntimeError("oml_auth_token_uri should be set to run dbt-py models")
data = {
"grant_type": "password",
"username": self.config.credentials.user,
Expand All @@ -428,7 +416,7 @@ def get_oml_auth_token(self) -> str:
json=data)
r.raise_for_status()
except requests.exceptions.RequestException:
raise dbt.exceptions.DbtRuntimeError("Error getting OML OAuth2.0 token")
raise dbt_common.exceptions.DbtRuntimeError("Error getting OML OAuth2.0 token")
else:
return r.json()["accessToken"]

Expand Down
22 changes: 11 additions & 11 deletions dbt/adapters/oracle/python_submissions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (c) 2023, Oracle and/or its affiliates.
Copyright (c) 2024, Oracle and/or its affiliates.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -23,10 +23,10 @@
import requests
import time

import dbt.exceptions
import dbt_common.exceptions
from dbt.adapters.oracle import OracleAdapterCredentials
from dbt.events import AdapterLogger
from dbt.ui import red, green
from dbt.adapters.events.logging import AdapterLogger
from dbt_common.ui import red, green
from dbt.version import __version__ as dbt_version

# ADB-S OML Rest API minimum timeout is 1800 seconds
Expand Down Expand Up @@ -170,7 +170,7 @@ def schedule_async_job_and_wait_for_completion(self, data):
r.raise_for_status()
except requests.exceptions.RequestException as e:
logger.error(red(f"Error {e} scheduling async Python job for model {self.identifier}"))
raise dbt.exceptions.DbtRuntimeError(f"Error scheduling Python model {self.identifier}")
raise dbt_common.exceptions.DbtRuntimeError(f"Error scheduling Python model {self.identifier}")

job_location = r.headers["location"]
logger.info(f"Started async job {job_location}")
Expand All @@ -192,24 +192,24 @@ def schedule_async_job_and_wait_for_completion(self, data):
job_result_json = job_result.json()
if 'errorMessage' in job_result_json:
logger.error(red(f"FAILURE - Python model {self.identifier} Job failure is: {job_result_json}"))
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
raise dbt_common.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
job_result.raise_for_status()
logger.info(green(f"SUCCESS - Python model {self.identifier} Job result is: {job_result_json}"))
return
elif job_status_code == http.HTTPStatus.INTERNAL_SERVER_ERROR:
logger.error(red(f"FAILURE - Job status is: {job_status.json()}"))
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
raise dbt_common.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
else:
logger.debug(f"Python model {self.identifier} job status is: {job_status.json()}")
job_status.raise_for_status()

except requests.exceptions.RequestException as e:
logger.error(red(f"Error {e} checking status of Python job {job_location} for model {self.identifier}"))
raise dbt.exceptions.DbtRuntimeError(f"Error checking status for job {job_location}")
raise dbt_common.exceptions.DbtRuntimeError(f"Error checking status for job {job_location}")

time.sleep(DEFAULT_DELAY_BETWEEN_POLL_IN_SECONDS)
logger.error(red(f"Timeout error for Python model {self.identifier}"))
raise dbt.exceptions.DbtRuntimeError(f"Timeout error for Python model {self.identifier}")
raise dbt_common.exceptions.DbtRuntimeError(f"Timeout error for Python model {self.identifier}")

def __call__(self, *args, **kwargs):
data = {
Expand All @@ -234,10 +234,10 @@ def __call__(self, *args, **kwargs):
job_result = r.json()
if 'errorMessage' in job_result:
logger.error(red(f"FAILURE - Python model {self.identifier} Job failure is: {job_result}"))
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
raise dbt_common.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
r.raise_for_status()
logger.info(green(f"SUCCESS - Python model {self.identifier} Job result is: {job_result}"))
except requests.exceptions.RequestException as e:
logger.error(red(f"Error {e} running Python model {self.identifier}"))
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
raise dbt_common.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")

5 changes: 3 additions & 2 deletions dbt/adapters/oracle/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
from dataclasses import dataclass, field

from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.relation_configs import (
RelationConfigBase,
RelationConfigChangeAction,
RelationResults,
)
from dbt.context.providers import RuntimeConfigObject
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import RelationType
from dbt.adapters.base import RelationType
from dbt.exceptions import DbtRuntimeError

from dbt.adapters.oracle.relation_configs import (
Expand All @@ -40,7 +41,7 @@
OracleQuotePolicy,
OracleIncludePolicy)

from dbt.events import AdapterLogger


logger = AdapterLogger("oracle")

Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/oracle/relation_configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
RelationResults,
)
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import ComponentName
from dbt.adapters.contracts.relation import ComponentName

from dbt.adapters.oracle.relation_configs.policies import (
OracleQuotePolicy,
Expand Down
Loading

0 comments on commit 1a7706e

Please sign in to comment.