-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
37de337
commit fa8cf90
Showing
10 changed files
with
806 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import Optional, List | ||
from uuid import uuid4 | ||
from hubs.hub import NovaHub | ||
|
||
class NovaBot: | ||
def __init__(self, hub: NovaHub): | ||
self.hub = hub | ||
self.session_id: Optional[str] = None | ||
self.system_prompt = """You are a helpful AI assistant. | ||
You provide clear, concise, and accurate responses.""" | ||
self.message_history: List[dict] = [] | ||
|
||
async def initialize(self) -> str: | ||
"""Initialize the bot and return a session ID""" | ||
self.session_id = str(uuid4()) | ||
self.message_history = [] | ||
return self.session_id | ||
|
||
async def process_message(self, message: str, model: str = "mistral") -> str: | ||
"""Process a user message and return a response""" | ||
if not self.session_id: | ||
raise Exception("Bot not initialized") | ||
|
||
# Add user message to history | ||
self.message_history.append({ | ||
"role": "user", | ||
"content": message | ||
}) | ||
|
||
response = await self.hub.generate_response( | ||
prompt=message, | ||
system=self.system_prompt, | ||
model=model | ||
) | ||
|
||
# Add assistant response to history | ||
self.message_history.append({ | ||
"role": "assistant", | ||
"content": response | ||
}) | ||
|
||
return response | ||
|
||
async def get_history(self) -> List[dict]: | ||
"""Get the chat history""" | ||
if not self.session_id: | ||
raise Exception("Bot not initialized") | ||
return self.message_history | ||
|
||
async def cleanup(self): | ||
"""Cleanup bot resources""" | ||
self.session_id = None | ||
self.message_history = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Standard library | ||
from typing import Optional | ||
import os | ||
from pathlib import Path | ||
|
||
# Third party | ||
from openai import OpenAI, AsyncOpenAI | ||
from dotenv import load_dotenv | ||
from utils.connection_manager import ConnectionManager | ||
from utils.openai_key_validator import OpenAIKeyValidator | ||
|
||
class NovaHub: | ||
def __init__(self, host: str = "http://localhost:11434"): | ||
self.connection_manager = ConnectionManager() | ||
# Get and validate API key | ||
self.key_validator = OpenAIKeyValidator() | ||
|
||
# Initialize OpenAI client with validated key | ||
self.openai_client = AsyncOpenAI(api_key=self.key_validator.key) | ||
|
||
# Initialize Ollama client | ||
self.ollama_client = AsyncOpenAI( | ||
base_url=f"{host}/v1", | ||
api_key="ollama" | ||
) | ||
|
||
async def generate_response(self, prompt: str, model: str = "gpt-4", system: Optional[str] = None) -> str: | ||
async with self.connection_manager.get_connection() as client: | ||
messages = self.connection_manager.format_messages(prompt, system) | ||
response = await client.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
max_tokens=50 | ||
) | ||
return response.choices[0].message.content | ||
|
||
async def cleanup(self): | ||
pass # OpenAI clients don't need cleanup |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
# Standard library | ||
from typing import Dict, Optional, List | ||
|
||
# Third party | ||
from fastapi import FastAPI, HTTPException, Depends | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from pydantic import BaseModel | ||
from fastapi.responses import HTMLResponse | ||
from fastapi.staticfiles import StaticFiles | ||
from fastapi.templating import Jinja2Templates | ||
|
||
# Local | ||
from hubs.hub import NovaHub | ||
from bots.bot import NovaBot | ||
import logging | ||
|
||
# Configure logging | ||
logging.basicConfig( | ||
level=logging.INFO, | ||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||
) | ||
logger = logging.getLogger(__name__) | ||
|
||
app = FastAPI(title="NovaSystem LITE") | ||
|
||
# Add CORS middleware | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], # In production, replace with specific origins | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
|
||
# Global instances | ||
hub = NovaHub() | ||
active_sessions: Dict[str, NovaBot] = {} | ||
|
||
class ChatMessage(BaseModel): | ||
message: str | ||
session_id: Optional[str] = None | ||
model: Optional[str] = "gpt-4o" | ||
|
||
class ChatResponse(BaseModel): | ||
response: str | ||
session_id: str | ||
|
||
class ChatHistory(BaseModel): | ||
messages: List[dict] | ||
session_id: str | ||
|
||
async def get_bot(session_id: str) -> NovaBot: | ||
"""Dependency to get bot instance""" | ||
if session_id not in active_sessions: | ||
raise HTTPException(status_code=404, detail="Session not found") | ||
return active_sessions[session_id] | ||
|
||
@app.on_event("shutdown") | ||
async def shutdown_event(): | ||
"""Cleanup resources on shutdown""" | ||
logger.info("Shutting down application...") | ||
for session_id in list(active_sessions.keys()): | ||
await active_sessions[session_id].cleanup() | ||
await hub.cleanup() | ||
|
||
@app.post("/chat/", response_model=ChatResponse) | ||
async def create_chat(): | ||
"""Create a new chat session""" | ||
try: | ||
bot = NovaBot(hub) | ||
session_id = await bot.initialize() | ||
active_sessions[session_id] = bot | ||
logger.info(f"Created new chat session: {session_id}") | ||
return ChatResponse(response="Chat session created", session_id=session_id) | ||
except Exception as e: | ||
logger.error(f"Error creating chat session: {str(e)}") | ||
raise HTTPException(status_code=500, detail="Failed to create chat session") | ||
|
||
@app.post("/chat/{session_id}/message", response_model=ChatResponse) | ||
async def send_message( | ||
chat_message: ChatMessage, | ||
bot: NovaBot = Depends(get_bot) | ||
): | ||
"""Send a message in an existing chat session""" | ||
try: | ||
response = await bot.process_message( | ||
chat_message.message, | ||
model=chat_message.model | ||
) | ||
return ChatResponse(response=response, session_id=bot.session_id) | ||
except Exception as e: | ||
logger.error(f"Error processing message: {str(e)}") | ||
raise HTTPException(status_code=500, detail="Failed to process message") | ||
|
||
@app.get("/chat/{session_id}/history", response_model=ChatHistory) | ||
async def get_history(bot: NovaBot = Depends(get_bot)): | ||
"""Get chat history for a session""" | ||
try: | ||
history = await bot.get_history() | ||
return ChatHistory(messages=history, session_id=bot.session_id) | ||
except Exception as e: | ||
logger.error(f"Error fetching history: {str(e)}") | ||
raise HTTPException(status_code=500, detail="Failed to fetch chat history") | ||
|
||
@app.post("/chat/{session_id}/end") | ||
async def end_chat(bot: NovaBot = Depends(get_bot)): | ||
"""End a chat session""" | ||
try: | ||
session_id = bot.session_id | ||
await bot.cleanup() | ||
del active_sessions[session_id] | ||
logger.info(f"Ended chat session: {session_id}") | ||
return {"message": "Chat session ended"} | ||
except Exception as e: | ||
logger.error(f"Error ending chat session: {str(e)}") | ||
raise HTTPException(status_code=500, detail="Failed to end chat session") | ||
|
||
@app.get("/", response_class=HTMLResponse) | ||
async def root(): | ||
return """ | ||
<!DOCTYPE html> | ||
<html> | ||
<head> | ||
<title>NovaSystem LITE Chat</title> | ||
<style> | ||
body { max-width: 800px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif; } | ||
#chat-box { height: 400px; border: 1px solid #ccc; overflow-y: auto; padding: 10px; margin: 20px 0; } | ||
#message-input { width: 80%; padding: 5px; } | ||
button { padding: 5px 15px; } | ||
</style> | ||
</head> | ||
<body> | ||
<h1>NovaSystem LITE Chat</h1> | ||
<select id="model-select"> | ||
<option value="gpt-4o" selected>GPT-4o</option> | ||
<option value="llama3.2">Llama 3.2</option> | ||
</select> | ||
<div id="chat-box"></div> | ||
<input type="text" id="message-input" placeholder="Type your message..."> | ||
<button onclick="sendMessage()">Send</button> | ||
<script> | ||
let sessionId = null; | ||
async function createSession() { | ||
const response = await fetch('/chat/', { method: 'POST' }); | ||
const data = await response.json(); | ||
sessionId = data.session_id; | ||
appendMessage('System', 'Chat session created'); | ||
} | ||
async function sendMessage() { | ||
if (!sessionId) await createSession(); | ||
const input = document.getElementById('message-input'); | ||
const model = document.getElementById('model-select').value; | ||
const message = input.value; | ||
if (!message) return; | ||
appendMessage('You', message); | ||
input.value = ''; | ||
try { | ||
const response = await fetch(`/chat/${sessionId}/message`, { | ||
method: 'POST', | ||
headers: { 'Content-Type': 'application/json' }, | ||
body: JSON.stringify({ message, model }) | ||
}); | ||
if (!response.ok) { | ||
const error = await response.json(); | ||
throw new Error(error.detail || 'Failed to get response'); | ||
} | ||
const data = await response.json(); | ||
appendMessage('Assistant', data.response); | ||
} catch (error) { | ||
appendMessage('System', 'Error: ' + error.message); | ||
} | ||
} | ||
function appendMessage(sender, message) { | ||
const chatBox = document.getElementById('chat-box'); | ||
chatBox.innerHTML += `<p><strong>${sender}:</strong> ${message}</p>`; | ||
chatBox.scrollTop = chatBox.scrollHeight; | ||
} | ||
// Handle Enter key | ||
document.getElementById('message-input').addEventListener('keypress', function(e) { | ||
if (e.key === 'Enter') sendMessage(); | ||
}); | ||
</script> | ||
</body> | ||
</html> | ||
""" | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run(app, host="0.0.0.0", port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
fastapi>=0.68.0 | ||
uvicorn>=0.15.0 | ||
pydantic>=1.8.0 | ||
aiohttp>=3.8.0 | ||
python-dotenv>=1.0.0 | ||
openai>=1.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Standard library | ||
import os | ||
import unittest | ||
from pathlib import Path | ||
from unittest.mock import patch, MagicMock | ||
|
||
# Local imports | ||
from utils.openai_key_validator import OpenAIKeyValidator | ||
|
||
class TestOpenAIKeyValidator(unittest.TestCase): | ||
"""Test cases for OpenAIKeyValidator.""" | ||
|
||
def setUp(self): | ||
"""Set up test cases.""" | ||
# Save original environment | ||
self.original_key = os.environ.get('OPENAI_API_KEY') | ||
# Create a mock env path | ||
self.mock_env_path = MagicMock(spec=Path) | ||
self.mock_env_path.exists.return_value = False | ||
|
||
def tearDown(self): | ||
"""Clean up after tests.""" | ||
# Restore original environment | ||
if self.original_key: | ||
os.environ['OPENAI_API_KEY'] = self.original_key | ||
elif 'OPENAI_API_KEY' in os.environ: | ||
del os.environ['OPENAI_API_KEY'] | ||
|
||
def test_valid_key(self): | ||
"""Test validator with a valid API key.""" | ||
test_key = "sk-test123validkey456" | ||
with patch.dict(os.environ, {'OPENAI_API_KEY': test_key}, clear=True): | ||
validator = OpenAIKeyValidator(env_path=self.mock_env_path, search_tree=False) | ||
self.assertTrue(validator.is_valid) | ||
self.assertEqual(validator.key, test_key) | ||
|
||
def test_invalid_key_format(self): | ||
"""Test validator with invalid key format.""" | ||
test_key = "invalid-key-format" | ||
with patch.dict(os.environ, {'OPENAI_API_KEY': test_key}, clear=True): | ||
validator = OpenAIKeyValidator(env_path=self.mock_env_path, search_tree=False) | ||
self.assertFalse(validator.is_valid) | ||
with self.assertRaises(ValueError): | ||
_ = validator.key | ||
|
||
def test_placeholder_key(self): | ||
"""Test validator with placeholder key.""" | ||
test_key = "your-actual-api-key" | ||
with patch.dict(os.environ, {'OPENAI_API_KEY': test_key}, clear=True): | ||
validator = OpenAIKeyValidator(env_path=self.mock_env_path, search_tree=False) | ||
self.assertFalse(validator.is_valid) | ||
with self.assertRaises(ValueError): | ||
_ = validator.key | ||
|
||
def test_missing_key(self): | ||
"""Test validator with no API key.""" | ||
with patch.dict(os.environ, {}, clear=True): | ||
validator = OpenAIKeyValidator(env_path=self.mock_env_path, search_tree=False) | ||
self.assertFalse(validator.is_valid) | ||
with self.assertRaises(ValueError): | ||
_ = validator.key | ||
|
||
def test_env_file_loading(self): | ||
"""Test .env file loading functionality.""" | ||
test_key = "sk-test123validkey456" | ||
mock_env_path = MagicMock(spec=Path) | ||
mock_env_path.exists.return_value = True | ||
|
||
with patch('utils.openai_key_validator.load_dotenv', return_value=True) as mock_load_dotenv: | ||
with patch.dict(os.environ, {'OPENAI_API_KEY': test_key}, clear=True): | ||
validator = OpenAIKeyValidator(env_path=mock_env_path, search_tree=False) | ||
|
||
# Verify load_dotenv was called with the correct arguments | ||
mock_load_dotenv.assert_called_once_with(mock_env_path, override=True) | ||
self.assertTrue(validator.is_valid) | ||
self.assertEqual(validator.key, test_key) | ||
|
||
if __name__ == '__main__': | ||
unittest.main(verbosity=2) |
Oops, something went wrong.