diff --git a/.github/workflows/stage_unit_tests.yml b/.github/workflows/stage_unit_tests.yml index e5f8769..e5badf3 100644 --- a/.github/workflows/stage_unit_tests.yml +++ b/.github/workflows/stage_unit_tests.yml @@ -2,7 +2,7 @@ name: Unit Tests - Stage Branch on: pull_request: - types: [opened , reopene, edited] + types: [opened , reopened, edited] branches: - 'stage' @@ -32,5 +32,5 @@ jobs: run: uv run ruff check - name: Run Unit Tests - working-directory: ./src + working-directory: . run: uv run pytest diff --git a/src/config_files/logging_config.yaml b/api/config_files/logging_config.yaml similarity index 95% rename from src/config_files/logging_config.yaml rename to api/config_files/logging_config.yaml index 3d8b5b4..2a8783d 100644 --- a/src/config_files/logging_config.yaml +++ b/api/config_files/logging_config.yaml @@ -18,7 +18,7 @@ handlers: class: logging.handlers.RotatingFileHandler level: DEBUG formatter: withdate - filename: "../logs/verity.log" + filename: "./logs/verity.log" maxBytes: 10485760 # 10MB backupCount: 5 queue_handler: diff --git a/src/config_files/verity_schema.yaml b/api/config_files/verity_schema.yaml similarity index 100% rename from src/config_files/verity_schema.yaml rename to api/config_files/verity_schema.yaml diff --git a/src/__init__.py b/api/src/__init__.py similarity index 100% rename from src/__init__.py rename to api/src/__init__.py diff --git a/api/src/category.py b/api/src/category.py new file mode 100644 index 0000000..65bf6c4 --- /dev/null +++ b/api/src/category.py @@ -0,0 +1,112 @@ +import logging +from typing import Optional + +from api.src.data_handler import Database + +# from api.src.user import User + +logger = logging.getLogger(__name__) + + +class Category: + "Main category class, for anything related to the category" + + def __init__( + self, + database: Database, + user_id: int, + category_name: str = "", + budget_value: int = 0, + id: int = 0, + parent: Optional["Category"] = None, + ): + self.database: Database = database + self.name: str = category_name + self.budget_value: int = budget_value + self.id: int = id + self.children: list[Category] = [] + self.parent: Category = parent + self.user_id = user_id + if not self.name: + self.get_name() + logger.info(f"{self.name} / {self.id} initialised") + + def __repr__(self): + return f"""Category:( + Name: {self.name} + Id: {self.id} + Budget Value: {self.budget_value} + Number of Childen: {len(self.children)} + Parent:{self.parent.name if self.parent else "Top Level Category"} + )""" + + def __str__(self): + return f"Category: {self.name}" + + def add(self): + logger.info(f"Adding {self.name} to user id {self.user_id}") + sql_statement = """ + INSERT INTO category (user_id, name, budget_value, parent_id) + VALUES (?, ?, ?, ?) + """ + if not self.parent: + logger.debug("No parent linked to category, linking now.") + self.get_default_category() + params = (self.user_id, self.name, self.budget_value, int(self.parent.id)) + success, self.id = self.database.execute(sql_statement, params, return_id=True) + if not success: + logger.error(f"Failed to add Category {self.name} Check the logs") + return self.id + + def get_default_category(self): + logger.info("Getting Default Category") + self.parent = Category(self.database, self.user_id, "internal_master_category") + logger.info(f"master category = {self.parent}") + self.parent.get_id() + logger.info(f"parent id = {self.parent.id}") + + def get_id(self): + logger.info(f"Getting Category id for {self.name}") + category_id_sql = "SELECT id FROM category WHERE user_id = ? AND name = ?" + category_id_params = (self.user_id, self.name) + category_id = self.database.read(category_id_sql, category_id_params) + logger.info(f"Category id is {category_id}") + try: + self.id = category_id[0][0] + except IndexError: + self.id = 0 + + def get_name(self): + logger.info(f"Getting Name for Category {self.id}") + cat_name_sql = "SELECT name FROM category WHERE user_id = ? and id = ?" + cat_name_params = (self.user_id, self.id) + cat_name = self.database.read(cat_name_sql, cat_name_params) + logger.debug(f"Category name returned {cat_name}") + try: + self.name = cat_name[0][0] + except IndexError: + self.name = "Unknown" + + def get_children(self): + logger.info(f"Getting Child Categories for {self.name}") + sql = """ + SELECT id, name, budget_value + FROM category + WHERE user_id = ? + AND parent_id = ? + """ + params = (self.user_id, self.id) + child_categories = self.database.read(sql, params) + for child in child_categories: + category = Category(self.database, self.user_id, child[1], child[2], child[0]) + self.children.append(category) + + def get(self): + logger.info(f"Getting details for {self.name} from database") + sql = """SELECT user_id, name, budget_value, id, parent_id + FROM category + WHERE id = ? + AND user_id = ? + """ + params = (self.id, self.user_id) + return self.database.read(sql, params) diff --git a/src/config.py b/api/src/config.py similarity index 82% rename from src/config.py rename to api/src/config.py index 851119d..8ab3f8d 100644 --- a/src/config.py +++ b/api/src/config.py @@ -6,11 +6,10 @@ class VerityConfig: def __init__(self): self.SECRET_KEY = os.environ.get("SECRET_KEY") or "super_secret_key" - self.DATABASE = "Verity.db" - self.CONFIG_FILE_DIRECTORY = "config_files" + self.DATABASE = "api/data/verity.db" + self.CONFIG_FILE_DIRECTORY = "api/config_files" self.LOGGING_CONFIG = self.load_config_file("logging_config.yaml") self.DATABASE_SCHEMA = self.load_config_file("verity_schema.yaml") - self.DEFAULT_DATA = self.load_config_file("default_data.yaml") def load_config_file(self, file): config = "" diff --git a/api/src/currency_handler.py b/api/src/currency_handler.py new file mode 100644 index 0000000..6afbe12 --- /dev/null +++ b/api/src/currency_handler.py @@ -0,0 +1,23 @@ +import logging + +logger = logging.getLogger(__name__) + +# TODO: Make this into a proper class so we can scale currrency handling + + +class CurrencyBrain: + def convert_to_universal_currency(input_value: float) -> int: + """ + Converts the the input value to remove all decimal places and return an int. + This will be the starting point for our universal currency, + (see docs/data_dictionary). + for now, we will just focus on making this an int. + it will need change later once we have the basics done + """ + logger.info(f"received {input_value} to convert to universal currency") + input_value = float(input_value) + while input_value % 1 != 0: + logger.debug(f"input value is not a whole number {input_value}") + input_value = input_value * 10 + logger.info(f"returning {int(input_value)}") + return int(input_value) diff --git a/src/data_handler.py b/api/src/data_handler.py similarity index 70% rename from src/data_handler.py rename to api/src/data_handler.py index c1d970d..6a0f44b 100644 --- a/src/data_handler.py +++ b/api/src/data_handler.py @@ -4,7 +4,7 @@ logger = logging.getLogger(__name__) -class database: +class Database: """basic Database class to start some development Will need a proper refactor once basic functions are in and working This is POC @@ -14,9 +14,8 @@ def __init__(self, config) -> None: self.verity_config = config self.schema = self.verity_config.DATABASE_SCHEMA self.database = self.verity_config.DATABASE - self.default_data = self.verity_config.DEFAULT_DATA - def execute_sql( + def execute( self, sql_statement: str, params: tuple = (), return_id: bool = False, seed: bool = False ) -> (bool, int): "send the query here, returns true if successful, false if fail" @@ -49,13 +48,15 @@ def execute_sql( connection.close() logger.info("closed connection to database") if return_id: + logger.debug(f"Returning tuple (success), (new id) ({is_success},{new_id})") return (is_success, new_id) else: + logger.debug(f"Returning success value {is_success}") return is_success - def read_database(self, sql_statement: str, params: tuple = ()) -> list: + def read(self, sql_statement: str, params: tuple = ()) -> list: "reads the database query and returns the results" - logger.debug(f"received request to read {sql_statement} with params {params}") + logger.info(f"received request to read {sql_statement} with params {params}") results = [] try: connection = sqlite3.connect(self.database) @@ -151,53 +152,7 @@ def print_table_schema(self, table_name): except Exception as e: logger.error(e) - def add_user_name(self, user_name: str) -> int: - "takes user name string, returns user id" - logger.debug(f"attempting to insert values into user table {user_name}") - sql_statement = """ - INSERT INTO user (name) - VALUES (?) - """ - params = (user_name,) - success, user_id = self.execute_sql(sql_statement, params, True) - if not success: - logger.error("Failed to execute sql, check the logs") - self.add_category(user_id, "internal_master_category", seed=True) - return user_id - def get_users(self) -> list: """Returns all users in the database.""" get_user_sql = "SELECT id, name FROM user" - return self.read_database(get_user_sql) - - def add_category( - self, user_id: int, category_name: str, budget_value: int = 0, parent_id=None, seed=False - ) -> int: - """Inserts a new category. Returns the category id.""" - logger.debug(f"attempting to insert category '{category_name}' for user {user_id}") - sql_statement = """ - INSERT INTO category (user_id, name, budget_value, parent_id) - VALUES (?, ?, ?, ?)""" - if not seed: - if not parent_id: - parent_id = self.read_database( - "SELECT id FROM category WHERE user_id = ? AND name = ?", - (user_id, "internal_master_category"), - ) - try: - parent_id = parent_id[0][0] - except IndexError: - parent_id = None - params = (user_id, category_name, budget_value, parent_id) - - success, category_id = self.execute_sql(sql_statement, params, True, seed) - if not success: - logger.error("Failed to execute sql, check the logs") - return category_id - - def get_categories(self, user_id: int) -> list: - """Returns all categories for a given user.""" - sql = ( - "SELECT id, name, budget_value, parent_id FROM category WHERE user_id = ? and name != ?" - ) - return self.read_database(sql_statement=sql, params=(user_id, "internal_master_category")) + return self.read(get_user_sql) diff --git a/api/src/user.py b/api/src/user.py new file mode 100644 index 0000000..d670ea5 --- /dev/null +++ b/api/src/user.py @@ -0,0 +1,100 @@ +import logging +from typing import List + +from api.src.category import Category +from api.src.data_handler import Database + +logger = logging.getLogger(__name__) + + +class User: + "Main user class, for anything related to the user" + + def __init__(self, database: Database, user_name: str = "", id: int = 0): + self.database: Database = database + self.name: str = user_name + self.id: int = id + self.categories: List = [] + self.internal_category_id = 0 + logger.info(f"{self.name} initialised.") + + def __str__(self): + return f"User is {self.name}" + + def add(self) -> int: + logger.info(f"adding user {self.name}") + sql = "INSERT INTO USER(name) VALUES (?)" + params = (self.name,) + success, self.id = self.database.execute(sql, params, return_id=True) + if not success: + logger.error(f"Failed to insert user {self.name} Please check the logs") + return 0 + category_sql = """INSERT INTO category (user_id, name, budget_value, parent_id) + VALUES (?, ?, ?, ?)""" + category_params = (self.id, "internal_master_category", 0, None) + category_success, default_category_id = self.database.execute( + category_sql, category_params, return_id=True, seed=True + ) + if not category_success: + logger.error(f"Failed to add default category for {self.name} Please check the logs") + # if category fails to add, + # we probably need to delete the user, so we dont lock the username + logger.info(f"master_category id = {default_category_id}") + self.internal_category_id = default_category_id + return self.id + + def exists(self) -> bool: + logger.info(f"Checking to see if {self.name} already exists!") + sql = "SELECT 1 FROM user where name = ?" + params = (self.name,) + user_in_db = self.database.read(sql, params) + logger.info(user_in_db) + if user_in_db: + logger.warning(f"{self.name} already exists in the database") + return True + return False + + def get_categories(self): + "Gets all top level categories for the user" + logger.info(f"Getting Categories for {self.name}") + sql = """ + SELECT id, name, budget_value + FROM category + WHERE user_id = ? + AND parent_id = ? + """ + if self.internal_category_id == 0: + self.get_internal_master_category() + params = (self.id, self.internal_category_id) + categories = self.database.read(sql, params) + logger.info(f"categories: {categories}") + for category in categories: + cat = Category(self.database, self.id, category[1], category[2], category[0]) + self.categories.append(cat) + return self.categories + + def get(self): + logger.info("Getting details for User") + sql = """ + SELECT name + FROM user + WHERE id = ? + """ + params = (self.id,) + name = self.database.read(sql, params) + logger.info(name) + name = name[0][0] + logger.info(name) + self.name = name + + def get_internal_master_category(self): + sql = """ + SELECT id FROM category WHERE name = ? and user_id = ? + """ + params = ("internal_master_category", self.id) + master_id = self.database.read(sql, params) + try: + master_id = master_id[0][0] + except Exception: + master_id = 0 + self.internal_category_id = master_id diff --git a/src/front/__init__.py b/front/__init__.py similarity index 100% rename from src/front/__init__.py rename to front/__init__.py diff --git a/src/front/home.py b/front/home.py similarity index 68% rename from src/front/home.py rename to front/home.py index 0bf3eb0..b5a0690 100644 --- a/src/front/home.py +++ b/front/home.py @@ -2,30 +2,37 @@ from flask import Blueprint, flash, redirect, render_template, request, session, url_for -import currency_handler -import data_handler -from config import VerityConfig +from api.src.category import Category +from api.src.config import VerityConfig +from api.src.currency_handler import CurrencyBrain +from api.src.data_handler import Database +from api.src.user import User logger = logging.getLogger(__name__) home_bp = Blueprint("home", __name__, template_folder="templates") -verity_config = VerityConfig() +# verity_config = VerityConfig() @home_bp.route("/") def home_page(): logger.info("home page hit") - db_call = data_handler.database(verity_config) + db_call = Database(VerityConfig()) users = db_call.get_users() # Get user info directly from session user_id = session.get("user_id") selected_user_name = session.get("user_name") - # Debug log to see what's being passed to the template logger.debug(f"User ID from session: {user_id}, User Name from session: {selected_user_name}") logger.info(f"{selected_user_name} is logged in") - - categories = db_call.get_categories(user_id) if user_id else [] + verity_user = User(db_call, selected_user_name, user_id) + logger.info(f"user: {verity_user}") + if not verity_user: + categories = [] + else: + categories = verity_user.get_categories() + for category in categories: + category.get_children() return render_template( "home.html", users=users, @@ -43,15 +50,13 @@ def submit_user_name(): flash("Please enter a user name.", "danger") return redirect(url_for("home.home_page")) - db_call = data_handler.database(verity_config) - # More efficient check for existing user - exists = db_call.read_database("SELECT 1 FROM user WHERE name = ?", (user_name,)) - if exists: - logger.warning(f"username already exists in the database: {user_name}") + verity_user = User(Database(VerityConfig()), user_name) + logger.info(f"does user '{user_name}' exist?") + if verity_user.exists(): flash("A user with that name already exists.", "danger") return redirect(url_for("home.home_page")) logger.info(f"User submitted new user name: {user_name}") - user_id = db_call.add_user_name(user_name) + user_id = verity_user.add() if user_id == 0: flash("User Name not saved, please check the logs", "danger") return redirect(url_for("home.home_page")) @@ -79,13 +84,14 @@ def select_user(): return redirect(url_for("home.home_page")) # Get the user details directly from the database using the ID - db_call = data_handler.database(verity_config) - result = db_call.read_database("SELECT name FROM user WHERE id = ?", (selected_user_id,)) + database = Database(VerityConfig()) + verity_user = User(database, id=selected_user_id) + verity_user.get() - if result and result[0]: + if verity_user.name and verity_user.id: # Store both ID and name in the session session["user_id"] = selected_user_id - session["user_name"] = result[0][0] + session["user_name"] = verity_user.name flash("User selected!", "success") else: session.pop("user_id", None) @@ -119,15 +125,32 @@ def submit_category(): logger.info("user is stupid and tried to assign a negative amount to the category") flash("Negative amounts don`t really make sense here, removed budget amount", "danger") budget_value_input = 0 - budget_value = currency_handler.convert_to_universal_currency(budget_value_input) + budget_value = CurrencyBrain.convert_to_universal_currency(budget_value_input) # Convert parent_id to int if provided parent_id = request.form.get("parentId", "0").strip() - - db_call = data_handler.database(verity_config) - if parent_id == 0: - category_id = db_call.add_category(user_id, category_name, budget_value) + database = Database(VerityConfig()) + logger.debug(f"parent_id for new category is {parent_id}") + if int(parent_id) == int(0): + logger.info("user submited category with no parent, class should get default category") + new_category = Category( + database=database, + user_id=user_id, + category_name=category_name, + budget_value=budget_value, + ) else: - category_id = db_call.add_category(user_id, category_name, budget_value, parent_id) + logger.info("user submited category with parent, class should use that id") + parent_category = Category(database=database, user_id=user_id, id=parent_id) + new_category = Category( + database=database, + user_id=user_id, + category_name=category_name, + budget_value=budget_value, + parent=parent_category, + ) + logger.info(repr(new_category)) + category_id = new_category.add() + logger.info(repr(new_category)) if category_id == 0: flash("Category not saved, please check the logs", "danger") else: diff --git a/src/front/static/verity_style.css b/front/static/verity_style.css similarity index 100% rename from src/front/static/verity_style.css rename to front/static/verity_style.css diff --git a/src/front/templates/base.html b/front/templates/base.html similarity index 100% rename from src/front/templates/base.html rename to front/templates/base.html diff --git a/src/front/templates/home.html b/front/templates/home.html similarity index 79% rename from src/front/templates/home.html rename to front/templates/home.html index 667cec8..c16c850 100644 --- a/src/front/templates/home.html +++ b/front/templates/home.html @@ -41,9 +41,9 @@

Home of Verity

@@ -51,15 +51,17 @@

Home of Verity

- Your Categories:
- {% for category in categories %} - - {{ category[1] }}{% if category[3] %} (Sub of ID {{ category[3] }}){% endif %}{% if category[2] %} [Budget: {{ category[2] }}]{% endif %}
- {% endfor %} -

+Your Categories:

+
{% else %} {% endif %} -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/src/config_files/default_data.yaml b/src/config_files/default_data.yaml deleted file mode 100644 index fb9cac2..0000000 --- a/src/config_files/default_data.yaml +++ /dev/null @@ -1,4 +0,0 @@ -category: - - name: 'internal_master_category' - - name: 'tesing more categories' - budget_value: '100' diff --git a/src/currency_handler.py b/src/currency_handler.py deleted file mode 100644 index cb92036..0000000 --- a/src/currency_handler.py +++ /dev/null @@ -1,23 +0,0 @@ -import logging - -logger = logging.getLogger(__name__) - -# TODO: Make this into a proper class so we can scale currrency handling - - -def convert_to_universal_currency(input_value: float) -> int: - """ - Converts the the input value to remove all decimal places and return an int. - This will be the starting point for our universal currency, - (see docs/data_dictionary). - for now, we will just focus on making this an int. - it will need change later once we have the basics done - """ - logger.info(f"received {input_value} to convert to universal currency") - input_value = float(input_value) - while input_value % 1 != 0: - logger.debug(f"input value is not a whole number {input_value}") - input_value = input_value * 10 - logger.info(f"returning {int(input_value)}") - return int(input_value) - diff --git a/src/tests/test_data_handler.py b/src/tests/test_data_handler.py deleted file mode 100644 index 283ae80..0000000 --- a/src/tests/test_data_handler.py +++ /dev/null @@ -1,186 +0,0 @@ -import os -import tempfile - -import pytest - -from src import data_handler -from src.config import VerityConfig - - -@pytest.fixture -def test_db_call(): - """Create a test database instance with a unique test database file""" - config = VerityConfig() - # Use a test-specific database file - with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_db: - test_db_file = temp_db.name - - # Create a config with the test DB - config.DATABASE = test_db_file - db_call = data_handler.database(config) - db_call.build_database() - yield db_call - # Cleanup after tests - if os.path.exists(test_db_file): - os.remove(test_db_file) - - -def test_execute_sql_success(test_db_call): - """Test executing valid SQL statements""" - sql_statement = """INSERT INTO user ( - name - ) - VALUES ('a_user_name') - """ - result = test_db_call.execute_sql(sql_statement) - assert result is True - - -def test_execute_sql_error(test_db_call): - """Test executing invalid SQL statements""" - sql_statement = "SELECT * FROM non_existent_table" - result = test_db_call.execute_sql(sql_statement) - assert result is False - - -def test_execute_sql_with_return_id(test_db_call): - """Test executing SQL with return_id flag""" - sql_statement = """INSERT INTO user ( - name - ) - VALUES ('return_id_test') - """ - result, new_id = test_db_call.execute_sql(sql_statement, return_id=True) - assert result is True - assert new_id > 0 - - -def test_read_database(test_db_call): - """Test reading database data""" - # First insert a user - test_db_call.add_user_name("read_test_user") - # Then read it back - sql_statement = "SELECT id, name FROM user WHERE name = 'read_test_user'" - results = test_db_call.read_database(sql_statement) - assert isinstance(results, list) - assert len(results) > 0 - assert results[0][1] == "read_test_user" - - -def test_read_database_with_params(test_db_call): - """Test reading database with parameterized query""" - # First insert a user - test_db_call.add_user_name("param_test_user") - # Then read it back with params - sql_statement = "SELECT id, name FROM user WHERE name = ?" - params = ("param_test_user",) - results = test_db_call.read_database(sql_statement, params) - assert isinstance(results, list) - assert len(results) > 0 - assert results[0][1] == "param_test_user" - - -def test_build_database(test_db_call): - """Test that the database builds correctly with all required tables""" - # Check if all tables have been created - config = VerityConfig() - for table in config.DATABASE_SCHEMA["tables"]: - table_name = table["table_name"] - sql = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';" - results = test_db_call.read_database(sql) - assert len(results) == 1, f"Table '{table}' should exist" - - -def test_add_user_name(test_db_call): - """Test adding a new user name""" - user_name = "Test user" - user_id = test_db_call.add_user_name(user_name) - assert user_id > 0 - - # Check if the user name was actually added - sql = "SELECT id, name FROM user WHERE name = ?" - results = test_db_call.read_database(sql, (user_name,)) - assert len(results) == 1 - assert results[0][1] == user_name - - -def test_add_duplicate_user_name(test_db_call): - """Test adding a duplicate user name (should succeed at DB level without constraints)""" - user_name = "Duplicate User" - # Add first user - user_id1 = test_db_call.add_user_name(user_name) - assert user_id1 > 0 - - # Add second user with same name - user_id2 = test_db_call.add_user_name(user_name) - assert user_id2 > 0 - assert user_id2 != user_id1 # They should have different IDs - - -def test_get_users(test_db_call): - """Test getting all users""" - # Add some test users first - test_db_call.add_user_name("User One") - test_db_call.add_user_name("User Two") - - # Get users - users = test_db_call.get_users() - assert isinstance(users, list) - assert len(users) >= 2 - - # Check that users are returned as (id, name) tuples - for user in users: - assert len(user) == 2 - assert isinstance(user[0], int) # ID - assert isinstance(user[1], str) # Name - - -def test_add_category_success(test_db_call): - # Add a user first, since category requires a user_id - user_id = test_db_call.add_user_name("CategoryTestUser") - assert user_id != 0 - # Add a category with all fields - category_id = test_db_call.add_category(user_id, "Groceries", 200.0, None) - assert category_id != 0 - # Add a category with only required fields - category_id2 = test_db_call.add_category(user_id, "Utilities") - assert category_id2 != 0 - - -def test_add_category_null_budget_and_parent(test_db_call): - user_id = test_db_call.add_user_name("NullBudgetParentUser") - assert user_id != 0 - # Add a category with None for budget_value and parent_id - category_id = test_db_call.add_category(user_id, "NoBudgetOrParent", None, None) - assert category_id != 0 - - -def test_add_category_and_get_categories(test_db_call): - # Add a user - user_name = "CategoryTestUser" - user_id = test_db_call.add_user_name(user_name) - assert user_id != 0 - # Add a category for this user - category_name = "Groceries" - budget_value = 100.0 - parent_id = None - category_id = test_db_call.add_category(user_id, category_name, budget_value, parent_id) - assert category_id != 0 - # Add a subcategory - subcategory_name = "Supermarket" - subcategory_id = test_db_call.add_category(user_id, subcategory_name, 50.0, category_id) - assert subcategory_id != 0 - # Retrieve categories for this user - categories = test_db_call.get_categories(user_id) - assert isinstance(categories, list) - names = [cat[1] for cat in categories] - assert category_name in names - assert subcategory_name in names - - -def test_add_category_invalid_user(test_db_call): - # Try to add a category with a non-existent user_id - # (should still succeed in SQLite unless foreign keys are enforced) - category_id = test_db_call.add_category(99999, "InvalidUserCategory") - assert category_id == 0 - assert isinstance(category_id, int) diff --git a/src/tests/__init__.py b/tests/__init__.py similarity index 100% rename from src/tests/__init__.py rename to tests/__init__.py diff --git a/tests/test_category.py b/tests/test_category.py new file mode 100644 index 0000000..13717bb --- /dev/null +++ b/tests/test_category.py @@ -0,0 +1,5 @@ + + + +def test_add_category(): + pass diff --git a/src/tests/test_config.py b/tests/test_config.py similarity index 91% rename from src/tests/test_config.py rename to tests/test_config.py index 56d14bd..5e28183 100644 --- a/src/tests/test_config.py +++ b/tests/test_config.py @@ -2,7 +2,7 @@ import yaml -from src import config +from api.src import config def test_verity_config_default_secret_key(monkeypatch): @@ -23,13 +23,13 @@ def test_verity_config_secret_key_from_env(monkeypatch): def test_verity_config_database_name(monkeypatch): """Test that the database name is correctly set.""" testing_config = config.VerityConfig() - assert testing_config.DATABASE == "Verity.db" + assert testing_config.DATABASE == "api/data/verity.db" def test_verity_config_config_file_directory(monkeypatch): """Test that the config file directory is set.""" testing_config = config.VerityConfig() - assert testing_config.CONFIG_FILE_DIRECTORY == "config_files" + assert testing_config.CONFIG_FILE_DIRECTORY == "api/config_files" def test_verity_config_load_config_file(monkeypatch): diff --git a/src/tests/test_currency_handler.py b/tests/test_currency_handler.py similarity index 51% rename from src/tests/test_currency_handler.py rename to tests/test_currency_handler.py index 8111018..a5eaab5 100644 --- a/src/tests/test_currency_handler.py +++ b/tests/test_currency_handler.py @@ -1,21 +1,20 @@ - -from src import currency_handler +from api.src.currency_handler import CurrencyBrain def test_convert_to_universal_currency(): input = 87.82 - result = currency_handler.convert_to_universal_currency(input) + result = CurrencyBrain.convert_to_universal_currency(input) assert result == 8782 def test_whole_number(): input = 123 - result = currency_handler.convert_to_universal_currency(input) + result = CurrencyBrain.convert_to_universal_currency(input) assert result == input def test_negative_number(): input = -123.45 - result = currency_handler.convert_to_universal_currency(input) + result = CurrencyBrain.convert_to_universal_currency(input) print(result) assert result == -12345 diff --git a/tests/test_data_handler.py b/tests/test_data_handler.py new file mode 100644 index 0000000..cb232c5 --- /dev/null +++ b/tests/test_data_handler.py @@ -0,0 +1,197 @@ +import os +import tempfile + +import pytest + +from api.src.config import VerityConfig +from api.src.data_handler import Database + + +@pytest.fixture +def test_db_call(): + """Create a test database instance with a unique test database file""" + config = VerityConfig() + # Use a test-specific database file + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_db: + test_db_file = temp_db.name + + # Create a config with the test DB + config.DATABASE = test_db_file + db_call = Database(config) + db_call.build_database() + yield db_call + # Cleanup after tests + if os.path.exists(test_db_file): + os.remove(test_db_file) + + +def test_execute_success(test_db_call): + """Test executing valid SQL statements""" + sql_statement = """INSERT INTO user ( + name + ) + VALUES ('a_user_name') + """ + result = test_db_call.execute(sql_statement) + assert result is True + + +def test_execute_error(test_db_call): + """Test executing invalid SQL statements""" + sql_statement = "SELECT * FROM non_existent_table" + result = test_db_call.execute(sql_statement) + assert result is False + + +def test_execute_with_return_id(test_db_call): + """Test executing SQL with return_id flag""" + sql_statement = """INSERT INTO user ( + name + ) + VALUES ('return_id_test') + """ + result, new_id = test_db_call.execute(sql_statement, return_id=True) + assert result is True + assert new_id > 0 + + +def test_read(test_db_call): + """Test reading database data""" + # First insert a user + insert_sql = """INSERT INTO user ( + name + ) + VALUES ('read_test_user') + """ + _ = test_db_call.execute(insert_sql) + # Then read it back + sql_statement = "SELECT id, name FROM user WHERE name = 'read_test_user'" + results = test_db_call.read(sql_statement) + assert isinstance(results, list) + assert len(results) > 0 + assert results[0][1] == "read_test_user" + + +def test_read_with_params(test_db_call): + """Test reading database with parameterized query""" + # First insert a user + insert_sql = """INSERT INTO user ( + name + ) + VALUES ('param_test_user') + """ + _ = test_db_call.execute(insert_sql) # Then read it back with params + sql_statement = "SELECT id, name FROM user WHERE name = ?" + params = ("param_test_user",) + results = test_db_call.read(sql_statement, params) + assert isinstance(results, list) + assert len(results) > 0 + assert results[0][1] == "param_test_user" + + +def test_build_database(test_db_call): + """Test that the database builds correctly with all required tables""" + # Check if all tables have been created + config = VerityConfig() + for table in config.DATABASE_SCHEMA["tables"]: + table_name = table["table_name"] + sql = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';" + results = test_db_call.read(sql) + assert len(results) == 1, f"Table '{table}' should exist" + + +# TODO: Put the below tests in their respective test files (user and category) + +# def test_add_user_name(test_db_call): +# """Test adding a new user name""" +# user_name = "Test user" +# user_id = test_db_call.add_user_name(user_name) +# assert user_id > 0 +# +# # Check if the user name was actually added +# sql = "SELECT id, name FROM user WHERE name = ?" +# results = test_db_call.read(sql, (user_name,)) +# assert len(results) == 1 +# assert results[0][1] == user_name +# +# +# def test_add_duplicate_user_name(test_db_call): +# """Test adding a duplicate user name (should succeed at DB level without constraints)""" +# user_name = "Duplicate User" +# # Add first user +# user_id1 = test_db_call.add_user_name(user_name) +# assert user_id1 > 0 +# +# # Add second user with same name +# user_id2 = test_db_call.add_user_name(user_name) +# assert user_id2 > 0 +# assert user_id2 != user_id1 # They should have different IDs +# +# +# def test_get_users(test_db_call): +# """Test getting all users""" +# # Add some test users first +# test_db_call.add_user_name("User One") +# test_db_call.add_user_name("User Two") +# +# # Get users +# users = test_db_call.get_users() +# assert isinstance(users, list) +# assert len(users) >= 2 +# +# # Check that users are returned as (id, name) tuples +# for user in users: +# assert len(user) == 2 +# assert isinstance(user[0], int) # ID +# assert isinstance(user[1], str) # Name +# +# +# def test_add_category_success(test_db_call): +# # Add a user first, since category requires a user_id +# user_id = test_db_call.add_user_name("CategoryTestUser") +# assert user_id != 0 +# # Add a category with all fields +# category_id = test_db_call.add_category(user_id, "Groceries", 200.0, None) +# assert category_id != 0 +# # Add a category with only required fields +# category_id2 = test_db_call.add_category(user_id, "Utilities") +# assert category_id2 != 0 +# +# +# def test_add_category_null_budget_and_parent(test_db_call): +# user_id = test_db_call.add_user_name("NullBudgetParentUser") +# assert user_id != 0 +# # Add a category with None for budget_value and parent_id +# category_id = test_db_call.add_category(user_id, "NoBudgetOrParent", None, None) +# assert category_id != 0 +# +# +# def test_add_category_and_get_categories(test_db_call): +# # Add a user +# user_name = "CategoryTestUser" +# user_id = test_db_call.add_user_name(user_name) +# assert user_id != 0 +# # Add a category for this user +# category_name = "Groceries" +# budget_value = 100.0 +# parent_id = None +# category_id = test_db_call.add_category(user_id, category_name, budget_value, parent_id) +# assert category_id != 0 +# # Add a subcategory +# subcategory_name = "Supermarket" +# subcategory_id = test_db_call.add_category(user_id, subcategory_name, 50.0, category_id) +# assert subcategory_id != 0 +# # Retrieve categories for this user +# categories = test_db_call.get_categories(user_id) +# assert isinstance(categories, list) +# names = [cat[1] for cat in categories] +# assert category_name in names +# assert subcategory_name in names +# +# +# def test_add_category_invalid_user(test_db_call): +# # Try to add a category with a non-existent user_id +# # (should still succeed in SQLite unless foreign keys are enforced) +# category_id = test_db_call.add_category(99999, "InvalidUserCategory") +# assert category_id == 0 +# assert isinstance(category_id, int) diff --git a/tests/test_user.py b/tests/test_user.py new file mode 100644 index 0000000..778183c --- /dev/null +++ b/tests/test_user.py @@ -0,0 +1,5 @@ + + + +def test_add_user(): + pass diff --git a/src/verity.py b/verity.py similarity index 85% rename from src/verity.py rename to verity.py index 0bf866c..c475100 100644 --- a/src/verity.py +++ b/verity.py @@ -5,8 +5,8 @@ from flask import Flask -from config import VerityConfig -from data_handler import database +from api.src.config import VerityConfig +from api.src.data_handler import Database from front.home import home_bp @@ -27,8 +27,8 @@ def set_up_logging(config): logger.info("app starting") # database initialise - verity = database(verity_config) - logger.debug(verity_config.DEFAULT_DATA) + verity = Database(verity_config) + logger.debug(verity_config) verity.build_database() # app initialise