diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 50c661bb..56a66ae0 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -402,6 +402,7 @@ def test_json_column(trino_connection, json_object): ins = table_with_json.insert() conn.execute(ins, {"id": 1, "json_column": json_object}) query = sqla.select(table_with_json) + assert isinstance(table_with_json.c.json_column.type, JSON) result = conn.execute(query) rows = result.fetchall() assert len(rows) == 1 @@ -410,6 +411,71 @@ def test_json_column(trino_connection, json_object): metadata.drop_all(engine) +@pytest.mark.skipif( + sqlalchemy_version() < "1.4", + reason="columns argument to select() must be a Python list or other iterable" +) +@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +def test_json_column_operations(trino_connection): + engine, conn = trino_connection + + metadata = sqla.MetaData() + + json_object = { + "a": {"c": 1}, + 100: {"z": 200}, + "b": 2, + 10: 20, + "foo-bar": {"z": 200} + } + + try: + table_with_json = sqla.Table( + 'table_with_json', + metadata, + sqla.Column('json_column', JSON), + schema="default" + ) + metadata.create_all(engine) + ins = table_with_json.insert() + conn.execute(ins, {"json_column": json_object}) + + # JSONPathType + query = sqla.select(table_with_json.c.json_column["a", "c"]) + conn.execute(query) + result = conn.execute(query) + assert result.fetchall()[0][0] == 1 + + query = sqla.select(table_with_json.c.json_column[100, "z"]) + conn.execute(query) + result = conn.execute(query) + assert result.fetchall()[0][0] == 200 + + query = sqla.select(table_with_json.c.json_column["foo-bar", "z"]) + conn.execute(query) + result = conn.execute(query) + assert result.fetchall()[0][0] == 200 + + # JSONIndexType + query = sqla.select(table_with_json.c.json_column["b"]) + conn.execute(query) + result = conn.execute(query) + assert result.fetchall()[0][0] == 2 + + query = sqla.select(table_with_json.c.json_column[10]) + conn.execute(query) + result = conn.execute(query) + assert result.fetchall()[0][0] == 20 + + query = sqla.select(table_with_json.c.json_column["foo-bar"]) + conn.execute(query) + result = conn.execute(query) + assert result.fetchall()[0][0] == {'z': 200} + + finally: + metadata.drop_all(engine) + + @pytest.mark.parametrize('trino_connection', ['system'], indirect=True) def test_get_catalog_names(trino_connection): engine, conn = trino_connection diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index 8747d190..fef6beb1 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from sqlalchemy.sql import compiler +from sqlalchemy.sql import compiler, sqltypes from sqlalchemy.sql.base import DialectKWArgs # https://trino.io/docs/current/language/reserved.html @@ -125,6 +125,19 @@ def add_catalog(sql, table): sql = f'"{catalog}".{sql}' return sql + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def _render_json_extract_from_binary(self, binary, operator, **kw): + if binary.type._type_affinity is sqltypes.JSON: + return "JSON_EXTRACT(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + class TrinoDDLCompiler(compiler.DDLCompiler): pass diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py index 9c1eed28..5bcacc3b 100644 --- a/trino/sqlalchemy/datatype.py +++ b/trino/sqlalchemy/datatype.py @@ -9,15 +9,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import re from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union import sqlalchemy -from sqlalchemy import util +from sqlalchemy import func, util from sqlalchemy.sql import sqltypes from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine -from sqlalchemy.types import String +from sqlalchemy.types import JSON SQLType = Union[TypeEngine, Type[TypeEngine]] @@ -75,16 +74,59 @@ def __init__(self, precision=None, timezone=False): class JSON(TypeDecorator): - impl = String + impl = JSON - def process_bind_param(self, value, dialect): - return json.dumps(value) + def bind_expression(self, bindvalue): + return func.JSON_PARSE(bindvalue) - def process_result_value(self, value, dialect): - return json.loads(value) - def get_col_spec(self, **kw): - return 'JSON' +class _FormatTypeMixin: + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class _JSONFormatter: + @staticmethod + def format_index(value): + return "$[\"%s\"]" % value + + @staticmethod + def format_path(value): + return "$%s" % ( + "".join(["[\"%s\"]" % elem for elem in value]) + ) + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value): + return _JSONFormatter.format_index(value) + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return _JSONFormatter.format_path(value) # https://trino.io/docs/current/language/types.html diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 9a401d4a..52da4ac3 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -19,6 +19,7 @@ from sqlalchemy.engine.base import Connection from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext from sqlalchemy.engine.url import URL +from sqlalchemy.sql import sqltypes from trino import dbapi as trino_dbapi from trino import logging @@ -31,10 +32,25 @@ from trino.dbapi import Cursor from trino.sqlalchemy import compiler, datatype, error +from .datatype import JSONIndexType, JSONPathType + logger = logging.get_logger(__name__) +colspecs = { + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, +} + class TrinoDialect(DefaultDialect): + def __init__(self, + json_serializer=None, + json_deserializer=None, + **kwargs): + DefaultDialect.__init__(self, **kwargs) + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + name = "trino" driver = "rest" @@ -70,6 +86,7 @@ class TrinoDialect(DefaultDialect): # Support proper ordering of CTEs in regard to an INSERT statement cte_follows_insert = True + colspecs = colspecs @classmethod def dbapi(cls):