diff --git a/README.rst b/README.rst index c4130e36..5f044aab 100644 --- a/README.rst +++ b/README.rst @@ -104,9 +104,10 @@ Client specific loads the database each test .. code-block:: python + from pathlib import Path postgresql_my_with_schema = factories.postgresql( 'postgresql_my_proc', - load=["schemafile.sql", "otherschema.sql", "import.path.to.function", "import.path.to:otherfunction", load_this] + load=[Path("schemafile.sql"), Path("otherschema.sql"), "import.path.to.function", "import.path.to:otherfunction", load_this] ) .. warning:: @@ -115,12 +116,13 @@ Client specific loads the database each test The process fixture performs the load once per test session, and loads the data into the template database. -Client fixture then creates test database out of the template database each test, which significantly speeds up the tests. +Client fixture then creates test database out of the template database each test, which significantly **speeds up the tests**. .. code-block:: python + from pathlib import Path postgresql_my_proc = factories.postgresql_proc( - load=["schemafile.sql", "otherschema.sql", "import.path.to.function", "import.path.to:otherfunction", load_this] + load=[Path("schemafile.sql"), Path("otherschema.sql"), "import.path.to.function", "import.path.to:otherfunction", load_this] ) diff --git a/newsfragments/638.feature.rst b/newsfragments/638.feature.rst new file mode 100644 index 00000000..e30e1537 --- /dev/null +++ b/newsfragments/638.feature.rst @@ -0,0 +1,3 @@ +Now all sql files used to initialise database for tests, has to be passed as pathlib.Path instance. + +This helps the DatabaseJanitor choose correct behaviour based on parameter. diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py index 0175767b..dfebb1cf 100644 --- a/pytest_postgresql/factories/client.py +++ b/pytest_postgresql/factories/client.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with pytest-dbfixtures. If not, see . """Fixture factory for postgresql client.""" +from pathlib import Path from typing import Callable, Iterator, List, Optional, Union import psycopg @@ -31,7 +32,7 @@ def postgresql( process_fixture_name: str, dbname: Optional[str] = None, - load: Optional[List[Union[Callable, str]]] = None, + load: Optional[List[Union[Callable, str, Path]]] = None, isolation_level: "Optional[psycopg.IsolationLevel]" = None, ) -> Callable[[FixtureRequest], Iterator[Connection]]: """Return connection fixture factory for PostgreSQL. diff --git a/pytest_postgresql/factories/noprocess.py b/pytest_postgresql/factories/noprocess.py index e61bb31d..638868a0 100644 --- a/pytest_postgresql/factories/noprocess.py +++ b/pytest_postgresql/factories/noprocess.py @@ -17,6 +17,7 @@ # along with pytest-dbfixtures. If not, see . """Fixture factory for existing postgresql server.""" import os +from pathlib import Path from typing import Callable, Iterator, List, Optional, Union import pytest @@ -42,7 +43,7 @@ def postgresql_noproc( password: Optional[str] = None, dbname: Optional[str] = None, options: str = "", - load: Optional[List[Union[Callable, str]]] = None, + load: Optional[List[Union[Callable, str, Path]]] = None, ) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]: """Postgresql noprocess factory. diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py index 095004fd..28fd0b54 100644 --- a/pytest_postgresql/factories/process.py +++ b/pytest_postgresql/factories/process.py @@ -19,6 +19,7 @@ import os.path import platform import subprocess +from pathlib import Path from typing import Callable, Iterator, List, Optional, Set, Tuple, Union import pytest @@ -54,7 +55,7 @@ def postgresql_proc( startparams: Optional[str] = None, unixsocketdir: Optional[str] = None, postgres_options: Optional[str] = None, - load: Optional[List[Union[Callable, str]]] = None, + load: Optional[List[Union[Callable, str, Path]]] = None, ) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]: """Postgresql process factory. diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index 3d838836..e9731278 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -2,7 +2,7 @@ import re from contextlib import contextmanager -from functools import partial +from pathlib import Path from types import TracebackType from typing import Callable, Iterator, Optional, Type, TypeVar, Union @@ -10,8 +10,8 @@ from packaging.version import parse from psycopg import Connection, Cursor +from pytest_postgresql.loader import build_loader from pytest_postgresql.retry import retry -from pytest_postgresql.sql import loader Version = type(parse("1")) @@ -104,23 +104,17 @@ def _terminate_connection(cur: Cursor, dbname: str) -> None: (dbname,), ) - def load(self, load: Union[Callable, str]) -> None: + def load(self, load: Union[Callable, str, Path]) -> None: """Load data into a database. - Either runs a passed loader if it's callback, - or runs predefined loader if it's sql file. + Expects: + + * a Path to sql file, that'll be loaded + * an import path to import callable + * a callable that expects: host, port, user, dbname and password arguments. + """ - if isinstance(load, str): - if "/" in load: - _loader: Callable = partial(loader, load) - else: - loader_parts = re.split("[.:]", load, 2) - import_path = ".".join(loader_parts[:-1]) - loader_name = loader_parts[-1] - _temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name]) - _loader = getattr(_temp_import, loader_name) - else: - _loader = load + _loader = build_loader(load) _loader( host=self.host, port=self.port, diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py new file mode 100644 index 00000000..9358d4dc --- /dev/null +++ b/pytest_postgresql/loader.py @@ -0,0 +1,32 @@ +"""Loader helper functions.""" + +import re +from functools import partial +from pathlib import Path +from typing import Any, Callable, Union + +import psycopg + + +def build_loader(load: Union[Callable, str, Path]) -> Callable: + """Build a loader callable.""" + if isinstance(load, Path): + return partial(sql, load) + elif isinstance(load, str): + loader_parts = re.split("[.:]", load, 2) + import_path = ".".join(loader_parts[:-1]) + loader_name = loader_parts[-1] + _temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name]) + _loader: Callable = getattr(_temp_import, loader_name) + return _loader + else: + return load + + +def sql(sql_filename: Path, **kwargs: Any) -> None: + """Database loader for sql files.""" + db_connection = psycopg.connect(**kwargs) + with open(sql_filename, "r") as _fd: + with db_connection.cursor() as cur: + cur.execute(_fd.read()) + db_connection.commit() diff --git a/pytest_postgresql/sql.py b/pytest_postgresql/sql.py deleted file mode 100644 index 24c12a9b..00000000 --- a/pytest_postgresql/sql.py +++ /dev/null @@ -1,14 +0,0 @@ -"""SQL Loader function.""" - -from typing import Any - -import psycopg - - -def loader(sql_filename: str, **kwargs: Any) -> None: - """Database loader for sql files.""" - db_connection = psycopg.connect(**kwargs) - with open(sql_filename, "r") as _fd: - with db_connection.cursor() as cur: - cur.execute(_fd.read()) - db_connection.commit() diff --git a/tests/conftest.py b/tests/conftest.py index c3140d48..a75e526e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ """Tests main conftest file.""" import os +from pathlib import Path from pytest_postgresql import factories from pytest_postgresql.plugin import * # noqa: F403,F401 @@ -10,6 +11,8 @@ TEST_SQL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/test_sql/" +TEST_SQL_FILE = Path(TEST_SQL_DIR + "test.sql") +TEST_SQL_FILE2 = Path(TEST_SQL_DIR + "test2.sql") postgresql_proc2 = factories.postgresql_proc(port=None) postgresql2 = factories.postgresql("postgresql_proc2", dbname="test-db") @@ -17,11 +20,11 @@ "postgresql_proc2", dbname="test-load-db", load=[ - TEST_SQL_DIR + "test.sql", + TEST_SQL_FILE, ], ) postgresql_load_2 = factories.postgresql( "postgresql_proc2", dbname="test-load-moredb", - load=[TEST_SQL_DIR + "test.sql", TEST_SQL_DIR + "test2.sql"], + load=[TEST_SQL_FILE, TEST_SQL_FILE2], ) diff --git a/tests/docker/test_noproc_docker.py b/tests/docker/test_noproc_docker.py index c621b8bc..79b8c1eb 100644 --- a/tests/docker/test_noproc_docker.py +++ b/tests/docker/test_noproc_docker.py @@ -1,5 +1,7 @@ """Noproc fixture tests.""" +import pathlib + import pytest from psycopg import Connection @@ -9,7 +11,7 @@ postgresql_my_proc = pytest_postgresql.factories.noprocess.postgresql_noproc() postgres_with_schema = pytest_postgresql.factories.client.postgresql( - "postgresql_my_proc", dbname="test", load=["tests/test_sql/eidastats.sql"] + "postgresql_my_proc", dbname="test", load=[pathlib.Path("tests/test_sql/eidastats.sql")] ) postgresql_my_proc_template = pytest_postgresql.factories.noprocess.postgresql_noproc( diff --git a/tests/examples/test.sql b/tests/examples/test.sql new file mode 100644 index 00000000..edd768c9 --- /dev/null +++ b/tests/examples/test.sql @@ -0,0 +1,2 @@ +CREATE TABLE test (id serial PRIMARY KEY, num integer, data varchar); +INSERT INTO test VALUES(1, 2, 'c'); diff --git a/tests/examples/test_load.py b/tests/examples/test_load.py new file mode 100644 index 00000000..b78a5280 --- /dev/null +++ b/tests/examples/test_load.py @@ -0,0 +1,12 @@ +"""All tests for pytest-postgresql.""" + +from psycopg import Connection + + +def test_postgres_load_one_file(postgresql: Connection) -> None: + """Check postgresql fixture can load one file.""" + cur = postgresql.cursor() + cur.execute("SELECT * FROM test;") + results = cur.fetchall() + assert len(results) == 1 + cur.close() diff --git a/tests/test_loader.py b/tests/test_loader.py new file mode 100644 index 00000000..c03f8a55 --- /dev/null +++ b/tests/test_loader.py @@ -0,0 +1,20 @@ +"""Tests for the `build_loader` function.""" + +from pathlib import Path + +from pytest_postgresql.loader import build_loader, sql +from tests.loader import load_database + + +def test_loader_callables() -> None: + """Test handling callables in build_loader.""" + assert load_database == build_loader(load_database) + assert load_database == build_loader("tests.loader:load_database") + + +def test_loader_sql() -> None: + """Test returning partial running sql for the sql file path.""" + sql_path = Path("test_sql/eidastats.sql") + loader_func = build_loader(sql_path) + assert loader_func.args == (sql_path,) # type: ignore + assert loader_func.func == sql # type: ignore diff --git a/tests/test_postgres_options_plugin.py b/tests/test_postgres_options_plugin.py index fbc000a5..a1a1bec1 100644 --- a/tests/test_postgres_options_plugin.py +++ b/tests/test_postgres_options_plugin.py @@ -33,3 +33,13 @@ def test_postgres_options_config_in_ini(pointed_pytester: Pytester) -> None: pointed_pytester.makefile(".ini", pytest="[pytest]\npostgresql_postgres_options = -N 16\n") ret = pointed_pytester.runpytest("test_postgres_options.py") ret.assert_outcomes(passed=1) + + +def test_postgres_loader_in_cli(pointed_pytester: Pytester) -> None: + """Check that command line arguments are honored.""" + pointed_pytester.copy_example("test_load.py") + pointed_pytester.copy_example("test.sql") + ret = pointed_pytester.runpytest( + "--postgresql-postgres-options", "--postgresql-load test.sql", "test_load.py" + ) + ret.assert_outcomes(passed=1)