From 44940dbe0ec9af18755cc6a94f7be48b3e6fa1dd Mon Sep 17 00:00:00 2001 From: Tirthanu Ghosh Date: Wed, 19 Nov 2025 15:39:51 -0800 Subject: [PATCH] Add hybrid search support with overrideSearchType parameter - Add overrideSearchType parameter to tool input schema with HYBRID/SEMANTIC enum validation - Implement parameter validation and configuration in retrieve function - Add comprehensive tests for hybrid search functionality including error cases - Support both HYBRID and SEMANTIC search types as per AWS Bedrock API documentation - Fix line length violations for code quality compliance - Preserve all existing tests including test_retrieve_via_agent_with_enable_metadata --- src/strands_tools/retrieve.py | 21 +++++ tests/test_retrieve.py | 167 ++++++++++++++++++++++++++++++++++ 2 files changed, 188 insertions(+) diff --git a/src/strands_tools/retrieve.py b/src/strands_tools/retrieve.py index 3d903aca..9eb56e03 100644 --- a/src/strands_tools/retrieve.py +++ b/src/strands_tools/retrieve.py @@ -166,6 +166,14 @@ ), "default": False, }, + "overrideSearchType": { + "type": "string", + "description": ( + "Override the search type for the knowledge base query. Supported values: 'HYBRID', " + "'SEMANTIC'. Default behavior uses the knowledge base's configured search type." + ), + "enum": ["HYBRID", "SEMANTIC"], + }, }, "required": ["text"], } @@ -306,6 +314,16 @@ def retrieve(tool: ToolUse, **kwargs: Any) -> ToolResult: min_score = tool_input.get("score", default_min_score) enable_metadata = tool_input.get("enableMetadata", default_enable_metadata) retrieve_filter = tool_input.get("retrieveFilter") + override_search_type = tool_input.get("overrideSearchType") + + # Validate overrideSearchType if provided + if override_search_type and override_search_type not in ["HYBRID", "SEMANTIC"]: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Invalid overrideSearchType: {override_search_type}. " + f"Supported values: HYBRID, SEMANTIC"}], + } # Initialize Bedrock client with optional profile name profile_name = tool_input.get("profile_name") @@ -321,6 +339,9 @@ def retrieve(tool: ToolUse, **kwargs: Any) -> ToolResult: # Default retrieval configuration retrieval_config = {"vectorSearchConfiguration": {"numberOfResults": number_of_results}} + if override_search_type: + retrieval_config["vectorSearchConfiguration"]["overrideSearchType"] = override_search_type + if retrieve_filter: try: if _validate_filter(retrieve_filter): diff --git a/tests/test_retrieve.py b/tests/test_retrieve.py index 97add221..0a44ec5f 100644 --- a/tests/test_retrieve.py +++ b/tests/test_retrieve.py @@ -656,6 +656,171 @@ def test_retrieve_with_environment_variable_default(mock_boto3_client): assert "test-source-1" not in result_text +def test_retrieve_with_override_search_type_hybrid(mock_boto3_client): + """Test retrieve with overrideSearchType set to HYBRID.""" + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "text": "test query", + "knowledgeBaseId": "test-kb-id", + "overrideSearchType": "HYBRID", + }, + } + + result = retrieve.retrieve(tool=tool_use) + + # Verify the result is successful + assert result["status"] == "success" + assert "Retrieved 2 results with score >= 0.4" in result["content"][0]["text"] + + # Verify that boto3 client was called with overrideSearchType + mock_boto3_client.return_value.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, + knowledgeBaseId="test-kb-id", + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 10, + "overrideSearchType": "HYBRID" + } + }, + ) + + +def test_retrieve_with_override_search_type_semantic(mock_boto3_client): + """Test retrieve with overrideSearchType set to SEMANTIC.""" + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "text": "test query", + "knowledgeBaseId": "test-kb-id", + "overrideSearchType": "SEMANTIC", + }, + } + + result = retrieve.retrieve(tool=tool_use) + + # Verify the result is successful + assert result["status"] == "success" + + # Verify that boto3 client was called with overrideSearchType + mock_boto3_client.return_value.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, + knowledgeBaseId="test-kb-id", + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 10, + "overrideSearchType": "SEMANTIC" + } + }, + ) + + +def test_retrieve_with_invalid_override_search_type(mock_boto3_client): + """Test retrieve with invalid overrideSearchType.""" + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "text": "test query", + "knowledgeBaseId": "test-kb-id", + "overrideSearchType": "INVALID_TYPE", + }, + } + + result = retrieve.retrieve(tool=tool_use) + + # Verify the result is an error + assert result["status"] == "error" + assert "Invalid overrideSearchType: INVALID_TYPE" in result["content"][0]["text"] + assert "Supported values: HYBRID, SEMANTIC" in result["content"][0]["text"] + + # Verify that boto3 client was not called + mock_boto3_client.return_value.retrieve.assert_not_called() + + +def test_retrieve_without_override_search_type(mock_boto3_client): + """Test retrieve without overrideSearchType (default behavior).""" + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "text": "test query", + "knowledgeBaseId": "test-kb-id", + }, + } + + result = retrieve.retrieve(tool=tool_use) + + # Verify the result is successful + assert result["status"] == "success" + + # Verify that boto3 client was called without overrideSearchType + mock_boto3_client.return_value.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, + knowledgeBaseId="test-kb-id", + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 10 + } + }, + ) + + +def test_retrieve_with_override_search_type_and_filter(mock_boto3_client): + """Test retrieve with both overrideSearchType and retrieveFilter.""" + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "text": "test query", + "knowledgeBaseId": "test-kb-id", + "overrideSearchType": "HYBRID", + "retrieveFilter": {"equals": {"key": "category", "value": "security"}}, + }, + } + + result = retrieve.retrieve(tool=tool_use) + + # Verify the result is successful + assert result["status"] == "success" + + # Verify that boto3 client was called with both overrideSearchType and filter + mock_boto3_client.return_value.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, + knowledgeBaseId="test-kb-id", + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 10, + "overrideSearchType": "HYBRID", + "filter": {"equals": {"key": "category", "value": "security"}} + } + }, + ) + + +def test_retrieve_via_agent_with_override_search_type(agent, mock_boto3_client): + """Test retrieving via the agent interface with overrideSearchType.""" + with mock.patch.dict(os.environ, {"KNOWLEDGE_BASE_ID": "agent-kb-id"}): + result = agent.tool.retrieve( + text="agent query", + knowledgeBaseId="test-kb-id", + overrideSearchType="HYBRID" + ) + + result_text = extract_result_text(result) + assert "Retrieved" in result_text + assert "results with score >=" in result_text + + # Verify the boto3 client was called with overrideSearchType + mock_boto3_client.return_value.retrieve.assert_called_once_with( + retrievalQuery={"text": "agent query"}, + knowledgeBaseId="test-kb-id", + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 10, + "overrideSearchType": "HYBRID" + } + }, + ) + + def test_retrieve_via_agent_with_enable_metadata(agent, mock_boto3_client): """Test retrieving via the agent interface with enableMetadata.""" with mock.patch.dict(os.environ, {"KNOWLEDGE_BASE_ID": "agent-kb-id"}): @@ -677,3 +842,5 @@ def test_retrieve_via_agent_with_enable_metadata(agent, mock_boto3_client): assert "results with score >=" in result_text assert "Metadata:" not in result_text assert "test-source" not in result_text + +