diff --git a/session_tracking.py b/session_tracking.py index ac510fd..71d4385 100644 --- a/session_tracking.py +++ b/session_tracking.py @@ -1,71 +1,55 @@ -import sqlite3 +import unittest +from session_tracking import get_db_connection, create_sessions_table, insert_session_data, retrieve_session_data, handle_new_session from datetime import datetime -# Connect to the database -def get_db_connection(): - return sqlite3.connect('venice.db') - -# Create the table for session tracking -def create_sessions_table(cursor): - """Create the sessions table if it doesn't exist.""" - cursor.execute(''' - CREATE TABLE IF NOT EXISTS sessions ( - session_id INTEGER PRIMARY KEY, - flag_state TEXT, - conversation_log TEXT, - timestamp TEXT - ); - ''') - -# Function to insert session data -def insert_session_data(cursor, session_id, flag_state, conversation_log, timestamp): - """Insert session data into the sessions table.""" - cursor.execute(''' - INSERT INTO sessions (session_id, flag_state, conversation_log, timestamp) - VALUES (?, ?, ?, ?); - ''', (session_id, flag_state, conversation_log, timestamp)) - -# Function to retrieve session data based on session ID -def retrieve_session_data(cursor, session_id): - """Retrieve session data from the sessions table based on session ID.""" - cursor.execute(''' - SELECT flag_state, conversation_log, timestamp - FROM sessions - WHERE session_id = ?; - ''', (session_id,)) - return cursor.fetchone() - -# Function to simulate a conversation and store session data -def handle_new_session(cursor, session_id, flag_state, conversation_log): - """Simulate a conversation and store session data.""" - timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - insert_session_data(cursor, session_id, flag_state, conversation_log, timestamp) - -# Function to check and update session logic -def session_logic(cursor, session_id): - """Check and update session logic.""" - session_data = retrieve_session_data(cursor, session_id) - if session_data: - flag_state, conversation_log, timestamp = session_data - print(f"Restored session {session_id}:") - print(f"Flag State: {flag_state}") - print(f"Last Conversation: {conversation_log}") - print(f"Timestamp: {timestamp}") - else: - print(f"No previous session found for session ID {session_id}. Starting fresh.") - -# Example session simulation -def main(): - session_id = 1 - flag_state = "reset_context=False, check_flags=False" - conversation_log = "Started the conversation with some context." - - with get_db_connection() as conn: - cursor = conn.cursor() - create_sessions_table(cursor) - handle_new_session(cursor, session_id, flag_state, conversation_log) - session_logic(cursor, session_id) - -# Run the example -if __name__ == "__main__": - main() +class TestSessionTracking(unittest.TestCase): + + def setUp(self): + """Shared setup method to initialize the database and table.""" + self.conn = get_db_connection() + self.cursor = self.conn.cursor() + create_sessions_table(self.cursor) + + def tearDown(self): + """Clean up method to delete data after each test.""" + self.cursor.execute('DELETE FROM sessions') + self.conn.commit() + + def test_create_sessions_table(self): + """Test if the sessions table is created.""" + self.assertTrue(self.cursor.execute('SELECT * FROM sessions').fetchone() is None) + + def test_insert_session_data(self): + """Test if session data is inserted correctly.""" + insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', '2023-03-01 12:00:00') + self.assertTrue(self.cursor.execute('SELECT * FROM sessions').fetchone() is not None) + + def test_retrieve_session_data(self): + """Test if session data is retrieved correctly.""" + insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', '2023-03-01 12:00:00') + session_data = retrieve_session_data(self.cursor, 1) + self.assertEqual(session_data[0], 'reset_context=False, check_flags=False') + self.assertEqual(session_data[1], 'Started the conversation with some context.') + + def test_retrieve_non_existent_session(self): + """Test retrieving a session that doesn't exist.""" + session_data = retrieve_session_data(self.cursor, 999) # assuming 999 is an invalid ID + self.assertIsNone(session_data) + + def test_timestamp_format(self): + """Test if the timestamp is correctly formatted.""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', timestamp) + session_data = retrieve_session_data(self.cursor, 1) + self.assertRegex(session_data[2], r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}') + + def test_update_session_data(self): + """Test if session data is updated correctly.""" + insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', '2023-03-01 12:00:00') + new_conversation_log = 'Added new conversation log.' + handle_new_session(self.cursor, 1, 'reset_context=False, check_flags=False', new_conversation_log) + session_data = retrieve_session_data(self.cursor, 1) + self.assertEqual(session_data[1], new_conversation_log) + +if __name__ == '__main__': + unittest.main()