-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #43 from milistu/vector-database
Vector database
- Loading branch information
Showing
4 changed files
with
365 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
openai: | ||
embedding_model: | ||
name: "text-embedding-3-small" | ||
embeddings: | ||
model: "text-embedding-3-small" | ||
dimensions: 1536 | ||
gpt_model: | ||
llm: "gpt-4o" | ||
router: "gpt-3.5-turbo" | ||
chat: | ||
model: "gpt-4-turbo-preview" | ||
temperature: 0 | ||
max_conversation: 100 | ||
collection: | ||
name: "labor_law" | ||
router: | ||
model: "gpt-3.5-turbo" | ||
temperature: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Database Module | ||
|
||
This directory contains the scripts and utilities for processing and embedding scraped data and upserting it into a vector database. | ||
|
||
## Overview | ||
|
||
- `utils.py`: Utility functions for embedding text, managing collections in the Qdrant vector database, and handling data files. | ||
- `vector_database.py`: **Main** script for creating embeddings from scraped data and storing them in a vector database. | ||
- `api_request_parallel_processor.py`: Handles parallel API requests to the OpenAI API for text embedding, ensuring efficient usage of API rate limits. | ||
|
||
## Setup | ||
|
||
Refer to the main README for details on setting up the environment and dependencies using Poetry. | ||
|
||
### Qdrant Setup | ||
To use Qdrant for storing embeddings: | ||
|
||
- Create a Free Cluster: | ||
|
||
Visit [Qdrant Cloud](https://cloud.qdrant.io/accounts/530e9933-88c7-42b7-a027-734ec6f5eb57/overview) and sign up for an account. | ||
Follow the prompts to create a new cluster. Select the free tier if available. | ||
|
||
- Get API Key and Cluster URL: | ||
|
||
Once your cluster is ready, navigate to the dashboard. | ||
Find your cluster's URL and API key under the 'Settings' or 'API' section. | ||
|
||
### Environment Configuration | ||
|
||
Before running the scripts, ensure you have set the necessary environment variables in your `.env` file: | ||
```yaml | ||
QDRANT_CLUSTER_URL= | ||
QDRANT_API_KEY= | ||
OPENAI_API_KEY= | ||
``` | ||
|
||
## Usage | ||
|
||
To process and embed scraped data and upsert it to the vector database, use the following command from your project root: | ||
|
||
```bash | ||
python -m database.vector_database --scraped_dir ./scraper/test_laws/ --model text-embedding-3-small | ||
``` | ||
|
||
This command will automatically handle all steps from data preparation, embedding, and upserting to the Qdrant vector database. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import argparse | ||
import json | ||
import os | ||
from pathlib import Path | ||
|
||
from loguru import logger | ||
from qdrant_client import QdrantClient | ||
from tqdm.auto import tqdm | ||
|
||
from database.utils import ( | ||
create_collection, | ||
create_embeddings, | ||
get_count, | ||
load_and_process_embeddings, | ||
upsert, | ||
) | ||
|
||
|
||
def main(args: argparse.Namespace) -> None: | ||
logger.info("Creating embeddings.") | ||
create_embeddings( | ||
scraped_dir=args.scraped_dir, | ||
to_process_dir=args.to_process_dir, | ||
embeddings_dir=args.embeddings_dir, | ||
model=args.model, | ||
) | ||
|
||
logger.info("Creating vector database.") | ||
qdrant_client = QdrantClient( | ||
url=os.environ["QDRANT_CLUSTER_URL"], | ||
api_key=os.environ["QDRANT_API_KEY"], | ||
) | ||
data_paths = list(args.embeddings_dir.iterdir()) | ||
for path in tqdm(data_paths, total=len(data_paths), desc="Creating collections"): | ||
# Check if this is necessary | ||
collection_name = path.stem.replace("-", "_") | ||
collection_name = collection_name + "_TESTIC" | ||
points = load_and_process_embeddings(path=path) | ||
|
||
create_collection(client=qdrant_client, name=collection_name) | ||
upsert(client=qdrant_client, collection=collection_name, points=points) | ||
|
||
point_num = get_count(client=qdrant_client, collection=collection_name) | ||
if not point_num == len(points): | ||
logger.error(f"There are missing points in {collection_name} collection.") | ||
|
||
logger.info( | ||
f'Created "{collection_name}" collection with {point_num} data points.' | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Create embeddings and vector database for scraped files." | ||
) | ||
parser.add_argument( | ||
"--scraped_dir", type=Path, help="Directory to the scraped files." | ||
) | ||
parser.add_argument( | ||
"--to_process_dir", | ||
type=Path, | ||
default=Path("./database/to_process"), | ||
help="Directory to process files.", | ||
) | ||
parser.add_argument( | ||
"--embeddings_dir", | ||
type=Path, | ||
default=Path("./database/embeddings"), | ||
help="Directory for storing embeddings.", | ||
) | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
default=None, | ||
help="The embedding model to be used. If not set, it will be loaded from the config file.", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
# Load model from config file if not explicitly set | ||
if args.model is None: | ||
with open("config.json", "r") as config_file: | ||
config = json.load(config_file) | ||
args.model = config.get("embedding_model", "default_model") | ||
|
||
main(args=args) |