diff --git a/.env.example b/.env.example index 637491ebb..698dc586e 100644 --- a/.env.example +++ b/.env.example @@ -29,6 +29,8 @@ SEC_PARSER_MODEL_ID=openai/gpt-4o-mini SEC_ANALYSIS_MODEL_ID=deepseek/deepseek-chat-v3-0324 AI_HEDGE_FUND_PARSER_MODEL_ID=google/gemini-2.5-flash RESEARCH_AGENT_MODEL_ID=google/gemini-2.5-flash +PRODUCT_MODEL_ID=anthropic/claude-haiku-4.5 + # Embedding EMBEDDER_API_KEY= @@ -41,6 +43,10 @@ EMBEDDER_DIMENSION=1568 # Email address for SEC API requests (required by SEC) SEC_EMAIL= +# Set your https://xueqiu.com/ token if Yfinance data fetching is unstable. +# XUEQIU_TOKEN= + + # TradingAgents Configurations # refer to ./python/third_party/TradingAgents/.env.example for details # OpenAI API Key - Required for LLM models and online data search diff --git a/.gitignore b/.gitignore index 0c0b5625b..2b798f7b5 100644 --- a/.gitignore +++ b/.gitignore @@ -226,4 +226,5 @@ lancedb # Local files logs -.knowledge \ No newline at end of file +.knowledge +.txt \ No newline at end of file diff --git a/python/valuecell/adapters/assets/__init__.py b/python/valuecell/adapters/assets/__init__.py index 5403ffdf3..287dc1814 100644 --- a/python/valuecell/adapters/assets/__init__.py +++ b/python/valuecell/adapters/assets/__init__.py @@ -33,13 +33,8 @@ # Base adapter classes from .base import ( - AdapterError, - AuthenticationError, + AdapterCapability, BaseDataAdapter, - DataNotAvailableError, - InvalidTickerError, - RateLimitError, - TickerConverter, ) # Internationalization support @@ -66,6 +61,7 @@ AssetSearchResult, AssetType, DataSource, + Exchange, LocalizedName, MarketInfo, MarketStatus, @@ -90,18 +86,14 @@ "AssetType", "MarketStatus", "DataSource", + "Exchange", "MarketInfo", "LocalizedName", "Watchlist", "WatchlistItem", # Base classes "BaseDataAdapter", - "TickerConverter", - "AdapterError", - "RateLimitError", - "DataNotAvailableError", - "AuthenticationError", - "InvalidTickerError", + "AdapterCapability", # Adapters "YFinanceAdapter", "AKShareAdapter", diff --git a/python/valuecell/adapters/assets/akshare_adapter.py b/python/valuecell/adapters/assets/akshare_adapter.py index 100a18bf8..11e37003a 100644 --- a/python/valuecell/adapters/assets/akshare_adapter.py +++ b/python/valuecell/adapters/assets/akshare_adapter.py @@ -4,14 +4,11 @@ Global financial market data including stocks, funds, bonds, and economic indicators. """ -import decimal import logging -import re -import threading -import time +import os from datetime import datetime, timedelta from decimal import Decimal -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import pandas as pd @@ -20,7 +17,7 @@ except ImportError: ak = None -from .base import BaseDataAdapter +from .base import AdapterCapability, BaseDataAdapter from .types import ( Asset, AssetPrice, @@ -28,6 +25,7 @@ AssetSearchResult, AssetType, DataSource, + Exchange, Interval, LocalizedName, MarketInfo, @@ -57,2264 +55,1178 @@ def _initialize(self) -> None: """Initialize AKShare adapter configuration.""" self.timeout = self.config.get("timeout", 10) # Reduced timeout duration - # Different cache TTLs for different data types - self.price_cache_ttl = self.config.get( - "price_cache_ttl", 30 - ) # 30 seconds for real-time prices - self.info_cache_ttl = self.config.get( - "info_cache_ttl", 3600 - ) # 1 hour for stock info - self.hist_cache_ttl = self.config.get( - "hist_cache_ttl", 1800 - ) # 30 minutes for historical data - - self.max_retries = self.config.get("max_retries", 2) # Maximum retry attempts - - # Data caching with different TTLs - self._cache = {} - self._cache_lock = threading.Lock() - self._last_cache_clear = time.time() - - # Cache statistics for monitoring - self._cache_stats = {"hits": 0, "misses": 0, "evictions": 0} - - # Asset type mapping for AKShare - self.asset_type_mapping = { - "stock": AssetType.STOCK, - "fund": AssetType.ETF, - # "bond": AssetType.BOND, - "index": AssetType.INDEX, - } - # Field mapping - Handle AKShare API field changes self.field_mappings = { "a_shares": { "code": ["代码", "symbol", "ts_code"], "name": ["名称", "name", "short_name"], "price": ["最新价", "close", "price"], - "open": ["今开", "open"], + "open": ["今开", "开盘", "open"], "high": ["最高", "high"], "low": ["最低", "low"], + "close": ["收盘", "close"], "volume": ["成交量", "volume", "vol"], "market_cap": ["总市值", "total_mv"], + "change": ["涨跌额", "change"], + "change_percent": ["涨跌幅", "change_percent", "pct_chg"], + "date": ["日期", "date", "trade_date"], + "time": ["时间", "time", "datetime"], }, "hk_stocks": { "code": ["symbol", "code", "代码"], "name": ["name", "名称", "short_name"], + "open": ["开盘", "open"], + "high": ["最高", "high"], + "low": ["最低", "low"], + "close": ["收盘", "close"], + "volume": ["成交量", "volume", "vol"], + "change": ["涨跌额", "change"], + "change_percent": ["涨跌幅", "change_percent", "pct_chg"], + "date": ["日期", "date", "trade_date"], + "time": ["时间", "time", "datetime"], }, "us_stocks": { "code": ["代码", "symbol", "ticker"], "name": ["名称", "name", "short_name"], + "open": ["开盘", "open"], + "high": ["最高", "high"], + "low": ["最低", "low"], + "close": ["收盘", "close"], + "volume": ["成交量", "volume", "vol"], + "change": ["涨跌额", "change"], + "change_percent": ["涨跌幅", "change_percent", "pct_chg"], + "date": ["日期", "date", "trade_date"], + "time": ["时间", "time", "datetime"], }, } # Exchange mapping for AKShare self.exchange_mapping = { - "SH": "SSE", # Shanghai Stock Exchange - "SZ": "SZSE", # Shenzhen Stock Exchange - "BJ": "BSE", # Beijing Stock Exchange - "HK": "HKEX", # Hong Kong Stock Exchange - "US": "NASDAQ", # US markets (generic) - "NYSE": "NYSE", # New York Stock Exchange - "NASDAQ": "NASDAQ", # NASDAQ + "SH": Exchange.SSE.value, # Shanghai Stock Exchange + "SZ": Exchange.SZSE.value, # Shenzhen Stock Exchange + "BJ": Exchange.BSE.value, # Beijing Stock Exchange + "HK": Exchange.HKEX.value, # Hong Kong Stock Exchange + "NYSE": Exchange.NYSE.value, # New York Stock Exchange + "NASDAQ": Exchange.NASDAQ.value, # NASDAQ Exchange + "AMEX": Exchange.AMEX.value, # AMEX Exchange } + # US exchange codes for AKShare API + # AKShare requires exchange code prefix for US stocks and indices + # Format: exchange_code.SYMBOL (e.g., 105.AAPL for NASDAQ:AAPL, 100.IXIC for INDEX) + # Code 100: US Index data (for INDEX asset type) + # Code 105: NASDAQ stocks + # Code 106: NYSE stocks + # Code 107: AMEX stocks + self.us_exchange_codes = { + Exchange.NASDAQ: "105", + Exchange.NYSE: "106", + Exchange.AMEX: "107", + } + + # Special exchange code for US indices + self.us_index_exchange_code = "100" + + # Reverse mapping for converting AKShare format back to internal format + self.us_exchange_codes_reverse = { + v: k for k, v in self.us_exchange_codes.items() + } + # Add index code to reverse mapping + self.us_exchange_codes_reverse["100"] = None # Special handling for indices + logger.info("AKShare adapter initialized with caching and field mapping") - def _get_cached_data(self, cache_key: str, fetch_func, *args, **kwargs): - """Get cached data or fetch new data with adaptive TTL.""" - current_time = time.time() - - # Determine TTL based on cache key type - ttl = self._get_cache_ttl(cache_key) - - with self._cache_lock: - # Clean up expired cache periodically - if current_time - self._last_cache_clear > min( - self.price_cache_ttl, self.info_cache_ttl - ): - expired_keys = [ - key - for key, (_, timestamp, key_ttl) in self._cache.items() - if current_time - timestamp - > key_ttl * 2 # Keep expired data for fallback - ] - for key in expired_keys: - del self._cache[key] - self._cache_stats["evictions"] += 1 - self._last_cache_clear = current_time - - # Check for valid cache - if cache_key in self._cache: - cached_data, timestamp, key_ttl = self._cache[cache_key] - if current_time - timestamp < key_ttl: - logger.debug(f"Cache hit for {cache_key}") - self._cache_stats["hits"] += 1 - return cached_data - else: - logger.debug(f"Cache expired for {cache_key}") - - # Cache miss - self._cache_stats["misses"] += 1 - - # Fetch new data outside the lock to reduce lock time - try: - logger.debug(f"Fetching new data for {cache_key}") - data = fetch_func(*args, **kwargs) - with self._cache_lock: - self._cache[cache_key] = (data, current_time, ttl) - return data - except Exception as e: - logger.error(f"Failed to fetch data for {cache_key}: {e}") - # Try to return expired data as fallback - with self._cache_lock: - if cache_key in self._cache: - cached_data, _, _ = self._cache[cache_key] - logger.warning(f"Using expired cached data for {cache_key}") - return cached_data - raise - - def _get_cache_ttl(self, cache_key: str) -> int: - """Get appropriate TTL based on cache key type.""" - if "price" in cache_key or "spot" in cache_key: - return self.price_cache_ttl - elif "hist" in cache_key: - return self.hist_cache_ttl - else: - return self.info_cache_ttl - - def get_cache_stats(self) -> dict: - """Get cache statistics for monitoring.""" - with self._cache_lock: - total_requests = self._cache_stats["hits"] + self._cache_stats["misses"] - hit_rate = ( - self._cache_stats["hits"] / total_requests if total_requests > 0 else 0 - ) - return { - "cache_size": len(self._cache), - "hit_rate": hit_rate, - **self._cache_stats, - } + def _get_market_type(self, exchange: Exchange) -> str: + """Get market type identifier for field mapping. - def clear_cache(self) -> None: - """Clear all cached data.""" - with self._cache_lock: - self._cache.clear() - self._cache_stats = {"hits": 0, "misses": 0, "evictions": 0} - logger.info("Cache cleared") + Args: + exchange: Exchange enum - def _safe_get_field(self, data_row, field_type: str, market_type: str = "a_shares"): - """Safely get data field value, handling field name changes.""" - possible_fields = self.field_mappings.get(market_type, {}).get(field_type, []) + Returns: + Market type string ('a_shares', 'hk_stocks', or 'us_stocks') + """ + if exchange in [Exchange.SSE, Exchange.SZSE, Exchange.BSE]: + return "a_shares" + elif exchange == Exchange.HKEX: + return "hk_stocks" + elif exchange in [Exchange.NASDAQ, Exchange.NYSE, Exchange.AMEX]: + return "us_stocks" + else: + return "a_shares" # Default fallback - for field_name in possible_fields: - if field_name in data_row and data_row[field_name] is not None: - return data_row[field_name] + def _get_currency(self, exchange: Exchange) -> str: + """Get currency code based on exchange. - logger.debug(f"Field {field_type} not found in {market_type} data") - return None + Args: + exchange: Exchange enum - def _safe_akshare_call(self, func, *args, **kwargs): - """Safely call AKShare API with retry mechanism.""" - for attempt in range(self.max_retries + 1): - try: - # Set timeout - result = func(*args, **kwargs) - if result is not None and not ( - hasattr(result, "empty") and result.empty - ): - return result - else: - logger.warning( - f"AKShare API returned empty data on attempt {attempt + 1}" - ) - if attempt < self.max_retries: - time.sleep(1) # Wait 1 second before retry - continue - return None - except Exception as e: - logger.warning(f"AKShare API call failed on attempt {attempt + 1}: {e}") - if attempt < self.max_retries: - time.sleep(2**attempt) # Exponential backoff - continue - raise e - return None + Returns: + Currency code (CNY, HKD, or USD) + """ + if exchange in [Exchange.SSE, Exchange.SZSE, Exchange.BSE]: + return "CNY" + elif exchange == Exchange.HKEX: + return "HKD" + elif exchange in [Exchange.NASDAQ, Exchange.NYSE, Exchange.AMEX]: + return "USD" + else: + return "USD" # Default fallback - def _get_exchange_from_a_share_code(self, stock_code: str) -> Optional[tuple]: - """Get exchange and ticker from A-share stock code.""" - if stock_code.startswith("6"): - return ("SSE", f"SSE:{stock_code}") - elif stock_code.startswith(("0", "3")): - return ("SZSE", f"SZSE:{stock_code}") - elif stock_code.startswith("8"): - return ("BSE", f"BSE:{stock_code}") - return None + def _get_field_name( + self, df: pd.DataFrame, field: str, exchange: Exchange + ) -> Optional[str]: + """Get the actual field name from DataFrame based on exchange type. - def _create_stock_search_result( - self, - ticker: str, - asset_type: AssetType, - stock_code: str, - stock_name: str, - exchange: str, - country: str, - currency: str, - search_term: str, - ) -> AssetSearchResult: - """Create a standardized stock search result.""" - names = { - "zh-Hans": stock_name, - "zh-Hant": stock_name, - "en-US": stock_name, - } + Args: + df: DataFrame to search for field + field: Standard field name (e.g., 'open', 'close', 'high') + exchange: Exchange enum to determine which mapping to use - return AssetSearchResult( - ticker=ticker, - asset_type=asset_type, - names=names, - exchange=exchange, - country=country, - currency=currency, - market_status=MarketStatus.UNKNOWN, - relevance_score=self._calculate_relevance( - search_term, stock_code, stock_name - ), - ) + Returns: + Actual field name found in DataFrame, or None if not found + """ + # Get market type for field mapping + market = self._get_market_type(exchange) - def search_assets(self, query: AssetSearchQuery) -> List[AssetSearchResult]: - """Search for assets using AKShare direct queries.""" - try: - results = [] - search_term = query.query.strip() + # Get possible field names for this field + possible_names = self.field_mappings.get(market, {}).get(field, []) - # Direct ticker lookup strategy - try to match exact codes first - if self._looks_like_ticker(search_term): - results.extend(self._search_by_direct_ticker_lookup(search_term, query)) - if results: - return results[: query.limit] + # Check which field name exists in the DataFrame + for name in possible_names: + if name in df.columns: + return name - # Determine likely markets based on search term - likely_markets = self._determine_likely_markets(search_term, query) + # If not found, try the standard field name directly + if field in df.columns: + return field - # Search markets by priority using direct queries - if "a_shares" in likely_markets: - results.extend(self._search_a_shares_direct(search_term, query)) + return None - if "hk_stocks" in likely_markets and len(results) < query.limit: - results.extend(self._search_hk_stocks_direct(search_term, query)) + def search_assets(self, query: AssetSearchQuery) -> List[AssetSearchResult]: + """AKShare does not support search assets.""" + return [] - if "us_stocks" in likely_markets and len(results) < query.limit: - results.extend(self._search_us_stocks_direct(search_term, query)) + def __get_xq_symbol(self, ticker: str) -> str: + """Get XQ symbol for a specific asset. + Args: + ticker: Asset ticker in internal format (e.g., "SSE:601127", "HKEX:02097", "NASDAQ:NVDA") + Returns: + XQ symbol or None if not found + """ + try: + # Parse ticker to get exchange and symbol + if ":" not in ticker: + logger.warning( + f"Invalid ticker format: {ticker}, expected 'EXCHANGE:SYMBOL'" + ) + return None - if "etfs" in likely_markets and len(results) < query.limit: - results.extend(self._search_etfs_direct(search_term, query)) + exchange_str, symbol = ticker.split(":", 1) - # Apply filters - if query.asset_types: - results = [r for r in results if r.asset_type in query.asset_types] + # Convert exchange string to Exchange enum + try: + exchange = Exchange(exchange_str) + except ValueError: + logger.warning(f"Unknown exchange '{exchange_str}' for ticker {ticker}") + return None - if query.exchanges: - results = [r for r in results if r.exchange in query.exchanges] + # For A-shares (SSE, SZSE, BSE): format is "SH600519", "SZ000001", "BJ430047" + if exchange == Exchange.SSE: + return f"SH{symbol}" + elif exchange == Exchange.SZSE: + return f"SZ{symbol}" + elif exchange == Exchange.BSE: + return f"BJ{symbol}" - if query.countries: - results = [r for r in results if r.country in query.countries] + # For Hong Kong stocks: format is just the symbol without leading zeros (e.g., "00700" -> "700", or keep as is) + elif exchange == Exchange.HKEX: + # Remove leading zeros for XQ format + return symbol.lstrip("0") or "0" # Keep at least one zero if all zeros - # Sort by relevance score - results.sort(key=lambda x: x.relevance_score, reverse=True) + # For US stocks: format is just the symbol without exchange prefix + elif exchange in [Exchange.NASDAQ, Exchange.NYSE, Exchange.AMEX]: + return symbol - return results[: query.limit] + else: + logger.warning( + f"Unsupported exchange for XQ symbol conversion: {exchange}" + ) + return None except Exception as e: - logger.error(f"Error searching assets: {e}") - return [] + logger.error(f"Error converting ticker {ticker} to XQ symbol: {e}") + return None - def _determine_likely_markets( - self, search_term: str, query: AssetSearchQuery - ) -> List[str]: - """Intelligently determine likely markets for search term, reducing unnecessary network requests.""" - likely_markets = set() # Use set to avoid duplicates - search_term_upper = search_term.upper().strip() - - # Mappings for efficient lookup - exchange_market_map = { - "SSE": "a_shares", - "SZSE": "a_shares", - "BSE": "a_shares", - "HKEX": "hk_stocks", - "NASDAQ": "us_stocks", - "NYSE": "us_stocks", - "CRYPTO": "crypto", - } + def get_asset_info(self, ticker: str) -> Optional[Asset]: + """Get detailed information about a specific asset. + Args: + ticker: Asset ticker in internal format (e.g., "SSE:601127", "HKEX:02097", "NASDAQ:NVDA") + Returns: + Asset information or None if not found + """ + try: + # Get XQ symbol for the ticker + xq_symbol = self.__get_xq_symbol(ticker) + if not xq_symbol: + logger.warning(f"Cannot get XQ symbol for ticker: {ticker}") + return None - country_market_map = { - "CN": ["a_shares", "etfs"], - "HK": ["hk_stocks"], - "US": ["us_stocks"], - "GLOBAL": ["crypto"], - } + # Parse ticker to get exchange and symbol + exchange_str, symbol = ticker.split(":", 1) + exchange = Exchange(exchange_str) - # Determine markets based on query filters - if query.asset_types: - type_market_map = { - AssetType.ETF: "etfs", - AssetType.STOCK: ["a_shares", "hk_stocks", "us_stocks"], - } - for asset_type in query.asset_types: - markets = type_market_map.get(asset_type, []) - if isinstance(markets, str): - likely_markets.add(markets) - else: - likely_markets.update(markets) - - if query.exchanges: - for exchange in query.exchanges: - market = exchange_market_map.get(exchange) - if market: - likely_markets.add(market) - - if query.countries: - for country in query.countries: - markets = country_market_map.get(country, []) - likely_markets.update(markets) - - # If no explicit filters, determine based on search term patterns - if not likely_markets: - likely_markets.update(self._analyze_search_term_pattern(search_term_upper)) - - # If still empty, search all markets - if not likely_markets: - likely_markets = {"a_shares", "us_stocks", "hk_stocks", "etfs"} - - # Convert to list with priority order - priority_order = ["a_shares", "us_stocks", "hk_stocks", "etfs"] - result = [market for market in priority_order if market in likely_markets] - - logger.debug(f"Determined likely markets for '{search_term}': {result}") - return result - - def _analyze_search_term_pattern(self, search_term_upper: str) -> set: - """Analyze search term pattern to determine likely markets.""" - markets = set() - - # A-share code pattern (6 digits starting with specific numbers) - if ( - search_term_upper.isdigit() - and len(search_term_upper) == 6 - and search_term_upper.startswith(("6", "0", "3", "8")) - ): - markets.add("a_shares") - - # HK stock code pattern - elif ( - search_term_upper.isdigit() and 1 <= len(search_term_upper) <= 5 - ) or search_term_upper.startswith("00"): - markets.add("hk_stocks") - - # US stock/crypto pattern (letters) - elif search_term_upper.isalpha() and len(search_term_upper) <= 5: - markets.add("us_stocks") - - # Chinese names - prioritize A-shares - elif any("\u4e00" <= char <= "\u9fff" for char in search_term_upper): - markets.update(["a_shares", "hk_stocks"]) - - # Default case - else: - markets.update(["a_shares", "us_stocks", "hk_stocks", "etfs"]) - - return markets - - def _search_a_shares_direct( - self, search_term: str, query: AssetSearchQuery - ) -> List[AssetSearchResult]: - """Search A-share stocks using direct queries.""" - results = [] - - # If search term looks like A-share code, try direct lookup - if self._is_a_share_code(search_term): - result = self._get_a_share_by_code(search_term) - if result: - results.append(result) - return results - - # For name searches, try fuzzy matching with common patterns - # This is a simplified approach - in production, you might want to use - # a search service or maintain a local index - if len(search_term) >= 2: # Only search if term is meaningful - # Try some common A-share codes that might match the search term - candidate_codes = self._generate_a_share_candidates(search_term) - - for code in candidate_codes[: query.limit]: + # Call different AKShare APIs based on the market + df = None + + # A-shares market (SSE, SZSE, BSE) + if exchange in [Exchange.SSE, Exchange.SZSE, Exchange.BSE]: try: - result = self._get_a_share_by_code(code) - if result and self._matches_search_term(result, search_term): - results.append(result) + df = ak.stock_individual_basic_info_xq( + symbol=xq_symbol, token=os.getenv("XUEQIU_TOKEN", None) + ) except Exception as e: - logger.debug(f"Failed to get A-share info for {code}: {e}") - continue + logger.error( + f"Error fetching A-share info for {xq_symbol}: {e}", + exc_info=True, + ) + return None - return results + # Hong Kong stock market + elif exchange == Exchange.HKEX: + try: + df = ak.stock_individual_basic_info_hk_xq( + symbol=xq_symbol, token=os.getenv("XUEQIU_TOKEN", None) + ) + except Exception as e: + logger.error( + f"Error fetching HK stock info for {xq_symbol}: {e}", + exc_info=True, + ) + return None - def _is_a_share_code(self, search_term: str) -> bool: - """Check if search term looks like an A-share code.""" - return ( - search_term.isdigit() - and len(search_term) == 6 - and search_term.startswith(("6", "0", "3", "8")) - ) + # US stock market (NASDAQ, NYSE, AMEX) + elif exchange in [Exchange.NASDAQ, Exchange.NYSE, Exchange.AMEX]: + try: + df = ak.stock_individual_basic_info_us_xq( + symbol=xq_symbol, token=os.getenv("XUEQIU_TOKEN", None) + ) + except Exception as e: + logger.error( + f"Error fetching US stock info for {xq_symbol}: {e}", + exc_info=True, + ) + return None - def _get_a_share_by_code(self, stock_code: str) -> Optional[AssetSearchResult]: - """Get A-share info by stock code using direct query.""" - try: - # Use individual stock info query - cache_key = f"a_share_info_{stock_code}" - df_info = self._get_cached_data( - cache_key, - self._safe_akshare_call, - ak.stock_individual_info_em, - symbol=stock_code, - ) + else: + logger.warning(f"Unsupported exchange for asset info: {exchange}") + return None - if df_info is None or df_info.empty: + # Check if data was retrieved + if df is None or df.empty: + logger.warning(f"No data found for ticker: {ticker}") return None - # Extract stock name from info + # Convert DataFrame to dictionary for easier access info_dict = {} - for _, row in df_info.iterrows(): - info_dict[row["item"]] = row["value"] + for _, row in df.iterrows(): + item = row.get("item", "") + value = row.get("value", "") + if item and value: + info_dict[item] = value - stock_name = info_dict.get("股票名称", stock_code) - - # Determine exchange from code - exchange_info = self._get_exchange_from_a_share_code(stock_code) - if not exchange_info: - return None - - exchange, internal_ticker = exchange_info - - return self._create_stock_search_result( - internal_ticker, - AssetType.STOCK, - stock_code, - stock_name, - exchange, - "CN", - "CNY", - stock_code, - ) + # Create Asset object based on market type + return self._create_asset_from_info(ticker, exchange, info_dict) except Exception as e: - logger.debug(f"Error getting A-share info for {stock_code}: {e}") + logger.error(f"Error getting asset info for {ticker}: {e}", exc_info=True) return None - def _generate_a_share_candidates(self, search_term: str) -> List[str]: - """Generate candidate A-share codes based on search term.""" - candidates = [] - - # If it's a partial number, try to complete it - if search_term.isdigit() and len(search_term) < 6: - # Try common prefixes - for prefix in ["6", "0", "3"]: - if search_term.startswith(prefix) or not search_term.startswith( - ("6", "0", "3", "8") - ): - padded = search_term.ljust(6, "0") - if not search_term.startswith(("6", "0", "3", "8")): - candidates.extend( - [f"{prefix}{padded[1:]}" for prefix in ["6", "0", "3"]] - ) - else: - candidates.append(padded) - - # For Chinese names, we would need a mapping service - # For now, return some common stocks as examples - common_stocks = [ - "000001", # 平安银行 - "000002", # 万科A - "600000", # 浦发银行 - "600036", # 招商银行 - "600519", # 贵州茅台 - ] - - if not candidates and any("\u4e00" <= char <= "\u9fff" for char in search_term): - candidates.extend(common_stocks) + def _create_asset_from_info( + self, ticker: str, exchange: Exchange, info_dict: Dict[str, Any] + ) -> Optional[Asset]: + """Create Asset object from info dictionary. + Args: + ticker: Asset ticker in internal format + exchange: Exchange enum + info_dict: Dictionary containing asset information + Returns: + Asset object or None if creation fails + """ + try: + # Create localized names + localized_names = LocalizedName() + + # Determine country and currency based on exchange + if exchange in [Exchange.SSE, Exchange.SZSE, Exchange.BSE]: + # A-shares + country = "CN" + currency = info_dict.get("currency", "CNY") + timezone = "Asia/Shanghai" + + # Set Chinese and English names + cn_name = info_dict.get( + "org_short_name_cn", info_dict.get("org_name_cn", "") + ) + en_name = info_dict.get( + "org_short_name_en", info_dict.get("org_name_en", "") + ) - return candidates[:10] # Limit candidates + if cn_name: + localized_names.set_name("zh-Hans", cn_name) + localized_names.set_name("zh-CN", cn_name) + if en_name: + localized_names.set_name("en-US", en_name) + localized_names.set_name("en", en_name) + + # Use Chinese name as fallback if no English name + if not en_name and cn_name: + localized_names.set_name("en-US", cn_name) + + elif exchange == Exchange.HKEX: + # Hong Kong stocks + country = "HK" + currency = "HKD" + timezone = "Asia/Hong_Kong" + + # Set Chinese and English names + cn_name = info_dict.get("comcnname", "") + en_name = info_dict.get("comenname", "") + + if cn_name: + localized_names.set_name("zh-Hant", cn_name) + localized_names.set_name("zh-HK", cn_name) + if en_name: + localized_names.set_name("en-US", en_name) + localized_names.set_name("en", en_name) + + # Use English name as fallback if no Chinese name + if not cn_name and en_name: + localized_names.set_name("zh-Hant", en_name) + + elif exchange in [Exchange.NASDAQ, Exchange.NYSE, Exchange.AMEX]: + # US stocks + country = "US" + currency = "USD" + timezone = "America/New_York" + + # Set English and Chinese names + en_name = info_dict.get( + "org_short_name_en", info_dict.get("org_name_en", "") + ) + cn_name = info_dict.get( + "org_short_name_cn", info_dict.get("org_name_cn", "") + ) - def _matches_search_term(self, result: AssetSearchResult, search_term: str) -> bool: - """Check if search result matches the search term.""" - search_lower = search_term.lower() + if en_name: + localized_names.set_name("en-US", en_name) + localized_names.set_name("en", en_name) + if cn_name: + localized_names.set_name("zh-Hans", cn_name) + localized_names.set_name("zh-CN", cn_name) - # Check ticker - if search_lower in result.ticker.lower(): - return True + # Use symbol as fallback if no names available + if not en_name and not cn_name: + symbol = ticker.split(":")[1] + localized_names.set_name("en-US", symbol) - # Check names - for name in result.names.values(): - if name and search_lower in name.lower(): - return True + else: + logger.warning(f"Unsupported exchange: {exchange}") + return None - return False + # Create market info + market_info = MarketInfo( + exchange=exchange.value, + country=country, + currency=currency, + timezone=timezone, + market_status=MarketStatus.UNKNOWN, + ) - def _search_hk_stocks_direct( - self, search_term: str, query: AssetSearchQuery - ) -> List[AssetSearchResult]: - """Search Hong Kong stocks using direct queries.""" - results = [] + # Create Asset object + asset = Asset( + ticker=ticker, + asset_type=AssetType.STOCK, # Default to stock, can be enhanced later + names=localized_names, + market_info=market_info, + ) - # If search term looks like HK stock code, try direct lookup - if self._is_hk_stock_code(search_term): - result = self._get_hk_stock_by_code(search_term) - if result: - results.append(result) - return results + # Add source mapping for AKShare + asset.set_source_ticker( + DataSource.AKSHARE, self.convert_to_source_ticker(ticker) + ) - # For other searches, try common HK stock codes - candidate_codes = self._generate_hk_stock_candidates(search_term) + # Add additional properties from info_dict + for key, value in info_dict.items(): + if value and str(value).strip() and str(value).lower() != "none": + asset.add_property(key, value) - for code in candidate_codes[: query.limit]: + # Save asset metadata to database try: - result = self._get_hk_stock_by_code(code) - if result and self._matches_search_term(result, search_term): - results.append(result) - except Exception as e: - logger.debug(f"Failed to get HK stock info for {code}: {e}") - continue - - return results + from ...server.db.repositories.asset_repository import ( + get_asset_repository, + ) - def _is_hk_stock_code(self, search_term: str) -> bool: - """Check if search term looks like a HK stock code.""" - return search_term.isdigit() and 1 <= len(search_term) <= 5 + asset_repo = get_asset_repository() - def _get_hk_stock_by_code(self, stock_code: str) -> Optional[AssetSearchResult]: - """Get HK stock info by stock code using direct query.""" - try: - # Format HK stock code - pad to 5 digits - formatted_code = stock_code.zfill(5) + # Get the primary name for the asset + primary_name = ( + localized_names.get_name("en-US") + or localized_names.get_name("zh-Hans") + or localized_names.get_name("zh-Hant") + or ticker + ) - # Validate: HK stock codes should be 5 digits - if not (formatted_code.isdigit() and len(formatted_code) == 5): - return None + asset_repo.upsert_asset( + symbol=ticker, + name=primary_name, + asset_type=asset.asset_type.value, + description=None, # AKShare info_dict doesn't have structured description field + sector=info_dict.get("industry") or info_dict.get("sector"), + asset_metadata={ + "currency": currency, + "country": country, + "timezone": timezone, + "source": "akshare", + "info": { + k: v + for k, v in info_dict.items() + if v and str(v).strip() and str(v).lower() != "none" + }, + }, + ) + logger.debug(f"Saved asset info from AKShare for {ticker}") + except Exception as e: + # Don't fail the info fetch if database save fails + logger.warning( + f"Failed to save asset info from AKShare for {ticker}: {e}" + ) - # Create internal ticker in standard format - internal_ticker = f"HKEX:{formatted_code}" - - # Create basic result - in production, you might want to query actual HK stock info - return self._create_stock_search_result( - internal_ticker, - AssetType.STOCK, - formatted_code, - f"HK{formatted_code}", # Basic name - "HKEX", - "HK", - "HKD", - stock_code, - ) + return asset except Exception as e: - logger.debug(f"Error getting HK stock info for {stock_code}: {e}") + logger.error( + f"Error creating asset from info for {ticker}: {e}", exc_info=True + ) return None - def _generate_hk_stock_candidates(self, search_term: str) -> List[str]: - """Generate candidate HK stock codes based on search term.""" - candidates = [] - - # Common HK stocks - common_hk_stocks = [ - "00700", # 腾讯 - "00941", # 中国移动 - "01299", # 友邦保险 - "02318", # 中国平安 - "03988", # 中国银行 - ] + def get_real_time_price(self, ticker: str) -> Optional[AssetPrice]: + """Get real-time price data for an asset using Eastmoney 1-minute API. - if search_term.isdigit() and len(search_term) <= 5: - candidates.append(search_term.zfill(5)) - else: - candidates.extend(common_hk_stocks) + This method fetches the latest 1-minute price data to get real-time information. + Supports US stocks, Hong Kong stocks, and A-shares. - return candidates[:10] + Args: + ticker: Asset ticker in internal format (e.g., "SSE:600519", "HKEX:00700", "NASDAQ:AAPL") - def _search_us_stocks_direct( - self, search_term: str, query: AssetSearchQuery - ) -> List[AssetSearchResult]: - """Search US stocks using direct queries.""" - results = [] + Returns: + Latest AssetPrice object, or None if data not available + """ + try: + # Parse ticker to get exchange and symbol + if ":" not in ticker: + logger.warning(f"Invalid ticker format: {ticker}") + return None - # If search term looks like US stock symbol, try direct lookup - if self._is_us_stock_symbol(search_term): - result = self._get_us_stock_by_symbol(search_term) - if result: - results.append(result) - return results + exchange_str, symbol = ticker.split(":", 1) + try: + exchange = Exchange(exchange_str) + except ValueError: + logger.warning(f"Unknown exchange: {exchange_str}") + return None - # For other searches, try common US stock symbols - candidate_symbols = self._generate_us_stock_candidates(search_term) + # Convert to AKShare format + source_ticker = self.convert_to_source_ticker(ticker) - for symbol in candidate_symbols[: query.limit]: - try: - result = self._get_us_stock_by_symbol(symbol) - if result and self._matches_search_term(result, search_term): - results.append(result) - except Exception as e: - logger.debug(f"Failed to get US stock info for {symbol}: {e}") - continue + # Use current time as end time, and set start time to 1 day ago to ensure we get recent data + end_date = datetime.now() + start_date = end_date - timedelta(days=1) - return results + df = None - def _is_us_stock_symbol(self, search_term: str) -> bool: - """Check if search term looks like a US stock symbol.""" - return search_term.isalpha() and 1 <= len(search_term) <= 5 + # A-shares (SSE, SZSE, BSE) + if exchange in [Exchange.SSE, Exchange.SZSE, Exchange.BSE]: + try: + # Get 1-minute data (returns recent 5 trading days, no adjustment) + df = ak.stock_zh_a_hist_min_em( + symbol=symbol, + start_date=start_date.strftime("%Y-%m-%d %H:%M:%S"), + end_date=end_date.strftime("%Y-%m-%d %H:%M:%S"), + period="1", + adjust="", # 1-minute data cannot be adjusted + ) + except Exception as e: + logger.error( + f"Error fetching A-share real-time data for {symbol}: {e}" + ) + return None - def _get_us_stock_by_symbol(self, symbol: str) -> Optional[AssetSearchResult]: - """Get US stock info by symbol using direct query.""" - try: - # Create basic result - AKShare may not have direct individual US stock query - exchange = "NASDAQ" # Default to NASDAQ - internal_ticker = f"{exchange}:{symbol.upper()}" - - return self._create_stock_search_result( - internal_ticker, - AssetType.STOCK, - symbol.upper(), - symbol.upper(), # Basic name - exchange, - "US", - "USD", - symbol, - ) + # Hong Kong stocks + elif exchange == Exchange.HKEX: + try: + # Get 1-minute data for HK stocks + df = ak.stock_hk_hist_min_em( + symbol=symbol, + period="1", + adjust="", + start_date=start_date.strftime("%Y-%m-%d %H:%M:%S"), + end_date=end_date.strftime("%Y-%m-%d %H:%M:%S"), + ) + except Exception as e: + logger.error( + f"Error fetching HK stock real-time data for {symbol}: {e}" + ) + return None - except Exception as e: - logger.debug(f"Error getting US stock info for {symbol}: {e}") - return None + # US stocks + elif exchange in [Exchange.NASDAQ, Exchange.NYSE, Exchange.AMEX]: + try: + # US stock minute data API returns latest data + df = ak.stock_us_hist_min_em(symbol=source_ticker) + except Exception as e: + logger.error( + f"Error fetching US stock real-time data for {source_ticker}: {e}" + ) + return None - def _generate_us_stock_candidates(self, search_term: str) -> List[str]: - """Generate candidate US stock symbols based on search term.""" - candidates = [] - - # Common US stocks - common_us_stocks = [ - "AAPL", # Apple - "GOOGL", # Google - "MSFT", # Microsoft - "AMZN", # Amazon - "TSLA", # Tesla - ] + else: + logger.warning(f"Unsupported exchange for real-time data: {exchange}") + return None - if search_term.isalpha() and len(search_term) <= 5: - candidates.append(search_term.upper()) - else: - candidates.extend(common_us_stocks) + # Check if data was retrieved + if df is None or df.empty: + logger.warning(f"No real-time data found for {ticker}") + return None - return candidates[:10] + # Convert DataFrame to AssetPrice list (reuse existing conversion method) + prices = self._convert_intraday_df_to_prices(df, ticker, exchange) - def _search_etfs_direct( - self, search_term: str, query: AssetSearchQuery - ) -> List[AssetSearchResult]: - """Search ETFs using direct queries.""" - results = [] + # Return the most recent price (last entry) + if prices: + return prices[-1] + else: + logger.warning(f"Failed to convert real-time data for {ticker}") + return None - # If search term looks like ETF code, try direct lookup - if self._is_etf_code(search_term): - result = self._get_etf_by_code(search_term) - if result: - results.append(result) - return results + except Exception as e: + logger.error( + f"Error getting real-time price for {ticker}: {e}", exc_info=True + ) + return None - # For other searches, try common ETF codes - candidate_codes = self._generate_etf_candidates(search_term) + def get_historical_prices( + self, + ticker: str, + start_date: datetime, + end_date: datetime, + interval: str = "1d", + ) -> List[AssetPrice]: + """Get historical price data using Eastmoney API. - for code in candidate_codes[: query.limit]: - try: - result = self._get_etf_by_code(code) - if result and self._matches_search_term(result, search_term): - results.append(result) - except Exception as e: - logger.debug(f"Failed to get ETF info for {code}: {e}") - continue + Supports US stocks, Hong Kong stocks, and A-shares with qfq (forward adjusted) data. - return results + Args: + ticker: Asset ticker in internal format (e.g., "SSE:600519", "HKEX:00700", "NASDAQ:AAPL") + start_date: Start date for historical data + end_date: End date for historical data + interval: Data interval using format like "1m", "5m", "15m", "30m", "60m", "1d", "1w", "1mo" + Supported intervals: + - Minute: "1m", "5m", "15m", "30m", "60m" (intraday data) + - Daily: "1d" (default) + - Weekly: "1w" + - Monthly: "1mo" - def _is_etf_code(self, search_term: str) -> bool: - """Check if search term looks like an ETF code.""" - return ( - search_term.isdigit() - and len(search_term) == 6 - and search_term.startswith(("5", "1")) - ) + Returns: + List of historical price data - def _get_etf_by_code(self, fund_code: str) -> Optional[AssetSearchResult]: - """Get ETF info by code using direct query.""" + Note: + - 1-minute data returns only recent 5 trading days and cannot be adjusted + - Intraday data uses separate API endpoints with different limitations per exchange + """ try: - # Determine exchange for funds - exchange = "SSE" if fund_code.startswith("5") else "SZSE" - internal_ticker = f"{exchange}:{fund_code}" - - # Create basic result - in production, you might want to query actual ETF info - names = { - "zh-Hans": f"ETF{fund_code}", - "zh-Hant": f"ETF{fund_code}", - "en-US": f"ETF{fund_code}", - } + # Parse ticker to get exchange and symbol + if ":" not in ticker: + logger.warning(f"Invalid ticker format: {ticker}") + return [] - return AssetSearchResult( - ticker=internal_ticker, - asset_type=AssetType.ETF, - names=names, - exchange=exchange, - country="CN", - currency="CNY", - market_status=MarketStatus.UNKNOWN, - relevance_score=2.0, # High relevance for direct matches - ) + exchange_str, symbol = ticker.split(":", 1) + try: + exchange = Exchange(exchange_str) + except ValueError: + logger.warning(f"Unknown exchange: {exchange_str}") + return [] - except Exception as e: - logger.debug(f"Error getting ETF info for {fund_code}: {e}") - return None + # Convert to AKShare format + source_ticker = self.convert_to_source_ticker(ticker) + + # Map interval to Eastmoney API format + # For minute data: period='1'/'5'/'15'/'30'/'60' + # For daily/weekly/monthly: period='daily'/'weekly'/'monthly' + interval_mapping = { + # Minute intervals (intraday) + f"1{Interval.MINUTE}": "1", + f"5{Interval.MINUTE}": "5", + f"15{Interval.MINUTE}": "15", + f"30{Interval.MINUTE}": "30", + f"60{Interval.MINUTE}": "60", + # Daily/Weekly/Monthly intervals + f"1{Interval.DAY}": "daily", + f"1{Interval.WEEK}": "weekly", + f"1{Interval.MONTH}": "monthly", + } - def _generate_etf_candidates(self, search_term: str) -> List[str]: - """Generate candidate ETF codes based on search term.""" - candidates = [] - - # Common ETFs - common_etfs = [ - "510050", # 50ETF - "510300", # 沪深300ETF - "159919", # 沪深300ETF - "510500", # 中证500ETF - "159915", # 创业板ETF - ] + # Get the period value from mapping + period = interval_mapping.get(interval) + if not period: + logger.warning( + f"Unsupported interval: {interval}. " + f"Supported intervals: {', '.join(interval_mapping.keys())}" + ) + return [] - if search_term.isdigit() and len(search_term) == 6: - candidates.append(search_term) - else: - candidates.extend(common_etfs) + # Determine if this is intraday (minute-level) data + is_intraday = period in ["1", "5", "15", "30", "60"] - return candidates[:10] + if is_intraday: + # Use intraday data method for minute-level intervals + return self._get_intraday_prices( + ticker, exchange, source_ticker, start_date, end_date, period + ) - def _calculate_relevance(self, search_term: str, code: str, name: str) -> float: - """Calculate relevance score for search results.""" - search_term_lower = search_term.lower() - code_lower = code.lower() - name_lower = name.lower() + # Format dates for daily/weekly/monthly data + start_date_str = start_date.strftime("%Y%m%d") + end_date_str = end_date.strftime("%Y%m%d") - # Exact matches get highest score - if search_term_lower == code_lower or search_term_lower == name_lower: - return 2.0 + # Get historical data based on exchange + df = None - # Code starts with search term - if code_lower.startswith(search_term_lower): - return 1.8 + # A-shares (SSE, SZSE, BSE) + if exchange in [Exchange.SSE, Exchange.SZSE, Exchange.BSE]: + try: + df = ak.stock_zh_a_hist( + symbol=symbol, + period=period, + start_date=start_date_str, + end_date=end_date_str, + adjust="qfq", # Forward adjusted + ) + except Exception as e: + logger.error( + f"Error fetching A-share historical data for {symbol} with period {period}: {e}" + ) + return [] - # Name starts with search term - if name_lower.startswith(search_term_lower): - return 1.6 + # Hong Kong stocks + elif exchange == Exchange.HKEX: + try: + df = ak.stock_hk_hist( + symbol=symbol, + period=period, + start_date=start_date_str, + end_date=end_date_str, + adjust="qfq", # Forward adjusted + ) + except Exception as e: + logger.error( + f"Error fetching HK stock historical data for {symbol} with period {period}: {e}" + ) + return [] - # Code contains search term - if search_term_lower in code_lower: - return 1.4 + # US stocks + elif exchange in [Exchange.NASDAQ, Exchange.NYSE, Exchange.AMEX]: + try: + df = ak.stock_us_hist( + symbol=source_ticker, # US stocks need exchange code prefix + period=period, + start_date=start_date_str, + end_date=end_date_str, + adjust="qfq", # Forward adjusted + ) + except Exception as e: + logger.error( + f"Error fetching US stock historical data for {source_ticker} with period {period}: {e}" + ) + return [] - # Name contains search term - if search_term_lower in name_lower: - return 1.2 + else: + logger.warning(f"Unsupported exchange for historical data: {exchange}") + return [] - return 1.0 + # Check if data was retrieved + if df is None or df.empty: + logger.warning(f"No historical data found for {ticker}") + return [] - def get_asset_info(self, ticker: str) -> Optional[Asset]: - """Get detailed asset information from AKShare.""" - try: - exchange, symbol = ticker.split(":") - - # Handle different markets - if exchange in ["SSE", "SZSE", "BSE"]: - return self._get_a_share_info(ticker, exchange, symbol) - elif exchange == "HKEX": - return self._get_hk_stock_info(ticker, exchange, symbol) - elif exchange in ["NASDAQ", "NYSE"]: - return self._get_us_stock_info(ticker, exchange, symbol) - else: - logger.warning(f"Unsupported exchange: {exchange}") - return None + # Convert DataFrame to AssetPrice list + return self._convert_df_to_prices(df, ticker, exchange) except Exception as e: - logger.error(f"Error getting asset info for {ticker}: {e}") - return None + logger.error( + f"Error getting historical prices for {ticker}: {e}", exc_info=True + ) + return [] - def _get_a_share_info( - self, ticker: str, exchange: str, symbol: str - ) -> Optional[Asset]: - """Get A-share stock information.""" + def _get_intraday_prices( + self, + ticker: str, + exchange: Exchange, + source_ticker: str, + start_date: datetime, + end_date: datetime, + period: str, + ) -> List[AssetPrice]: + """Get intraday (minute-level) price data using Eastmoney API. + + Args: + ticker: Asset ticker in internal format + exchange: Exchange enum + source_ticker: Ticker in AKShare format + start_date: Start date and time + end_date: End date and time + period: Period value for Eastmoney API ('1', '5', '15', '30', '60') + + Returns: + List of intraday price data + + Note: + - period='1': 1-minute data, returns only recent 5 trading days, no adjustment + - period='5': 5-minute data + - period='15': 15-minute data + - period='30': 30-minute data + - period='60': 60-minute data (1 hour) + """ try: - # Use the new Snowball API for individual stock basic info - df_info = ak.stock_individual_basic_info_xq(symbol=symbol) + # Validate period value + if period not in ["1", "5", "15", "30", "60"]: + logger.warning( + f"Invalid period for intraday data: {period}. Expected one of: 1, 5, 15, 30, 60" + ) + return [] - if df_info is None or df_info.empty: - return None + # Format datetime strings + start_datetime_str = start_date.strftime("%Y-%m-%d %H:%M:%S") + end_datetime_str = end_date.strftime("%Y-%m-%d %H:%M:%S") - # Convert DataFrame to dict for easier access - info_dict = {} - for _, row in df_info.iterrows(): - info_dict[row["item"]] = row["value"] + # Get symbol from ticker + symbol = ticker.split(":", 1)[1] - # Create localized names - names = LocalizedName() - stock_name_cn = info_dict.get("org_short_name_cn", symbol) - stock_name_en = info_dict.get("org_short_name_en", symbol) - names.set_name("zh-Hans", stock_name_cn) - names.set_name("zh-Hant", stock_name_cn) - names.set_name("en-US", stock_name_en) + df = None - # Create market info - market_info = MarketInfo( - exchange=exchange, - country="CN", - currency="CNY", - timezone="Asia/Shanghai", - ) - - # Create asset - asset = Asset( - ticker=ticker, - asset_type=AssetType.STOCK, - names=names, - market_info=market_info, - ) - - # Set source mapping - asset.set_source_ticker(self.source, symbol) - - # Add additional properties from Snowball API - properties = { - "org_id": info_dict.get("org_id"), - "org_name_cn": info_dict.get("org_name_cn"), - "org_short_name_cn": info_dict.get("org_short_name_cn"), - "org_name_en": info_dict.get("org_name_en"), - "org_short_name_en": info_dict.get("org_short_name_en"), - "main_operation_business": info_dict.get("main_operation_business"), - "operating_scope": info_dict.get("operating_scope"), - "org_cn_introduction": info_dict.get("org_cn_introduction"), - "legal_representative": info_dict.get("legal_representative"), - "general_manager": info_dict.get("general_manager"), - "secretary": info_dict.get("secretary"), - "established_date": info_dict.get("established_date"), - "reg_asset": info_dict.get("reg_asset"), - "staff_num": info_dict.get("staff_num"), - "telephone": info_dict.get("telephone"), - "postcode": info_dict.get("postcode"), - "fax": info_dict.get("fax"), - "email": info_dict.get("email"), - "org_website": info_dict.get("org_website"), - "reg_address_cn": info_dict.get("reg_address_cn"), - "reg_address_en": info_dict.get("reg_address_en"), - "office_address_cn": info_dict.get("office_address_cn"), - "office_address_en": info_dict.get("office_address_en"), - "currency": info_dict.get("currency"), - "listed_date": info_dict.get("listed_date"), - "provincial_name": info_dict.get("provincial_name"), - "actual_controller": info_dict.get("actual_controller"), - "classi_name": info_dict.get("classi_name"), - "pre_name_cn": info_dict.get("pre_name_cn"), - "chairman": info_dict.get("chairman"), - "executives_nums": info_dict.get("executives_nums"), - "actual_issue_vol": info_dict.get("actual_issue_vol"), - "issue_price": info_dict.get("issue_price"), - "actual_rc_net_amt": info_dict.get("actual_rc_net_amt"), - "pe_after_issuing": info_dict.get("pe_after_issuing"), - "online_success_rate_of_issue": info_dict.get( - "online_success_rate_of_issue" - ), - "affiliate_industry": info_dict.get("affiliate_industry"), - } - - # Filter out None values - properties = {k: v for k, v in properties.items() if v is not None} - asset.properties.update(properties) - - return asset - - except Exception as e: - logger.error(f"Error fetching A-share info for {symbol}: {e}") - return None - - def _get_hk_stock_info( - self, ticker: str, exchange: str, symbol: str - ) -> Optional[Asset]: - """Get Hong Kong stock information.""" - try: - # For HK stocks, we'll create basic info since detailed info API may be limited - names = LocalizedName() - names.set_name("zh-Hans", symbol) - names.set_name("zh-Hant", symbol) - names.set_name("en-US", symbol) - - market_info = MarketInfo( - exchange=exchange, - country="HK", - currency="HKD", - timezone="Asia/Hong_Kong", - ) - - asset = Asset( - ticker=ticker, - asset_type=AssetType.STOCK, - names=names, - market_info=market_info, - ) - - asset.set_source_ticker(self.source, symbol) - return asset - - except Exception as e: - logger.error(f"Error creating HK stock info for {symbol}: {e}") - return None - - def _get_us_stock_info( - self, ticker: str, exchange: str, symbol: str - ) -> Optional[Asset]: - """Get US stock information.""" - try: - # For US stocks, we'll create basic info since detailed info API may be limited - names = LocalizedName() - names.set_name("zh-Hans", symbol) - names.set_name("zh-Hant", symbol) - names.set_name("en-US", symbol) - - market_info = MarketInfo( - exchange=exchange, - country="US", - currency="USD", - timezone="America/New_York", - ) - - asset = Asset( - ticker=ticker, - asset_type=AssetType.STOCK, - names=names, - market_info=market_info, - ) - - asset.set_source_ticker(self.source, symbol) - return asset - - except Exception as e: - logger.error(f"Error creating US stock info for {symbol}: {e}") - return None - - def get_real_time_price(self, ticker: str) -> Optional[AssetPrice]: - """Get real-time price data from AKShare.""" - try: - exchange, symbol = ticker.split(":") - - # Handle different markets - if exchange in ["SSE", "SZSE", "BSE"]: - return self._get_a_share_price(ticker, exchange, symbol) - elif exchange == "HKEX": - return self._get_hk_stock_price(ticker, exchange, symbol) - elif exchange in ["NASDAQ", "NYSE"]: - return self._get_us_stock_price(ticker, exchange, symbol) - else: - logger.warning(f"Unsupported exchange for real-time price: {exchange}") - return None - - except Exception as e: - logger.error(f"Error getting real-time price for {ticker}: {e}") - return None - - def _get_a_share_price( - self, ticker: str, exchange: str, symbol: str - ) -> Optional[AssetPrice]: - """Get A-share real-time price using direct query.""" - try: - # Use direct real-time price query - stock_zh_a_spot_em takes no parameters - cache_key = "a_share_price_all" - df_realtime = self._get_cached_data( - cache_key, self._safe_akshare_call, ak.stock_zh_a_spot_em - ) - - if df_realtime is None or df_realtime.empty: - # Fallback to individual stock info if spot price fails - return self._get_a_share_price_from_info(ticker, exchange, symbol) - - # Find the specific stock in the A-share data - # The dataframe contains all A-shares, we need to filter by stock code - stock_data = df_realtime[df_realtime["代码"] == symbol] - if stock_data.empty: - # If not found by exact match, try alternative matching - logger.warning( - f"Stock {symbol} not found in A-share spot data, falling back to individual info" - ) - return self._get_a_share_price_from_info(ticker, exchange, symbol) - - stock_info = stock_data.iloc[0] - - # Extract price information using safe field access - current_price = self._safe_decimal_convert(stock_info.get("最新价", 0)) - open_price = self._safe_decimal_convert(stock_info.get("今开", 0)) - high_price = self._safe_decimal_convert(stock_info.get("最高", 0)) - low_price = self._safe_decimal_convert(stock_info.get("最低", 0)) - pre_close = self._safe_decimal_convert(stock_info.get("昨收", 0)) - - # Calculate change - change = current_price - pre_close if current_price and pre_close else None - change_percent = ( - (change / pre_close) * 100 - if change and pre_close and pre_close != 0 - else None - ) - - # Get volume and market cap - volume = self._safe_decimal_convert(stock_info.get("成交量")) - market_cap = self._safe_decimal_convert(stock_info.get("总市值")) - - return AssetPrice( - ticker=ticker, - price=current_price, - currency="CNY", - timestamp=datetime.now(), - volume=volume, - open_price=open_price, - high_price=high_price, - low_price=low_price, - close_price=current_price, - change=change, - change_percent=change_percent, - market_cap=market_cap, - source=self.source, - ) - - except Exception as e: - logger.error(f"Error fetching A-share price for {symbol}: {e}") - return None - - def _get_a_share_price_from_info( - self, ticker: str, exchange: str, symbol: str - ) -> Optional[AssetPrice]: - """Get A-share price from individual stock info as fallback.""" - try: - # Try to get basic price info from stock individual info - cache_key = f"a_share_info_price_{symbol}" - df_info = self._get_cached_data( - cache_key, - self._safe_akshare_call, - ak.stock_individual_info_em, - symbol=symbol, - ) - - if df_info is None or df_info.empty: - logger.warning(f"No individual stock info available for {symbol}") - return None - - # Convert DataFrame to dict for easier access - info_dict = {} - for _, row in df_info.iterrows(): - info_dict[row["item"]] = row["value"] - - # Extract current price from the individual info (if available) - current_price_value = info_dict.get("最新", info_dict.get("现价", 0)) - current_price = self._safe_decimal_convert(current_price_value) - - # Get market cap and other info - market_cap = self._safe_decimal_convert(info_dict.get("总市值")) - - if not current_price or current_price == 0: - logger.warning( - f"No valid current price found for {symbol} in individual info" - ) - return None - - return AssetPrice( - ticker=ticker, - price=current_price, - currency="CNY", - timestamp=datetime.now(), - volume=None, # Not available in individual info - open_price=None, # Not available in individual info - high_price=None, # Not available in individual info - low_price=None, # Not available in individual info - close_price=current_price, - change=None, # Not available in individual info - change_percent=None, # Not available in individual info - market_cap=market_cap, - source=self.source, - ) - - except Exception as e: - logger.error(f"Error fetching A-share info price for {symbol}: {e}") - return None - - def _safe_decimal_convert(self, value) -> Optional[Decimal]: - """Safely convert value to Decimal.""" - if value is None or value == "": - return None - try: - return Decimal(str(value)) - except (ValueError, TypeError, decimal.InvalidOperation): - return None - - def _get_hk_stock_price( - self, ticker: str, exchange: str, symbol: str - ) -> Optional[AssetPrice]: - """Get Hong Kong stock real-time price using individual stock query.""" - try: - # Use individual stock query instead of downloading all HK stocks - # Try to get individual stock info first - try: - # For HK stocks, try to get historical data as a proxy for current price - df_hk_hist = ak.stock_hk_daily(symbol=symbol, adjust="qfq") - if df_hk_hist is not None and not df_hk_hist.empty: - latest = df_hk_hist.iloc[-1] - current_price = Decimal( - str(latest.get("close", latest.get("收盘", 0))) + # A-shares (SSE, SZSE, BSE) + if exchange in [Exchange.SSE, Exchange.SZSE, Exchange.BSE]: + try: + # Note: 1-minute data only returns recent 5 trading days and cannot be adjusted + adjust = "" if period == "1" else "qfq" + df = ak.stock_zh_a_hist_min_em( + symbol=symbol, + start_date=start_datetime_str, + end_date=end_datetime_str, + period=period, + adjust=adjust, ) - - return AssetPrice( - ticker=ticker, - price=current_price, - currency="HKD", - timestamp=datetime.now(), - volume=Decimal( - str(latest.get("volume", latest.get("成交量", 0))) - ) - if latest.get("volume", latest.get("成交量", 0)) - else None, - open_price=Decimal( - str(latest.get("open", latest.get("开盘", 0))) - ), - high_price=Decimal( - str(latest.get("high", latest.get("最高", 0))) - ), - low_price=Decimal( - str(latest.get("low", latest.get("最低", 0))) - ), - close_price=current_price, - change=None, - change_percent=None, - market_cap=None, - source=self.source, + except Exception as e: + logger.error( + f"Error fetching A-share intraday data for {symbol}: {e}" ) - except Exception as e: - logger.debug(f"Individual HK stock query failed for {symbol}: {e}") - - # Fallback: return None instead of downloading all HK stocks - logger.warning( - f"Unable to get HK stock price for {symbol} without full market data download" - ) - return None - - except Exception as e: - logger.error(f"Error fetching HK stock price for {symbol}: {e}") - return None + return [] - def _get_us_stock_price( - self, ticker: str, exchange: str, symbol: str - ) -> Optional[AssetPrice]: - """Get US stock real-time price using individual stock query.""" - try: - # Use individual stock query instead of downloading all US stocks - try: - # For US stocks, try to get historical data as a proxy for current price - df_us_hist = ak.stock_us_daily(symbol=symbol, adjust="qfq") - if df_us_hist is not None and not df_us_hist.empty: - latest = df_us_hist.iloc[-1] - current_price = Decimal( - str(latest.get("close", latest.get("收盘", 0))) + # Hong Kong stocks + elif exchange == Exchange.HKEX: + try: + # Note: HK stock minute data doesn't support adjust parameter + df = ak.stock_hk_hist_min_em( + symbol=symbol, + period=period, + adjust="", # HK stocks don't support adjustment for minute data + start_date=start_datetime_str, + end_date=end_datetime_str, ) - - return AssetPrice( - ticker=ticker, - price=current_price, - currency="USD", - timestamp=datetime.now(), - volume=Decimal( - str(latest.get("volume", latest.get("成交量", 0))) - ) - if latest.get("volume", latest.get("成交量", 0)) - else None, - open_price=Decimal( - str(latest.get("open", latest.get("开盘", 0))) - ), - high_price=Decimal( - str(latest.get("high", latest.get("最高", 0))) - ), - low_price=Decimal( - str(latest.get("low", latest.get("最低", 0))) - ), - close_price=current_price, - change=None, - change_percent=None, - market_cap=None, - source=self.source, + except Exception as e: + logger.error( + f"Error fetching HK stock intraday data for {symbol}: {e}" ) - except Exception as e: - logger.debug(f"Individual US stock query failed for {symbol}: {e}") - - # Fallback: return None instead of downloading all US stocks - logger.warning( - f"Unable to get US stock price for {symbol} without full market data download" - ) - return None + return [] - except Exception as e: - logger.error(f"Error fetching US stock price for {symbol}: {e}") - return None - - def get_historical_prices( - self, - ticker: str, - start_date: datetime, - end_date: datetime, - interval: str = "1d", - ) -> List[AssetPrice]: - """Get historical price data from AKShare.""" - try: - exchange, symbol = ticker.split(":") + # US stocks + elif exchange in [Exchange.NASDAQ, Exchange.NYSE, Exchange.AMEX]: + try: + # Note: US stock minute data API only returns latest data, doesn't support date range + df = ak.stock_us_hist_min_em(symbol=source_ticker) + except Exception as e: + logger.error( + f"Error fetching US stock intraday data for {source_ticker}: {e}" + ) + return [] - # Handle different markets - if exchange in ["SSE", "SZSE", "BSE"]: - return self._get_a_share_historical( - ticker, exchange, symbol, start_date, end_date, interval - ) - elif exchange == "HKEX": - return self._get_hk_stock_historical( - ticker, exchange, symbol, start_date, end_date, interval - ) - elif exchange in ["NASDAQ", "NYSE"]: - return self._get_us_stock_historical( - ticker, exchange, symbol, start_date, end_date, interval - ) else: - logger.warning(f"Unsupported exchange for historical data: {exchange}") + logger.warning(f"Unsupported exchange for intraday data: {exchange}") return [] + # Check if data was retrieved + if df is None or df.empty: + logger.warning(f"No intraday data found for {ticker}") + return [] + + # Convert DataFrame to AssetPrice list + return self._convert_intraday_df_to_prices(df, ticker, exchange) + except Exception as e: - logger.error(f"Error getting historical prices for {ticker}: {e}") + logger.error( + f"Error getting intraday prices for {ticker}: {e}", exc_info=True + ) return [] - def _get_a_share_historical( - self, - ticker: str, - exchange: str, - symbol: str, - start_date: datetime, - end_date: datetime, - interval: str, + def _convert_df_to_prices( + self, df: pd.DataFrame, ticker: str, exchange: Exchange ) -> List[AssetPrice]: - """Get A-share historical price data using direct query. + """Convert historical price DataFrame to list of AssetPrice objects. Args: + df: DataFrame containing historical price data ticker: Asset ticker in internal format - start_date: Start date for historical data, format: YYYY-MM-DD, timezone: UTC - end_date: End date for historical data, format: YYYY-MM-DD, timezone: UTC - interval: Data interval (e.g., "1d", "1h", "5m") + exchange: Exchange enum Returns: - List of historical price data + List of AssetPrice objects """ - try: - # Map interval to AKShare format and determine if intraday data is needed - akshare_params = self._map_interval_to_akshare_params(interval) - if not akshare_params: - logger.warning(f"Unsupported interval: {interval}") - return [] - - is_intraday = akshare_params["is_intraday"] - period_or_minutes = akshare_params["period"] - - if is_intraday: - return self._get_a_share_intraday_historical( - ticker, exchange, symbol, start_date, end_date, period_or_minutes - ) - else: - return self._get_a_share_daily_historical( - ticker, exchange, symbol, start_date, end_date, period_or_minutes - ) - - except Exception as e: - logger.error(f"Error fetching A-share historical data for {symbol}: {e}") - return [] - - def _map_interval_to_akshare_params(self, interval: str) -> Optional[dict]: - """Map interval to AKShare parameters, similar to yfinance mapping. - - Returns dict with 'is_intraday' and 'period' keys, or None if unsupported. - """ - # Create interval mapping similar to yfinance adapter - interval_mapping = { - # Minute intervals (intraday data) - f"1{Interval.MINUTE}": {"is_intraday": True, "period": "1"}, - f"5{Interval.MINUTE}": {"is_intraday": True, "period": "5"}, - f"15{Interval.MINUTE}": {"is_intraday": True, "period": "15"}, - f"30{Interval.MINUTE}": {"is_intraday": True, "period": "30"}, - f"60{Interval.MINUTE}": {"is_intraday": True, "period": "60"}, - # Daily and higher intervals - f"1{Interval.DAY}": {"is_intraday": False, "period": "daily"}, - f"1{Interval.WEEK}": {"is_intraday": False, "period": "weekly"}, - f"1{Interval.MONTH}": {"is_intraday": False, "period": "monthly"}, - # Common aliases - "1d": {"is_intraday": False, "period": "daily"}, - "daily": {"is_intraday": False, "period": "daily"}, - "1w": {"is_intraday": False, "period": "weekly"}, - "weekly": {"is_intraday": False, "period": "weekly"}, - "1mo": {"is_intraday": False, "period": "monthly"}, - "monthly": {"is_intraday": False, "period": "monthly"}, - "1m": {"is_intraday": True, "period": "1"}, - "5m": {"is_intraday": True, "period": "5"}, - "15m": {"is_intraday": True, "period": "15"}, - "30m": {"is_intraday": True, "period": "30"}, - "60m": {"is_intraday": True, "period": "60"}, - } - - return interval_mapping.get(interval) + prices = [] - def _get_a_share_daily_historical( - self, - ticker: str, - exchange: str, - symbol: str, - start_date: datetime, - end_date: datetime, - period: str, - ) -> List[AssetPrice]: - """Get A-share daily historical price data.""" try: - # Format dates for AKShare - start_date_str = start_date.strftime("%Y%m%d") - end_date_str = end_date.strftime("%Y%m%d") - - # Use cached data for historical prices - cache_key = ( - f"a_share_hist_{symbol}_{start_date_str}_{end_date_str}_{period}" - ) - df_hist = self._get_cached_data( - cache_key, - self._safe_akshare_call, - ak.stock_zh_a_hist, - symbol=symbol, - period=period, - start_date=start_date_str, - end_date=end_date_str, - adjust="qfq", # Use forward adjustment - ) - - if df_hist is None or df_hist.empty: - logger.warning(f"No daily historical data available for {symbol}") + # Get currency based on exchange + currency = self._get_currency(exchange) + + # Use field mapping helper to get actual field names + date_field = self._get_field_name(df, "date", exchange) + open_field = self._get_field_name(df, "open", exchange) + close_field = self._get_field_name(df, "close", exchange) + high_field = self._get_field_name(df, "high", exchange) + low_field = self._get_field_name(df, "low", exchange) + volume_field = self._get_field_name(df, "volume", exchange) + change_field = self._get_field_name(df, "change", exchange) + change_pct_field = self._get_field_name(df, "change_percent", exchange) + + # Validate required fields + if not date_field or not close_field: + logger.error( + f"Missing required fields in DataFrame. date_field={date_field}, close_field={close_field}" + ) return [] - return self._process_a_share_daily_data(ticker, df_hist) - - except Exception as e: - logger.error( - f"Error fetching A-share daily historical data for {symbol}: {e}" - ) - return [] + for _, row in df.iterrows(): + try: + # Parse date + date_str = str(row[date_field]) + if len(date_str) == 8: # Format: YYYYMMDD + timestamp = datetime.strptime(date_str, "%Y%m%d") + else: + # Try parsing as standard date format + timestamp = pd.to_datetime(date_str) - def _get_a_share_intraday_historical( - self, - ticker: str, - exchange: str, - symbol: str, - start_date: datetime, - end_date: datetime, - period: str, - ) -> List[AssetPrice]: - """Get A-share intraday historical price data using minute data.""" - try: - # Format dates for AKShare intraday query - start_date_str = start_date.strftime("%Y-%m-%d %H:%M:%S") - end_date_str = end_date.strftime("%Y-%m-%d %H:%M:%S") - - # Use cached data for intraday historical prices - cache_key = f"a_share_hist_min_{symbol}_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}_{period}" - - # Note: AKShare minute data has limitations - only recent 5 trading days for 1-minute - # and 1-minute data doesn't support forward adjustment - adjust_param = ( - "" if period == "1" else "qfq" - ) # 1-minute data doesn't support adjustment - - df_hist = self._get_cached_data( - cache_key, - self._safe_akshare_call, - ak.stock_zh_a_hist_min_em, - symbol=symbol, - start_date=start_date_str, - end_date=end_date_str, - period=period, - adjust=adjust_param, - ) + # Create AssetPrice object + price = AssetPrice( + ticker=ticker, + price=Decimal(str(row[close_field])), + currency=currency, + timestamp=timestamp, + open_price=Decimal(str(row[open_field])) + if open_field and pd.notna(row[open_field]) + else None, + high_price=Decimal(str(row[high_field])) + if high_field and pd.notna(row[high_field]) + else None, + low_price=Decimal(str(row[low_field])) + if low_field and pd.notna(row[low_field]) + else None, + close_price=Decimal(str(row[close_field])) + if pd.notna(row[close_field]) + else None, + volume=Decimal(str(row[volume_field])) + if volume_field and pd.notna(row[volume_field]) + else None, + change=Decimal(str(row[change_field])) + if change_field and pd.notna(row[change_field]) + else None, + change_percent=Decimal(str(row[change_pct_field])) + if change_pct_field and pd.notna(row[change_pct_field]) + else None, + source=DataSource.AKSHARE, + ) + prices.append(price) - if df_hist is None or df_hist.empty: - logger.warning(f"No intraday historical data available for {symbol}") - return [] + except Exception as e: + logger.warning(f"Error converting row to AssetPrice: {e}") + continue - return self._process_a_share_intraday_data(ticker, df_hist, period) + return prices except Exception as e: - logger.error( - f"Error fetching A-share intraday historical data for {symbol}: {e}" - ) + logger.error(f"Error converting DataFrame to prices: {e}", exc_info=True) return [] - def _process_a_share_daily_data( - self, ticker: str, df_hist: pd.DataFrame + def _convert_intraday_df_to_prices( + self, df: pd.DataFrame, ticker: str, exchange: Exchange ) -> List[AssetPrice]: - """Process A-share daily historical data.""" - prices = [] - for _, row in df_hist.iterrows(): - try: - # Parse date safely - trade_date = pd.to_datetime(row["日期"]).to_pydatetime() - - # Extract price data safely - open_price = self._safe_decimal_convert(row.get("开盘")) - high_price = self._safe_decimal_convert(row.get("最高")) - low_price = self._safe_decimal_convert(row.get("最低")) - close_price = self._safe_decimal_convert(row.get("收盘")) - volume = self._safe_decimal_convert(row.get("成交量")) - - if not close_price: # Skip if no closing price - continue + """Convert intraday price DataFrame to list of AssetPrice objects. - # Extract change data if available (AKShare provides this directly) - change = self._safe_decimal_convert(row.get("涨跌额")) - change_percent = self._safe_decimal_convert(row.get("涨跌幅")) - - # If change data not available, calculate from previous day - if change is None and len(prices) > 0: - prev_close = prices[-1].close_price - if prev_close and prev_close != 0: - change = close_price - prev_close - change_percent = (change / prev_close) * 100 - - price = AssetPrice( - ticker=ticker, - price=close_price, - currency="CNY", - timestamp=trade_date, - volume=volume, - open_price=open_price, - high_price=high_price, - low_price=low_price, - close_price=close_price, - change=change, - change_percent=change_percent, - source=self.source, - ) - prices.append(price) - - except Exception as row_error: - logger.warning(f"Error processing daily data row: {row_error}") - continue - - logger.info(f"Retrieved {len(prices)} daily price points") - return prices + Args: + df: DataFrame containing intraday price data + ticker: Asset ticker in internal format + exchange: Exchange enum - def _process_a_share_intraday_data( - self, ticker: str, df_hist: pd.DataFrame, period: str - ) -> List[AssetPrice]: - """Process A-share intraday historical data.""" + Returns: + List of AssetPrice objects + """ prices = [] - for _, row in df_hist.iterrows(): - try: - # Parse timestamp safely - trade_time = pd.to_datetime(row["时间"]).to_pydatetime() - - # Extract price data safely - open_price = self._safe_decimal_convert(row.get("开盘")) - high_price = self._safe_decimal_convert(row.get("最高")) - low_price = self._safe_decimal_convert(row.get("最低")) - close_price = self._safe_decimal_convert(row.get("收盘")) - - # Volume is in 手 (lots), convert to shares (1 lot = 100 shares) - volume_lots = self._safe_decimal_convert(row.get("成交量")) - volume = volume_lots * 100 if volume_lots else None - - if not close_price: # Skip if no closing price - continue - - # For intraday data, calculate change from previous period - change = None - change_percent = None - if len(prices) > 0: - prev_close = prices[-1].close_price - if prev_close and prev_close != 0: - change = close_price - prev_close - change_percent = (change / prev_close) * 100 - - # For periods > 1 minute, AKShare provides change data directly - if period != "1": - akshare_change = self._safe_decimal_convert(row.get("涨跌额")) - akshare_change_percent = self._safe_decimal_convert( - row.get("涨跌幅") - ) - if akshare_change is not None: - change = akshare_change - if akshare_change_percent is not None: - change_percent = akshare_change_percent - - price = AssetPrice( - ticker=ticker, - price=close_price, - currency="CNY", - timestamp=trade_time, - volume=volume, - open_price=open_price, - high_price=high_price, - low_price=low_price, - close_price=close_price, - change=change, - change_percent=change_percent, - source=self.source, - ) - prices.append(price) - - except Exception as row_error: - logger.warning(f"Error processing intraday data row: {row_error}") - continue - - logger.info(f"Retrieved {len(prices)} intraday ({period}m) price points") - return prices - def _get_hk_stock_historical( - self, - ticker: str, - exchange: str, - symbol: str, - start_date: datetime, - end_date: datetime, - interval: str, - ) -> List[AssetPrice]: - """Get Hong Kong stock historical price data.""" - try: - # Map interval to AKShare format - akshare_params = self._map_interval_to_akshare_params(interval) - if not akshare_params: - logger.warning(f"Unsupported interval for HK stocks: {interval}") - return [] - - is_intraday = akshare_params["is_intraday"] - period_or_minutes = akshare_params["period"] - - if is_intraday: - return self._get_hk_stock_intraday_historical( - ticker, exchange, symbol, start_date, end_date, period_or_minutes - ) - else: - return self._get_hk_stock_daily_historical( - ticker, exchange, symbol, start_date, end_date, period_or_minutes - ) - - except Exception as e: - logger.error(f"Error fetching HK stock historical data for {symbol}: {e}") - return [] - - def _get_hk_stock_daily_historical( - self, - ticker: str, - exchange: str, - symbol: str, - start_date: datetime, - end_date: datetime, - period: str, - ) -> List[AssetPrice]: - """Get Hong Kong stock daily historical price data.""" try: - # Format dates for AKShare - start_date_str = start_date.strftime("%Y%m%d") - end_date_str = end_date.strftime("%Y%m%d") - - # Use cached data for historical prices - cache_key = ( - f"hk_stock_hist_{symbol}_{start_date_str}_{end_date_str}_{period}" - ) - df_hist = self._get_cached_data( - cache_key, - self._safe_akshare_call, - ak.stock_hk_hist, - symbol=symbol, - period=period, - start_date=start_date_str, - end_date=end_date_str, - adjust="qfq", # Use forward adjustment (前复权) - ) - - if df_hist is None or df_hist.empty: - logger.warning( - f"No HK stock daily historical data available for {symbol}" + # Get currency based on exchange + currency = self._get_currency(exchange) + + # Use field mapping helper to get actual field names + time_field = self._get_field_name(df, "time", exchange) + open_field = self._get_field_name(df, "open", exchange) + close_field = self._get_field_name(df, "close", exchange) + high_field = self._get_field_name(df, "high", exchange) + low_field = self._get_field_name(df, "low", exchange) + volume_field = self._get_field_name(df, "volume", exchange) + + # Validate required fields + if not time_field or not close_field: + logger.error( + f"Missing required fields in DataFrame. time_field={time_field}, close_field={close_field}" ) return [] - return self._process_hk_stock_daily_data(ticker, df_hist) - - except Exception as e: - logger.error( - f"Error fetching HK stock daily historical data for {symbol}: {e}" - ) - return [] + for _, row in df.iterrows(): + try: + # Parse datetime + time_str = str(row[time_field]) + timestamp = pd.to_datetime(time_str) - def _get_hk_stock_intraday_historical( - self, - ticker: str, - exchange: str, - symbol: str, - start_date: datetime, - end_date: datetime, - period: str, - ) -> List[AssetPrice]: - """Get Hong Kong stock intraday historical price data.""" - try: - # Format dates for AKShare intraday query - start_date_str = start_date.strftime("%Y-%m-%d %H:%M:%S") - end_date_str = end_date.strftime("%Y-%m-%d %H:%M:%S") - - # Use cached data for intraday historical prices - cache_key = f"hk_stock_hist_min_{symbol}_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}_{period}" - - # Note: HK stock minute data has limitations - only recent 5 trading days for 1-minute - adjust_param = "" if period == "1" else "qfq" - - df_hist = self._get_cached_data( - cache_key, - self._safe_akshare_call, - ak.stock_hk_hist_min_em, - symbol=symbol, - start_date=start_date_str, - end_date=end_date_str, - period=period, - adjust=adjust_param, - ) + # Create AssetPrice object + price = AssetPrice( + ticker=ticker, + price=Decimal(str(row[close_field])), + currency=currency, + timestamp=timestamp, + open_price=Decimal(str(row[open_field])) + if open_field + and pd.notna(row[open_field]) + and row[open_field] != 0 + else None, + high_price=Decimal(str(row[high_field])) + if high_field and pd.notna(row[high_field]) + else None, + low_price=Decimal(str(row[low_field])) + if low_field and pd.notna(row[low_field]) + else None, + close_price=Decimal(str(row[close_field])) + if pd.notna(row[close_field]) + else None, + volume=Decimal(str(row[volume_field])) + if volume_field and pd.notna(row[volume_field]) + else None, + source=DataSource.AKSHARE, + ) + prices.append(price) - if df_hist is None or df_hist.empty: - logger.warning( - f"No HK stock intraday historical data available for {symbol}" - ) - return [] + except Exception as e: + logger.warning(f"Error converting intraday row to AssetPrice: {e}") + continue - return self._process_hk_stock_intraday_data(ticker, df_hist, period) + return prices except Exception as e: logger.error( - f"Error fetching HK stock intraday historical data for {symbol}: {e}" + f"Error converting intraday DataFrame to prices: {e}", exc_info=True ) return [] - def _process_hk_stock_daily_data( - self, ticker: str, df_hist: pd.DataFrame - ) -> List[AssetPrice]: - """Process Hong Kong stock daily historical data.""" - prices = [] - for _, row in df_hist.iterrows(): - try: - # Parse date safely - trade_date = pd.to_datetime(row["日期"]).to_pydatetime() + def get_capabilities(self) -> List[AdapterCapability]: + """Get detailed capabilities of AKShare adapter. - # Extract price data safely - open_price = self._safe_decimal_convert(row.get("开盘")) - high_price = self._safe_decimal_convert(row.get("最高")) - low_price = self._safe_decimal_convert(row.get("最低")) - close_price = self._safe_decimal_convert(row.get("收盘")) - volume = self._safe_decimal_convert(row.get("成交量")) + AKShare primarily supports Chinese and Hong Kong markets. - if not close_price: # Skip if no closing price - continue - - # Extract change data if available (AKShare provides this directly) - change = self._safe_decimal_convert(row.get("涨跌额")) - change_percent = self._safe_decimal_convert(row.get("涨跌幅")) - - # If change data not available, calculate from previous day - if change is None and len(prices) > 0: - prev_close = prices[-1].close_price - if prev_close and prev_close != 0: - change = close_price - prev_close - change_percent = (change / prev_close) * 100 - - price = AssetPrice( - ticker=ticker, - price=close_price, - currency="HKD", - timestamp=trade_date, - volume=volume, - open_price=open_price, - high_price=high_price, - low_price=low_price, - close_price=close_price, - change=change, - change_percent=change_percent, - source=self.source, - ) - prices.append(price) - - except Exception as row_error: - logger.warning(f"Error processing HK stock daily data row: {row_error}") - continue + Returns: + List of capabilities describing supported asset types and exchanges + """ + return [ + AdapterCapability( + asset_type=AssetType.STOCK, + exchanges={ + Exchange.SSE, + Exchange.SZSE, + Exchange.BSE, + Exchange.HKEX, + Exchange.NASDAQ, + Exchange.NYSE, + Exchange.AMEX, + }, + ), + AdapterCapability( + asset_type=AssetType.ETF, + exchanges={ + Exchange.SSE, + Exchange.SZSE, + Exchange.BSE, + Exchange.HKEX, + Exchange.NASDAQ, + Exchange.NYSE, + Exchange.AMEX, + }, + ), + AdapterCapability( + asset_type=AssetType.INDEX, + exchanges={ + Exchange.SSE, + Exchange.SZSE, + Exchange.BSE, + Exchange.HKEX, + Exchange.NASDAQ, + Exchange.NYSE, + Exchange.AMEX, + }, + ), + ] - logger.info(f"Retrieved {len(prices)} HK stock daily price points") - return prices + def convert_to_source_ticker(self, internal_ticker: str) -> str: + """Convert internal ticker to data source format. + Args: + internal_ticker: Ticker in internal format (e.g., "NASDAQ:AAPL", "NYSE:GSPC") + source: Target data source + Returns: + Ticker in data source specific format (e.g., "105.AAPL", "100.GSPC") + """ + try: + exchange, symbol = internal_ticker.split(":", 1) - def _process_hk_stock_intraday_data( - self, ticker: str, df_hist: pd.DataFrame, period: str - ) -> List[AssetPrice]: - """Process Hong Kong stock intraday historical data.""" - prices = [] - for _, row in df_hist.iterrows(): + # Convert exchange string to Exchange enum if needed try: - # Parse timestamp safely - trade_time = pd.to_datetime(row["时间"]).to_pydatetime() - - # Extract price data safely - open_price = self._safe_decimal_convert(row.get("开盘")) - high_price = self._safe_decimal_convert(row.get("最高")) - low_price = self._safe_decimal_convert(row.get("最低")) - close_price = self._safe_decimal_convert(row.get("收盘")) - volume = self._safe_decimal_convert(row.get("成交量")) - - if not close_price: # Skip if no closing price - continue - - # For intraday data, calculate change from previous period - change = None - change_percent = None - if len(prices) > 0: - prev_close = prices[-1].close_price - if prev_close and prev_close != 0: - change = close_price - prev_close - change_percent = (change / prev_close) * 100 - - price = AssetPrice( - ticker=ticker, - price=close_price, - currency="HKD", - timestamp=trade_time, - volume=volume, - open_price=open_price, - high_price=high_price, - low_price=low_price, - close_price=close_price, - change=change, - change_percent=change_percent, - source=self.source, - ) - prices.append(price) - - except Exception as row_error: + exchange_enum = Exchange(exchange) + except ValueError: logger.warning( - f"Error processing HK stock intraday data row: {row_error}" + f"Unknown exchange '{exchange}' for ticker {internal_ticker}" ) - continue - - logger.info( - f"Retrieved {len(prices)} HK stock intraday ({period}m) price points" - ) - return prices - - def _get_us_stock_historical( - self, - ticker: str, - exchange: str, - symbol: str, - start_date: datetime, - end_date: datetime, - interval: str, - ) -> List[AssetPrice]: - """Get US stock historical price data.""" - try: - # Map interval to AKShare format - akshare_params = self._map_interval_to_akshare_params(interval) - if not akshare_params: - logger.warning(f"Unsupported interval for US stocks: {interval}") - return [] - - is_intraday = akshare_params["is_intraday"] - period_or_minutes = akshare_params["period"] + return symbol - if is_intraday: - return self._get_us_stock_intraday_historical( - ticker, exchange, symbol, start_date, end_date, period_or_minutes - ) - else: - return self._get_us_stock_daily_historical( - ticker, exchange, symbol, start_date, end_date, period_or_minutes + # Check if this is an INDEX asset type from database + # Use lazy import to avoid circular dependency + try: + from ...server.db.repositories.asset_repository import ( + get_asset_repository, ) - except Exception as e: - logger.error(f"Error fetching US stock historical data for {symbol}: {e}") - return [] - - def _get_us_stock_daily_historical( - self, - ticker: str, - exchange: str, - symbol: str, - start_date: datetime, - end_date: datetime, - period: str, - ) -> List[AssetPrice]: - """Get US stock daily historical price data.""" - try: - # Format dates for AKShare - start_date_str = start_date.strftime("%Y%m%d") - end_date_str = end_date.strftime("%Y%m%d") - - # Use cached data for historical prices - cache_key = ( - f"us_stock_hist_{symbol}_{start_date_str}_{end_date_str}_{period}" - ) - df_hist = self._get_cached_data( - cache_key, - self._safe_akshare_call, - ak.stock_us_hist, - symbol=symbol, - period=period, - start_date=start_date_str, - end_date=end_date_str, - adjust="qfq", # Use forward adjustment - ) - - if df_hist is None or df_hist.empty: - logger.warning( - f"No US stock daily historical data available for {symbol}" + asset_repo = get_asset_repository() + asset = asset_repo.get_asset_by_symbol(internal_ticker) + + # For US INDEX assets, use special exchange code 100 + if asset and asset.asset_type == AssetType.INDEX.value: + if exchange_enum in [Exchange.NASDAQ, Exchange.NYSE, Exchange.AMEX]: + return f"{self.us_index_exchange_code}.{symbol}" + except (ImportError, Exception) as e: + # If repository is not available, skip database lookup + logger.debug( + f"Asset repository not available for AKShare, skipping database lookup: {e}" ) - return [] - - return self._process_us_stock_daily_data(ticker, df_hist) - - except Exception as e: - logger.error( - f"Error fetching US stock daily historical data for {symbol}: {e}" + pass + + # Handle US stocks - add exchange code prefix + if exchange_enum in self.us_exchange_codes: + exchange_code = self.us_exchange_codes[exchange_enum] + return f"{exchange_code}.{symbol}" + + # Handle Chinese A-shares and Hong Kong stocks + # For Chinese markets, AKShare uses the symbol directly without suffix + if exchange_enum in [ + Exchange.SSE, + Exchange.SZSE, + Exchange.BSE, + Exchange.HKEX, + ]: + return symbol + + # For other exchanges, return symbol as-is + logger.debug( + f"No specific format mapping for exchange {exchange}, returning symbol: {symbol}" ) - return [] - - def _get_us_stock_intraday_historical( - self, - ticker: str, - exchange: str, - symbol: str, - start_date: datetime, - end_date: datetime, - period: str, - ) -> List[AssetPrice]: - """Get US stock intraday historical price data.""" - try: - # Format dates for AKShare intraday query - start_date_str = start_date.strftime("%Y-%m-%d %H:%M:%S") - end_date_str = end_date.strftime("%Y-%m-%d %H:%M:%S") - - # Use cached data for intraday historical prices - cache_key = f"us_stock_hist_min_{symbol}_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}_{period}" - - # Note: US stock minute data has limitations - only recent 5 trading days - df_hist = self._get_cached_data( - cache_key, - self._safe_akshare_call, - ak.stock_us_hist_min_em, - symbol=symbol, - start_date=start_date_str, - end_date=end_date_str, - ) - - if df_hist is None or df_hist.empty: - logger.warning( - f"No US stock intraday historical data available for {symbol}" - ) - return [] + return symbol - return self._process_us_stock_intraday_data(ticker, df_hist, period) - - except Exception as e: + except ValueError: logger.error( - f"Error fetching US stock intraday historical data for {symbol}: {e}" + f"Invalid ticker format: {internal_ticker}, expected 'EXCHANGE:SYMBOL'" ) - return [] - - def _process_us_stock_daily_data( - self, ticker: str, df_hist: pd.DataFrame - ) -> List[AssetPrice]: - """Process US stock daily historical data.""" - prices = [] - for _, row in df_hist.iterrows(): - try: - # Parse date safely - trade_date = pd.to_datetime(row["日期"]).to_pydatetime() - - # Extract price data safely - open_price = self._safe_decimal_convert(row.get("开盘")) - high_price = self._safe_decimal_convert(row.get("最高")) - low_price = self._safe_decimal_convert(row.get("最低")) - close_price = self._safe_decimal_convert(row.get("收盘")) - volume = self._safe_decimal_convert(row.get("成交量")) - - if not close_price: # Skip if no closing price - continue - - # Extract change data if available (AKShare provides this directly) - change = self._safe_decimal_convert(row.get("涨跌额")) - change_percent = self._safe_decimal_convert(row.get("涨跌幅")) - - # If change data not available, calculate from previous day - if change is None and len(prices) > 0: - prev_close = prices[-1].close_price - if prev_close and prev_close != 0: - change = close_price - prev_close - change_percent = (change / prev_close) * 100 - - price = AssetPrice( - ticker=ticker, - price=close_price, - currency="USD", - timestamp=trade_date, - volume=volume, - open_price=open_price, - high_price=high_price, - low_price=low_price, - close_price=close_price, - change=change, - change_percent=change_percent, - source=self.source, - ) - prices.append(price) - - except Exception as row_error: - logger.warning(f"Error processing US stock daily data row: {row_error}") - continue - - logger.info(f"Retrieved {len(prices)} US stock daily price points") - return prices + return internal_ticker - def _process_us_stock_intraday_data( - self, ticker: str, df_hist: pd.DataFrame, period: str - ) -> List[AssetPrice]: - """Process US stock intraday historical data.""" - prices = [] - for _, row in df_hist.iterrows(): - try: - # Parse timestamp safely - trade_time = pd.to_datetime(row["时间"]).to_pydatetime() - - # Extract price data safely - open_price = self._safe_decimal_convert(row.get("开盘")) - high_price = self._safe_decimal_convert(row.get("最高")) - low_price = self._safe_decimal_convert(row.get("最低")) - close_price = self._safe_decimal_convert(row.get("收盘")) - volume = self._safe_decimal_convert(row.get("成交量")) - - if not close_price: # Skip if no closing price - continue - - # For intraday data, calculate change from previous period - change = None - change_percent = None - if len(prices) > 0: - prev_close = prices[-1].close_price - if prev_close and prev_close != 0: - change = close_price - prev_close - change_percent = (change / prev_close) * 100 - - price = AssetPrice( - ticker=ticker, - price=close_price, - currency="USD", - timestamp=trade_time, - volume=volume, - open_price=open_price, - high_price=high_price, - low_price=low_price, - close_price=close_price, - change=change, - change_percent=change_percent, - source=self.source, - ) - prices.append(price) - - except Exception as row_error: - logger.warning( - f"Error processing US stock intraday data row: {row_error}" - ) - continue - - logger.info( - f"Retrieved {len(prices)} US stock intraday ({period}m) price points" - ) - return prices - - def get_supported_asset_types(self) -> List[AssetType]: - """Get asset types supported by AKShare.""" - return [ - AssetType.STOCK, - AssetType.ETF, - AssetType.INDEX, - ] - - def _perform_health_check(self) -> Any: - """Perform health check by testing a simple stock info call instead of full data download.""" - try: - # Test with a simple individual stock info call instead of downloading all market data - # This avoids the expensive full market data download during health checks - try: - # Test A-share with a known stock (Ping An Bank) - df_test = ak.stock_individual_info_em(symbol="000001") - if df_test is not None and not df_test.empty: - return { - "status": "ok", - "test_method": "individual_stock_info", - "test_symbol": "000001", - "response_received": True, - } - except Exception as e: - logger.debug(f"A-share test failed: {e}") - - # Fallback: just check if akshare module is available and importable - import akshare as ak_test - - if ak_test: - return { - "status": "ok", - "test_method": "module_import", - "message": "AKShare module available", - } + def convert_to_internal_ticker( + self, source_ticker: str, default_exchange: Optional[str] = None + ) -> str: + """Convert data source ticker to internal format. + Args: + source_ticker: Ticker in data source format (e.g., "105.AAPL", "00700","600519") + source: Source data provider + default_exchange: Default exchange if cannot be determined from ticker + Returns: + Ticker in internal format (e.g., "NASDAQ:AAPL", "HKEX:00700", "SSE:600519") + """ + # Handle US stocks with exchange code prefix (e.g., "105.AAPL" -> "NASDAQ:AAPL") + if "." in source_ticker: + parts = source_ticker.split(".", 1) + if len(parts) == 2: + exchange_code, symbol = parts + # Check if this is a US exchange code + if exchange_code in self.us_exchange_codes_reverse: + exchange_enum = self.us_exchange_codes_reverse[exchange_code] + return f"{exchange_enum.value}:{symbol}" + + # Handle Chinese A-shares by ticker format + # Shanghai Stock Exchange: 6-digit codes starting with 6 + # Shenzhen Stock Exchange: 6-digit codes starting with 0 or 3 + # Beijing Stock Exchange: 6-digit codes starting with 4 or 8 + if source_ticker.isdigit(): + if len(source_ticker) == 6: + first_digit = source_ticker[0] + if first_digit == "6": + return f"{Exchange.SSE.value}:{source_ticker}" + elif first_digit in ["0", "3"]: + return f"{Exchange.SZSE.value}:{source_ticker}" + elif first_digit in ["4", "8"]: + return f"{Exchange.BSE.value}:{source_ticker}" + + # Handle Hong Kong stocks (5-digit codes, can have leading zeros) + # Hong Kong stocks are typically 5 digits (e.g., "00700", "01810") + elif len(source_ticker) == 5: + return f"{Exchange.HKEX.value}:{source_ticker}" + + # For 4-digit codes, could be simplified HK stocks + elif len(source_ticker) == 4: + # Pad to 5 digits for Hong Kong stocks + padded_symbol = source_ticker.zfill(5) + return f"{Exchange.HKEX.value}:{padded_symbol}" + + # If default exchange is provided, use it + if default_exchange: + # Normalize default_exchange if it's an Exchange enum + if isinstance(default_exchange, Exchange): + exchange_value = default_exchange.value else: - return {"status": "error", "message": "AKShare module not available"} - - except Exception as e: - return {"status": "error", "message": str(e)} - - except Exception as e: - return {"status": "error", "message": str(e)} + exchange_value = default_exchange + return f"{exchange_value}:{source_ticker}" - def _looks_like_ticker(self, search_term: str) -> bool: - """Check if search term looks like a ticker symbol.""" - search_term = search_term.upper().strip() - - # Combined heuristics for ticker-like patterns - return (len(search_term) <= 6 and search_term.isalnum()) or ( - len(search_term) <= 10 and search_term.isalpha() + # Fallback: return with AKSHARE prefix if cannot determine exchange + logger.warning( + f"Cannot determine exchange for ticker '{source_ticker}', using AKSHARE as prefix" ) - - def _search_by_direct_ticker_lookup( - self, search_term: str, query: AssetSearchQuery - ) -> List[AssetSearchResult]: - """Search by direct ticker lookup as fallback for semantic search. - - This method provides a yfinance-like approach for cases where AKShare - doesn't have comprehensive search capabilities. - """ - search_term = search_term.upper().strip() - - # Generate ticker variations based on search term characteristics - ticker_variations = self._generate_ticker_variations(search_term) - - for ticker_format in ticker_variations: - try: - # Try to get asset info to validate the ticker - asset_info = self.get_asset_info(ticker_format) - if asset_info: - # Create search result from asset info - result = AssetSearchResult( - ticker=ticker_format, - asset_type=asset_info.asset_type, - names={ - "zh-Hans": asset_info.names.get_name("zh-Hans") - or search_term, - "zh-Hant": asset_info.names.get_name("zh-Hant") - or search_term, - "en-US": asset_info.names.get_name("en-US") or search_term, - }, - exchange=asset_info.market_info.exchange, - country=asset_info.market_info.country, - currency=asset_info.market_info.currency, - market_status=MarketStatus.UNKNOWN, - relevance_score=2.0, # High relevance for direct matches - ) - return [result] # Return immediately on first match - - except Exception as e: - logger.debug(f"Ticker lookup failed for {ticker_format}: {e}") - continue - - return [] - - def _generate_ticker_variations(self, search_term: str) -> List[str]: - """Generate ticker variations based on search term characteristics.""" - variations = [search_term] # Direct ticker first - - # A-share variations (6 digits) - if search_term.isdigit() and len(search_term) == 6: - if search_term.startswith("6"): - variations.append(f"SSE:{search_term}") - elif search_term.startswith(("0", "3")): - variations.append(f"SZSE:{search_term}") - elif search_term.startswith("8"): - variations.append(f"BSE:{search_term}") - - # HK variations (digits, potentially short) - elif search_term.isdigit() and 1 <= len(search_term) <= 5: - variations.extend( - [ - f"HKEX:{search_term}", - f"HKEX:{search_term.zfill(5)}", # Pad with zeros - ] - ) - - # US/Crypto variations (letters) - elif search_term.isalpha(): - variations.extend( - [ - f"NASDAQ:{search_term}", - f"NYSE:{search_term}", - ] - ) - - return variations - - # Ticker validation patterns - TICKER_VALIDATION_PATTERNS = { - "SSE": re.compile(r"^6\d{5}$"), # Shanghai: 6xxxxx - "SZSE": re.compile(r"^[03]\d{5}$"), # Shenzhen: 0xxxxx or 3xxxxx - "BSE": re.compile(r"^8\d{5}$"), # Beijing: 8xxxxx - "HKEX": re.compile(r"^\d{5}$"), # Hong Kong: 5 digits - "NASDAQ": re.compile( - r"^[A-Z0-9]{1,5}$" - ), # US markets: 1-5 alphanumeric uppercase - "NYSE": re.compile( - r"^[A-Z0-9]{1,5}$" - ), # US markets: 1-5 alphanumeric uppercase - "CRYPTO": re.compile(r"^[A-Z0-9]{1,5}$"), # Crypto: 1-5 alphanumeric uppercase - } + return f"AKSHARE:{source_ticker}" def validate_ticker(self, ticker: str) -> bool: """Validate if ticker is supported by AKShare and matches standard format.""" try: - exchange, symbol = ticker.split(":", 1) - - pattern = self.TICKER_VALIDATION_PATTERNS.get(exchange) - return bool(pattern and pattern.match(symbol)) - - except ValueError: - return False - - def get_market_calendar( - self, start_date: datetime, end_date: datetime - ) -> List[datetime]: - """Get trading calendar for Chinese markets.""" - try: - # Get trading calendar from AKShare - df_calendar = ak.tool_trade_date_hist_sina() - - if df_calendar is None or df_calendar.empty: - return [] - - # Convert to datetime and filter by date range - df_calendar["trade_date"] = pd.to_datetime(df_calendar["trade_date"]) - - mask = (df_calendar["trade_date"] >= start_date) & ( - df_calendar["trade_date"] <= end_date - ) - filtered_dates = df_calendar[mask]["trade_date"] - - return [date.to_pydatetime() for date in filtered_dates] - - except Exception as e: - logger.error(f"Error fetching market calendar: {e}") - return [] - - def get_sector_stocks(self, sector: str) -> List[AssetSearchResult]: - """Get stocks from a specific sector.""" - try: - # Get sector classification - df_industry = ak.stock_board_industry_name_em() - - if df_industry is None or df_industry.empty: - return [] + if ":" not in ticker: + return False - # Find matching sectors - sector_matches = df_industry[ - df_industry["板块名称"].str.contains(sector, na=False) - ] + exchange, _ = ticker.split(":", 1) + capabilities = self.get_capabilities() - results = [] - for _, sector_row in sector_matches.iterrows(): - try: - # Get stocks in this sector - sector_name = sector_row["板块名称"] - df_sector_stocks = ak.stock_board_industry_cons_em( - symbol=sector_name - ) - - if df_sector_stocks is not None and not df_sector_stocks.empty: - for _, stock_row in df_sector_stocks.iterrows(): - stock_code = str(stock_row["代码"]) - stock_name = stock_row["名称"] - - # Determine exchange - exchange_info = self._get_exchange_from_a_share_code( - stock_code - ) - if not exchange_info: - continue - - exchange, internal_ticker = exchange_info - - result = self._create_stock_search_result( - internal_ticker, - AssetType.STOCK, - stock_code, - stock_name, - exchange, - "CN", - "CNY", - "", - ) - result.relevance_score = 1.0 # Override for sector search - results.append(result) - - except Exception as e: - logger.warning( - f"Error processing sector {sector_row.get('板块名称')}: {e}" - ) - continue - - return results - - except Exception as e: - logger.error(f"Error getting sector stocks for {sector}: {e}") - return [] - - def is_market_open(self, exchange: str) -> bool: - """Check if market is currently open.""" - now = datetime.utcnow() - - # Market configurations: (timezone_offset, trading_sessions) - market_config = { - "SSE": (8, [("09:30", "11:30"), ("13:00", "15:00")]), - "SZSE": (8, [("09:30", "11:30"), ("13:00", "15:00")]), - "BSE": (8, [("09:30", "11:30"), ("13:00", "15:00")]), - "HKEX": (8, [("09:30", "12:00"), ("13:00", "16:00")]), - "NASDAQ": (-5, [("09:30", "16:00")]), - "NYSE": (-5, [("09:30", "16:00")]), - "CRYPTO": (0, [("00:00", "23:59")]), # Always open - } - - if exchange not in market_config: + # Check if any capability supports this exchange + return any(cap.supports_exchange(exchange) for cap in capabilities) + except Exception: return False - - if exchange == "CRYPTO": - return True - - timezone_offset, sessions = market_config[exchange] - local_time = now.replace(tzinfo=None) + timedelta(hours=timezone_offset) - - # Check if it's a weekday - if local_time.weekday() >= 5: - return False - - current_time = local_time.time() - - # Check if current time falls within any trading session - for start_str, end_str in sessions: - start_time = datetime.strptime(start_str, "%H:%M").time() - end_time = datetime.strptime(end_str, "%H:%M").time() - if start_time <= current_time <= end_time: - return True - - return False diff --git a/python/valuecell/adapters/assets/base.py b/python/valuecell/adapters/assets/base.py index 5cef4a72e..c1106d363 100644 --- a/python/valuecell/adapters/assets/base.py +++ b/python/valuecell/adapters/assets/base.py @@ -6,8 +6,9 @@ import logging from abc import ABC, abstractmethod +from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional, Set from .types import ( Asset, @@ -16,204 +17,42 @@ AssetSearchResult, AssetType, DataSource, + Exchange, ) logger = logging.getLogger(__name__) -class TickerConverter: - """Utility class for converting between internal ticker format and data source formats.""" - - def __init__(self): - """Initialize ticker converter with mapping rules.""" - # Mapping from internal exchange codes to data source specific formats - self.exchange_mappings: Dict[DataSource, Dict[str, str]] = { - DataSource.YFINANCE: { - "NASDAQ": "", # NASDAQ stocks don't need suffix in yfinance - "NYSE": "", # NYSE stocks don't need suffix in yfinance - "SSE": ".SS", # Shanghai Stock Exchange - "SZSE": ".SZ", # Shenzhen Stock Exchange - "HKEX": ".HK", # Hong Kong Exchange - "TSE": ".T", # Tokyo Stock Exchange - }, - DataSource.AKSHARE: { - "SSE": "", # AKShare uses plain symbols for Chinese stocks - "SZSE": "", # AKShare uses plain symbols for Chinese stocks - "BSE": "", # Beijing Stock Exchange - }, - } - - # Reverse mappings for converting back to internal format - self.reverse_mappings: Dict[DataSource, Dict[str, str]] = {} - for source, mappings in self.exchange_mappings.items(): - self.reverse_mappings[source] = {v: k for k, v in mappings.items() if v} - - def to_source_format(self, internal_ticker: str, source: DataSource) -> str: - """Convert internal ticker format to data source specific format. +@dataclass +class AdapterCapability: + """Describes the asset types and exchanges supported by an adapter. - Args: - internal_ticker: Ticker in internal format (e.g., "NASDAQ:AAPL") - source: Target data source + This provides fine-grained control over adapter routing based on + specific exchange and asset type combinations. + """ - Returns: - Ticker in data source specific format (e.g., "AAPL" for yfinance) - """ - try: - exchange, symbol = internal_ticker.split(":", 1) - - # Special handling for indices in yfinance (use ^ prefix) - if source == DataSource.YFINANCE: - index_mapping = { - # US Indices - "NASDAQ:IXIC": "^IXIC", # NASDAQ Composite - "NYSE:DJI": "^DJI", # Dow Jones Industrial Average - "NYSE:GSPC": "^GSPC", # S&P 500 - "NASDAQ:NDX": "^NDX", # NASDAQ 100 - # Hong Kong Indices - "HKEX:HSI": "^HSI", # Hang Seng Index - "HKEX:HSCEI": "^HSCEI", # Hang Seng China Enterprises Index - # Chinese Indices (already work with .SS/.SZ suffixes) - # European Indices - "LSE:FTSE": "^FTSE", # FTSE 100 - "EURONEXT:FCHI": "^FCHI", # CAC 40 - "XETRA:GDAXI": "^GDAXI", # DAX - } - - if internal_ticker in index_mapping: - return index_mapping[internal_ticker] - - # Special handling for crypto tickers in yfinance - if exchange == "CRYPTO" and source == DataSource.YFINANCE: - # Map common crypto symbols to yfinance format - crypto_mapping = { - "BTC": "BTC-USD", - "ETH": "ETH-USD", - "ADA": "ADA-USD", - "DOT": "DOT-USD", - "SOL": "SOL-USD", - "MATIC": "MATIC-USD", - "LINK": "LINK-USD", - "UNI": "UNI-USD", - "AVAX": "AVAX-USD", - "ATOM": "ATOM-USD", - } - return crypto_mapping.get(symbol, f"{symbol}-USD") - - # Special handling for Hong Kong stocks in yfinance - if exchange == "HKEX" and source == DataSource.YFINANCE: - # Hong Kong stock codes need to be in proper format - # e.g., "700" -> "0700.HK", "00700" -> "0700.HK", "1234" -> "1234.HK" - if symbol.isdigit(): - # Remove leading zeros first, then pad to 4 digits - clean_symbol = str(int(symbol)) # Remove leading zeros - padded_symbol = clean_symbol.zfill(4) # Pad to 4 digits - return f"{padded_symbol}.HK" - else: - # For non-numeric symbols, use as-is with .HK suffix - return f"{symbol}.HK" - - if source not in self.exchange_mappings: - logger.warning(f"No mapping found for data source: {source}") - return symbol - - suffix = self.exchange_mappings[source].get(exchange, "") - return f"{symbol}{suffix}" - - except ValueError: - logger.error(f"Invalid ticker format: {internal_ticker}") - return internal_ticker - - def to_internal_format( - self, - source_ticker: str, - source: DataSource, - default_exchange: Optional[str] = None, - ) -> str: - """Convert data source ticker to internal format. + asset_type: AssetType + exchanges: Set[Exchange] # Supported exchanges + + def supports_exchange(self, exchange: Exchange) -> bool: + """Check if this capability supports the given exchange. Args: - source_ticker: Ticker in data source format (e.g., "000001.SZ") - source: Source data provider - default_exchange: Default exchange if cannot be determined from ticker + exchange: Exchange to check (can be Exchange enum or string) Returns: - Ticker in internal format (e.g., "SZSE:000001") + True if this capability supports the exchange """ - try: - # Special handling for indices from yfinance (reverse ^ prefix mapping) - if source == DataSource.YFINANCE and source_ticker.startswith("^"): - index_reverse_mapping = { - # US Indices - "^IXIC": "NASDAQ:IXIC", # NASDAQ Composite - "^DJI": "NYSE:DJI", # Dow Jones Industrial Average - "^GSPC": "NYSE:GSPC", # S&P 500 - "^NDX": "NASDAQ:NDX", # NASDAQ 100 - # Hong Kong Indices - "^HSI": "HKEX:HSI", # Hang Seng Index - "^HSCEI": "HKEX:HSCEI", # Hang Seng China Enterprises Index - # European Indices - "^FTSE": "LSE:FTSE", # FTSE 100 - "^FCHI": "EURONEXT:FCHI", # CAC 40 - "^GDAXI": "XETRA:GDAXI", # DAX - } - - if source_ticker in index_reverse_mapping: - return index_reverse_mapping[source_ticker] - - # Special handling for crypto from yfinance - remove currency suffix - if source == DataSource.YFINANCE and ( - "-USD" in source_ticker - or "-CAD" in source_ticker - or "-EUR" in source_ticker - ): - # Remove any currency suffix - crypto_symbol = source_ticker.split("-")[0].upper() - return f"CRYPTO:{crypto_symbol}" - - # Special handling for Hong Kong stocks from yfinance - if source == DataSource.YFINANCE and ".HK" in source_ticker: - symbol = source_ticker.replace(".HK", "") # Remove .HK suffix - # Keep as digits only, no leading zero removal for internal format - if symbol.isdigit(): - # Pad to 5 digits for Hong Kong stocks - symbol = symbol.zfill(5) - return f"HKEX:{symbol}" - - # Special handling for Shanghai stocks from yfinance - if source == DataSource.YFINANCE and ".SS" in source_ticker: - symbol = source_ticker.replace(".SS", "") - return f"SSE:{symbol}" - - # Special handling for Shenzhen stocks from yfinance - if source == DataSource.YFINANCE and ".SZ" in source_ticker: - symbol = source_ticker.replace(".SZ", "") - return f"SZSE:{symbol}" - - # Check for known suffixes - if source in self.reverse_mappings: - for suffix, exchange in self.reverse_mappings[source].items(): - if source_ticker.endswith(suffix): - symbol = ( - source_ticker[: -len(suffix)] if suffix else source_ticker - ) - return f"{exchange}:{symbol}" - - # If no suffix found and default exchange provided - if default_exchange: - # For US stocks from yfinance, symbol is already clean - return f"{default_exchange}:{source_ticker}" - - # For other assets without clear exchange mapping - # Fallback to using the source as exchange - return f"{source.value.upper()}:{source_ticker}" - - except Exception as e: - logger.error(f"Error converting ticker {source_ticker}: {e}") - return f"UNKNOWN:{source_ticker}" - - def get_supported_exchanges(self, source: DataSource) -> List[str]: - """Get list of supported exchanges for a data source.""" - return list(self.exchange_mappings.get(source, {}).keys()) + # Support both Exchange enum and string for backward compatibility + if isinstance(exchange, str): + exchange_str = exchange + else: + exchange_str = exchange.value + + return any( + ex.value == exchange_str if isinstance(ex, Exchange) else ex == exchange_str + for ex in self.exchanges + ) class BaseDataAdapter(ABC): @@ -230,7 +69,6 @@ def __init__(self, source: DataSource, api_key: Optional[str] = None, **kwargs): self.source = source self.api_key = api_key self.config = kwargs - self.converter = TickerConverter() self.logger = logging.getLogger(f"{__name__}.{source.value}") # Initialize adapter-specific configuration @@ -322,145 +160,79 @@ def validate_ticker(self, ticker: str) -> bool: """Validate if a ticker format is supported by this adapter. Args: - ticker: Ticker in internal format + ticker: Ticker in internal format (e.g., "NASDAQ:AAPL") Returns: True if ticker is valid for this adapter """ try: + if ":" not in ticker: + return False + exchange, _ = ticker.split(":", 1) - supported_exchanges = self.converter.get_supported_exchanges(self.source) - return exchange in supported_exchanges - except ValueError: + capabilities = self.get_capabilities() + + # Check if any capability supports this exchange + return any(cap.supports_exchange(exchange) for cap in capabilities) + except Exception: return False + @abstractmethod def convert_to_source_ticker(self, internal_ticker: str) -> str: - """Convert internal ticker to data source format.""" - return self.converter.to_source_format(internal_ticker, self.source) - - def convert_to_internal_ticker( - self, source_ticker: str, default_exchange: Optional[str] = None - ) -> str: - """Convert data source ticker to internal format.""" - return self.converter.to_internal_format( - source_ticker, self.source, default_exchange - ) - - def is_market_open(self, exchange: str) -> bool: - """Check if a specific market is currently open. + """Convert internal ticker to data source format. Args: - exchange: Exchange identifier + internal_ticker: Ticker in internal format (e.g., "NASDAQ:AAPL") + source: Target data source Returns: - True if market is open, False otherwise + Ticker in data source specific format (e.g., "AAPL" for yfinance) """ - # This is a basic implementation - subclasses should override - # with more accurate market hours checking - now = datetime.utcnow() - hour = now.hour - - # Basic US market hours (9:30 AM - 4:00 PM EST = 14:30 - 21:00 UTC) - if exchange in ["NASDAQ", "NYSE"]: - return 14 <= hour < 21 - - # Basic Chinese market hours (9:30 AM - 3:00 PM CST = 1:30 - 7:00 UTC) - elif exchange in ["SSE", "SZSE"]: - return 1 <= hour < 7 - - # For crypto markets, assume always open - elif exchange in ["CRYPTO"]: - return True - - return False - - def get_supported_asset_types(self) -> List[AssetType]: - """Get list of asset types supported by this adapter.""" - # Default implementation - subclasses should override - return [AssetType.STOCK] + pass - def health_check(self) -> Dict[str, Any]: - """Perform health check on the data adapter. + @abstractmethod + def convert_to_internal_ticker( + self, source_ticker: str, default_exchange: Optional[str] = None + ) -> str: + """Convert data source ticker to internal format. + Args: + source_ticker: Ticker in data source format (e.g., "000001.SZ") + source: Source data provider + default_exchange: Default exchange if cannot be determined from ticker Returns: - Dictionary containing health status information + Ticker in internal format (e.g., "SZSE:000001") """ - try: - # Try to make a simple API call to test connectivity - test_result = self._perform_health_check() - return { - "source": self.source.value, - "status": "healthy" if test_result else "unhealthy", - "timestamp": datetime.utcnow().isoformat(), - "details": test_result, - } - except Exception as e: - return { - "source": self.source.value, - "status": "error", - "timestamp": datetime.utcnow().isoformat(), - "error": str(e), - } + pass @abstractmethod - def _perform_health_check(self) -> Any: - """Perform adapter-specific health check. + def get_capabilities(self) -> List[AdapterCapability]: + """Get detailed capabilities describing supported asset types and exchanges. Returns: - Health check result (implementation-specific) + List of capabilities describing what this adapter can handle """ pass + def get_supported_asset_types(self) -> List[AssetType]: + """Get list of asset types supported by this adapter. -class AdapterError(Exception): - """Base exception class for adapter-related errors.""" - - def __init__( - self, - message: str, - source: Optional[DataSource] = None, - ticker: Optional[str] = None, - ): - """Initialize adapter error. - - Args: - message: Error message - source: Data source where error occurred - ticker: Asset ticker related to the error + This method extracts asset types from capabilities. """ - self.source = source - self.ticker = ticker - super().__init__(message) - + capabilities = self.get_capabilities() + asset_types = set() + for cap in capabilities: + asset_types.add(cap.asset_type) + return list(asset_types) -class RateLimitError(AdapterError): - """Exception raised when API rate limits are exceeded.""" + def get_supported_exchanges(self) -> Set[Exchange]: + """Get set of all exchanges supported by this adapter. - def __init__(self, message: str, retry_after: Optional[int] = None, **kwargs): - """Initialize rate limit error. - - Args: - message: Error message - retry_after: Seconds to wait before retrying - **kwargs: Additional error context + Returns: + Set of Exchange enums """ - self.retry_after = retry_after - super().__init__(message, **kwargs) - - -class DataNotAvailableError(AdapterError): - """Exception raised when requested data is not available.""" - - pass - - -class AuthenticationError(AdapterError): - """Exception raised when API authentication fails.""" - - pass - - -class InvalidTickerError(AdapterError): - """Exception raised when ticker format is invalid or not supported.""" - - pass + capabilities = self.get_capabilities() + exchanges: Set[Exchange] = set() + for cap in capabilities: + exchanges.update(cap.exchanges) + return exchanges diff --git a/python/valuecell/adapters/assets/i18n_integration.py b/python/valuecell/adapters/assets/i18n_integration.py index 0a3c22b91..7e8df5a1e 100644 --- a/python/valuecell/adapters/assets/i18n_integration.py +++ b/python/valuecell/adapters/assets/i18n_integration.py @@ -284,7 +284,11 @@ def get_market_status_display_name( return t(key, default=status.value.replace("_", " ").title()) def format_currency_amount( - self, amount: float, currency: str, language: Optional[str] = None + self, + amount: float, + currency: str, + language: Optional[str] = None, + asset_type: Optional[str] = None, ) -> str: """Format currency amount according to locale. @@ -292,6 +296,7 @@ def format_currency_amount( amount: Amount to format currency: Currency code language: Target language (uses current i18n config if None) + asset_type: Asset type (e.g., 'stock', 'index', 'etf'). If 'index', no currency symbol is added. Returns: Formatted currency string @@ -301,6 +306,10 @@ def format_currency_amount( else: config = I18nConfig(language=language) + # For index type assets, don't add currency symbol + if asset_type == "index": + return config.format_number(amount, 2) + # Use the existing i18n currency formatting if currency == "USD": return f"${config.format_number(amount, 2)}" diff --git a/python/valuecell/adapters/assets/manager.py b/python/valuecell/adapters/assets/manager.py index 5fcb8f1cf..b24045289 100644 --- a/python/valuecell/adapters/assets/manager.py +++ b/python/valuecell/adapters/assets/manager.py @@ -4,11 +4,15 @@ and routing requests to the appropriate providers based on asset types and availability. """ +import json import logging +import os import threading from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional + +from openai import OpenAI from .akshare_adapter import AKShareAdapter from .base import BaseDataAdapter @@ -19,6 +23,7 @@ AssetSearchResult, AssetType, DataSource, + Exchange, Watchlist, ) from .yfinance_adapter import YFinanceAdapter @@ -32,55 +37,67 @@ class AdapterManager: def __init__(self): """Initialize adapter manager.""" self.adapters: Dict[DataSource, BaseDataAdapter] = {} - self.adapter_priorities: Dict[AssetType, List[DataSource]] = {} - self.lock = threading.RLock() - # Default adapter priorities by asset type - self._set_default_priorities() + # Exchange → Adapters routing table (simplified) + # Note: Keys are Exchange.value strings for efficient lookup + self.exchange_routing: Dict[str, List[BaseDataAdapter]] = {} + + # Ticker → Adapter cache for fast lookups + self._ticker_cache: Dict[str, BaseDataAdapter] = {} + self._cache_lock = threading.Lock() + + self.lock = threading.RLock() logger.info("Asset adapter manager initialized") - def _set_default_priorities(self) -> None: - """Set default adapter priorities for different asset types.""" - self.adapter_priorities = { - AssetType.STOCK: [ - DataSource.YFINANCE, - DataSource.AKSHARE, - ], - AssetType.ETF: [ - DataSource.YFINANCE, - DataSource.AKSHARE, - ], - AssetType.CRYPTO: [ - DataSource.YFINANCE, - ], - AssetType.INDEX: [ - DataSource.YFINANCE, - DataSource.AKSHARE, - ], - } + def _rebuild_routing_table(self) -> None: + """Rebuild routing table based on registered adapters' capabilities. + + Simplified: Only use exchange to determine adapter routing. + """ + with self.lock: + self.exchange_routing.clear() + + # Build routing table: Exchange → List[Adapters] + for adapter in self.adapters.values(): + capabilities = adapter.get_capabilities() + + # Get all exchanges supported by this adapter (across all asset types) + supported_exchanges = set() + for cap in capabilities: + for exchange in cap.exchanges: + exchange_key = ( + exchange.value + if isinstance(exchange, Exchange) + else exchange + ) + supported_exchanges.add(exchange_key) + + # Register adapter for each supported exchange + for exchange_key in supported_exchanges: + if exchange_key not in self.exchange_routing: + self.exchange_routing[exchange_key] = [] + self.exchange_routing[exchange_key].append(adapter) + + # Clear ticker cache when routing table changes + with self._cache_lock: + self._ticker_cache.clear() + + logger.debug( + f"Routing table rebuilt with {len(self.exchange_routing)} exchanges" + ) def register_adapter(self, adapter: BaseDataAdapter) -> None: - """Register a data adapter. + """Register a data adapter and rebuild routing table. Args: adapter: Data adapter instance to register """ with self.lock: self.adapters[adapter.source] = adapter + self._rebuild_routing_table() logger.info(f"Registered adapter: {adapter.source.value}") - def unregister_adapter(self, source: DataSource) -> None: - """Unregister a data adapter. - - Args: - source: Data source to unregister - """ - with self.lock: - if source in self.adapters: - del self.adapters[source] - logger.info(f"Unregistered adapter: {source.value}") - def configure_yfinance(self, **kwargs) -> None: """Configure and register Yahoo Finance adapter.""" try: @@ -106,60 +123,174 @@ def get_available_adapters(self) -> List[DataSource]: with self.lock: return list(self.adapters.keys()) + def get_adapters_for_exchange(self, exchange: str) -> List[BaseDataAdapter]: + """Get list of adapters for a specific exchange. + + Args: + exchange: Exchange identifier (e.g., "NASDAQ", "SSE") + + Returns: + List of adapters that support the exchange + """ + with self.lock: + return self.exchange_routing.get(exchange, []) + def get_adapters_for_asset_type( self, asset_type: AssetType ) -> List[BaseDataAdapter]: - """Get prioritized list of adapters for an asset type. + """Get list of adapters that support a specific asset type. + + Note: This collects adapters across all exchanges. Consider using + get_adapters_for_exchange() for more specific routing. Args: asset_type: Type of asset Returns: - List of adapters in priority order + List of adapters that support this asset type """ with self.lock: - priority_sources = self.adapter_priorities.get(asset_type, []) - adapters = [] - - for source in priority_sources: - if source in self.adapters: - adapters.append(self.adapters[source]) + # Collect all adapters that support this asset type + supporting_adapters = set() + for adapter in self.adapters.values(): + supported_types = adapter.get_supported_asset_types() + if asset_type in supported_types: + supporting_adapters.add(adapter) - return adapters + return list(supporting_adapters) def get_adapter_for_ticker(self, ticker: str) -> Optional[BaseDataAdapter]: - """Get the best adapter for a specific ticker. + """Get the best adapter for a specific ticker (with caching). + + Simplified: Only based on exchange, first adapter that validates wins. Args: - ticker: Asset ticker in internal format + ticker: Asset ticker in internal format (e.g., "NASDAQ:AAPL") Returns: - Best available adapter for the ticker + Best available adapter for the ticker or None if not found """ - with self.lock: - # Try to determine asset type from ticker - exchange = ticker.split(":")[0] if ":" in ticker else "" - - # Map exchanges to likely asset types - exchange_asset_mapping = { - "NASDAQ": AssetType.STOCK, - "NYSE": AssetType.STOCK, - "SSE": AssetType.STOCK, - "SZSE": AssetType.STOCK, - "HKEX": AssetType.STOCK, - "CRYPTO": AssetType.CRYPTO, - } + # Check cache first + with self._cache_lock: + if ticker in self._ticker_cache: + return self._ticker_cache[ticker] + + # Parse ticker + if ":" not in ticker: + logger.warning(f"Invalid ticker format (missing ':'): {ticker}") + return None - asset_type = exchange_asset_mapping.get(exchange, AssetType.STOCK) - adapters = self.get_adapters_for_asset_type(asset_type) + exchange, symbol = ticker.split(":", 1) - # Return first adapter that supports this ticker - for adapter in adapters: - if adapter.validate_ticker(ticker): - return adapter + # Get adapters for this exchange + adapters = self.get_adapters_for_exchange(exchange) + if not adapters: + logger.debug(f"No adapters registered for exchange: {exchange}") return None + # Find first adapter that validates this ticker + for adapter in adapters: + if adapter.validate_ticker(ticker): + # Cache the result + with self._cache_lock: + self._ticker_cache[ticker] = adapter + logger.debug(f"Matched adapter {adapter.source.value} for {ticker}") + return adapter + + logger.warning(f"No suitable adapter found for ticker: {ticker}") + return None + + def _deduplicate_search_results( + self, results: List[AssetSearchResult] + ) -> List[AssetSearchResult]: + """Smart deduplication of search results to handle cross-exchange duplicates. + + This method handles cases where the same asset appears on multiple exchanges + (e.g., AMEX:GORO vs NASDAQ:GORO). It prioritizes certain exchanges and removes + likely duplicates based on symbol matching. + + Args: + results: List of search results to deduplicate + + Returns: + Deduplicated list of search results + """ + # Exchange priority for US stocks (higher number = higher priority) + exchange_priority = { + "NASDAQ": 3, + "NYSE": 2, + "AMEX": 1, + "HKEX": 3, + "SSE": 2, + "SZSE": 2, + "BSE": 1, + } + + seen_tickers = set() + # Map: (symbol, country) -> best result so far + symbol_map: Dict[tuple, AssetSearchResult] = {} + unique_results = [] + + for result in results: + # Skip exact ticker duplicates + if result.ticker in seen_tickers: + continue + + try: + exchange, symbol = result.ticker.split(":", 1) + except ValueError: + # Invalid ticker format, skip + logger.warning( + f"Invalid ticker format in search result: {result.ticker}" + ) + continue + + # Create a key for cross-exchange deduplication + # Group by symbol and country to identify potential duplicates + dedup_key = (symbol.upper(), result.country) + + # Check if we've seen this symbol in the same country before + if dedup_key in symbol_map: + existing_result = symbol_map[dedup_key] + existing_exchange = existing_result.ticker.split(":")[0] + + # Compare exchange priorities + current_priority = exchange_priority.get(exchange, 0) + existing_priority = exchange_priority.get(existing_exchange, 0) + + if current_priority > existing_priority: + # Replace with higher priority exchange + symbol_map[dedup_key] = result + logger.debug( + f"Preferring {result.ticker} over {existing_result.ticker} (priority)" + ) + elif current_priority == existing_priority: + # Same priority, prefer the one with higher relevance score + if result.relevance_score > existing_result.relevance_score: + symbol_map[dedup_key] = result + logger.debug( + f"Preferring {result.ticker} over {existing_result.ticker} (relevance)" + ) + # else: keep existing result (lower priority exchange) + else: + # First time seeing this symbol, add it + symbol_map[dedup_key] = result + + seen_tickers.add(result.ticker) + + # Convert map back to list + unique_results = list(symbol_map.values()) + + # Sort by relevance score (descending) + unique_results.sort(key=lambda x: x.relevance_score, reverse=True) + + logger.info( + f"Deduplicated {len(results)} results to {len(unique_results)} unique assets" + ) + + return unique_results + def search_assets(self, query: AssetSearchQuery) -> List[AssetSearchResult]: """Search for assets across all available adapters. @@ -174,13 +305,9 @@ def search_assets(self, query: AssetSearchQuery) -> List[AssetSearchResult]: # Determine which adapters to use based on asset types target_adapters = set() - if query.asset_types: - for asset_type in query.asset_types: - target_adapters.update(self.get_adapters_for_asset_type(asset_type)) - else: - # Use all available adapters - with self.lock: - target_adapters.update(self.adapters.values()) + # Use all available adapters + with self.lock: + target_adapters.update(self.adapters.values()) # Search in parallel across adapters if not target_adapters: @@ -195,29 +322,171 @@ def search_assets(self, query: AssetSearchQuery) -> List[AssetSearchResult]: for future in as_completed(future_to_adapter): adapter = future_to_adapter[future] try: - results = future.result(timeout=30) # 30 second timeout + results = future.result(timeout=15) all_results.extend(results) except Exception as e: logger.warning( f"Search failed for adapter {adapter.source.value}: {e}" ) - # Deduplicate results by ticker - seen_tickers = set() - unique_results = [] - - # Sort by relevance score first - all_results.sort(key=lambda x: x.relevance_score, reverse=True) + # Smart deduplication of results + unique_results = self._deduplicate_search_results(all_results) - for result in all_results: - if result.ticker not in seen_tickers: - seen_tickers.add(result.ticker) - unique_results.append(result) + # Use fallback search if no results found + if len(unique_results) == 0: + logger.info( + f"No results from adapters, trying fallback search for query: {query.query}" + ) + fallback_results = self._fallback_search_assets(query) + # Deduplicate fallback results with existing results + combined_results = unique_results + fallback_results + unique_results = self._deduplicate_search_results(combined_results) return unique_results[: query.limit] + def _fallback_search_assets( + self, query: AssetSearchQuery + ) -> List[AssetSearchResult]: + """Fallback search assets if no results are found using LLM-based ticker generation. + + This method uses an OpenAI-like API to intelligently generate possible ticker formats + based on the user's search query, then validates each generated ticker. + + Args: + query: Search query parameters + + Returns: + List of validated search results + """ + # Get environment variables + api_key = os.getenv("OPENROUTER_API_KEY") + model_id = os.getenv("PRODUCT_MODEL_ID") + + if not api_key or not model_id: + logger.warning( + "OPENROUTER_API_KEY or PRODUCT_MODEL_ID not configured, skipping fallback search" + ) + return [] + + try: + # Initialize OpenAI client with OpenRouter + client = OpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1") + + # Create prompt to generate possible ticker formats + prompt = f"""Given the user search query: "{query.query}" + +Generate a list of possible internal ticker IDs that match this query. The internal ticker format is: EXCHANGE:SYMBOL + +Supported exchanges and their formats: +- NASDAQ: NASDAQ:SYMBOL (e.g., NASDAQ:AAPL, NASDAQ:MSFT) +- NYSE: NYSE:SYMBOL (e.g., NYSE:JPM, NYSE:BAC) +- AMEX: AMEX:SYMBOL (e.g., AMEX:GORO, AMEX:GLD) +- SSE: SSE:SYMBOL (Shanghai Stock Exchange, 6-digit code, e.g., SSE:601398, SSE:510050) +- SZSE: SZSE:SYMBOL (Shenzhen Stock Exchange, 6-digit code, e.g., SZSE:000001, SZSE:002594, SZSE:300750) +- BSE: BSE:SYMBOL (Beijing Stock Exchange, 6-digit code, e.g., BSE:835368, BSE:560800) +- HKEX: HKEX:SYMBOL (Hong Kong Stock Exchange, 5-digit code with leading zeros, e.g., HKEX:00700, HKEX:03033) +- CRYPTO: CRYPTO:SYMBOL (e.g., CRYPTO:BTC, CRYPTO:ETH) + +Consider: +1. Common stock symbols and company names +2. Chinese company names (if query contains Chinese characters) +3. Cryptocurrency names +4. Index names +5. ETF names + +Return ONLY a JSON array of ticker strings, like: +["NASDAQ:AAPL", "NYSE:AAPL", "HKEX:00700"] + +Generate up to at least 1 possible ticker candidate up to 10. Be creative but realistic.""" + + # Call LLM API + response = client.chat.completions.create( + model=model_id, + messages=[ + { + "role": "system", + "content": "You are a financial data expert that helps map search queries to standardized ticker formats. Always respond with valid JSON arrays only.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.7, + max_tokens=500, + ) + + # Parse response + response_text = response.choices[0].message.content.strip() + logger.debug(f"LLM response for query '{query.query}': {response_text}") + + # Extract JSON array from response (handle cases where LLM adds markdown formatting) + if response_text.startswith("```json"): + response_text = ( + response_text.split("```json")[1].split("```")[0].strip() + ) + elif response_text.startswith("```"): + response_text = response_text.split("```")[1].split("```")[0].strip() + + possible_tickers = json.loads(response_text) + + if not isinstance(possible_tickers, list): + logger.warning(f"LLM response is not a list: {possible_tickers}") + return [] + + # Validate each ticker and convert to search results + results = [] + seen_tickers = set() + + for ticker in possible_tickers: + if not isinstance(ticker, str): + continue + + ticker = ticker.strip().upper() + + # Skip duplicates + if ticker in seen_tickers: + continue + + # Validate ticker format + if ":" not in ticker: + continue + + # Try to get asset info + try: + asset_info = self.get_asset_info(ticker) + + if asset_info: + seen_tickers.add(ticker) + + # Convert Asset to AssetSearchResult + search_result = AssetSearchResult( + ticker=asset_info.ticker, + asset_type=asset_info.asset_type, + names=asset_info.names.names, + exchange=asset_info.market_info.exchange, + country=asset_info.market_info.country, + ) + results.append(search_result) + + logger.info(f"Fallback search found valid asset: {ticker}") + + # Stop if we have enough results + if len(results) >= query.limit: + break + + except Exception as e: + logger.debug(f"Ticker {ticker} validation failed: {e}") + continue + + logger.info( + f"Fallback search returned {len(results)} results for query '{query.query}'" + ) + return results + + except Exception as e: + logger.error(f"Fallback search failed: {e}", exc_info=True) + return [] + def get_asset_info(self, ticker: str) -> Optional[Asset]: - """Get detailed asset information. + """Get detailed asset information with automatic failover. Args: ticker: Asset ticker in internal format @@ -225,19 +494,69 @@ def get_asset_info(self, ticker: str) -> Optional[Asset]: Returns: Asset information or None if not found """ + # Get the primary adapter for this ticker adapter = self.get_adapter_for_ticker(ticker) + if not adapter: logger.warning(f"No suitable adapter found for ticker: {ticker}") return None + # Try the primary adapter try: - return adapter.get_asset_info(ticker) + logger.debug( + f"Fetching asset info for {ticker} from {adapter.source.value}" + ) + asset_info = adapter.get_asset_info(ticker) + if asset_info: + logger.info( + f"Successfully fetched asset info for {ticker} from {adapter.source.value}" + ) + return asset_info + else: + logger.debug( + f"Adapter {adapter.source.value} returned None for {ticker}" + ) except Exception as e: - logger.error(f"Error fetching asset info for {ticker}: {e}") - return None + logger.warning( + f"Primary adapter {adapter.source.value} failed for {ticker}: {e}" + ) + + # Automatic failover: try other adapters for this exchange + exchange = ticker.split(":")[0] if ":" in ticker else "" + fallback_adapters = self.get_adapters_for_exchange(exchange) + + for fallback_adapter in fallback_adapters: + # Skip the primary adapter we already tried + if fallback_adapter.source == adapter.source: + continue + + if not fallback_adapter.validate_ticker(ticker): + continue + + try: + logger.debug( + f"Fallback: trying {fallback_adapter.source.value} for {ticker}" + ) + asset_info = fallback_adapter.get_asset_info(ticker) + if asset_info: + logger.info( + f"Fallback success: fetched asset info for {ticker} from {fallback_adapter.source.value}" + ) + # Update cache to use successful adapter + with self._cache_lock: + self._ticker_cache[ticker] = fallback_adapter + return asset_info + except Exception as e: + logger.warning( + f"Fallback adapter {fallback_adapter.source.value} failed for {ticker}: {e}" + ) + continue + + logger.error(f"All adapters failed for {ticker}") + return None def get_real_time_price(self, ticker: str) -> Optional[AssetPrice]: - """Get real-time price for an asset. + """Get real-time price for an asset with automatic failover. Args: ticker: Asset ticker in internal format @@ -245,21 +564,69 @@ def get_real_time_price(self, ticker: str) -> Optional[AssetPrice]: Returns: Current price data or None if not available """ + # Get the primary adapter for this ticker adapter = self.get_adapter_for_ticker(ticker) + if not adapter: logger.warning(f"No suitable adapter found for ticker: {ticker}") return None + # Try the primary adapter try: - return adapter.get_real_time_price(ticker) + logger.debug(f"Fetching price for {ticker} from {adapter.source.value}") + price = adapter.get_real_time_price(ticker) + if price: + logger.info( + f"Successfully fetched price for {ticker} from {adapter.source.value}" + ) + return price + else: + logger.debug( + f"Adapter {adapter.source.value} returned None for {ticker}" + ) except Exception as e: - logger.error(f"Error fetching real-time price for {ticker}: {e}") - return None + logger.warning( + f"Primary adapter {adapter.source.value} failed for {ticker}: {e}" + ) + + # Automatic failover: try other adapters for this exchange + exchange = ticker.split(":")[0] if ":" in ticker else "" + fallback_adapters = self.get_adapters_for_exchange(exchange) + + for fallback_adapter in fallback_adapters: + # Skip the primary adapter we already tried + if fallback_adapter.source == adapter.source: + continue + + if not fallback_adapter.validate_ticker(ticker): + continue + + try: + logger.debug( + f"Fallback: trying {fallback_adapter.source.value} for {ticker}" + ) + price = fallback_adapter.get_real_time_price(ticker) + if price: + logger.info( + f"Fallback success: fetched price for {ticker} from {fallback_adapter.source.value}" + ) + # Update cache to use successful adapter + with self._cache_lock: + self._ticker_cache[ticker] = fallback_adapter + return price + except Exception as e: + logger.warning( + f"Fallback adapter {fallback_adapter.source.value} failed for {ticker}: {e}" + ) + continue + + logger.error(f"All adapters failed for {ticker}") + return None def get_multiple_prices( self, tickers: List[str] ) -> Dict[str, Optional[AssetPrice]]: - """Get real-time prices for multiple assets efficiently. + """Get real-time prices for multiple assets efficiently with automatic failover. Args: tickers: List of asset tickers @@ -279,6 +646,7 @@ def get_multiple_prices( # Fetch prices in parallel from each adapter all_results = {} + failed_tickers = [] if not adapter_tickers: # If no adapters found for any tickers, return None for all @@ -294,11 +662,29 @@ def get_multiple_prices( adapter = future_to_adapter[future] try: results = future.result(timeout=60) # 60 second timeout - all_results.update(results) + # Separate successful and failed results + for ticker, price in results.items(): + if price is not None: + all_results[ticker] = price + else: + failed_tickers.append(ticker) except Exception as e: logger.warning( f"Batch price fetch failed for adapter {adapter.source.value}: {e}" ) + # Mark all tickers from this adapter as failed + failed_tickers.extend(adapter_tickers[adapter]) + + # Retry failed tickers individually with fallback adapters + if failed_tickers: + logger.info( + f"Retrying {len(failed_tickers)} failed tickers with fallback adapters" + ) + for ticker in failed_tickers: + if ticker not in all_results or all_results[ticker] is None: + # Try to get price with automatic failover + price = self.get_real_time_price(ticker) + all_results[ticker] = price # Ensure all requested tickers are in results for ticker in tickers: @@ -314,7 +700,7 @@ def get_historical_prices( end_date: datetime, interval: str = "1d", ) -> List[AssetPrice]: - """Get historical price data for an asset. + """Get historical price data for an asset with automatic failover. Args: ticker: Asset ticker in internal format @@ -325,83 +711,70 @@ def get_historical_prices( Returns: List of historical price data """ + # Get the primary adapter for this ticker adapter = self.get_adapter_for_ticker(ticker) + if not adapter: logger.warning(f"No suitable adapter found for ticker: {ticker}") return [] + # Try the primary adapter try: - return adapter.get_historical_prices(ticker, start_date, end_date, interval) + logger.debug( + f"Fetching historical data for {ticker} from {adapter.source.value}" + ) + prices = adapter.get_historical_prices( + ticker, start_date, end_date, interval + ) + if prices: + logger.info( + f"Successfully fetched {len(prices)} historical prices for {ticker} from {adapter.source.value}" + ) + return prices + else: + logger.debug( + f"Adapter {adapter.source.value} returned empty historical data for {ticker}" + ) except Exception as e: - logger.error(f"Error fetching historical prices for {ticker}: {e}") - return [] - - def health_check(self) -> Dict[DataSource, Dict[str, Any]]: - """Perform health check on all registered adapters. - - Returns: - Dictionary mapping data sources to health status - """ - health_results = {} - - # If no adapters are registered, return empty results - if not self.adapters: - return health_results - - with ThreadPoolExecutor(max_workers=len(self.adapters)) as executor: - future_to_source = { - executor.submit(adapter.health_check): source - for source, adapter in self.adapters.items() - } - - for future in as_completed(future_to_source): - source = future_to_source[future] - try: - result = future.result(timeout=30) - health_results[source] = result - except Exception as e: - health_results[source] = { - "status": "error", - "message": f"Health check failed: {e}", - "timestamp": datetime.utcnow().isoformat(), - } + logger.warning( + f"Primary adapter {adapter.source.value} failed for historical data of {ticker}: {e}" + ) - return health_results + # Automatic failover: try other adapters for this exchange + exchange = ticker.split(":")[0] if ":" in ticker else "" + fallback_adapters = self.get_adapters_for_exchange(exchange) - def get_supported_asset_types(self) -> Dict[DataSource, List[AssetType]]: - """Get supported asset types for each adapter. + for fallback_adapter in fallback_adapters: + # Skip the primary adapter we already tried + if fallback_adapter.source == adapter.source: + continue - Returns: - Dictionary mapping data sources to supported asset types - """ - supported_types = {} + if not fallback_adapter.validate_ticker(ticker): + continue - with self.lock: - for source, adapter in self.adapters.items(): - try: - supported_types[source] = adapter.get_supported_asset_types() - except Exception as e: - logger.warning( - f"Error getting supported types for {source.value}: {e}" + try: + logger.debug( + f"Fallback: trying {fallback_adapter.source.value} for historical data of {ticker}" + ) + prices = fallback_adapter.get_historical_prices( + ticker, start_date, end_date, interval + ) + if prices: + logger.info( + f"Fallback success: fetched {len(prices)} historical prices for {ticker} from {fallback_adapter.source.value}" ) - supported_types[source] = [] - - return supported_types - - def set_adapter_priority( - self, asset_type: AssetType, sources: List[DataSource] - ) -> None: - """Set adapter priority for an asset type. + # Update cache to use successful adapter + with self._cache_lock: + self._ticker_cache[ticker] = fallback_adapter + return prices + except Exception as e: + logger.warning( + f"Fallback adapter {fallback_adapter.source.value} failed for historical data of {ticker}: {e}" + ) + continue - Args: - asset_type: Asset type to configure - sources: List of data sources in priority order - """ - with self.lock: - self.adapter_priorities[asset_type] = sources - logger.info( - f"Updated adapter priority for {asset_type.value}: {[s.value for s in sources]}" - ) + logger.error(f"All adapters failed for historical data of {ticker}") + return [] class WatchlistManager: diff --git a/python/valuecell/adapters/assets/tests/test_adapters_comparison.py b/python/valuecell/adapters/assets/tests/test_adapters_comparison.py new file mode 100644 index 000000000..ff4cc3da4 --- /dev/null +++ b/python/valuecell/adapters/assets/tests/test_adapters_comparison.py @@ -0,0 +1,409 @@ +"""Test script to compare YFinance and AKShare adapters functionality.""" + +import logging +from datetime import datetime, timedelta +from typing import Dict, Optional + +from valuecell.adapters.assets.yfinance_adapter import YFinanceAdapter +from valuecell.adapters.assets.akshare_adapter import AKShareAdapter +from valuecell.adapters.assets.types import Asset, AssetPrice, Interval + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# Test data organized by asset type +TEST_TICKERS = { + "STOCK": [ + "NASDAQ:AAPL", + "AMEX:GORO", + "NYSE:JPM", + "HKEX:00700", + "SSE:601398", + "SZSE:002594", + "SZSE:300750", + "BSE:835368", + "CRYPTO:BTC", + ], + "ETF": [ + "NASDAQ:QQQ", + "AMEX:GLD", + "NYSE:SPY", + "HKEX:03033", + "SSE:510050", + "SZSE:159919", + "BSE:560800", + ], + "INDEX": [ + "NASDAQ:IXIC", + "AMEX:RUT", + "NYSE:DJI", + "HKEX:HSI", + "SSE:000001", + "SZSE:399001", + "BSE:899050", + ], +} + + +class AdapterTestResult: + """Store test results for a single adapter and ticker.""" + + def __init__(self, ticker: str, adapter_name: str): + self.ticker = ticker + self.adapter_name = adapter_name + self.asset_info_success = False + self.asset_info_data: Optional[Asset] = None + self.asset_info_error: Optional[str] = None + + self.real_time_price_success = False + self.real_time_price_data: Optional[AssetPrice] = None + self.real_time_price_error: Optional[str] = None + + self.historical_prices_success = False + self.historical_prices_count = 0 + self.historical_prices_error: Optional[str] = None + + +class AdapterTester: + """Test adapter functionality and generate comparison reports.""" + + def __init__(self): + self.yfinance_adapter = YFinanceAdapter() + self.akshare_adapter = AKShareAdapter() + self.results: Dict[str, Dict[str, AdapterTestResult]] = {} + + def test_get_asset_info( + self, adapter, ticker: str + ) -> tuple[bool, Optional[Asset], Optional[str]]: + """Test get_asset_info function.""" + try: + asset = adapter.get_asset_info(ticker) + if asset: + return True, asset, None + else: + return False, None, "No data returned" + except Exception as e: + return False, None, str(e) + + def test_get_real_time_price( + self, adapter, ticker: str + ) -> tuple[bool, Optional[AssetPrice], Optional[str]]: + """Test get_real_time_price function.""" + try: + price = adapter.get_real_time_price(ticker) + if price: + return True, price, None + else: + return False, None, "No data returned" + except Exception as e: + return False, None, str(e) + + def test_get_historical_prices( + self, adapter, ticker: str + ) -> tuple[bool, int, Optional[str]]: + """Test get_historical_prices function.""" + try: + end_date = datetime.now() + start_date = end_date - timedelta(days=30) + # Use proper interval format: "1" + Interval.DAY + interval = f"1{Interval.DAY}" + prices = adapter.get_historical_prices( + ticker, start_date, end_date, interval=interval + ) + if prices: + return True, len(prices), None + else: + return False, 0, "No data returned" + except Exception as e: + return False, 0, str(e) + + def test_ticker(self, ticker: str) -> None: + """Test a single ticker with both adapters.""" + logger.info(f"\n{'=' * 80}") + logger.info(f"Testing ticker: {ticker}") + logger.info(f"{'=' * 80}") + + # Initialize results storage + if ticker not in self.results: + self.results[ticker] = {} + + # Test YFinance adapter + logger.info("\n--- Testing YFinance Adapter ---") + yf_result = AdapterTestResult(ticker, "YFinance") + + # Test asset info + logger.info("Testing get_asset_info...") + success, data, error = self.test_get_asset_info(self.yfinance_adapter, ticker) + yf_result.asset_info_success = success + yf_result.asset_info_data = data + yf_result.asset_info_error = error + logger.info(f"Result: {'✓ Success' if success else f'✗ Failed: {error}'}") + + # Test real-time price + logger.info("Testing get_real_time_price...") + success, data, error = self.test_get_real_time_price( + self.yfinance_adapter, ticker + ) + yf_result.real_time_price_success = success + yf_result.real_time_price_data = data + yf_result.real_time_price_error = error + logger.info(f"Result: {'✓ Success' if success else f'✗ Failed: {error}'}") + + # Test historical prices + logger.info("Testing get_historical_prices...") + success, count, error = self.test_get_historical_prices( + self.yfinance_adapter, ticker + ) + yf_result.historical_prices_success = success + yf_result.historical_prices_count = count + yf_result.historical_prices_error = error + logger.info( + f"Result: {'✓ Success' if success else f'✗ Failed: {error}'} (Count: {count})" + ) + + self.results[ticker]["yfinance"] = yf_result + + # Test AKShare adapter + logger.info("\n--- Testing AKShare Adapter ---") + ak_result = AdapterTestResult(ticker, "AKShare") + + # Test asset info + logger.info("Testing get_asset_info...") + success, data, error = self.test_get_asset_info(self.akshare_adapter, ticker) + ak_result.asset_info_success = success + ak_result.asset_info_data = data + ak_result.asset_info_error = error + logger.info(f"Result: {'✓ Success' if success else f'✗ Failed: {error}'}") + + # Test real-time price + logger.info("Testing get_real_time_price...") + success, data, error = self.test_get_real_time_price( + self.akshare_adapter, ticker + ) + ak_result.real_time_price_success = success + ak_result.real_time_price_data = data + ak_result.real_time_price_error = error + logger.info(f"Result: {'✓ Success' if success else f'✗ Failed: {error}'}") + + # Test historical prices + logger.info("Testing get_historical_prices...") + success, count, error = self.test_get_historical_prices( + self.akshare_adapter, ticker + ) + ak_result.historical_prices_success = success + ak_result.historical_prices_count = count + ak_result.historical_prices_error = error + logger.info( + f"Result: {'✓ Success' if success else f'✗ Failed: {error}'} (Count: {count})" + ) + + self.results[ticker]["akshare"] = ak_result + + def run_all_tests(self) -> None: + """Run tests for all tickers.""" + for asset_type, tickers in TEST_TICKERS.items(): + logger.info(f"\n\n{'#' * 80}") + logger.info(f"# Testing {asset_type}") + logger.info(f"{'#' * 80}") + + for ticker in tickers: + try: + self.test_ticker(ticker) + except Exception as e: + logger.error(f"Error testing ticker {ticker}: {e}", exc_info=True) + + def generate_report(self) -> str: + """Generate a comprehensive comparison report.""" + report_lines = [] + report_lines.append("=" * 120) + report_lines.append("ADAPTER COMPARISON REPORT: YFinance vs AKShare") + report_lines.append("=" * 120) + report_lines.append("") + + # Summary statistics + yf_total_success = { + "asset_info": 0, + "real_time_price": 0, + "historical_prices": 0, + } + ak_total_success = { + "asset_info": 0, + "real_time_price": 0, + "historical_prices": 0, + } + total_tests = len(self.results) + + for asset_type, tickers in TEST_TICKERS.items(): + report_lines.append(f"\n{'=' * 120}") + report_lines.append(f"ASSET TYPE: {asset_type}") + report_lines.append(f"{'=' * 120}\n") + + for ticker in tickers: + if ticker not in self.results: + continue + + yf_result = self.results[ticker].get("yfinance") + ak_result = self.results[ticker].get("akshare") + + if not yf_result or not ak_result: + continue + + report_lines.append(f"\nTicker: {ticker}") + report_lines.append("-" * 120) + + # Asset Info comparison + report_lines.append("\n1. GET_ASSET_INFO:") + report_lines.append( + f" YFinance: {'✓ SUCCESS' if yf_result.asset_info_success else f'✗ FAILED - {yf_result.asset_info_error}'}" + ) + if yf_result.asset_info_data: + report_lines.append( + f" - Names: {yf_result.asset_info_data.names.get_name('en-US')}" + ) + report_lines.append( + f" - Exchange: {yf_result.asset_info_data.market_info.exchange}" + ) + report_lines.append( + f" - Currency: {yf_result.asset_info_data.market_info.currency}" + ) + + report_lines.append( + f" AKShare: {'✓ SUCCESS' if ak_result.asset_info_success else f'✗ FAILED - {ak_result.asset_info_error}'}" + ) + if ak_result.asset_info_data: + report_lines.append( + f" - Names: {ak_result.asset_info_data.names.get_name('en-US')}" + ) + report_lines.append( + f" - Exchange: {ak_result.asset_info_data.market_info.exchange}" + ) + report_lines.append( + f" - Currency: {ak_result.asset_info_data.market_info.currency}" + ) + + # Real-time price comparison + report_lines.append("\n2. GET_REAL_TIME_PRICE:") + report_lines.append( + f" YFinance: {'✓ SUCCESS' if yf_result.real_time_price_success else f'✗ FAILED - {yf_result.real_time_price_error}'}" + ) + if yf_result.real_time_price_data: + report_lines.append( + f" - Price: {yf_result.real_time_price_data.price} {yf_result.real_time_price_data.currency}" + ) + report_lines.append( + f" - Timestamp: {yf_result.real_time_price_data.timestamp}" + ) + report_lines.append( + f" - Change: {yf_result.real_time_price_data.change} ({yf_result.real_time_price_data.change_percent}%)" + ) + + report_lines.append( + f" AKShare: {'✓ SUCCESS' if ak_result.real_time_price_success else f'✗ FAILED - {ak_result.real_time_price_error}'}" + ) + if ak_result.real_time_price_data: + report_lines.append( + f" - Price: {ak_result.real_time_price_data.price} {ak_result.real_time_price_data.currency}" + ) + report_lines.append( + f" - Timestamp: {ak_result.real_time_price_data.timestamp}" + ) + report_lines.append( + f" - Change: {ak_result.real_time_price_data.change} ({ak_result.real_time_price_data.change_percent}%)" + ) + + # Historical prices comparison + report_lines.append("\n3. GET_HISTORICAL_PRICES (Last 30 days):") + report_lines.append( + f" YFinance: {'✓ SUCCESS' if yf_result.historical_prices_success else f'✗ FAILED - {yf_result.historical_prices_error}'}" + ) + report_lines.append( + f" - Data Points: {yf_result.historical_prices_count}" + ) + + report_lines.append( + f" AKShare: {'✓ SUCCESS' if ak_result.historical_prices_success else f'✗ FAILED - {ak_result.historical_prices_error}'}" + ) + report_lines.append( + f" - Data Points: {ak_result.historical_prices_count}" + ) + + # Update statistics + if yf_result.asset_info_success: + yf_total_success["asset_info"] += 1 + if yf_result.real_time_price_success: + yf_total_success["real_time_price"] += 1 + if yf_result.historical_prices_success: + yf_total_success["historical_prices"] += 1 + + if ak_result.asset_info_success: + ak_total_success["asset_info"] += 1 + if ak_result.real_time_price_success: + ak_total_success["real_time_price"] += 1 + if ak_result.historical_prices_success: + ak_total_success["historical_prices"] += 1 + + # Summary section + report_lines.append(f"\n\n{'=' * 120}") + report_lines.append("SUMMARY STATISTICS") + report_lines.append(f"{'=' * 120}\n") + report_lines.append(f"Total Tickers Tested: {total_tests}\n") + + report_lines.append("YFinance Adapter Success Rate:") + report_lines.append( + f" - get_asset_info: {yf_total_success['asset_info']}/{total_tests} ({yf_total_success['asset_info'] / total_tests * 100:.1f}%)" + ) + report_lines.append( + f" - get_real_time_price: {yf_total_success['real_time_price']}/{total_tests} ({yf_total_success['real_time_price'] / total_tests * 100:.1f}%)" + ) + report_lines.append( + f" - get_historical_prices: {yf_total_success['historical_prices']}/{total_tests} ({yf_total_success['historical_prices'] / total_tests * 100:.1f}%)" + ) + + report_lines.append("\nAKShare Adapter Success Rate:") + report_lines.append( + f" - get_asset_info: {ak_total_success['asset_info']}/{total_tests} ({ak_total_success['asset_info'] / total_tests * 100:.1f}%)" + ) + report_lines.append( + f" - get_real_time_price: {ak_total_success['real_time_price']}/{total_tests} ({ak_total_success['real_time_price'] / total_tests * 100:.1f}%)" + ) + report_lines.append( + f" - get_historical_prices: {ak_total_success['historical_prices']}/{total_tests} ({ak_total_success['historical_prices'] / total_tests * 100:.1f}%)" + ) + + report_lines.append(f"\n{'=' * 120}") + + return "\n".join(report_lines) + + def save_report(self, filename: str = "adapter_comparison_report.txt") -> None: + """Save the report to a file.""" + report = self.generate_report() + with open(filename, "w", encoding="utf-8") as f: + f.write(report) + logger.info(f"\nReport saved to: {filename}") + + +def main(): + """Main test execution.""" + logger.info("Starting adapter comparison tests...") + + tester = AdapterTester() + + # Run all tests + tester.run_all_tests() + + # Generate and save report + report = tester.generate_report() + print("\n\n") + print(report) + + # Save to file + tester.save_report("adapter_comparison_report.txt") + + logger.info("\nAll tests completed!") + + +if __name__ == "__main__": + main() diff --git a/python/valuecell/adapters/assets/types.py b/python/valuecell/adapters/assets/types.py index f51486821..1f9d3eeb7 100644 --- a/python/valuecell/adapters/assets/types.py +++ b/python/valuecell/adapters/assets/types.py @@ -31,6 +31,19 @@ class AssetType(str, Enum): # FUTURE = "future" +class Exchange(str, Enum): + """Enumeration of supported exchanges.""" + + NASDAQ = "NASDAQ" # NASDAQ Market in the US + NYSE = "NYSE" # NYSE Market in the US + AMEX = "AMEX" # AMEX Market in the US + SSE = "SSE" # Shanghai Stock Exchange + SZSE = "SZSE" # Shenzhen Stock Exchange + BSE = "BSE" # Beijing Stock Exchange + HKEX = "HKEX" # Hong Kong Stock Exchange + CRYPTO = "CRYPTO" # Crypto Market + + class MarketStatus(str, Enum): """Market status enumeration.""" @@ -340,9 +353,15 @@ class AssetSearchResult(BaseModel): names: Dict[str, str] = Field(..., description="Asset names in different languages") exchange: str = Field(..., description="Exchange name") country: str = Field(..., description="Country code") - currency: str = Field(..., description="Currency code") - market_status: MarketStatus = Field(default=MarketStatus.UNKNOWN) - relevance_score: float = Field(default=0.0, description="Search relevance score") + + # Optional fields for enhanced search results + currency: Optional[str] = Field(default=None, description="Currency code") + market_status: Optional[MarketStatus] = Field( + default=None, description="Market status" + ) + relevance_score: float = Field( + default=0.5, description="Search relevance score (0-1)" + ) def get_display_name(self, language: str = "en-US") -> str: """Get display name for specified language.""" @@ -353,13 +372,7 @@ class AssetSearchQuery(BaseModel): """Asset search query parameters.""" query: str = Field(..., description="Search query string") - asset_types: Optional[List[AssetType]] = Field( - None, description="Filter by asset types" - ) - exchanges: Optional[List[str]] = Field(None, description="Filter by exchanges") - countries: Optional[List[str]] = Field(None, description="Filter by countries") - limit: int = Field(default=50, description="Maximum number of results") - language: str = Field(default="en-US", description="Preferred language for results") + limit: int = Field(default=10, description="Maximum number of results") @validator("limit") def validate_limit(cls, v): diff --git a/python/valuecell/adapters/assets/yfinance_adapter.py b/python/valuecell/adapters/assets/yfinance_adapter.py index 9a982fcde..f0f2952b6 100644 --- a/python/valuecell/adapters/assets/yfinance_adapter.py +++ b/python/valuecell/adapters/assets/yfinance_adapter.py @@ -7,14 +7,11 @@ import logging from datetime import datetime from decimal import Decimal -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional -try: - import yfinance as yf -except ImportError: - yf = None +import yfinance as yf -from .base import BaseDataAdapter +from .base import AdapterCapability, BaseDataAdapter from .types import ( Asset, AssetPrice, @@ -22,6 +19,7 @@ AssetSearchResult, AssetType, DataSource, + Exchange, Interval, LocalizedName, MarketInfo, @@ -45,18 +43,36 @@ def __init__(self, **kwargs): def _initialize(self) -> None: """Initialize Yahoo Finance adapter configuration.""" - self.session = None # yfinance handles sessions internally self.timeout = self.config.get("timeout", 30) # Asset type mapping for Yahoo Finance - self.asset_type_mapping = { + self.quote_type_to_asset_type_mapping = { "EQUITY": AssetType.STOCK, "ETF": AssetType.ETF, "INDEX": AssetType.INDEX, "CRYPTOCURRENCY": AssetType.CRYPTO, - # Additional mappings for search results - "STOCK": AssetType.STOCK, - "FUND": AssetType.ETF, + } + + # Map yfinance exchanges to our internal exchanges + self.exchange_mapping = { + "NMS": Exchange.NASDAQ, + "NYQ": Exchange.NYSE, + "ASE": Exchange.AMEX, + "SHH": Exchange.SSE, + "SHZ": Exchange.SZSE, + "HKG": Exchange.HKEX, + "PCX": Exchange.NYSE, + "CCC": Exchange.CRYPTO, + } + + self.yfinance_exchange_suffix_mapping = { + Exchange.NASDAQ.value: "", + Exchange.NYSE.value: "", + Exchange.AMEX.value: "", + Exchange.SSE.value: ".SS", + Exchange.SZSE.value: ".SZ", + Exchange.HKEX.value: ".HK", + Exchange.CRYPTO.value: "-USD", } logger.info("Yahoo Finance adapter initialized") @@ -66,6 +82,8 @@ def search_assets(self, query: AssetSearchQuery) -> List[AssetSearchResult]: Uses yfinance.Search for better search results across stocks, ETFs, and other assets. Falls back to direct ticker lookup for specific symbols. + + This method """ results = [] search_term = query.query.strip() @@ -78,13 +96,9 @@ def search_assets(self, query: AssetSearchQuery) -> List[AssetSearchResult]: search_quotes = getattr(search_obj, "quotes", []) # Process search results - for quote in search_quotes[ - : query.limit * 2 - ]: # Get more results to filter later + for quote in search_quotes: try: - result = self._create_search_result_from_quote( - quote, query.language - ) + result = self._create_search_result_from_quote(quote) if result: results.append(result) except Exception as e: @@ -92,34 +106,12 @@ def search_assets(self, query: AssetSearchQuery) -> List[AssetSearchResult]: continue except Exception as e: - logger.debug(f"yfinance Search API failed for '{search_term}': {e}") - - # Fallback to direct ticker lookup - results.extend(self._fallback_ticker_search(search_term, query)) - - # If no results from search, try direct ticker lookup as final fallback - if not results: - results.extend(self._fallback_ticker_search(search_term.upper(), query)) - - # Filter by asset types if specified - if query.asset_types: - results = [r for r in results if r.asset_type in query.asset_types] - - # Filter by exchanges if specified - if query.exchanges: - results = [r for r in results if r.exchange in query.exchanges] - - # Filter by countries if specified - if query.countries: - results = [r for r in results if r.country in query.countries] - - # Sort by relevance score (highest first) - results.sort(key=lambda x: x.relevance_score, reverse=True) + logger.error(f"yfinance Search API failed for '{search_term}': {e}") return results[: query.limit] def _create_search_result_from_quote( - self, quote: Dict, language: str + self, quote: Dict ) -> Optional[AssetSearchResult]: """Create search result from Yahoo Finance search quote.""" try: @@ -128,28 +120,14 @@ def _create_search_result_from_quote( return None # Get exchange information first - exchange = quote.get("exchange", "UNKNOWN") - - # Map yfinance exchange codes to our internal format - exchange_mapping = { - "NMS": "NASDAQ", - "NYQ": "NYSE", - "ASE": "AMEX", - "SHH": "SSE", - "SHZ": "SZSE", - "HKG": "HKEX", - "TYO": "TSE", - "LSE": "LSE", - "PAR": "EURONEXT", - "FRA": "XETRA", - "PCX": "NYSE", # Pacific Exchange (for ETFs like SPY) - "CCC": "CRYPTO", # Crypto - } - mapped_exchange = exchange_mapping.get(exchange, exchange) + exchange = quote.get("exchange") + if not exchange: + return None + + mapped_exchange = self.exchange_mapping.get(exchange) # Filter: Only support specific exchanges - supported_exchanges = ["NASDAQ", "NYSE", "SSE", "SZSE", "HKEX", "CRYPTO"] - if mapped_exchange not in supported_exchanges: + if mapped_exchange not in self.exchange_mapping.values(): logger.debug( f"Skipping unsupported exchange: {mapped_exchange} for symbol {symbol}" ) @@ -157,7 +135,9 @@ def _create_search_result_from_quote( # Convert to internal ticker format and normalize # Remove any suffixes that yfinance might include - internal_ticker = self.convert_to_internal_ticker(symbol, mapped_exchange) + internal_ticker = self.convert_to_internal_ticker( + symbol, mapped_exchange.value + ) # Validate the ticker format if not self._is_valid_internal_ticker(internal_ticker): @@ -168,15 +148,17 @@ def _create_search_result_from_quote( # Get asset type from quote type quote_type = quote.get("quoteType", "").upper() - asset_type = self.asset_type_mapping.get(quote_type, AssetType.STOCK) + asset_type = self.quote_type_to_asset_type_mapping.get( + quote_type, AssetType.STOCK + ) # Get country information country = "US" # Default - if mapped_exchange in ["SSE", "SZSE"]: + if mapped_exchange in [Exchange.SSE, Exchange.SZSE]: country = "CN" - elif mapped_exchange == "HKEX": + elif mapped_exchange == Exchange.HKEX: country = "HK" - elif mapped_exchange == "CRYPTO": + elif mapped_exchange == Exchange.CRYPTO: country = "US" # Get names in different languages @@ -193,134 +175,48 @@ def _create_search_result_from_quote( quote, symbol, long_name or short_name ) - return AssetSearchResult( + # Create search result + search_result = AssetSearchResult( ticker=internal_ticker, asset_type=asset_type, names=names, - exchange=mapped_exchange, + exchange=mapped_exchange.value, country=country, currency=quote.get("currency", "USD"), market_status=MarketStatus.UNKNOWN, relevance_score=relevance_score, ) - except Exception as e: - logger.error(f"Error creating search result from quote: {e}") - return None - - def _fallback_ticker_search( - self, search_term: str, query: AssetSearchQuery - ) -> List[AssetSearchResult]: - """Fallback search using direct ticker lookup with common suffixes.""" - results = [] - - # Try direct ticker lookup first - try: - ticker_obj = yf.Ticker(search_term) - info = ticker_obj.info - - if info and "symbol" in info and info.get("symbol"): - result = self._create_search_result_from_info(info, query.language) - if result: - results.append(result) - except Exception as e: - logger.debug(f"Direct ticker lookup failed for {search_term}: {e}") - - # Try with common suffixes for international markets - if not results: - suffixes = [".SS", ".SZ", ".HK", ".T", ".L", ".PA", ".DE", ".TO", ".AX"] - for suffix in suffixes: - try: - test_ticker = f"{search_term}{suffix}" - ticker_obj = yf.Ticker(test_ticker) - info = ticker_obj.info - - if info and "symbol" in info and info.get("symbol"): - result = self._create_search_result_from_info( - info, query.language - ) - if result: - results.append(result) - break # Found one, stop searching - except Exception: - continue - - return results - - def _calculate_search_relevance(self, quote: Dict, symbol: str, name: str) -> float: - """Calculate relevance score for search results.""" - score = 0.0 - - # Base score for having a result - score += 0.5 - - # Higher score for exact symbol matches - if quote.get("symbol", "").upper() == symbol.upper(): - score += 0.3 - - # Score based on market cap (larger companies get higher scores) - market_cap = quote.get("marketCap") - if market_cap and isinstance(market_cap, (int, float)) and market_cap > 0: - # Normalize market cap to 0-0.2 range - score += min( - 0.2, market_cap / 1e12 - ) # Trillion dollar companies get max score - - # Bonus for having complete information - if quote.get("longname"): - score += 0.1 - if quote.get("currency"): - score += 0.05 - if quote.get("exchange"): - score += 0.05 - - return min(1.0, score) # Cap at 1.0 - - def _create_search_result_from_info( - self, info: Dict, language: str - ) -> Optional[AssetSearchResult]: - """Create search result from Yahoo Finance info dictionary.""" - try: - symbol = info.get("symbol", "") - if not symbol: - return None - - # Convert to internal ticker format - internal_ticker = self.convert_to_internal_ticker(symbol) - - # Get asset type - asset_type = self.asset_type_mapping.get( - info.get("quoteType", "").upper(), AssetType.STOCK - ) - - # Get exchange and country - exchange = info.get("exchange", "UNKNOWN") - country = info.get("country", "US") # Default to US - - # Get names in different languages - names = { - "en-US": info.get("longName", info.get("shortName", symbol)), - } + # Save asset metadata to database for future lookups + try: + from ...server.db.repositories.asset_repository import ( + get_asset_repository, + ) - # For Chinese markets, try to get Chinese name - if exchange in ["SSE", "SHE"] and language.startswith("zh"): - # This would require additional API calls or data sources - # For now, use English name as fallback - pass + asset_repo = get_asset_repository() + asset_repo.upsert_asset( + symbol=internal_ticker, + name=long_name or short_name, + asset_type=asset_type.value, + description=quote.get("longname"), + sector=quote.get("sector"), + asset_metadata={ + "currency": quote.get("currency", "USD"), + "exchange_code": exchange, # Original yfinance exchange code + "quote_type": quote_type, + }, + ) + logger.debug(f"Saved asset metadata for {internal_ticker}") + except Exception as e: + # Don't fail the search if database save fails + logger.warning( + f"Failed to save asset metadata for {internal_ticker}: {e}" + ) - return AssetSearchResult( - ticker=internal_ticker, - asset_type=asset_type, - names=names, - exchange=exchange, - country=country, - currency=info.get("currency", "USD"), - market_status=MarketStatus.UNKNOWN, # Would need real-time data - relevance_score=1.0, # Simple relevance scoring - ) + return search_result except Exception as e: - logger.error(f"Error creating search result: {e}") + logger.error(f"Error creating search result from quote: {e}") return None def get_asset_info(self, ticker: str) -> Optional[Asset]: @@ -338,16 +234,19 @@ def get_asset_info(self, ticker: str) -> Optional[Asset]: long_name = info.get("longName", info.get("shortName", ticker)) names.set_name("en-US", long_name) + if info.get("exchange"): + exchange = self.exchange_mapping.get(info.get("exchange")) + # Create market info market_info = MarketInfo( - exchange=info.get("exchange", "UNKNOWN"), + exchange=exchange.value if exchange else "UNKNOWN", country=info.get("country", "US"), currency=info.get("currency", "USD"), timezone=info.get("exchangeTimezoneName", "America/New_York"), ) # Determine asset type - asset_type = self.asset_type_mapping.get( + asset_type = self.quote_type_to_asset_type_mapping.get( info.get("quoteType", "").upper(), AssetType.STOCK ) @@ -378,6 +277,32 @@ def get_asset_info(self, ticker: str) -> Optional[Asset]: properties = {k: v for k, v in properties.items() if v is not None} asset.properties.update(properties) + # Save asset metadata to database + try: + from ...server.db.repositories.asset_repository import ( + get_asset_repository, + ) + + asset_repo = get_asset_repository() + asset_repo.upsert_asset( + symbol=ticker, + name=long_name, + asset_type=asset_type.value, + description=info.get("longBusinessSummary"), + sector=info.get("sector"), + asset_metadata={ + "currency": info.get("currency", "USD"), + "exchange_code": info.get("exchange"), + "quote_type": info.get("quoteType"), + "industry": info.get("industry"), + "market_cap": info.get("marketCap"), + }, + ) + logger.debug(f"Saved asset info for {ticker}") + except Exception as e: + # Don't fail the info fetch if database save fails + logger.warning(f"Failed to save asset info for {ticker}: {e}") + return asset except Exception as e: @@ -631,6 +556,53 @@ def safe_decimal(value, default=None): # Fallback to individual requests return super().get_multiple_prices(tickers) + def get_capabilities(self) -> List[AdapterCapability]: + """Get detailed capabilities of Yahoo Finance adapter. + + Yahoo Finance supports major US, Hong Kong, and Chinese exchanges. + + Returns: + List of capabilities describing supported asset types and exchanges + """ + return [ + AdapterCapability( + asset_type=AssetType.STOCK, + exchanges={ + Exchange.NASDAQ, + Exchange.NYSE, + Exchange.AMEX, + Exchange.SSE, + Exchange.SZSE, + Exchange.HKEX, + }, + ), + AdapterCapability( + asset_type=AssetType.ETF, + exchanges={ + Exchange.NASDAQ, + Exchange.NYSE, + Exchange.AMEX, + Exchange.SSE, + Exchange.SZSE, + Exchange.HKEX, + }, + ), + AdapterCapability( + asset_type=AssetType.INDEX, + exchanges={ + Exchange.NASDAQ, + Exchange.NYSE, + Exchange.SSE, + Exchange.SZSE, + Exchange.HKEX, + }, + ), + AdapterCapability( + asset_type=AssetType.CRYPTO, + exchanges={Exchange.CRYPTO}, + ), + ] + def get_supported_asset_types(self) -> List[AssetType]: """Get asset types supported by Yahoo Finance.""" return [ @@ -640,82 +612,143 @@ def get_supported_asset_types(self) -> List[AssetType]: AssetType.CRYPTO, ] - def _perform_health_check(self) -> Any: - """Perform health check by fetching a known ticker.""" - try: - # Test with Apple stock - ticker_obj = yf.Ticker("AAPL") - info = ticker_obj.info + def validate_ticker(self, ticker: str) -> bool: + """Validate if ticker is supported by Yahoo Finance. + Args: + ticker: Ticker in internal format, suppose the ticker has been validated before by the caller. + (e.g., "NASDAQ:AAPL", "HKEX:00700", "CRYPTO:BTC") + Returns: + True if ticker is supported + """ - if info and "symbol" in info: - return { - "status": "ok", - "test_ticker": "AAPL", - "response_received": True, - } - else: - return {"status": "error", "message": "No data received"} + if ":" not in ticker: + return False - except Exception as e: - return {"status": "error", "message": str(e)} + exchange, symbol = ticker.split(":", 1) + + # Validate exchange + if exchange not in [ + exchange.value for exchange in self.exchange_mapping.values() + ]: + return False - def _is_valid_internal_ticker(self, ticker: str) -> bool: - """Validate if internal ticker format is correct and supported. + return True - Args: - ticker: Internal ticker format (e.g., "NASDAQ:AAPL", "HKEX:00700", "CRYPTO:BTC") + def convert_to_source_ticker(self, internal_ticker: str) -> str: + """Convert internal ticker to Yahoo Finance source ticker. - Returns: - True if ticker format is valid + For INDEX assets, adds ^ prefix (e.g., NASDAQ:IXIC -> ^IXIC). + For other assets, applies exchange-specific suffix rules. """ try: - if ":" not in ticker: - return False + exchange, symbol = internal_ticker.split(":", 1) - exchange, symbol = ticker.split(":", 1) - - # Validate exchange - supported_exchanges = ["NASDAQ", "NYSE", "SSE", "SZSE", "HKEX", "CRYPTO"] - if exchange not in supported_exchanges: - return False - - # Validate symbol format based on exchange - if exchange in ["NASDAQ", "NYSE"]: - # US stocks: 1-5 uppercase letters, no special characters except hyphen - return ( - bool(symbol) - and len(symbol) <= 5 - and symbol.replace("-", "").isalnum() + # Check asset type from database to determine if it's an index + # Use lazy import to avoid circular dependency + try: + from ...server.db.repositories.asset_repository import ( + get_asset_repository, ) - elif exchange in ["SSE", "SZSE"]: - # A-shares: exactly 6 digits - return symbol.isdigit() and len(symbol) == 6 - - elif exchange == "HKEX": - # HK stocks: 1-5 digits (e.g., 00700) - # HK indices: uppercase letters (e.g., HSI, HSCEI) - # No .HK suffix allowed - if ".HK" in symbol: - return False - return (symbol.isdigit() and 1 <= len(symbol) <= 5) or ( - symbol.isalpha() and symbol.isupper() + asset_repo = get_asset_repository() + asset = asset_repo.get_asset_by_symbol(internal_ticker) + + if ( + asset + and asset.asset_type == AssetType.INDEX.value + and exchange != Exchange.SSE.value + and exchange != Exchange.SZSE.value + ): + # For indices, add ^ prefix + return f"^{symbol}" + except (ImportError, Exception) as e: + # If repository is not available, skip database lookup + logger.debug( + f"Asset repository not available, skipping database lookup: {e}" ) + pass - elif exchange == "CRYPTO": - # Crypto: uppercase letters, no currency suffix (e.g., BTC, not BTC-USD) + # For non-index assets, apply exchange-specific formatting + if exchange == Exchange.HKEX.value: + # Hong Kong stock codes need to be in proper format + # e.g., "700" -> "0700.HK", "00700" -> "0700.HK", "1234" -> "1234.HK" + if symbol.isdigit(): + # Remove leading zeros first, then pad to 4 digits + clean_symbol = str(int(symbol)) # Remove leading zeros + padded_symbol = clean_symbol.zfill(4) # Pad to 4 digits + return f"{padded_symbol}{self.yfinance_exchange_suffix_mapping.get(exchange, '')}" + else: + # For non-numeric symbols, use as-is with .HK suffix + return f"{symbol}{self.yfinance_exchange_suffix_mapping.get(exchange, '')}" + + if exchange in self.yfinance_exchange_suffix_mapping.keys(): return ( - bool(symbol) - and symbol.isalpha() - and symbol.isupper() - and "-" not in symbol + f"{symbol}{self.yfinance_exchange_suffix_mapping.get(exchange, '')}" ) + else: + logger.warning(f"No mapping found for exchange: {exchange} in Yfinance") + return symbol - return False + except ValueError: + logger.error(f"Invalid ticker format: {internal_ticker}, Yfinance adapter.") + return internal_ticker - except (ValueError, AttributeError): - return False + def convert_to_internal_ticker( + self, source_ticker: str, default_exchange: Optional[str] = None + ) -> str: + """Convert Yahoo Finance source ticker to internal ticker. - def validate_ticker(self, ticker: str) -> bool: - """Validate if ticker is supported by Yahoo Finance.""" - return self._is_valid_internal_ticker(ticker) + Simply removes yfinance-specific prefixes/suffixes and formats the symbol. + Asset type determination (e.g., INDEX) is done during search/info retrieval. + """ + # Special handling for indices from yfinance - remove ^ prefix + if source_ticker.startswith("^"): + symbol = source_ticker[1:] # Remove ^ prefix + # Use default exchange if provided, otherwise try to infer from symbol + if default_exchange: + return f"{default_exchange}:{symbol}" + # Common index exchange inference (simplified) + # Most major indices use their primary exchange + return ( + f"{default_exchange}:{symbol}" + if default_exchange + else f"UNKNOWN:{symbol}" + ) + + # Special handling for crypto from yfinance - remove currency suffix + if ( + "-USD" in source_ticker + or "-CAD" in source_ticker + or "-EUR" in source_ticker + ): + # Remove any currency suffix + crypto_symbol = source_ticker.split("-")[0].upper() + return f"CRYPTO:{crypto_symbol}" + + # Special handling for Hong Kong stocks from yfinance + if ".HK" in source_ticker: + symbol = source_ticker.replace(".HK", "") # Remove .HK suffix + # Keep as digits only, no leading zero removal for internal format + if symbol.isdigit(): + # Pad to 5 digits for Hong Kong stocks + symbol = symbol.zfill(5) + return f"HKEX:{symbol}" + + # Special handling for Shanghai stocks from yfinance + if ".SS" in source_ticker: + symbol = source_ticker.replace(".SS", "") + return f"SSE:{symbol}" + + # Special handling for Shenzhen stocks from yfinance + if ".SZ" in source_ticker: + symbol = source_ticker.replace(".SZ", "") + return f"SZSE:{symbol}" + + # If no suffix found and default exchange provided + if default_exchange: + # For US stocks from yfinance, symbol is already clean + return f"{default_exchange}:{source_ticker}" + + # For other assets without clear exchange mapping + # Fallback to using the source as exchange + return f"YFINANCE:{source_ticker}" diff --git a/python/valuecell/agents/auto_trading_agent/agent.py b/python/valuecell/agents/auto_trading_agent/agent.py index d9b8a3b91..04d1bea9b 100644 --- a/python/valuecell/agents/auto_trading_agent/agent.py +++ b/python/valuecell/agents/auto_trading_agent/agent.py @@ -80,15 +80,15 @@ def __init__(self): raise async def _process_trading_instance( - self, - session_id: str, - instance_id: str, + self, + session_id: str, + instance_id: str, semaphore: asyncio.Semaphore, - unified_timestamp: Optional[datetime] = None + unified_timestamp: Optional[datetime] = None, ) -> None: """ Process a single trading instance with semaphore control for concurrency limiting. - + Args: session_id: Session identifier instance_id: Trading instance identifier @@ -100,16 +100,16 @@ async def _process_trading_instance( # Check if instance still exists and is active if instance_id not in self.trading_instances.get(session_id, {}): return - + instance = self.trading_instances[session_id][instance_id] if not instance["active"]: return - + # Get instance components executor = instance["executor"] config = instance["config"] ai_signal_generator = instance["ai_signal_generator"] - + # Update check info instance["check_count"] += 1 instance["last_check"] = datetime.now() @@ -202,9 +202,7 @@ async def _process_trading_instance( # Phase 2: Make portfolio-level decision logger.info( "\n" + "=" * 50 + "\n" - "🎯 **Phase 2: Portfolio Decision Making...**\n" - + "=" * 50 - + "\n\n" + "🎯 **Phase 2: Portfolio Decision Making...**\n" + "=" * 50 + "\n\n" ) # Get portfolio summary @@ -212,12 +210,10 @@ async def _process_trading_instance( logger.info(portfolio_summary + "\n") # Make coordinated decision (async call for AI analysis) - portfolio_decision = ( - await portfolio_manager.make_portfolio_decision( - current_positions=executor.positions, - available_cash=executor.get_current_capital(), - total_portfolio_value=executor.get_portfolio_value(), - ) + portfolio_decision = await portfolio_manager.make_portfolio_decision( + current_positions=executor.positions, + available_cash=executor.get_current_capital(), + total_portfolio_value=executor.get_portfolio_value(), ) # Display decision reasoning - cache it @@ -248,9 +244,7 @@ async def _process_trading_instance( trade_type, ) in portfolio_decision.trades_to_execute: # Get indicators for this symbol - asset_analysis = portfolio_manager.asset_analyses.get( - symbol - ) + asset_analysis = portfolio_manager.asset_analyses.get(symbol) if not asset_analysis: continue @@ -317,23 +311,21 @@ async def _process_trading_instance( import yfinance as yf ticker = yf.Ticker(symbol) - current_price = ticker.history( - period="1d", interval="1m" - )["Close"].iloc[-1] + current_price = ticker.history(period="1d", interval="1m")[ + "Close" + ].iloc[-1] if pos.trade_type.value == "long": - current_pnl = ( - current_price - pos.entry_price - ) * abs(pos.quantity) + current_pnl = (current_price - pos.entry_price) * abs( + pos.quantity + ) else: - current_pnl = ( - pos.entry_price - current_price - ) * abs(pos.quantity) + current_pnl = (pos.entry_price - current_price) * abs( + pos.quantity + ) pnl_emoji = "🟢" if current_pnl >= 0 else "🔴" portfolio_msg += f"- {symbol}: {pos.trade_type.value.upper()} @ ${pos.entry_price:,.2f} {pnl_emoji} P&L: ${current_pnl:,.2f}\n" except Exception as e: - logger.warning( - f"Failed to calculate P&L for {symbol}: {e}" - ) + logger.warning(f"Failed to calculate P&L for {symbol}: {e}") portfolio_msg += f"- {symbol}: {pos.trade_type.value.upper()} @ ${pos.entry_price:,.2f}\n" logger.info(portfolio_msg + "\n") @@ -352,22 +344,24 @@ async def _process_trading_instance( def _generate_instance_id(self, task_id: str, model_id: str) -> str: """ Generate unique instance ID for a specific model - + Args: task_id: Task ID from the request model_id: Model identifier (e.g., 'deepseek/deepseek-v3.1-terminus') - + Returns: Unique instance ID combining timestamp, task, and model """ import hashlib - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") # Include microseconds for uniqueness + + timestamp = datetime.now().strftime( + "%Y%m%d_%H%M%S_%f" + ) # Include microseconds for uniqueness # Create a short hash from model_id for readability model_hash = hashlib.md5(model_id.encode()).hexdigest()[:6] # Extract model name (last part after /) - model_name = model_id.split('/')[-1].replace('-', '_').replace('.', '_')[:15] - + model_name = model_id.split("/")[-1].replace("-", "_").replace(".", "_")[:15] + return f"trade_{timestamp}_{model_name}_{model_hash}" def _init_notification_cache(self, session_id: str) -> None: @@ -640,14 +634,14 @@ def _get_session_portfolio_chart_data(self, session_id: str) -> str: if model_id not in model_data: model_data[model_id] = { - 'initial_capital': config.initial_capital, - 'history': [] + "initial_capital": config.initial_capital, + "history": [], } portfolio_history = executor.get_portfolio_history() for snapshot in portfolio_history: - model_data[model_id]['history'].append( + model_data[model_id]["history"].append( (snapshot.timestamp, snapshot.total_value) ) @@ -656,12 +650,12 @@ def _get_session_portfolio_chart_data(self, session_id: str) -> str: # Sort each model's history by timestamp for model_id in model_data: - model_data[model_id]['history'].sort(key=lambda x: x[0]) + model_data[model_id]["history"].sort(key=lambda x: x[0]) # Collect all unique timestamps across all models all_timestamps = set() for model_id, data in model_data.items(): - for timestamp, _ in data['history']: + for timestamp, _ in data["history"]: all_timestamps.add(timestamp) if not all_timestamps: @@ -675,8 +669,9 @@ def _get_session_portfolio_chart_data(self, session_id: str) -> str: data_array = [["Time"] + model_ids] # Track last known value for each model (for forward-fill) - last_known_values = {model_id: data['initial_capital'] - for model_id, data in model_data.items()} + last_known_values = { + model_id: data["initial_capital"] for model_id, data in model_data.items() + } # Data rows: ['timestamp', value1, value2, ...] for timestamp in sorted_timestamps: @@ -686,7 +681,7 @@ def _get_session_portfolio_chart_data(self, session_id: str) -> str: for model_id in model_ids: # Find value at this timestamp for this model value_at_timestamp = None - for ts, val in model_data[model_id]['history']: + for ts, val in model_data[model_id]["history"]: if ts == timestamp: value_at_timestamp = val break @@ -818,7 +813,7 @@ async def stream( """ # Track created instances for cleanup created_instances = [] - + try: logger.info( f"Processing auto trading request - session: {session_id}, task: {task_id}" @@ -863,12 +858,12 @@ async def stream( # Get list of models to create instances for agent_models = trading_request.agent_models or [DEFAULT_AGENT_MODEL] - + # Create one trading instance per model yield streaming.message_chunk( f"🚀 **Creating {len(agent_models)} trading instance(s)...**\n\n" ) - + for model_id in agent_models: # Generate unique instance ID for this model instance_id = self._generate_instance_id(task_id, model_id) @@ -898,7 +893,7 @@ async def stream( "check_count": 0, "last_check": None, } - + created_instances.append(instance_id) # Display configuration for this instance @@ -917,7 +912,7 @@ async def stream( ) yield streaming.message_chunk(config_message) - + # Summary message yield streaming.message_chunk( f"**Session ID:** `{session_id[:8]}`\n" @@ -933,7 +928,7 @@ async def stream( instance = self.trading_instances[session_id][instance_id] executor = instance["executor"] config = instance["config"] - + # Send initial portfolio snapshot - cache it portfolio_value = executor.get_portfolio_value() executor.snapshot_portfolio(unified_initial_timestamp) @@ -943,7 +938,9 @@ async def stream( data=f"💰 **Initial Portfolio**\nTotal Value: ${portfolio_value:,.2f}\nAvailable Capital: ${executor.current_capital:,.2f}\n", filters=[config.agent_model], table_title="Portfolio Detail", - create_time=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + create_time=datetime.now(timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S" + ), ) # Cache the initial notification self._cache_notification(session_id, initial_portfolio_msg) @@ -955,7 +952,9 @@ async def stream( semaphore = asyncio.Semaphore(10) # Main trading loop - monitor all instances in parallel - yield streaming.message_chunk("📈 **Starting monitoring loop for all instances...**\n\n") + yield streaming.message_chunk( + "📈 **Starting monitoring loop for all instances...**\n\n" + ) # Check if any instance is still active while any( @@ -966,18 +965,18 @@ async def stream( try: # Create unified timestamp for this iteration to align snapshots unified_timestamp = datetime.now() - + # Process all active instances concurrently using task pool tasks = [] for instance_id in created_instances: # Skip if instance was removed or is inactive if instance_id not in self.trading_instances[session_id]: continue - + instance = self.trading_instances[session_id][instance_id] if not instance["active"]: continue - + # Create task for this instance with semaphore control and unified timestamp task = asyncio.create_task( self._process_trading_instance( @@ -985,17 +984,19 @@ async def stream( ) ) tasks.append(task) - + # Wait for all instance tasks to complete (process concurrently) if tasks: # Gather all tasks and handle any exceptions results = await asyncio.gather(*tasks, return_exceptions=True) - + # Log any exceptions that occurred for i, result in enumerate(results): if isinstance(result, Exception): - logger.error(f"Task {i} failed with exception: {result}") - + logger.error( + f"Task {i} failed with exception: {result}" + ) + # After processing all instances, send batched notifications cached_notifications = self._get_cached_notifications(session_id) if cached_notifications: @@ -1012,7 +1013,7 @@ async def stream( ComponentType.FILTERED_CARD_PUSH_NOTIFICATION, component_id=f"trading_status_{session_id}", ) - + # Send chart data (not cached, sent separately) chart_data = self._get_session_portfolio_chart_data(session_id) if chart_data: @@ -1042,5 +1043,7 @@ async def stream( if session_id in self.trading_instances: for instance_id in created_instances: if instance_id in self.trading_instances[session_id]: - self.trading_instances[session_id][instance_id]["active"] = False + self.trading_instances[session_id][instance_id]["active"] = ( + False + ) logger.info(f"Stopped instance: {instance_id}") diff --git a/python/valuecell/core/conversation/tests/test_conv_manager.py b/python/valuecell/core/conversation/tests/test_conv_manager.py index 8d2cc2dd3..a7606245e 100644 --- a/python/valuecell/core/conversation/tests/test_conv_manager.py +++ b/python/valuecell/core/conversation/tests/test_conv_manager.py @@ -41,7 +41,9 @@ async def test_create_conversation_minimal(self): manager = ConversationManager() user_id = "user-123" - with patch("valuecell.core.conversation.manager.generate_conversation_id") as mock_uuid: + with patch( + "valuecell.core.conversation.manager.generate_conversation_id" + ) as mock_uuid: mock_uuid.return_value = "conv-generated-123" result = await manager.create_conversation(user_id) diff --git a/python/valuecell/server/api/schemas/watchlist.py b/python/valuecell/server/api/schemas/watchlist.py index 8e947f1b1..fc9377643 100644 --- a/python/valuecell/server/api/schemas/watchlist.py +++ b/python/valuecell/server/api/schemas/watchlist.py @@ -107,7 +107,6 @@ class AssetInfoData(BaseModel): market_status_display: Optional[str] = Field( None, description="Localized market status display" ) - relevance_score: Optional[float] = Field(None, description="Search relevance score") class AssetSearchResultData(BaseModel): diff --git a/python/valuecell/server/db/init_db.py b/python/valuecell/server/db/init_db.py index acf780c12..2b4b5c687 100644 --- a/python/valuecell/server/db/init_db.py +++ b/python/valuecell/server/db/init_db.py @@ -4,12 +4,7 @@ import logging import sys from pathlib import Path -from typing import TYPE_CHECKING, Optional - -from valuecell.utils.path import get_agent_card_path - -if TYPE_CHECKING: - from .models.asset import Asset +from typing import Optional from sqlalchemy import inspect, text from sqlalchemy.exc import SQLAlchemyError @@ -17,9 +12,10 @@ from valuecell.server.config.settings import get_settings from valuecell.server.db.connection import DatabaseManager, get_database_manager from valuecell.server.db.models.agent import Agent -from valuecell.server.db.models.asset import Asset from valuecell.server.db.models.base import Base +from valuecell.server.db.repositories.asset_repository import get_asset_repository from valuecell.server.services.assets import get_asset_service +from valuecell.utils.path import get_agent_card_path # Configure logging logging.basicConfig( @@ -132,28 +128,20 @@ def initialize_assets_with_service(self) -> bool: try: logger.info("Initializing assets using AssetService...") - # Get asset service instance + # Get asset service and repository instances asset_service = get_asset_service() + session = self.db_manager.get_session() + asset_repo = get_asset_repository(db_session=session) # Define default tickers to search and initialize # Using proper EXCHANGE:SYMBOL format for better adapter matching default_tickers = [ - "NASDAQ:AAPL", # Apple Inc. - "NASDAQ:GOOGL", # Alphabet Inc. - "NASDAQ:MSFT", # Microsoft Corporation - "NYSE:SPY", # SPDR S&P 500 ETF - "CRYPTO:BTC", # Bitcoin - # Additional diverse assets - "NYSE:TSLA", # Tesla Inc. - "NASDAQ:NVDA", # NVIDIA Corporation - "NYSE:JPM", # JPMorgan Chase & Co. - "CRYPTO:ETH", # Ethereum - "NASDAQ:QQQ", # Invesco QQQ Trust ETF + # Major indices + "NASDAQ:IXIC", # NASDAQ Composite Index + "HKEX:HSI", # Hang Seng Index + "SSE:000001", # Shanghai Composite Index ] - # Get database session for manual asset creation if needed - session = self.db_manager.get_session() - try: initialized_count = 0 @@ -182,7 +170,7 @@ def initialize_assets_with_service(self) -> bool: search_result = {"success": False, "results": []} if search_result["success"] and search_result["results"]: - # Asset found via adapter, create database record + # Asset found via adapter, create or update database record asset_data = search_result["results"][0] # Use the standardized ticker format (ensure EXCHANGE:SYMBOL format) @@ -192,15 +180,33 @@ def initialize_assets_with_service(self) -> bool: asset_ticker = ticker # Check if asset already exists in database - existing_asset = ( - session.query(Asset) - .filter_by(symbol=asset_ticker) - .first() - ) - - if not existing_asset: + if asset_repo.asset_exists(asset_ticker): + # Update existing asset with adapter data + metadata_updates = { + "exchange": asset_data.get("exchange") + or ticker.split(":")[0], + "country": asset_data.get("country"), + "currency": asset_data.get("currency"), + "market_status": asset_data.get("market_status"), + "last_updated_from_adapter": True, + "last_search_query": query, + } + + asset_repo.update_asset( + symbol=asset_ticker, + name=asset_data["display_name"], + asset_type=asset_data["asset_type"], + ) + asset_repo.update_asset_metadata( + symbol=asset_ticker, + metadata_updates=metadata_updates, + ) + logger.info( + f"Updated asset from adapter: {asset_ticker} (searched as '{query}')" + ) + else: # Create new asset from adapter data - new_asset = Asset( + asset_repo.create_asset( symbol=asset_ticker, name=asset_data["display_name"], asset_type=asset_data["asset_type"], @@ -213,41 +219,14 @@ def initialize_assets_with_service(self) -> bool: "market_status" ), "source": "adapter_search", - "relevance_score": asset_data.get( - "relevance_score", 0.0 - ), "original_search_query": query, "standardized_ticker": asset_ticker, }, ) - session.add(new_asset) logger.info( f"Added asset from adapter: {asset_ticker} (searched as '{query}')" ) initialized_count += 1 - else: - # Update existing asset with adapter data - existing_asset.name = asset_data["display_name"] - existing_asset.asset_type = asset_data["asset_type"] - # Update existing asset metadata - existing_metadata = existing_asset.asset_metadata or {} - existing_metadata.update( - { - "exchange": asset_data.get("exchange") - or ticker.split(":")[0], - "country": asset_data.get("country"), - "currency": asset_data.get("currency"), - "market_status": asset_data.get( - "market_status" - ), - "last_updated_from_adapter": True, - "last_search_query": query, - } - ) - existing_asset.asset_metadata = existing_metadata - logger.info( - f"Updated asset from adapter: {asset_ticker} (searched as '{query}')" - ) else: # Fallback: create basic asset record for common tickers @@ -255,13 +234,10 @@ def initialize_assets_with_service(self) -> bool: f"Could not find {ticker} via adapters, creating basic record" ) - existing_asset = ( - session.query(Asset).filter_by(symbol=ticker).first() - ) - if not existing_asset: - fallback_asset = self._create_fallback_asset(ticker) - if fallback_asset: - session.add(fallback_asset) + if not asset_repo.asset_exists(ticker): + fallback_data = self._get_fallback_asset_data(ticker) + if fallback_data: + asset_repo.create_asset(**fallback_data) logger.info(f"Added fallback asset: {ticker}") initialized_count += 1 @@ -294,116 +270,46 @@ def initialize_assets_with_service(self) -> bool: logger.error(f"Error getting asset service or database session: {e}") return False - def _create_fallback_asset(self, ticker: str) -> Optional["Asset"]: - """Create fallback asset data when adapter search fails.""" + def _get_fallback_asset_data(self, ticker: str) -> Optional[dict]: + """Get fallback asset data when adapter search fails. + Returns: + Dictionary with asset data suitable for create_asset() method + """ # Basic fallback data for common tickers (using proper EXCHANGE:SYMBOL format) - fallback_data = { - "NASDAQ:AAPL": { - "name": "Apple Inc.", - "asset_type": "stock", - "sector": "Technology", - "exchange": "NASDAQ", - "metadata": { - "market_cap": "large", - "tags": ["blue-chip", "technology"], - }, - }, - "NASDAQ:GOOGL": { - "name": "Alphabet Inc. Class A", - "asset_type": "stock", - "sector": "Technology", - "exchange": "NASDAQ", - "metadata": { - "market_cap": "large", - "tags": ["growth", "tech-giant", "ai"], - }, - }, - "NASDAQ:MSFT": { - "name": "Microsoft Corporation", - "asset_type": "stock", - "sector": "Technology", - "exchange": "NASDAQ", - "metadata": { - "market_cap": "large", - "tags": ["blue-chip", "cloud", "ai"], - }, - }, - "NYSE:SPY": { - "name": "SPDR S&P 500 ETF Trust", - "asset_type": "etf", - "sector": "Diversified", - "exchange": "NYSE", - "metadata": {"tags": ["index", "diversified", "low-cost"]}, - }, - "CRYPTO:BTC": { - "name": "Bitcoin", - "asset_type": "crypto", - "sector": "Cryptocurrency", - "exchange": "CRYPTO", - "metadata": {"tags": ["crypto", "store-of-value", "digital-gold"]}, - }, - "NYSE:TSLA": { - "name": "Tesla Inc.", - "asset_type": "stock", - "sector": "Automotive", - "exchange": "NYSE", - "metadata": { - "market_cap": "large", - "tags": ["electric-vehicles", "innovation", "growth"], - }, - }, - "NASDAQ:NVDA": { - "name": "NVIDIA Corporation", - "asset_type": "stock", - "sector": "Technology", - "exchange": "NASDAQ", - "metadata": { - "market_cap": "large", - "tags": ["semiconductors", "ai", "gaming"], - }, - }, - "NYSE:JPM": { - "name": "JPMorgan Chase & Co.", - "asset_type": "stock", - "sector": "Financial Services", - "exchange": "NYSE", - "metadata": { - "market_cap": "large", - "tags": ["banking", "blue-chip", "finance"], - }, + fallback_configs = { + "SSE:000001": { + "name": "Shanghai Composite Index", + "asset_type": "index", + "exchange": "SSE", }, - "CRYPTO:ETH": { - "name": "Ethereum", - "asset_type": "crypto", - "sector": "Cryptocurrency", - "exchange": "CRYPTO", - "metadata": {"tags": ["crypto", "smart-contracts", "defi"]}, + "HKEX:HSI": { + "name": "Hang Seng Index", + "asset_type": "index", + "exchange": "HKEX", }, - "NASDAQ:QQQ": { - "name": "Invesco QQQ Trust ETF", - "asset_type": "etf", - "sector": "Technology", + "NASDAQ:IXIC": { + "name": "NASDAQ Composite Index", + "asset_type": "index", "exchange": "NASDAQ", - "metadata": {"tags": ["tech-etf", "index", "growth"]}, }, } - if ticker in fallback_data: - data = fallback_data[ticker] - return Asset( - symbol=ticker, - name=data["name"], - asset_type=data["asset_type"], - sector=data.get("sector"), - is_active=True, - asset_metadata={ - **data.get("metadata", {}), - "exchange": data.get("exchange"), + if ticker in fallback_configs: + config = fallback_configs[ticker] + return { + "symbol": ticker, + "name": config["name"], + "asset_type": config["asset_type"], + "sector": config.get("sector"), + "is_active": True, + "asset_metadata": { + **config.get("metadata", {}), + "exchange": config.get("exchange"), "source": "fallback_data", "initialized_at": "database_init", }, - ) + } return None def initialize_basic_data(self) -> bool: diff --git a/python/valuecell/server/db/models/asset.py b/python/valuecell/server/db/models/asset.py index 6304a0d09..70b939746 100644 --- a/python/valuecell/server/db/models/asset.py +++ b/python/valuecell/server/db/models/asset.py @@ -31,7 +31,7 @@ class Asset(Base): unique=True, nullable=False, index=True, - comment="Asset symbol/ticker (e.g., AAPL, BTC, etc.)", + comment="Asset symbol/ticker (e.g., NASDAQ:MSFT, CRYPTO:BTC, etc.)", ) name = Column(String(200), nullable=False, comment="Full name of the asset") description = Column( diff --git a/python/valuecell/server/db/repositories/__init__.py b/python/valuecell/server/db/repositories/__init__.py index e69de29bb..624fcf2bc 100644 --- a/python/valuecell/server/db/repositories/__init__.py +++ b/python/valuecell/server/db/repositories/__init__.py @@ -0,0 +1,21 @@ +"""Database repositories for ValueCell Server.""" + +from .asset_repository import ( + AssetRepository, + get_asset_repository, + reset_asset_repository, +) +from .watchlist_repository import ( + WatchlistRepository, + get_watchlist_repository, + reset_watchlist_repository, +) + +__all__ = [ + "AssetRepository", + "get_asset_repository", + "reset_asset_repository", + "WatchlistRepository", + "get_watchlist_repository", + "reset_watchlist_repository", +] diff --git a/python/valuecell/server/db/repositories/asset_repository.py b/python/valuecell/server/db/repositories/asset_repository.py new file mode 100644 index 000000000..e9e9c4c15 --- /dev/null +++ b/python/valuecell/server/db/repositories/asset_repository.py @@ -0,0 +1,442 @@ +""" +ValueCell Server - Asset Repository + +This module provides database operations for asset management. +""" + +from typing import List, Optional + +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from ..connection import get_database_manager +from ..models.asset import Asset + + +class AssetRepository: + """Repository class for asset database operations.""" + + def __init__(self, db_session: Optional[Session] = None): + """Initialize repository with optional database session.""" + self.db_session = db_session + + def _get_session(self) -> Session: + """Get database session.""" + if self.db_session: + return self.db_session + return get_database_manager().get_session() + + def create_asset( + self, + symbol: str, + name: str, + asset_type: str, + description: Optional[str] = None, + sector: Optional[str] = None, + current_price: Optional[float] = None, + is_active: bool = True, + asset_metadata: Optional[dict] = None, + config: Optional[dict] = None, + ) -> Optional[Asset]: + """Create a new asset. + + Args: + symbol: Asset symbol/ticker (e.g., NASDAQ:AAPL) + name: Full name of the asset + asset_type: Type of asset (stock, bond, crypto, etc.) + description: Detailed description of the asset + sector: Industry sector (for stocks) + current_price: Current market price + is_active: Whether the asset is active + asset_metadata: Additional metadata + config: Asset-specific configuration parameters + + Returns: + Created Asset object or None if creation fails + """ + session = self._get_session() + + try: + asset = Asset( + symbol=symbol, + name=name, + asset_type=asset_type, + description=description, + sector=sector, + current_price=current_price, + is_active=is_active, + asset_metadata=asset_metadata, + config=config, + ) + + session.add(asset) + session.commit() + session.refresh(asset) + + # Expunge to avoid session issues + session.expunge(asset) + + return asset + + except IntegrityError: + session.rollback() + return None + except Exception: + session.rollback() + return None + finally: + if not self.db_session: + session.close() + + def get_asset_by_symbol(self, symbol: str) -> Optional[Asset]: + """Get asset by symbol. + + Args: + symbol: Asset symbol/ticker + + Returns: + Asset object or None if not found + """ + session = self._get_session() + + try: + asset = session.query(Asset).filter_by(symbol=symbol).first() + + if asset: + # Expunge to avoid session issues + session.expunge(asset) + + return asset + + finally: + if not self.db_session: + session.close() + + def get_asset_by_id(self, asset_id: int) -> Optional[Asset]: + """Get asset by ID. + + Args: + asset_id: Asset ID + + Returns: + Asset object or None if not found + """ + session = self._get_session() + + try: + asset = session.query(Asset).filter_by(id=asset_id).first() + + if asset: + # Expunge to avoid session issues + session.expunge(asset) + + return asset + + finally: + if not self.db_session: + session.close() + + def get_all_assets( + self, is_active: Optional[bool] = None, limit: Optional[int] = None + ) -> List[Asset]: + """Get all assets with optional filtering. + + Args: + is_active: Filter by active status (None for all) + limit: Maximum number of results + + Returns: + List of Asset objects + """ + session = self._get_session() + + try: + query = session.query(Asset) + + if is_active is not None: + query = query.filter(Asset.is_active == is_active) + + if limit: + query = query.limit(limit) + + assets = query.all() + + # Expunge all assets to avoid session issues + for asset in assets: + session.expunge(asset) + + return assets + + finally: + if not self.db_session: + session.close() + + def update_asset( + self, + symbol: str, + name: Optional[str] = None, + description: Optional[str] = None, + asset_type: Optional[str] = None, + sector: Optional[str] = None, + current_price: Optional[float] = None, + is_active: Optional[bool] = None, + asset_metadata: Optional[dict] = None, + config: Optional[dict] = None, + ) -> Optional[Asset]: + """Update an existing asset. + + Args: + symbol: Asset symbol/ticker + name: Full name of the asset (if updating) + description: Detailed description (if updating) + asset_type: Type of asset (if updating) + sector: Industry sector (if updating) + current_price: Current market price (if updating) + is_active: Whether the asset is active (if updating) + asset_metadata: Additional metadata (if updating) + config: Asset-specific configuration (if updating) + + Returns: + Updated Asset object or None if not found + """ + session = self._get_session() + + try: + asset = session.query(Asset).filter_by(symbol=symbol).first() + + if not asset: + return None + + # Update fields if provided + if name is not None: + asset.name = name + if description is not None: + asset.description = description + if asset_type is not None: + asset.asset_type = asset_type + if sector is not None: + asset.sector = sector + if current_price is not None: + asset.current_price = current_price + if is_active is not None: + asset.is_active = is_active + if asset_metadata is not None: + asset.asset_metadata = asset_metadata + if config is not None: + asset.config = config + + session.commit() + session.refresh(asset) + + # Expunge to avoid session issues + session.expunge(asset) + + return asset + + except Exception: + session.rollback() + return None + finally: + if not self.db_session: + session.close() + + def update_asset_metadata( + self, symbol: str, metadata_updates: dict + ) -> Optional[Asset]: + """Update asset metadata by merging new data with existing. + + Args: + symbol: Asset symbol/ticker + metadata_updates: Dictionary of metadata updates to merge + + Returns: + Updated Asset object or None if not found + """ + session = self._get_session() + + try: + asset = session.query(Asset).filter_by(symbol=symbol).first() + + if not asset: + return None + + # Merge metadata + existing_metadata = asset.asset_metadata or {} + existing_metadata.update(metadata_updates) + asset.asset_metadata = existing_metadata + + session.commit() + session.refresh(asset) + + # Expunge to avoid session issues + session.expunge(asset) + + return asset + + except Exception: + session.rollback() + return None + finally: + if not self.db_session: + session.close() + + def delete_asset(self, symbol: str) -> bool: + """Delete an asset by symbol. + + Args: + symbol: Asset symbol/ticker + + Returns: + True if deleted successfully, False otherwise + """ + session = self._get_session() + + try: + asset = session.query(Asset).filter_by(symbol=symbol).first() + + if not asset: + return False + + session.delete(asset) + session.commit() + + return True + + except Exception: + session.rollback() + return False + finally: + if not self.db_session: + session.close() + + def asset_exists(self, symbol: str) -> bool: + """Check if an asset exists by symbol. + + Args: + symbol: Asset symbol/ticker + + Returns: + True if asset exists, False otherwise + """ + session = self._get_session() + + try: + exists = session.query(Asset).filter_by(symbol=symbol).first() is not None + return exists + + finally: + if not self.db_session: + session.close() + + def upsert_asset( + self, + symbol: str, + name: str, + asset_type: str, + description: Optional[str] = None, + sector: Optional[str] = None, + current_price: Optional[float] = None, + is_active: bool = True, + asset_metadata: Optional[dict] = None, + config: Optional[dict] = None, + ) -> Optional[Asset]: + """Create or update an asset by symbol. + + If asset exists, updates all provided fields. + If asset doesn't exist, creates a new one. + + Args: + symbol: Asset symbol/ticker (e.g., NASDAQ:AAPL) + name: Full name of the asset + asset_type: Type of asset (stock, bond, crypto, index, etc.) + description: Detailed description of the asset + sector: Industry sector (for stocks) + current_price: Current market price + is_active: Whether the asset is active + asset_metadata: Additional metadata + config: Asset-specific configuration parameters + + Returns: + Created or updated Asset object or None if operation fails + """ + session = self._get_session() + + try: + # Try to find existing asset + asset = session.query(Asset).filter_by(symbol=symbol).first() + + if asset: + # Update existing asset + asset.name = name + asset.asset_type = asset_type + if description is not None: + asset.description = description + if sector is not None: + asset.sector = sector + if current_price is not None: + asset.current_price = current_price + asset.is_active = is_active + if asset_metadata is not None: + asset.asset_metadata = asset_metadata + if config is not None: + asset.config = config + else: + # Create new asset + asset = Asset( + symbol=symbol, + name=name, + asset_type=asset_type, + description=description, + sector=sector, + current_price=current_price, + is_active=is_active, + asset_metadata=asset_metadata, + config=config, + ) + session.add(asset) + + session.commit() + session.refresh(asset) + + # Expunge to avoid session issues + session.expunge(asset) + + return asset + + except Exception: + session.rollback() + return None + finally: + if not self.db_session: + session.close() + + +# Global repository instance +_asset_repository: Optional[AssetRepository] = None + + +def get_asset_repository(db_session: Optional[Session] = None) -> AssetRepository: + """Get global asset repository instance or create with custom session. + + Args: + db_session: Optional database session. If provided, creates new instance. + + Returns: + AssetRepository instance + """ + global _asset_repository + + if db_session: + # Return new instance with custom session + return AssetRepository(db_session) + + if _asset_repository is None: + _asset_repository = AssetRepository() + + return _asset_repository + + +def reset_asset_repository() -> None: + """Reset global asset repository instance (mainly for testing).""" + global _asset_repository + _asset_repository = None diff --git a/python/valuecell/server/db/repositories/watchlist_repository.py b/python/valuecell/server/db/repositories/watchlist_repository.py index 096646732..fa7245128 100644 --- a/python/valuecell/server/db/repositories/watchlist_repository.py +++ b/python/valuecell/server/db/repositories/watchlist_repository.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from ..connection import get_database_manager +from ..models.asset import Asset from ..models.watchlist import Watchlist, WatchlistItem @@ -261,6 +262,12 @@ def add_asset_to_watchlist( if not watchlist: return False + # If display_name is not provided, try to get it from assets table + if not display_name: + asset = session.query(Asset).filter(Asset.symbol == ticker).first() + if asset and asset.name: + display_name = asset.name + # Set order_index if not provided if order_index is None: max_order = ( diff --git a/python/valuecell/server/services/assets/asset_service.py b/python/valuecell/server/services/assets/asset_service.py index e9997a2b9..bbde510e2 100644 --- a/python/valuecell/server/services/assets/asset_service.py +++ b/python/valuecell/server/services/assets/asset_service.py @@ -101,12 +101,6 @@ def search_assets( ), "exchange": result.exchange, "country": result.country, - "currency": result.currency, - "market_status": result.market_status.value, - "market_status_display": self.i18n_service.get_market_status_display_name( - result.market_status, language - ), - "relevance_score": result.relevance_score, } result_dicts.append(result_dict) @@ -207,13 +201,31 @@ def get_asset_price( "ticker": ticker, } + # Get asset_type from database to handle formatting correctly + asset_type = None + try: + from ...db.repositories.asset_repository import get_asset_repository + + asset_repo = get_asset_repository() + db_asset = asset_repo.get_asset_by_symbol(ticker) + if db_asset: + asset_type = db_asset.asset_type + except Exception as e: + logger.debug( + f"Could not get asset_type from database for {ticker}: {e}" + ) + # If asset not in database, it will be treated as a regular asset with currency + # Format price data with localization formatted_price = { "success": True, "ticker": price_data.ticker, "price": float(price_data.price), "price_formatted": self.i18n_service.format_currency_amount( - float(price_data.price), price_data.currency, language + float(price_data.price), + price_data.currency, + language, + asset_type, ), "currency": price_data.currency, "timestamp": price_data.timestamp.isoformat(), @@ -271,14 +283,28 @@ def get_multiple_prices( try: price_data = self.adapter_manager.get_multiple_prices(tickers) + # Get asset_types from database for all tickers in batch + asset_types = {} + try: + from ...db.repositories.asset_repository import get_asset_repository + + asset_repo = get_asset_repository() + for ticker in tickers: + db_asset = asset_repo.get_asset_by_symbol(ticker) + if db_asset: + asset_types[ticker] = db_asset.asset_type + except Exception as e: + logger.debug(f"Could not get asset_types from database: {e}") + formatted_prices = {} for ticker, price in price_data.items(): if price: + asset_type = asset_types.get(ticker) formatted_prices[ticker] = { "price": float(price.price), "price_formatted": self.i18n_service.format_currency_amount( - float(price.price), price.currency, language + float(price.price), price.currency, language, asset_type ), "currency": price.currency, "timestamp": price.timestamp.isoformat(), @@ -634,51 +660,6 @@ def get_user_watchlists(self, user_id: str) -> Dict[str, Any]: logger.error(f"Error getting user watchlists: {e}") return {"success": False, "error": str(e), "user_id": user_id} - def get_system_health(self) -> Dict[str, Any]: - """Get system health status for all data adapters. - - Returns: - Dictionary containing health status for all adapters - """ - try: - health_data = self.adapter_manager.health_check() - - # Convert enum keys to strings - health_status = {} - for source, status in health_data.items(): - health_status[source.value] = status - - # Calculate overall health - healthy_count = sum( - 1 - for status in health_status.values() - if status.get("status") == "healthy" - ) - total_count = len(health_status) - - # Determine overall status - if total_count == 0: - overall_status = "no_adapters" - elif healthy_count == total_count: - overall_status = "healthy" - elif healthy_count > 0: - overall_status = "degraded" - else: - overall_status = "unhealthy" - - return { - "success": True, - "overall_status": overall_status, - "healthy_adapters": healthy_count, - "total_adapters": total_count, - "adapters": health_status, - "timestamp": datetime.utcnow().isoformat(), - } - - except Exception as e: - logger.error(f"Error getting system health: {e}") - return {"success": False, "error": str(e), "overall_status": "error"} - # Global service instance _asset_service: Optional[AssetService] = None