Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 69 additions & 43 deletions src/postgres_mcp/top_queries/top_queries_calc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from dataclasses import dataclass
from typing import Literal
from typing import LiteralString
from typing import Union
Expand Down Expand Up @@ -29,6 +30,55 @@
)


@dataclass
class PgStatStatementsColumns:
"""Column names for pg_stat_statements view, which vary by PostgreSQL version."""

total_time: str
mean_time: str
stddev_time: str
wal_bytes_select: str # Full SELECT expression (handles missing column in PG12)
wal_bytes_frac: str # Full fraction expression


def _get_pg_stat_statements_columns(pg_version: int) -> PgStatStatementsColumns:
"""Get pg_stat_statements column names based on PostgreSQL version.

PostgreSQL 13 introduced pg_stat_statements v2.0 with breaking changes:
- Renamed timing columns: *_time → *_exec_time (total_time → total_exec_time, etc.)
- Added wal_bytes column for write-ahead log tracking

This function provides version-appropriate column names to ensure compatibility
with both old (PG ≤ 12) and new (PG ≥ 13) versions.

See: https://www.postgresql.org/docs/13/release-13.html#id-1.11.6.11.4

Args:
pg_version: PostgreSQL major version number

Returns:
PgStatStatementsColumns with version-appropriate column names
"""
if pg_version >= 13:
# PostgreSQL 13+ with pg_stat_statements v2.0
return PgStatStatementsColumns(
total_time="total_exec_time",
mean_time="mean_exec_time",
stddev_time="stddev_exec_time",
wal_bytes_select="wal_bytes",
wal_bytes_frac="wal_bytes / NULLIF(SUM(wal_bytes) OVER (), 0) AS total_wal_bytes_frac",
)

# PostgreSQL 12 and older with pg_stat_statements v1.x
return PgStatStatementsColumns(
total_time="total_time",
mean_time="mean_time",
stddev_time="stddev_time",
wal_bytes_select="0 AS wal_bytes", # Column doesn't exist in PG12
wal_bytes_frac="0 AS total_wal_bytes_frac",
)


class TopQueriesCalc:
"""Tool for retrieving the slowest SQL queries."""

Expand Down Expand Up @@ -59,36 +109,28 @@ async def get_top_queries_by_time(self, limit: int = 10, sort_by: Literal["total
# Return installation instructions if the extension is not installed
return install_pg_stat_statements_message

# Check PostgreSQL version to determine column names
# Get version-appropriate column names
pg_version = await get_postgres_version(self.sql_driver)
logger.debug(f"PostgreSQL version: {pg_version}")
cols = _get_pg_stat_statements_columns(pg_version)

# Column names changed in PostgreSQL 13
if pg_version >= 13:
# PostgreSQL 13 and newer
total_time_col = "total_exec_time"
mean_time_col = "mean_exec_time"
else:
# PostgreSQL 12 and older
total_time_col = "total_time"
mean_time_col = "mean_time"

logger.debug(f"Using time columns: total={total_time_col}, mean={mean_time_col}")

# Determine which column to sort by based on sort_by parameter and version
order_by_column = total_time_col if sort_by == "total" else mean_time_col
# Determine which column to sort by based on sort_by parameter
order_by_column = cols.total_time if sort_by == "total" else cols.mean_time

query = f"""
query = cast(
LiteralString,
f"""
SELECT
query,
calls,
{total_time_col},
{mean_time_col},
{cols.total_time},
{cols.mean_time},
rows
FROM pg_stat_statements
ORDER BY {order_by_column} DESC
LIMIT {{}};
"""
""",
)
logger.debug(f"Executing query: {query}")
slow_query_rows = await SafeSqlDriver.execute_param_query(
self.sql_driver,
Expand Down Expand Up @@ -134,26 +176,10 @@ async def get_top_resource_queries(self, frac_threshold: float = 0.05) -> str:
# Return installation instructions if the extension is not installed
return install_pg_stat_statements_message

# Check PostgreSQL version to determine column names
# Get version-appropriate column names
pg_version = await get_postgres_version(self.sql_driver)
logger.debug(f"PostgreSQL version: {pg_version}")

# Column names changed in PostgreSQL 13
# Also, wal_bytes was added in PostgreSQL 13
if pg_version >= 13:
# PostgreSQL 13 and newer
total_time_col = "total_exec_time"
mean_time_col = "mean_exec_time"
stddev_time_col = "stddev_exec_time"
wal_bytes_col = "wal_bytes"
wal_bytes_frac = "wal_bytes / NULLIF(SUM(wal_bytes) OVER (), 0) AS total_wal_bytes_frac"
else:
# PostgreSQL 12 and older
total_time_col = "total_time"
mean_time_col = "mean_time"
stddev_time_col = "stddev_time"
wal_bytes_col = "0 AS wal_bytes" # Column doesn't exist in PG12
wal_bytes_frac = "0 AS total_wal_bytes_frac"
cols = _get_pg_stat_statements_columns(pg_version)

query = cast(
LiteralString,
Expand All @@ -163,14 +189,14 @@ async def get_top_resource_queries(self, frac_threshold: float = 0.05) -> str:
query,
calls,
rows,
{total_time_col} AS total_exec_time,
{mean_time_col} AS mean_exec_time,
{stddev_time_col} AS stddev_exec_time,
{cols.total_time} AS total_exec_time,
{cols.mean_time} AS mean_exec_time,
{cols.stddev_time} AS stddev_exec_time,
shared_blks_hit,
shared_blks_read,
shared_blks_dirtied,
{wal_bytes_col},
{total_time_col} / NULLIF(SUM({total_time_col}) OVER (), 0)
{cols.wal_bytes_select},
{cols.total_time} / NULLIF(SUM({cols.total_time}) OVER (), 0)
AS total_exec_time_frac,
(shared_blks_hit + shared_blks_read)
/ NULLIF(SUM(shared_blks_hit + shared_blks_read) OVER (), 0)
Expand All @@ -179,7 +205,7 @@ async def get_top_resource_queries(self, frac_threshold: float = 0.05) -> str:
AS shared_blks_read_frac,
shared_blks_dirtied / NULLIF(SUM(shared_blks_dirtied) OVER (), 0)
AS shared_blks_dirtied_frac,
{wal_bytes_frac}
{cols.wal_bytes_frac}
FROM pg_stat_statements
)
SELECT
Expand Down