From 798b85bdad5ba96035ea0363a01857c471c31695 Mon Sep 17 00:00:00 2001 From: Johann Schleier-Smith Date: Tue, 20 Jan 2026 00:43:03 +0000 Subject: [PATCH] refactor: Extract pg_stat_statements column logic to helper Improvements: 1. Add _get_pg_stat_statements_columns() helper function - Eliminates duplication between get_top_queries_by_time and get_top_resource_queries - Single source of truth for version-dependent column names - Returns PgStatStatementsColumns dataclass for type safety 2. Add comprehensive documentation - Documents PostgreSQL 13 breaking changes (v2.0 of pg_stat_statements) - Links to official PG13 release notes - Explains why version branching is necessary Benefits: - More maintainable: one place to update if future PG versions change - More testable: helper can be unit tested independently - Self-documenting: clear dataclass shows all version differences - Type-safe: dataclass prevents typos in column names Co-Authored-By: Claude Sonnet 4.5 --- .../top_queries/top_queries_calc.py | 112 +++++++++++------- 1 file changed, 69 insertions(+), 43 deletions(-) diff --git a/src/postgres_mcp/top_queries/top_queries_calc.py b/src/postgres_mcp/top_queries/top_queries_calc.py index 54e4413..f3fcba1 100644 --- a/src/postgres_mcp/top_queries/top_queries_calc.py +++ b/src/postgres_mcp/top_queries/top_queries_calc.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from typing import Literal from typing import LiteralString from typing import Union @@ -29,6 +30,55 @@ ) +@dataclass +class PgStatStatementsColumns: + """Column names for pg_stat_statements view, which vary by PostgreSQL version.""" + + total_time: str + mean_time: str + stddev_time: str + wal_bytes_select: str # Full SELECT expression (handles missing column in PG12) + wal_bytes_frac: str # Full fraction expression + + +def _get_pg_stat_statements_columns(pg_version: int) -> PgStatStatementsColumns: + """Get pg_stat_statements column names based on PostgreSQL version. + + PostgreSQL 13 introduced pg_stat_statements v2.0 with breaking changes: + - Renamed timing columns: *_time → *_exec_time (total_time → total_exec_time, etc.) + - Added wal_bytes column for write-ahead log tracking + + This function provides version-appropriate column names to ensure compatibility + with both old (PG ≤ 12) and new (PG ≥ 13) versions. + + See: https://www.postgresql.org/docs/13/release-13.html#id-1.11.6.11.4 + + Args: + pg_version: PostgreSQL major version number + + Returns: + PgStatStatementsColumns with version-appropriate column names + """ + if pg_version >= 13: + # PostgreSQL 13+ with pg_stat_statements v2.0 + return PgStatStatementsColumns( + total_time="total_exec_time", + mean_time="mean_exec_time", + stddev_time="stddev_exec_time", + wal_bytes_select="wal_bytes", + wal_bytes_frac="wal_bytes / NULLIF(SUM(wal_bytes) OVER (), 0) AS total_wal_bytes_frac", + ) + + # PostgreSQL 12 and older with pg_stat_statements v1.x + return PgStatStatementsColumns( + total_time="total_time", + mean_time="mean_time", + stddev_time="stddev_time", + wal_bytes_select="0 AS wal_bytes", # Column doesn't exist in PG12 + wal_bytes_frac="0 AS total_wal_bytes_frac", + ) + + class TopQueriesCalc: """Tool for retrieving the slowest SQL queries.""" @@ -59,36 +109,28 @@ async def get_top_queries_by_time(self, limit: int = 10, sort_by: Literal["total # Return installation instructions if the extension is not installed return install_pg_stat_statements_message - # Check PostgreSQL version to determine column names + # Get version-appropriate column names pg_version = await get_postgres_version(self.sql_driver) logger.debug(f"PostgreSQL version: {pg_version}") + cols = _get_pg_stat_statements_columns(pg_version) - # Column names changed in PostgreSQL 13 - if pg_version >= 13: - # PostgreSQL 13 and newer - total_time_col = "total_exec_time" - mean_time_col = "mean_exec_time" - else: - # PostgreSQL 12 and older - total_time_col = "total_time" - mean_time_col = "mean_time" - - logger.debug(f"Using time columns: total={total_time_col}, mean={mean_time_col}") - - # Determine which column to sort by based on sort_by parameter and version - order_by_column = total_time_col if sort_by == "total" else mean_time_col + # Determine which column to sort by based on sort_by parameter + order_by_column = cols.total_time if sort_by == "total" else cols.mean_time - query = f""" + query = cast( + LiteralString, + f""" SELECT query, calls, - {total_time_col}, - {mean_time_col}, + {cols.total_time}, + {cols.mean_time}, rows FROM pg_stat_statements ORDER BY {order_by_column} DESC LIMIT {{}}; - """ + """, + ) logger.debug(f"Executing query: {query}") slow_query_rows = await SafeSqlDriver.execute_param_query( self.sql_driver, @@ -134,26 +176,10 @@ async def get_top_resource_queries(self, frac_threshold: float = 0.05) -> str: # Return installation instructions if the extension is not installed return install_pg_stat_statements_message - # Check PostgreSQL version to determine column names + # Get version-appropriate column names pg_version = await get_postgres_version(self.sql_driver) logger.debug(f"PostgreSQL version: {pg_version}") - - # Column names changed in PostgreSQL 13 - # Also, wal_bytes was added in PostgreSQL 13 - if pg_version >= 13: - # PostgreSQL 13 and newer - total_time_col = "total_exec_time" - mean_time_col = "mean_exec_time" - stddev_time_col = "stddev_exec_time" - wal_bytes_col = "wal_bytes" - wal_bytes_frac = "wal_bytes / NULLIF(SUM(wal_bytes) OVER (), 0) AS total_wal_bytes_frac" - else: - # PostgreSQL 12 and older - total_time_col = "total_time" - mean_time_col = "mean_time" - stddev_time_col = "stddev_time" - wal_bytes_col = "0 AS wal_bytes" # Column doesn't exist in PG12 - wal_bytes_frac = "0 AS total_wal_bytes_frac" + cols = _get_pg_stat_statements_columns(pg_version) query = cast( LiteralString, @@ -163,14 +189,14 @@ async def get_top_resource_queries(self, frac_threshold: float = 0.05) -> str: query, calls, rows, - {total_time_col} AS total_exec_time, - {mean_time_col} AS mean_exec_time, - {stddev_time_col} AS stddev_exec_time, + {cols.total_time} AS total_exec_time, + {cols.mean_time} AS mean_exec_time, + {cols.stddev_time} AS stddev_exec_time, shared_blks_hit, shared_blks_read, shared_blks_dirtied, - {wal_bytes_col}, - {total_time_col} / NULLIF(SUM({total_time_col}) OVER (), 0) + {cols.wal_bytes_select}, + {cols.total_time} / NULLIF(SUM({cols.total_time}) OVER (), 0) AS total_exec_time_frac, (shared_blks_hit + shared_blks_read) / NULLIF(SUM(shared_blks_hit + shared_blks_read) OVER (), 0) @@ -179,7 +205,7 @@ async def get_top_resource_queries(self, frac_threshold: float = 0.05) -> str: AS shared_blks_read_frac, shared_blks_dirtied / NULLIF(SUM(shared_blks_dirtied) OVER (), 0) AS shared_blks_dirtied_frac, - {wal_bytes_frac} + {cols.wal_bytes_frac} FROM pg_stat_statements ) SELECT