diff --git a/tiled/_tests/test_queries.py b/tiled/_tests/test_queries.py index 6914fd34c..7af19eebb 100644 --- a/tiled/_tests/test_queries.py +++ b/tiled/_tests/test_queries.py @@ -45,6 +45,9 @@ mapping["does_not_contain_z"] = ArrayAdapter.from_array( numpy.ones(10), metadata={"letters": list(string.ascii_lowercase[:-1])} ) +mapping["full_text_test_case"] = ArrayAdapter.from_array( + numpy.ones(10), metadata={"color": "purple"} +) mapping["specs_foo_bar"] = ArrayAdapter.from_array(numpy.ones(10), specs=["foo", "bar"]) mapping["specs_foo_bar_baz"] = ArrayAdapter.from_array( @@ -159,7 +162,7 @@ def test_contains(client): def test_full_text(client): - if client.metadata["backend"] in {"postgresql", "sqlite"}: + if client.metadata["backend"] in {"sqlite"}: def cm(): return fail_with_status_code(400) @@ -168,6 +171,9 @@ def cm(): cm = nullcontext with cm(): assert list(client.search(FullText("z"))) == ["z", "does_contain_z"] + # plainto_tsquery fails to find certain words, weirdly, so it is a useful + # test that we are using tsquery + assert list(client.search(FullText("purple"))) == ["full_text_test_case"] def test_regex(client): diff --git a/tiled/adapters/mapping.py b/tiled/adapters/mapping.py index 037a20c44..6fe2ab7fa 100644 --- a/tiled/adapters/mapping.py +++ b/tiled/adapters/mapping.py @@ -75,10 +75,10 @@ def __init__( specs : List[str], optional access_policy : AccessPolicy, optional entries_stale_after: timedelta - This server uses this to communite to the client how long + This server uses this to communicate to the client how long it should rely on a local cache before checking back for changes. metadata_stale_after: timedelta - This server uses this to communite to the client how long + This server uses this to communicate to the client how long it should rely on a local cache before checking back for changes. must_revalidate : bool Whether the client should strictly refresh stale cache items. @@ -336,20 +336,12 @@ def iter_child_metadata(query_key, tree): def full_text_search(query, tree): matches = {} text = query.text - if query.case_sensitive: - - def maybe_lower(s): - # no-op - return s - - else: - maybe_lower = str.lower query_words = set(text.split()) for key, value in tree.items(): words = set( word for s in walk_string_values(value.metadata()) - for word in maybe_lower(s).split() + for word in s.lower().split() ) # Note that `not set.isdisjoint` is faster than `set.intersection`. At # the C level, `isdisjoint` loops over the set until it finds one match, diff --git a/tiled/catalog/adapter.py b/tiled/catalog/adapter.py index 60421d18a..30fa83cab 100644 --- a/tiled/catalog/adapter.py +++ b/tiled/catalog/adapter.py @@ -13,13 +13,16 @@ import anyio from fastapi import HTTPException from sqlalchemy import delete, event, func, not_, or_, select, text, type_coerce, update +from sqlalchemy.dialects.postgresql import JSONB, REGCONFIG from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.sql.expression import cast from tiled.queries import ( Comparison, Contains, Eq, + FullText, In, KeysFilter, NotEq, @@ -997,6 +1000,20 @@ def contains(query, tree): return tree.new_variation(conditions=tree.conditions + [condition]) +def full_text(query, tree): + dialect_name = tree.engine.url.get_dialect().name + if dialect_name == "sqlite": + raise UnsupportedQueryType("full_text") + elif dialect_name == "postgresql": + tsvector = func.jsonb_to_tsvector( + cast("simple", REGCONFIG), orm.Node.metadata_, cast(["string"], JSONB) + ) + condition = tsvector.op("@@")(func.to_tsquery("simple", query.text)) + else: + raise UnsupportedQueryType("full_text") + return tree.new_variation(conditions=tree.conditions + [condition]) + + def specs(query, tree): dialect_name = tree.engine.url.get_dialect().name conditions = [] @@ -1068,7 +1085,8 @@ def structure_family(query, tree): CatalogNodeAdapter.register_query(KeysFilter, keys_filter) CatalogNodeAdapter.register_query(StructureFamilyQuery, structure_family) CatalogNodeAdapter.register_query(SpecsQuery, specs) -# TODO: FullText, Regex +CatalogNodeAdapter.register_query(FullText, full_text) +# TODO: Regex def in_memory( diff --git a/tiled/catalog/core.py b/tiled/catalog/core.py index 992c8a1da..e4d7fcb1b 100644 --- a/tiled/catalog/core.py +++ b/tiled/catalog/core.py @@ -5,6 +5,7 @@ # This is list of all valid revisions (from current to oldest). ALL_REVISIONS = [ + "1cd99c02d0c7", "a66028395cab", "3db11ff95b6c", "0b033e7fbe30", diff --git a/tiled/catalog/migrations/versions/1cd99c02d0c7_create_index_for_fulltext_search.py b/tiled/catalog/migrations/versions/1cd99c02d0c7_create_index_for_fulltext_search.py new file mode 100644 index 000000000..dfc8a16fe --- /dev/null +++ b/tiled/catalog/migrations/versions/1cd99c02d0c7_create_index_for_fulltext_search.py @@ -0,0 +1,39 @@ +"""Create index for fulltext search + +Revision ID: 1cd99c02d0c7 +Revises: a66028395cab +Create Date: 2024-01-24 15:53:12.348880 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import JSONB + +# revision identifiers, used by Alembic. +revision = "1cd99c02d0c7" +down_revision = "a66028395cab" +branch_labels = None +depends_on = None + +# Make JSONB available in column +JSONVariant = sa.JSON().with_variant(JSONB(), "postgresql") + + +def upgrade(): + connection = op.get_bind() + if connection.engine.dialect.name == "postgresql": + with op.get_context().autocommit_block(): + # There is no sane way to perform this using op.create_index() + op.execute( + """ + CREATE INDEX metadata_tsvector_search + ON nodes + USING gin (jsonb_to_tsvector('simple', metadata, '["string"]')) + """ + ) + + +def downgrade(): + # This _could_ be implemented but we will wait for a need since we are + # still in alpha releases. + raise NotImplementedError diff --git a/tiled/catalog/migrations/versions/a66028395cab_enrich_datasource_asset_association.py b/tiled/catalog/migrations/versions/a66028395cab_enrich_datasource_asset_association.py index f2fbae9cf..82a7fd902 100644 --- a/tiled/catalog/migrations/versions/a66028395cab_enrich_datasource_asset_association.py +++ b/tiled/catalog/migrations/versions/a66028395cab_enrich_datasource_asset_association.py @@ -8,11 +8,7 @@ import sqlalchemy as sa from alembic import op -from tiled.catalog.orm import ( - DataSourceAssetAssociation, - JSONVariant, - unique_parameter_num_null_check, -) +from tiled.catalog.orm import JSONVariant # revision identifiers, used by Alembic. revision = "a66028395cab" @@ -32,7 +28,7 @@ def upgrade(): sa.Column("structure", JSONVariant), ) data_source_asset_association = sa.Table( - DataSourceAssetAssociation.__tablename__, + "data_source_asset_association", sa.MetaData(), sa.Column("asset_id", sa.Integer), sa.Column("data_source_id", sa.Integer), @@ -67,11 +63,11 @@ def upgrade(): # Add columns 'parameter' and 'num' to association table. op.add_column( - DataSourceAssetAssociation.__tablename__, + "data_source_asset_association", sa.Column("parameter", sa.Unicode(255), nullable=True), ) op.add_column( - DataSourceAssetAssociation.__tablename__, + "data_source_asset_association", sa.Column("num", sa.Integer, nullable=True), ) @@ -162,7 +158,7 @@ def upgrade(): if connection.engine.dialect.name == "sqlite": # SQLite does not supported adding constraints to an existing table. # We invoke its 'copy and move' functionality. - with op.batch_alter_table(DataSourceAssetAssociation.__tablename__) as batch_op: + with op.batch_alter_table("data_source_asset_association") as batch_op: # Gotcha: This does not take table_name because it is bound into batch_op. batch_op.create_unique_constraint( "parameter_num_unique_constraint", @@ -172,11 +168,15 @@ def upgrade(): "num", ], ) + # This creates a pair of triggers on the data_source_asset_association + # table. Each pair include one trigger that runs when NEW.num IS NULL and + # one trigger than runs when NEW.num IS NOT NULL. Thus, for a given insert, + # only one of these triggers is run. with op.get_context().autocommit_block(): connection.execute( sa.text( """ - CREATE TRIGGER cannot_insert_num_null_if_num_int_exists + CREATE TRIGGER cannot_insert_num_null_if_num_exists BEFORE INSERT ON data_source_asset_association WHEN NEW.num IS NULL BEGIN @@ -214,14 +214,72 @@ def upgrade(): # PostgreSQL op.create_unique_constraint( "parameter_num_unique_constraint", - DataSourceAssetAssociation.__tablename__, + "data_source_asset_association", [ "data_source_id", "parameter", "num", ], ) - unique_parameter_num_null_check(data_source_asset_association, connection) + connection.execute( + sa.text( + """ +CREATE OR REPLACE FUNCTION raise_if_parameter_exists() +RETURNS TRIGGER AS $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM data_source_asset_association + WHERE parameter = NEW.parameter + AND data_source_id = NEW.data_source_id + ) THEN + RAISE EXCEPTION 'Can only insert num=NULL if no other row exists for the same parameter'; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql;""" + ) + ) + connection.execute( + sa.text( + """ +CREATE TRIGGER cannot_insert_num_null_if_num_exists +BEFORE INSERT ON data_source_asset_association +FOR EACH ROW +WHEN (NEW.num IS NULL) +EXECUTE FUNCTION raise_if_parameter_exists();""" + ) + ) + connection.execute( + sa.text( + """ +CREATE OR REPLACE FUNCTION raise_if_null_parameter_exists() +RETURNS TRIGGER AS $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM data_source_asset_association + WHERE parameter = NEW.parameter + AND data_source_id = NEW.data_source_id + AND num IS NULL + ) THEN + RAISE EXCEPTION 'Can only insert INTEGER num if no NULL row exists for the same parameter'; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql;""" + ) + ) + connection.execute( + sa.text( + """ +CREATE TRIGGER cannot_insert_num_int_if_num_null_exists +BEFORE INSERT ON data_source_asset_association +FOR EACH ROW +WHEN (NEW.num IS NOT NULL) +EXECUTE FUNCTION raise_if_null_parameter_exists();""" + ) + ) def downgrade(): diff --git a/tiled/catalog/orm.py b/tiled/catalog/orm.py index bda9d0309..4beacd0b2 100644 --- a/tiled/catalog/orm.py +++ b/tiled/catalog/orm.py @@ -94,7 +94,7 @@ class Node(Timestamped, Base): "id", "metadata", postgresql_using="gin", - ), + ) # This is used by ORDER BY with the default sorting. # Index("ancestors_time_created", "ancestors", "time_created"), ) @@ -149,9 +149,12 @@ class DataSourceAssetAssociation(Base): @event.listens_for(DataSourceAssetAssociation.__table__, "after_create") def unique_parameter_num_null_check(target, connection, **kw): - # Ensure that we cannot mix NULL and INTEGER values of num for - # a given data_source_id and parameter, and that there cannot be multiple - # instances of NULL. + # This creates a pair of triggers on the data_source_asset_association + # table. (There are a total of four defined below, two for the SQLite + # branch and two for the PostgreSQL branch.) Each pair include one trigger + # that runs when NEW.num IS NULL and one trigger than runs when + # NEW.num IS NOT NULL. Thus, for a given insert, only one of these + # triggers is run. if connection.engine.dialect.name == "sqlite": connection.execute( text( @@ -252,6 +255,22 @@ def unique_parameter_num_null_check(target, connection, **kw): ) +@event.listens_for(DataSourceAssetAssociation.__table__, "after_create") +def create_index_metadata_tsvector_search(target, connection, **kw): + # This creates a ts_vector based metadata search index for fulltext. + # Postgres only feature + if connection.engine.dialect.name == "postgresql": + connection.execute( + text( + """ + CREATE INDEX metadata_tsvector_search + ON nodes + USING gin (jsonb_to_tsvector('simple', metadata, '["string"]')) + """ + ) + ) + + class DataSource(Timestamped, Base): """ The describes how to open one or more file/blobs to extract data for a Node. diff --git a/tiled/queries.py b/tiled/queries.py index af5851a5b..10d871e4d 100644 --- a/tiled/queries.py +++ b/tiled/queries.py @@ -42,23 +42,16 @@ class FullText(NoBool): Parameters ---------- text : str - case_sensitive : bool, optional - Default False (case-insensitive). """ text: str - case_sensitive: bool = False def encode(self): - return {"text": self.text, "case_sensitive": json.dumps(self.case_sensitive)} + return {"text": self.text} @classmethod - def decode(cls, *, text, case_sensitive=False): - # Note: FastAPI decodes case_sensitive into a boolean for us. - return cls( - text=text, - case_sensitive=case_sensitive, - ) + def decode(cls, *, text): + return cls(text=text) @register(name="lookup") diff --git a/web-frontend/src/openapi_schemas.ts b/web-frontend/src/openapi_schemas.ts index 238e19ce2..ec494cdad 100644 --- a/web-frontend/src/openapi_schemas.ts +++ b/web-frontend/src/openapi_schemas.ts @@ -409,7 +409,6 @@ export interface operations { sort?: string; omit_links?: boolean; "filter[fulltext][condition][text]"?: string[]; - "filter[fulltext][condition][case_sensitive]"?: boolean[]; "filter[lookup][condition][key]"?: string[]; }; };