diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3797005e9..6ea015941 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,6 +3,28 @@ Changelog ========= +0.12.1 +------ +* Notable efficiency improvement for regular inserts +* Bulk insert operation: + + .. note:: + The bulk insert operation will do the minimum to ensure that the object + created in the DB has all the defaults and generated fields set, + but may be incomplete reference in Python. + + e.g. ``IntField`` primary keys will not be poplulated. + + This is recommend only for throw away inserts where you want to ensure optimal + insert performance. + + .. code-block:: python3 + + User.bulk_create([ + User(name="...", email="..."), + User(name="...", email="...") + ]) + 0.12.0 ------ * Tortoise ORM now supports non-autonumber primary keys. @@ -36,8 +58,6 @@ Changelog guid = fields.UUIDField(pk=True) - For more info, please have a look at :ref:`init_app` - 0.11.13 ------- @@ -69,12 +89,12 @@ Changelog 0.11.7 ------ -- Fixed 'unique_together' for foreign keys (#114) +- Fixed ``unique_together`` for foreign keys (#114) - Fixed Field.to_db_value method to handle Enum (#113 #115 #116) 0.11.6 ------ -- Added ability to use "unique_together" meta Model option +- Added ability to use ``unique_together`` meta Model option 0.11.5 ------ diff --git a/docs/models.rst b/docs/models.rst index 445bfa779..179718fe0 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -99,6 +99,33 @@ Any of these are valid primary key definitions in a Model: guid = fields.UUIDField(pk=True) +The ``Meta`` class +------------------ + +.. autoclass:: tortoise.models.Model.Meta + + .. attribute:: abstract + :annotation: = False + + Set to ``True`` to indicate this is an abstract class + + .. attribute:: table + :annotation: = "" + + Set this to configure a manual table name, instead of a generated one + + .. attribute:: unique_together + :annotation: = None + + Specify ``unique_together`` to set up compound unique indexes for sets of columns. + + It should be a tuple of tuples (lists are fine) in the format of: + + .. code-block:: python3 + + unique_together=("field_a", "field_b") + unique_together=(("field_a", "field_b"), ) + unique_together=(("field_a", "field_b"), ("field_c", "field_d", "field_e") ``ForeignKeyField`` ------------------- @@ -195,7 +222,6 @@ The reverse lookup of ``team.event_team`` works exactly the same way. Reference ========= -.. autoclass:: tortoise.models.Model - :members: +.. automodule:: tortoise.models + :members: Model :undoc-members: - diff --git a/tortoise/__init__.py b/tortoise/__init__.py index e41e33467..89d73b51b 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -401,4 +401,4 @@ async def do_stuff(): loop.run_until_complete(Tortoise.close_connections()) -__version__ = "0.12.0" +__version__ = "0.12.1" diff --git a/tortoise/backends/asyncpg/client.py b/tortoise/backends/asyncpg/client.py index 377465709..1aa0037b9 100644 --- a/tortoise/backends/asyncpg/client.py +++ b/tortoise/backends/asyncpg/client.py @@ -158,6 +158,14 @@ async def execute_insert(self, query: str, values: list) -> Optional[asyncpg.Rec stmt = await connection.prepare(query) return await stmt.fetchrow(*values) + @translate_exceptions + @retry_connection + async def execute_many(self, query: str, values: list) -> None: + async with self.acquire_connection() as connection: + self.log.debug("%s: %s", query, values) + # TODO: Consider using copy_records_to_table instead + await connection.executemany(query, values) + @translate_exceptions @retry_connection async def execute_query(self, query: str) -> List[dict]: diff --git a/tortoise/backends/base/config_generator.py b/tortoise/backends/base/config_generator.py index 517e01baa..39dcde229 100644 --- a/tortoise/backends/base/config_generator.py +++ b/tortoise/backends/base/config_generator.py @@ -34,8 +34,8 @@ "engine": "tortoise.backends.sqlite", "skip_first_char": False, "vmap": {"path": "file_path"}, - "defaults": {}, - "cast": {}, + "defaults": {"journal_mode": "WAL", "journal_size_limit": 16384}, + "cast": {"journal_size_limit": int}, }, "mysql": { "engine": "tortoise.backends.mysql", diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 4b9166cc0..3ad772485 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -1,3 +1,4 @@ +from functools import partial from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type # noqa from pypika import JoinType, Table @@ -9,7 +10,7 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model -INSERT_CACHE = {} # type: Dict[str, Tuple[list, list, str]] +INSERT_CACHE = {} # type: Dict[str, Tuple[list, str, Dict[str, Callable]]] class BaseExecutor: @@ -23,6 +24,24 @@ def __init__(self, model, db=None, prefetch_map=None, prefetch_queries=None): self.prefetch_map = prefetch_map if prefetch_map else {} self._prefetch_queries = prefetch_queries if prefetch_queries else {} + key = "{}:{}".format(self.db.connection_name, self.model._meta.table) + if key not in INSERT_CACHE: + self.regular_columns, columns = self._prepare_insert_columns() + self.query = self._prepare_insert_statement(columns) + + self.column_map = {} # type: Dict[str, Callable] + for column in self.regular_columns: + field_object = self.model._meta.fields_map[column] + if field_object.__class__ in self.TO_DB_OVERRIDE: + func = partial(self.TO_DB_OVERRIDE[field_object.__class__], field_object) + else: + func = field_object.to_db_value + self.column_map[column] = func + + INSERT_CACHE[key] = self.regular_columns, self.query, self.column_map + else: + self.regular_columns, self.query, self.column_map = INSERT_CACHE[key] + async def execute_explain(self, query) -> Any: sql = " ".join(((self.EXPLAIN_PREFIX, query.get_sql()))) return await self.db.execute_query(sql) @@ -54,14 +73,6 @@ def _field_to_db(cls, field_object: fields.Field, attr: Any, instance) -> Any: return cls.TO_DB_OVERRIDE[field_object.__class__](field_object, attr, instance) return field_object.to_db_value(attr, instance) - def _prepare_insert_values(self, instance, regular_columns: List[str]) -> list: - return [ - self._field_to_db( - self.model._meta.fields_map[column], getattr(instance, column), instance - ) - for column in regular_columns - ] - def _prepare_insert_statement(self, columns: List[str]) -> str: # Insert should implement returning new id to saved object # Each db has it's own methods for it, so each implementation should @@ -72,19 +83,24 @@ async def _process_insert_result(self, instance: "Model", results: Any): raise NotImplementedError() # pragma: nocoverage async def execute_insert(self, instance): - key = "{}:{}".format(self.db.connection_name, self.model._meta.table) - if key not in INSERT_CACHE: - regular_columns, columns = self._prepare_insert_columns() - query = self._prepare_insert_statement(columns) - INSERT_CACHE[key] = regular_columns, columns, query - else: - regular_columns, columns, query = INSERT_CACHE[key] - - values = self._prepare_insert_values(instance=instance, regular_columns=regular_columns) - insert_result = await self.db.execute_insert(query, values) + values = [ + self.column_map[column](getattr(instance, column), instance) + for column in self.regular_columns + ] + insert_result = await self.db.execute_insert(self.query, values) await self._process_insert_result(instance, insert_result) return instance + async def execute_bulk_insert(self, instances): + values_lists = [ + [ + self.column_map[column](getattr(instance, column), instance) + for column in self.regular_columns + ] + for instance in instances + ] + await self.db.execute_many(self.query, values_lists) + async def execute_update(self, instance): table = Table(self.model._meta.table) query = self.db.query_class.update(table) @@ -92,7 +108,7 @@ async def execute_update(self, instance): field_object = self.model._meta.fields_map[field] if not field_object.generated: query = query.set( - db_field, self._field_to_db(field_object, getattr(instance, field), instance) + db_field, self.column_map[field](getattr(instance, field), instance) ) query = query.where( getattr(table, self.model._meta.db_pk_field) diff --git a/tortoise/backends/mysql/client.py b/tortoise/backends/mysql/client.py index 5b3f85b18..2d58f17ee 100644 --- a/tortoise/backends/mysql/client.py +++ b/tortoise/backends/mysql/client.py @@ -159,10 +159,17 @@ async def execute_insert(self, query: str, values: list) -> int: async with self.acquire_connection() as connection: self.log.debug("%s: %s", query, values) async with connection.cursor() as cursor: - # TODO: Use prepared statement, and cache it await cursor.execute(query, values) return cursor.lastrowid # return auto-generated id + @translate_exceptions + @retry_connection + async def execute_many(self, query: str, values: list) -> None: + async with self.acquire_connection() as connection: + self.log.debug("%s: %s", query, values) + async with connection.cursor() as cursor: + await cursor.executemany(query, values) + @translate_exceptions @retry_connection async def execute_query(self, query: str) -> List[aiomysql.DictCursor]: diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index 5da4599ce..48f959f7e 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -40,6 +40,11 @@ class SqliteClient(BaseDBAsyncClient): def __init__(self, file_path: str, **kwargs) -> None: super().__init__(**kwargs) self.filename = file_path + + self.pragmas = kwargs.copy() + self.pragmas.pop("connection_name", None) + self.pragmas.pop("fetch_inserted", None) + self._transaction_class = type( "TransactionWrapper", (TransactionWrapper, self.__class__), {} ) @@ -52,15 +57,24 @@ async def create_connection(self, with_db: bool) -> None: self._connection.start() await self._connection._connect() self._connection._conn.row_factory = sqlite3.Row + for pragma, val in self.pragmas.items(): + cursor = await self._connection.execute("PRAGMA {}={}".format(pragma, val)) + await cursor.close() self.log.debug( - "Created connection %s with params: filename=%s", self._connection, self.filename + "Created connection %s with params: filename=%s %s", + self._connection, + self.filename, + " ".join(["{}={}".format(k, v) for k, v in self.pragmas.items()]), ) async def close(self) -> None: if self._connection: await self._connection.close() self.log.debug( - "Closed connection %s with params: filename=%s", self._connection, self.filename + "Closed connection %s with params: filename=%s %s", + self._connection, + self.filename, + " ".join(["{}={}".format(k, v) for k, v in self.pragmas.items()]), ) self._connection = None @@ -91,6 +105,13 @@ async def execute_insert(self, query: str, values: list) -> int: self.log.debug("%s: %s", query, values) return (await connection.execute_insert(query, values))[0] + @translate_exceptions + async def execute_many(self, query: str, values: List[list]) -> None: + async with self.acquire_connection() as connection: + self.log.debug("%s: %s", query, values) + # TODO: Ensure that this is wrapped by a transaction, will provide a big speedup + await connection.executemany(query, values) + @translate_exceptions async def execute_query(self, query: str) -> List[dict]: async with self.acquire_connection() as connection: diff --git a/tortoise/models.py b/tortoise/models.py index b1ad64bd3..4db3b44fb 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -306,7 +306,7 @@ def __init__(self, *args, _from_db: bool = False, **kwargs) -> None: # Assign values and do type conversions passed_fields = set(kwargs.keys()) passed_fields.update(meta.fetch_fields) - passed_fields |= self.set_field_values(kwargs) + passed_fields |= self._set_field_values(kwargs) # Assign defaults for missing fields for key in meta.fields.difference(passed_fields): @@ -316,7 +316,7 @@ def __init__(self, *args, _from_db: bool = False, **kwargs) -> None: else: setattr(self, key, field_object.default) - def set_field_values(self, values_map: Dict[str, Any]) -> Set[str]: + def _set_field_values(self, values_map: Dict[str, Any]) -> Set[str]: """ Sets values for fields honoring type transformations and return list of fields that were set additionally @@ -354,39 +354,6 @@ def set_field_values(self, values_map: Dict[str, Any]) -> Set[str]: return passed_fields - def _get_pk_val(self): - return getattr(self, self._meta.pk_attr) - - def _set_pk_val(self, value): - setattr(self, self._meta.pk_attr, value) - - pk = property(_get_pk_val, _set_pk_val) - - async def _insert_instance(self, using_db=None) -> None: - db = using_db if using_db else self._meta.db - await db.executor_class(model=self.__class__, db=db).execute_insert(self) - self._saved_in_db = True - - async def _update_instance(self, using_db=None) -> None: - db = using_db if using_db else self._meta.db - await db.executor_class(model=self.__class__, db=db).execute_update(self) - - async def save(self, *args, **kwargs) -> None: - if not self._saved_in_db: - await self._insert_instance(*args, **kwargs) - else: - await self._update_instance(*args, **kwargs) - - async def delete(self, using_db=None) -> None: - db = using_db if using_db else self._meta.db - if not self._saved_in_db: - raise OperationalError("Can't delete unpersisted record") - await db.executor_class(model=self.__class__, db=db).execute_delete(self) - - async def fetch_related(self, *args, using_db=None): - db = using_db if using_db else self._meta.db - await db.executor_class(model=self.__class__, db=db).fetch_for_list([self], *args) - def __str__(self) -> str: return "<{}>".format(self.__class__.__name__) @@ -406,6 +373,37 @@ def __eq__(self, other) -> bool: return True return False + def _get_pk_val(self): + return getattr(self, self._meta.pk_attr) + + def _set_pk_val(self, value): + setattr(self, self._meta.pk_attr, value) + + pk = property(_get_pk_val, _set_pk_val) + """ + Alias to the models Primary Key. + Can be used as a field name when doing filtering e.g. ``.filter(pk=...)`` etc... + """ + + async def save(self, using_db=None) -> None: + db = using_db or self._meta.db + executor = db.executor_class(model=self.__class__, db=db) + if self._saved_in_db: + await executor.execute_update(self) + else: + await executor.execute_insert(self) + self._saved_in_db = True + + async def delete(self, using_db=None) -> None: + db = using_db or self._meta.db + if not self._saved_in_db: + raise OperationalError("Can't delete unpersisted record") + await db.executor_class(model=self.__class__, db=db).execute_delete(self) + + async def fetch_related(self, *args, using_db=None): + db = using_db or self._meta.db + await db.executor_class(model=self.__class__, db=db).fetch_for_list([self], *args) + @classmethod async def get_or_create( cls: Type[MODEL_TYPE], using_db=None, defaults=None, **kwargs @@ -420,9 +418,38 @@ async def get_or_create( @classmethod async def create(cls: Type[MODEL_TYPE], **kwargs) -> MODEL_TYPE: instance = cls(**kwargs) - await instance.save(using_db=kwargs.get("using_db")) + db = kwargs.get("using_db") or cls._meta.db + await db.executor_class(model=cls, db=db).execute_insert(instance) + instance._saved_in_db = True return instance + @classmethod + async def bulk_create(cls: Type[MODEL_TYPE], objects: List[MODEL_TYPE], using_db=None) -> None: + """ + Bulk insert operation: + + .. note:: + The bulk insert operation will do the minimum to ensure that the object + created in the DB has all the defaults and generated fields set, + but may be incomplete reference in Python. + + e.g. ``IntField`` primary keys will not be poplulated. + + This is recommend only for throw away inserts where you want to ensure optimal + insert performance. + + .. code-block:: python3 + + User.bulk_create([ + User(name="...", email="..."), + User(name="...", email="...") + ]) + + :param objects: List of objects to bulk create + """ + db = using_db or cls._meta.db + await db.executor_class(model=cls, db=db).execute_bulk_insert(objects) + @classmethod def first(cls) -> QuerySet: return QuerySet(cls).first() @@ -449,7 +476,7 @@ def get(cls, *args, **kwargs) -> QuerySet: @classmethod async def fetch_for_list(cls, instance_list, *args, using_db=None): - db = using_db if using_db else cls._meta.db + db = using_db or cls._meta.db await db.executor_class(model=cls, db=db).fetch_for_list(instance_list, *args) @classmethod @@ -493,4 +520,19 @@ def _check_unique_together(cls): ) class Meta: + """ + The ``Meta`` class is used to configure metadate for the Model. + + Usage: + + .. code-block:: python3 + + class Foo(Model): + ... + + class Meta: + table="custom_table" + unique_together=(("field_a", "field_b"), ) + """ + pass diff --git a/tortoise/queryset.py b/tortoise/queryset.py index f5fcf029b..d267be07d 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -236,6 +236,9 @@ def offset(self, offset: int) -> "QuerySet": def distinct(self) -> "QuerySet": """ Make QuerySet distinct. + + Only makes sense in combination with a .values() or .values_list() as it + precedes all the fetched fields with a distinct. """ queryset = self._clone() queryset._distinct = True diff --git a/tortoise/tests/test_bulk.py b/tortoise/tests/test_bulk.py new file mode 100644 index 000000000..be4365403 --- /dev/null +++ b/tortoise/tests/test_bulk.py @@ -0,0 +1,19 @@ +from uuid import UUID + +from tortoise.contrib import test +from tortoise.tests.testmodels import NoID, UUIDPkModel + + +class TestBasic(test.TestCase): + async def test_bulk_create(self): + await NoID.bulk_create([NoID() for _ in range(1, 1000)]) + self.assertEqual( + await NoID.all().values("id", "name"), + [{"id": val, "name": None} for val in range(1, 1000)], + ) + + async def test_bulk_create_uuidpk(self): + await UUIDPkModel.bulk_create([UUIDPkModel() for _ in range(1000)]) + res = await UUIDPkModel.all().values_list("id", flat=True) + self.assertEqual(len(res), 1000) + self.assertIsInstance(res[0], UUID) diff --git a/tortoise/tests/test_db_url.py b/tortoise/tests/test_db_url.py index 3e17c3768..e1b1896e4 100644 --- a/tortoise/tests/test_db_url.py +++ b/tortoise/tests/test_db_url.py @@ -14,21 +14,40 @@ def test_sqlite_basic(self): res, { "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": "/some/test.sqlite"}, + "credentials": { + "file_path": "/some/test.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, }, ) def test_sqlite_relative(self): res = expand_db_url("sqlite://test.sqlite") self.assertDictEqual( - res, {"engine": "tortoise.backends.sqlite", "credentials": {"file_path": "test.sqlite"}} + res, + { + "engine": "tortoise.backends.sqlite", + "credentials": { + "file_path": "test.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, + }, ) def test_sqlite_relative_with_subdir(self): res = expand_db_url("sqlite://data/db.sqlite") self.assertDictEqual( res, - {"engine": "tortoise.backends.sqlite", "credentials": {"file_path": "data/db.sqlite"}}, + { + "engine": "tortoise.backends.sqlite", + "credentials": { + "file_path": "data/db.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, + }, ) def test_sqlite_testing(self): @@ -38,16 +57,30 @@ def test_sqlite_testing(self): self.assertIn(".sqlite", file_path) self.assertNotEqual("sqlite:///some/test-{}.sqlite", file_path) self.assertDictEqual( - res, {"engine": "tortoise.backends.sqlite", "credentials": {"file_path": file_path}} + res, + { + "engine": "tortoise.backends.sqlite", + "credentials": { + "file_path": file_path, + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, + }, ) def test_sqlite_params(self): - res = expand_db_url("sqlite:///some/test.sqlite?AHA=5&moo=yes") + res = expand_db_url("sqlite:///some/test.sqlite?AHA=5&moo=yes&journal_mode=TRUNCATE") self.assertDictEqual( res, { "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": "/some/test.sqlite", "AHA": "5", "moo": "yes"}, + "credentials": { + "file_path": "/some/test.sqlite", + "AHA": "5", + "moo": "yes", + "journal_mode": "TRUNCATE", + "journal_size_limit": 16384, + }, }, ) @@ -217,7 +250,11 @@ def test_generate_config_basic(self): { "connections": { "default": { - "credentials": {"file_path": "/some/test.sqlite"}, + "credentials": { + "file_path": "/some/test.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, "engine": "tortoise.backends.sqlite", } }, @@ -242,7 +279,11 @@ def test_generate_config_explicit(self): { "connections": { "models": { - "credentials": {"file_path": "/some/test.sqlite"}, + "credentials": { + "file_path": "/some/test.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, "engine": "tortoise.backends.sqlite", } }, @@ -265,7 +306,11 @@ def test_generate_config_many_apps(self): { "connections": { "default": { - "credentials": {"file_path": "/some/test.sqlite"}, + "credentials": { + "file_path": "/some/test.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, "engine": "tortoise.backends.sqlite", } },