Skip to content

Commit

Permalink
only report tool traces in code generation time (#338)
Browse files Browse the repository at this point in the history
* only report tool traces in code generation time
  • Loading branch information
humpydonkey authored Jan 7, 2025
1 parent 41dc84b commit f9bce5a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
6 changes: 5 additions & 1 deletion vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
_LND_API_URL_v2 = f"{_LND_BASE_URL}/v1/tools"


def should_report_tool_traces() -> bool:
return bool(os.environ.get("REPORT_TOOL_TRACES", False))


class ToolCallTrace(BaseModel):
endpoint_url: str
type: str
Expand Down Expand Up @@ -251,7 +255,7 @@ def _call_post(
tool_call_trace.response = result
return result
finally:
if tool_call_trace is not None:
if tool_call_trace is not None and should_report_tool_traces():
trace = tool_call_trace.model_dump()
display({MimeType.APPLICATION_JSON: trace}, raw=True)

Expand Down
17 changes: 13 additions & 4 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
nms,
send_inference_request,
send_task_inference_request,
should_report_tool_traces,
single_nms,
)
from vision_agent.tools.tools_types import JobStatus
Expand Down Expand Up @@ -94,6 +95,9 @@ def _display_tool_trace(
# such as video bytes, which can be slow. Since this is calculated inside the
# function we can't capture it with a decarator without adding it as a return value
# which would change the function signature and affect the agent.
if not should_report_tool_traces():
return

files_in_b64: List[Tuple[str, str]]
if isinstance(files, str):
files_in_b64 = [("images", files)]
Expand Down Expand Up @@ -264,7 +268,7 @@ def od_sam2_video_tracking(
image_size = frames[0].shape[:2]

def _transform_detections(
input_list: List[Optional[List[Dict[str, Any]]]]
input_list: List[Optional[List[Dict[str, Any]]]],
) -> List[Optional[Dict[str, Any]]]:
output_list: List[Optional[Dict[str, Any]]] = []

Expand Down Expand Up @@ -2243,15 +2247,17 @@ def save_image(image: np.ndarray, file_path: str) -> None:
>>> save_image(image)
"""
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
from IPython.display import display

if not isinstance(image, np.ndarray) or (
image.shape[0] == 0 and image.shape[1] == 0
):
raise ValueError("The image is not a valid NumPy array with shape (H, W, C)")

pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
display(pil_image)
if should_report_tool_traces():
from IPython.display import display

display(pil_image)

pil_image.save(file_path)


Expand Down Expand Up @@ -2302,6 +2308,9 @@ def save_video(

def _save_video_to_result(video_uri: str) -> None:
"""Saves a video into the result of the code execution (as an intermediate output)."""
if not should_report_tool_traces():
return

from IPython.display import display

serializer = FileSerializer(video_uri)
Expand Down

0 comments on commit f9bce5a

Please sign in to comment.