Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/databao_cli/ui/components/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,12 @@ def render_dataframe_section(result: "ExecutionResult", has_visualization: bool)
st.dataframe(df, width="stretch")


def render_visualization_section(thread: "Thread", visualization_data: dict[str, Any] | None = None) -> None:
def render_visualization_section(
thread: "Thread",
visualization_data: dict[str, Any] | None = None,
*,
allow_thread_fallback: bool = False,
) -> None:
"""Render the visualization section.

Follows the same rendering logic as Jupyter notebooks:
Expand All @@ -205,6 +210,10 @@ def render_visualization_section(thread: "Thread", visualization_data: dict[str,
Args:
thread: The Thread object (may have _visualization_result)
visualization_data: Optional persisted visualization data (takes priority over thread result)
allow_thread_fallback: If True, fall back to the thread-level
``_visualization_result`` when *visualization_data* is ``None``.
Disabled by default to prevent a plot from propagating to
messages that did not produce it.
"""
vis_result = thread._visualization_result

Expand All @@ -221,7 +230,7 @@ def render_visualization_section(thread: "Thread", visualization_data: dict[str,
return
return

if vis_result is None:
if vis_result is None or not allow_thread_fallback:
return

with st.expander("📈 Visualization", expanded=True):
Expand Down Expand Up @@ -349,8 +358,10 @@ def render_visualization_and_actions(
st.info("Generating visualization...", icon="📈")
elif viz_error:
st.error(f"Failed to generate visualization: {viz_error}")
elif has_visualization or visualization_data is not None or (is_latest and thread._visualization_result is not None):
elif visualization_data is not None:
render_visualization_section(thread, visualization_data)
elif is_latest and (has_visualization or thread._visualization_result is not None):
render_visualization_section(thread, None, allow_thread_fallback=True)
Comment on lines +361 to +364
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In render_visualization_and_actions, the elif visualization_data is not None: branch will run even when the dict can’t actually be rendered (e.g., missing spec/spec_df), which prevents the latest-message thread fallback branch from ever running. This can lead to the latest message showing no visualization despite thread._visualization_result being available. Consider gating this branch on persisted data being renderable (e.g., both spec and spec_df present), or allowing render_visualization_section to fall back to the thread result when allow_thread_fallback=True and persisted data is incomplete.

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_extract_visualization_data is all-or-nothing: it either returns a complete dict with both spec and spec_df, or None. There is no code path that produces an incomplete dict, so in practice this branch cannot run with unrenderable data. The early return at L231 in render_visualization_section is already a safety net for that theoretical case. Adding an extra renderability gate here would be defensive code for a scenario that cannot currently occur.


if is_latest and not viz_pending and not viz_error:
_render_and_handle_action_buttons(result, current_chat, message_index, has_visualization)
Expand Down Expand Up @@ -409,8 +420,13 @@ def execute_pending_plot(chat: "ChatSession") -> None:

messages = chat.messages
if message_index < len(messages):
messages[message_index].has_visualization = True
messages[message_index].visualization_data = _extract_visualization_data(thread)
messages[message_index].has_visualization = messages[message_index].visualization_data is not None
if not messages[message_index].has_visualization:
logger.warning(
"visualization_data is None after successful thread.plot() for message %d",
message_index,
)
save_current_chat()
except Exception as e:
logger.exception("Failed to generate visualization")
Expand Down
18 changes: 13 additions & 5 deletions src/databao_cli/ui/services/chat_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,19 @@ def load_chat(chat_id: str) -> ChatSession | None:
vis_data = msg_data.get("visualization_data")
if vis_data is not None:
vis_df_path = visualizations_dir / f"{i}_spec_df.parquet"
if vis_df_path.exists():
try:
vis_data["spec_df"] = pd.read_parquet(vis_df_path)
except Exception as e:
logger.warning(f"Failed to load visualization spec_df {i} for chat {chat_id}: {e}")
expects_spec_df = bool(vis_data.get("_has_spec_df"))
has_spec_df_parquet = vis_df_path.exists()
if expects_spec_df or has_spec_df_parquet:
if has_spec_df_parquet:
try:
vis_data["spec_df"] = pd.read_parquet(vis_df_path)
except Exception as e:
logger.warning(f"Failed to load visualization spec_df {i} for chat {chat_id}: {e}")
vis_data["spec_df"] = None
elif expects_spec_df:
logger.warning(f"Missing visualization spec_df parquet for message {i} in chat {chat_id}")
vis_data["spec_df"] = None
vis_data.pop("_has_spec_df", None)

chat = ChatSession.from_dict(session_data, results)
logger.debug(f"Chat loaded: {chat_id}")
Expand Down
250 changes: 250 additions & 0 deletions tests/test_plot_propagation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
"""Tests for DBA-94: plot propagation to previous messages.

Verifies that:
1. render_visualization_section refuses the shared thread fallback unless
explicitly allowed via allow_thread_fallback=True.
2. The persistence round-trip correctly handles missing parquet files when
_has_spec_df is True.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch

import pandas as pd
import pytest


@pytest.fixture(autouse=True, scope="module")
def _preload_results_module() -> None:
"""Pre-import the query_executor module to resolve the circular import chain.

results → chat_persistence → __init__ → query_executor → results.
Importing query_executor first breaks the cycle.
"""
import databao_cli.ui.services.query_executor # noqa: F401


def _get_render_fn() -> Any:
"""Import render_visualization_section lazily to avoid circular import at collection time."""
from databao_cli.ui.components.results import render_visualization_section

return render_visualization_section


# ---------------------------------------------------------------------------
# render_visualization_section tests
# ---------------------------------------------------------------------------


class TestRenderVisualizationSection:
"""Ensure the allow_thread_fallback guard works correctly."""

def test_thread_fallback_blocked_by_default(self) -> None:
"""Without allow_thread_fallback the shared thread result is ignored."""
render_visualization_section = _get_render_fn()
with patch("databao_cli.ui.components.results.st") as mock_st:
thread = MagicMock()
thread._visualization_result = MagicMock()

render_visualization_section(thread, visualization_data=None)

mock_st.expander.assert_not_called()

@patch("databao_cli.ui.components.results.st")
def test_thread_fallback_allowed_renders(self, mock_st: MagicMock) -> None:
"""With allow_thread_fallback=True the shared thread result is used."""
render_visualization_section = _get_render_fn()
thread = MagicMock()
vis = MagicMock()
vis.plot = None
thread._visualization_result = vis

mock_expander = MagicMock()
mock_st.expander.return_value.__enter__ = MagicMock(return_value=mock_expander)
mock_st.expander.return_value.__exit__ = MagicMock(return_value=False)

render_visualization_section(thread, visualization_data=None, allow_thread_fallback=True)

mock_st.expander.assert_called_once()

@patch("databao_cli.ui.components.results.st")
def test_visualization_data_takes_priority(self, mock_st: MagicMock) -> None:
"""Per-message visualization_data is used even when thread has a result."""
render_visualization_section = _get_render_fn()
thread = MagicMock()
thread._visualization_result = MagicMock()

spec = {"mark": "bar"}
spec_df = pd.DataFrame({"x": [1], "y": [2]})
vis_data: dict[str, Any] = {"spec": spec, "spec_df": spec_df}

mock_expander = MagicMock()
mock_st.expander.return_value.__enter__ = MagicMock(return_value=mock_expander)
mock_st.expander.return_value.__exit__ = MagicMock(return_value=False)

render_visualization_section(thread, visualization_data=vis_data)

mock_st.vega_lite_chart.assert_called_once()

@patch("databao_cli.ui.components.results.st")
def test_visualization_data_missing_spec_df_returns_early(self, mock_st: MagicMock) -> None:
"""If visualization_data has spec but no spec_df, returns without rendering."""
render_visualization_section = _get_render_fn()
thread = MagicMock()
thread._visualization_result = MagicMock()

vis_data: dict[str, Any] = {"spec": {"mark": "bar"}, "spec_df": None}

render_visualization_section(thread, visualization_data=vis_data)

mock_st.expander.assert_not_called()

@patch("databao_cli.ui.components.results.st")
def test_no_vis_result_no_data_returns_early(self, mock_st: MagicMock) -> None:
"""No visualization_data and no thread result -> nothing rendered."""
render_visualization_section = _get_render_fn()
thread = MagicMock()
thread._visualization_result = None

render_visualization_section(thread, visualization_data=None, allow_thread_fallback=True)

mock_st.expander.assert_not_called()


# ---------------------------------------------------------------------------
# Persistence round-trip tests for _has_spec_df handling
# ---------------------------------------------------------------------------


class TestVisualizationPersistenceRoundTrip:
"""Test the _has_spec_df loading logic by calling the real load_chat function.

We mock get_chats_dir to point at a tmp directory and build a minimal
on-disk chat structure so load_chat exercises the actual persistence code.
"""

@staticmethod
def _write_chat(
tmp_path: Path,
chat_id: str,
vis_data: dict[str, Any] | None,
write_parquet: bool = False,
) -> None:
"""Write a minimal chat directory that load_chat can read."""
import json

chat_dir = tmp_path / chat_id
chat_dir.mkdir(parents=True, exist_ok=True)

session = {
"id": chat_id,
"created_at": "2026-01-01T00:00:00",
"messages": [
{
"role": "assistant",
"content": "test answer",
"has_visualization": True,
"visualization_data": vis_data,
},
],
}
(chat_dir / "session.json").write_text(json.dumps(session))

if write_parquet:
vis_dir = chat_dir / "visualizations"
vis_dir.mkdir(exist_ok=True)
df = pd.DataFrame({"x": [1, 2], "y": [3, 4]})
df.to_parquet(vis_dir / "0_spec_df.parquet")

@patch("databao_cli.ui.services.chat_persistence.get_chats_dir")
def test_has_spec_df_true_with_parquet(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None:
"""When _has_spec_df=True and parquet exists, spec_df is loaded."""
from databao_cli.ui.services.chat_persistence import load_chat

mock_chats_dir.return_value = tmp_path
chat_id = "00000000-0000-0000-0000-000000000001"
self._write_chat(
tmp_path,
chat_id,
vis_data={"spec": {"mark": "bar"}, "_has_spec_df": True},
write_parquet=True,
)

chat = load_chat(chat_id)

assert chat is not None
msg = chat.messages[0]
assert msg.visualization_data is not None
assert msg.visualization_data["spec_df"] is not None
assert len(msg.visualization_data["spec_df"]) == 2
assert "_has_spec_df" not in msg.visualization_data

@patch("databao_cli.ui.services.chat_persistence.get_chats_dir")
def test_has_spec_df_true_without_parquet(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None:
"""When _has_spec_df=True but parquet is missing, spec_df is explicitly None."""
from databao_cli.ui.services.chat_persistence import load_chat

mock_chats_dir.return_value = tmp_path
chat_id = "00000000-0000-0000-0000-000000000002"
self._write_chat(
tmp_path,
chat_id,
vis_data={"spec": {"mark": "bar"}, "_has_spec_df": True},
write_parquet=False,
)

chat = load_chat(chat_id)

assert chat is not None
msg = chat.messages[0]
assert msg.visualization_data is not None
# Key assertion: spec_df must be explicitly None, not absent
assert "spec_df" in msg.visualization_data
assert msg.visualization_data["spec_df"] is None
assert "_has_spec_df" not in msg.visualization_data

@patch("databao_cli.ui.services.chat_persistence.get_chats_dir")
def test_legacy_chat_without_marker_still_loads_parquet(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None:
"""Older chats without _has_spec_df still reload spec_df from parquet."""
from databao_cli.ui.services.chat_persistence import load_chat

mock_chats_dir.return_value = tmp_path
chat_id = "00000000-0000-0000-0000-000000000003"
self._write_chat(
tmp_path,
chat_id,
vis_data={"spec": {"mark": "bar"}},
write_parquet=True,
)

chat = load_chat(chat_id)

assert chat is not None
msg = chat.messages[0]
assert msg.visualization_data is not None
assert msg.visualization_data["spec_df"] is not None
assert len(msg.visualization_data["spec_df"]) == 2

@patch("databao_cli.ui.services.chat_persistence.get_chats_dir")
def test_has_spec_df_marker_always_removed(self, mock_chats_dir: MagicMock, tmp_path: Path) -> None:
"""The _has_spec_df marker is always stripped from the loaded data."""
from databao_cli.ui.services.chat_persistence import load_chat

mock_chats_dir.return_value = tmp_path
chat_id = "00000000-0000-0000-0000-000000000004"
self._write_chat(
tmp_path,
chat_id,
vis_data={"spec": {"mark": "bar"}, "_has_spec_df": False},
write_parquet=False,
)

chat = load_chat(chat_id)

assert chat is not None
msg = chat.messages[0]
assert msg.visualization_data is not None
assert "_has_spec_df" not in msg.visualization_data
Loading