diff --git a/README.rst b/README.rst
index c4130e36..5f044aab 100644
--- a/README.rst
+++ b/README.rst
@@ -104,9 +104,10 @@ Client specific loads the database each test
.. code-block:: python
+ from pathlib import Path
postgresql_my_with_schema = factories.postgresql(
'postgresql_my_proc',
- load=["schemafile.sql", "otherschema.sql", "import.path.to.function", "import.path.to:otherfunction", load_this]
+ load=[Path("schemafile.sql"), Path("otherschema.sql"), "import.path.to.function", "import.path.to:otherfunction", load_this]
)
.. warning::
@@ -115,12 +116,13 @@ Client specific loads the database each test
The process fixture performs the load once per test session, and loads the data into the template database.
-Client fixture then creates test database out of the template database each test, which significantly speeds up the tests.
+Client fixture then creates test database out of the template database each test, which significantly **speeds up the tests**.
.. code-block:: python
+ from pathlib import Path
postgresql_my_proc = factories.postgresql_proc(
- load=["schemafile.sql", "otherschema.sql", "import.path.to.function", "import.path.to:otherfunction", load_this]
+ load=[Path("schemafile.sql"), Path("otherschema.sql"), "import.path.to.function", "import.path.to:otherfunction", load_this]
)
diff --git a/newsfragments/638.feature.rst b/newsfragments/638.feature.rst
new file mode 100644
index 00000000..e30e1537
--- /dev/null
+++ b/newsfragments/638.feature.rst
@@ -0,0 +1,3 @@
+Now all sql files used to initialise database for tests, has to be passed as pathlib.Path instance.
+
+This helps the DatabaseJanitor choose correct behaviour based on parameter.
diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py
index 0175767b..dfebb1cf 100644
--- a/pytest_postgresql/factories/client.py
+++ b/pytest_postgresql/factories/client.py
@@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with pytest-dbfixtures. If not, see .
"""Fixture factory for postgresql client."""
+from pathlib import Path
from typing import Callable, Iterator, List, Optional, Union
import psycopg
@@ -31,7 +32,7 @@
def postgresql(
process_fixture_name: str,
dbname: Optional[str] = None,
- load: Optional[List[Union[Callable, str]]] = None,
+ load: Optional[List[Union[Callable, str, Path]]] = None,
isolation_level: "Optional[psycopg.IsolationLevel]" = None,
) -> Callable[[FixtureRequest], Iterator[Connection]]:
"""Return connection fixture factory for PostgreSQL.
diff --git a/pytest_postgresql/factories/noprocess.py b/pytest_postgresql/factories/noprocess.py
index e61bb31d..638868a0 100644
--- a/pytest_postgresql/factories/noprocess.py
+++ b/pytest_postgresql/factories/noprocess.py
@@ -17,6 +17,7 @@
# along with pytest-dbfixtures. If not, see .
"""Fixture factory for existing postgresql server."""
import os
+from pathlib import Path
from typing import Callable, Iterator, List, Optional, Union
import pytest
@@ -42,7 +43,7 @@ def postgresql_noproc(
password: Optional[str] = None,
dbname: Optional[str] = None,
options: str = "",
- load: Optional[List[Union[Callable, str]]] = None,
+ load: Optional[List[Union[Callable, str, Path]]] = None,
) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]:
"""Postgresql noprocess factory.
diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py
index 095004fd..28fd0b54 100644
--- a/pytest_postgresql/factories/process.py
+++ b/pytest_postgresql/factories/process.py
@@ -19,6 +19,7 @@
import os.path
import platform
import subprocess
+from pathlib import Path
from typing import Callable, Iterator, List, Optional, Set, Tuple, Union
import pytest
@@ -54,7 +55,7 @@ def postgresql_proc(
startparams: Optional[str] = None,
unixsocketdir: Optional[str] = None,
postgres_options: Optional[str] = None,
- load: Optional[List[Union[Callable, str]]] = None,
+ load: Optional[List[Union[Callable, str, Path]]] = None,
) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]:
"""Postgresql process factory.
diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py
index 3d838836..e9731278 100644
--- a/pytest_postgresql/janitor.py
+++ b/pytest_postgresql/janitor.py
@@ -2,7 +2,7 @@
import re
from contextlib import contextmanager
-from functools import partial
+from pathlib import Path
from types import TracebackType
from typing import Callable, Iterator, Optional, Type, TypeVar, Union
@@ -10,8 +10,8 @@
from packaging.version import parse
from psycopg import Connection, Cursor
+from pytest_postgresql.loader import build_loader
from pytest_postgresql.retry import retry
-from pytest_postgresql.sql import loader
Version = type(parse("1"))
@@ -104,23 +104,17 @@ def _terminate_connection(cur: Cursor, dbname: str) -> None:
(dbname,),
)
- def load(self, load: Union[Callable, str]) -> None:
+ def load(self, load: Union[Callable, str, Path]) -> None:
"""Load data into a database.
- Either runs a passed loader if it's callback,
- or runs predefined loader if it's sql file.
+ Expects:
+
+ * a Path to sql file, that'll be loaded
+ * an import path to import callable
+ * a callable that expects: host, port, user, dbname and password arguments.
+
"""
- if isinstance(load, str):
- if "/" in load:
- _loader: Callable = partial(loader, load)
- else:
- loader_parts = re.split("[.:]", load, 2)
- import_path = ".".join(loader_parts[:-1])
- loader_name = loader_parts[-1]
- _temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name])
- _loader = getattr(_temp_import, loader_name)
- else:
- _loader = load
+ _loader = build_loader(load)
_loader(
host=self.host,
port=self.port,
diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py
new file mode 100644
index 00000000..9358d4dc
--- /dev/null
+++ b/pytest_postgresql/loader.py
@@ -0,0 +1,32 @@
+"""Loader helper functions."""
+
+import re
+from functools import partial
+from pathlib import Path
+from typing import Any, Callable, Union
+
+import psycopg
+
+
+def build_loader(load: Union[Callable, str, Path]) -> Callable:
+ """Build a loader callable."""
+ if isinstance(load, Path):
+ return partial(sql, load)
+ elif isinstance(load, str):
+ loader_parts = re.split("[.:]", load, 2)
+ import_path = ".".join(loader_parts[:-1])
+ loader_name = loader_parts[-1]
+ _temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name])
+ _loader: Callable = getattr(_temp_import, loader_name)
+ return _loader
+ else:
+ return load
+
+
+def sql(sql_filename: Path, **kwargs: Any) -> None:
+ """Database loader for sql files."""
+ db_connection = psycopg.connect(**kwargs)
+ with open(sql_filename, "r") as _fd:
+ with db_connection.cursor() as cur:
+ cur.execute(_fd.read())
+ db_connection.commit()
diff --git a/pytest_postgresql/sql.py b/pytest_postgresql/sql.py
deleted file mode 100644
index 24c12a9b..00000000
--- a/pytest_postgresql/sql.py
+++ /dev/null
@@ -1,14 +0,0 @@
-"""SQL Loader function."""
-
-from typing import Any
-
-import psycopg
-
-
-def loader(sql_filename: str, **kwargs: Any) -> None:
- """Database loader for sql files."""
- db_connection = psycopg.connect(**kwargs)
- with open(sql_filename, "r") as _fd:
- with db_connection.cursor() as cur:
- cur.execute(_fd.read())
- db_connection.commit()
diff --git a/tests/conftest.py b/tests/conftest.py
index c3140d48..a75e526e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,7 @@
"""Tests main conftest file."""
import os
+from pathlib import Path
from pytest_postgresql import factories
from pytest_postgresql.plugin import * # noqa: F403,F401
@@ -10,6 +11,8 @@
TEST_SQL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/test_sql/"
+TEST_SQL_FILE = Path(TEST_SQL_DIR + "test.sql")
+TEST_SQL_FILE2 = Path(TEST_SQL_DIR + "test2.sql")
postgresql_proc2 = factories.postgresql_proc(port=None)
postgresql2 = factories.postgresql("postgresql_proc2", dbname="test-db")
@@ -17,11 +20,11 @@
"postgresql_proc2",
dbname="test-load-db",
load=[
- TEST_SQL_DIR + "test.sql",
+ TEST_SQL_FILE,
],
)
postgresql_load_2 = factories.postgresql(
"postgresql_proc2",
dbname="test-load-moredb",
- load=[TEST_SQL_DIR + "test.sql", TEST_SQL_DIR + "test2.sql"],
+ load=[TEST_SQL_FILE, TEST_SQL_FILE2],
)
diff --git a/tests/docker/test_noproc_docker.py b/tests/docker/test_noproc_docker.py
index c621b8bc..79b8c1eb 100644
--- a/tests/docker/test_noproc_docker.py
+++ b/tests/docker/test_noproc_docker.py
@@ -1,5 +1,7 @@
"""Noproc fixture tests."""
+import pathlib
+
import pytest
from psycopg import Connection
@@ -9,7 +11,7 @@
postgresql_my_proc = pytest_postgresql.factories.noprocess.postgresql_noproc()
postgres_with_schema = pytest_postgresql.factories.client.postgresql(
- "postgresql_my_proc", dbname="test", load=["tests/test_sql/eidastats.sql"]
+ "postgresql_my_proc", dbname="test", load=[pathlib.Path("tests/test_sql/eidastats.sql")]
)
postgresql_my_proc_template = pytest_postgresql.factories.noprocess.postgresql_noproc(
diff --git a/tests/examples/test.sql b/tests/examples/test.sql
new file mode 100644
index 00000000..edd768c9
--- /dev/null
+++ b/tests/examples/test.sql
@@ -0,0 +1,2 @@
+CREATE TABLE test (id serial PRIMARY KEY, num integer, data varchar);
+INSERT INTO test VALUES(1, 2, 'c');
diff --git a/tests/examples/test_load.py b/tests/examples/test_load.py
new file mode 100644
index 00000000..b78a5280
--- /dev/null
+++ b/tests/examples/test_load.py
@@ -0,0 +1,12 @@
+"""All tests for pytest-postgresql."""
+
+from psycopg import Connection
+
+
+def test_postgres_load_one_file(postgresql: Connection) -> None:
+ """Check postgresql fixture can load one file."""
+ cur = postgresql.cursor()
+ cur.execute("SELECT * FROM test;")
+ results = cur.fetchall()
+ assert len(results) == 1
+ cur.close()
diff --git a/tests/test_loader.py b/tests/test_loader.py
new file mode 100644
index 00000000..c03f8a55
--- /dev/null
+++ b/tests/test_loader.py
@@ -0,0 +1,20 @@
+"""Tests for the `build_loader` function."""
+
+from pathlib import Path
+
+from pytest_postgresql.loader import build_loader, sql
+from tests.loader import load_database
+
+
+def test_loader_callables() -> None:
+ """Test handling callables in build_loader."""
+ assert load_database == build_loader(load_database)
+ assert load_database == build_loader("tests.loader:load_database")
+
+
+def test_loader_sql() -> None:
+ """Test returning partial running sql for the sql file path."""
+ sql_path = Path("test_sql/eidastats.sql")
+ loader_func = build_loader(sql_path)
+ assert loader_func.args == (sql_path,) # type: ignore
+ assert loader_func.func == sql # type: ignore
diff --git a/tests/test_postgres_options_plugin.py b/tests/test_postgres_options_plugin.py
index fbc000a5..a1a1bec1 100644
--- a/tests/test_postgres_options_plugin.py
+++ b/tests/test_postgres_options_plugin.py
@@ -33,3 +33,13 @@ def test_postgres_options_config_in_ini(pointed_pytester: Pytester) -> None:
pointed_pytester.makefile(".ini", pytest="[pytest]\npostgresql_postgres_options = -N 16\n")
ret = pointed_pytester.runpytest("test_postgres_options.py")
ret.assert_outcomes(passed=1)
+
+
+def test_postgres_loader_in_cli(pointed_pytester: Pytester) -> None:
+ """Check that command line arguments are honored."""
+ pointed_pytester.copy_example("test_load.py")
+ pointed_pytester.copy_example("test.sql")
+ ret = pointed_pytester.runpytest(
+ "--postgresql-postgres-options", "--postgresql-load test.sql", "test_load.py"
+ )
+ ret.assert_outcomes(passed=1)