Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector database #43

Merged
merged 9 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
openai:
embedding_model:
name: "text-embedding-3-small"
embeddings:
model: "text-embedding-3-small"
dimensions: 1536
gpt_model:
llm: "gpt-4o"
router: "gpt-3.5-turbo"
chat:
model: "gpt-4-turbo-preview"
temperature: 0
max_conversation: 100
collection:
name: "labor_law"
router:
model: "gpt-3.5-turbo"
temperature: 0
45 changes: 45 additions & 0 deletions database/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Database Module

This directory contains the scripts and utilities for processing and embedding scraped data and upserting it into a vector database.

## Overview

- `utils.py`: Utility functions for embedding text, managing collections in the Qdrant vector database, and handling data files.
- `vector_database.py`: **Main** script for creating embeddings from scraped data and storing them in a vector database.
- `api_request_parallel_processor.py`: Handles parallel API requests to the OpenAI API for text embedding, ensuring efficient usage of API rate limits.

## Setup

Refer to the main README for details on setting up the environment and dependencies using Poetry.

### Qdrant Setup
To use Qdrant for storing embeddings:

- Create a Free Cluster:

Visit [Qdrant Cloud](https://cloud.qdrant.io/accounts/530e9933-88c7-42b7-a027-734ec6f5eb57/overview) and sign up for an account.
Follow the prompts to create a new cluster. Select the free tier if available.

- Get API Key and Cluster URL:

Once your cluster is ready, navigate to the dashboard.
Find your cluster's URL and API key under the 'Settings' or 'API' section.

### Environment Configuration

Before running the scripts, ensure you have set the necessary environment variables in your `.env` file:
```yaml
QDRANT_CLUSTER_URL=
QDRANT_API_KEY=
OPENAI_API_KEY=
```

## Usage

To process and embed scraped data and upsert it to the vector database, use the following command from your project root:

```bash
python -m database.vector_database --scraped_dir ./scraper/test_laws/ --model text-embedding-3-small
```

This command will automatically handle all steps from data preparation, embedding, and upserting to the Qdrant vector database.
227 changes: 227 additions & 0 deletions database/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json
import subprocess
from pathlib import Path
from typing import Dict, List, Union

import numpy as np
Expand All @@ -14,6 +17,7 @@
UpdateResult,
VectorParams,
)
from tqdm.auto import tqdm


def create_collection(
Expand Down Expand Up @@ -106,3 +110,226 @@ def get_context(search_results: List[ScoredPoint], top_k: int = None) -> str:
:top_k
]
return "\n".join([format_context(point.payload) for point in search_results])


def load_json(path: Path) -> List[Dict]:
"""
Load JSON data from a file.

Args:
path (Path): The path to the JSON file.

Returns:
List[Dict]: The JSON data loaded from the file.

Raises:
FileNotFoundError: If the file does not exist.
"""
if not path.exists():
logger.error(f"File: {path} does not exist.")
raise FileNotFoundError(f"File: {path} does not exist.")

with open(path, "r", encoding="utf-8") as file:
data = json.load(file)

return data


def prepare_for_embedding(
output_path: Path, scraped_data: List[Dict], model: str
) -> None:
"""
Prepare data for embedding and save to a file.

Args:
output_path (Path): The path to save the prepared data.
scraped_data (List[Dict]): The scraped data to be prepared.
model (str): The embedding model to be used.

Returns:
None
"""
jobs = [
{
"model": model,
"id": id,
"title": sample["title"],
"link": sample["link"],
"input": f"{sample['title']}: {' '.join(sample['texts'])}",
}
for id, sample in enumerate(scraped_data)
]
with open(output_path, "w", encoding="utf-8") as file:
for job in jobs:
json_string = json.dumps(job)
file.write(json_string + "\n")


def get_token_num(text: str, model_name: str) -> int:
"""
Get the number of tokens in a text for a given model.

Args:
text (str): The input text.
model_name (str): The name of the model.

Returns:
int: The number of tokens in the text.
"""
enc = tiktoken.encoding_for_model(model_name)
return len(enc.encode(text))


def run_api_request_processor(
requests_filepath: Path,
save_path: Path,
max_requests_per_minute: int = 2500,
max_tokens_per_minute: int = 900000,
token_encoding_name: str = "cl100k_base",
max_attempts: int = 5,
logging_level: int = 20,
) -> None:
"""
Run the API request processor to call the OpenAI API in parallel, creating embeddings with the specified model.

Args:
requests_filepath (Path): The path to the requests file.
save_path (Path): The path to save the results.
max_requests_per_minute (int): Maximum number of requests per minute.
max_tokens_per_minute (int): Maximum number of tokens per minute.
token_encoding_name (str): The name of the token encoding.
max_attempts (int): Maximum number of attempts for each request.
logging_level (int): Logging level.

Returns:
None
"""
if not requests_filepath.exists():
logger.error(f"File {requests_filepath} does not exist.")
raise FileNotFoundError(f"File {requests_filepath} does not exist.")
if save_path.suffix != ".jsonl":
logger.error(f"Save path {save_path} must be JSONL.")
raise ValueError(f"Save path {save_path} must be JSONL.")

command = [
"python",
"database/api_request_parallel_processor.py",
"--requests_filepath",
requests_filepath,
"--save_filepath",
save_path,
"--request_url",
"https://api.openai.com/v1/embeddings",
"--max_requests_per_minute",
str(max_requests_per_minute),
"--max_tokens_per_minute",
str(max_tokens_per_minute),
"--token_encoding_name",
token_encoding_name,
"--max_attempts",
str(max_attempts),
"--logging_level",
str(logging_level),
]
result = subprocess.run(command, text=True, capture_output=True)

if result.returncode == 0:
logger.info(f"Embeddings saved to: {save_path}")
else:
logger.error("Error in Embedding execution!")
logger.error("Error:", result.stderr)


# Eliminate this or make it more general
def validate_path(path: Path) -> None:
if not isinstance(path, Path):
logger.error(f'"{path}" must be a valid Path object')
raise ValueError(f'"{path}" must be a valid Path object')
path.mkdir(parents=True, exist_ok=True)


def create_embeddings(
scraped_dir: Path, to_process_dir: Path, embeddings_dir: Path, model: str
) -> None:
"""
Embed scraped law files by preparing the data and running the request processor
to call the OpenAI API in parallel, creating embeddings with the specified model.

Args:
scraped_dir (Path): Directory to the law files.
to_process_dir (Path): Directory to process files.
embeddings_dir (Path): Directory for storing embeddings.
model (str): The embedding model to be used.

Raises:
ValueError: If any of the provided paths are invalid.
"""
# Validate input paths
validate_path(scraped_dir)
validate_path(to_process_dir)
validate_path(embeddings_dir)

scraped_paths = list(scraped_dir.iterdir())

for file_path in tqdm(
scraped_paths, desc="Embedding scraped files", total=len(scraped_paths)
):
scraped_data = load_json(path=file_path)

requests_filepath = to_process_dir / (file_path.stem + ".jsonl")
prepare_for_embedding(
output_path=requests_filepath,
scraped_data=scraped_data,
model=model,
)

processed_filepath = embeddings_dir / requests_filepath.name
run_api_request_processor(
requests_filepath=requests_filepath, save_path=processed_filepath
)


def load_and_process_embeddings(path: Path) -> List[PointStruct]:
"""
Load embeddings from a JSON lines file and process them into data points.

Args:
path (Path): The path to the JSON lines file containing embeddings.

Returns:
List[PointStruct]: A list of PointStruct objects containing the processed data.

Raises:
FileNotFoundError: If the file does not exist.
IOError: If there is an error reading the file.
json.JSONDecodeError: If there is an error parsing the JSON.
"""
if not path.exists():
logger.error(f"File: {path} does not exist.")
raise FileNotFoundError(f"File: {path} does not exist.")

try:
with open(path, "r", encoding="utf-8") as file:
embedding_data = [json.loads(line) for line in file]
except (IOError, json.JSONDecodeError) as e:
logger.error(f"Error reading or parsing file: {e}")
raise

points = []
for item in embedding_data:
try:
points.append(
PointStruct(
id=item[0]["id"],
vector=item[1]["data"][0]["embedding"],
payload={
"title": item[0]["title"],
"text": item[0]["input"],
"link": item[0]["link"],
},
)
)
except KeyError as e:
logger.error(f"Missing key in embedded data: {e}")
continue
return points
86 changes: 86 additions & 0 deletions database/vector_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import argparse
import json
import os
from pathlib import Path

from loguru import logger
from qdrant_client import QdrantClient
from tqdm.auto import tqdm

from database.utils import (
create_collection,
create_embeddings,
get_count,
load_and_process_embeddings,
upsert,
)


def main(args: argparse.Namespace) -> None:
logger.info("Creating embeddings.")
create_embeddings(
scraped_dir=args.scraped_dir,
to_process_dir=args.to_process_dir,
embeddings_dir=args.embeddings_dir,
model=args.model,
)

logger.info("Creating vector database.")
qdrant_client = QdrantClient(
url=os.environ["QDRANT_CLUSTER_URL"],
api_key=os.environ["QDRANT_API_KEY"],
)
data_paths = list(args.embeddings_dir.iterdir())
for path in tqdm(data_paths, total=len(data_paths), desc="Creating collections"):
# Check if this is necessary
collection_name = path.stem.replace("-", "_")
collection_name = collection_name + "_TESTIC"
points = load_and_process_embeddings(path=path)

create_collection(client=qdrant_client, name=collection_name)
upsert(client=qdrant_client, collection=collection_name, points=points)

point_num = get_count(client=qdrant_client, collection=collection_name)
if not point_num == len(points):
logger.error(f"There are missing points in {collection_name} collection.")

logger.info(
f'Created "{collection_name}" collection with {point_num} data points.'
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Create embeddings and vector database for scraped files."
)
parser.add_argument(
"--scraped_dir", type=Path, help="Directory to the scraped files."
)
parser.add_argument(
"--to_process_dir",
type=Path,
default=Path("./database/to_process"),
help="Directory to process files.",
)
parser.add_argument(
"--embeddings_dir",
type=Path,
default=Path("./database/embeddings"),
help="Directory for storing embeddings.",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="The embedding model to be used. If not set, it will be loaded from the config file.",
)

args = parser.parse_args()

# Load model from config file if not explicitly set
if args.model is None:
with open("config.json", "r") as config_file:
config = json.load(config_file)
args.model = config.get("embedding_model", "default_model")

main(args=args)