From 9b881b6c168f63dc641057efec14a28fcfc54449 Mon Sep 17 00:00:00 2001 From: Sourabh-Bharale Date: Fri, 22 Nov 2024 16:52:53 +0530 Subject: [PATCH 1/4] feat[ollama]: streaming chat with history context --- data_source/DriveParsers/requirements.txt | 4 +- db/index.py | 2 + llm/ChatHistory.py | 9 ++ llm/OllamaService.py | 37 +++++ main.py | 54 ++++++- ollama.sh | 176 ++++++++++++++++++++++ requirements.txt | 5 +- 7 files changed, 281 insertions(+), 6 deletions(-) create mode 100644 llm/ChatHistory.py create mode 100644 llm/OllamaService.py create mode 100755 ollama.sh diff --git a/data_source/DriveParsers/requirements.txt b/data_source/DriveParsers/requirements.txt index dc4541f..73ceec9 100644 --- a/data_source/DriveParsers/requirements.txt +++ b/data_source/DriveParsers/requirements.txt @@ -1,10 +1,10 @@ PyPDF2>=3.0.1 -"unstructured" +# "unstructured" langchain>=0.3.4 langchain_community>=0.3.3 langchain_google_community>=2.0.1 google-api-python-client google-auth-httplib2 google-auth-oauthlib -googleapiclient +# googleapiclient python-dotenv \ No newline at end of file diff --git a/db/index.py b/db/index.py index 62ac7c6..865f3ca 100644 --- a/db/index.py +++ b/db/index.py @@ -3,6 +3,7 @@ import os from dotenv import load_dotenv from typing import Annotated +from llm.ChatHistory import ChatHistory load_dotenv() DATABASE_URL = os.getenv("DATABASE_URL") @@ -14,6 +15,7 @@ def create_db_and_tables(): def get_session(): with Session(engine) as session: + session.chat_history = ChatHistory() yield session UserSession = Annotated[Session, Depends(get_session)] diff --git a/llm/ChatHistory.py b/llm/ChatHistory.py new file mode 100644 index 0000000..2a00484 --- /dev/null +++ b/llm/ChatHistory.py @@ -0,0 +1,9 @@ +class ChatHistory: + def __init__(self): + self.history = [] + + def add_message(self, message): + self.history.append(message) + + def get_history(self): + return self.history \ No newline at end of file diff --git a/llm/OllamaService.py b/llm/OllamaService.py new file mode 100644 index 0000000..bb70bb4 --- /dev/null +++ b/llm/OllamaService.py @@ -0,0 +1,37 @@ +from ollama import AsyncClient +from typing import List +from db.index import UserSession +from db.actions.vectors import get_similar_vectors +from .ChatHistory import ChatHistory +from ollama import AsyncClient +from typing import AsyncGenerator + + +async def ollama_client(query: str, session: UserSession) -> AsyncGenerator[dict, None]: + if not hasattr(session, 'chat_history'): + session.chat_history = ChatHistory() + + session.chat_history.add_message({'role': 'user', 'content': query}) + + # need to convert the query text to a embeddings vector for searching + # when implemented replace the below line with the actual generated vector + query_vector = [0.1, 0.2, 0.3] + similar_vectors = await get_similar_vectors(query_vector, session) + + for vector in similar_vectors: + metadata_str = str(vector['metaData']) + session.chat_history.add_message({'role': 'system', 'content': metadata_str}) + + context = session.chat_history.get_history() + messages = [{'role': msg['role'], 'content': msg['content']} for msg in context] + + try: + client = AsyncClient() + async for chunk in await client.chat( + model='llama2', + messages=messages, + stream=True + ): + yield chunk + except Exception as e: + yield {'message': {'content': f"Error: {str(e)}"}} \ No newline at end of file diff --git a/main.py b/main.py index d8da937..eaea668 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,17 @@ -from fastapi import FastAPI, Depends +from fastapi import FastAPI +from fastapi.responses import StreamingResponse from data_source.webscraper.index import WebCrawler from db.index import create_db_and_tables, UserSession from pydantic import BaseModel, Field, HttpUrl from contextlib import asynccontextmanager from db.actions.vectors import get_similar_vectors from db.actions.web_scrapper import list_webscraps, save_webscrap +from llm.OllamaService import ollama_client +from llm.ChatHistory import ChatHistory +import json -from typing import List + +from typing import List, AsyncGenerator class ScrapModel(BaseModel): base_url: HttpUrl = Field(..., example="https://example.com") depth: int = Field(..., ge=1, le=10, example=3) @@ -15,6 +20,9 @@ class ScrapModel(BaseModel): class VectorQueryModel(BaseModel): query_vector: List[float] = Field(..., example=[0.1, 0.2, 0.3]) top_k: int = Field(5, example=5) + +class ChatModel(BaseModel): + message: str = Field(..., example="Hello, Ollama!") @asynccontextmanager async def lifespan(app: FastAPI): @@ -39,3 +47,45 @@ async def scrap_website(scrap_model: ScrapModel, session: UserSession): data = crawler.save_results() save_webscrap(data, session) return {"message": "Crawling completed successfully"} + +@app.post("/chat") +async def chat_endpoint(request: ChatModel, session: UserSession): + async def response_stream() -> AsyncGenerator[bytes, None]: + buffer = "" + async for chunk in ollama_client(request.message, session): + if chunk and 'message' in chunk and 'content' in chunk['message']: + # Accumulate content in buffer + buffer += chunk['message']['content'] + # If we have a complete word or punctuation, yield it + if buffer.endswith((' ', '.', '!', '?', '\n')): + response_json = { + "content": buffer, + "isFinished": False + } + yield f"{json.dumps(response_json)}\n".encode('utf-8') + buffer = "" + + # Yield any remaining content with isFinished flag + if buffer: + response_json = { + "content": buffer, + "isFinished": True + } + yield f"{json.dumps(response_json)}\n".encode('utf-8') + else: + # Send final empty message with isFinished flag if buffer is empty + response_json = { + "content": "", + "isFinished": True + } + yield f"{json.dumps(response_json)}\n".encode('utf-8') + + return StreamingResponse( + response_stream(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked" + } + ) diff --git a/ollama.sh b/ollama.sh new file mode 100755 index 0000000..ea74fb6 --- /dev/null +++ b/ollama.sh @@ -0,0 +1,176 @@ +#!/bin/bash + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Function to print colored messages +print_message() { + local color=$1 + local message=$2 + echo -e "${color}${message}${NC}" +} + +# Function to check if Docker is installed +check_docker() { + if ! command -v docker &> /dev/null; then + print_message "$RED" "Error: Docker is not installed. Please install Docker first." + exit 1 + fi +} + +# Function to check if Nvidia GPU is available +check_nvidia_gpu() { + if command -v nvidia-smi &> /dev/null; then + if nvidia-smi &> /dev/null; then + return 0 # GPU available + fi + fi + return 1 # GPU not available +} + +# Function to check if Nvidia Container Toolkit is installed +check_nvidia_toolkit() { + if ! docker info | grep -i "nvidia" &> /dev/null; then + print_message "$YELLOW" "Warning: Nvidia Container Toolkit not detected. GPU support may not work." + print_message "$YELLOW" "To install, visit: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html" + return 1 + fi + return 0 +} + +# Function to stop and remove existing container +cleanup_existing() { + if docker ps -a | grep -q "ollama"; then + print_message "$YELLOW" "Stopping and removing existing Ollama container..." + docker stop ollama &> /dev/null + docker rm ollama &> null + fi +} + +# Function to start Ollama +start_ollama() { + local use_gpu=$1 + + if [ "$use_gpu" = true ] && check_nvidia_gpu && check_nvidia_toolkit; then + print_message "$GREEN" "Setting up Ollama with GPU support..." + docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama + else + if [ "$use_gpu" = true ]; then + print_message "$YELLOW" "GPU setup requested but not available. Falling back to CPU..." + else + print_message "$GREEN" "Setting up Ollama with CPU..." + fi + docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama + fi +} + +# Function to wait for Ollama to be ready +wait_for_ollama() { + print_message "$GREEN" "Waiting for Ollama to start..." + local max_attempts=30 + local attempt=1 + + while [ $attempt -le $max_attempts ]; do + if curl -s http://localhost:11434/api/tags &> /dev/null; then + print_message "$GREEN" "Ollama is ready!" + return 0 + fi + sleep 1 + attempt=$((attempt + 1)) + done + + print_message "$RED" "Timeout waiting for Ollama to start" + return 1 +} + +# Function to exit Ollama +exit_ollama() { + print_message "$YELLOW" "Stopping Ollama container..." + docker stop ollama &> /dev/null +} + +# Function to chat with Ollama +chat_ollama() { + local message=$1 + local model="llama2" + response=$(curl -s -X POST http://localhost:11434/api/chat -d "{\"message\": \"$message\", \"model\": \"$model\"}" -H "Content-Type: application/json") + if echo "$response" | grep -q "error"; then + print_message "$RED" "Failed to send message to Ollama: $response" + else + print_message "$GREEN" "Message sent to Ollama: $response" + fi +} + +# Function to pull Ollama Docker images +pull_ollama_images() { + print_message "$GREEN" "Pulling Ollama Docker images..." + docker pull ollama/ollama +} + +# Function to delete Ollama Docker images +delete_ollama_images() { + print_message "$YELLOW" "Stopping and removing existing Ollama container..." + docker stop ollama &> /dev/null + docker rm ollama &> /dev/null + print_message "$YELLOW" "Deleting Ollama Docker images..." + docker rmi ollama/ollama +} + +# Function to display help message +display_help() { + echo "Usage: $0 {start|exit|chat |pull|delete|--help}" + echo + echo "Commands:" + echo " start Start the Ollama container" + echo " exit Stop the Ollama container" + echo " chat Send a chat message to Ollama" + echo " pull Pull the Ollama Docker images" + echo " delete Delete the Ollama Docker images" + echo " --help Display this help message" +} + +# Main script execution +main() { + check_docker + + case "$1" in + start) + cleanup_existing + local use_gpu=false + if check_nvidia_gpu; then + use_gpu=true + fi + start_ollama $use_gpu + wait_for_ollama + ;; + exit) + exit_ollama + ;; + chat) + if [ -z "$2" ]; then + print_message "$RED" "Error: No message provided for chat" + exit 1 + fi + chat_ollama "$2" + ;; + pull) + pull_ollama_images + ;; + delete) + delete_ollama_images + ;; + --help) + display_help + ;; + *) + print_message "$YELLOW" "Invalid command. Use --help to see the available commands." + exit 1 + ;; + esac +} + +# Run the script +main "$@" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index cdf4cf0..798e779 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,8 @@ requests>=2.31.0 beautifulsoup4>=4.12.0 html2text>=2024.2.9 urllib3>=2.0.0 -"fastapi[standard]" +# "fastapi[standard]" sqlmodel>=0.0.22 psycopg2>=2.9.10 -pgvector \ No newline at end of file +pgvector +ollama>=0.4.0 \ No newline at end of file From 31f702c8846849d94a89a0067a76e3aa4972a76e Mon Sep 17 00:00:00 2001 From: Vishwajeetsingh Desurkar Date: Tue, 26 Nov 2024 12:24:33 +0530 Subject: [PATCH 2/4] Update table names and also llm model is changed to llama3.2 --- db/actions/vectors/similarity_search.py | 4 ++-- db/actions/web_scrapper/list_webscraps.py | 6 +++--- db/actions/web_scrapper/save_webscrap.py | 6 +++--- db/schema.py | 7 ++++--- llm/OllamaService.py | 2 +- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/db/actions/vectors/similarity_search.py b/db/actions/vectors/similarity_search.py index c71117f..71c3c3a 100644 --- a/db/actions/vectors/similarity_search.py +++ b/db/actions/vectors/similarity_search.py @@ -1,11 +1,11 @@ from typing import List, Any -from db.schema import VectorData +from db.schema import OrgDataEmbedding from sqlmodel import select from db.index import UserSession async def get_similar_vectors(query_vector: List[float], session: UserSession, top_k: int = 5) -> List[Any]: try: - result = session.exec(select(VectorData).order_by(VectorData.embedding.l2_distance(query_vector)).limit(top_k)) + result = session.exec(select(OrgDataEmbedding).order_by(OrgDataEmbedding.embedding.l2_distance(query_vector)).limit(top_k)) rows = result.all() formatted_rows = [ { diff --git a/db/actions/web_scrapper/list_webscraps.py b/db/actions/web_scrapper/list_webscraps.py index 235f3c6..67942d3 100644 --- a/db/actions/web_scrapper/list_webscraps.py +++ b/db/actions/web_scrapper/list_webscraps.py @@ -1,4 +1,4 @@ -from db.schema import ScrapData +from db.schema import Orgnization from db.index import UserSession from typing import Annotated from fastapi import Query @@ -9,6 +9,6 @@ def list_webscraps( session: UserSession, offset: int = 0, limit: Annotated[int, Query(le=100)] = 100, -) -> list[ScrapData]: - webscraps = session.exec(select(ScrapData).offset(offset).limit(limit)).all() +) -> list[Orgnization]: + webscraps = session.exec(select(Orgnization).offset(offset).limit(limit)).all() return webscraps diff --git a/db/actions/web_scrapper/save_webscrap.py b/db/actions/web_scrapper/save_webscrap.py index fa44475..8e7bb37 100644 --- a/db/actions/web_scrapper/save_webscrap.py +++ b/db/actions/web_scrapper/save_webscrap.py @@ -1,9 +1,9 @@ -from db.schema import ScrapData +from db.schema import Orgnization from db.index import UserSession -def save_webscrap(data: dict, session: UserSession) -> ScrapData: - scrap_data = ScrapData( +def save_webscrap(data: dict, session: UserSession) -> Orgnization: + scrap_data = Orgnization( websiteUrl=data['websiteUrl'], websiteDepth=data['websiteDepth'], websiteMaxNumberOfPages=data['websiteMaxNumberOfPages'], diff --git a/db/schema.py b/db/schema.py index cc53f1b..4f239a4 100644 --- a/db/schema.py +++ b/db/schema.py @@ -10,14 +10,15 @@ from typing import Any from pgvector.sqlalchemy import Vector -class ScrapData(SQLModel, table=True): +class Orgnization(SQLModel, table=True): id: int = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True)) websiteUrl: str = Field(sa_column=Column(String(255))) websiteDepth: int websiteMaxNumberOfPages: int lastScrapedDate: str filePath: str -class VectorData(SQLModel, table=True): +class OrgDataEmbedding(SQLModel, table=True): id: int = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True)) metaData: dict = Field(sa_column=Column(JSON)) - embedding: Any = Field(sa_column=Column(Vector(3))) \ No newline at end of file + embedding: Any = Field(sa_column=Column(Vector(1024))) + org_id: int = Field(default=None, foreign_key="orgnization.id") \ No newline at end of file diff --git a/llm/OllamaService.py b/llm/OllamaService.py index bb70bb4..e82652b 100644 --- a/llm/OllamaService.py +++ b/llm/OllamaService.py @@ -28,7 +28,7 @@ async def ollama_client(query: str, session: UserSession) -> AsyncGenerator[dict try: client = AsyncClient() async for chunk in await client.chat( - model='llama2', + model='llama3.2', messages=messages, stream=True ): From cc6030689df7d52009edfc39a1e3fd130902e1f8 Mon Sep 17 00:00:00 2001 From: Vishwajeetsingh Desurkar Date: Tue, 26 Nov 2024 12:30:12 +0530 Subject: [PATCH 3/4] Create embedding service has been added along with db actions to call the same --- db/actions/embeddings/__init__.py | 4 + db/actions/embeddings/save_embeddings.py | 16 +++ env.sample | 25 ++++ requirements.txt | 7 +- src/embeddings/createEmbeddings.py | 150 +++++++++++++++++++++++ src/embeddings/sentenceSegmentation.py | 127 +++++++++++++++++++ src/embeddings/service.py | 30 +++++ 7 files changed, 357 insertions(+), 2 deletions(-) create mode 100644 db/actions/embeddings/__init__.py create mode 100644 db/actions/embeddings/save_embeddings.py create mode 100644 env.sample create mode 100644 src/embeddings/createEmbeddings.py create mode 100644 src/embeddings/sentenceSegmentation.py create mode 100644 src/embeddings/service.py diff --git a/db/actions/embeddings/__init__.py b/db/actions/embeddings/__init__.py new file mode 100644 index 0000000..9018add --- /dev/null +++ b/db/actions/embeddings/__init__.py @@ -0,0 +1,4 @@ +from .save_embeddings import save_embeddings + + +__all__ = ["save_embeddings"] \ No newline at end of file diff --git a/db/actions/embeddings/save_embeddings.py b/db/actions/embeddings/save_embeddings.py new file mode 100644 index 0000000..2e524e2 --- /dev/null +++ b/db/actions/embeddings/save_embeddings.py @@ -0,0 +1,16 @@ +from src.embeddings.service import EmbeddingService +from db.index import UserSession + + +def save_embeddings(data: dict, session: UserSession, org_meta: dict) -> None: + """ + Process the file to generate embeddings and associate them with the given organization metadata. + + :param data: A dictionary containing data like file path + :param session: A session instance for interacting with the database + :param org_meta: Organization metadata to associate with the embeddings + :return: None + """ + # Step: Call the EmbeddingService to process the file and generate embeddings + embedding_service = EmbeddingService() + embedding_service.process_file(data['filePath'], session, org_meta) \ No newline at end of file diff --git a/env.sample b/env.sample new file mode 100644 index 0000000..a22545d --- /dev/null +++ b/env.sample @@ -0,0 +1,25 @@ +GOOGLE_ACCOUNT_FILE='' + +#WebScrapping Varialbles +MAX_PAGES= +MIN_PAGES= +MAX_PAGES_DEFAULT +DEPTH_MIN +DEPTH_MAX +DEPTH_DEFAULT + +#AppConfig +DATABASE_URL=postgresql://user:password@localhost:5432/mydatabase + + +#SQLParser +DB_TYPE= +DB_HOST= +DB_PORT= +DB_NAME= +DB_USER= +DB_PASSWORD= + +#Embedding +VECTOR_DIM=1024 +EMBEDDING_MODEL_PATH='' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 798e779..34c3756 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,11 @@ requests>=2.31.0 beautifulsoup4>=4.12.0 html2text>=2024.2.9 urllib3>=2.0.0 -# "fastapi[standard]" +fastapi[standard] sqlmodel>=0.0.22 psycopg2>=2.9.10 pgvector -ollama>=0.4.0 \ No newline at end of file +ollama>=0.4.0 +spacy>=3.8.2 +transformers>=4.46.3 +numpy>=2.0 \ No newline at end of file diff --git a/src/embeddings/createEmbeddings.py b/src/embeddings/createEmbeddings.py new file mode 100644 index 0000000..9f4af66 --- /dev/null +++ b/src/embeddings/createEmbeddings.py @@ -0,0 +1,150 @@ +import os +import torch +from transformers import AutoModel, AutoTokenizer +from sklearn.preprocessing import normalize +from sqlmodel import SQLModel, Field, Column, JSON, Integer, Float +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.dialects.postgresql import ARRAY +from db.schema import OrgDataEmbedding +from db.index import UserSession +import logging + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +VECTOR_DIM = int(os.getenv("VECTOR_DIM", 1024)) + +# Your model directory path +model_dir = os.getenv("EMBEDDING_MODEL_PATH") + +# Initialize model and tokenizer +model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).eval() + +# Explicitly move model to CPU (this will work on Mac since CUDA is not supported) +model.to(torch.device("cpu")) + +tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + +# Linear layer for transforming the embeddings to the desired vector dimension +vector_linear = torch.nn.Linear(in_features=model.config.hidden_size, out_features=VECTOR_DIM) +vector_linear.load_state_dict({ + k.replace("linear.", ""): v for k, v in + torch.load(os.path.join(model_dir, f"2_Dense_{VECTOR_DIM}/pytorch_model.bin"), map_location=torch.device("cpu")).items() +}) + +# Function to get embeddings from transformer model +async def get_embeddings_transformer(sentences: list) -> list[list[float]]: + """ + Generate embeddings for the given sentences using a transformer model. + Args: + sentences (list[str]): List of sentences to generate embeddings for. + Returns: + list[list[float]]: List of embeddings for each sentence. + """ + embeddings = [] + + # Flatten the list if input is a list of lists + if isinstance(sentences[0], list): + sentences = [item for sublist in sentences for item in sublist] + for sentence in sentences: + try: + if not isinstance(sentence, str): + logger.warning(f"Skipping non-string input: {sentence}") + continue + + # Log the sentence being processed + logger.info(f"Processing sentence: {sentence[:50]}...") + + # Truncate overly long sentences + sentence = sentence[:512] + + # Tokenize the sentence + input_data = tokenizer( + sentence, + padding="longest", + truncation=True, + max_length=512, + return_tensors="pt" + ) + input_data = {k: v.to(torch.device("cpu")) for k, v in input_data.items()} + + # Get model's last hidden state + attention_mask = input_data["attention_mask"] + last_hidden_state = model(**input_data)[0] + + # Compute the embeddings by averaging the hidden states + last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) + sentence_embedding = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + # Detach, convert to NumPy, and normalize + detached_embedding = vector_linear(sentence_embedding).detach().cpu().numpy() + normalized_embedding = normalize(detached_embedding, axis=1) + + embeddings.append(normalized_embedding[0].tolist()) + logger.info(f"Successfully generated embedding for: {sentence[:50]}...") + + except Exception as e: + logger.error(f"Error while embedding sentence: {sentence}\n{e}") + continue + + return embeddings + +# Function to process and store embeddings +async def process_sentences_and_store( + sentences: list[str], org_meta: dict, session: UserSession +) -> int: + + if not sentences: + raise ValueError("No sentences provided for embedding.") + + if not org_meta or 'id' not in org_meta: + raise ValueError("Invalid organization metadata. Must include 'id'.") + + successful_count = 0 # Initialize the count of successfully stored embeddings + print("Sentence Count1:", len(sentences)) + + try: + # Generate embeddings using the transformer model + embeddings = await get_embeddings_transformer(sentences) + print("Embedding Count:", len(embeddings)) + if not embeddings: + logger.error("No valid embeddings were generated.") + return successful_count + logger.info("Embeddings successfully generated.") + + if not session.is_active: + session.begin() + + # Store embeddings and metadata in the database + for idx, embedding in enumerate(embeddings): + try: + # Get the corresponding sentence based on the current index + current_sentence = sentences[idx] if isinstance(sentences[idx], str) else str(sentences[idx]) + + embedding_entry = OrgDataEmbedding( + metaData={"sentence": current_sentence}, + embedding=embedding, + org_id=org_meta["id"], + ) + + session.add(embedding_entry) + successful_count += 1 + logger.info(f"Successfully stored embedding for sentence at index {idx}.") + + except Exception as e: + logger.error(f"Error storing embedding for sentence at index {idx}: {str(e)}") + continue + + if successful_count > 0: + session.commit() + logger.info(f"Successfully stored {successful_count} embeddings.") + else: + logger.error("No embeddings were stored in the database.") + + except Exception as e: + logger.error(f"Error processing sentences: {str(e)}") + session.rollback() + raise + + return successful_count diff --git a/src/embeddings/sentenceSegmentation.py b/src/embeddings/sentenceSegmentation.py new file mode 100644 index 0000000..eafa9ad --- /dev/null +++ b/src/embeddings/sentenceSegmentation.py @@ -0,0 +1,127 @@ +import spacy +from typing import List +import os +import json +import re + + +class SentenceSegmentationService: + def __init__(self, model: str = "en_core_web_sm"): + """ + Initialize the sentence segmentation service with a SpaCy language model. + + Args: + model (str): The SpaCy language model to use for segmentation. + """ + self.nlp = spacy.load(model) + + def read_file(self, file_path: str) -> List[dict]: + """ + Read the content of a JSON file containing an array of JSON objects. + + Args: + file_path (str): The path to the JSON file. + + Returns: + List[dict]: The parsed JSON content. + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + if not file_path.endswith(".json"): + raise ValueError("Only .json files are supported.") + + with open(file_path, "r", encoding="utf-8") as file: + content = json.load(file) + + if not isinstance(content, list): + raise ValueError("Expected a JSON array in the file.") + + return content + + def clean_sentence(self, sentence: str) -> str: + """ + Clean an individual sentence by removing URLs, emails, mentions, stopwords, + punctuation, and applying lemmatization. + + Args: + sentence (str): The input sentence to clean. + + Returns: + str: The cleaned sentence. + """ + # Remove URLs, emails, and mentions + sentence = re.sub(r'http\S+|www\S+|https\S+', '', sentence) + sentence = re.sub(r'\S+@\S+', '', sentence) + sentence = re.sub(r'@[\w]+', '', sentence) # Remove mentions (e.g., @user) + + # Apply SpaCy processing + doc = self.nlp(sentence) + + # Tokenize and clean text by removing stopwords, punctuation, and lemmatization + cleaned_tokens = [ + token.lemma_ for token in doc + if not token.is_stop and not token.is_punct and token.is_alpha + ] + + return ' '.join(cleaned_tokens) + + def segment_text(self, text: str) -> List[str]: + """ + Segment the given text into sentences using SpaCy's NLP model. + + Args: + text (str): The input text to segment. + + Returns: + List[str]: A list of sentences extracted from the text. + """ + doc = self.nlp(text) + return [sent.text.strip() for sent in doc.sents] + + def process_text(self, text: str) -> List[str]: + """ + Segment the given text into sentences and clean each sentence. + + Args: + text (str): The input text to process. + + Returns: + List[str]: A list of cleaned sentences. + """ + # Segment the text into sentences + sentences = self.segment_text(text) + print("Sentence Count:", len(sentences)) + # Clean each sentence + processed_sentences = [self.clean_sentence(sentence) for sentence in sentences] + print("Sentence Count:", len(processed_sentences)) + return processed_sentences + + def segment_content(self, file_path: str) -> List[List[str]]: + """ + Read a JSON file, extract the "content" attribute from each JSON object, + segment it into sentences, and clean each sentence. + + Args: + file_path (str): The path to the JSON file. + + Returns: + List[List[str]]: A list where each element is a list of cleaned sentences + from the "content" attribute of the corresponding JSON object. + """ + content = self.read_file(file_path) + + processed_content = [] + for obj in content: + if "content" not in obj: + raise KeyError(f"Key 'content' not found in one of the JSON objects.") + text = obj["content"] + if not isinstance(text, str): + raise ValueError(f"Expected a string for key 'content', got {type(text).__name__}.") + + # Segment and clean the "content" value + processed_content.append(self.process_text(text)) + + + sentences = [item for sublist in processed_content for item in sublist] + + return sentences diff --git a/src/embeddings/service.py b/src/embeddings/service.py new file mode 100644 index 0000000..363e252 --- /dev/null +++ b/src/embeddings/service.py @@ -0,0 +1,30 @@ +from src.embeddings.sentenceSegmentation import SentenceSegmentationService +from src.embeddings.createEmbeddings import process_sentences_and_store + +class EmbeddingService: + def __init__(self): + # Initialize the segmentation and embedding services + self.segmentation_service = SentenceSegmentationService() + + async def process_file(self, file_path, session, org_meta): + """ + Process the given file, segment it into sentences, and store embeddings. + + :param file_path: Path to the file containing text to segment and process + :param session: Database or session instance used for storing the embeddings + :param org_meta: Organization metadata (e.g., {'id': 123}) + :return: bool - True for success, False for failure + """ + try: + # Segment the file content into sentences + sentences = self.segmentation_service.segment_content(file_path) + + # Process the sentences and store embeddings in the database + await process_sentences_and_store(sentences, org_meta, session) + + # If processing is successful, return True + return True + except Exception as e: + # If an error occurs, log and return False + print(f"Error processing file: {str(e)}") + return False From 5f5b7df316dabaf54d3d37ded36f79d64104d976 Mon Sep 17 00:00:00 2001 From: Vishwajeetsingh Desurkar Date: Tue, 26 Nov 2024 14:29:06 +0530 Subject: [PATCH 4/4] Update embedding model for the query and webscrapped data --- data_source/webscraper/index.py | 2 +- db/actions/embeddings/save_embeddings.py | 6 ++++-- db/actions/web_scrapper/save_webscrap.py | 8 ++++---- llm/OllamaService.py | 13 +++++++++---- main.py | 5 ++++- src/embeddings/createEmbeddings.py | 7 ++----- src/embeddings/service.py | 13 +++++++++---- 7 files changed, 33 insertions(+), 21 deletions(-) diff --git a/data_source/webscraper/index.py b/data_source/webscraper/index.py index 786c794..0865ef9 100644 --- a/data_source/webscraper/index.py +++ b/data_source/webscraper/index.py @@ -310,7 +310,7 @@ def save_results(self): "websiteDepth": self.depth, "websiteMaxNumberOfPages": self.max_pages, "lastScrapedDate": timestamp, - "filePath": output_dir + "filePath": metadata_filename } diff --git a/db/actions/embeddings/save_embeddings.py b/db/actions/embeddings/save_embeddings.py index 2e524e2..0bcce68 100644 --- a/db/actions/embeddings/save_embeddings.py +++ b/db/actions/embeddings/save_embeddings.py @@ -1,8 +1,9 @@ from src.embeddings.service import EmbeddingService from db.index import UserSession +from db.schema import Orgnization -def save_embeddings(data: dict, session: UserSession, org_meta: dict) -> None: +async def save_embeddings(data: Orgnization, session: UserSession) -> None: """ Process the file to generate embeddings and associate them with the given organization metadata. @@ -13,4 +14,5 @@ def save_embeddings(data: dict, session: UserSession, org_meta: dict) -> None: """ # Step: Call the EmbeddingService to process the file and generate embeddings embedding_service = EmbeddingService() - embedding_service.process_file(data['filePath'], session, org_meta) \ No newline at end of file + print(data) + await embedding_service.process_file(data.filePath, session, data.id) \ No newline at end of file diff --git a/db/actions/web_scrapper/save_webscrap.py b/db/actions/web_scrapper/save_webscrap.py index 8e7bb37..3f516fc 100644 --- a/db/actions/web_scrapper/save_webscrap.py +++ b/db/actions/web_scrapper/save_webscrap.py @@ -3,7 +3,7 @@ def save_webscrap(data: dict, session: UserSession) -> Orgnization: - scrap_data = Orgnization( + org_data = Orgnization( websiteUrl=data['websiteUrl'], websiteDepth=data['websiteDepth'], websiteMaxNumberOfPages=data['websiteMaxNumberOfPages'], @@ -11,7 +11,7 @@ def save_webscrap(data: dict, session: UserSession) -> Orgnization: filePath=data['filePath'] ) - session.add(scrap_data) + session.add(org_data) session.commit() - session.refresh(scrap_data) - return scrap_data + session.refresh(org_data) + return org_data diff --git a/llm/OllamaService.py b/llm/OllamaService.py index e82652b..0a4c183 100644 --- a/llm/OllamaService.py +++ b/llm/OllamaService.py @@ -5,6 +5,7 @@ from .ChatHistory import ChatHistory from ollama import AsyncClient from typing import AsyncGenerator +from src.embeddings.service import EmbeddingService async def ollama_client(query: str, session: UserSession) -> AsyncGenerator[dict, None]: @@ -13,22 +14,26 @@ async def ollama_client(query: str, session: UserSession) -> AsyncGenerator[dict session.chat_history.add_message({'role': 'user', 'content': query}) - # need to convert the query text to a embeddings vector for searching - # when implemented replace the below line with the actual generated vector - query_vector = [0.1, 0.2, 0.3] + query_vector = await EmbeddingService().get_query_vector(query) + + #print("Generated Vector\n", query_vector) + similar_vectors = await get_similar_vectors(query_vector, session) + #print("Similar Vectors", similar_vectors) for vector in similar_vectors: metadata_str = str(vector['metaData']) session.chat_history.add_message({'role': 'system', 'content': metadata_str}) context = session.chat_history.get_history() messages = [{'role': msg['role'], 'content': msg['content']} for msg in context] + + print("Messages", messages) try: client = AsyncClient() async for chunk in await client.chat( - model='llama3.2', + model='llama2', messages=messages, stream=True ): diff --git a/main.py b/main.py index eaea668..a1069cf 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ from contextlib import asynccontextmanager from db.actions.vectors import get_similar_vectors from db.actions.web_scrapper import list_webscraps, save_webscrap +from db.actions.embeddings import save_embeddings from llm.OllamaService import ollama_client from llm.ChatHistory import ChatHistory import json @@ -45,7 +46,9 @@ async def scrap_website(scrap_model: ScrapModel, session: UserSession): crawler = WebCrawler(str(scrap_model.base_url), depth=scrap_model.depth, max_pages=scrap_model.max_pages) crawler.crawl() data = crawler.save_results() - save_webscrap(data, session) + org_data = save_webscrap(data, session) + print("data: ", data) + await save_embeddings(org_data, session) return {"message": "Crawling completed successfully"} @app.post("/chat") diff --git a/src/embeddings/createEmbeddings.py b/src/embeddings/createEmbeddings.py index 9f4af66..5454ba1 100644 --- a/src/embeddings/createEmbeddings.py +++ b/src/embeddings/createEmbeddings.py @@ -92,14 +92,11 @@ async def get_embeddings_transformer(sentences: list) -> list[list[float]]: # Function to process and store embeddings async def process_sentences_and_store( - sentences: list[str], org_meta: dict, session: UserSession + sentences: list[str], org_id: int, session: UserSession ) -> int: if not sentences: raise ValueError("No sentences provided for embedding.") - - if not org_meta or 'id' not in org_meta: - raise ValueError("Invalid organization metadata. Must include 'id'.") successful_count = 0 # Initialize the count of successfully stored embeddings print("Sentence Count1:", len(sentences)) @@ -125,7 +122,7 @@ async def process_sentences_and_store( embedding_entry = OrgDataEmbedding( metaData={"sentence": current_sentence}, embedding=embedding, - org_id=org_meta["id"], + org_id=org_id ) session.add(embedding_entry) diff --git a/src/embeddings/service.py b/src/embeddings/service.py index 363e252..136b753 100644 --- a/src/embeddings/service.py +++ b/src/embeddings/service.py @@ -1,12 +1,13 @@ from src.embeddings.sentenceSegmentation import SentenceSegmentationService -from src.embeddings.createEmbeddings import process_sentences_and_store +from src.embeddings.createEmbeddings import process_sentences_and_store, get_embeddings_transformer +from sqlmodel import String class EmbeddingService: def __init__(self): # Initialize the segmentation and embedding services self.segmentation_service = SentenceSegmentationService() - async def process_file(self, file_path, session, org_meta): + async def process_file(self, file_path, session, org_id): """ Process the given file, segment it into sentences, and store embeddings. @@ -18,9 +19,9 @@ async def process_file(self, file_path, session, org_meta): try: # Segment the file content into sentences sentences = self.segmentation_service.segment_content(file_path) - + # Process the sentences and store embeddings in the database - await process_sentences_and_store(sentences, org_meta, session) + await process_sentences_and_store(sentences, org_id, session) # If processing is successful, return True return True @@ -28,3 +29,7 @@ async def process_file(self, file_path, session, org_meta): # If an error occurs, log and return False print(f"Error processing file: {str(e)}") return False + + async def get_query_vector(self, sentence: String) -> dict: + embeddings = await get_embeddings_transformer([sentence]) + return embeddings[0]