diff --git a/python/valuecell/agents/common/trading/data/market.py b/python/valuecell/agents/common/trading/data/market.py index 2ecf1a87f..577e0837e 100644 --- a/python/valuecell/agents/common/trading/data/market.py +++ b/python/valuecell/agents/common/trading/data/market.py @@ -1,3 +1,5 @@ +import asyncio +import itertools from collections import defaultdict from typing import List, Optional @@ -56,33 +58,32 @@ def _normalize_symbol(self, symbol: str) -> str: async def get_recent_candles( self, symbols: List[str], interval: str, lookback: int ) -> List[Candle]: - async def _fetch(symbol: str, normalized_symbol: str) -> List[List]: + async def _fetch_and_process(symbol: str) -> List[Candle]: # instantiate exchange class by name (e.g., ccxtpro.kraken) exchange_cls = get_exchange_cls(self._exchange_id) exchange = exchange_cls({"newUpdates": False}) + + symbol_candles: List[Candle] = [] + normalized_symbol = self._normalize_symbol(symbol) try: - # ccxt.pro uses async fetch_ohlcv with normalized symbol - data = await exchange.fetch_ohlcv( - normalized_symbol, timeframe=interval, since=None, limit=lookback - ) - return data - finally: try: - await exchange.close() - except Exception: - pass + # ccxt.pro uses async fetch_ohlcv with normalized symbol + raw = await exchange.fetch_ohlcv( + normalized_symbol, + timeframe=interval, + since=None, + limit=lookback, + ) + finally: + try: + await exchange.close() + except Exception: + pass - candles: List[Candle] = [] - # Run fetch for each symbol sequentially - for symbol in symbols: - try: - # Normalize symbol format for the exchange (e.g., BTC-USDC -> BTC/USDC:USDC) - normalized_symbol = self._normalize_symbol(symbol) - raw = await _fetch(symbol, normalized_symbol) # raw is list of [ts, open, high, low, close, volume] for row in raw: ts, open_v, high_v, low_v, close_v, vol = row - candles.append( + symbol_candles.append( Candle( ts=int(ts), instrument=InstrumentRef( @@ -98,6 +99,7 @@ async def _fetch(symbol: str, normalized_symbol: str) -> List[List]: interval=interval, ) ) + return symbol_candles except Exception as exc: logger.warning( "Failed to fetch candles for {} (normalized: {}) from {}, data interval is {}, return empty candles. Error: {}", @@ -107,6 +109,15 @@ async def _fetch(symbol: str, normalized_symbol: str) -> List[List]: interval, exc, ) + return [] + + # Run fetch for each symbol concurrently + tasks = [_fetch_and_process(symbol) for symbol in symbols] + results = await asyncio.gather(*tasks) + + # Flatten the list of lists results into a single list of candles + candles: List[Candle] = list(itertools.chain.from_iterable(results)) + logger.debug( f"Fetch {len(candles)} candles symbols: {symbols}, interval: {interval}, lookback: {lookback}" ) diff --git a/python/valuecell/agents/common/trading/decision/prompt_based/composer.py b/python/valuecell/agents/common/trading/decision/prompt_based/composer.py index 6b1cd3536..0a9d9fb42 100644 --- a/python/valuecell/agents/common/trading/decision/prompt_based/composer.py +++ b/python/valuecell/agents/common/trading/decision/prompt_based/composer.py @@ -52,6 +52,20 @@ def __init__( self._request = request self._default_slippage_bps = default_slippage_bps self._quantity_precision = quantity_precision + cfg = self._request.llm_model_config + self._model = model_utils.create_model_with_provider( + provider=cfg.provider, + model_id=cfg.model_id, + api_key=cfg.api_key, + ) + self.agent = AgnoAgent( + model=self._model, + output_schema=TradePlanProposal, + markdown=False, + instructions=[SYSTEM_PROMPT], + use_json_mode=model_utils.model_should_use_json_mode(self._model), + debug_mode=env_utils.agent_debug_mode_enabled(), + ) def _build_prompt_text(self) -> str: """Return a resolved prompt text by fusing custom_prompt and prompt_text. @@ -186,24 +200,7 @@ async def _call_llm(self, prompt: str) -> TradePlanProposal: agent's `response.content` is returned (or validated) as a `LlmPlanProposal`. """ - - cfg = self._request.llm_model_config - model = model_utils.create_model_with_provider( - provider=cfg.provider, - model_id=cfg.model_id, - api_key=cfg.api_key, - ) - - # Wrap model in an Agent (consistent with parser_agent usage) - agent = AgnoAgent( - model=model, - output_schema=TradePlanProposal, - markdown=False, - instructions=[SYSTEM_PROMPT], - use_json_mode=model_utils.model_should_use_json_mode(model), - debug_mode=env_utils.agent_debug_mode_enabled(), - ) - response = await agent.arun(prompt) + response = await self.agent.arun(prompt) # Agent may return a raw object or a wrapper with `.content`. content = getattr(response, "content", None) or response logger.debug("Received LLM response {}", content) @@ -216,7 +213,7 @@ async def _call_llm(self, prompt: str) -> TradePlanProposal: items=[], rationale=( "LLM output failed validation. The model you chose " - f"`{model_utils.describe_model(model)}` " + f"`{model_utils.describe_model(self._model)}` " "may be incompatible or returned unexpected output. " f"Raw output: {content}" ), diff --git a/python/valuecell/agents/common/trading/features/pipeline.py b/python/valuecell/agents/common/trading/features/pipeline.py index 29af101d2..547fd6ebd 100644 --- a/python/valuecell/agents/common/trading/features/pipeline.py +++ b/python/valuecell/agents/common/trading/features/pipeline.py @@ -8,9 +8,14 @@ from __future__ import annotations -from typing import List +import asyncio +import itertools +from typing import List, Optional + +from loguru import logger from valuecell.agents.common.trading.models import ( + CandleConfig, FeaturesPipelineResult, FeatureVector, UserRequest, @@ -36,53 +41,65 @@ def __init__( market_data_source: BaseMarketDataSource, candle_feature_computer: CandleBasedFeatureComputer, market_snapshot_computer: MarketSnapshotFeatureComputer, - micro_interval: str = "1s", - micro_lookback: int = 60 * 3, - medium_interval: str = "1m", - medium_lookback: int = 60 * 4, + candle_configurations: Optional[List[CandleConfig]] = None, ) -> None: self._request = request self._market_data_source = market_data_source self._candle_feature_computer = candle_feature_computer - self._micro_interval = micro_interval - self._micro_lookback = micro_lookback - self._medium_interval = medium_interval - self._medium_lookback = medium_lookback self._symbols = list(dict.fromkeys(request.trading_config.symbols)) self._market_snapshot_computer = market_snapshot_computer + self._candle_configurations = candle_configurations + self._candle_configurations = candle_configurations or [ + CandleConfig(interval="1s", lookback=60 * 3), + CandleConfig(interval="1m", lookback=60 * 4), + ] async def build(self) -> FeaturesPipelineResult: - """Fetch candles, compute feature vectors, and append market features.""" - # Determine symbols from the configured request so caller doesn't pass them - candles_micro = await self._market_data_source.get_recent_candles( - self._symbols, self._micro_interval, self._micro_lookback - ) - micro_features = self._candle_feature_computer.compute_features( - candles=candles_micro - ) + """ + Fetch candles and market snapshot, compute feature vectors concurrently, + and combine results. + """ - candles_medium = await self._market_data_source.get_recent_candles( - self._symbols, self._medium_interval, self._medium_lookback - ) - medium_features = self._candle_feature_computer.compute_features( - candles=candles_medium - ) + async def _fetch_candles(interval: str, lookback: int) -> List[FeatureVector]: + """Fetches candles and computes features for a single (interval, lookback) pair.""" + _candles = await self._market_data_source.get_recent_candles( + self._symbols, interval, lookback + ) + return self._candle_feature_computer.compute_features(candles=_candles) - features: List[FeatureVector] = [] - features.extend(medium_features or []) - features.extend(micro_features or []) + async def _fetch_market_features() -> List[FeatureVector]: + """Fetches market snapshot for all symbols and computes features.""" + market_snapshot = await self._market_data_source.get_market_snapshot( + self._symbols + ) + market_snapshot = market_snapshot or {} + return self._market_snapshot_computer.build( + market_snapshot, self._request.exchange_config.exchange_id + ) - market_snapshot = await self._market_data_source.get_market_snapshot( - self._symbols + logger.info( + f"Starting concurrent data fetching for {len(self._candle_configurations)} candle sets and markets snapshot..." ) - market_snapshot = market_snapshot or {} + tasks = [ + _fetch_candles(config.interval, config.lookback) + for config in self._candle_configurations + ] + tasks.append(_fetch_market_features()) + + # results = [ [candle_features_1], [candle_features_2], ..., [market_features] ] + results = await asyncio.gather(*tasks) + logger.info("Concurrent data fetching complete.") - market_features = self._market_snapshot_computer.build( - market_snapshot, self._request.exchange_config.exchange_id + market_features: List[FeatureVector] = results.pop() + + # Flatten the list of lists of candle features + candle_features: List[FeatureVector] = list( + itertools.chain.from_iterable(results) ) - features.extend(market_features) - return FeaturesPipelineResult(features=features) + candle_features.extend(market_features) + + return FeaturesPipelineResult(features=candle_features) @classmethod def from_request(cls, request: UserRequest) -> DefaultFeaturesPipeline: diff --git a/python/valuecell/agents/common/trading/models.py b/python/valuecell/agents/common/trading/models.py index 6d5900a8d..983123674 100644 --- a/python/valuecell/agents/common/trading/models.py +++ b/python/valuecell/agents/common/trading/models.py @@ -327,6 +327,16 @@ class InstrumentRef(BaseModel): # ) +@dataclass(frozen=True) +class CandleConfig: + """Configuration for a specific candle size and lookback number.""" + + interval: str = Field( + ..., description="the interval of each candle, e.g., '1s', '1m'" + ) + lookback: int = Field(..., gt=0, description="the number of candles to look back") + + class Candle(BaseModel): """Aggregated OHLCV candle for a fixed interval."""