Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions python/valuecell/agents/common/trading/data/market.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import itertools
from collections import defaultdict
from typing import List, Optional

Expand Down Expand Up @@ -56,33 +58,32 @@ def _normalize_symbol(self, symbol: str) -> str:
async def get_recent_candles(
self, symbols: List[str], interval: str, lookback: int
) -> List[Candle]:
async def _fetch(symbol: str, normalized_symbol: str) -> List[List]:
async def _fetch_and_process(symbol: str) -> List[Candle]:
# instantiate exchange class by name (e.g., ccxtpro.kraken)
exchange_cls = get_exchange_cls(self._exchange_id)
exchange = exchange_cls({"newUpdates": False})

symbol_candles: List[Candle] = []
normalized_symbol = self._normalize_symbol(symbol)
try:
# ccxt.pro uses async fetch_ohlcv with normalized symbol
data = await exchange.fetch_ohlcv(
normalized_symbol, timeframe=interval, since=None, limit=lookback
)
return data
finally:
try:
await exchange.close()
except Exception:
pass
# ccxt.pro uses async fetch_ohlcv with normalized symbol
raw = await exchange.fetch_ohlcv(
normalized_symbol,
timeframe=interval,
since=None,
limit=lookback,
)
finally:
try:
await exchange.close()
except Exception:
pass

candles: List[Candle] = []
# Run fetch for each symbol sequentially
for symbol in symbols:
try:
# Normalize symbol format for the exchange (e.g., BTC-USDC -> BTC/USDC:USDC)
normalized_symbol = self._normalize_symbol(symbol)
raw = await _fetch(symbol, normalized_symbol)
# raw is list of [ts, open, high, low, close, volume]
for row in raw:
ts, open_v, high_v, low_v, close_v, vol = row
candles.append(
symbol_candles.append(
Candle(
ts=int(ts),
instrument=InstrumentRef(
Expand All @@ -98,6 +99,7 @@ async def _fetch(symbol: str, normalized_symbol: str) -> List[List]:
interval=interval,
)
)
return symbol_candles
except Exception as exc:
logger.warning(
"Failed to fetch candles for {} (normalized: {}) from {}, data interval is {}, return empty candles. Error: {}",
Expand All @@ -107,6 +109,15 @@ async def _fetch(symbol: str, normalized_symbol: str) -> List[List]:
interval,
exc,
)
return []

# Run fetch for each symbol concurrently
tasks = [_fetch_and_process(symbol) for symbol in symbols]
results = await asyncio.gather(*tasks)

# Flatten the list of lists results into a single list of candles
candles: List[Candle] = list(itertools.chain.from_iterable(results))

logger.debug(
f"Fetch {len(candles)} candles symbols: {symbols}, interval: {interval}, lookback: {lookback}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ def __init__(
self._request = request
self._default_slippage_bps = default_slippage_bps
self._quantity_precision = quantity_precision
cfg = self._request.llm_model_config
self._model = model_utils.create_model_with_provider(
provider=cfg.provider,
model_id=cfg.model_id,
api_key=cfg.api_key,
)
self.agent = AgnoAgent(
model=self._model,
output_schema=TradePlanProposal,
markdown=False,
instructions=[SYSTEM_PROMPT],
use_json_mode=model_utils.model_should_use_json_mode(self._model),
debug_mode=env_utils.agent_debug_mode_enabled(),
)

def _build_prompt_text(self) -> str:
"""Return a resolved prompt text by fusing custom_prompt and prompt_text.
Expand Down Expand Up @@ -186,24 +200,7 @@ async def _call_llm(self, prompt: str) -> TradePlanProposal:
agent's `response.content` is returned (or validated) as a
`LlmPlanProposal`.
"""

cfg = self._request.llm_model_config
model = model_utils.create_model_with_provider(
provider=cfg.provider,
model_id=cfg.model_id,
api_key=cfg.api_key,
)

# Wrap model in an Agent (consistent with parser_agent usage)
agent = AgnoAgent(
model=model,
output_schema=TradePlanProposal,
markdown=False,
instructions=[SYSTEM_PROMPT],
use_json_mode=model_utils.model_should_use_json_mode(model),
debug_mode=env_utils.agent_debug_mode_enabled(),
)
response = await agent.arun(prompt)
response = await self.agent.arun(prompt)
# Agent may return a raw object or a wrapper with `.content`.
content = getattr(response, "content", None) or response
logger.debug("Received LLM response {}", content)
Expand All @@ -216,7 +213,7 @@ async def _call_llm(self, prompt: str) -> TradePlanProposal:
items=[],
rationale=(
"LLM output failed validation. The model you chose "
f"`{model_utils.describe_model(model)}` "
f"`{model_utils.describe_model(self._model)}` "
"may be incompatible or returned unexpected output. "
f"Raw output: {content}"
),
Expand Down
83 changes: 50 additions & 33 deletions python/valuecell/agents/common/trading/features/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@

from __future__ import annotations

from typing import List
import asyncio
import itertools
from typing import List, Optional

from loguru import logger

from valuecell.agents.common.trading.models import (
CandleConfig,
FeaturesPipelineResult,
FeatureVector,
UserRequest,
Expand All @@ -36,53 +41,65 @@ def __init__(
market_data_source: BaseMarketDataSource,
candle_feature_computer: CandleBasedFeatureComputer,
market_snapshot_computer: MarketSnapshotFeatureComputer,
micro_interval: str = "1s",
micro_lookback: int = 60 * 3,
medium_interval: str = "1m",
medium_lookback: int = 60 * 4,
candle_configurations: Optional[List[CandleConfig]] = None,
) -> None:
self._request = request
self._market_data_source = market_data_source
self._candle_feature_computer = candle_feature_computer
self._micro_interval = micro_interval
self._micro_lookback = micro_lookback
self._medium_interval = medium_interval
self._medium_lookback = medium_lookback
self._symbols = list(dict.fromkeys(request.trading_config.symbols))
self._market_snapshot_computer = market_snapshot_computer
self._candle_configurations = candle_configurations
self._candle_configurations = candle_configurations or [
CandleConfig(interval="1s", lookback=60 * 3),
CandleConfig(interval="1m", lookback=60 * 4),
]

async def build(self) -> FeaturesPipelineResult:
"""Fetch candles, compute feature vectors, and append market features."""
# Determine symbols from the configured request so caller doesn't pass them
candles_micro = await self._market_data_source.get_recent_candles(
self._symbols, self._micro_interval, self._micro_lookback
)
micro_features = self._candle_feature_computer.compute_features(
candles=candles_micro
)
"""
Fetch candles and market snapshot, compute feature vectors concurrently,
and combine results.
"""

candles_medium = await self._market_data_source.get_recent_candles(
self._symbols, self._medium_interval, self._medium_lookback
)
medium_features = self._candle_feature_computer.compute_features(
candles=candles_medium
)
async def _fetch_candles(interval: str, lookback: int) -> List[FeatureVector]:
"""Fetches candles and computes features for a single (interval, lookback) pair."""
_candles = await self._market_data_source.get_recent_candles(
self._symbols, interval, lookback
)
return self._candle_feature_computer.compute_features(candles=_candles)

features: List[FeatureVector] = []
features.extend(medium_features or [])
features.extend(micro_features or [])
async def _fetch_market_features() -> List[FeatureVector]:
"""Fetches market snapshot for all symbols and computes features."""
market_snapshot = await self._market_data_source.get_market_snapshot(
self._symbols
)
market_snapshot = market_snapshot or {}
return self._market_snapshot_computer.build(
market_snapshot, self._request.exchange_config.exchange_id
)

market_snapshot = await self._market_data_source.get_market_snapshot(
self._symbols
logger.info(
f"Starting concurrent data fetching for {len(self._candle_configurations)} candle sets and markets snapshot..."
)
market_snapshot = market_snapshot or {}
tasks = [
_fetch_candles(config.interval, config.lookback)
for config in self._candle_configurations
]
tasks.append(_fetch_market_features())

# results = [ [candle_features_1], [candle_features_2], ..., [market_features] ]
results = await asyncio.gather(*tasks)
logger.info("Concurrent data fetching complete.")

market_features = self._market_snapshot_computer.build(
market_snapshot, self._request.exchange_config.exchange_id
market_features: List[FeatureVector] = results.pop()

# Flatten the list of lists of candle features
candle_features: List[FeatureVector] = list(
itertools.chain.from_iterable(results)
)
features.extend(market_features)

return FeaturesPipelineResult(features=features)
candle_features.extend(market_features)

return FeaturesPipelineResult(features=candle_features)

@classmethod
def from_request(cls, request: UserRequest) -> DefaultFeaturesPipeline:
Expand Down
10 changes: 10 additions & 0 deletions python/valuecell/agents/common/trading/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ class InstrumentRef(BaseModel):
# )


@dataclass(frozen=True)
class CandleConfig:
"""Configuration for a specific candle size and lookback number."""

interval: str = Field(
..., description="the interval of each candle, e.g., '1s', '1m'"
)
lookback: int = Field(..., gt=0, description="the number of candles to look back")


class Candle(BaseModel):
"""Aggregated OHLCV candle for a fixed interval."""

Expand Down