-
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Josef Edwards <joed6834@colorado.edu>
- Loading branch information
1 parent
cdf8d85
commit 63929c6
Showing
1 changed file
with
53 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,55 @@ | ||
import sqlite3 | ||
import unittest | ||
from session_tracking import get_db_connection, create_sessions_table, insert_session_data, retrieve_session_data, handle_new_session | ||
from datetime import datetime | ||
|
||
# Connect to the database | ||
def get_db_connection(): | ||
return sqlite3.connect('venice.db') | ||
|
||
# Create the table for session tracking | ||
def create_sessions_table(cursor): | ||
"""Create the sessions table if it doesn't exist.""" | ||
cursor.execute(''' | ||
CREATE TABLE IF NOT EXISTS sessions ( | ||
session_id INTEGER PRIMARY KEY, | ||
flag_state TEXT, | ||
conversation_log TEXT, | ||
timestamp TEXT | ||
); | ||
''') | ||
|
||
# Function to insert session data | ||
def insert_session_data(cursor, session_id, flag_state, conversation_log, timestamp): | ||
"""Insert session data into the sessions table.""" | ||
cursor.execute(''' | ||
INSERT INTO sessions (session_id, flag_state, conversation_log, timestamp) | ||
VALUES (?, ?, ?, ?); | ||
''', (session_id, flag_state, conversation_log, timestamp)) | ||
|
||
# Function to retrieve session data based on session ID | ||
def retrieve_session_data(cursor, session_id): | ||
"""Retrieve session data from the sessions table based on session ID.""" | ||
cursor.execute(''' | ||
SELECT flag_state, conversation_log, timestamp | ||
FROM sessions | ||
WHERE session_id = ?; | ||
''', (session_id,)) | ||
return cursor.fetchone() | ||
|
||
# Function to simulate a conversation and store session data | ||
def handle_new_session(cursor, session_id, flag_state, conversation_log): | ||
"""Simulate a conversation and store session data.""" | ||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | ||
insert_session_data(cursor, session_id, flag_state, conversation_log, timestamp) | ||
|
||
# Function to check and update session logic | ||
def session_logic(cursor, session_id): | ||
"""Check and update session logic.""" | ||
session_data = retrieve_session_data(cursor, session_id) | ||
if session_data: | ||
flag_state, conversation_log, timestamp = session_data | ||
print(f"Restored session {session_id}:") | ||
print(f"Flag State: {flag_state}") | ||
print(f"Last Conversation: {conversation_log}") | ||
print(f"Timestamp: {timestamp}") | ||
else: | ||
print(f"No previous session found for session ID {session_id}. Starting fresh.") | ||
|
||
# Example session simulation | ||
def main(): | ||
session_id = 1 | ||
flag_state = "reset_context=False, check_flags=False" | ||
conversation_log = "Started the conversation with some context." | ||
|
||
with get_db_connection() as conn: | ||
cursor = conn.cursor() | ||
create_sessions_table(cursor) | ||
handle_new_session(cursor, session_id, flag_state, conversation_log) | ||
session_logic(cursor, session_id) | ||
|
||
# Run the example | ||
if __name__ == "__main__": | ||
main() | ||
class TestSessionTracking(unittest.TestCase): | ||
|
||
def setUp(self): | ||
"""Shared setup method to initialize the database and table.""" | ||
self.conn = get_db_connection() | ||
self.cursor = self.conn.cursor() | ||
create_sessions_table(self.cursor) | ||
|
||
def tearDown(self): | ||
"""Clean up method to delete data after each test.""" | ||
self.cursor.execute('DELETE FROM sessions') | ||
self.conn.commit() | ||
|
||
def test_create_sessions_table(self): | ||
"""Test if the sessions table is created.""" | ||
self.assertTrue(self.cursor.execute('SELECT * FROM sessions').fetchone() is None) | ||
|
||
def test_insert_session_data(self): | ||
"""Test if session data is inserted correctly.""" | ||
insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', '2023-03-01 12:00:00') | ||
self.assertTrue(self.cursor.execute('SELECT * FROM sessions').fetchone() is not None) | ||
|
||
def test_retrieve_session_data(self): | ||
"""Test if session data is retrieved correctly.""" | ||
insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', '2023-03-01 12:00:00') | ||
session_data = retrieve_session_data(self.cursor, 1) | ||
self.assertEqual(session_data[0], 'reset_context=False, check_flags=False') | ||
self.assertEqual(session_data[1], 'Started the conversation with some context.') | ||
|
||
def test_retrieve_non_existent_session(self): | ||
"""Test retrieving a session that doesn't exist.""" | ||
session_data = retrieve_session_data(self.cursor, 999) # assuming 999 is an invalid ID | ||
self.assertIsNone(session_data) | ||
|
||
def test_timestamp_format(self): | ||
"""Test if the timestamp is correctly formatted.""" | ||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | ||
insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', timestamp) | ||
session_data = retrieve_session_data(self.cursor, 1) | ||
self.assertRegex(session_data[2], r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}') | ||
|
||
def test_update_session_data(self): | ||
"""Test if session data is updated correctly.""" | ||
insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', '2023-03-01 12:00:00') | ||
new_conversation_log = 'Added new conversation log.' | ||
handle_new_session(self.cursor, 1, 'reset_context=False, check_flags=False', new_conversation_log) | ||
session_data = retrieve_session_data(self.cursor, 1) | ||
self.assertEqual(session_data[1], new_conversation_log) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |