Skip to content

Commit

Permalink
housekeeping
Browse files Browse the repository at this point in the history
Co-Authored-By: Glib <71976818+GLEF1X@users.noreply.github.com>
  • Loading branch information
isafulf and GLEF1X committed Mar 28, 2023
1 parent eef0916 commit 958bb78
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 61 deletions.
12 changes: 5 additions & 7 deletions datastore/providers/milvus_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _create_collection(self, create_new: bool) -> None:

# If no index on the collection, create one
if len(self.col.indexes) == 0:
if self.index_params != None:
if self.index_params is not None:
# Create an index on the 'embedding' field with the index params found in init
self.col.create_index("embedding", index_params=self.index_params)
else:
Expand All @@ -215,9 +215,7 @@ def _create_collection(self, create_new: bool) -> None:
print("Creation of Zilliz Cloud default index successful")
# If an index already exists, grab its params
else:
self.index_params = self.col.indexes[0].to_dict()['index_param']


self.index_params = self.col.indexes[0].to_dict()["index_param"]

self.col.load()

Expand Down Expand Up @@ -327,7 +325,7 @@ async def _single_query(query: QueryWithEmbedding) -> QueryResult:

filter = None
# Set the filter to expression that is valid for Milvus
if query.filter != None:
if query.filter is not None:
# Either a valid filter or None will be returned
filter = self._get_filter(query.filter)

Expand Down Expand Up @@ -404,7 +402,7 @@ async def delete(
delete_count = 0

# Check if empty ids
if ids != None:
if ids is not None:
if len(ids) != 0:
# Add quotation marks around the string format id
ids = ['"' + str(id) + '"' for id in ids]
Expand All @@ -420,7 +418,7 @@ async def delete(
delete_count += int(res.delete_count) # type: ignore

# Check if empty filter
if filter != None:
if filter is not None:
# Convert filter to milvus expression
filter = self._get_filter(filter) # type: ignore
# Check if there is anything to filter
Expand Down
4 changes: 2 additions & 2 deletions datastore/providers/pinecone_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def delete(
Removes vectors by ids, filter, or everything from the index.
"""
# Delete all vectors from the index if delete_all is True
if delete_all == True:
if delete_all:
try:
print(f"Deleting all vectors from index")
self.index.delete(delete_all=True)
Expand All @@ -205,7 +205,7 @@ async def delete(
raise e

# Delete vectors that match the document ids from the index if the ids list is not empty
if ids != None and len(ids) > 0:
if ids is not None and len(ids) > 0:
try:
print(f"Deleting vectors with ids {ids}")
pinecone_filter = {"document_id": {"$in": ids}}
Expand Down
14 changes: 9 additions & 5 deletions datastore/providers/redis_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ def unpack_schema(d: dict):
else:
yield v


async def _check_redis_module_exist(client: redis.Redis, modules: List[str]) -> bool:
installed_modules = (await client.info()).get("modules", {"name": ""})
installed_modules = [m["name"] for m in installed_modules]
installed_modules = [m["name"] for m in installed_modules] # type: ignore
return all([module in installed_modules for module in modules])


Expand Down Expand Up @@ -102,7 +103,7 @@ async def init(cls):
logging.error(f"Error setting up Redis: {e}")
raise e

if not await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES):
if not await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES): # type: ignore
raise ValueError(
"You must add the search and json modules in Redis Stack. "
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
Expand Down Expand Up @@ -194,7 +195,6 @@ def _get_redis_query(self, query: QueryWithEmbedding) -> RediSearchQuery:
Returns:
RediSearchQuery: Query for RediSearch.
"""
query_str: str = ""
filter_str: str = ""

# RediSearch field type to query string
Expand Down Expand Up @@ -368,7 +368,9 @@ async def delete(
# TODO - extend this to work with other metadata filters?
if filter.document_id:
try:
keys = await self._find_keys(f"{REDIS_DOC_PREFIX}:{filter.document_id}:*")
keys = await self._find_keys(
f"{REDIS_DOC_PREFIX}:{filter.document_id}:*"
)
await self._redis_delete(keys)
logging.info(f"Deleted document {filter.document_id} successfully")
except Exception as e:
Expand All @@ -382,7 +384,9 @@ async def delete(
keys = []
# find all keys associated with the document ids
for document_id in ids:
doc_keys = await self._find_keys(pattern=f"{REDIS_DOC_PREFIX}:{document_id}:*")
doc_keys = await self._find_keys(
pattern=f"{REDIS_DOC_PREFIX}:{document_id}:*"
)
keys.extend(doc_keys)
# delete all keys
logging.info(f"Deleting {len(keys)} keys from Redis")
Expand Down
5 changes: 1 addition & 4 deletions examples/authentication-methods/no-auth/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# This is a version of the main.py file found in ../../../server/main.py without authentication.
# Copy and paste this into the main file at ../../../server/main.py if you choose to use no authentication for your retrieval plugin.

import os
import uvicorn
from fastapi import FastAPI, File, HTTPException, Depends, Body, UploadFile
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import FastAPI, File, HTTPException, Body, UploadFile
from fastapi.staticfiles import StaticFiles

from models.api import (
Expand Down
22 changes: 12 additions & 10 deletions examples/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@
from services.file import get_document_from_file


bearer_scheme = HTTPBearer()
BEARER_TOKEN = os.environ.get("BEARER_TOKEN")
assert BEARER_TOKEN is not None


def validate_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
if credentials.scheme != "Bearer" or credentials.credentials != BEARER_TOKEN:
raise HTTPException(status_code=401, detail="Invalid or missing token")
return credentials


app = FastAPI()
app.mount("/.well-known", StaticFiles(directory=".well-known"), name="static")

Expand All @@ -29,19 +40,10 @@
description="A retrieval API for querying and filtering documents based on natural language queries and metadata",
version="1.0.0",
servers=[{"url": "https://your-app-url.com"}],
dependencies=[Depends(validate_token)],
)
app.mount("/sub", sub_app)

bearer_scheme = HTTPBearer()
BEARER_TOKEN = os.environ.get("BEARER_TOKEN")
assert BEARER_TOKEN is not None


def validate_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
if credentials.scheme != "Bearer" or credentials.credentials != BEARER_TOKEN:
raise HTTPException(status_code=401, detail="Invalid or missing token")
return credentials


@app.post(
"/upsert-file",
Expand Down
1 change: 0 additions & 1 deletion models/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from models.models import (
Document,
DocumentChunkWithScore,
DocumentMetadataFilter,
Query,
QueryResult,
Expand Down
2 changes: 1 addition & 1 deletion scripts/process_json/process_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import asyncio

from models.models import Document, DocumentMetadata, Source
from models.models import Document, DocumentMetadata
from datastore.datastore import DataStore
from datastore.factory import get_datastore
from services.extract_metadata import extract_metadata_from_document
Expand Down
2 changes: 1 addition & 1 deletion scripts/process_jsonl/process_jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import asyncio

from models.models import Document, DocumentMetadata, Source
from models.models import Document, DocumentMetadata
from datastore.datastore import DataStore
from datastore.factory import get_datastore
from services.extract_metadata import extract_metadata_from_document
Expand Down
28 changes: 12 additions & 16 deletions server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,18 @@
from datastore.factory import get_datastore
from services.file import get_document_from_file

bearer_scheme = HTTPBearer()
BEARER_TOKEN = os.environ.get("BEARER_TOKEN")
assert BEARER_TOKEN is not None


def validate_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
if credentials.scheme != "Bearer" or credentials.credentials != BEARER_TOKEN:
raise HTTPException(status_code=401, detail="Invalid or missing token")
return credentials


app = FastAPI()
app = FastAPI(dependencies=[Depends(validate_token)])
app.mount("/.well-known", StaticFiles(directory=".well-known"), name="static")

# Create a sub-application, in order to access just the query endpoint in an OpenAPI schema, found at http://0.0.0.0:8000/sub/openapi.json when the app is running locally
Expand All @@ -25,27 +35,17 @@
description="A retrieval API for querying and filtering documents based on natural language queries and metadata",
version="1.0.0",
servers=[{"url": "https://your-app-url.com"}],
dependencies=[Depends(validate_token)],
)
app.mount("/sub", sub_app)

bearer_scheme = HTTPBearer()
BEARER_TOKEN = os.environ.get("BEARER_TOKEN")
assert BEARER_TOKEN is not None


def validate_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
if credentials.scheme != "Bearer" or credentials.credentials != BEARER_TOKEN:
raise HTTPException(status_code=401, detail="Invalid or missing token")
return credentials


@app.post(
"/upsert-file",
response_model=UpsertResponse,
)
async def upsert_file(
file: UploadFile = File(...),
token: HTTPAuthorizationCredentials = Depends(validate_token),
):
document = await get_document_from_file(file)

Expand All @@ -63,7 +63,6 @@ async def upsert_file(
)
async def upsert(
request: UpsertRequest = Body(...),
token: HTTPAuthorizationCredentials = Depends(validate_token),
):
try:
ids = await datastore.upsert(request.documents)
Expand All @@ -79,7 +78,6 @@ async def upsert(
)
async def query_main(
request: QueryRequest = Body(...),
token: HTTPAuthorizationCredentials = Depends(validate_token),
):
try:
results = await datastore.query(
Expand All @@ -99,7 +97,6 @@ async def query_main(
)
async def query(
request: QueryRequest = Body(...),
token: HTTPAuthorizationCredentials = Depends(validate_token),
):
try:
results = await datastore.query(
Expand All @@ -117,7 +114,6 @@ async def query(
)
async def delete(
request: DeleteRequest = Body(...),
token: HTTPAuthorizationCredentials = Depends(validate_token),
):
if not (request.ids or request.filter or request.delete_all):
raise HTTPException(
Expand Down
15 changes: 1 addition & 14 deletions tests/datastore/providers/milvus/test_milvus_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
DocumentChunkMetadata,
DocumentMetadataFilter,
DocumentChunk,
Query,
QueryWithEmbedding,
Source,
)
Expand Down Expand Up @@ -151,7 +150,6 @@ async def test_upsert(milvus_datastore, document_chunk_one):
assert res == list(document_chunk_one.keys())
milvus_datastore.col.flush()
assert 3 == milvus_datastore.col.num_entities



@pytest.mark.asyncio
Expand All @@ -163,7 +161,7 @@ async def test_reload(milvus_datastore, document_chunk_one, document_chunk_two):
milvus_datastore.col.flush()
assert 3 == milvus_datastore.col.num_entities
new_store = MilvusDataStore()
another_in = {i:document_chunk_two[i] for i in document_chunk_two if i!=res[0]}
another_in = {i: document_chunk_two[i] for i in document_chunk_two if i != res[0]}
res = await new_store._upsert(another_in)
new_store.col.flush()
assert 6 == new_store.col.num_entities
Expand All @@ -175,9 +173,6 @@ async def test_reload(milvus_datastore, document_chunk_one, document_chunk_two):
query_results = await milvus_datastore._query(queries=[query])
assert 1 == len(query_results)





@pytest.mark.asyncio
async def test_upsert_query_all(milvus_datastore, document_chunk_two):
Expand All @@ -197,8 +192,6 @@ async def test_upsert_query_all(milvus_datastore, document_chunk_two):
assert 1 == len(query_results)
assert 6 == len(query_results[0].results)




@pytest.mark.asyncio
async def test_query_accuracy(milvus_datastore, document_chunk_one):
Expand All @@ -217,7 +210,6 @@ async def test_query_accuracy(milvus_datastore, document_chunk_one):
assert 1 == len(query_results[0].results)
assert 0 == query_results[0].results[0].score
assert "abc_123" == query_results[0].results[0].id



@pytest.mark.asyncio
Expand All @@ -240,7 +232,6 @@ async def test_query_filter(milvus_datastore, document_chunk_one):
assert 1 == len(query_results[0].results)
assert 0 != query_results[0].results[0].score
assert "def_456" == query_results[0].results[0].id



@pytest.mark.asyncio
Expand All @@ -265,7 +256,6 @@ async def test_delete_with_date_filter(milvus_datastore, document_chunk_one):
assert 1 == len(query_results)
assert 1 == len(query_results[0].results)
assert "ghi_789" == query_results[0].results[0].id



@pytest.mark.asyncio
Expand All @@ -290,7 +280,6 @@ async def test_delete_with_source_filter(milvus_datastore, document_chunk_one):
assert 1 == len(query_results)
assert 2 == len(query_results[0].results)
assert "def_456" == query_results[0].results[0].id



@pytest.mark.asyncio
Expand All @@ -313,7 +302,6 @@ async def test_delete_with_document_id_filter(milvus_datastore, document_chunk_o

assert 1 == len(query_results)
assert 0 == len(query_results[0].results)



@pytest.mark.asyncio
Expand All @@ -333,7 +321,6 @@ async def test_delete_with_document_id(milvus_datastore, document_chunk_one):

assert 1 == len(query_results)
assert 0 == len(query_results[0].results)



# if __name__ == '__main__':
Expand Down

0 comments on commit 958bb78

Please sign in to comment.