diff --git a/README.md b/README.md index 81c6d18..93e502e 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ ChatGPT clone in htmx, Python, and SQLite * Python * ChatGPT API * SQLite + # Goal * Create a ChatGPT like chat interface using HTMX and Tailwind * Use Python to interact with the ChatGPT API @@ -21,6 +22,7 @@ ChatGPT clone in htmx, Python, and SQLite - [x] Add loading display when waiting for response from API - [ ] Add ability to use multiple LLMs - [ ] Dockerize the application +- [ ] File uploads for conversations # Possible Future Goals * User accounts & authentication diff --git a/ai/split.py b/ai/split.py new file mode 100644 index 0000000..476270e --- /dev/null +++ b/ai/split.py @@ -0,0 +1,17 @@ +from ai.tokens import tkm + +def split_file_by_tokens(content: str, token_limit: int) -> list[str]: + '''Split a file into multiple parts, each within the token limit.''' + current_part = '' + current_token_count = 0 + parts = [] + for line in content.split('\n'): + if current_token_count + len(tkm.encode(line)) > token_limit: + parts.append(current_part) + current_part = '' + current_token_count = 0 + current_part += line + '\n' + current_token_count += len(tkm.encode(line)) + 1 + if current_part != '': + parts.append(current_part) + return parts diff --git a/app.py b/app.py index 03d08e9..98811d4 100644 --- a/app.py +++ b/app.py @@ -2,22 +2,24 @@ from typing import Annotated from dotenv import load_dotenv -from fastapi import FastAPI, Request, Form +from fastapi import FastAPI, Request, Form, File from fastapi.responses import RedirectResponse, Response from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from sqlmodel import Session, select import openai -from db import new_engine, Message, Conversation +from ai.split import split_file_by_tokens +from db import new_engine, Message, Conversation, FileUpload, PartialFile load_dotenv() MODEL = 'gpt-3.5-turbo' +MAX_TOKEN_CHUNK = 512 # this should be changed depending on model openai.api_key = os.getenv('OPENAI_API_KEY') -engine = new_engine(os.getenv('DB_URI', 'sqlite:///htmx.gpt')) +engine = new_engine(os.getenv('DB_URI', 'sqlite:///htmxgpt.db')) app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") @@ -142,6 +144,38 @@ async def send_form(request: Request, convo_id: int, prompt: Annotated[str, Form # as HTML return templates.TemplateResponse("form_response.html", {"request": request, "messages": [new_message.to_html()]}) +@app.post('/conversation/{convo_id}/upload') +async def upload_file(convo_id: int, file: Annotated[bytes, File()]): + # if the file is not utf-8 encoded, return an error + try: + file = file.decode('utf-8') + except UnicodeDecodeError: + return Response(status_code=400, content="File must be UTF-8 encoded", media_type="text/plain") + + # split the file into chunks + file_parts = split_file_by_tokens(file, MAX_TOKEN_CHUNK) + with Session(engine) as session: + file_upload = FileUpload(filename="test", conversation_id=convo_id) + session.add(file_upload) + session.commit() + session.refresh(file_upload) + for part in file_parts: + embedding_resp = openai.Embedding.create( + model="text-embedding-ada-002", + input=part + ) + data = embedding_resp['data'][0] + embedding = data['embedding'] + partial_file = PartialFile(content=part, file_upload_id=file_upload.id, embeddings=embedding) + session.add(partial_file) + session.commit() + session.refresh(file_upload) + return { + 'file_upload_id': file_upload.id, + 'filename': file_upload.filename, + 'parts': file_upload.parts + } + @app.post('/message') async def create_message(message: Message): @@ -158,10 +192,12 @@ async def get_message(convo_id: int): message = session.get(Message, convo_id) return message + @app.get("/favicon.ico") async def favicon(): return RedirectResponse(url='/static/favicon.ico', status_code=302) + @app.get("/site.webmanifest") async def manifest(): return RedirectResponse(url='/static/site.webmanifest', status_code=302) diff --git a/db/__init__.py b/db/__init__.py index c6a1843..b909013 100644 --- a/db/__init__.py +++ b/db/__init__.py @@ -1,6 +1,6 @@ from sqlmodel import create_engine, SQLModel -from db.models import Message, Conversation +from db.models import Message, Conversation, FileUpload, PartialFile def new_engine(uri: str = "sqlite:///htmxgpt.db"): diff --git a/db/models.py b/db/models.py index ed3dfb6..67c3b85 100644 --- a/db/models.py +++ b/db/models.py @@ -2,7 +2,7 @@ from datetime import datetime from markdown import markdown -from sqlmodel import SQLModel, Field, Relationship +from sqlmodel import SQLModel, Field, Relationship, Column, JSON from ai import MODEL_TOKEN_LIMIT from ai.tokens import token_count_single_message @@ -34,7 +34,7 @@ def to_html(self): 'role': self.role, # TODO fix later # this was written in 2023, comment what year it was when you found this - 'content': escape_lt_gt_inside_code_tags(markdown(replace_code_blocks(self.content)).replace('
', '