Skip to content

Commit

Permalink
Merge pull request #43 from milistu/vector-database
Browse files Browse the repository at this point in the history
Vector database
  • Loading branch information
milistu authored May 27, 2024
2 parents e53e3be + e0cbd6c commit 21d9062
Show file tree
Hide file tree
Showing 4 changed files with 365 additions and 7 deletions.
14 changes: 7 additions & 7 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
openai:
embedding_model:
name: "text-embedding-3-small"
embeddings:
model: "text-embedding-3-small"
dimensions: 1536
gpt_model:
llm: "gpt-4o"
router: "gpt-3.5-turbo"
chat:
model: "gpt-4-turbo-preview"
temperature: 0
max_conversation: 100
collection:
name: "labor_law"
router:
model: "gpt-3.5-turbo"
temperature: 0
45 changes: 45 additions & 0 deletions database/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Database Module

This directory contains the scripts and utilities for processing and embedding scraped data and upserting it into a vector database.

## Overview

- `utils.py`: Utility functions for embedding text, managing collections in the Qdrant vector database, and handling data files.
- `vector_database.py`: **Main** script for creating embeddings from scraped data and storing them in a vector database.
- `api_request_parallel_processor.py`: Handles parallel API requests to the OpenAI API for text embedding, ensuring efficient usage of API rate limits.

## Setup

Refer to the main README for details on setting up the environment and dependencies using Poetry.

### Qdrant Setup
To use Qdrant for storing embeddings:

- Create a Free Cluster:

Visit [Qdrant Cloud](https://cloud.qdrant.io/accounts/530e9933-88c7-42b7-a027-734ec6f5eb57/overview) and sign up for an account.
Follow the prompts to create a new cluster. Select the free tier if available.

- Get API Key and Cluster URL:

Once your cluster is ready, navigate to the dashboard.
Find your cluster's URL and API key under the 'Settings' or 'API' section.

### Environment Configuration

Before running the scripts, ensure you have set the necessary environment variables in your `.env` file:
```yaml
QDRANT_CLUSTER_URL=
QDRANT_API_KEY=
OPENAI_API_KEY=
```

## Usage

To process and embed scraped data and upsert it to the vector database, use the following command from your project root:

```bash
python -m database.vector_database --scraped_dir ./scraper/test_laws/ --model text-embedding-3-small
```

This command will automatically handle all steps from data preparation, embedding, and upserting to the Qdrant vector database.
227 changes: 227 additions & 0 deletions database/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json
import subprocess
from pathlib import Path
from typing import Dict, List, Union

import numpy as np
Expand All @@ -14,6 +17,7 @@
UpdateResult,
VectorParams,
)
from tqdm.auto import tqdm


def create_collection(
Expand Down Expand Up @@ -106,3 +110,226 @@ def get_context(search_results: List[ScoredPoint], top_k: int = None) -> str:
:top_k
]
return "\n".join([format_context(point.payload) for point in search_results])


def load_json(path: Path) -> List[Dict]:
"""
Load JSON data from a file.
Args:
path (Path): The path to the JSON file.
Returns:
List[Dict]: The JSON data loaded from the file.
Raises:
FileNotFoundError: If the file does not exist.
"""
if not path.exists():
logger.error(f"File: {path} does not exist.")
raise FileNotFoundError(f"File: {path} does not exist.")

with open(path, "r", encoding="utf-8") as file:
data = json.load(file)

return data


def prepare_for_embedding(
output_path: Path, scraped_data: List[Dict], model: str
) -> None:
"""
Prepare data for embedding and save to a file.
Args:
output_path (Path): The path to save the prepared data.
scraped_data (List[Dict]): The scraped data to be prepared.
model (str): The embedding model to be used.
Returns:
None
"""
jobs = [
{
"model": model,
"id": id,
"title": sample["title"],
"link": sample["link"],
"input": f"{sample['title']}: {' '.join(sample['texts'])}",
}
for id, sample in enumerate(scraped_data)
]
with open(output_path, "w", encoding="utf-8") as file:
for job in jobs:
json_string = json.dumps(job)
file.write(json_string + "\n")


def get_token_num(text: str, model_name: str) -> int:
"""
Get the number of tokens in a text for a given model.
Args:
text (str): The input text.
model_name (str): The name of the model.
Returns:
int: The number of tokens in the text.
"""
enc = tiktoken.encoding_for_model(model_name)
return len(enc.encode(text))


def run_api_request_processor(
requests_filepath: Path,
save_path: Path,
max_requests_per_minute: int = 2500,
max_tokens_per_minute: int = 900000,
token_encoding_name: str = "cl100k_base",
max_attempts: int = 5,
logging_level: int = 20,
) -> None:
"""
Run the API request processor to call the OpenAI API in parallel, creating embeddings with the specified model.
Args:
requests_filepath (Path): The path to the requests file.
save_path (Path): The path to save the results.
max_requests_per_minute (int): Maximum number of requests per minute.
max_tokens_per_minute (int): Maximum number of tokens per minute.
token_encoding_name (str): The name of the token encoding.
max_attempts (int): Maximum number of attempts for each request.
logging_level (int): Logging level.
Returns:
None
"""
if not requests_filepath.exists():
logger.error(f"File {requests_filepath} does not exist.")
raise FileNotFoundError(f"File {requests_filepath} does not exist.")
if save_path.suffix != ".jsonl":
logger.error(f"Save path {save_path} must be JSONL.")
raise ValueError(f"Save path {save_path} must be JSONL.")

command = [
"python",
"database/api_request_parallel_processor.py",
"--requests_filepath",
requests_filepath,
"--save_filepath",
save_path,
"--request_url",
"https://api.openai.com/v1/embeddings",
"--max_requests_per_minute",
str(max_requests_per_minute),
"--max_tokens_per_minute",
str(max_tokens_per_minute),
"--token_encoding_name",
token_encoding_name,
"--max_attempts",
str(max_attempts),
"--logging_level",
str(logging_level),
]
result = subprocess.run(command, text=True, capture_output=True)

if result.returncode == 0:
logger.info(f"Embeddings saved to: {save_path}")
else:
logger.error("Error in Embedding execution!")
logger.error("Error:", result.stderr)


# Eliminate this or make it more general
def validate_path(path: Path) -> None:
if not isinstance(path, Path):
logger.error(f'"{path}" must be a valid Path object')
raise ValueError(f'"{path}" must be a valid Path object')
path.mkdir(parents=True, exist_ok=True)


def create_embeddings(
scraped_dir: Path, to_process_dir: Path, embeddings_dir: Path, model: str
) -> None:
"""
Embed scraped law files by preparing the data and running the request processor
to call the OpenAI API in parallel, creating embeddings with the specified model.
Args:
scraped_dir (Path): Directory to the law files.
to_process_dir (Path): Directory to process files.
embeddings_dir (Path): Directory for storing embeddings.
model (str): The embedding model to be used.
Raises:
ValueError: If any of the provided paths are invalid.
"""
# Validate input paths
validate_path(scraped_dir)
validate_path(to_process_dir)
validate_path(embeddings_dir)

scraped_paths = list(scraped_dir.iterdir())

for file_path in tqdm(
scraped_paths, desc="Embedding scraped files", total=len(scraped_paths)
):
scraped_data = load_json(path=file_path)

requests_filepath = to_process_dir / (file_path.stem + ".jsonl")
prepare_for_embedding(
output_path=requests_filepath,
scraped_data=scraped_data,
model=model,
)

processed_filepath = embeddings_dir / requests_filepath.name
run_api_request_processor(
requests_filepath=requests_filepath, save_path=processed_filepath
)


def load_and_process_embeddings(path: Path) -> List[PointStruct]:
"""
Load embeddings from a JSON lines file and process them into data points.
Args:
path (Path): The path to the JSON lines file containing embeddings.
Returns:
List[PointStruct]: A list of PointStruct objects containing the processed data.
Raises:
FileNotFoundError: If the file does not exist.
IOError: If there is an error reading the file.
json.JSONDecodeError: If there is an error parsing the JSON.
"""
if not path.exists():
logger.error(f"File: {path} does not exist.")
raise FileNotFoundError(f"File: {path} does not exist.")

try:
with open(path, "r", encoding="utf-8") as file:
embedding_data = [json.loads(line) for line in file]
except (IOError, json.JSONDecodeError) as e:
logger.error(f"Error reading or parsing file: {e}")
raise

points = []
for item in embedding_data:
try:
points.append(
PointStruct(
id=item[0]["id"],
vector=item[1]["data"][0]["embedding"],
payload={
"title": item[0]["title"],
"text": item[0]["input"],
"link": item[0]["link"],
},
)
)
except KeyError as e:
logger.error(f"Missing key in embedded data: {e}")
continue
return points
86 changes: 86 additions & 0 deletions database/vector_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import argparse
import json
import os
from pathlib import Path

from loguru import logger
from qdrant_client import QdrantClient
from tqdm.auto import tqdm

from database.utils import (
create_collection,
create_embeddings,
get_count,
load_and_process_embeddings,
upsert,
)


def main(args: argparse.Namespace) -> None:
logger.info("Creating embeddings.")
create_embeddings(
scraped_dir=args.scraped_dir,
to_process_dir=args.to_process_dir,
embeddings_dir=args.embeddings_dir,
model=args.model,
)

logger.info("Creating vector database.")
qdrant_client = QdrantClient(
url=os.environ["QDRANT_CLUSTER_URL"],
api_key=os.environ["QDRANT_API_KEY"],
)
data_paths = list(args.embeddings_dir.iterdir())
for path in tqdm(data_paths, total=len(data_paths), desc="Creating collections"):
# Check if this is necessary
collection_name = path.stem.replace("-", "_")
collection_name = collection_name + "_TESTIC"
points = load_and_process_embeddings(path=path)

create_collection(client=qdrant_client, name=collection_name)
upsert(client=qdrant_client, collection=collection_name, points=points)

point_num = get_count(client=qdrant_client, collection=collection_name)
if not point_num == len(points):
logger.error(f"There are missing points in {collection_name} collection.")

logger.info(
f'Created "{collection_name}" collection with {point_num} data points.'
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Create embeddings and vector database for scraped files."
)
parser.add_argument(
"--scraped_dir", type=Path, help="Directory to the scraped files."
)
parser.add_argument(
"--to_process_dir",
type=Path,
default=Path("./database/to_process"),
help="Directory to process files.",
)
parser.add_argument(
"--embeddings_dir",
type=Path,
default=Path("./database/embeddings"),
help="Directory for storing embeddings.",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="The embedding model to be used. If not set, it will be loaded from the config file.",
)

args = parser.parse_args()

# Load model from config file if not explicitly set
if args.model is None:
with open("config.json", "r") as config_file:
config = json.load(config_file)
args.model = config.get("embedding_model", "default_model")

main(args=args)

0 comments on commit 21d9062

Please sign in to comment.