Skip to content

Commit

Permalink
Update it with API Key validation
Browse files Browse the repository at this point in the history
  • Loading branch information
ctavolazzi committed Dec 7, 2024
1 parent 37de337 commit fa8cf90
Show file tree
Hide file tree
Showing 10 changed files with 806 additions and 0 deletions.
53 changes: 53 additions & 0 deletions dev/NS-bytesize/bots/bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Optional, List
from uuid import uuid4
from hubs.hub import NovaHub

class NovaBot:
def __init__(self, hub: NovaHub):
self.hub = hub
self.session_id: Optional[str] = None
self.system_prompt = """You are a helpful AI assistant.
You provide clear, concise, and accurate responses."""
self.message_history: List[dict] = []

async def initialize(self) -> str:
"""Initialize the bot and return a session ID"""
self.session_id = str(uuid4())
self.message_history = []
return self.session_id

async def process_message(self, message: str, model: str = "mistral") -> str:
"""Process a user message and return a response"""
if not self.session_id:
raise Exception("Bot not initialized")

# Add user message to history
self.message_history.append({
"role": "user",
"content": message
})

response = await self.hub.generate_response(
prompt=message,
system=self.system_prompt,
model=model
)

# Add assistant response to history
self.message_history.append({
"role": "assistant",
"content": response
})

return response

async def get_history(self) -> List[dict]:
"""Get the chat history"""
if not self.session_id:
raise Exception("Bot not initialized")
return self.message_history

async def cleanup(self):
"""Cleanup bot resources"""
self.session_id = None
self.message_history = []
38 changes: 38 additions & 0 deletions dev/NS-bytesize/hubs/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Standard library
from typing import Optional
import os
from pathlib import Path

# Third party
from openai import OpenAI, AsyncOpenAI
from dotenv import load_dotenv
from utils.connection_manager import ConnectionManager
from utils.openai_key_validator import OpenAIKeyValidator

class NovaHub:
def __init__(self, host: str = "http://localhost:11434"):
self.connection_manager = ConnectionManager()
# Get and validate API key
self.key_validator = OpenAIKeyValidator()

# Initialize OpenAI client with validated key
self.openai_client = AsyncOpenAI(api_key=self.key_validator.key)

# Initialize Ollama client
self.ollama_client = AsyncOpenAI(
base_url=f"{host}/v1",
api_key="ollama"
)

async def generate_response(self, prompt: str, model: str = "gpt-4", system: Optional[str] = None) -> str:
async with self.connection_manager.get_connection() as client:
messages = self.connection_manager.format_messages(prompt, system)
response = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=50
)
return response.choices[0].message.content

async def cleanup(self):
pass # OpenAI clients don't need cleanup
197 changes: 197 additions & 0 deletions dev/NS-bytesize/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Standard library
from typing import Dict, Optional, List

# Third party
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates

# Local
from hubs.hub import NovaHub
from bots.bot import NovaBot
import logging

# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

app = FastAPI(title="NovaSystem LITE")

# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# Global instances
hub = NovaHub()
active_sessions: Dict[str, NovaBot] = {}

class ChatMessage(BaseModel):
message: str
session_id: Optional[str] = None
model: Optional[str] = "gpt-4o"

class ChatResponse(BaseModel):
response: str
session_id: str

class ChatHistory(BaseModel):
messages: List[dict]
session_id: str

async def get_bot(session_id: str) -> NovaBot:
"""Dependency to get bot instance"""
if session_id not in active_sessions:
raise HTTPException(status_code=404, detail="Session not found")
return active_sessions[session_id]

@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup resources on shutdown"""
logger.info("Shutting down application...")
for session_id in list(active_sessions.keys()):
await active_sessions[session_id].cleanup()
await hub.cleanup()

@app.post("/chat/", response_model=ChatResponse)
async def create_chat():
"""Create a new chat session"""
try:
bot = NovaBot(hub)
session_id = await bot.initialize()
active_sessions[session_id] = bot
logger.info(f"Created new chat session: {session_id}")
return ChatResponse(response="Chat session created", session_id=session_id)
except Exception as e:
logger.error(f"Error creating chat session: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to create chat session")

@app.post("/chat/{session_id}/message", response_model=ChatResponse)
async def send_message(
chat_message: ChatMessage,
bot: NovaBot = Depends(get_bot)
):
"""Send a message in an existing chat session"""
try:
response = await bot.process_message(
chat_message.message,
model=chat_message.model
)
return ChatResponse(response=response, session_id=bot.session_id)
except Exception as e:
logger.error(f"Error processing message: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to process message")

@app.get("/chat/{session_id}/history", response_model=ChatHistory)
async def get_history(bot: NovaBot = Depends(get_bot)):
"""Get chat history for a session"""
try:
history = await bot.get_history()
return ChatHistory(messages=history, session_id=bot.session_id)
except Exception as e:
logger.error(f"Error fetching history: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to fetch chat history")

@app.post("/chat/{session_id}/end")
async def end_chat(bot: NovaBot = Depends(get_bot)):
"""End a chat session"""
try:
session_id = bot.session_id
await bot.cleanup()
del active_sessions[session_id]
logger.info(f"Ended chat session: {session_id}")
return {"message": "Chat session ended"}
except Exception as e:
logger.error(f"Error ending chat session: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to end chat session")

@app.get("/", response_class=HTMLResponse)
async def root():
return """
<!DOCTYPE html>
<html>
<head>
<title>NovaSystem LITE Chat</title>
<style>
body { max-width: 800px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif; }
#chat-box { height: 400px; border: 1px solid #ccc; overflow-y: auto; padding: 10px; margin: 20px 0; }
#message-input { width: 80%; padding: 5px; }
button { padding: 5px 15px; }
</style>
</head>
<body>
<h1>NovaSystem LITE Chat</h1>
<select id="model-select">
<option value="gpt-4o" selected>GPT-4o</option>
<option value="llama3.2">Llama 3.2</option>
</select>
<div id="chat-box"></div>
<input type="text" id="message-input" placeholder="Type your message...">
<button onclick="sendMessage()">Send</button>
<script>
let sessionId = null;
async function createSession() {
const response = await fetch('/chat/', { method: 'POST' });
const data = await response.json();
sessionId = data.session_id;
appendMessage('System', 'Chat session created');
}
async function sendMessage() {
if (!sessionId) await createSession();
const input = document.getElementById('message-input');
const model = document.getElementById('model-select').value;
const message = input.value;
if (!message) return;
appendMessage('You', message);
input.value = '';
try {
const response = await fetch(`/chat/${sessionId}/message`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ message, model })
});
if (!response.ok) {
const error = await response.json();
throw new Error(error.detail || 'Failed to get response');
}
const data = await response.json();
appendMessage('Assistant', data.response);
} catch (error) {
appendMessage('System', 'Error: ' + error.message);
}
}
function appendMessage(sender, message) {
const chatBox = document.getElementById('chat-box');
chatBox.innerHTML += `<p><strong>${sender}:</strong> ${message}</p>`;
chatBox.scrollTop = chatBox.scrollHeight;
}
// Handle Enter key
document.getElementById('message-input').addEventListener('keypress', function(e) {
if (e.key === 'Enter') sendMessage();
});
</script>
</body>
</html>
"""

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
6 changes: 6 additions & 0 deletions dev/NS-bytesize/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fastapi>=0.68.0
uvicorn>=0.15.0
pydantic>=1.8.0
aiohttp>=3.8.0
python-dotenv>=1.0.0
openai>=1.0.0
79 changes: 79 additions & 0 deletions dev/NS-bytesize/tests/test_openai_key_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Standard library
import os
import unittest
from pathlib import Path
from unittest.mock import patch, MagicMock

# Local imports
from utils.openai_key_validator import OpenAIKeyValidator

class TestOpenAIKeyValidator(unittest.TestCase):
"""Test cases for OpenAIKeyValidator."""

def setUp(self):
"""Set up test cases."""
# Save original environment
self.original_key = os.environ.get('OPENAI_API_KEY')
# Create a mock env path
self.mock_env_path = MagicMock(spec=Path)
self.mock_env_path.exists.return_value = False

def tearDown(self):
"""Clean up after tests."""
# Restore original environment
if self.original_key:
os.environ['OPENAI_API_KEY'] = self.original_key
elif 'OPENAI_API_KEY' in os.environ:
del os.environ['OPENAI_API_KEY']

def test_valid_key(self):
"""Test validator with a valid API key."""
test_key = "sk-test123validkey456"
with patch.dict(os.environ, {'OPENAI_API_KEY': test_key}, clear=True):
validator = OpenAIKeyValidator(env_path=self.mock_env_path, search_tree=False)
self.assertTrue(validator.is_valid)
self.assertEqual(validator.key, test_key)

def test_invalid_key_format(self):
"""Test validator with invalid key format."""
test_key = "invalid-key-format"
with patch.dict(os.environ, {'OPENAI_API_KEY': test_key}, clear=True):
validator = OpenAIKeyValidator(env_path=self.mock_env_path, search_tree=False)
self.assertFalse(validator.is_valid)
with self.assertRaises(ValueError):
_ = validator.key

def test_placeholder_key(self):
"""Test validator with placeholder key."""
test_key = "your-actual-api-key"
with patch.dict(os.environ, {'OPENAI_API_KEY': test_key}, clear=True):
validator = OpenAIKeyValidator(env_path=self.mock_env_path, search_tree=False)
self.assertFalse(validator.is_valid)
with self.assertRaises(ValueError):
_ = validator.key

def test_missing_key(self):
"""Test validator with no API key."""
with patch.dict(os.environ, {}, clear=True):
validator = OpenAIKeyValidator(env_path=self.mock_env_path, search_tree=False)
self.assertFalse(validator.is_valid)
with self.assertRaises(ValueError):
_ = validator.key

def test_env_file_loading(self):
"""Test .env file loading functionality."""
test_key = "sk-test123validkey456"
mock_env_path = MagicMock(spec=Path)
mock_env_path.exists.return_value = True

with patch('utils.openai_key_validator.load_dotenv', return_value=True) as mock_load_dotenv:
with patch.dict(os.environ, {'OPENAI_API_KEY': test_key}, clear=True):
validator = OpenAIKeyValidator(env_path=mock_env_path, search_tree=False)

# Verify load_dotenv was called with the correct arguments
mock_load_dotenv.assert_called_once_with(mock_env_path, override=True)
self.assertTrue(validator.is_valid)
self.assertEqual(validator.key, test_key)

if __name__ == '__main__':
unittest.main(verbosity=2)
Loading

0 comments on commit fa8cf90

Please sign in to comment.