Skip to content

Commit

Permalink
🚨(scripts) lint and fix scripts
Browse files Browse the repository at this point in the history
Scripts should follow the same quality standards.
  • Loading branch information
jmaupetit committed Jul 16, 2024
1 parent b2f1c8c commit afed7e2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 51 deletions.
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,27 @@ lint: \

lint-black: ## lint python sources with black
@echo 'lint:black started…'
poetry run black src/data7 tests
poetry run black src/data7 tests scripts
.PHONY: lint-black

lint-black-check: ## check python sources with black
@echo 'lint:black check started…'
poetry run black --check src/data7 tests
poetry run black --check src/data7 tests scripts
.PHONY: lint-black-check

lint-ruff: ## lint python sources with ruff
@echo 'lint:ruff started…'
poetry run ruff check src/data7 tests
poetry run ruff check src/data7 tests scripts
.PHONY: lint-ruff

lint-ruff-fix: ## lint and fix python sources with ruff
@echo 'lint:ruff-fix started…'
poetry run ruff check --fix src/data7 tests
poetry run ruff check --fix src/data7 tests scripts
.PHONY: lint-ruff-fix

lint-mypy: ## lint python sources with mypy
@echo 'lint:mypy started…'
poetry run mypy src/data7 tests
poetry run mypy src/data7 tests scripts
.PHONY: lint-mypy


Expand Down
78 changes: 32 additions & 46 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,27 @@
poetry run python scripts/benchmark.py
For performance testing, we use the imdb-sql project to seed the database:
https://github.com/jmaupetit/imdb-sql
"""

import asyncio
import csv
from dataclasses import dataclass
from io import StringIO
from typing import Any, AsyncGenerator, List
from typing import Any, AsyncGenerator

import databases
from data7.app import Dataset
from data7.app import sql2csv as pd_sql2csv
from data7.config import settings
from pyinstrument import Profiler
from rich.console import Console
from rich.live import Live
from rich.pretty import Pretty, pprint
from rich.table import Table
from sqlalchemy import create_engine

from data7.app import Extension, sql2any
from data7.config import settings


# Models
@dataclass
class Dataset:
"""Dataset model."""

basename: str
query: str
fields: List[str]


# Databases
database = databases.Database("postgresql+asyncpg://imdb:pass@localhost:5432/imdb")
engine = create_engine("postgresql://imdb:pass@localhost:5432/imdb")
Expand All @@ -49,17 +40,24 @@ class Dataset:
console = Console()


title_fields = [
"tconst",
"titleType",
"primaryTitle",
"originalTitle",
"isAdult",
"startYear",
"endYear",
"runtimeMinutes",
"genres",
]


# Legacy
async def sql2csv(dataset: Dataset) -> AsyncGenerator[str, Any]:
"""Stream SQL rows to CSV."""
output = StringIO()

if dataset.fields is None:
raise ValueError(
f"Requested dataset '{dataset.basename}' has no defined fields"
)

writer = csv.DictWriter(output, fieldnames=dataset.fields)
writer = csv.DictWriter(output, fieldnames=title_fields)

# Header
writer.writeheader()
Expand All @@ -75,7 +73,7 @@ async def sql2csv(dataset: Dataset) -> AsyncGenerator[str, Any]:


# Wrappers
async def _sql2csv(dataset: Dataset):
async def _sql2csv(dataset: Dataset) -> float:
"""sql2csv wrapper to handle database connection."""
await database.connect()
aprofiler.reset()
Expand All @@ -84,18 +82,17 @@ async def _sql2csv(dataset: Dataset):
pass
aprofiler.stop()
await database.disconnect()
return aprofiler.last_session.duration
return aprofiler.last_session.duration if aprofiler.last_session else 0.0


def _sql2any(dataset: Dataset, extension: Extension, chunksize: int = 5000):
def _pd_sql2csv(dataset: Dataset, chunksize: int = 5000) -> float:
"""sql2any wrapper to handled database connection."""
with engine.connect() as conn:
profiler.reset()
profiler.start()
for _ in sql2any(dataset, extension, conn, chunksize):
pass
profiler.stop()
return profiler.last_session.duration
profiler.reset()
profiler.start()
for _ in pd_sql2csv(dataset, chunksize):
pass
profiler.stop()
return profiler.last_session.duration if profiler.last_session else 0.0


# Output
Expand All @@ -114,22 +111,11 @@ def render_float(value):
table.add_column("🐼 (10000)", justify="right")
table.add_column("Ratio (10000)", justify="right")

title_fields = [
"tconst",
"titleType",
"primaryTitle",
"originalTitle",
"isAdult",
"startYear",
"endYear",
"runtimeMinutes",
"genres",
]
# ruff: noqa: S608
datasets = [
Dataset(
basename=f"{rows}",
query=f'SELECT * FROM "title.basics" LIMIT {rows}',
fields=title_fields,
)
for rows in (500, 1000, 5000, 10000, 50000, 100000, 500000, 1000000, 5000000)
]
Expand All @@ -143,7 +129,7 @@ def render_float(value):
pandas = []
ratios = []
for chunksize in (1000, 5000, 10000):
p = _sql2any(dataset, Extension.CSV, chunksize)
p = _pd_sql2csv(dataset, chunksize)
pandas.append(p)
ratios.append(legacy / p)
values = (v for couple in zip(pandas, ratios) for v in couple)
Expand Down

0 comments on commit afed7e2

Please sign in to comment.