diff --git a/README.md b/README.md index 7674ede..040a15d 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ path to `uvx`. ```bash # Claude CLI -claude mcp add massive -e MASSIVE_API_KEY=your_api_key_here -- uvx --from git+https://github.com/massive-com/mcp_massive@v0.6.0 mcp_massive +claude mcp add massive -e MASSIVE_API_KEY=your_api_key_here -- uvx --from git+https://github.com/massive-com/mcp_massive@v0.7.0 mcp_massive ``` This command will install the MCP server in your current project. @@ -83,7 +83,7 @@ Make sure you complete the various fields. "command": "/uvx", "args": [ "--from", - "git+https://github.com/massive-com/mcp_massive@v0.6.0", + "git+https://github.com/massive-com/mcp_massive@v0.7.0", "mcp_massive" ], "env": { @@ -130,12 +130,36 @@ This MCP server implements all Massive.com API endpoints as tools, including: - `get_last_trade` - Latest trade for a symbol - `list_ticker_news` - Recent news articles for tickers - `get_snapshot_ticker` - Current market snapshot for a ticker +- `list_snapshot_options_chain` - Option chain snapshot with greeks and market data - `get_market_status` - Current market status and trading hours - `list_stock_financials` - Fundamental financial data - And many more... Each tool follows the Massive.com SDK parameter structure while converting responses to standard JSON that LLMs can easily process. +### Output Filtering + +Some tools support output filtering to reduce response size and token usage. These tools accept additional parameters: + +| Parameter | Description | +|-----------|-------------| +| `fields` | Comma-separated field names or a preset (e.g., `"ticker,close"` or `"preset:greeks"`) | +| `output_format` | Output format: `"csv"` (default), `"json"`, or `"compact"` | +| `aggregate` | Return only `"first"` or `"last"` record | + +**Available field presets:** + +| Preset | Fields | +|--------|--------| +| `price` | ticker, close, timestamp | +| `ohlcv` | ticker, open, high, low, close, volume, timestamp | +| `summary` | ticker, close, volume, change_percent | +| `greeks` | details_ticker, details_strike_price, details_expiration_date, details_contract_type, greeks_delta, greeks_gamma, greeks_theta, greeks_vega, implied_volatility | +| `options_summary` | details_ticker, details_strike_price, details_expiration_date, details_contract_type, day_close, day_open, day_volume, open_interest, implied_volatility | +| `options_quote` | details_ticker, details_strike_price, details_contract_type, last_quote_bid, last_quote_ask, last_quote_bid_size, last_quote_ask_size | + +Example: `fields="preset:greeks"` returns only the greek values for options contracts. + ## Development ### Running Locally diff --git a/entrypoint.py b/entrypoint.py index b7968a6..27f2e73 100644 --- a/entrypoint.py +++ b/entrypoint.py @@ -3,6 +3,7 @@ Backwards-compatible entrypoint script. This script delegates to the main package CLI entry point. """ + from mcp_massive import main if __name__ == "__main__": diff --git a/src/mcp_massive/__init__.py b/src/mcp_massive/__init__.py index 989ca37..1910062 100644 --- a/src/mcp_massive/__init__.py +++ b/src/mcp_massive/__init__.py @@ -33,8 +33,12 @@ def main() -> None: if massive_api_key: print("Starting Massive MCP server with API key configured.") elif polygon_api_key: - print("Warning: POLYGON_API_KEY is deprecated. Please migrate to MASSIVE_API_KEY.") - print("Starting Massive MCP server with API key configured (using deprecated POLYGON_API_KEY).") + print( + "Warning: POLYGON_API_KEY is deprecated. Please migrate to MASSIVE_API_KEY." + ) + print( + "Starting Massive MCP server with API key configured (using deprecated POLYGON_API_KEY)." + ) # Set MASSIVE_API_KEY from POLYGON_API_KEY for backward compatibility os.environ["MASSIVE_API_KEY"] = polygon_api_key else: diff --git a/src/mcp_massive/filters.py b/src/mcp_massive/filters.py new file mode 100644 index 0000000..7b9bd75 --- /dev/null +++ b/src/mcp_massive/filters.py @@ -0,0 +1,222 @@ +""" +Output filtering module for MCP Massive server. + +This module provides server-side filtering capabilities to reduce context token usage +by allowing field selection, output format selection, and row aggregation. +""" + +import json +from dataclasses import dataclass +from typing import Optional, List, Dict, Any, Literal + + +# Field presets for common use cases +FIELD_PRESETS = { + # Price presets + "price": ["ticker", "close", "timestamp"], + "last_price": ["close"], + # OHLC presets + "ohlc": ["ticker", "open", "high", "low", "close", "timestamp"], + "ohlcv": ["ticker", "open", "high", "low", "close", "volume", "timestamp"], + # Summary presets + "summary": ["ticker", "close", "volume", "change_percent"], + "minimal": ["ticker", "close"], + # Volume presets + "volume": ["ticker", "volume", "timestamp"], + # Details presets + "details": ["ticker", "name", "market", "locale", "primary_exchange"], + "info": ["ticker", "name", "description", "homepage_url"], + # News presets + "news_headlines": ["title", "published_utc", "author"], + "news_summary": ["title", "description", "published_utc", "article_url"], + # Trade presets + "trade": ["price", "size", "timestamp"], + "quote": ["bid", "ask", "bid_size", "ask_size", "timestamp"], + # Options presets (field names are flattened from nested API response) + "greeks": [ + "details_ticker", + "details_strike_price", + "details_expiration_date", + "details_contract_type", + "greeks_delta", + "greeks_gamma", + "greeks_theta", + "greeks_vega", + "implied_volatility", + ], + "options_summary": [ + "details_ticker", + "details_strike_price", + "details_expiration_date", + "details_contract_type", + "day_close", + "day_open", + "day_volume", + "open_interest", + "implied_volatility", + ], + "options_quote": [ + "details_ticker", + "details_strike_price", + "details_contract_type", + "last_quote_bid", + "last_quote_ask", + "last_quote_bid_size", + "last_quote_ask_size", + ], +} + + +@dataclass +class FilterOptions: + """Options for filtering MCP tool outputs.""" + + # Field selection + fields: Optional[List[str]] = None # Include only these fields + exclude_fields: Optional[List[str]] = None # Exclude these fields + + # Output format + format: Literal["csv", "json", "compact"] = "csv" + + # Aggregation + aggregate: Optional[Literal["first", "last"]] = None + + # Row filtering (future enhancement) + conditions: Optional[Dict[str, Any]] = None # {"volume_gt": 1000000} + + +def parse_filter_params( + fields: Optional[str] = None, + output_format: str = "csv", + aggregate: Optional[str] = None, +) -> FilterOptions: + """ + Parse tool parameters into FilterOptions. + + Args: + fields: Comma-separated field names or preset name (e.g., "ticker,close" or "preset:price") + output_format: Desired output format ("csv", "json", or "compact") + aggregate: Aggregation method ("first", "last", or None) + + Returns: + FilterOptions instance + """ + # Parse fields parameter + field_list = None + if fields: + # Check if it's a preset + if fields.startswith("preset:"): + preset_name = fields[7:] # Remove "preset:" prefix + field_list = FIELD_PRESETS.get(preset_name) + if field_list is None: + raise ValueError( + f"Unknown preset: {preset_name}. Available presets: {', '.join(FIELD_PRESETS.keys())}" + ) + else: + # Parse comma-separated fields + field_list = [f.strip() for f in fields.split(",") if f.strip()] + + # Validate output format + if output_format not in ["csv", "json", "compact"]: + raise ValueError( + f"Invalid output_format: {output_format}. Must be 'csv', 'json', or 'compact'" + ) + + # Validate aggregate + if aggregate and aggregate not in ["first", "last"]: + raise ValueError( + f"Invalid aggregate: {aggregate}. Must be 'first', 'last', or None" + ) + + return FilterOptions( + fields=field_list, + format=output_format, + aggregate=aggregate, + ) + + +def apply_filters(data: dict | str, options: FilterOptions) -> str: + """ + Apply filtering to API response data. + + Args: + data: JSON string or dict from Massive API + options: Filtering options to apply + + Returns: + Filtered and formatted string response + """ + # Import formatters here to avoid circular imports + from .formatters import ( + json_to_csv_filtered, + json_to_compact, + json_to_json_filtered, + ) + + # Parse JSON if it's a string + if isinstance(data, str): + parsed_data = json.loads(data) + else: + parsed_data = data + + # Apply aggregation if specified + if options.aggregate: + parsed_data = _apply_aggregation(parsed_data, options.aggregate) + + # Route to appropriate formatter based on output format + if options.format == "csv": + return json_to_csv_filtered( + parsed_data, + fields=options.fields, + exclude_fields=options.exclude_fields, + ) + elif options.format == "json": + return json_to_json_filtered( + parsed_data, + fields=options.fields, + ) + elif options.format == "compact": + return json_to_compact( + parsed_data, + fields=options.fields, + ) + else: + raise ValueError(f"Unsupported format: {options.format}") + + +def _apply_aggregation(data: dict | list, method: str) -> dict | list: + """ + Apply aggregation to extract a single record. + + Args: + data: JSON data (dict or list) + method: Aggregation method ("first" or "last") + + Returns: + Aggregated data + """ + # Extract records + if isinstance(data, dict) and "results" in data: + records = data["results"] + elif isinstance(data, list): + records = data + else: + # Single record, return as-is + return data + + if not records: + return data + + # Apply aggregation + if method == "first": + aggregated_record = records[0] + elif method == "last": + aggregated_record = records[-1] + else: + raise ValueError(f"Unknown aggregation method: {method}") + + # Preserve structure + if isinstance(data, dict) and "results" in data: + return {**data, "results": [aggregated_record]} + else: + return [aggregated_record] diff --git a/src/mcp_massive/formatters.py b/src/mcp_massive/formatters.py index 55fe5f3..68ebc61 100644 --- a/src/mcp_massive/formatters.py +++ b/src/mcp_massive/formatters.py @@ -1,7 +1,7 @@ import json import csv import io -from typing import Any +from typing import Any, Optional, List def json_to_csv(json_input: str | dict) -> str: @@ -102,3 +102,187 @@ def _flatten_dict( items.append((new_key, v)) return dict(items) + + +def json_to_csv_filtered( + json_input: str | dict, + fields: Optional[List[str]] = None, + exclude_fields: Optional[List[str]] = None, +) -> str: + """ + Convert JSON to CSV with optional field filtering. + + Args: + json_input: JSON string or dict + fields: Include only these fields (None = all) + exclude_fields: Exclude these fields + + Returns: + CSV string with selected fields only + """ + # Parse JSON + if isinstance(json_input, str): + try: + data = json.loads(json_input) + except json.JSONDecodeError: + return "" + else: + data = json_input + + # Extract records + if isinstance(data, dict) and "results" in data: + results_value = data["results"] + if isinstance(results_value, list): + records = results_value + elif isinstance(results_value, dict): + records = [results_value] + else: + records = [results_value] + elif isinstance(data, dict) and "last" in data: + records = [data["last"]] if isinstance(data["last"], dict) else [data] + elif isinstance(data, list): + records = data + else: + records = [data] + + # Flatten records + flattened = [] + for record in records: + if isinstance(record, dict): + flattened.append(_flatten_dict(record)) + else: + flattened.append({"value": str(record)}) + + # Apply field filtering + if fields: + flattened = [ + {k: v for k, v in record.items() if k in fields} for record in flattened + ] + elif exclude_fields: + flattened = [ + {k: v for k, v in record.items() if k not in exclude_fields} + for record in flattened + ] + + # Convert to CSV + if not flattened: + return "" + + # Get all unique keys across all records (for consistent column ordering) + all_keys = [] + seen = set() + for record in flattened: + for key in record.keys(): + if key not in seen: + all_keys.append(key) + seen.add(key) + + output = io.StringIO() + writer = csv.DictWriter(output, fieldnames=all_keys, lineterminator="\n") + writer.writeheader() + writer.writerows(flattened) + + return output.getvalue() + + +def json_to_compact(json_input: str | dict, fields: Optional[List[str]] = None) -> str: + """ + Convert JSON to minimal compact format. + Best for single-record responses. + + Args: + json_input: JSON string or dict + fields: Include only these fields + + Returns: + Compact JSON string (e.g., '{"close": 185.92, "volume": 52165200}') + """ + if isinstance(json_input, str): + try: + data = json.loads(json_input) + except json.JSONDecodeError: + return "{}" + else: + data = json_input + + # Extract single record + if isinstance(data, dict) and "results" in data: + results = data["results"] + if isinstance(results, list): + record = results[0] if results else {} + else: + record = results + elif isinstance(data, dict) and "last" in data: + record = data["last"] if isinstance(data["last"], dict) else {} + elif isinstance(data, list): + record = data[0] if data else {} + else: + record = data + + # Flatten + if isinstance(record, dict): + flattened = _flatten_dict(record) + else: + flattened = {"value": str(record)} + + # Apply field filtering + if fields: + flattened = {k: v for k, v in flattened.items() if k in fields} + + return json.dumps(flattened, separators=(",", ":")) + + +def json_to_json_filtered( + json_input: str | dict, + fields: Optional[List[str]] = None, + preserve_structure: bool = False, +) -> str: + """ + Convert to JSON with optional field filtering. + + Args: + json_input: JSON string or dict + fields: Include only these fields + preserve_structure: Keep nested structure (don't flatten) + + Returns: + JSON string + """ + if isinstance(json_input, str): + try: + data = json.loads(json_input) + except json.JSONDecodeError: + return "[]" + else: + data = json_input + + if isinstance(data, dict) and "results" in data: + results_value = data["results"] + if isinstance(results_value, list): + records = results_value + elif isinstance(results_value, dict): + records = [results_value] + else: + records = [results_value] + elif isinstance(data, dict) and "last" in data: + records = [data["last"]] if isinstance(data["last"], dict) else [data] + elif isinstance(data, list): + records = data + else: + records = [data] + + if not preserve_structure: + flattened = [] + for record in records: + if isinstance(record, dict): + flattened.append(_flatten_dict(record)) + else: + flattened.append({"value": str(record)}) + records = flattened + + if fields: + records = [ + {k: v for k, v in record.items() if k in fields} for record in records + ] + + return json.dumps(records, indent=2) diff --git a/src/mcp_massive/server.py b/src/mcp_massive/server.py index 6ec2528..f423c39 100644 --- a/src/mcp_massive/server.py +++ b/src/mcp_massive/server.py @@ -15,7 +15,9 @@ MASSIVE_API_KEY = os.environ.get("MASSIVE_API_KEY", "") if not MASSIVE_API_KEY: print("Warning: MASSIVE_API_KEY environment variable not set.") - print("Please set it in your environment or create a .env file with MASSIVE_API_KEY=your_key") + print( + "Please set it in your environment or create a .env file with MASSIVE_API_KEY=your_key" + ) version_number = "MCP-Massive/unknown" try: @@ -29,6 +31,37 @@ poly_mcp = FastMCP("Massive") +def _apply_output_filtering( + raw_data: bytes, + fields: Optional[str] = None, + output_format: str = "csv", + aggregate: Optional[str] = None, +) -> str: + """ + Helper function to apply output filtering to API responses. + + Args: + raw_data: Raw bytes from API response + fields: Field selection (comma-separated or preset like "preset:greeks") + output_format: Output format (csv, json, compact) + aggregate: Aggregation method (first, last) + + Returns: + Filtered and formatted string response + """ + if fields or output_format != "csv" or aggregate: + from .filters import parse_filter_params, apply_filters + + filter_options = parse_filter_params( + fields=fields, + output_format=output_format, + aggregate=aggregate, + ) + return apply_filters(raw_data.decode("utf-8"), filter_options) + else: + return json_to_csv(raw_data.decode("utf-8")) + + @poly_mcp.tool(annotations=ToolAnnotations(readOnlyHint=True)) async def get_aggs( ticker: str, @@ -466,6 +499,109 @@ async def get_snapshot_crypto_book( return f"Error: {e}" +@poly_mcp.tool(annotations=ToolAnnotations(readOnlyHint=True)) +async def list_snapshot_options_chain( + underlying_asset: str, + strike_price: Optional[float] = None, + strike_price_lt: Optional[float] = None, + strike_price_lte: Optional[float] = None, + strike_price_gt: Optional[float] = None, + strike_price_gte: Optional[float] = None, + expiration_date: Optional[str] = None, + expiration_date_lt: Optional[str] = None, + expiration_date_lte: Optional[str] = None, + expiration_date_gt: Optional[str] = None, + expiration_date_gte: Optional[str] = None, + contract_type: Optional[str] = None, + limit: Optional[int] = 250, + sort: Optional[str] = None, + order: Optional[str] = None, + fields: Optional[str] = None, + output_format: str = "csv", + aggregate: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, +) -> str: + """ + Get option chain snapshot for an underlying asset with greeks and market data. + + Returns all option contracts for the specified underlying asset, including + current prices, volume, open interest, and greeks (delta, gamma, theta, vega). + + Args: + underlying_asset: The underlying ticker symbol (e.g., "AAPL") + strike_price: Filter by exact strike price + strike_price_lt/lte/gt/gte: Filter by strike price range + expiration_date: Filter by exact expiration date (YYYY-MM-DD) + expiration_date_lt/lte/gt/gte: Filter by expiration date range + contract_type: Filter by contract type ("call" or "put") + limit: Maximum number of results (default 250) + sort: Sort field + order: Sort order ("asc" or "desc") + fields: Field selection - comma-separated field names or preset + (e.g., "preset:greeks", "preset:options_summary", or "ticker,strike_price,greeks_delta") + output_format: Output format - "csv" (default), "json", or "compact" + aggregate: Aggregation method - "first" or "last" (for single record) + params: Additional API parameters + + Available presets for fields: + - "preset:greeks" - Greek values (delta, gamma, theta, vega) with IV + - "preset:options_summary" - Price, volume, open interest summary + - "preset:options_quote" - Bid/ask quote data + """ + try: + # Build params dict with all filter parameters + api_params = params.copy() if params else {} + + # Add strike price filters + if strike_price is not None: + api_params["strike_price"] = strike_price + if strike_price_lt is not None: + api_params["strike_price.lt"] = strike_price_lt + if strike_price_lte is not None: + api_params["strike_price.lte"] = strike_price_lte + if strike_price_gt is not None: + api_params["strike_price.gt"] = strike_price_gt + if strike_price_gte is not None: + api_params["strike_price.gte"] = strike_price_gte + + # Add expiration date filters + if expiration_date is not None: + api_params["expiration_date"] = expiration_date + if expiration_date_lt is not None: + api_params["expiration_date.lt"] = expiration_date_lt + if expiration_date_lte is not None: + api_params["expiration_date.lte"] = expiration_date_lte + if expiration_date_gt is not None: + api_params["expiration_date.gt"] = expiration_date_gt + if expiration_date_gte is not None: + api_params["expiration_date.gte"] = expiration_date_gte + + # Add other filters + if contract_type is not None: + api_params["contract_type"] = contract_type + if limit is not None: + api_params["limit"] = limit + if sort is not None: + api_params["sort"] = sort + if order is not None: + api_params["order"] = order + + results = massive_client.list_snapshot_options_chain( + underlying_asset=underlying_asset, + params=api_params if api_params else None, + raw=True, + ) + + return _apply_output_filtering( + results.data, + fields=fields, + output_format=output_format, + aggregate=aggregate, + ) + except Exception as e: + return f"Error: {e}" + + @poly_mcp.tool(annotations=ToolAnnotations(readOnlyHint=True)) async def get_market_holidays( params: Optional[Dict[str, Any]] = None, @@ -1430,25 +1566,25 @@ async def list_benzinga_news( sort: Optional[str] = None, ) -> str: """ - Retrieve real-time structured, timestamped news articles from Benzinga v2 API, including headlines, - full-text content, tickers, categories, and more. Each article entry contains metadata such as author, - publication time, and topic channels, as well as optional elements like teaser summaries, article body text, - and images. Articles can be filtered by ticker and time, and are returned in a consistent format for easy - parsing and integration. This endpoint is ideal for building alerting systems, autonomous risk analysis, + Retrieve real-time structured, timestamped news articles from Benzinga v2 API, including headlines, + full-text content, tickers, categories, and more. Each article entry contains metadata such as author, + publication time, and topic channels, as well as optional elements like teaser summaries, article body text, + and images. Articles can be filtered by ticker and time, and are returned in a consistent format for easy + parsing and integration. This endpoint is ideal for building alerting systems, autonomous risk analysis, and sentiment-driven trading strategies. - + Args: - published: The timestamp (formatted as an ISO 8601 timestamp) when the news article was originally + published: The timestamp (formatted as an ISO 8601 timestamp) when the news article was originally published. Value must be an integer timestamp in seconds or formatted 'yyyy-mm-dd'. channels: Filter for arrays that contain the value (e.g., 'News', 'Price Target'). tags: Filter for arrays that contain the value. author: The name of the journalist or entity that authored the news article. stocks: Filter for arrays that contain the value. tickers: Filter for arrays that contain the value. - limit: Limit the maximum number of results returned. Defaults to 100 if not specified. + limit: Limit the maximum number of results returned. Defaults to 100 if not specified. The maximum allowed limit is 50000. - sort: A comma separated list of sort columns. For each column, append '.asc' or '.desc' to specify - the sort direction. The sort column defaults to 'published' if not specified. + sort: A comma separated list of sort columns. For each column, append '.asc' or '.desc' to specify + the sort direction. The sort column defaults to 'published' if not specified. The sort order defaults to 'desc' if not specified. """ try: diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 0000000..40cf77f --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,237 @@ +"""Tests for the filters module.""" + +import json +import pytest + +from mcp_massive.filters import ( + FIELD_PRESETS, + FilterOptions, + parse_filter_params, + apply_filters, + _apply_aggregation, +) + + +class TestFieldPresets: + """Tests for the FIELD_PRESETS configuration.""" + + def test_greeks_preset_exists(self): + """Test that greeks preset exists with expected fields.""" + assert "greeks" in FIELD_PRESETS + preset = FIELD_PRESETS["greeks"] + assert "details_strike_price" in preset + assert "details_contract_type" in preset + assert "greeks_delta" in preset + assert "greeks_gamma" in preset + assert "greeks_theta" in preset + assert "greeks_vega" in preset + assert "implied_volatility" in preset + + def test_options_summary_preset_exists(self): + """Test that options_summary preset exists with expected fields.""" + assert "options_summary" in FIELD_PRESETS + preset = FIELD_PRESETS["options_summary"] + assert "details_strike_price" in preset + assert "details_expiration_date" in preset + assert "details_contract_type" in preset + assert "day_volume" in preset + assert "open_interest" in preset + + def test_options_quote_preset_exists(self): + """Test that options_quote preset exists with expected fields.""" + assert "options_quote" in FIELD_PRESETS + preset = FIELD_PRESETS["options_quote"] + assert "details_strike_price" in preset + assert "last_quote_bid" in preset + assert "last_quote_ask" in preset + + def test_price_presets_exist(self): + """Test that basic price presets exist.""" + assert "price" in FIELD_PRESETS + assert "ohlc" in FIELD_PRESETS + assert "ohlcv" in FIELD_PRESETS + + +class TestParseFilterParams: + """Tests for the parse_filter_params function.""" + + def test_parse_comma_separated_fields(self): + """Test parsing comma-separated field names.""" + options = parse_filter_params(fields="ticker,close,volume") + assert options.fields == ["ticker", "close", "volume"] + + def test_parse_preset_fields(self): + """Test parsing preset field names.""" + options = parse_filter_params(fields="preset:greeks") + assert options.fields == FIELD_PRESETS["greeks"] + + def test_parse_unknown_preset_raises_error(self): + """Test that unknown presets raise ValueError.""" + with pytest.raises(ValueError, match="Unknown preset"): + parse_filter_params(fields="preset:unknown_preset") + + def test_parse_output_format_csv(self): + """Test parsing CSV output format.""" + options = parse_filter_params(output_format="csv") + assert options.format == "csv" + + def test_parse_output_format_json(self): + """Test parsing JSON output format.""" + options = parse_filter_params(output_format="json") + assert options.format == "json" + + def test_parse_output_format_compact(self): + """Test parsing compact output format.""" + options = parse_filter_params(output_format="compact") + assert options.format == "compact" + + def test_parse_invalid_output_format_raises_error(self): + """Test that invalid output formats raise ValueError.""" + with pytest.raises(ValueError, match="Invalid output_format"): + parse_filter_params(output_format="xml") + + def test_parse_aggregate_first(self): + """Test parsing 'first' aggregation.""" + options = parse_filter_params(aggregate="first") + assert options.aggregate == "first" + + def test_parse_aggregate_last(self): + """Test parsing 'last' aggregation.""" + options = parse_filter_params(aggregate="last") + assert options.aggregate == "last" + + def test_parse_invalid_aggregate_raises_error(self): + """Test that invalid aggregation raises ValueError.""" + with pytest.raises(ValueError, match="Invalid aggregate"): + parse_filter_params(aggregate="average") + + def test_parse_none_fields(self): + """Test that None fields is preserved.""" + options = parse_filter_params(fields=None) + assert options.fields is None + + def test_parse_whitespace_in_fields(self): + """Test that whitespace is stripped from field names.""" + options = parse_filter_params(fields=" ticker , close , volume ") + assert options.fields == ["ticker", "close", "volume"] + + +class TestApplyAggregation: + """Tests for the _apply_aggregation helper function.""" + + def test_aggregate_first_from_results(self): + """Test extracting first record from results list.""" + data = {"results": [{"a": 1}, {"a": 2}, {"a": 3}]} + result = _apply_aggregation(data, "first") + assert result == {"results": [{"a": 1}]} + + def test_aggregate_last_from_results(self): + """Test extracting last record from results list.""" + data = {"results": [{"a": 1}, {"a": 2}, {"a": 3}]} + result = _apply_aggregation(data, "last") + assert result == {"results": [{"a": 3}]} + + def test_aggregate_first_from_list(self): + """Test extracting first record from plain list.""" + data = [{"a": 1}, {"a": 2}] + result = _apply_aggregation(data, "first") + assert result == [{"a": 1}] + + def test_aggregate_last_from_list(self): + """Test extracting last record from plain list.""" + data = [{"a": 1}, {"a": 2}] + result = _apply_aggregation(data, "last") + assert result == [{"a": 2}] + + def test_aggregate_single_record_unchanged(self): + """Test that single non-list record is unchanged.""" + data = {"ticker": "AAPL", "price": 150} + result = _apply_aggregation(data, "first") + assert result == data + + def test_aggregate_empty_results(self): + """Test aggregation of empty results.""" + data = {"results": []} + result = _apply_aggregation(data, "first") + assert result == {"results": []} + + +class TestApplyFilters: + """Tests for the apply_filters function.""" + + def test_apply_field_filter_csv(self): + """Test applying field filter with CSV output.""" + data = {"results": [{"ticker": "AAPL", "price": 150, "volume": 1000}]} + options = FilterOptions(fields=["ticker", "price"], format="csv") + result = apply_filters(data, options) + assert "ticker" in result + assert "price" in result + assert "volume" not in result + + def test_apply_field_filter_json(self): + """Test applying field filter with JSON output.""" + data = {"results": [{"ticker": "AAPL", "price": 150, "volume": 1000}]} + options = FilterOptions(fields=["ticker", "price"], format="json") + result = apply_filters(data, options) + parsed = json.loads(result) + assert len(parsed) == 1 + assert parsed[0]["ticker"] == "AAPL" + assert parsed[0]["price"] == 150 + assert "volume" not in parsed[0] + + def test_apply_compact_format(self): + """Test applying compact format.""" + data = {"results": [{"ticker": "AAPL", "price": 150}]} + options = FilterOptions(format="compact") + result = apply_filters(data, options) + # Compact format should be a compact JSON string + assert '"ticker"' in result or '"price"' in result + + def test_apply_aggregation_and_filter(self): + """Test combining aggregation with field filtering.""" + data = { + "results": [ + {"ticker": "AAPL", "price": 150}, + {"ticker": "GOOGL", "price": 2800}, + ] + } + options = FilterOptions(fields=["ticker"], aggregate="last", format="json") + result = apply_filters(data, options) + parsed = json.loads(result) + assert len(parsed) == 1 + assert parsed[0]["ticker"] == "GOOGL" + + def test_apply_filters_from_json_string(self): + """Test applying filters to JSON string input.""" + data = '{"results": [{"ticker": "AAPL", "price": 150}]}' + options = FilterOptions(fields=["ticker"], format="json") + result = apply_filters(data, options) + parsed = json.loads(result) + assert len(parsed) == 1 + assert parsed[0]["ticker"] == "AAPL" + + +class TestFilterOptionsDataclass: + """Tests for the FilterOptions dataclass.""" + + def test_default_values(self): + """Test default values of FilterOptions.""" + options = FilterOptions() + assert options.fields is None + assert options.exclude_fields is None + assert options.format == "csv" + assert options.aggregate is None + assert options.conditions is None + + def test_custom_values(self): + """Test custom values of FilterOptions.""" + options = FilterOptions( + fields=["a", "b"], + exclude_fields=["c"], + format="json", + aggregate="first", + ) + assert options.fields == ["a", "b"] + assert options.exclude_fields == ["c"] + assert options.format == "json" + assert options.aggregate == "first" diff --git a/tests/test_formatters.py b/tests/test_formatters.py index 5013d24..bdcd8f3 100644 --- a/tests/test_formatters.py +++ b/tests/test_formatters.py @@ -1,8 +1,6 @@ -import json import csv import io -import pytest from mcp_massive.formatters import json_to_csv, _flatten_dict