Skip to content

Commit

Permalink
Merge pull request #123 from datastax/bugfix/#122-vector-find-when-nulls
Browse files Browse the repository at this point in the history
Fix bug in vector_ methods and rearrange tests
  • Loading branch information
hemidactylus authored Nov 24, 2023
2 parents 25c8c74 + 5e4f1a5 commit 40ff784
Show file tree
Hide file tree
Showing 11 changed files with 1,344 additions and 872 deletions.
15 changes: 13 additions & 2 deletions README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -378,16 +378,27 @@ black --check tests && ruff tests && mypy tests

### Testing

Ensure you provide all required environment variables:
Ensure you provide all required environment variables (you can do so by editing `tests/.env` after `tests/.env.template`):

```bash
export ASTRA_DB_ID="..."
export ASTRA_DB_APPLICATION_TOKEN="..."
export ASTRA_DB_API_ENDPOINT="..."
export ASTRA_DB_KEYSPACE="..." # Optional

export ASTRA_DB_ID="..." # For the Ops testing only
export ASTRA_DB_OPS_APPLICATION_TOKEN="..." # Ops-only, falls back to the other token
```

then you can run:

```bash
pytest
```

To remove the noise from the logs (on by default), run `pytest -o log_cli=0`.

To enable the `AstraDBOps` testing (off by default):

```bash
TEST_ASTRADBOPS=1 pytest
```
2 changes: 1 addition & 1 deletion astrapy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.6.1"
__version__ = "0.6.2"
60 changes: 12 additions & 48 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import json
from functools import partial
from typing import Any, cast, Dict, Iterable, List, Optional, Tuple
from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Union

import httpx

Expand Down Expand Up @@ -148,17 +148,6 @@ def _pre_process_find(

return sort, projection

def _finalize_document_from_find(
self, itm: API_DOC, include_similarity: bool = True
) -> API_DOC:
# Clean away the returned similarity score if so desired
if include_similarity:
if "$similarity" not in itm:
raise ValueError("Expected '$similarity' not found in document.")
return itm
else:
return {k: v for k, v in itm.items() if k != "$similarity"}

def get(self, path: Optional[str] = None) -> Optional[API_RESPONSE]:
"""
Retrieve a document from the collection by its path.
Expand Down Expand Up @@ -241,15 +230,7 @@ def vector_find(
},
)

# Post-process the return
find_result = [
self._finalize_document_from_find(
raw_doc, include_similarity=include_similarity
)
for raw_doc in raw_find_result["data"]["documents"]
]

return find_result
return cast(List[API_DOC], raw_find_result["data"]["documents"])

@staticmethod
def paginate(
Expand Down Expand Up @@ -398,7 +379,7 @@ def vector_find_one_and_replace(
*,
filter: Optional[Dict[str, Any]] = None,
fields: Optional[List[str]] = None,
) -> API_DOC:
) -> Union[API_DOC, None]:
"""
Perform a vector-based search and replace the first matched document.
Args:
Expand All @@ -407,7 +388,7 @@ def vector_find_one_and_replace(
filter (dict, optional): Criteria to filter documents.
fields (list, optional): Specifies the fields to return in the result.
Returns:
dict: The result of the vector find and replace operation.
dict or None: either the matched document or None if nothing found
"""
# Pre-process the included arguments
sort, _ = self._pre_process_find(
Expand All @@ -422,13 +403,7 @@ def vector_find_one_and_replace(
sort=sort,
)

# Post-process the return
find_result = self._finalize_document_from_find(
raw_find_result["data"]["document"],
include_similarity=False,
)

return find_result
return cast(Union[API_DOC, None], raw_find_result["data"]["document"])

def find_one_and_update(
self,
Expand Down Expand Up @@ -470,7 +445,7 @@ def vector_find_one_and_update(
*,
filter: Optional[Dict[str, Any]] = None,
fields: Optional[List[str]] = None,
) -> API_DOC:
) -> Union[API_DOC, None]:
"""
Perform a vector-based search and update the first matched document.
Args:
Expand All @@ -479,7 +454,8 @@ def vector_find_one_and_update(
filter (dict, optional): Criteria to filter documents before applying the vector search.
fields (list, optional): Specifies the fields to return in the updated document.
Returns:
dict: The result of the vector-based find and update operation.
dict or None: The result of the vector-based find and
update operation, or None if nothing found
"""
# Pre-process the included arguments
sort, _ = self._pre_process_find(
Expand All @@ -494,13 +470,7 @@ def vector_find_one_and_update(
sort=sort,
)

# Post-process the return
find_result = self._finalize_document_from_find(
raw_find_result["data"]["document"],
include_similarity=False,
)

return find_result
return cast(Union[API_DOC, None], raw_find_result["data"]["document"])

def find_one(
self,
Expand Down Expand Up @@ -544,7 +514,7 @@ def vector_find_one(
filter: Optional[Dict[str, Any]] = None,
fields: Optional[List[str]] = None,
include_similarity: bool = True,
) -> API_DOC:
) -> Union[API_DOC, None]:
"""
Perform a vector-based search to find a single document in the collection.
Args:
Expand All @@ -553,7 +523,7 @@ def vector_find_one(
fields (list, optional): Specifies the fields to return in the result.
include_similarity (bool, optional): Whether to include similarity score in the result.
Returns:
dict: The found document or None if no matching document is found.
dict or None: The found document or None if no matching document is found.
"""
# Pre-process the included arguments
sort, projection = self._pre_process_find(
Expand All @@ -569,13 +539,7 @@ def vector_find_one(
options={"includeSimilarity": include_similarity},
)

# Post-process the return
find_result = self._finalize_document_from_find(
raw_find_result["data"]["document"],
include_similarity=include_similarity,
)

return find_result
return cast(Union[API_DOC, None], raw_find_result["data"]["document"])

def insert_one(
self, document: API_DOC, failures_allowed: bool = False
Expand Down
16 changes: 16 additions & 0 deletions tests/.env.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
########################
# FOR THE REGULAR TESTS:
########################
ASTRA_DB_APPLICATION_TOKEN="AstraCS:..."
ASTRA_DB_API_ENDPOINT="https://<DB_ID>-<DB_REGION>.apps.astra.datastax.com"
#
# OPTIONAL:
# ASTRA_DB_KEYSPACE="..."


###################
# FOR THE OPS TEST:
###################
ASTRA_DB_ID="..."
# OPTIONAL (falls back to the token above)
ASTRA_DB_OPS_APPLICATION_TOKEN="..."
Loading

0 comments on commit 40ff784

Please sign in to comment.