Skip to content

Commit

Permalink
Add flags if present in query
Browse files Browse the repository at this point in the history
Signed-off-by: Olga Bulat <obulat@gmail.com>
  • Loading branch information
obulat committed Nov 27, 2023
1 parent e8cb711 commit c6f9a28
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
26 changes: 22 additions & 4 deletions api/api/controllers/search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import logging as log
import re
from math import ceil
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -46,7 +47,6 @@

module_logger = logging.getLogger(__name__)


NESTING_THRESHOLD = config("POST_PROCESS_NESTING_THRESHOLD", cast=int, default=5)
SOURCE_CACHE_TIMEOUT = 60 * 60 * 4 # 4 hours
FILTER_CACHE_TIMEOUT = 30
Expand Down Expand Up @@ -284,9 +284,11 @@ def build_search_query(
# individual field-level queries specified.
if "q" in search_params.data:
query = _quote_escape(search_params.data["q"])
sqs_flags = extract_flags_from_query(query, query_name="q")

base_query_kwargs = {
"query": query,
"flags": DEFAULT_SQS_FLAGS,
"flags": sqs_flags,
"fields": DEFAULT_SEARCH_FIELDS,
"default_operator": "AND",
}
Expand All @@ -299,7 +301,7 @@ def build_search_query(
quotes_stripped = query.replace('"', "")
exact_match_boost = Q(
"simple_query_string",
flags=DEFAULT_SQS_FLAGS,
flags=sqs_flags,
fields=["title"],
query=f"{quotes_stripped}",
boost=10000,
Expand All @@ -312,10 +314,11 @@ def build_search_query(
("tags", "tags.name"),
]:
if field_value := search_params.data.get(field):
sqs_flags = extract_flags_from_query(field_value, query_name="field")
search_queries["must"].append(
Q(
"simple_query_string",
flags=DEFAULT_SQS_FLAGS,
flags=sqs_flags,
query=_quote_escape(field_value),
fields=[field_name],
)
Expand All @@ -339,6 +342,21 @@ def build_search_query(
)


def extract_flags_from_query(query: str, query_name) -> str:
sqs_flags = DEFAULT_SQS_FLAGS
flags = [
("PRECEDENCE", r"\(.*\)"),
("ESCAPE", r"\\"),
("FUZZY|SLOP", r"~\d"),
("PREFIX", r"\*"),
]
for flag, pattern in flags:
if bool(re.search(pattern, query)):
log.info(f"Special feature in `{query_name}` query string. {flag}: {query}")
sqs_flags += f"|{flag}"
return sqs_flags


def build_collection_query(
search_params: MediaListRequestSerializer,
collection_params: dict[str, str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from elasticsearch_dsl import Q

from api.controllers import search_controller
from api.controllers.search_controller import DEFAULT_SQS_FLAGS, FILTERED_PROVIDERS_CACHE_KEY
from api.controllers.search_controller import (
DEFAULT_SQS_FLAGS,
FILTERED_PROVIDERS_CACHE_KEY,
)


pytestmark = pytest.mark.django_db
Expand Down

0 comments on commit c6f9a28

Please sign in to comment.