Skip to content
This repository was archived by the owner on Sep 12, 2022. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions executor/hnswlib_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from jina import DocumentArray, Document
from jina.logging.logger import JinaLogger

import warnings

GENERATOR_DELTA = Generator[
Tuple[str, Optional[np.ndarray], Optional[datetime]], None, None
]
Expand All @@ -36,7 +38,8 @@ def __init__(
ef_query: int = 50,
max_connection: int = 16,
dump_path: Optional[str] = None,
traversal_paths: str = '@r',
access_paths: str = '@r',
traversal_paths: Optional[str] = None,
is_distance: bool = True,
last_timestamp: datetime = datetime.fromtimestamp(0, timezone.utc),
num_threads: int = -1,
Expand All @@ -56,8 +59,9 @@ def __init__(
graph (the "M" parameter)
:param dump_path: The path to the directory from where to load, and where to
save the index state
:param traversal_paths: The default traversal path on docs (used for
:param access_paths: The default traversal path on docs (used for
indexing, search and update), e.g. '@r', '@c', '@r,c'
:param traversal_paths: please use access_paths
:param is_distance: Boolean flag that describes if distance metric need to
be reinterpreted as similarities.
:param last_timestamp: the last time we synced into this HNSW index
Expand All @@ -67,7 +71,13 @@ def __init__(
self.metric = metric
self.dim = dim
self.max_elements = max_elements
self.traversal_paths = traversal_paths
if traversal_paths is not None:
self.access_paths = traversal_paths
warnings.warn("'traversal_paths' will be deprecated in the future, please use 'access_paths'.",
DeprecationWarning,
stacklevel=2)
else:
self.access_paths = access_paths
self.ef_construction = ef_construction
self.ef_query = ef_query
self.max_connection = max_connection
Expand Down Expand Up @@ -115,13 +125,13 @@ def search(
of the same dimension as vectors in the index
:param parameters: Dictionary with optional parameters that can be used to
override the parameters set at initialization. Supported keys are
`traversal_paths`, `limit` and `ef_query`.
`access_paths`, `limit` and `ef_query`.
"""
if docs is None:
return

traversal_paths = parameters.get('traversal_paths', self.traversal_paths)
docs_search = docs[traversal_paths]
access_paths = parameters.get('access_paths', self.access_paths)
docs_search = docs[access_paths]
if len(docs_search) == 0:
return

Expand Down Expand Up @@ -166,13 +176,13 @@ def index(
:param docs: Documents whose `embedding` to index.
:param parameters: Dictionary with optional parameters that can be used to
override the parameters set at initialization. The only supported key is
`traversal_paths`.
`access_paths`.
"""
traversal_paths = parameters.get('traversal_paths', self.traversal_paths)
access_paths = parameters.get('access_paths', self.access_paths)
if docs is None:
return

docs_to_index = docs[traversal_paths]
docs_to_index = docs[access_paths]
if len(docs_to_index) == 0:
return

Expand Down Expand Up @@ -212,13 +222,13 @@ def update(
:param docs: Documents whose `embedding` to update.
:param parameters: Dictionary with optional parameters that can be used to
override the parameters set at initialization. The only supported key is
`traversal_paths`.
`access_paths`.
"""
traversal_paths = parameters.get('traversal_paths', self.traversal_paths)
access_paths = parameters.get('access_paths', self.access_paths)
if docs is None:
return

docs_to_update = docs[traversal_paths]
docs_to_update = docs[access_paths]
if len(docs_to_update) == 0:
return

Expand Down
18 changes: 10 additions & 8 deletions executor/hnswpsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(
max_connection: int = 64,
is_distance: bool = True,
num_threads: int = -1,
traversal_paths: str = '@r',
access_paths: str = '@r',
traversal_paths: Optional[str] = None,
hostname: str = '127.0.0.1',
port: int = 5432,
username: str = 'postgres',
Expand Down Expand Up @@ -85,8 +86,9 @@ def __init__(
similarity
:param last_timestamp: (HNSW) the last time we synced into this HNSW index
:param num_threads: (HNSW) nr of threads to use during indexing. -1 is default
:param traversal_paths: (PSQL) default traversal paths on docs
:param access_paths: (PSQL) default traversal paths on docs
(used for indexing, delete and update), e.g. '@r', '@c', '@r,c'
:param traversal_paths: please use access_paths
:param hostname: (PSQL) hostname of the machine
:param port: (PSQL) the port
:param username: (PSQL) the username to authenticate
Expand Down Expand Up @@ -242,7 +244,7 @@ def index(self, docs: DocumentArray, parameters: Dict, **kwargs):

Keys accepted:

- 'traversal_paths' (str): traversal path for the docs
- 'access_paths' (str): traversal path for the docs
"""
self._kv_indexer.add(docs, parameters, **kwargs)

Expand All @@ -255,7 +257,7 @@ def update(self, docs: DocumentArray, parameters: Dict, **kwargs):

Keys accepted:

- 'traversal_paths' (str): traversal path for the docs
- 'access_paths' (str): traversal path for the docs
"""
self._kv_indexer.update(docs, parameters, **kwargs)

Expand All @@ -268,7 +270,7 @@ def delete(self, docs: DocumentArray, parameters: Dict, **kwargs):

Keys accepted:

- 'traversal_paths' (str): traversal path for the docs
- 'access_paths' (str): traversal path for the docs
- 'soft_delete' (bool, default `True`): whether to perform soft delete
(doc is marked as empty but still exists in db, for retrieval purposes)
"""
Expand Down Expand Up @@ -351,7 +353,7 @@ def search(self, docs: 'DocumentArray', parameters: Dict = None, **kwargs):
:param parameters: dictionary for parameters for the search operation


- 'traversal_paths' (str): traversal paths for the docs
- 'access_paths' (str): traversal paths for the docs
- 'limit' (int): nr of matches to get per Document
- 'ef_query' (int): query time accuracy/speed trade-off. High is more
accurate but slower
Expand All @@ -362,10 +364,10 @@ def search(self, docs: 'DocumentArray', parameters: Dict = None, **kwargs):
self._vec_indexer.search(docs, parameters)

kv_parameters = copy.deepcopy(parameters)
kv_parameters['traversal_paths'] = ','.join(
kv_parameters['access_paths'] = ','.join(
[
path + 'm'
for path in kv_parameters.get('traversal_paths', '@r').split(',')
for path in kv_parameters.get('access_paths', '@r').split(',')
]
)
self._kv_indexer.search(docs, kv_parameters)
Expand Down
40 changes: 26 additions & 14 deletions executor/postgres_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from jina import Document, DocumentArray
from jina.logging.logger import JinaLogger

from typing import Optional
import warnings

from .commons import export_dump_streaming # this is for local testing
from .postgreshandler import PostgreSQLHandler

Expand All @@ -36,7 +39,8 @@ def __init__(
database: str = 'postgres',
table: str = 'default_table',
max_connections=5,
traversal_paths: str = '@r',
access_paths: str = '@r',
traversal_paths: Optional[str] = None,
return_embeddings: bool = True,
dry_run: bool = False,
partitions: int = 128,
Expand All @@ -63,7 +67,15 @@ def __init__(
ids constraint failing (useful when indexing with shards and
polling = 'all')
"""
self.default_traversal_paths = traversal_paths

if traversal_paths is not None:
warnings.warn("'traversal_paths' will be deprecated in the future, please use 'access_paths'.",
DeprecationWarning,
stacklevel=2)
self.default_access_paths = traversal_paths
else:
self.default_access_paths = access_paths

self.hostname = hostname
self.port = port
self.username = username
Expand Down Expand Up @@ -115,10 +127,10 @@ def add(self, docs: DocumentArray, parameters: Dict, **kwargs):
"""
if docs is None:
return
traversal_paths = parameters.get(
'traversal_paths', self.default_traversal_paths
access_paths = parameters.get(
'access_paths', self.default_access_paths
)
self.handler.add(docs[traversal_paths])
self.handler.add(docs[access_paths])

def update(self, docs: DocumentArray, parameters: Dict, **kwargs):
"""Updated document from the database.
Expand All @@ -128,10 +140,10 @@ def update(self, docs: DocumentArray, parameters: Dict, **kwargs):
"""
if docs is None:
return
traversal_paths = parameters.get(
'traversal_paths', self.default_traversal_paths
access_paths = parameters.get(
'access_paths', self.default_access_paths
)
self.handler.update(docs[traversal_paths])
self.handler.update(docs[access_paths])

def cleanup(self, **kwargs):
"""
Expand All @@ -153,11 +165,11 @@ def delete(self, docs: DocumentArray, parameters: Dict, **kwargs):
"""
if docs is None:
return
traversal_paths = parameters.get(
'traversal_paths', self.default_traversal_paths
access_paths = parameters.get(
'access_paths', self.default_access_paths
)
soft_delete = parameters.get('soft_delete', False)
self.handler.delete(docs[traversal_paths], soft_delete)
self.handler.delete(docs[access_paths], soft_delete)

def dump(self, parameters: Dict, **kwargs):
"""Dump the index
Expand Down Expand Up @@ -198,12 +210,12 @@ def search(self, docs: DocumentArray, parameters: Dict, **kwargs):
"""
if docs is None:
return
traversal_paths = parameters.get(
'traversal_paths', self.default_traversal_paths
access_paths = parameters.get(
'access_paths', self.default_access_paths
)

self.handler.search(
docs[traversal_paths],
docs[access_paths],
return_embeddings=parameters.get(
'return_embeddings', self.default_return_embeddings
),
Expand Down