Skip to content

Commit

Permalink
Merge pull request #2 from hookdeck/chore/refactor
Browse files Browse the repository at this point in the history
Refactor to follow better practice
  • Loading branch information
leggetter authored Nov 12, 2024
2 parents 21b1822 + 8a7ba98 commit 5fb280c
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 111 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ Run the following to create Hookdeck connections to receive webhooks from Replic
poetry run python create-hookdeck-connections.py
```

Run the following to create a Vector Index within MongoDB:
Run the following to create a search indexes within MongoDB:

> [!WARNING]
> You may need some data within MongoDB before you can create the vector index.
> You may need some data within MongoDB before you can create the indexes.
```sh
poetry run python create-vector-index.py
poetry run python create-indexes.py
```

### Run the app
Expand Down
10 changes: 1 addition & 9 deletions lib/generators.py → allthethings/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@
from config import Config


def get_embedding_generator():
return EmbeddingsGenerator()


def get_sync_embedding_generator():
return SyncEmbeddingsGenerator()


class EmbeddingsGenerator:
class AsyncEmbeddingsGenerator:
def __init__(self):
self.WEBHOOK_URL = Config.EMBEDDINGS_WEBHOOK_URL
self.model = replicate.models.get("replicate/all-mpnet-base-v2")
Expand Down
22 changes: 22 additions & 0 deletions allthethings/mongo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

from config import Config


class Database:

def __init__(self):
MONGODB_CONNECTION_URI = Config.MONGODB_CONNECTION_URI

self.client = MongoClient(MONGODB_CONNECTION_URI, server_api=ServerApi("1"))

self.client.admin.command("ping")

def get_client(self):
return self.client

def get_collection(self):
return self.client.get_database(Config.DB_NAME).get_collection(
Config.COLLECTION_NAME
)
File renamed without changes.
77 changes: 31 additions & 46 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

from config import Config

from lib.mongo import get_mongo_client
from lib.processors import get_asset_processor
from lib.generators import get_embedding_generator, get_sync_embedding_generator
from allthethings.mongo import Database
from allthethings.processors import get_asset_processor
from allthethings.generators import (
AsyncEmbeddingsGenerator,
SyncEmbeddingsGenerator,
)

app = Flask(
__name__, static_url_path="", template_folder="templates", static_folder="static"
Expand Down Expand Up @@ -35,12 +38,9 @@ def format_results(results):

@app.route("/")
def index():
client = get_mongo_client()
if client is None:
flash("Failed to connect to MongoDB")
return redirect(url_for("index"))
database = Database()

indexes = client[Config.DB_NAME][Config.COLLECTION_NAME].find({})
indexes = database.get_collection().find({})
results = format_results(indexes)

app.logger.info("Homepage loading")
Expand All @@ -56,11 +56,6 @@ def search():

@app.route("/search", methods=["POST"])
def search_post():
client = get_mongo_client()
if client is None:
flash("Failed to connect to MongoDB")
return redirect(url_for("index"))

query = request.form["query"]

app.logger.info("Query submitted")
Expand All @@ -80,12 +75,10 @@ def search_post():
def process():
url = request.form["url"]

client = get_mongo_client()
if client is None:
flash("Failed to connect to MongoDB")
return redirect(url_for("index"))
database = Database()
collection = database.get_collection()

exists = client[Config.DB_NAME][Config.COLLECTION_NAME].find_one({"url": url})
exists = collection.find_one({"url": url})
if exists is not None:
flash("URL has already been indexed")
return redirect(url_for("index"))
Expand All @@ -108,7 +101,7 @@ def process():
flash('Unsupported content type "' + content_type + '"')
return redirect(url_for("index"))

client[Config.DB_NAME][Config.COLLECTION_NAME].insert_one(
collection.insert_one(
{
"url": url,
"content_type": content_type,
Expand All @@ -119,7 +112,7 @@ def process():

prediction = processor.process(url)

client[Config.DB_NAME][Config.COLLECTION_NAME].update_one(
collection.update_one(
filter={"url": url},
update={
"$set": {
Expand Down Expand Up @@ -149,23 +142,22 @@ def process():
def request_embeddings(id):
app.logger.info("Requesting embeddings for %s", id)

client = get_mongo_client()
if client is None:
raise RuntimeError("Failed to connect to MongoDB")
database = Database()
collection = database.get_collection()

asset = client[Config.DB_NAME][Config.COLLECTION_NAME].find_one({"_id": id})
asset = collection.find_one({"_id": id})

if asset is None:
raise RuntimeError("Asset not found")

if asset["status"] != "PROCESSED":
raise RuntimeError("Asset has not been processed")

generator = get_embedding_generator()
generator = AsyncEmbeddingsGenerator()

generate_request = generator.generate(asset["text"])

client[Config.DB_NAME][Config.COLLECTION_NAME].update_one(
collection.update_one(
filter={"_id": id},
update={
"$set": {
Expand All @@ -178,14 +170,8 @@ def request_embeddings(id):

# Inspiration https://www.mongodb.com/developer/products/atlas/how-use-cohere-embeddings-rerank-modules-mongodb-atlas/#query-mongodb-vector-index-using--vectorsearch
def query_vector_search(q, prefilter={}, postfilter={}, path="embedding", topK=2):
client = get_mongo_client()
if client is None:
raise RuntimeError("Failed to connect to MongoDB")

asset_collection = client[Config.DB_NAME][Config.COLLECTION_NAME]

# Because the search is user-driven, we use the synchronous generator
generator = get_sync_embedding_generator()
generator = SyncEmbeddingsGenerator()

generate_response = generator.generate(q)

Expand Down Expand Up @@ -221,13 +207,16 @@ def query_vector_search(q, prefilter={}, postfilter={}, path="embedding", topK=2
}
}

database = Database()
collection = database.get_collection()

if len(postfilter.keys()) > 0:
app.logger.info("Vector search query with post filter")
postFilter = {"$match": postfilter}
res = list(asset_collection.aggregate([new_search_query, project, postFilter]))
res = list(collection.aggregate([new_search_query, project, postFilter]))
else:
app.logger.info("Vector search query without post filter")
res = list(asset_collection.aggregate([new_search_query, project]))
res = list(collection.aggregate([new_search_query, project]))

app.logger.info("Vector search query run")
app.logger.debug(res)
Expand All @@ -240,16 +229,14 @@ def webhook_audio():
app.logger.info("Audio payload recieved")
app.logger.debug(payload)

client = get_mongo_client()

if client is None:
return jsonify({"error": "Database connection failed"}), 500
database = Database()
collection = database.get_collection()

status = (
"PROCESSING_ERROR" if "error" in payload and payload["error"] else "PROCESSED"
)

result = client[Config.DB_NAME][Config.COLLECTION_NAME].find_one_and_update(
result = collection.find_one_and_update(
filter={"replicate_process_id": payload["id"]},
update={
"$set": {
Expand Down Expand Up @@ -281,18 +268,16 @@ def webhook_embeddings():
app.logger.info("Embeddings payload recieved")
app.logger.debug(payload)

client = get_mongo_client()

if client is None:
return jsonify({"error": "Database connection failed"}), 500

status = (
"EMBEDDINGS_ERROR" if "error" in payload and payload["error"] else "SEARCHABLE"
)

embedding = payload["output"][0]["embedding"]

result = client[Config.DB_NAME][Config.COLLECTION_NAME].update_one(
database = Database()
collection = database.get_collection()

result = collection.update_one(
filter={"replicate_embedding_id": payload["id"]},
update={
"$set": {
Expand Down
59 changes: 59 additions & 0 deletions create-indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Inspiration https://www.mongodb.com/developer/products/atlas/how-use-cohere-embeddings-rerank-modules-mongodb-atlas/#programmatically-create-vector-search-and-full-text-search-index

from allthethings.mongo import Database
from pymongo.operations import SearchIndexModel

database = Database()
collection = database.get_collection()


def create_or_update_search_index(index_name, index_definition, index_type):
indexes = list(collection.list_search_indexes(index_name))
if len(indexes) == 0:
print(f'Creating search index: "{index_name}"')
index_model = SearchIndexModel(
definition=index_definition,
name=index_name,
type=index_type,
)
result = collection.create_search_index(model=index_model)

else:
print(f'Search index "{index_name}" already exists. Updating.')
result = collection.update_search_index(
name=index_name, definition=index_definition
)

return result


vector_result = create_or_update_search_index(
"vector_index",
{
"fields": [
{
"type": "vector",
"path": "embedding",
"numDimensions": 768,
"similarity": "euclidean",
}
]
},
"vectorSearch",
)
print(vector_result)

index_result = create_or_update_search_index(
"replicate_by_embedding_id_index",
{
"mappings": {"dynamic": True},
"fields": [
{
"type": "string",
"path": "replicate_embedding_id",
}
],
},
"search",
)
print(index_result)
33 changes: 0 additions & 33 deletions create-vector-index.py

This file was deleted.

20 changes: 0 additions & 20 deletions lib/mongo.py

This file was deleted.

0 comments on commit 5fb280c

Please sign in to comment.