diff --git a/CHANGELOG.md b/CHANGELOG.md index ee98295..1c1b807 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Added - Add `GenericOpenSearchApiTool` - A flexible, general-purpose tool that can interact with any OpenSearch API endpoint, addressing tool explosion and reducing context size. Supports all HTTP methods with write operation protection via `OPENSEARCH_SETTINGS_ALLOW_WRITE` environment variable. Closes [#109](https://github.com/opensearch-project/opensearch-mcp-server-py/issues/109) - Add header-based authentication + Code Clean up ([#117](https://github.com/opensearch-project/opensearch-mcp-server-py/pull/117)) +- Add skills tools integration ([#121](https://github.com/opensearch-project/opensearch-mcp-server-py/pull/121)) ### Fixed - Fix Concurrency: Use Async OpenSearch client to improve concurrency ([#125](https://github.com/opensearch-project/opensearch-mcp-server-py/pull/125)) diff --git a/src/tools/skills_tools.py b/src/tools/skills_tools.py new file mode 100644 index 0000000..2760749 --- /dev/null +++ b/src/tools/skills_tools.py @@ -0,0 +1,104 @@ +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from typing import Dict, Any, List +from .tool_params import baseToolArgs +from pydantic import Field +from opensearch.client import initialize_client + +logger = logging.getLogger(__name__) + +class DataDistributionToolArgs(baseToolArgs): + index: str = Field(description="Target OpenSearch index name") + selectionTimeRangeStart: str = Field(description="Start time for analysis period") + selectionTimeRangeEnd: str = Field(description="End time for analysis period") + timeField: str = Field(description="Date/time field for filtering(requied)") + baselineTimeRangeStart: str = Field(default="", description="Start time for baseline period (optional)") + baselineTimeRangeEnd: str = Field(default="", description="End time for baseline period (optional)") + size: int = Field(default=1000, description="Maximum number of documents to analyze") + +class LogPatternAnalysisToolArgs(baseToolArgs): + index: str = Field(description="Target OpenSearch index name containing log data") + logFieldName: str = Field(description="Field containing raw log messages to analyze") + selectionTimeRangeStart: str = Field(description="Start time for analysis target period") + selectionTimeRangeEnd: str = Field(description="End time for analysis target period") + timeField: str = Field(description="Date/time field for time-based filtering(requied)") + traceFieldName: str = Field(default="", description="Field for trace/correlation ID (optional)") + baseTimeRangeStart: str = Field(default="", description="Start time for baseline comparison period (optional)") + baseTimeRangeEnd: str = Field(default="", description="End time for baseline comparison period (optional)") + +async def call_opensearch_tool(tool_name: str, parameters: Dict[str, Any], args: baseToolArgs) -> List[Dict]: + """Call OpenSearch ML tools API""" + try: + client = initialize_client(args) + + # Call OpenSearch ML tools execute API + response = client.transport.perform_request( + 'POST', + f'/_plugins/_ml/tools/_execute/{tool_name}', + body={'parameters': parameters} + ) + + logger.info(f"Tool {tool_name} result: {json.dumps(response, indent=2)}") + formatted_result = json.dumps(response, indent=2) + return [{'type': 'text', 'text': f'{tool_name} result:\n{formatted_result}'}] + + except Exception as e: + return [{'type': 'text', 'text': f'Error executing {tool_name}: {str(e)}'}] + +async def data_distribution_tool(args: DataDistributionToolArgs) -> List[Dict]: + params = { + 'index': args.index, + 'timeField': args.timeField, + 'selectionTimeRangeStart': args.selectionTimeRangeStart, + 'selectionTimeRangeEnd': args.selectionTimeRangeEnd, + 'size': args.size + } + if args.baselineTimeRangeStart: + params['baselineTimeRangeStart'] = args.baselineTimeRangeStart + if args.baselineTimeRangeEnd: + params['baselineTimeRangeEnd'] = args.baselineTimeRangeEnd + + result = await call_opensearch_tool('DataDistributionTool', params, args) + return result + +async def log_pattern_analysis_tool(args: LogPatternAnalysisToolArgs) -> List[Dict]: + params = { + 'index': args.index, + 'timeField': args.timeField, + 'logFieldName': args.logFieldName, + 'selectionTimeRangeStart': args.selectionTimeRangeStart, + 'selectionTimeRangeEnd': args.selectionTimeRangeEnd + } + if args.traceFieldName: + params['traceFieldName'] = args.traceFieldName + if args.baseTimeRangeStart: + params['baseTimeRangeStart'] = args.baseTimeRangeStart + if args.baseTimeRangeEnd: + params['baseTimeRangeEnd'] = args.baseTimeRangeEnd + + result = await call_opensearch_tool('LogPatternAnalysisTool', params, args) + return result + +SKILLS_TOOLS_REGISTRY = { + 'DataDistributionTool': { + 'display_name': 'DataDistributionTool', + 'description': 'Analyzes data distribution patterns and field value frequencies within OpenSearch indices. Supports both single dataset analysis for understanding data characteristics and comparative analysis between two time periods to identify distribution changes. Automatically detects useful fields, calculates value distributions, groups numeric data, and computes divergence metrics. Useful for anomaly detection, data quality assessment, and trend analysis. We can use this tool to analyze the distribution of failures over time', + 'input_schema': DataDistributionToolArgs.model_json_schema(), + 'function': data_distribution_tool, + 'args_model': DataDistributionToolArgs, + 'min_version': '3.3.0', + 'http_methods': 'POST', + }, + 'LogPatternAnalysisTool': { + 'display_name': 'LogPatternAnalysisTool', + 'description': 'Intelligent log pattern analysis tool for troubleshooting and anomaly detection in application logs. Use this tool when you need to: analyze error patterns in logs, identify unusual log sequences, compare log patterns between time periods, find root causes of system issues, detect anomalous behavior in application traces, or investigate performance problems. The tool automatically extracts meaningful patterns from raw log messages, groups similar patterns, identifies outliers, and provides insights for debugging. Essential for log-based troubleshooting, incident analysis, and proactive monitoring of system health.', + 'input_schema': LogPatternAnalysisToolArgs.model_json_schema(), + 'function': log_pattern_analysis_tool, + 'args_model': LogPatternAnalysisToolArgs, + 'min_version': '3.3.0', + 'http_methods': 'POST', + }, +} \ No newline at end of file diff --git a/src/tools/tool_filter.py b/src/tools/tool_filter.py index b4a7854..533b3fa 100644 --- a/src/tools/tool_filter.py +++ b/src/tools/tool_filter.py @@ -171,6 +171,8 @@ def process_tool_filter( 'ExplainTool', 'MsearchTool', 'GenericOpenSearchApiTool', + 'DataDistributionTool', + 'LogPatternAnalysisTool', ] # Build core tools list using display names diff --git a/src/tools/tools.py b/src/tools/tools.py index e644ba4..7a2e32d 100644 --- a/src/tools/tools.py +++ b/src/tools/tools.py @@ -38,6 +38,7 @@ list_indices, search_index, ) +from .skills_tools import SKILLS_TOOLS_REGISTRY async def check_tool_compatibility(tool_name: str, args: baseToolArgs = None): @@ -492,6 +493,7 @@ async def get_long_running_tasks_tool(args: GetLongRunningTasksArgs) -> list[dic # Registry of available OpenSearch tools with their metadata TOOL_REGISTRY = { + **SKILLS_TOOLS_REGISTRY, 'ListIndexTool': { 'display_name': 'ListIndexTool', 'description': 'Lists indices in the OpenSearch cluster. By default, returns a filtered list of index names only to minimize response size. Set include_detail=true to return full metadata from cat.indices (docs.count, store.size, etc.). If an index parameter is provided, returns detailed information for that specific index including mappings and settings.', diff --git a/tests/tools/test_skills_tools.py b/tests/tools/test_skills_tools.py new file mode 100644 index 0000000..a0e00b8 --- /dev/null +++ b/tests/tools/test_skills_tools.py @@ -0,0 +1,346 @@ +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 + +import json +import pytest +import sys +from unittest.mock import Mock, patch + + +class TestSkillsTools: + def setup_method(self): + """Setup that runs before each test method.""" + # Create a properly configured mock client + self.mock_client = Mock() + + # Configure mock client methods to return proper data structures + self.mock_client.transport.perform_request.return_value = {} + self.mock_client.info.return_value = {'version': {'number': '3.3.0'}} + + # Patch initialize_client to always return our mock client + self.init_client_patcher = patch( + 'opensearch.client.initialize_client', return_value=self.mock_client + ) + self.init_client_patcher.start() + + # Clear any existing imports to ensure fresh imports + modules_to_clear = [ + 'tools.skills_tools', + ] + for module in modules_to_clear: + if module in sys.modules: + del sys.modules[module] + + # Import after patching to ensure fresh imports + from tools.skills_tools import ( + SKILLS_TOOLS_REGISTRY, + DataDistributionToolArgs, + LogPatternAnalysisToolArgs, + data_distribution_tool, + log_pattern_analysis_tool, + call_opensearch_tool, + ) + + self.SKILLS_TOOLS_REGISTRY = SKILLS_TOOLS_REGISTRY + self.DataDistributionToolArgs = DataDistributionToolArgs + self.LogPatternAnalysisToolArgs = LogPatternAnalysisToolArgs + self._data_distribution_tool = data_distribution_tool + self._log_pattern_analysis_tool = log_pattern_analysis_tool + self._call_opensearch_tool = call_opensearch_tool + + def teardown_method(self): + """Cleanup after each test method.""" + self.init_client_patcher.stop() + + @pytest.mark.asyncio + async def test_call_opensearch_tool_success(self): + """Test call_opensearch_tool successful execution.""" + # Setup + mock_response = { + 'status': 'success', + 'result': {'analysis': 'data distribution complete'} + } + self.mock_client.transport.perform_request.return_value = mock_response + + args = self.DataDistributionToolArgs( + index='test-index', + selectionTimeRangeStart='2023-01-01T00:00:00Z', + selectionTimeRangeEnd='2023-01-02T00:00:00Z', + timeField='@timestamp', + opensearch_cluster_name='' + ) + + # Execute + result = await self._call_opensearch_tool('DataDistributionTool', {'index': 'test-index'}, args) + + # Assert + assert len(result) == 1 + assert result[0]['type'] == 'text' + assert 'DataDistributionTool result:' in result[0]['text'] + assert '"status": "success"' in result[0]['text'] + self.mock_client.transport.perform_request.assert_called_once_with( + 'POST', + '/_plugins/_ml/tools/_execute/DataDistributionTool', + body={'parameters': {'index': 'test-index'}} + ) + + @pytest.mark.asyncio + async def test_call_opensearch_tool_error(self): + """Test call_opensearch_tool exception handling.""" + # Setup + self.mock_client.transport.perform_request.side_effect = Exception('Test error') + + args = self.DataDistributionToolArgs( + index='test-index', + selectionTimeRangeStart='2023-01-01T00:00:00Z', + selectionTimeRangeEnd='2023-01-02T00:00:00Z', + timeField='@timestamp', + opensearch_cluster_name='' + ) + + # Execute + result = await self._call_opensearch_tool('DataDistributionTool', {'index': 'test-index'}, args) + + # Assert + assert len(result) == 1 + assert result[0]['type'] == 'text' + assert 'Error executing DataDistributionTool: Test error' in result[0]['text'] + + @pytest.mark.asyncio + async def test_data_distribution_tool_minimal_params(self): + """Test data_distribution_tool with minimal required parameters.""" + # Setup + mock_response = { + 'status': 'success', + 'result': {'field_distributions': {'field1': {'count': 100}}} + } + self.mock_client.transport.perform_request.return_value = mock_response + + args = self.DataDistributionToolArgs( + index='test-index', + selectionTimeRangeStart='2023-01-01T00:00:00Z', + selectionTimeRangeEnd='2023-01-02T00:00:00Z', + timeField='@timestamp', + opensearch_cluster_name='' + ) + + # Execute + result = await self._data_distribution_tool(args) + + # Assert + assert len(result) == 1 + assert result[0]['type'] == 'text' + assert 'DataDistributionTool result:' in result[0]['text'] + + # Verify the correct parameters were passed + expected_params = { + 'index': 'test-index', + 'timeField': '@timestamp', + 'selectionTimeRangeStart': '2023-01-01T00:00:00Z', + 'selectionTimeRangeEnd': '2023-01-02T00:00:00Z', + 'size': 1000 + } + self.mock_client.transport.perform_request.assert_called_once_with( + 'POST', + '/_plugins/_ml/tools/_execute/DataDistributionTool', + body={'parameters': expected_params} + ) + + @pytest.mark.asyncio + async def test_data_distribution_tool_all_params(self): + """Test data_distribution_tool with all parameters.""" + # Setup + mock_response = {'status': 'success', 'result': {}} + self.mock_client.transport.perform_request.return_value = mock_response + + args = self.DataDistributionToolArgs( + index='test-index', + selectionTimeRangeStart='2023-01-01T00:00:00Z', + selectionTimeRangeEnd='2023-01-02T00:00:00Z', + timeField='@timestamp', + baselineTimeRangeStart='2022-12-01T00:00:00Z', + baselineTimeRangeEnd='2022-12-02T00:00:00Z', + size=500, + opensearch_cluster_name='' + ) + + # Execute + result = await self._data_distribution_tool(args) + + # Assert + expected_params = { + 'index': 'test-index', + 'timeField': '@timestamp', + 'selectionTimeRangeStart': '2023-01-01T00:00:00Z', + 'selectionTimeRangeEnd': '2023-01-02T00:00:00Z', + 'size': 500, + 'baselineTimeRangeStart': '2022-12-01T00:00:00Z', + 'baselineTimeRangeEnd': '2022-12-02T00:00:00Z' + } + self.mock_client.transport.perform_request.assert_called_once_with( + 'POST', + '/_plugins/_ml/tools/_execute/DataDistributionTool', + body={'parameters': expected_params} + ) + + @pytest.mark.asyncio + async def test_log_pattern_analysis_tool_minimal_params(self): + """Test log_pattern_analysis_tool with minimal required parameters.""" + # Setup + mock_response = { + 'status': 'success', + 'result': {'patterns': [{'pattern': 'ERROR', 'count': 10}]} + } + self.mock_client.transport.perform_request.return_value = mock_response + + args = self.LogPatternAnalysisToolArgs( + index='logs-index', + logFieldName='message', + selectionTimeRangeStart='2023-01-01T00:00:00Z', + selectionTimeRangeEnd='2023-01-02T00:00:00Z', + timeField='@timestamp', + opensearch_cluster_name='' + ) + + # Execute + result = await self._log_pattern_analysis_tool(args) + + # Assert + assert len(result) == 1 + assert result[0]['type'] == 'text' + assert 'LogPatternAnalysisTool result:' in result[0]['text'] + + expected_params = { + 'index': 'logs-index', + 'timeField': '@timestamp', + 'logFieldName': 'message', + 'selectionTimeRangeStart': '2023-01-01T00:00:00Z', + 'selectionTimeRangeEnd': '2023-01-02T00:00:00Z' + } + self.mock_client.transport.perform_request.assert_called_once_with( + 'POST', + '/_plugins/_ml/tools/_execute/LogPatternAnalysisTool', + body={'parameters': expected_params} + ) + + @pytest.mark.asyncio + async def test_log_pattern_analysis_tool_all_params(self): + """Test log_pattern_analysis_tool with all parameters.""" + # Setup + mock_response = {'status': 'success', 'result': {}} + self.mock_client.transport.perform_request.return_value = mock_response + + args = self.LogPatternAnalysisToolArgs( + index='logs-index', + logFieldName='message', + selectionTimeRangeStart='2023-01-01T00:00:00Z', + selectionTimeRangeEnd='2023-01-02T00:00:00Z', + timeField='@timestamp', + traceFieldName='trace_id', + baseTimeRangeStart='2022-12-01T00:00:00Z', + baseTimeRangeEnd='2022-12-02T00:00:00Z', + opensearch_cluster_name='' + ) + + # Execute + result = await self._log_pattern_analysis_tool(args) + + # Assert + expected_params = { + 'index': 'logs-index', + 'timeField': '@timestamp', + 'logFieldName': 'message', + 'selectionTimeRangeStart': '2023-01-01T00:00:00Z', + 'selectionTimeRangeEnd': '2023-01-02T00:00:00Z', + 'traceFieldName': 'trace_id', + 'baseTimeRangeStart': '2022-12-01T00:00:00Z', + 'baseTimeRangeEnd': '2022-12-02T00:00:00Z' + } + self.mock_client.transport.perform_request.assert_called_once_with( + 'POST', + '/_plugins/_ml/tools/_execute/LogPatternAnalysisTool', + body={'parameters': expected_params} + ) + + def test_skills_tools_registry(self): + """Test SKILLS_TOOLS_REGISTRY structure.""" + expected_tools = [ + 'DataDistributionTool', + 'LogPatternAnalysisTool', + ] + + for tool in expected_tools: + assert tool in self.SKILLS_TOOLS_REGISTRY + assert 'display_name' in self.SKILLS_TOOLS_REGISTRY[tool] + assert 'description' in self.SKILLS_TOOLS_REGISTRY[tool] + assert 'input_schema' in self.SKILLS_TOOLS_REGISTRY[tool] + assert 'function' in self.SKILLS_TOOLS_REGISTRY[tool] + assert 'args_model' in self.SKILLS_TOOLS_REGISTRY[tool] + assert 'min_version' in self.SKILLS_TOOLS_REGISTRY[tool] + assert 'http_methods' in self.SKILLS_TOOLS_REGISTRY[tool] + + def test_data_distribution_tool_args_validation(self): + """Test DataDistributionToolArgs validation.""" + # Test valid inputs + args = self.DataDistributionToolArgs( + index='test-index', + selectionTimeRangeStart='2023-01-01T00:00:00Z', + selectionTimeRangeEnd='2023-01-02T00:00:00Z', + timeField='@timestamp', + opensearch_cluster_name='' + ) + assert args.index == 'test-index' + assert args.timeField == '@timestamp' + assert args.size == 1000 # default value + + # Test with custom size + args_custom = self.DataDistributionToolArgs( + index='test-index', + selectionTimeRangeStart='2023-01-01T00:00:00Z', + selectionTimeRangeEnd='2023-01-02T00:00:00Z', + timeField='@timestamp', + size=500, + opensearch_cluster_name='' + ) + assert args_custom.size == 500 + + def test_log_pattern_analysis_tool_args_validation(self): + """Test LogPatternAnalysisToolArgs validation.""" + # Test valid inputs + args = self.LogPatternAnalysisToolArgs( + index='logs-index', + logFieldName='message', + selectionTimeRangeStart='2023-01-01T00:00:00Z', + selectionTimeRangeEnd='2023-01-02T00:00:00Z', + timeField='@timestamp', + opensearch_cluster_name='' + ) + assert args.index == 'logs-index' + assert args.logFieldName == 'message' + assert args.timeField == '@timestamp' + assert args.traceFieldName == '' # default empty value + + # Test with optional fields + args_full = self.LogPatternAnalysisToolArgs( + index='logs-index', + logFieldName='message', + selectionTimeRangeStart='2023-01-01T00:00:00Z', + selectionTimeRangeEnd='2023-01-02T00:00:00Z', + timeField='@timestamp', + traceFieldName='trace_id', + baseTimeRangeStart='2022-12-01T00:00:00Z', + baseTimeRangeEnd='2022-12-02T00:00:00Z', + opensearch_cluster_name='' + ) + assert args_full.traceFieldName == 'trace_id' + assert args_full.baseTimeRangeStart == '2022-12-01T00:00:00Z' + + def test_input_models_validation(self): + """Test input models validation for required fields.""" + # Test DataDistributionToolArgs - should fail without required fields + with pytest.raises(ValueError): + self.DataDistributionToolArgs(opensearch_cluster_name='') # Missing required fields + + # Test LogPatternAnalysisToolArgs - should fail without required fields + with pytest.raises(ValueError): + self.LogPatternAnalysisToolArgs(opensearch_cluster_name='') # Missing required fields \ No newline at end of file diff --git a/tests/tools/test_tool_filters.py b/tests/tools/test_tool_filters.py index 2096109..f0422ca 100644 --- a/tests/tools/test_tool_filters.py +++ b/tests/tools/test_tool_filters.py @@ -32,6 +32,36 @@ 'min_version': '2.0.0', 'max_version': '3.0.0', }, + 'DataDistributionTool': { + 'display_name': 'DataDistributionTool', + 'description': 'Analyze data distribution patterns', + 'input_schema': { + 'type': 'object', + 'properties': { + 'opensearch_cluster_name': {'type': 'string'}, + 'index': {'type': 'string'}, + }, + }, + 'function': MagicMock(), + 'args_model': MagicMock(), + 'min_version': '3.3.0', + 'http_methods': 'POST', + }, + 'LogPatternAnalysisTool': { + 'display_name': 'LogPatternAnalysisTool', + 'description': 'Analyze log patterns', + 'input_schema': { + 'type': 'object', + 'properties': { + 'opensearch_cluster_name': {'type': 'string'}, + 'index': {'type': 'string'}, + }, + }, + 'function': MagicMock(), + 'args_model': MagicMock(), + 'min_version': '3.3.0', + 'http_methods': 'POST', + }, } @@ -213,6 +243,51 @@ async def test_get_tools_default_mode_is_single(self, mock_tool_registry, mock_p not in result['SearchIndexTool']['input_schema']['properties'] ) + @pytest.mark.asyncio + async def test_get_tools_skills_tools_version_filtering(self, mock_tool_registry, mock_patches): + """Test that skills tools are filtered based on version compatibility.""" + mock_get_version, mock_is_compatible = mock_patches + + # Setup mocks - simulate OpenSearch 2.5.0 (below skills tools min version 3.3.0) + mock_get_version.return_value = Version.parse('2.5.0') + + # Mock compatibility: skills tools should be incompatible with 2.5.0 + def mock_compatibility(version, tool_info): + min_version = tool_info.get('min_version', '1.0.0') + return version >= Version.parse(min_version) + + mock_is_compatible.side_effect = mock_compatibility + + # Patch TOOL_REGISTRY to use our mock registry + with patch('tools.tool_filter.TOOL_REGISTRY', mock_tool_registry): + result = await get_tools(mock_tool_registry) + + # Skills tools should be filtered out due to version incompatibility + assert 'DataDistributionTool' not in result + assert 'LogPatternAnalysisTool' not in result + # Other tools should still be present + assert 'ListIndexTool' in result + assert 'SearchIndexTool' in result + + @pytest.mark.asyncio + async def test_get_tools_skills_tools_compatible_version(self, mock_tool_registry, mock_patches): + """Test that skills tools are included when OpenSearch version is compatible.""" + mock_get_version, mock_is_compatible = mock_patches + + # Setup mocks - simulate OpenSearch 3.5.0 (above skills tools min version 3.3.0) + mock_get_version.return_value = Version.parse('3.5.0') + mock_is_compatible.return_value = True # All tools compatible + + # Patch TOOL_REGISTRY to use our mock registry + with patch('tools.tool_filter.TOOL_REGISTRY', mock_tool_registry): + result = await get_tools(mock_tool_registry) + + # All tools should be present including skills tools + assert 'DataDistributionTool' in result + assert 'LogPatternAnalysisTool' in result + assert 'ListIndexTool' in result + assert 'SearchIndexTool' in result + @pytest.mark.asyncio async def test_get_tools_logs_version_info(self, mock_tool_registry, mock_patches, caplog): """Test that get_tools logs version information in single mode."""