From eb2cd958f32dd66c9cd8da1e7d5ef56c66fbeec1 Mon Sep 17 00:00:00 2001 From: Simon Karan Date: Fri, 20 Mar 2026 12:22:37 +0100 Subject: [PATCH 1/4] [DBA-94] Fix plot propagation to previous messages Restrict the shared thread._visualization_result fallback to the latest message only, preventing older messages from rendering a newer query's plot. Add allow_thread_fallback guard to render_visualization_section as defense-in-depth, fix persistence round-trip when _has_spec_df marker is present but parquet file is missing, and add diagnostic logging. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/databao_cli/ui/components/results.py | 18 +- .../ui/services/chat_persistence.py | 16 +- tests/test_plot_propagation.py | 249 ++++++++++++++++++ 3 files changed, 275 insertions(+), 8 deletions(-) create mode 100644 tests/test_plot_propagation.py diff --git a/src/databao_cli/ui/components/results.py b/src/databao_cli/ui/components/results.py index 3a386690..ade628a3 100644 --- a/src/databao_cli/ui/components/results.py +++ b/src/databao_cli/ui/components/results.py @@ -190,7 +190,12 @@ def render_dataframe_section(result: "ExecutionResult", has_visualization: bool) st.dataframe(df, width="stretch") -def render_visualization_section(thread: "Thread", visualization_data: dict[str, Any] | None = None) -> None: +def render_visualization_section( + thread: "Thread", + visualization_data: dict[str, Any] | None = None, + *, + allow_thread_fallback: bool = False, +) -> None: """Render the visualization section. Follows the same rendering logic as Jupyter notebooks: @@ -221,7 +226,7 @@ def render_visualization_section(thread: "Thread", visualization_data: dict[str, return return - if vis_result is None: + if vis_result is None or not allow_thread_fallback: return with st.expander("📈 Visualization", expanded=True): @@ -349,8 +354,10 @@ def render_visualization_and_actions( st.info("Generating visualization...", icon="📈") elif viz_error: st.error(f"Failed to generate visualization: {viz_error}") - elif has_visualization or visualization_data is not None or (is_latest and thread._visualization_result is not None): + elif visualization_data is not None: render_visualization_section(thread, visualization_data) + elif is_latest and (has_visualization or thread._visualization_result is not None): + render_visualization_section(thread, None, allow_thread_fallback=True) if is_latest and not viz_pending and not viz_error: _render_and_handle_action_buttons(result, current_chat, message_index, has_visualization) @@ -411,6 +418,11 @@ def execute_pending_plot(chat: "ChatSession") -> None: if message_index < len(messages): messages[message_index].has_visualization = True messages[message_index].visualization_data = _extract_visualization_data(thread) + if messages[message_index].visualization_data is None: + logger.warning( + "visualization_data is None after successful thread.plot() for message %d", + message_index, + ) save_current_chat() except Exception as e: logger.exception("Failed to generate visualization") diff --git a/src/databao_cli/ui/services/chat_persistence.py b/src/databao_cli/ui/services/chat_persistence.py index d410019e..383f8243 100644 --- a/src/databao_cli/ui/services/chat_persistence.py +++ b/src/databao_cli/ui/services/chat_persistence.py @@ -169,11 +169,17 @@ def load_chat(chat_id: str) -> ChatSession | None: vis_data = msg_data.get("visualization_data") if vis_data is not None: vis_df_path = visualizations_dir / f"{i}_spec_df.parquet" - if vis_df_path.exists(): - try: - vis_data["spec_df"] = pd.read_parquet(vis_df_path) - except Exception as e: - logger.warning(f"Failed to load visualization spec_df {i} for chat {chat_id}: {e}") + if vis_data.get("_has_spec_df"): + if vis_df_path.exists(): + try: + vis_data["spec_df"] = pd.read_parquet(vis_df_path) + except Exception as e: + logger.warning(f"Failed to load visualization spec_df {i} for chat {chat_id}: {e}") + vis_data["spec_df"] = None + else: + logger.warning(f"Missing visualization spec_df parquet for message {i} in chat {chat_id}") + vis_data["spec_df"] = None + vis_data.pop("_has_spec_df", None) chat = ChatSession.from_dict(session_data, results) logger.debug(f"Chat loaded: {chat_id}") diff --git a/tests/test_plot_propagation.py b/tests/test_plot_propagation.py new file mode 100644 index 00000000..f4651378 --- /dev/null +++ b/tests/test_plot_propagation.py @@ -0,0 +1,249 @@ +"""Tests for DBA-94: plot propagation to previous messages. + +Verifies that: +1. render_visualization_section refuses the shared thread fallback unless + explicitly allowed via allow_thread_fallback=True. +2. The persistence round-trip correctly handles missing parquet files when + _has_spec_df is True. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + + +@pytest.fixture(autouse=True, scope="module") +def _preload_results_module() -> None: + """Pre-import the results module to resolve the circular import chain. + + results → chat_persistence → __init__ → query_executor → results. + Importing query_executor first breaks the cycle. + """ + import databao_cli.ui.services.query_executor # noqa: F401 + + +def _get_render_fn() -> Any: + """Import render_visualization_section lazily to avoid circular import at collection time.""" + from databao_cli.ui.components.results import render_visualization_section + + return render_visualization_section + + +# --------------------------------------------------------------------------- +# render_visualization_section tests +# --------------------------------------------------------------------------- + + +class TestRenderVisualizationSection: + """Ensure the allow_thread_fallback guard works correctly.""" + + def test_thread_fallback_blocked_by_default(self) -> None: + """Without allow_thread_fallback the shared thread result is ignored.""" + render_visualization_section = _get_render_fn() + with patch("databao_cli.ui.components.results.st") as mock_st: + thread = MagicMock() + thread._visualization_result = MagicMock() + + render_visualization_section(thread, visualization_data=None) + + mock_st.expander.assert_not_called() + + @patch("databao_cli.ui.components.results.st") + def test_thread_fallback_allowed_renders(self, mock_st: MagicMock) -> None: + """With allow_thread_fallback=True the shared thread result is used.""" + render_visualization_section = _get_render_fn() + thread = MagicMock() + vis = MagicMock() + vis.plot = None + thread._visualization_result = vis + + mock_expander = MagicMock() + mock_st.expander.return_value.__enter__ = MagicMock(return_value=mock_expander) + mock_st.expander.return_value.__exit__ = MagicMock(return_value=False) + + render_visualization_section(thread, visualization_data=None, allow_thread_fallback=True) + + mock_st.expander.assert_called_once() + + @patch("databao_cli.ui.components.results.st") + def test_visualization_data_takes_priority(self, mock_st: MagicMock) -> None: + """Per-message visualization_data is used even when thread has a result.""" + render_visualization_section = _get_render_fn() + thread = MagicMock() + thread._visualization_result = MagicMock() + + spec = {"mark": "bar"} + spec_df = pd.DataFrame({"x": [1], "y": [2]}) + vis_data: dict[str, Any] = {"spec": spec, "spec_df": spec_df} + + mock_expander = MagicMock() + mock_st.expander.return_value.__enter__ = MagicMock(return_value=mock_expander) + mock_st.expander.return_value.__exit__ = MagicMock(return_value=False) + + render_visualization_section(thread, visualization_data=vis_data) + + mock_st.vega_lite_chart.assert_called_once() + + @patch("databao_cli.ui.components.results.st") + def test_visualization_data_missing_spec_df_returns_early(self, mock_st: MagicMock) -> None: + """If visualization_data has spec but no spec_df, returns without rendering.""" + render_visualization_section = _get_render_fn() + thread = MagicMock() + thread._visualization_result = MagicMock() + + vis_data: dict[str, Any] = {"spec": {"mark": "bar"}, "spec_df": None} + + render_visualization_section(thread, visualization_data=vis_data) + + mock_st.expander.assert_not_called() + + @patch("databao_cli.ui.components.results.st") + def test_no_vis_result_no_data_returns_early(self, mock_st: MagicMock) -> None: + """No visualization_data and no thread result -> nothing rendered.""" + render_visualization_section = _get_render_fn() + thread = MagicMock() + thread._visualization_result = None + + render_visualization_section(thread, visualization_data=None, allow_thread_fallback=True) + + mock_st.expander.assert_not_called() + + +# --------------------------------------------------------------------------- +# Persistence round-trip tests for _has_spec_df handling +# --------------------------------------------------------------------------- + + +class TestVisualizationPersistenceRoundTrip: + """Test the _has_spec_df loading logic by calling the real load_chat function. + + We mock get_chats_dir to point at a tmp directory and build a minimal + on-disk chat structure so load_chat exercises the actual persistence code. + """ + + @staticmethod + def _write_chat( + tmp_path: Path, + chat_id: str, + vis_data: dict[str, Any] | None, + write_parquet: bool = False, + ) -> None: + """Write a minimal chat directory that load_chat can read.""" + import json + + chat_dir = tmp_path / chat_id + chat_dir.mkdir(parents=True, exist_ok=True) + + session = { + "id": chat_id, + "created_at": "2026-01-01T00:00:00", + "messages": [ + { + "role": "assistant", + "content": "test answer", + "has_visualization": True, + "visualization_data": vis_data, + }, + ], + } + (chat_dir / "session.json").write_text(json.dumps(session)) + + if write_parquet: + vis_dir = chat_dir / "visualizations" + vis_dir.mkdir(exist_ok=True) + df = pd.DataFrame({"x": [1, 2], "y": [3, 4]}) + df.to_parquet(vis_dir / "0_spec_df.parquet") + + @patch("databao_cli.ui.services.chat_persistence.get_chats_dir") + def test_has_spec_df_true_with_parquet(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None: + """When _has_spec_df=True and parquet exists, spec_df is loaded.""" + from databao_cli.ui.services.chat_persistence import load_chat + + mock_chats_dir.return_value = tmp_path + chat_id = "00000000-0000-0000-0000-000000000001" + self._write_chat( + tmp_path, + chat_id, + vis_data={"spec": {"mark": "bar"}, "_has_spec_df": True}, + write_parquet=True, + ) + + chat = load_chat(chat_id) + + assert chat is not None + msg = chat.messages[0] + assert msg.visualization_data is not None + assert msg.visualization_data["spec_df"] is not None + assert len(msg.visualization_data["spec_df"]) == 2 + assert "_has_spec_df" not in msg.visualization_data + + @patch("databao_cli.ui.services.chat_persistence.get_chats_dir") + def test_has_spec_df_true_without_parquet(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None: + """When _has_spec_df=True but parquet is missing, spec_df is explicitly None.""" + from databao_cli.ui.services.chat_persistence import load_chat + + mock_chats_dir.return_value = tmp_path + chat_id = "00000000-0000-0000-0000-000000000002" + self._write_chat( + tmp_path, + chat_id, + vis_data={"spec": {"mark": "bar"}, "_has_spec_df": True}, + write_parquet=False, + ) + + chat = load_chat(chat_id) + + assert chat is not None + msg = chat.messages[0] + assert msg.visualization_data is not None + # Key assertion: spec_df must be explicitly None, not absent + assert "spec_df" in msg.visualization_data + assert msg.visualization_data["spec_df"] is None + assert "_has_spec_df" not in msg.visualization_data + + @patch("databao_cli.ui.services.chat_persistence.get_chats_dir") + def test_has_spec_df_false_skips_parquet_load(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None: + """When _has_spec_df is absent, spec_df is not loaded even if parquet exists.""" + from databao_cli.ui.services.chat_persistence import load_chat + + mock_chats_dir.return_value = tmp_path + chat_id = "00000000-0000-0000-0000-000000000003" + self._write_chat( + tmp_path, + chat_id, + vis_data={"spec": {"mark": "bar"}}, + write_parquet=True, # parquet exists but should NOT be loaded + ) + + chat = load_chat(chat_id) + + assert chat is not None + msg = chat.messages[0] + assert msg.visualization_data is not None + assert "spec_df" not in msg.visualization_data + + @patch("databao_cli.ui.services.chat_persistence.get_chats_dir") + def test_has_spec_df_marker_always_removed(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None: + """The _has_spec_df marker is always stripped from the loaded data.""" + from databao_cli.ui.services.chat_persistence import load_chat + + mock_chats_dir.return_value = tmp_path + chat_id = "00000000-0000-0000-0000-000000000004" + self._write_chat( + tmp_path, + chat_id, + vis_data={"spec": {"mark": "bar"}, "_has_spec_df": False}, + write_parquet=False, + ) + + chat = load_chat(chat_id) + + assert chat is not None + msg = chat.messages[0] + assert msg.visualization_data is not None + assert "_has_spec_df" not in msg.visualization_data From bb68723e7673bbdbae5b746e29413a846f2dc0a2 Mon Sep 17 00:00:00 2001 From: Simon Karan Date: Fri, 20 Mar 2026 14:21:14 +0100 Subject: [PATCH 2/4] [DBA-94] Fix legacy visualization reload compatibility --- src/databao_cli/ui/services/chat_persistence.py | 8 +++++--- tests/test_plot_propagation.py | 9 +++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/databao_cli/ui/services/chat_persistence.py b/src/databao_cli/ui/services/chat_persistence.py index 383f8243..815ed89e 100644 --- a/src/databao_cli/ui/services/chat_persistence.py +++ b/src/databao_cli/ui/services/chat_persistence.py @@ -169,14 +169,16 @@ def load_chat(chat_id: str) -> ChatSession | None: vis_data = msg_data.get("visualization_data") if vis_data is not None: vis_df_path = visualizations_dir / f"{i}_spec_df.parquet" - if vis_data.get("_has_spec_df"): - if vis_df_path.exists(): + expects_spec_df = bool(vis_data.get("_has_spec_df")) + has_spec_df_parquet = vis_df_path.exists() + if expects_spec_df or has_spec_df_parquet: + if has_spec_df_parquet: try: vis_data["spec_df"] = pd.read_parquet(vis_df_path) except Exception as e: logger.warning(f"Failed to load visualization spec_df {i} for chat {chat_id}: {e}") vis_data["spec_df"] = None - else: + elif expects_spec_df: logger.warning(f"Missing visualization spec_df parquet for message {i} in chat {chat_id}") vis_data["spec_df"] = None vis_data.pop("_has_spec_df", None) diff --git a/tests/test_plot_propagation.py b/tests/test_plot_propagation.py index f4651378..6d4a02d9 100644 --- a/tests/test_plot_propagation.py +++ b/tests/test_plot_propagation.py @@ -207,8 +207,8 @@ def test_has_spec_df_true_without_parquet(self, mock_chats_dir: MagicMock, tmp_p assert "_has_spec_df" not in msg.visualization_data @patch("databao_cli.ui.services.chat_persistence.get_chats_dir") - def test_has_spec_df_false_skips_parquet_load(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None: - """When _has_spec_df is absent, spec_df is not loaded even if parquet exists.""" + def test_legacy_chat_without_marker_still_loads_parquet(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None: + """Older chats without _has_spec_df still reload spec_df from parquet.""" from databao_cli.ui.services.chat_persistence import load_chat mock_chats_dir.return_value = tmp_path @@ -217,7 +217,7 @@ def test_has_spec_df_false_skips_parquet_load(self, mock_chats_dir: MagicMock, t tmp_path, chat_id, vis_data={"spec": {"mark": "bar"}}, - write_parquet=True, # parquet exists but should NOT be loaded + write_parquet=True, ) chat = load_chat(chat_id) @@ -225,7 +225,8 @@ def test_has_spec_df_false_skips_parquet_load(self, mock_chats_dir: MagicMock, t assert chat is not None msg = chat.messages[0] assert msg.visualization_data is not None - assert "spec_df" not in msg.visualization_data + assert msg.visualization_data["spec_df"] is not None + assert len(msg.visualization_data["spec_df"]) == 2 @patch("databao_cli.ui.services.chat_persistence.get_chats_dir") def test_has_spec_df_marker_always_removed(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None: From 1d23b12e69d5e49e9812884e3bb400294799b2e6 Mon Sep 17 00:00:00 2001 From: Simon Karan Date: Fri, 20 Mar 2026 15:08:47 +0100 Subject: [PATCH 3/4] [DBA-94] Fix docstrings per PR review feedback Co-Authored-By: Claude Opus 4.6 (1M context) --- src/databao_cli/ui/components/results.py | 4 ++++ tests/test_plot_propagation.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/databao_cli/ui/components/results.py b/src/databao_cli/ui/components/results.py index ade628a3..f41a4b96 100644 --- a/src/databao_cli/ui/components/results.py +++ b/src/databao_cli/ui/components/results.py @@ -210,6 +210,10 @@ def render_visualization_section( Args: thread: The Thread object (may have _visualization_result) visualization_data: Optional persisted visualization data (takes priority over thread result) + allow_thread_fallback: If True, fall back to the thread-level + ``_visualization_result`` when *visualization_data* is ``None``. + Disabled by default to prevent a plot from propagating to + messages that did not produce it. """ vis_result = thread._visualization_result diff --git a/tests/test_plot_propagation.py b/tests/test_plot_propagation.py index 6d4a02d9..f63808b0 100644 --- a/tests/test_plot_propagation.py +++ b/tests/test_plot_propagation.py @@ -19,7 +19,7 @@ @pytest.fixture(autouse=True, scope="module") def _preload_results_module() -> None: - """Pre-import the results module to resolve the circular import chain. + """Pre-import the query_executor module to resolve the circular import chain. results → chat_persistence → __init__ → query_executor → results. Importing query_executor first breaks the cycle. From 0ba37937de24b9624337ba6a6d33d6daf8239e17 Mon Sep 17 00:00:00 2001 From: Simon Karan Date: Fri, 20 Mar 2026 16:41:57 +0100 Subject: [PATCH 4/4] [DBA-94] Set has_visualization based on extracted data Only mark has_visualization=True when _extract_visualization_data actually returns data, preventing collapsed sections with no content. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/databao_cli/ui/components/results.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databao_cli/ui/components/results.py b/src/databao_cli/ui/components/results.py index f41a4b96..25dfa895 100644 --- a/src/databao_cli/ui/components/results.py +++ b/src/databao_cli/ui/components/results.py @@ -420,9 +420,9 @@ def execute_pending_plot(chat: "ChatSession") -> None: messages = chat.messages if message_index < len(messages): - messages[message_index].has_visualization = True messages[message_index].visualization_data = _extract_visualization_data(thread) - if messages[message_index].visualization_data is None: + messages[message_index].has_visualization = messages[message_index].visualization_data is not None + if not messages[message_index].has_visualization: logger.warning( "visualization_data is None after successful thread.plot() for message %d", message_index,