Skip to content

Commit

Permalink
Update session_tracking.py
Browse files Browse the repository at this point in the history
Signed-off-by: Josef Edwards <joed6834@colorado.edu>
  • Loading branch information
bearycool11 authored Nov 11, 2024
1 parent cdf8d85 commit 63929c6
Showing 1 changed file with 53 additions and 69 deletions.
122 changes: 53 additions & 69 deletions session_tracking.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,55 @@
import sqlite3
import unittest
from session_tracking import get_db_connection, create_sessions_table, insert_session_data, retrieve_session_data, handle_new_session
from datetime import datetime

# Connect to the database
def get_db_connection():
return sqlite3.connect('venice.db')

# Create the table for session tracking
def create_sessions_table(cursor):
"""Create the sessions table if it doesn't exist."""
cursor.execute('''
CREATE TABLE IF NOT EXISTS sessions (
session_id INTEGER PRIMARY KEY,
flag_state TEXT,
conversation_log TEXT,
timestamp TEXT
);
''')

# Function to insert session data
def insert_session_data(cursor, session_id, flag_state, conversation_log, timestamp):
"""Insert session data into the sessions table."""
cursor.execute('''
INSERT INTO sessions (session_id, flag_state, conversation_log, timestamp)
VALUES (?, ?, ?, ?);
''', (session_id, flag_state, conversation_log, timestamp))

# Function to retrieve session data based on session ID
def retrieve_session_data(cursor, session_id):
"""Retrieve session data from the sessions table based on session ID."""
cursor.execute('''
SELECT flag_state, conversation_log, timestamp
FROM sessions
WHERE session_id = ?;
''', (session_id,))
return cursor.fetchone()

# Function to simulate a conversation and store session data
def handle_new_session(cursor, session_id, flag_state, conversation_log):
"""Simulate a conversation and store session data."""
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
insert_session_data(cursor, session_id, flag_state, conversation_log, timestamp)

# Function to check and update session logic
def session_logic(cursor, session_id):
"""Check and update session logic."""
session_data = retrieve_session_data(cursor, session_id)
if session_data:
flag_state, conversation_log, timestamp = session_data
print(f"Restored session {session_id}:")
print(f"Flag State: {flag_state}")
print(f"Last Conversation: {conversation_log}")
print(f"Timestamp: {timestamp}")
else:
print(f"No previous session found for session ID {session_id}. Starting fresh.")

# Example session simulation
def main():
session_id = 1
flag_state = "reset_context=False, check_flags=False"
conversation_log = "Started the conversation with some context."

with get_db_connection() as conn:
cursor = conn.cursor()
create_sessions_table(cursor)
handle_new_session(cursor, session_id, flag_state, conversation_log)
session_logic(cursor, session_id)

# Run the example
if __name__ == "__main__":
main()
class TestSessionTracking(unittest.TestCase):

def setUp(self):
"""Shared setup method to initialize the database and table."""
self.conn = get_db_connection()
self.cursor = self.conn.cursor()
create_sessions_table(self.cursor)

def tearDown(self):
"""Clean up method to delete data after each test."""
self.cursor.execute('DELETE FROM sessions')
self.conn.commit()

def test_create_sessions_table(self):
"""Test if the sessions table is created."""
self.assertTrue(self.cursor.execute('SELECT * FROM sessions').fetchone() is None)

def test_insert_session_data(self):
"""Test if session data is inserted correctly."""
insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', '2023-03-01 12:00:00')
self.assertTrue(self.cursor.execute('SELECT * FROM sessions').fetchone() is not None)

def test_retrieve_session_data(self):
"""Test if session data is retrieved correctly."""
insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', '2023-03-01 12:00:00')
session_data = retrieve_session_data(self.cursor, 1)
self.assertEqual(session_data[0], 'reset_context=False, check_flags=False')
self.assertEqual(session_data[1], 'Started the conversation with some context.')

def test_retrieve_non_existent_session(self):
"""Test retrieving a session that doesn't exist."""
session_data = retrieve_session_data(self.cursor, 999) # assuming 999 is an invalid ID
self.assertIsNone(session_data)

def test_timestamp_format(self):
"""Test if the timestamp is correctly formatted."""
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', timestamp)
session_data = retrieve_session_data(self.cursor, 1)
self.assertRegex(session_data[2], r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')

def test_update_session_data(self):
"""Test if session data is updated correctly."""
insert_session_data(self.cursor, 1, 'reset_context=False, check_flags=False', 'Started the conversation with some context.', '2023-03-01 12:00:00')
new_conversation_log = 'Added new conversation log.'
handle_new_session(self.cursor, 1, 'reset_context=False, check_flags=False', new_conversation_log)
session_data = retrieve_session_data(self.cursor, 1)
self.assertEqual(session_data[1], new_conversation_log)

if __name__ == '__main__':
unittest.main()

0 comments on commit 63929c6

Please sign in to comment.