From afed7e2bc5941d0e2a0d178f1c6a1d5b2ac82816 Mon Sep 17 00:00:00 2001 From: Julien Maupetit Date: Tue, 16 Jul 2024 11:28:23 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8(scripts)=20lint=20and=20fix=20scri?= =?UTF-8?q?pts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Scripts should follow the same quality standards. --- Makefile | 10 +++--- scripts/benchmark.py | 78 ++++++++++++++++++-------------------------- 2 files changed, 37 insertions(+), 51 deletions(-) diff --git a/Makefile b/Makefile index b5461e9..b9d45d1 100644 --- a/Makefile +++ b/Makefile @@ -69,27 +69,27 @@ lint: \ lint-black: ## lint python sources with black @echo 'lint:black started…' - poetry run black src/data7 tests + poetry run black src/data7 tests scripts .PHONY: lint-black lint-black-check: ## check python sources with black @echo 'lint:black check started…' - poetry run black --check src/data7 tests + poetry run black --check src/data7 tests scripts .PHONY: lint-black-check lint-ruff: ## lint python sources with ruff @echo 'lint:ruff started…' - poetry run ruff check src/data7 tests + poetry run ruff check src/data7 tests scripts .PHONY: lint-ruff lint-ruff-fix: ## lint and fix python sources with ruff @echo 'lint:ruff-fix started…' - poetry run ruff check --fix src/data7 tests + poetry run ruff check --fix src/data7 tests scripts .PHONY: lint-ruff-fix lint-mypy: ## lint python sources with mypy @echo 'lint:mypy started…' - poetry run mypy src/data7 tests + poetry run mypy src/data7 tests scripts .PHONY: lint-mypy diff --git a/scripts/benchmark.py b/scripts/benchmark.py index 04dee70..fcb8410 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -7,15 +7,20 @@ poetry run python scripts/benchmark.py +For performance testing, we use the imdb-sql project to seed the database: +https://github.com/jmaupetit/imdb-sql + """ import asyncio import csv -from dataclasses import dataclass from io import StringIO -from typing import Any, AsyncGenerator, List +from typing import Any, AsyncGenerator import databases +from data7.app import Dataset +from data7.app import sql2csv as pd_sql2csv +from data7.config import settings from pyinstrument import Profiler from rich.console import Console from rich.live import Live @@ -23,20 +28,6 @@ from rich.table import Table from sqlalchemy import create_engine -from data7.app import Extension, sql2any -from data7.config import settings - - -# Models -@dataclass -class Dataset: - """Dataset model.""" - - basename: str - query: str - fields: List[str] - - # Databases database = databases.Database("postgresql+asyncpg://imdb:pass@localhost:5432/imdb") engine = create_engine("postgresql://imdb:pass@localhost:5432/imdb") @@ -49,17 +40,24 @@ class Dataset: console = Console() +title_fields = [ + "tconst", + "titleType", + "primaryTitle", + "originalTitle", + "isAdult", + "startYear", + "endYear", + "runtimeMinutes", + "genres", +] + + # Legacy async def sql2csv(dataset: Dataset) -> AsyncGenerator[str, Any]: """Stream SQL rows to CSV.""" output = StringIO() - - if dataset.fields is None: - raise ValueError( - f"Requested dataset '{dataset.basename}' has no defined fields" - ) - - writer = csv.DictWriter(output, fieldnames=dataset.fields) + writer = csv.DictWriter(output, fieldnames=title_fields) # Header writer.writeheader() @@ -75,7 +73,7 @@ async def sql2csv(dataset: Dataset) -> AsyncGenerator[str, Any]: # Wrappers -async def _sql2csv(dataset: Dataset): +async def _sql2csv(dataset: Dataset) -> float: """sql2csv wrapper to handle database connection.""" await database.connect() aprofiler.reset() @@ -84,18 +82,17 @@ async def _sql2csv(dataset: Dataset): pass aprofiler.stop() await database.disconnect() - return aprofiler.last_session.duration + return aprofiler.last_session.duration if aprofiler.last_session else 0.0 -def _sql2any(dataset: Dataset, extension: Extension, chunksize: int = 5000): +def _pd_sql2csv(dataset: Dataset, chunksize: int = 5000) -> float: """sql2any wrapper to handled database connection.""" - with engine.connect() as conn: - profiler.reset() - profiler.start() - for _ in sql2any(dataset, extension, conn, chunksize): - pass - profiler.stop() - return profiler.last_session.duration + profiler.reset() + profiler.start() + for _ in pd_sql2csv(dataset, chunksize): + pass + profiler.stop() + return profiler.last_session.duration if profiler.last_session else 0.0 # Output @@ -114,22 +111,11 @@ def render_float(value): table.add_column("🐼 (10000)", justify="right") table.add_column("Ratio (10000)", justify="right") -title_fields = [ - "tconst", - "titleType", - "primaryTitle", - "originalTitle", - "isAdult", - "startYear", - "endYear", - "runtimeMinutes", - "genres", -] +# ruff: noqa: S608 datasets = [ Dataset( basename=f"{rows}", query=f'SELECT * FROM "title.basics" LIMIT {rows}', - fields=title_fields, ) for rows in (500, 1000, 5000, 10000, 50000, 100000, 500000, 1000000, 5000000) ] @@ -143,7 +129,7 @@ def render_float(value): pandas = [] ratios = [] for chunksize in (1000, 5000, 10000): - p = _sql2any(dataset, Extension.CSV, chunksize) + p = _pd_sql2csv(dataset, chunksize) pandas.append(p) ratios.append(legacy / p) values = (v for couple in zip(pandas, ratios) for v in couple)