Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Ollama Streaming Chat API with FastAPI #26

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions data_source/DriveParsers/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
PyPDF2>=3.0.1
"unstructured"
# "unstructured"
langchain>=0.3.4
langchain_community>=0.3.3
langchain_google_community>=2.0.1
google-api-python-client
google-auth-httplib2
google-auth-oauthlib
googleapiclient
# googleapiclient
python-dotenv
2 changes: 1 addition & 1 deletion data_source/webscraper/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def save_results(self):
"websiteDepth": self.depth,
"websiteMaxNumberOfPages": self.max_pages,
"lastScrapedDate": timestamp,
"filePath": output_dir
"filePath": metadata_filename
}


Expand Down
4 changes: 4 additions & 0 deletions db/actions/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .save_embeddings import save_embeddings


__all__ = ["save_embeddings"]
18 changes: 18 additions & 0 deletions db/actions/embeddings/save_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from src.embeddings.service import EmbeddingService
from db.index import UserSession
from db.schema import Orgnization


async def save_embeddings(data: Orgnization, session: UserSession) -> None:
"""
Process the file to generate embeddings and associate them with the given organization metadata.

:param data: A dictionary containing data like file path
:param session: A session instance for interacting with the database
:param org_meta: Organization metadata to associate with the embeddings
:return: None
"""
# Step: Call the EmbeddingService to process the file and generate embeddings
embedding_service = EmbeddingService()
print(data)
await embedding_service.process_file(data.filePath, session, data.id)
4 changes: 2 additions & 2 deletions db/actions/vectors/similarity_search.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import List, Any
from db.schema import VectorData
from db.schema import OrgDataEmbedding
from sqlmodel import select
from db.index import UserSession

async def get_similar_vectors(query_vector: List[float], session: UserSession, top_k: int = 5) -> List[Any]:
try:
result = session.exec(select(VectorData).order_by(VectorData.embedding.l2_distance(query_vector)).limit(top_k))
result = session.exec(select(OrgDataEmbedding).order_by(OrgDataEmbedding.embedding.l2_distance(query_vector)).limit(top_k))
rows = result.all()
formatted_rows = [
{
Expand Down
6 changes: 3 additions & 3 deletions db/actions/web_scrapper/list_webscraps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from db.schema import ScrapData
from db.schema import Orgnization
from db.index import UserSession
from typing import Annotated
from fastapi import Query
Expand All @@ -9,6 +9,6 @@ def list_webscraps(
session: UserSession,
offset: int = 0,
limit: Annotated[int, Query(le=100)] = 100,
) -> list[ScrapData]:
webscraps = session.exec(select(ScrapData).offset(offset).limit(limit)).all()
) -> list[Orgnization]:
webscraps = session.exec(select(Orgnization).offset(offset).limit(limit)).all()
return webscraps
12 changes: 6 additions & 6 deletions db/actions/web_scrapper/save_webscrap.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from db.schema import ScrapData
from db.schema import Orgnization
from db.index import UserSession


def save_webscrap(data: dict, session: UserSession) -> ScrapData:
scrap_data = ScrapData(
def save_webscrap(data: dict, session: UserSession) -> Orgnization:
org_data = Orgnization(
websiteUrl=data['websiteUrl'],
websiteDepth=data['websiteDepth'],
websiteMaxNumberOfPages=data['websiteMaxNumberOfPages'],
lastScrapedDate=data['lastScrapedDate'],
filePath=data['filePath']
)

session.add(scrap_data)
session.add(org_data)
session.commit()
session.refresh(scrap_data)
return scrap_data
session.refresh(org_data)
return org_data
2 changes: 2 additions & 0 deletions db/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from dotenv import load_dotenv
from typing import Annotated
from llm.ChatHistory import ChatHistory
load_dotenv()

DATABASE_URL = os.getenv("DATABASE_URL")
Expand All @@ -14,6 +15,7 @@ def create_db_and_tables():

def get_session():
with Session(engine) as session:
session.chat_history = ChatHistory()
yield session

UserSession = Annotated[Session, Depends(get_session)]
7 changes: 4 additions & 3 deletions db/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from typing import Any
from pgvector.sqlalchemy import Vector

class ScrapData(SQLModel, table=True):
class Orgnization(SQLModel, table=True):
id: int = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
websiteUrl: str = Field(sa_column=Column(String(255)))
websiteDepth: int
websiteMaxNumberOfPages: int
lastScrapedDate: str
filePath: str
class VectorData(SQLModel, table=True):
class OrgDataEmbedding(SQLModel, table=True):
id: int = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
metaData: dict = Field(sa_column=Column(JSON))
embedding: Any = Field(sa_column=Column(Vector(3)))
embedding: Any = Field(sa_column=Column(Vector(1024)))
org_id: int = Field(default=None, foreign_key="orgnization.id")
25 changes: 25 additions & 0 deletions env.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
GOOGLE_ACCOUNT_FILE=''

#WebScrapping Varialbles
MAX_PAGES=
MIN_PAGES=
MAX_PAGES_DEFAULT
DEPTH_MIN
DEPTH_MAX
DEPTH_DEFAULT

#AppConfig
DATABASE_URL=postgresql://user:password@localhost:5432/mydatabase


#SQLParser
DB_TYPE=
DB_HOST=
DB_PORT=
DB_NAME=
DB_USER=
DB_PASSWORD=

#Embedding
VECTOR_DIM=1024
EMBEDDING_MODEL_PATH=''
9 changes: 9 additions & 0 deletions llm/ChatHistory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class ChatHistory:
def __init__(self):
self.history = []

def add_message(self, message):
self.history.append(message)

def get_history(self):
return self.history
42 changes: 42 additions & 0 deletions llm/OllamaService.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from ollama import AsyncClient
from typing import List
from db.index import UserSession
from db.actions.vectors import get_similar_vectors
from .ChatHistory import ChatHistory
from ollama import AsyncClient
from typing import AsyncGenerator
from src.embeddings.service import EmbeddingService


async def ollama_client(query: str, session: UserSession) -> AsyncGenerator[dict, None]:
if not hasattr(session, 'chat_history'):
session.chat_history = ChatHistory()

session.chat_history.add_message({'role': 'user', 'content': query})

query_vector = await EmbeddingService().get_query_vector(query)

#print("Generated Vector\n", query_vector)

similar_vectors = await get_similar_vectors(query_vector, session)

#print("Similar Vectors", similar_vectors)
for vector in similar_vectors:
metadata_str = str(vector['metaData'])
session.chat_history.add_message({'role': 'system', 'content': metadata_str})

context = session.chat_history.get_history()
messages = [{'role': msg['role'], 'content': msg['content']} for msg in context]

print("Messages", messages)

try:
client = AsyncClient()
async for chunk in await client.chat(
model='llama2',
messages=messages,
stream=True
):
yield chunk
except Exception as e:
yield {'message': {'content': f"Error: {str(e)}"}}
59 changes: 56 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from fastapi import FastAPI, Depends
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from data_source.webscraper.index import WebCrawler
from db.index import create_db_and_tables, UserSession
from pydantic import BaseModel, Field, HttpUrl
from contextlib import asynccontextmanager
from db.actions.vectors import get_similar_vectors
from db.actions.web_scrapper import list_webscraps, save_webscrap
from db.actions.embeddings import save_embeddings
from llm.OllamaService import ollama_client
from llm.ChatHistory import ChatHistory
import json

from typing import List

from typing import List, AsyncGenerator
class ScrapModel(BaseModel):
base_url: HttpUrl = Field(..., example="https://example.com")
depth: int = Field(..., ge=1, le=10, example=3)
Expand All @@ -15,6 +21,9 @@ class ScrapModel(BaseModel):
class VectorQueryModel(BaseModel):
query_vector: List[float] = Field(..., example=[0.1, 0.2, 0.3])
top_k: int = Field(5, example=5)

class ChatModel(BaseModel):
message: str = Field(..., example="Hello, Ollama!")

@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -37,5 +46,49 @@ async def scrap_website(scrap_model: ScrapModel, session: UserSession):
crawler = WebCrawler(str(scrap_model.base_url), depth=scrap_model.depth, max_pages=scrap_model.max_pages)
crawler.crawl()
data = crawler.save_results()
save_webscrap(data, session)
org_data = save_webscrap(data, session)
print("data: ", data)
await save_embeddings(org_data, session)
return {"message": "Crawling completed successfully"}

@app.post("/chat")
async def chat_endpoint(request: ChatModel, session: UserSession):
async def response_stream() -> AsyncGenerator[bytes, None]:
buffer = ""
async for chunk in ollama_client(request.message, session):
if chunk and 'message' in chunk and 'content' in chunk['message']:
# Accumulate content in buffer
buffer += chunk['message']['content']
# If we have a complete word or punctuation, yield it
if buffer.endswith((' ', '.', '!', '?', '\n')):
response_json = {
"content": buffer,
"isFinished": False
}
yield f"{json.dumps(response_json)}\n".encode('utf-8')
buffer = ""

# Yield any remaining content with isFinished flag
if buffer:
response_json = {
"content": buffer,
"isFinished": True
}
yield f"{json.dumps(response_json)}\n".encode('utf-8')
else:
# Send final empty message with isFinished flag if buffer is empty
response_json = {
"content": "",
"isFinished": True
}
yield f"{json.dumps(response_json)}\n".encode('utf-8')

return StreamingResponse(
response_stream(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
Loading