Skip to content

Commit

Permalink
Bugfix string escaping of SQL statements (#123)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Lukas Fehring <lukas.ferhing@stud.uni-hannover.de>
  • Loading branch information
tornede and Lukas Fehring authored Jun 16, 2023
1 parent 6d1159c commit d183496
Show file tree
Hide file tree
Showing 11 changed files with 2,225 additions and 2,379 deletions.
2,343 changes: 1,179 additions & 1,164 deletions docs/source/examples/example_conditional_grid.ipynb

Large diffs are not rendered by default.

1,370 changes: 688 additions & 682 deletions docs/source/examples/example_general_usage.ipynb

Large diffs are not rendered by default.

727 changes: 285 additions & 442 deletions docs/source/examples/example_logtables.ipynb

Large diffs are not rendered by default.

58 changes: 23 additions & 35 deletions py_experimenter/database_connector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import abc
import itertools
import logging
from configparser import ConfigParser
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd

from py_experimenter import utils
Expand Down Expand Up @@ -47,9 +49,12 @@ def commit(self, connection) -> None:
except Exception as e:
raise DatabaseConnectionError(f'error \n{e}\n raised when committing to database.')

def execute(self, cursor, sql_statement):
def execute(self, cursor, sql_statement, values=None) -> None:
try:
cursor.execute(sql_statement)
if values is None:
cursor.execute(sql_statement)
else:
cursor.execute(sql_statement, values)
except Exception as e:
raise DatabaseConnectionError(f'error \n{e}\n raised when executing sql statement.')

Expand Down Expand Up @@ -122,9 +127,9 @@ def _create_table(self, cursor, columns: List[Tuple['str']], table_name: str, lo
raise CreatingTableError(f'Error when creating table: {err}')

def _get_create_table_query(self, columns: List[Tuple['str']], table_name: str, logtable: bool):
columns = ['%s %s DEFAULT NULL' % (self.escape_sql_chars(field)[0], datatype) for field, datatype in columns]
columns = ','.join(self.escape_sql_chars(*columns))
query = f"CREATE TABLE {self.escape_sql_chars(table_name)[0]} (ID INTEGER PRIMARY KEY {self.get_autoincrement()}"
columns = ['%s %s DEFAULT NULL' % (field, datatype) for field, datatype in columns]
columns = ','.join(columns)
query = f"CREATE TABLE {table_name} (ID INTEGER PRIMARY KEY {self.get_autoincrement()}"
if logtable:
query += f", experiment_id INTEGER, timestamp DATETIME, {columns}, FOREIGN KEY (experiment_id) REFERENCES {self.table_name}(ID) ON DELETE CASCADE"
else:
Expand Down Expand Up @@ -169,7 +174,7 @@ def fill_table(self, parameters=None, fixed_parameter_combinations=None) -> None

if rows:
logging.debug(f"Now adding {len(rows)} rows to database. {rows_skipped} rows were skipped.")
self._write_to_database(pd.DataFrame(rows, columns=column_names + ["status", "creation_date"]))
self._write_to_database(rows, column_names + ["status", "creation_date"])
logging.info(f"{len(rows)} rows successfully added to database. {rows_skipped} rows were skipped.")
else:
logging.info(f"No rows to add. All the {len(combinations)} experiments already exist.")
Expand Down Expand Up @@ -200,7 +205,7 @@ def _execute_queries(self, connection, cursor) -> Tuple[int, List, List]:
self.execute(cursor, f"SELECT id FROM {self.table_name} WHERE status = 'created' ORDER BY {order_by} LIMIT 1;")
experiment_id = self.fetchall(cursor)[0][0]
self.execute(
cursor, f"UPDATE {self.table_name} SET status = '{ExperimentStatus.RUNNING.value}', start_date = '{time}' WHERE id = {experiment_id};")
cursor, f"UPDATE {self.table_name} SET status = {self._prepared_statement_placeholder}, start_date = {self._prepared_statement_placeholder} WHERE id = {self._prepared_statement_placeholder};", (ExperimentStatus.RUNNING.value, time, experiment_id))
keyfields = ','.join(utils.get_keyfield_names(self.config))
self.execute(cursor, f"SELECT {keyfields} FROM {self.table_name} WHERE id = {experiment_id};")
values = self.fetchall(cursor)
Expand All @@ -212,49 +217,32 @@ def _execute_queries(self, connection, cursor) -> Tuple[int, List, List]:
def _pull_open_experiment(self) -> Tuple[int, List, List]:
pass

def _write_to_database(self, df) -> None:
keys = ", ".join(df.columns)
values = df.apply(lambda row: "('" + self.__class__._write_to_database_separator.join([str(value) for value in row]) + "')", axis=1)

stmt = f"INSERT INTO {self.table_name} ({keys}) VALUES {', '.join(values)}"
def _write_to_database(self, values: List, columns=List[str]) -> None:
values_prepared = ','.join([f"({', '.join([self._prepared_statement_placeholder] * len(columns))})"] * len(values))
values = list(map(lambda x: str(x) if x is not None else x, itertools.chain(*values)))
stmt = f"INSERT INTO {self.table_name} ({','.join(columns)}) VALUES {values_prepared}"

connection = self.connect()
cursor = self.cursor(connection)
self.execute(cursor, stmt)
self.execute(cursor, stmt, values)
self.commit(connection)
self.close_connection(connection)

def prepare_write_query(self, table_name: str, keys) -> str:
return f"INSERT INTO {table_name} ({', '.join(keys)}) VALUES ({','.join([self._prepared_statement_placeholder] * len(keys))})"

def update_database(self, table_name: str, values: Dict[str, Union[str, int, object]], condition: str):
connection = self.connect()
cursor = self.cursor(connection)
cursor.execute(self._prepare_update_query(table_name, values.keys(), condition), list(values.values()))
self.execute(cursor, self._prepare_update_query(table_name, values.keys(), condition),
list(values.values()))
self.commit(connection)
self.close_connection(connection)

def _prepare_update_query(self, table_name: str, values: Dict[str, Union[str, int, object]], condition: str) -> str:
return (f"UPDATE {table_name} SET {', '.join(f'{key} = {self._prepared_statement_placeholder}' for key in values)}"
f" WHERE {condition}")

def not_executed_yet(self, where) -> bool:
not_executed = False

try:
connection = self.connect()
cursor = self.cursor(connection)

stmt = "SELECT status FROM %s WHERE %s" % (self.table_name, where)

self.execute(cursor, stmt)
for result in cursor:
if result[0] == 'created':
not_executed = True

except Exception as err:
logging.error(err)
else:
connection.close()
return not_executed

def reset_experiments(self, *states: str) -> None:
def get_dict_for_keyfields_and_rows(keyfields: List[str], rows: List[List[str]]) -> List[dict]:
return [{key: value for key, value in zip(keyfields, row)} for row in rows]
Expand Down Expand Up @@ -311,7 +299,7 @@ def execute_queries(self, queries: List[str]):
connection = self.connect()
cursor = self.cursor(connection)
for query in queries:
self.execute(cursor, query)
self.execute(cursor, query[0], tuple(query[1]))
self.commit(connection)
self.close_connection(connection)

Expand Down
4 changes: 2 additions & 2 deletions py_experimenter/database_connector_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ def _table_has_correct_structure(self, cursor, typed_fields) -> List[str]:
config_columns = [k[0] for k in typed_fields]
return set(columns) == set(config_columns)

def _get_existing_rows(self, column_names):
def _get_existing_rows(self, column_names: List[str]):
def _remove_string_markers(row):
return row.replace("'", "")
connection = self.connect()
cursor = self.cursor(connection)
self.execute(cursor, f"SELECT {', '.join(column_names)} FROM {self.table_name}")
self.execute(cursor, f"SELECT {','.join(column_names)} FROM {self.table_name}")
existing_rows = list(map(np.array2string, np.array(self.fetchall(cursor))))
existing_rows = [' '.join(_remove_string_markers(row).split()) for row in existing_rows]
self.close_connection(connection)
Expand Down
16 changes: 3 additions & 13 deletions py_experimenter/database_connector_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class DatabaseConnectorMYSQL(DatabaseConnector):
_write_to_database_separator = "', '"
_prepared_statement_placeholder = '%s'

def __init__(self, experiment_configuration_file_path: ConfigParser, database_credential_file_path):
Expand Down Expand Up @@ -75,7 +74,8 @@ def get_autoincrement():

def _table_has_correct_structure(self, cursor, typed_fields):
self.execute(cursor,
f"SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{self.table_name}' AND TABLE_SCHEMA = '{self.database_name}'")
f"SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = {self._prepared_statement_placeholder} AND TABLE_SCHEMA = {self._prepared_statement_placeholder}",
(self.table_name, self.database_name))

columns = self._exclude_fixed_columns([k[0] for k in self.fetchall(cursor)])
config_columns = [k[0] for k in typed_fields]
Expand All @@ -93,16 +93,6 @@ def _pull_open_experiment(self) -> Tuple[int, List, List]:
self.close_connection(connection)

return experiment_id, description, values

@staticmethod
def escape_sql_chars(*args):
escaped_args = []
for arg in args:
if isinstance(arg, str):
escaped_args.append(arg.replace("'", "''").replace('"', '""').replace('`', '``'))
else:
escaped_args.append(arg)
return escaped_args

def _get_existing_rows(self, column_names):
def _remove_double_whitespaces(existing_rows):
Expand All @@ -113,7 +103,7 @@ def _remove_string_markers(existing_rows):

connection = self.connect()
cursor = self.cursor(connection)
self.execute(cursor, f"SELECT {', '.join(column_names)} FROM {self.table_name}")
self.execute(cursor, f"SELECT {','.join(column_names)} FROM {self.table_name}")
existing_rows = list(map(np.array2string, np.array(self.fetchall(cursor))))
existing_rows = _remove_string_markers(existing_rows)
existing_rows = _remove_double_whitespaces(existing_rows)
Expand Down
7 changes: 3 additions & 4 deletions py_experimenter/result_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def process_results(self, results: dict) -> None:
want to write results to the database.
:param results: Dictionary with result field name and result value pairs.
"""
time = utils.get_timestamp_representation()
if not self._valid_result_fields(list(results.keys())):
invalid_result_keys = set(list(results.keys())) - set(self._result_fields)
raise InvalidResultFieldError(f"Invalid result keys: {invalid_result_keys}")
Expand All @@ -71,9 +70,9 @@ def process_logs(self, logs: Dict[str, Dict[str, str]]) -> None:
for logtable_identifier, log_entries in logs.items():
logtable_name = f'{self._table_name}__{logtable_identifier}'
log_entries['experiment_id'] = str(self._experiment_id)
log_entries['timestamp'] = f"'{time}'"
queries.append(
f"INSERT INTO {logtable_name} ({', '.join(log_entries.keys())}) VALUES ({', '.join(map(lambda x: str(x), log_entries.values()))})")
log_entries['timestamp'] = f"{time}"
stmt = self._dbconnector.prepare_write_query(logtable_name, log_entries.keys())
queries.append((stmt, log_entries.values()))
self._dbconnector.execute_queries(queries)

def _change_status(self, status: str):
Expand Down
2 changes: 1 addition & 1 deletion py_experimenter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def load_config(path):
raise NoConfigFileError(f'Configuration file missing! Please add file: {path}')

return config


def get_keyfield_data(config):
keyfields = get_keyfields(config)
Expand Down
46 changes: 23 additions & 23 deletions test/test_database_connector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import datetime
import pandas as pd
import numpy as np
import os

import mock
import numpy as np
import pandas as pd
import pytest
from mock import patch

Expand Down Expand Up @@ -74,37 +74,37 @@ def test_create_table_if_not_exists(create_database_if_not_existing_mock, test_c
[],
['value', 'exponent', 'status', 'creation_date'],
[
[1, 3, ExperimentStatus.CREATED.value],
[1, 4, ExperimentStatus.CREATED.value],
[2, 3, ExperimentStatus.CREATED.value],
[2, 4, ExperimentStatus.CREATED.value]
[1, 3, str(ExperimentStatus.CREATED.value)],
[1, 4, str(ExperimentStatus.CREATED.value)],
[2, 3, str(ExperimentStatus.CREATED.value)],
[2, 4, str(ExperimentStatus.CREATED.value)]
]),
(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file.cfg'),
{},
[{'value': 1, 'exponent': 3}, {'value': 1, 'exponent': 4}],
['value', 'exponent', 'status', 'creation_date'],
[
[1, 3, ExperimentStatus.CREATED.value],
[1, 4, ExperimentStatus.CREATED.value],
[1, 3, str(ExperimentStatus.CREATED.value)],
[1, 4, str(ExperimentStatus.CREATED.value)],
]),
(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file_3_parameters.cfg'),
{'value': [1, 2], },
[{'exponent': 3, 'other_value': 5}],
['value', 'exponent', 'other_value', 'status', 'creation_date'],
[
[1, 3, 5, ExperimentStatus.CREATED.value],
[2, 3, 5, ExperimentStatus.CREATED.value],
[1, 3, 5, str(ExperimentStatus.CREATED.value)],
[2, 3, 5, str(ExperimentStatus.CREATED.value)],
]
),
(os.path.join('test', 'test_config_files', 'load_config_test_file', 'my_sql_test_file_3_parameters.cfg'),
{'value': [1, 2], 'exponent': [3, 4], },
[{'other_value': 5}],
['value', 'exponent', 'other_value', 'status', 'creation_date'],
[
[1, 3, 5, ExperimentStatus.CREATED.value],
[1, 4, 5, ExperimentStatus.CREATED.value],
[2, 3, 5, ExperimentStatus.CREATED.value],
[2, 4, 5, ExperimentStatus.CREATED.value],
[1, 3, 5, str(ExperimentStatus.CREATED.value)],
[1, 4, 5, str(ExperimentStatus.CREATED.value)],
[2, 3, 5, str(ExperimentStatus.CREATED.value)],
[2, 4, 5, str(ExperimentStatus.CREATED.value)],
]
),
]
Expand Down Expand Up @@ -133,15 +133,15 @@ def test_fill_table(
experiment_configuration,
database_credential_file_path=os.path.join('test', 'test_config_files', 'load_config_test_file', 'mysql_fake_credentials.cfg'))
database_connector.fill_table(parameters, fixed_parameter_combination)
args_of_first_call = write_to_database_mock.call_args_list[0][0]

assert type(args_of_first_call[0]) == pd.DataFrame
df = args_of_first_call[0]
assert len(df) == len(write_to_database_values)
assert write_to_database_keys == list(df.columns)
for expected_row, row in zip(write_to_database_values, df.values):
assert expected_row == row[:-1].tolist()
datetime_from_string_argument = datetime.datetime.strptime(row[-1], '%Y-%m-%d %H:%M:%S')
values, columns = write_to_database_mock.call_args_list[0][0]

assert isinstance(values, list)
assert len(values) == len(write_to_database_values)
assert write_to_database_keys == columns
for expected_entry, entry in zip(write_to_database_values, values):
assert isinstance(entry, list)
assert expected_entry == entry[:-1]
datetime_from_string_argument = datetime.datetime.strptime(entry[-1], '%Y-%m-%d %H:%M:%S')
assert datetime_from_string_argument.day == datetime.datetime.now().day
assert datetime_from_string_argument.hour == datetime.datetime.now().hour
assert datetime_from_string_argument.minute - datetime.datetime.now().minute <= 2
Expand Down
16 changes: 9 additions & 7 deletions test/test_logtables/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,19 @@ def test_tables_created(execute_mock, close_connection_mock, fetchall_mock, curs
@freeze_time("2012-01-14 03:21:34")
@patch('py_experimenter.result_processor.DatabaseConnectorMYSQL')
def test_logtable_insertion(database_connector_mock):
fixed_time = '2012-01-14 03:21:34'
config = ConfigParser()
config.read(os.path.join('test', 'test_logtables', 'mysql_logtables.cfg'))
result_processor = ResultProcessor(config, None, None, None, 0)
result_processor._table_name = 'some_table_name'
result_processor.process_logs({'test_table_0': {'test0': 'test', 'test1': 'test'},
'test_table_1': {'test0': 'test'}})
result_processor._table_name = 'table_name'
table_0_logs = {'test0': 'test', 'test1': 'test'}
table_1_logs = {'test0': 'test'}
result_processor.process_logs({'test_table_0': table_0_logs,
'test_table_1': table_1_logs})
result_processor._dbconnector.prepare_write_query.assert_any_call(
'table_name__test_table_1', table_1_logs.keys())
result_processor._dbconnector.prepare_write_query.assert_any_call(
'table_name__test_table_0', table_0_logs.keys())
result_processor._dbconnector.execute_queries.assert_called()
result_processor._dbconnector.execute_queries.assert_called_with(
[f'INSERT INTO some_table_name__test_table_0 (test0, test1, experiment_id, timestamp) VALUES (test, test, 0, \'{fixed_time}\')',
f'INSERT INTO some_table_name__test_table_1 (test0, experiment_id, timestamp) VALUES (test, 0, \'{fixed_time}\')'])


@patch('py_experimenter.experimenter.DatabaseConnectorMYSQL._create_database_if_not_existing')
Expand Down
15 changes: 9 additions & 6 deletions test/test_logtables/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,23 @@ def test_tables_created(execute_mock, close_connection_mock, fetchall_mock, curs
assert execute_mock.mock_calls[1][1][1] == ('CREATE TABLE test_sqlite_logtables__test_sqlite_log (ID INTEGER PRIMARY KEY AUTOINCREMENT, experiment_id INTEGER,'
' timestamp DATETIME, test int DEFAULT NULL, FOREIGN KEY (experiment_id) REFERENCES test_sqlite_logtables(ID) ON DELETE CASCADE);')


@freeze_time("2012-01-14 03:21:34")
@patch('py_experimenter.result_processor.DatabaseConnectorLITE')
def test_logtable_insertion(database_connector_mock):
config = ConfigParser()
config.read(os.path.join('test', 'test_logtables', 'sqlite_logtables.cfg'))
result_processor = ResultProcessor(config, None, None, None, 0)
result_processor._table_name = 'table_name'
result_processor.process_logs({'test_table_0': {'test0': 'test', 'test1': 'test'},
'test_table_1': {'test0': 'test'}})
table_0_logs = {'test0': 'test', 'test1': 'test'}
table_1_logs = {'test0': 'test'}
result_processor.process_logs({'test_table_0': table_0_logs,
'test_table_1': table_1_logs})
# result_processor._dbconnector.prepare_write_query.
result_processor._dbconnector.prepare_write_query.assert_any_call(
'table_name__test_table_1', table_1_logs.keys())
result_processor._dbconnector.prepare_write_query.assert_any_call(
'table_name__test_table_0', table_0_logs.keys())
result_processor._dbconnector.execute_queries.assert_called()
result_processor._dbconnector.execute_queries.assert_called_with(
['INSERT INTO table_name__test_table_0 (test0, test1, experiment_id, timestamp) VALUES (test, test, 0, \'2012-01-14 03:21:34\')',
'INSERT INTO table_name__test_table_1 (test0, experiment_id, timestamp) VALUES (test, 0, \'2012-01-14 03:21:34\')'])


@patch('py_experimenter.experimenter.DatabaseConnectorLITE._test_connection')
Expand Down

0 comments on commit d183496

Please sign in to comment.