Skip to content

Commit

Permalink
Start file uploads
Browse files Browse the repository at this point in the history
  • Loading branch information
chand1012 committed Sep 15, 2023
1 parent b612221 commit bc6eb8b
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ChatGPT clone in htmx, Python, and SQLite
* Python
* ChatGPT API
* SQLite

# Goal
* Create a ChatGPT like chat interface using HTMX and Tailwind
* Use Python to interact with the ChatGPT API
Expand All @@ -21,6 +22,7 @@ ChatGPT clone in htmx, Python, and SQLite
- [x] Add loading display when waiting for response from API
- [ ] Add ability to use multiple LLMs
- [ ] Dockerize the application
- [ ] File uploads for conversations

# Possible Future Goals
* User accounts & authentication
Expand Down
17 changes: 17 additions & 0 deletions ai/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ai.tokens import tkm

def split_file_by_tokens(content: str, token_limit: int) -> list[str]:
'''Split a file into multiple parts, each within the token limit.'''
current_part = ''
current_token_count = 0
parts = []
for line in content.split('\n'):
if current_token_count + len(tkm.encode(line)) > token_limit:
parts.append(current_part)
current_part = ''
current_token_count = 0
current_part += line + '\n'
current_token_count += len(tkm.encode(line)) + 1
if current_part != '':
parts.append(current_part)
return parts
42 changes: 39 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@
from typing import Annotated

from dotenv import load_dotenv
from fastapi import FastAPI, Request, Form
from fastapi import FastAPI, Request, Form, File
from fastapi.responses import RedirectResponse, Response
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from sqlmodel import Session, select
import openai

from db import new_engine, Message, Conversation
from ai.split import split_file_by_tokens
from db import new_engine, Message, Conversation, FileUpload, PartialFile

load_dotenv()

MODEL = 'gpt-3.5-turbo'
MAX_TOKEN_CHUNK = 512 # this should be changed depending on model

openai.api_key = os.getenv('OPENAI_API_KEY')

engine = new_engine(os.getenv('DB_URI', 'sqlite:///htmx.gpt'))
engine = new_engine(os.getenv('DB_URI', 'sqlite:///htmxgpt.db'))

app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
Expand Down Expand Up @@ -142,6 +144,38 @@ async def send_form(request: Request, convo_id: int, prompt: Annotated[str, Form
# as HTML
return templates.TemplateResponse("form_response.html", {"request": request, "messages": [new_message.to_html()]})

@app.post('/conversation/{convo_id}/upload')
async def upload_file(convo_id: int, file: Annotated[bytes, File()]):
# if the file is not utf-8 encoded, return an error
try:
file = file.decode('utf-8')
except UnicodeDecodeError:
return Response(status_code=400, content="File must be UTF-8 encoded", media_type="text/plain")

# split the file into chunks
file_parts = split_file_by_tokens(file, MAX_TOKEN_CHUNK)
with Session(engine) as session:
file_upload = FileUpload(filename="test", conversation_id=convo_id)
session.add(file_upload)
session.commit()
session.refresh(file_upload)
for part in file_parts:
embedding_resp = openai.Embedding.create(
model="text-embedding-ada-002",
input=part
)
data = embedding_resp['data'][0]
embedding = data['embedding']
partial_file = PartialFile(content=part, file_upload_id=file_upload.id, embeddings=embedding)
session.add(partial_file)
session.commit()
session.refresh(file_upload)
return {
'file_upload_id': file_upload.id,
'filename': file_upload.filename,
'parts': file_upload.parts
}


@app.post('/message')
async def create_message(message: Message):
Expand All @@ -158,10 +192,12 @@ async def get_message(convo_id: int):
message = session.get(Message, convo_id)
return message


@app.get("/favicon.ico")
async def favicon():
return RedirectResponse(url='/static/favicon.ico', status_code=302)


@app.get("/site.webmanifest")
async def manifest():
return RedirectResponse(url='/static/site.webmanifest', status_code=302)
Expand Down
2 changes: 1 addition & 1 deletion db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlmodel import create_engine, SQLModel

from db.models import Message, Conversation
from db.models import Message, Conversation, FileUpload, PartialFile


def new_engine(uri: str = "sqlite:///htmxgpt.db"):
Expand Down
34 changes: 32 additions & 2 deletions db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime

from markdown import markdown
from sqlmodel import SQLModel, Field, Relationship
from sqlmodel import SQLModel, Field, Relationship, Column, JSON

from ai import MODEL_TOKEN_LIMIT
from ai.tokens import token_count_single_message
Expand Down Expand Up @@ -34,7 +34,7 @@ def to_html(self):
'role': self.role,
# TODO fix later
# this was written in 2023, comment what year it was when you found this
'content': escape_lt_gt_inside_code_tags(markdown(replace_code_blocks(self.content)).replace('</p>', '</p><br/>').replace('<ol>','<ol class="list-disc">')),
'content': escape_lt_gt_inside_code_tags(markdown(replace_code_blocks(self.content)).replace('</p>', '</p><br/>').replace('<ol>', '<ol class="list-disc">')),
'id': self.id,
}

Expand All @@ -46,12 +46,42 @@ def from_chatgpt(role, content):
)


class PartialFile(SQLModel, table=True):
__tablename__ = "partial_files"
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
content: str
file_upload_id: Optional[int] = Field(
default=None, foreign_key="file_uploads.id")
embeddings: Optional[List[float]] = Field(
default=None, sa_column=Column(JSON))
file_uploads: Optional["FileUpload"] = Relationship(
back_populates="parts")


class FileUpload(SQLModel, table=True):
__tablename__ = "file_uploads"
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
filename: str
parts: Optional[List[PartialFile]] = Relationship(
back_populates="file_uploads")
conversation_id: Optional[int] = Field(
default=None, foreign_key="conversations.id")
conversations: Optional["Conversation"] = Relationship(
back_populates="file_uploads")


class Conversation(SQLModel, table=True):
__tablename__ = "conversations"
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
title: str = Field(default="New Chat")
file_uploads: Optional[List[FileUpload]] = Relationship(
back_populates="conversations")
messages: Optional[List[Message]] = Relationship(
back_populates="conversations")

Expand Down

0 comments on commit bc6eb8b

Please sign in to comment.