diff --git a/TASKS.md b/TASKS.md index 931decde..a4dbf5dd 100644 --- a/TASKS.md +++ b/TASKS.md @@ -78,6 +78,8 @@ This document consolidates all implementation tasks from the architecture and de | **TB-005** | Observability | HIGH | 4h | Implement MCP metrics (request count, latency, errors) | | **TB-006** | Security | HIGH | 3h | Add API key expiration/rotation mechanism | +> **Note (2026-03-20)**: Paper session `initial_capital` changed to $100k — affects new auto-created sessions only (existing sessions are unaffected). + **Total Blocking Effort**: ~18 hours (2-3 days) ### ⚠️ CRITICAL FOR BETA LAUNCH diff --git a/cmd/api/handlers_trading.go b/cmd/api/handlers_trading.go index 8d9dff33..75fa2516 100644 --- a/cmd/api/handlers_trading.go +++ b/cmd/api/handlers_trading.go @@ -1,17 +1,33 @@ package main import ( + "errors" + "fmt" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/rs/zerolog/log" "github.com/ajitpratap0/cryptofunk/internal/db" ) +const ( + paperSlippageBuy = 1.001 // 0.1% adverse slippage for market buy orders + paperSlippageSell = 0.999 // 0.1% adverse slippage for market sell orders + // TODO: make slippage configurable via config.Trading. +) + +// errOppositeSide is a sentinel returned from the WithTx callback when the +// incoming order is on the opposite side of an existing open position. It is +// handled by the caller to produce a 422 response without logging as an +// internal server error. +var errOppositeSide = errors.New("opposite side trade on existing position") + // Session handlers func (s *APIServer) handleListSessions(c *gin.Context) { ctx := c.Request.Context() @@ -378,6 +394,20 @@ func (s *APIServer) handleCancelOrder(c *gin.Context) { }) } +// quoteAsset derives the quote asset token from a trading symbol by checking +// common suffixes. Falls back to "USDT" for unrecognised symbols. +func quoteAsset(symbol string) string { + // Ordered most-specific first: "BUSD" before "BTC" prevents "BTCUSDT" from + // matching suffix "BTC" when "USDT" would be the correct quote asset. + // TODO: make configurable for non-Binance exchanges (e.g. Kraken uses XBT/USD). + for _, suffix := range []string{"USDT", "BUSD", "BTC", "ETH", "BNB"} { + if strings.HasSuffix(strings.ToUpper(symbol), suffix) { + return suffix + } + } + return "USDT" +} + // handlePaperTrade executes a paper (simulated) trade order. // Market orders are immediately filled; limit orders remain open (NEW status). // POST /api/v1/trade @@ -389,63 +419,278 @@ func (s *APIServer) handlePaperTrade(c *gin.Context) { Quantity float64 `json:"quantity" binding:"required,gt=0"` Price float64 `json:"price"` } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "invalid request body", - "details": err.Error(), - }) + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body", "details": err.Error()}) + return + } + isLimit := strings.EqualFold(req.Type, "limit") + if isLimit && req.Price <= 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "price is required for limit orders"}) return } - if (req.Type == "limit" || req.Type == "LIMIT") && req.Price <= 0 { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "price is required for limit orders", - }) + ctx := c.Request.Context() + + // NOTE: Session lookup and creation happen outside WithTx intentionally. + // The session row is long-lived (one per trading mode) and is safe to create + // outside the fill transaction. A failed fill transaction does not orphan the + // session — the next request simply reuses it. + + // 1. Resolve or create paper session + sessions, err := s.db.ListActiveSessions(ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to list active sessions for paper trade") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to resolve trading session"}) return } + var sessionID *uuid.UUID + for i := range sessions { + if sessions[i].Mode == db.TradingModePaper { + id := sessions[i].ID + sessionID = &id + break + } + } + if sessionID == nil { + newSession := &db.TradingSession{ + ID: uuid.New(), + Mode: db.TradingModePaper, + // TODO: Add a session_type or is_multi_asset column to trading_sessions in a follow-up + // migration. "PAPER" is a placeholder to distinguish multi-asset paper sessions from + // single-symbol sessions (which use the actual symbol, e.g. "BTCUSDT"). + Symbol: "PAPER", + Exchange: "paper", + InitialCapital: s.config.Trading.InitialCapital, + StartedAt: time.Now(), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + if err := s.db.CreateSession(ctx, newSession); err != nil { + // Only retry on unique constraint violation (PG error code 23505). + // A concurrent request may have inserted the same paper session between the + // ListActiveSessions call above and this insert (TOCTOU race). In that case + // we look up and reuse the existing session. + // All other errors (e.g. connection timeout) are returned immediately. + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + log.Warn().Err(err).Msg("Unique constraint on paper session; retrying lookup for concurrent session") + sessions2, err2 := s.db.ListActiveSessions(ctx) + if err2 == nil { + for i := range sessions2 { + if sessions2[i].Mode == db.TradingModePaper { + id := sessions2[i].ID + sessionID = &id + break + } + } + } + if sessionID == nil { + log.Error().Err(err).Msg("Failed to find paper session after unique constraint violation") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create trading session"}) + return + } + } else { + log.Error().Err(err).Msg("Failed to create paper session") + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "internal server error", + }) + return + } + } else { + sessionID = &newSession.ID + } + } + + // 2. Determine execution price + refPrice := req.Price + if !isLimit && refPrice <= 0 { + refPrice = s.exchange.GetMarketPrice(req.Symbol) + if refPrice <= 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "no market price configured for symbol; provide a price field", + }) + return + } + } + execPrice := refPrice + if !isLimit { + if strings.EqualFold(req.Side, "BUY") { + execPrice = refPrice * paperSlippageBuy + } else { + execPrice = refPrice * paperSlippageSell + } + } - var price *float64 + // 3. Build order struct (inserted inside the transaction below) + now := time.Now() + var pricePtr *float64 if req.Price > 0 { - price = &req.Price + pricePtr = db.PtrFloat64(req.Price) } + orderSide := db.ConvertOrderSide(req.Side) + orderType := db.ConvertOrderType(req.Type) order := &db.Order{ ID: uuid.New(), + SessionID: sessionID, Symbol: req.Symbol, Exchange: "paper", - Side: db.ConvertOrderSide(req.Side), - Type: db.ConvertOrderType(req.Type), + Side: orderSide, + Type: orderType, Quantity: req.Quantity, - Price: price, + Price: pricePtr, Status: db.OrderStatusNew, - PlacedAt: time.Now(), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), + PlacedAt: now, + CreatedAt: now, + UpdatedAt: now, } - ctx := c.Request.Context() - if err := s.db.InsertOrder(ctx, order); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "failed to create paper trade order", - }) - return - } - - // Simulate immediate fill for market orders + // 4. Immediate fill for market orders if order.Type == db.OrderTypeMarket { - now := time.Now() - execPrice := req.Price execQuoteQty := execPrice * req.Quantity + // Use the configured paper trading commission rate (taker fee from trading config). + // Falls back to 0.001 (0.1%) if not configured, matching Binance standard tier. + commissionRate := s.config.Trading.CommissionRate + if commissionRate <= 0 { + commissionRate = 0.001 + } + commission := execQuoteQty * commissionRate - if err := s.db.UpdateOrderStatus(ctx, order.ID, db.OrderStatusFilled, req.Quantity, execQuoteQty, &now, nil, nil); err != nil { - log.Warn().Err(err).Str("order_id", order.ID.String()).Msg("Failed to mark paper trade order as filled") - } else { + posSide := db.PositionSideLong + if orderSide == db.OrderSideSell { + posSide = db.PositionSideShort + } + + // Derive the quote asset from the symbol (e.g. "BTCUSDT" → "USDT"). + commissionAsset := quoteAsset(req.Symbol) + + // Wrap all fill writes in a single DB transaction so a mid-flight failure + // does not leave orphaned rows. The order insert is the first step inside + // the transaction so no orphaned order rows can result from a failed fill. + // The existingPos lookup is also inside the transaction to eliminate the + // TOCTOU race where two concurrent BUY orders could both observe + // existingPos == nil and each try to INSERT a new position for the same symbol. + // AggregateSessionStats is intentionally kept outside the transaction: it + // is a read-then-aggregate UPDATE that can be safely retried. + // RepeatableRead + FOR UPDATE on positions prevents concurrent orders from + // both observing existingPos == nil and each inserting a duplicate position. + // Final safety net: migration 019 adds a UNIQUE partial index on + // (session_id, symbol) WHERE exit_time IS NULL. If two concurrent transactions + // both observe existingPos == nil and both attempt to INSERT a new open + // position, the second INSERT fails with SQLSTATE 23505 (unique_violation), + // causing its enclosing transaction to roll back rather than silently creating + // a duplicate open position for the same symbol. + txErr := s.db.WithTx(ctx, pgx.TxOptions{IsoLevel: pgx.RepeatableRead}, func(tx pgx.Tx) error { + // Insert the order as the first step so it is rolled back atomically + // with all fill rows if any later step fails. + if err := s.db.InsertOrderTx(ctx, tx, order); err != nil { + return fmt.Errorf("failed to insert paper trade order: %w", err) + } + + // Re-fetch position inside the transaction for a consistent view. + existingPos, err := s.db.GetOpenPositionBySymbolTx(ctx, tx, *sessionID, req.Symbol) + if err != nil { + return fmt.Errorf("failed to look up existing position: %w", err) + } + + if existingPos != nil && existingPos.Side != posSide { + // Opposite-side trade on an existing open position. Proper close/reduce + // logic (netting, realized PnL calculation) is not yet implemented. + // Return a sentinel so the outer handler can respond with 422. + return errOppositeSide + } + + // UpdateOrderStatus inside transaction + if err := s.db.UpdateOrderStatusTx(ctx, tx, order.ID, db.OrderStatusFilled, req.Quantity, execQuoteQty, &now, nil, nil); err != nil { + return fmt.Errorf("failed to mark paper order filled: %w", err) + } order.Status = db.OrderStatusFilled order.ExecutedQuantity = req.Quantity order.ExecutedQuoteQuantity = execQuoteQty order.FilledAt = &now order.UpdatedAt = now + + // InsertTrade inside transaction via DB-layer method. + trade := &db.Trade{ + ID: uuid.New(), + OrderID: order.ID, + ExchangeTradeID: nil, + Symbol: req.Symbol, + Exchange: "paper", + Side: orderSide, + Price: execPrice, + Quantity: req.Quantity, + QuoteQuantity: execQuoteQty, + Commission: commission, + CommissionAsset: &commissionAsset, + ExecutedAt: now, + IsMaker: false, + Metadata: nil, + CreatedAt: now, + } + if err := s.db.InsertTradeTx(ctx, tx, trade); err != nil { + return fmt.Errorf("failed to insert paper trade fill row: %w", err) + } + + // Create or average into existing position inside transaction via DB-layer methods. + if existingPos == nil { + entryReason := "paper_trade_api" + pos := &db.Position{ + ID: uuid.New(), + SessionID: sessionID, + Symbol: req.Symbol, + Exchange: "paper", + Side: posSide, + EntryPrice: execPrice, + Quantity: req.Quantity, + EntryTime: now, + EntryReason: &entryReason, + CreatedAt: now, + UpdatedAt: now, + } + if err := s.db.CreatePositionTx(ctx, tx, pos); err != nil { + return fmt.Errorf("failed to create position for paper trade: %w", err) + } + } else { + totalQty := existingPos.Quantity + req.Quantity + weightedAvg := (existingPos.Quantity*existingPos.EntryPrice + req.Quantity*execPrice) / totalQty + if err := s.db.UpdatePositionAveragingTx(ctx, tx, existingPos.ID, weightedAvg, totalQty, commission); err != nil { + return fmt.Errorf("failed to update position for paper trade: %w", err) + } + } + return nil + }) + + if txErr != nil { + if errors.Is(txErr, errOppositeSide) { + // Opposite-side trade on an existing open position. Proper close/reduce + // logic (netting, realized PnL calculation) is not yet implemented. + // Reject the trade rather than silently corrupting position data. + log.Warn(). + Str("symbol", req.Symbol). + Str("order_side", string(posSide)). + Msg("Opposite-side trade on existing position; position close logic not yet implemented") + c.JSON(http.StatusUnprocessableEntity, gin.H{ + "error": "position close/reduce not yet implemented; opposite-side trade rejected", + }) + return + } + log.Error().Err(txErr).Msg("Paper trade transaction failed") + c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"}) + return + } + + // AggregateSessionStats is outside the transaction: it is a safe read-aggregate + // UPDATE that can be retried without risk of partial data corruption. + if err := s.db.AggregateSessionStats(ctx, *sessionID); err != nil { + log.Warn().Err(err).Msg("Failed to aggregate session stats after paper trade") + } + } else { + // Limit orders are not immediately filled; persist the order record in NEW status. + if err := s.db.InsertOrder(ctx, order); err != nil { + log.Error().Err(err).Msg("Failed to insert paper trade limit order") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create paper trade order"}) + return } } diff --git a/cmd/api/main.go b/cmd/api/main.go index 1eddbe00..bb269309 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -18,6 +18,7 @@ import ( "github.com/ajitpratap0/cryptofunk/internal/api" "github.com/ajitpratap0/cryptofunk/internal/config" "github.com/ajitpratap0/cryptofunk/internal/db" + "github.com/ajitpratap0/cryptofunk/internal/exchange" "github.com/ajitpratap0/cryptofunk/internal/safety" ) @@ -43,6 +44,7 @@ type APIServer struct { orderExecSession *mcp.ClientSession // MCP session for order-executor calls mcpClient *mcp.Client // MCP client for creating/reconnecting sessions activeSessionID *uuid.UUID // Currently active trading session ID (guarded by sessionMu) + exchange exchange.Exchange // Shared mock exchange instance for paper trading } // HTTP client for orchestrator communication with timeout and connection pooling @@ -122,6 +124,7 @@ func main() { ctx: ctx, safetyGuard: safetyGuard, orderExecutorURL: getOrderExecutorURL(), + exchange: exchange.NewMockExchange(database), } // Initialize MCP client for order-executor (session connects lazily on first order) diff --git a/cmd/api/paper_trade_pipeline_test.go b/cmd/api/paper_trade_pipeline_test.go new file mode 100644 index 00000000..2e9e10e8 --- /dev/null +++ b/cmd/api/paper_trade_pipeline_test.go @@ -0,0 +1,109 @@ +//go:build integration + +package main + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ajitpratap0/cryptofunk/internal/exchange" +) + +func TestPaperTrade_PersistsAllRows(t *testing.T) { + t.Run("with_explicit_price", func(t *testing.T) { + srv, _ := setupTestAPIServer(t) + // Provide a mock exchange so s.exchange is never nil. + srv.exchange = exchange.NewMockExchange(srv.db) + + body := `{"symbol":"BTCUSDT","side":"BUY","type":"market","quantity":0.1,"price":45000}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/trade", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.router.ServeHTTP(w, req) + + require.Equal(t, http.StatusCreated, w.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + orderID := resp["order"].(map[string]interface{})["id"].(string) + + ctx := context.Background() + + // Verify order row is filled with non-zero executed_quote_quantity + order, err := srv.db.GetOrder(ctx, uuid.MustParse(orderID)) + require.NoError(t, err) + assert.Equal(t, "FILLED", string(order.Status)) + assert.Greater(t, order.ExecutedQuantity, 0.0) + assert.Greater(t, order.ExecutedQuoteQuantity, 0.0, "executed_quote_quantity must be non-zero (price bug)") + + // Verify trade fill row via GetTradesByOrderID (already exists in orders.go) + fills, err := srv.db.GetTradesByOrderID(ctx, uuid.MustParse(orderID)) + require.NoError(t, err) + assert.NotEmpty(t, fills, "expected at least one trade fill row in trades table") + assert.Greater(t, fills[0].Price, 0.0, "fill price must be > 0") + + // Verify open position exists + positions, err := srv.db.GetAllOpenPositions(ctx) + require.NoError(t, err) + found := false + for _, p := range positions { + if p.Symbol == "BTCUSDT" { + found = true + assert.Greater(t, p.EntryPrice, 0.0) + assert.Greater(t, p.Quantity, 0.0) + } + } + assert.True(t, found, "expected BTCUSDT open position after paper trade") + }) + + t.Run("uses_get_market_price_when_no_price_in_request", func(t *testing.T) { + srv, _ := setupTestAPIServer(t) + // Seed a market price so GetMarketPrice returns a non-zero value. + mockEx := exchange.NewMockExchange(srv.db) + mockEx.SetMarketPrice("ETHUSDT", 3000.0) + srv.exchange = mockEx + + body := `{"symbol":"ETHUSDT","side":"BUY","type":"market","quantity":0.05}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/trade", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.router.ServeHTTP(w, req) + + require.Equal(t, http.StatusCreated, w.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + orderID := resp["order"].(map[string]interface{})["id"].(string) + + ctx := context.Background() + + // Verify trade fill row exists and has a positive price derived from GetMarketPrice. + fills, err := srv.db.GetTradesByOrderID(ctx, uuid.MustParse(orderID)) + require.NoError(t, err) + assert.NotEmpty(t, fills, "expected at least one trade fill row") + assert.Greater(t, fills[0].Price, 0.0, "fill price must be > 0 when price comes from GetMarketPrice") + }) + + t.Run("returns_400_when_no_price_and_no_market_price_seeded", func(t *testing.T) { + srv, _ := setupTestAPIServer(t) + // Use a fresh mock exchange with no price seeded for SOLUSDT. + srv.exchange = exchange.NewMockExchange(srv.db) + + body := `{"symbol":"SOLUSDT","side":"BUY","type":"market","quantity":0.05}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/trade", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + srv.router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Contains(t, resp["error"], "no market price configured for symbol") + }) +} diff --git a/cmd/api/routes.go b/cmd/api/routes.go index 5ac451d2..9eb34799 100644 --- a/cmd/api/routes.go +++ b/cmd/api/routes.go @@ -236,6 +236,10 @@ func (s *APIServer) setupRoutes() { dashboardHandler := api.NewDashboardHandlerWithOrchestrator(s.db, orchClient, config.Version) dashboardHandler.RegisterRoutesWithRateLimiter(v1, s.rateLimiter.ReadMiddleware(), s.rateLimiter.ControlMiddleware()) + // Trades routes — fill records from the trades table + tradesHandler := api.NewTradesHandler(s.db) + tradesHandler.RegisterRoutes(v1, s.rateLimiter.ReadMiddleware(), api.AuthMiddleware(s.apiKeyStore, authConfig)) + // TB-006: API Key Management routes // These endpoints allow users to manage their API keys (create, rotate, revoke) // All key management operations require authentication @@ -260,6 +264,14 @@ func (s *APIServer) setupRoutes() { polymarketHandler := api.NewPolymarketHandler(s.db) polymarketHandler.RegisterRoutesWithRateLimiter(v1, s.rateLimiter.ReadMiddleware(), s.rateLimiter.OrderMiddleware()) + // Risk metrics routes + riskHandler := api.NewRiskHandler(s.db, &s.config.Risk) + riskHandler.RegisterRoutes(v1, s.rateLimiter.ReadMiddleware(), api.AuthMiddleware(s.apiKeyStore, authConfig)) + + // Performance routes + perfHandler := api.NewPerformanceHandler(s.db) + perfHandler.RegisterRoutes(v1, s.rateLimiter.ReadMiddleware(), api.AuthMiddleware(s.apiKeyStore, authConfig)) + // Decision analytics and outcome resolution routes decisionAnalyticsHandler := api.NewDecisionAnalyticsHandler(s.db) decisionAnalyticsHandler.RegisterRoutes(v1, s.rateLimiter.ReadMiddleware(), api.AuthMiddleware(s.apiKeyStore, authConfig)) diff --git a/configs/config.yaml b/configs/config.yaml index 9b004124..89377045 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -150,9 +150,13 @@ trading: - "BTCUSDT" - "ETHUSDT" exchange: "binance" - initial_capital: 10000.0 + # initial_capital: starting capital for auto-created paper trading sessions. + # Changed from 10000.0 → 100000.0 (2026-03-20) to reflect realistic position sizing. + # Existing sessions are unaffected; only newly auto-created sessions use this value. + initial_capital: 100000.0 max_positions: 3 default_quantity: 0.01 + commission_rate: 0.001 # 0.1% taker fee applied to paper trade fills (Binance standard tier) risk: max_position_size: 0.1 # 10% of portfolio @@ -163,6 +167,11 @@ risk: llm_approval_required: true # Require LLM approval for trades min_confidence: 0.7 # Minimum confidence for signals + # Dashboard circuit breaker thresholds (absolute units, used by /risk/circuit-breakers) + max_daily_loss_dollars: 5000.0 # Trigger when cumulative session loss exceeds $5000 + max_drawdown_pct: 10.0 # Trigger when max session drawdown exceeds 10% + max_trade_count: 100 # Trigger when total trades across active sessions exceeds 100 + # Safety Guard Configuration # Live trading protection mechanisms to prevent excessive losses # These guards are enforced at the order execution level diff --git a/internal/api/decisions_unit_test.go b/internal/api/decisions_unit_test.go new file mode 100644 index 00000000..92e838e4 --- /dev/null +++ b/internal/api/decisions_unit_test.go @@ -0,0 +1,14 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestItoa(t *testing.T) { + assert.Equal(t, "0", itoa(0)) + assert.Equal(t, "42", itoa(42)) + assert.Equal(t, "-1", itoa(-1)) + assert.Equal(t, "100", itoa(100)) +} diff --git a/internal/api/performance.go b/internal/api/performance.go new file mode 100644 index 00000000..a971c60b --- /dev/null +++ b/internal/api/performance.go @@ -0,0 +1,60 @@ +package api + +import ( + "math" + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/ajitpratap0/cryptofunk/internal/db" +) + +// PerformanceHandler provides REST endpoints for per-symbol performance metrics. +type PerformanceHandler struct { + db *db.DB +} + +// NewPerformanceHandler creates a new PerformanceHandler backed by the given database. +func NewPerformanceHandler(database *db.DB) *PerformanceHandler { + return &PerformanceHandler{db: database} +} + +// RegisterRoutes mounts the /performance sub-group under the provided router group. +// If authMiddleware is non-nil it is applied to all routes alongside readMiddleware. +func (h *PerformanceHandler) RegisterRoutes(rg *gin.RouterGroup, readMiddleware gin.HandlerFunc, authMiddleware gin.HandlerFunc) { + g := rg.Group("/performance") + if authMiddleware != nil { + g.GET("/pairs", readMiddleware, authMiddleware, h.GetPairPerformance) + } else { + g.Use(readMiddleware) + g.GET("/pairs", h.GetPairPerformance) + } +} + +// GetPairPerformance returns realized PnL aggregated by trading pair across all active sessions. +// GET /api/v1/performance/pairs +func (h *PerformanceHandler) GetPairPerformance(c *gin.Context) { + ctx := c.Request.Context() + + rows, err := h.db.GetPairPerformance(ctx) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to aggregate pair performance"}) + return + } + + type pairPerf struct { + Symbol string `json:"symbol"` + RealizedPnL float64 `json:"realized_pnl"` + TradeCount int `json:"trade_count"` + } + pairs := make([]pairPerf, 0, len(rows)) + for _, r := range rows { + pairs = append(pairs, pairPerf{ + Symbol: r.Symbol, + RealizedPnL: math.Round(r.RealizedPnL*100) / 100, + TradeCount: r.TradeCount, + }) + } + + c.JSON(http.StatusOK, gin.H{"pairs": pairs, "count": len(pairs)}) +} diff --git a/internal/api/risk.go b/internal/api/risk.go new file mode 100644 index 00000000..46ad190b --- /dev/null +++ b/internal/api/risk.go @@ -0,0 +1,252 @@ +package api + +import ( + "context" + "math" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" + + "github.com/ajitpratap0/cryptofunk/internal/config" + "github.com/ajitpratap0/cryptofunk/internal/db" + "github.com/ajitpratap0/cryptofunk/internal/risk" +) + +// RiskHandler provides REST endpoints for risk metrics and circuit breaker status. +type RiskHandler struct { + db *db.DB + riskService *risk.Service + cfg *config.RiskConfig +} + +// NewRiskHandler creates a new RiskHandler backed by the given database and risk config. +// If cfg is nil a zero-value RiskConfig is used, which causes setDefaults values to be +// applied at load time and therefore cfg should always be non-nil in production. +func NewRiskHandler(database *db.DB, cfg *config.RiskConfig) *RiskHandler { + if cfg == nil { + cfg = &config.RiskConfig{} + } + return &RiskHandler{ + db: database, + riskService: risk.NewService(), + cfg: cfg, + } +} + +// RegisterRoutes mounts the /risk sub-group under the provided router group. +// If authMiddleware is non-nil it is applied to all routes alongside readMiddleware. +func (h *RiskHandler) RegisterRoutes(rg *gin.RouterGroup, readMiddleware gin.HandlerFunc, authMiddleware gin.HandlerFunc) { + r := rg.Group("/risk") + if authMiddleware != nil { + r.GET("/metrics", readMiddleware, authMiddleware, h.GetMetrics) + r.GET("/circuit-breakers", readMiddleware, authMiddleware, h.GetCircuitBreakers) + r.GET("/exposure", readMiddleware, authMiddleware, h.GetExposure) + } else { + r.Use(readMiddleware) + r.GET("/metrics", h.GetMetrics) + r.GET("/circuit-breakers", h.GetCircuitBreakers) + r.GET("/exposure", h.GetExposure) + } +} + +// GetMetrics returns VaR, CVaR, open position count, and total exposure. +// GET /api/v1/risk/metrics +func (h *RiskHandler) GetMetrics(c *gin.Context) { + ctx := c.Request.Context() + + openPositions, err := h.db.GetAllOpenPositions(ctx) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to query open positions"}) + return + } + + openCount := len(openPositions) + // NOTE: exposure is calculated at cost-basis (entry_price), not mark-to-market. + // Current market price is not stored on the position; a live price lookup + // would be needed for accurate mark-to-market exposure. + var totalExposure float64 + for _, p := range openPositions { + totalExposure += p.Quantity * p.EntryPrice + } + + returns, err := h.collectClosedReturns(ctx) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to collect returns"}) + return + } + + dataPoints := len(returns) + response := gin.H{ + "open_positions": openCount, + "total_exposure": math.Round(totalExposure*100) / 100, + "data_points": dataPoints, + "var_95": nil, + "var_99": nil, + "expected_shortfall": nil, + } + if dataPoints < 10 { + log.Debug().Int("data_points", dataPoints).Msg("insufficient data for meaningful VaR estimate; need at least 10 closed positions in the last 90 days") + } + + if dataPoints >= 10 { + // CalculateVaR requires []interface{} not []float64 + returnsIface := make([]interface{}, dataPoints) + for i, v := range returns { + returnsIface[i] = v + } + + // Sum InitialCapital across all active sessions to get total portfolio value. + // Used to convert fractional VaR (e.g. 0.023) into dollar VaR (e.g. $2,300). + // NOTE: scaling by InitialCapital (not CurrentCapital) because TradingSession + // does not yet track current portfolio value. Dollar VaR will be inaccurate + // after significant P&L. Track in TASKS.md follow-up. + activeSessions, sessErr := h.db.ListActiveSessions(ctx) + portfolioValue := 0.0 + if sessErr == nil { + for _, s := range activeSessions { + portfolioValue += s.InitialCapital + } + } else { + log.Warn().Err(sessErr).Msg("ListActiveSessions failed; skipping VaR calculation") + } + + // Skip VaR entirely when there is no portfolio to compute it for. + // var_95, var_99, and expected_shortfall remain nil in the response. + if portfolioValue > 0 { + res95, err := h.riskService.CalculateVaR(map[string]interface{}{ + "returns": returnsIface, + "confidence_level": 0.95, + }) + if err != nil { + log.Debug().Err(err).Msg("VaR calculation failed (95%)") + } else { + if varResult, ok := res95.(*risk.VaRResult); ok { + response["var_95"] = varResult.VaR * portfolioValue + } + } + + res99, err := h.riskService.CalculateVaR(map[string]interface{}{ + "returns": returnsIface, + "confidence_level": 0.99, + }) + if err != nil { + log.Debug().Err(err).Msg("VaR calculation failed (99%)") + } else { + if varResult, ok := res99.(*risk.VaRResult); ok { + response["var_99"] = varResult.VaR * portfolioValue + response["expected_shortfall"] = varResult.CVaR * portfolioValue + } + } + } + } + + c.JSON(http.StatusOK, response) +} + +// GetCircuitBreakers returns the status of system-level circuit breakers. +// GET /api/v1/risk/circuit-breakers +func (h *RiskHandler) GetCircuitBreakers(c *gin.Context) { + ctx := c.Request.Context() + + sessions, err := h.db.ListActiveSessions(ctx) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list sessions"}) + return + } + + var totalPnL, maxDrawdown float64 + var totalTrades int + for _, s := range sessions { + totalPnL += s.TotalPnL + if s.MaxDrawdown > maxDrawdown { + maxDrawdown = s.MaxDrawdown + } + totalTrades += s.TotalTrades + } + + // Session Total Loss: triggered when cumulative session losses exceed the configured threshold (dollars). + // NOTE: totalPnL is the sum of TotalPnL across all active sessions (lifetime session PnL), + // not a rolling daily figure. The threshold field MaxDailyLossDollars still applies as the + // absolute dollar loss limit; only the label has been corrected to avoid confusion. + lossAmount := math.Abs(math.Min(totalPnL, 0)) + breakers := []gin.H{ + buildBreaker("Session Total Loss", lossAmount, h.cfg.MaxDailyLossDollars), + buildBreaker("Max Drawdown %", maxDrawdown*100, h.cfg.MaxDrawdownPct), + buildBreaker("Total Trade Count", float64(totalTrades), float64(h.cfg.MaxTradeCount)), + } + + c.JSON(http.StatusOK, gin.H{"circuit_breakers": breakers, "count": len(breakers)}) +} + +func buildBreaker(name string, current, threshold float64) gin.H { + status := "OK" + if threshold <= 0 { + // A zero or negative threshold means the breaker is disabled/unconfigured. + // Guard against false positives: never fire for unset config values. + status = "DISABLED" + } else if current >= threshold { + status = "TRIGGERED" + } else if current/threshold >= 0.8 { + status = "WARNING" + } + return gin.H{ + "name": name, + "current": math.Round(current*100) / 100, + "threshold": threshold, + "status": status, + } +} + +// GetExposure returns open position exposure grouped by symbol. +// GET /api/v1/risk/exposure +func (h *RiskHandler) GetExposure(c *gin.Context) { + ctx := c.Request.Context() + + // NOTE: exposure is calculated at cost-basis (entry_price), not mark-to-market. + // Current market price is not stored on the position; a live price lookup + // would be needed for accurate mark-to-market exposure. + rows, err := h.db.GetExposureBySymbol(ctx) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to query exposure by symbol"}) + return + } + + type symbolExposureJSON struct { + Symbol string `json:"symbol"` + Exposure float64 `json:"exposure"` + } + result := make([]symbolExposureJSON, 0, len(rows)) + for _, r := range rows { + result = append(result, symbolExposureJSON{ + Symbol: r.Symbol, + Exposure: math.Round(r.Exposure*100) / 100, + }) + } + + c.JSON(http.StatusOK, gin.H{"exposure": result, "count": len(result)}) +} + +// collectClosedReturns gathers fractional returns from all closed positions. +// Each return is RealizedPnL / (EntryPrice * Quantity) so values are dimensionless +// fractions (e.g. 0.023 = 2.3%) suitable for VaR calculations. +func (h *RiskHandler) collectClosedReturns(ctx context.Context) ([]float64, error) { + positions, err := h.db.GetAllClosedPositions(ctx) + if err != nil { + return nil, err + } + + var returns []float64 + for _, p := range positions { + // RealizedPnL is non-nil here: GetAllClosedPositions filters realized_pnl IS NOT NULL. + // This guard is retained as defense-in-depth against future query changes. + if p.RealizedPnL != nil { + notional := p.EntryPrice * p.Quantity + if notional == 0 { + continue + } + returns = append(returns, *p.RealizedPnL/notional) + } + } + return returns, nil +} diff --git a/internal/api/risk_unit_test.go b/internal/api/risk_unit_test.go new file mode 100644 index 00000000..3c0c484a --- /dev/null +++ b/internal/api/risk_unit_test.go @@ -0,0 +1,215 @@ +package api + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/ajitpratap0/cryptofunk/internal/config" +) + +// TestParseDuration verifies that parseDuration converts human-friendly duration +// strings into the correct past time.Time value relative to now. +func TestParseDuration(t *testing.T) { + tests := []struct { + name string + input string + checkApprox func(t *testing.T, got time.Time) + }{ + { + name: "1h subtracts one hour", + input: "1h", + checkApprox: func(t *testing.T, got time.Time) { + t.Helper() + want := time.Now().Add(-1 * time.Hour) + diff := want.Sub(got) + if diff < 0 { + diff = -diff + } + assert.Less(t, diff, 2*time.Second, "expected time ~1h ago") + }, + }, + { + name: "24h subtracts 24 hours", + input: "24h", + checkApprox: func(t *testing.T, got time.Time) { + t.Helper() + want := time.Now().Add(-24 * time.Hour) + diff := want.Sub(got) + if diff < 0 { + diff = -diff + } + assert.Less(t, diff, 2*time.Second, "expected time ~24h ago") + }, + }, + { + name: "7d subtracts 7 days", + input: "7d", + checkApprox: func(t *testing.T, got time.Time) { + t.Helper() + want := time.Now().AddDate(0, 0, -7) + diff := want.Sub(got) + if diff < 0 { + diff = -diff + } + assert.Less(t, diff, 2*time.Second, "expected time ~7 days ago") + }, + }, + { + name: "30d subtracts 30 days", + input: "30d", + checkApprox: func(t *testing.T, got time.Time) { + t.Helper() + want := time.Now().AddDate(0, 0, -30) + diff := want.Sub(got) + if diff < 0 { + diff = -diff + } + assert.Less(t, diff, 2*time.Second, "expected time ~30 days ago") + }, + }, + { + name: "3m subtracts 3 months", + input: "3m", + checkApprox: func(t *testing.T, got time.Time) { + t.Helper() + want := time.Now().AddDate(0, -3, 0) + diff := want.Sub(got) + if diff < 0 { + diff = -diff + } + assert.Less(t, diff, 2*time.Second, "expected time ~3 months ago") + }, + }, + { + name: "1y subtracts 1 year", + input: "1y", + checkApprox: func(t *testing.T, got time.Time) { + t.Helper() + want := time.Now().AddDate(-1, 0, 0) + diff := want.Sub(got) + if diff < 0 { + diff = -diff + } + assert.Less(t, diff, 2*time.Second, "expected time ~1 year ago") + }, + }, + { + name: "invalid string falls back to 1 month ago", + input: "notvalid", + checkApprox: func(t *testing.T, got time.Time) { + t.Helper() + want := time.Now().AddDate(0, -1, 0) + diff := want.Sub(got) + if diff < 0 { + diff = -diff + } + assert.Less(t, diff, 2*time.Second, "expected default fallback ~1 month ago") + }, + }, + { + name: "empty string falls back to 1 month ago", + input: "", + checkApprox: func(t *testing.T, got time.Time) { + t.Helper() + want := time.Now().AddDate(0, -1, 0) + diff := want.Sub(got) + if diff < 0 { + diff = -diff + } + assert.Less(t, diff, 2*time.Second, "expected default fallback ~1 month ago") + }, + }, + { + name: "unknown unit falls back to 1 month ago", + input: "5z", + checkApprox: func(t *testing.T, got time.Time) { + t.Helper() + want := time.Now().AddDate(0, -1, 0) + diff := want.Sub(got) + if diff < 0 { + diff = -diff + } + assert.Less(t, diff, 2*time.Second, "expected default fallback ~1 month ago") + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := parseDuration(tc.input) + assert.False(t, got.IsZero(), "parseDuration should never return zero time") + tc.checkApprox(t, got) + }) + } +} + +// TestNewRiskHandler verifies that NewRiskHandler returns a non-nil handler with +// the provided config, and that a nil config is handled gracefully. +func TestNewRiskHandler(t *testing.T) { + t.Run("with explicit config", func(t *testing.T) { + cfg := &config.RiskConfig{ + MaxDailyLossDollars: 500.0, + MaxDrawdownPct: 20.0, + MaxTradeCount: 100, + } + h := NewRiskHandler(nil, cfg) + assert.NotNil(t, h) + assert.Equal(t, cfg, h.cfg) + assert.Nil(t, h.db) + assert.NotNil(t, h.riskService) + }) + + t.Run("with nil config falls back to zero-value RiskConfig", func(t *testing.T) { + h := NewRiskHandler(nil, nil) + assert.NotNil(t, h) + assert.NotNil(t, h.cfg, "cfg should be non-nil even when nil is passed") + assert.Equal(t, &config.RiskConfig{}, h.cfg) + assert.NotNil(t, h.riskService) + }) +} + +// TestNewTradesHandler verifies that NewTradesHandler returns a non-nil handler. +func TestNewTradesHandler(t *testing.T) { + h := NewTradesHandler(nil) + assert.NotNil(t, h) + assert.Nil(t, h.db) +} + +// TestNewPerformanceHandler verifies that NewPerformanceHandler returns a non-nil handler. +func TestNewPerformanceHandler(t *testing.T) { + h := NewPerformanceHandler(nil) + assert.NotNil(t, h) + assert.Nil(t, h.db) +} + +func TestBuildBreaker(t *testing.T) { + tests := []struct { + name string + current float64 + threshold float64 + expectedStatus string + }{ + {"disabled when threshold zero", 100.0, 0.0, "DISABLED"}, + {"disabled when threshold negative", 50.0, -1.0, "DISABLED"}, + {"triggered when current equals threshold", 100.0, 100.0, "TRIGGERED"}, + {"triggered when current exceeds threshold", 110.0, 100.0, "TRIGGERED"}, + {"warning when at 80 percent", 80.0, 100.0, "WARNING"}, + {"warning just above 80 percent", 81.0, 100.0, "WARNING"}, + {"ok when below 80 percent", 79.0, 100.0, "OK"}, + {"ok at zero current", 0.0, 100.0, "OK"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := buildBreaker("TestBreaker", tc.current, tc.threshold) + assert.Equal(t, tc.expectedStatus, result["status"]) + assert.Equal(t, "TestBreaker", result["name"]) + assert.Equal(t, tc.threshold, result["threshold"]) + // current is rounded to 2dp + assert.Equal(t, math.Round(tc.current*100)/100, result["current"]) + }) + } +} diff --git a/internal/api/trades.go b/internal/api/trades.go new file mode 100644 index 00000000..11bd2fbe --- /dev/null +++ b/internal/api/trades.go @@ -0,0 +1,74 @@ +package api + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" + + "github.com/ajitpratap0/cryptofunk/internal/db" +) + +// TradesHandler serves fill records from the trades table. +type TradesHandler struct { + db *db.DB +} + +func NewTradesHandler(database *db.DB) *TradesHandler { + return &TradesHandler{db: database} +} + +// RegisterRoutes mounts the /trades sub-group under the provided router group. +// If authMiddleware is non-nil it is applied to all routes alongside readMiddleware. +func (h *TradesHandler) RegisterRoutes(rg *gin.RouterGroup, readMiddleware gin.HandlerFunc, authMiddleware gin.HandlerFunc) { + trades := rg.Group("/trades") + if authMiddleware != nil { + trades.GET("", readMiddleware, authMiddleware, h.ListTrades) + } else { + trades.Use(readMiddleware) + trades.GET("", h.ListTrades) + } +} + +// ListTrades returns recent trade fills, newest first. +// GET /api/v1/trades?limit=50&offset=0 +func (h *TradesHandler) ListTrades(c *gin.Context) { + limit := 50 + offset := 0 + if l := c.Query("limit"); l != "" { + if v, err := strconv.Atoi(l); err == nil && v > 0 && v <= 500 { + limit = v + } + } + if o := c.Query("offset"); o != "" { + if v, err := strconv.Atoi(o); err == nil && v >= 0 { + offset = v + } + } + + ctx := c.Request.Context() + trades, err := h.db.ListAllTrades(ctx, limit, offset) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch trades"}) + return + } + if trades == nil { + trades = []*db.Trade{} + } + + total, err := h.db.CountAllTrades(ctx) + if err != nil { + // Non-fatal: return the page without total rather than failing the request. + log.Warn().Err(err).Msg("failed to count trades, total will be 0") + total = 0 + } + + c.JSON(http.StatusOK, gin.H{ + "trades": trades, + "count": len(trades), + "total": total, + "limit": limit, + "offset": offset, + }) +} diff --git a/internal/backtest/job_unit_test.go b/internal/backtest/job_unit_test.go new file mode 100644 index 00000000..4e7c97c9 --- /dev/null +++ b/internal/backtest/job_unit_test.go @@ -0,0 +1,83 @@ +package backtest + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestValidateJob tests the validateJob method without a DB connection. +func TestValidateJob(t *testing.T) { + m := &JobManager{} // nil db is fine — validateJob doesn't touch the DB + + validJob := func() *BacktestJob { + return &BacktestJob{ + Name: "test", + StartDate: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + EndDate: time.Date(2024, 6, 1, 0, 0, 0, 0, time.UTC), + Symbols: []string{"BTCUSDT"}, + InitialCapital: 10000.0, + StrategyConfig: map[string]interface{}{"type": "trend"}, + } + } + + t.Run("valid job passes", func(t *testing.T) { + require.NoError(t, m.validateJob(validJob())) + }) + + t.Run("empty name", func(t *testing.T) { + j := validJob() + j.Name = "" + assert.Error(t, m.validateJob(j)) + }) + + t.Run("end before start", func(t *testing.T) { + j := validJob() + j.EndDate = j.StartDate.Add(-time.Hour) + assert.Error(t, m.validateJob(j)) + }) + + t.Run("equal start and end", func(t *testing.T) { + j := validJob() + j.EndDate = j.StartDate + assert.Error(t, m.validateJob(j)) + }) + + t.Run("no symbols", func(t *testing.T) { + j := validJob() + j.Symbols = nil + assert.Error(t, m.validateJob(j)) + }) + + t.Run("zero capital", func(t *testing.T) { + j := validJob() + j.InitialCapital = 0 + assert.Error(t, m.validateJob(j)) + }) + + t.Run("negative capital", func(t *testing.T) { + j := validJob() + j.InitialCapital = -1 + assert.Error(t, m.validateJob(j)) + }) + + t.Run("empty strategy config", func(t *testing.T) { + j := validJob() + j.StrategyConfig = nil + assert.Error(t, m.validateJob(j)) + }) +} + +func TestGetValue(t *testing.T) { + v := 3.14 + assert.Equal(t, 3.14, getValue(&v)) + assert.Equal(t, 0.0, getValue(nil)) +} + +func TestGetIntValue(t *testing.T) { + n := 42 + assert.Equal(t, 42, getIntValue(&n)) + assert.Equal(t, 0, getIntValue(nil)) +} diff --git a/internal/config/config.go b/internal/config/config.go index 26773b89..f0109b18 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -136,6 +136,7 @@ type TradingConfig struct { InitialCapital float64 `mapstructure:"initial_capital"` // 10000.0 MaxPositions int `mapstructure:"max_positions"` // 3 DefaultQuantity float64 `mapstructure:"default_quantity"` // 0.01 + CommissionRate float64 `mapstructure:"commission_rate"` // 0.001 (0.1%) — paper trade taker fee } // RiskConfig contains risk management settings @@ -149,6 +150,13 @@ type RiskConfig struct { MinConfidence float64 `mapstructure:"min_confidence"` // 0.7 CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"` // Circuit breaker thresholds SafetyGuard SafetyGuardConfig `mapstructure:"safety_guard"` // Live trading safety guards + + // Dashboard circuit breaker thresholds — used by the /risk/circuit-breakers endpoint. + // These are expressed in absolute units: dollars for MaxDailyLossDollars, + // percentage points for MaxDrawdownPct, and a count for MaxTradeCount. + MaxDailyLossDollars float64 `mapstructure:"max_daily_loss_dollars"` // 5000.0 (USD) + MaxDrawdownPct float64 `mapstructure:"max_drawdown_pct"` // 10.0 (percent) + MaxTradeCount int `mapstructure:"max_trade_count"` // 100 } // SafetyGuardConfig contains safety guard settings for live trading protection @@ -434,6 +442,7 @@ func setDefaults(v *viper.Viper) { v.SetDefault("trading.initial_capital", 10000.0) v.SetDefault("trading.max_positions", 3) v.SetDefault("trading.default_quantity", 0.01) + v.SetDefault("trading.commission_rate", 0.001) // 0.1% taker fee (Binance standard tier) // Risk defaults v.SetDefault("risk.max_position_size", 0.1) @@ -444,6 +453,11 @@ func setDefaults(v *viper.Viper) { v.SetDefault("risk.llm_approval_required", true) v.SetDefault("risk.min_confidence", 0.7) + // Dashboard circuit breaker threshold defaults (absolute units) + v.SetDefault("risk.max_daily_loss_dollars", 5000.0) + v.SetDefault("risk.max_drawdown_pct", 10.0) + v.SetDefault("risk.max_trade_count", 100) + // Circuit Breaker defaults - aligned with configs/circuit_breakers.yaml // Exchange v.SetDefault("risk.circuit_breaker.exchange.min_requests", 5) diff --git a/internal/db/db.go b/internal/db/db.go index 07e42381..3fc4abf5 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -6,6 +6,7 @@ import ( "os" "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/rs/zerolog/log" "github.com/sony/gobreaker" @@ -142,3 +143,24 @@ func (db *DB) GetCircuitBreaker() *risk.CircuitBreakerManager { func (db *DB) SetCircuitBreaker(cb *risk.CircuitBreakerManager) { db.circuitBreaker = cb } + +// WithTx runs fn inside a single pgx transaction using the provided TxOptions. +// It commits on success and rolls back on any error or panic. +// The caller must not commit or roll back tx. +func (db *DB) WithTx(ctx context.Context, opts pgx.TxOptions, fn func(tx pgx.Tx) error) error { + tx, err := db.pool.BeginTx(ctx, opts) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + defer func() { + if p := recover(); p != nil { + _ = tx.Rollback(context.Background()) + panic(p) + } + }() + if err := fn(tx); err != nil { + _ = tx.Rollback(ctx) + return err + } + return tx.Commit(ctx) +} diff --git a/internal/db/orders.go b/internal/db/orders.go index d2c595db..277cf43f 100644 --- a/internal/db/orders.go +++ b/internal/db/orders.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" ) @@ -72,38 +73,144 @@ type Order struct { // Trade represents a database trade record (fill) type Trade struct { - ID uuid.UUID - OrderID uuid.UUID - ExchangeTradeID *string - Symbol string - Exchange string - Side OrderSide - Price float64 - Quantity float64 - QuoteQuantity float64 - Commission float64 - CommissionAsset *string - ExecutedAt time.Time - IsMaker bool - Metadata map[string]interface{} - CreatedAt time.Time + ID uuid.UUID `json:"id"` + OrderID uuid.UUID `json:"order_id"` + ExchangeTradeID *string `json:"exchange_trade_id"` + Symbol string `json:"symbol"` + Exchange string `json:"exchange"` + Side OrderSide `json:"side"` + Price float64 `json:"price"` + Quantity float64 `json:"quantity"` + QuoteQuantity float64 `json:"quote_quantity"` + Commission float64 `json:"commission"` + CommissionAsset *string `json:"commission_asset"` + ExecutedAt time.Time `json:"executed_at"` + IsMaker bool `json:"is_maker"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +const sqlInsertOrder = ` + INSERT INTO orders ( + id, session_id, position_id, exchange_order_id, symbol, exchange, + side, type, status, price, stop_price, quantity, executed_quantity, + executed_quote_quantity, time_in_force, placed_at, filled_at, + canceled_at, error_message, metadata, created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, + $16, $17, $18, $19, $20, $21, $22 + ) +` + +const sqlUpdateOrderStatus = ` + UPDATE orders + SET status = $1, + executed_quantity = $2, + executed_quote_quantity = $3, + filled_at = $4, + canceled_at = $5, + error_message = $6, + updated_at = NOW() + WHERE id = $7 +` + +const sqlInsertTrade = ` + INSERT INTO trades ( + id, order_id, exchange_trade_id, symbol, exchange, side, + price, quantity, quote_quantity, commission, commission_asset, + executed_at, is_maker, metadata, created_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15 + ) +` + +// InsertOrderTx inserts a new order into the database within an existing transaction. +func (db *DB) InsertOrderTx(ctx context.Context, tx pgx.Tx, order *Order) error { + _, err := tx.Exec(ctx, sqlInsertOrder, + order.ID, + order.SessionID, + order.PositionID, + order.ExchangeOrderID, + order.Symbol, + order.Exchange, + order.Side, + order.Type, + order.Status, + order.Price, + order.StopPrice, + order.Quantity, + order.ExecutedQuantity, + order.ExecutedQuoteQuantity, + order.TimeInForce, + order.PlacedAt, + order.FilledAt, + order.CanceledAt, + order.ErrorMessage, + order.Metadata, + order.CreatedAt, + order.UpdatedAt, + ) + + if err != nil { + log.Error(). + Err(err). + Str("order_id", order.ID.String()). + Str("symbol", order.Symbol). + Msg("Failed to insert order in transaction") + return fmt.Errorf("failed to insert order: %w", err) + } + + log.Debug(). + Str("order_id", order.ID.String()). + Str("symbol", order.Symbol). + Str("status", string(order.Status)). + Msg("Order inserted into database in transaction") + + return nil +} + +// InsertTradeTx inserts a new trade (fill) into the database within an existing transaction. +func (db *DB) InsertTradeTx(ctx context.Context, tx pgx.Tx, trade *Trade) error { + _, err := tx.Exec(ctx, sqlInsertTrade, + trade.ID, + trade.OrderID, + trade.ExchangeTradeID, + trade.Symbol, + trade.Exchange, + trade.Side, + trade.Price, + trade.Quantity, + trade.QuoteQuantity, + trade.Commission, + trade.CommissionAsset, + trade.ExecutedAt, + trade.IsMaker, + trade.Metadata, + trade.CreatedAt, + ) + + if err != nil { + log.Error(). + Err(err). + Str("trade_id", trade.ID.String()). + Str("order_id", trade.OrderID.String()). + Msg("Failed to insert trade in transaction") + return fmt.Errorf("failed to insert trade: %w", err) + } + + log.Debug(). + Str("trade_id", trade.ID.String()). + Str("order_id", trade.OrderID.String()). + Float64("price", trade.Price). + Float64("quantity", trade.Quantity). + Msg("Trade inserted into database in transaction") + + return nil } // InsertOrder inserts a new order into the database func (db *DB) InsertOrder(ctx context.Context, order *Order) error { - query := ` - INSERT INTO orders ( - id, session_id, position_id, exchange_order_id, symbol, exchange, - side, type, status, price, stop_price, quantity, executed_quantity, - executed_quote_quantity, time_in_force, placed_at, filled_at, - canceled_at, error_message, metadata, created_at, updated_at - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, - $16, $17, $18, $19, $20, $21, $22 - ) - ` - - _, err := db.pool.Exec(ctx, query, + _, err := db.pool.Exec(ctx, sqlInsertOrder, order.ID, order.SessionID, order.PositionID, @@ -148,19 +255,7 @@ func (db *DB) InsertOrder(ctx context.Context, order *Order) error { // UpdateOrderStatus updates an order's status and related fields func (db *DB) UpdateOrderStatus(ctx context.Context, orderID uuid.UUID, status OrderStatus, executedQty, executedQuoteQty float64, filledAt, canceledAt *time.Time, errorMsg *string) error { - query := ` - UPDATE orders - SET status = $1, - executed_quantity = $2, - executed_quote_quantity = $3, - filled_at = $4, - canceled_at = $5, - error_message = $6, - updated_at = NOW() - WHERE id = $7 - ` - - result, err := db.pool.Exec(ctx, query, + result, err := db.pool.Exec(ctx, sqlUpdateOrderStatus, status, executedQty, executedQuoteQty, @@ -190,19 +285,41 @@ func (db *DB) UpdateOrderStatus(ctx context.Context, orderID uuid.UUID, status O return nil } +// UpdateOrderStatusTx updates an order's status and executed fields within an existing transaction. +func (db *DB) UpdateOrderStatusTx(ctx context.Context, tx pgx.Tx, orderID uuid.UUID, status OrderStatus, executedQty, executedQuoteQty float64, filledAt, canceledAt *time.Time, errorMsg *string) error { + result, err := tx.Exec(ctx, sqlUpdateOrderStatus, + status, + executedQty, + executedQuoteQty, + filledAt, + canceledAt, + errorMsg, + orderID, + ) + + if err != nil { + log.Error(). + Err(err). + Str("order_id", orderID.String()). + Msg("Failed to update order status in transaction") + return fmt.Errorf("failed to update order status: %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("order not found: %s", orderID.String()) + } + + log.Debug(). + Str("order_id", orderID.String()). + Str("status", string(status)). + Msg("Order status updated in transaction") + + return nil +} + // InsertTrade inserts a new trade (fill) into the database func (db *DB) InsertTrade(ctx context.Context, trade *Trade) error { - query := ` - INSERT INTO trades ( - id, order_id, exchange_trade_id, symbol, exchange, side, - price, quantity, quote_quantity, commission, commission_asset, - executed_at, is_maker, metadata, created_at - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15 - ) - ` - - _, err := db.pool.Exec(ctx, query, + _, err := db.pool.Exec(ctx, sqlInsertTrade, trade.ID, trade.OrderID, trade.ExchangeTradeID, diff --git a/internal/db/positions.go b/internal/db/positions.go index 5376abb2..482d162f 100644 --- a/internal/db/positions.go +++ b/internal/db/positions.go @@ -353,6 +353,32 @@ func (db *DB) GetAllOpenPositions(ctx context.Context) ([]*Position, error) { return scanPositions(rows) } +// GetAllClosedPositions returns positions closed within the last 90 days across all sessions, +// ordered by exit_time DESC. A 90-day window ensures VaR calculations use recent, +// relevant return data rather than an arbitrary row count. +func (db *DB) GetAllClosedPositions(ctx context.Context) ([]*Position, error) { + query := ` + SELECT + id, session_id, symbol, exchange, side, entry_price, exit_price, + quantity, entry_time, exit_time, stop_loss, take_profit, + realized_pnl, unrealized_pnl, fees, entry_reason, exit_reason, + metadata, created_at, updated_at + FROM positions + WHERE exit_time IS NOT NULL + AND exit_time > NOW() - INTERVAL '90 days' + AND realized_pnl IS NOT NULL + ORDER BY exit_time DESC + ` + + rows, err := db.pool.Query(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to query closed positions: %w", err) + } + defer rows.Close() + + return scanPositions(rows) +} + // GetPositionsBySession retrieves all positions (including closed) for a session func (db *DB) GetPositionsBySession(ctx context.Context, sessionID uuid.UUID) ([]*Position, error) { query := ` @@ -419,6 +445,56 @@ func (db *DB) GetPositionBySymbolAndSession(ctx context.Context, symbol string, return &position, nil } +// GetOpenPositionBySymbolTx retrieves the most recent open position for a symbol within a session +// using an existing transaction, providing a consistent read within the transaction boundary. +// Returns (nil, nil) when no open position is found. +func (db *DB) GetOpenPositionBySymbolTx(ctx context.Context, tx pgx.Tx, sessionID uuid.UUID, symbol string) (*Position, error) { + query := ` + SELECT + id, session_id, symbol, exchange, side, entry_price, exit_price, + quantity, entry_time, exit_time, stop_loss, take_profit, + realized_pnl, unrealized_pnl, fees, entry_reason, exit_reason, + metadata, created_at, updated_at + FROM positions + WHERE symbol = $1 AND session_id = $2 AND exit_time IS NULL + ORDER BY entry_time DESC + LIMIT 1 + FOR UPDATE + ` + + var position Position + err := tx.QueryRow(ctx, query, symbol, sessionID).Scan( + &position.ID, + &position.SessionID, + &position.Symbol, + &position.Exchange, + &position.Side, + &position.EntryPrice, + &position.ExitPrice, + &position.Quantity, + &position.EntryTime, + &position.ExitTime, + &position.StopLoss, + &position.TakeProfit, + &position.RealizedPnL, + &position.UnrealizedPnL, + &position.Fees, + &position.EntryReason, + &position.ExitReason, + &position.Metadata, + &position.CreatedAt, + &position.UpdatedAt, + ) + if err != nil { + if err == pgx.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get open position by symbol in transaction: %w", err) + } + + return &position, nil +} + // GetLatestPositionBySymbol retrieves the latest position for a symbol (any session) func (db *DB) GetLatestPositionBySymbol(ctx context.Context, symbol string) (*Position, error) { query := ` @@ -537,6 +613,83 @@ func (db *DB) UpdatePositionQuantity(ctx context.Context, id uuid.UUID, newQuant return nil } +// CreatePositionTx inserts a new position into the database within an existing transaction. +func (db *DB) CreatePositionTx(ctx context.Context, tx pgx.Tx, position *Position) error { + query := ` + INSERT INTO positions ( + id, session_id, symbol, exchange, side, entry_price, quantity, + entry_time, stop_loss, take_profit, entry_reason, metadata, created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14 + ) + ` + + if position.ID == uuid.Nil { + position.ID = uuid.New() + } + if position.CreatedAt.IsZero() { + position.CreatedAt = time.Now() + } + if position.UpdatedAt.IsZero() { + position.UpdatedAt = time.Now() + } + + _, err := tx.Exec(ctx, query, + position.ID, + position.SessionID, + position.Symbol, + position.Exchange, + position.Side, + position.EntryPrice, + position.Quantity, + position.EntryTime, + position.StopLoss, + position.TakeProfit, + position.EntryReason, + position.Metadata, + position.CreatedAt, + position.UpdatedAt, + ) + + if err != nil { + return fmt.Errorf("failed to create position: %w", err) + } + + return nil +} + +// UpdatePositionAveragingTx updates entry price and quantity when adding to a position, +// within an existing transaction. +func (db *DB) UpdatePositionAveragingTx(ctx context.Context, tx pgx.Tx, id uuid.UUID, newEntryPrice, newQuantity float64, additionalFees float64) error { + query := ` + UPDATE positions + SET + entry_price = $2, + quantity = $3, + fees = fees + $4, + updated_at = $5 + WHERE id = $1 AND exit_time IS NULL + ` + + result, err := tx.Exec(ctx, query, + id, + newEntryPrice, + newQuantity, + additionalFees, + time.Now(), + ) + + if err != nil { + return fmt.Errorf("failed to update position averaging: %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("position not found or already closed: %s", id) + } + + return nil +} + // UpdatePositionAveraging updates entry price and quantity when adding to a position func (db *DB) UpdatePositionAveraging(ctx context.Context, id uuid.UUID, newEntryPrice, newQuantity float64, additionalFees float64) error { query := ` @@ -634,6 +787,85 @@ func (db *DB) PartialClosePosition(ctx context.Context, id uuid.UUID, closeQuant return closedPosition, nil } +// PairPerformance holds aggregated realized PnL and trade count for a single trading pair. +type PairPerformance struct { + Symbol string `db:"symbol"` + RealizedPnL float64 `db:"realized_pnl"` + TradeCount int `db:"trade_count"` +} + +// GetPairPerformance returns realized PnL and trade count grouped by symbol using SQL GROUP BY, +// covering all closed positions where realized_pnl is not NULL. +func (db *DB) GetPairPerformance(ctx context.Context) ([]PairPerformance, error) { + query := ` + SELECT symbol, COALESCE(SUM(realized_pnl), 0) AS realized_pnl, COUNT(*) AS trade_count + FROM positions + WHERE exit_time IS NOT NULL AND realized_pnl IS NOT NULL + GROUP BY symbol + ORDER BY realized_pnl DESC + -- Cap at 200 rows — sufficient for dashboard display; a full paginated API is a follow-up. + LIMIT 200 + ` + + rows, err := db.pool.Query(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to query pair performance: %w", err) + } + defer rows.Close() + + var results []PairPerformance + for rows.Next() { + var p PairPerformance + if err := rows.Scan(&p.Symbol, &p.RealizedPnL, &p.TradeCount); err != nil { + return nil, fmt.Errorf("failed to scan pair performance row: %w", err) + } + results = append(results, p) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating pair performance rows: %w", err) + } + return results, nil +} + +// SymbolExposure holds the cost-basis exposure for a single symbol across all open positions. +type SymbolExposure struct { + Symbol string `db:"symbol"` + Exposure float64 `db:"exposure"` +} + +// GetExposureBySymbol returns the total open-position exposure (quantity * entry_price) grouped by +// symbol using SQL GROUP BY. Exposure is calculated at cost-basis, not mark-to-market. +func (db *DB) GetExposureBySymbol(ctx context.Context) ([]SymbolExposure, error) { + query := ` + SELECT symbol, SUM(quantity * entry_price) AS exposure + FROM positions + WHERE exit_time IS NULL + GROUP BY symbol + ORDER BY exposure DESC + -- Cap at 200 rows — sufficient for dashboard display; a full paginated API is a follow-up. + LIMIT 200 + ` + + rows, err := db.pool.Query(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to query exposure by symbol: %w", err) + } + defer rows.Close() + + var results []SymbolExposure + for rows.Next() { + var s SymbolExposure + if err := rows.Scan(&s.Symbol, &s.Exposure); err != nil { + return nil, fmt.Errorf("failed to scan symbol exposure row: %w", err) + } + results = append(results, s) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating symbol exposure rows: %w", err) + } + return results, nil +} + // ConvertPositionSide converts a string to PositionSide func ConvertPositionSide(side string) PositionSide { switch side { diff --git a/internal/db/trades.go b/internal/db/trades.go new file mode 100644 index 00000000..b56575ce --- /dev/null +++ b/internal/db/trades.go @@ -0,0 +1,67 @@ +package db + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" +) + +// ListAllTrades returns recent trade fills across all orders, newest first. +// For fills by specific order, use GetTradesByOrderID in orders.go instead. +func (db *DB) ListAllTrades(ctx context.Context, limit, offset int) ([]*Trade, error) { + rows, err := db.pool.Query(ctx, ` + SELECT id, order_id, exchange_trade_id, symbol, exchange, side, + price, quantity, quote_quantity, commission, commission_asset, + executed_at, is_maker, metadata, created_at + FROM trades + ORDER BY executed_at DESC + LIMIT $1 OFFSET $2 + `, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to query trades: %w", err) + } + defer rows.Close() + + var trades []*Trade + for rows.Next() { + t := &Trade{} + if err := rows.Scan( + &t.ID, &t.OrderID, &t.ExchangeTradeID, &t.Symbol, &t.Exchange, &t.Side, + &t.Price, &t.Quantity, &t.QuoteQuantity, &t.Commission, &t.CommissionAsset, + &t.ExecutedAt, &t.IsMaker, &t.Metadata, &t.CreatedAt, + ); err != nil { + return nil, fmt.Errorf("failed to scan trade row: %w", err) + } + trades = append(trades, t) + } + return trades, rows.Err() +} + +// CountAllTrades returns an approximate count of trade fill records using pg_class statistics. +// This is O(1) instead of O(n) — avoids a full sequential COUNT(*) scan on every request. +// The estimate is sourced from pg_class.reltuples which is updated by ANALYZE/autovacuum. +// If the table has never been analyzed (reltuples = -1), the function returns 0. +// to_regclass('public.trades') is used instead of current_schema() to avoid a wrong +// result when TimescaleDB adds _timescaledb_internal (or another schema) to search_path, +// which would make current_schema() return something other than 'public'. +func (db *DB) CountAllTrades(ctx context.Context) (int, error) { + var estimate int64 + err := db.pool.QueryRow(ctx, + `SELECT COALESCE(reltuples::bigint, -1) AS estimate + FROM pg_class + WHERE oid = to_regclass('public.trades')`, + ).Scan(&estimate) + if err != nil { + // pgx returns pgx.ErrNoRows when to_regclass returns NULL (table doesn't exist). + if err == pgx.ErrNoRows { + return 0, nil + } + return 0, fmt.Errorf("failed to count trades: %w", err) + } + // reltuples is -1 for a freshly created table that has never been analyzed. + if estimate < 0 { + return 0, nil + } + return int(estimate), nil +} diff --git a/internal/db/utils.go b/internal/db/utils.go new file mode 100644 index 00000000..a0bf154b --- /dev/null +++ b/internal/db/utils.go @@ -0,0 +1,4 @@ +package db + +// PtrFloat64 returns a pointer to f. Useful for optional float64 struct fields. +func PtrFloat64(f float64) *float64 { return &f } diff --git a/internal/db/utils_test.go b/internal/db/utils_test.go new file mode 100644 index 00000000..5f961d29 --- /dev/null +++ b/internal/db/utils_test.go @@ -0,0 +1,22 @@ +package db + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPtrFloat64(t *testing.T) { + v := PtrFloat64(3.14) + require.NotNil(t, v) + assert.Equal(t, 3.14, *v) + + zero := PtrFloat64(0.0) + require.NotNil(t, zero) + assert.Equal(t, 0.0, *zero) + + neg := PtrFloat64(-99.5) + require.NotNil(t, neg) + assert.Equal(t, -99.5, *neg) +} diff --git a/internal/exchange/safety_guard_nil_test.go b/internal/exchange/safety_guard_nil_test.go new file mode 100644 index 00000000..48ca30a0 --- /dev/null +++ b/internal/exchange/safety_guard_nil_test.go @@ -0,0 +1,117 @@ +package exchange + +// Tests for service safety-guard methods when the guard is initialized but disabled +// (i.e., NewServicePaper — SafetyGuardConfig{Enabled: false}). +// +// Note: safetyGuard is always non-nil; the nil branch in each method is dead code +// because NewService always calls risk.NewSafetyGuard(config.SafetyGuard). + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newPaperServiceNoSafety(t *testing.T) *Service { + t.Helper() + s := NewServicePaper(nil) // nil DB, no safety guard config → guard initialized but disabled + return s +} + +func TestGetSafetyGuardStats_NoGuard(t *testing.T) { + s := newPaperServiceNoSafety(t) + result, err := s.GetSafetyGuardStats(context.Background(), nil) + require.NoError(t, err) + m := result.(map[string]interface{}) + // Guard is non-nil, so enabled is always true in the non-nil branch + assert.Equal(t, true, m["enabled"]) + assert.Equal(t, float64(0), m["daily_pnl"]) + assert.Equal(t, int(0), m["daily_trade_count"]) +} + +func TestSetSafetyGuardCapital_NoGuard(t *testing.T) { + s := newPaperServiceNoSafety(t) + result, err := s.SetSafetyGuardCapital(context.Background(), map[string]interface{}{"capital": 1000.0}) + require.NoError(t, err) + m := result.(map[string]interface{}) + assert.Equal(t, true, m["success"]) + assert.Equal(t, 1000.0, m["capital"]) +} + +func TestRecordTradePnL_NoGuard(t *testing.T) { + s := newPaperServiceNoSafety(t) + result, err := s.RecordTradePnL(context.Background(), map[string]interface{}{"pnl": 50.0}) + require.NoError(t, err) + m := result.(map[string]interface{}) + assert.Equal(t, true, m["success"]) + assert.Equal(t, 50.0, m["pnl_recorded"]) +} + +func TestResetSafetyGuardCircuitBreaker_NoGuard(t *testing.T) { + s := newPaperServiceNoSafety(t) + result, err := s.ResetSafetyGuardCircuitBreaker(context.Background(), nil) + require.NoError(t, err) + m := result.(map[string]interface{}) + assert.Equal(t, true, m["success"]) +} + +func TestResetDailyCounters_NoGuard(t *testing.T) { + s := newPaperServiceNoSafety(t) + // nil args → capital extraction fails: "capital is required" + _, err := s.ResetDailyCounters(context.Background(), nil) + assert.Error(t, err) + + // valid call succeeds + result, err := s.ResetDailyCounters(context.Background(), map[string]interface{}{"capital": 5000.0}) + require.NoError(t, err) + m := result.(map[string]interface{}) + assert.Equal(t, true, m["success"]) + assert.Equal(t, 5000.0, m["new_capital"]) +} + +func TestEmergencyStop_NoGuard(t *testing.T) { + s := newPaperServiceNoSafety(t) + result, err := s.EmergencyStop(context.Background(), map[string]interface{}{"reason": "test"}) + require.NoError(t, err) + m := result.(map[string]interface{}) + assert.Equal(t, true, m["success"]) + assert.Equal(t, "test", m["reason"]) +} + +func TestClearEmergencyStop_NoGuard(t *testing.T) { + s := newPaperServiceNoSafety(t) + result, err := s.ClearEmergencyStop(context.Background(), nil) + require.NoError(t, err) + m := result.(map[string]interface{}) + assert.Equal(t, true, m["success"]) + assert.Equal(t, false, m["was_active"]) +} + +func TestGetEmergencyStopStatus_NoGuard(t *testing.T) { + s := newPaperServiceNoSafety(t) + result, err := s.GetEmergencyStopStatus(context.Background(), nil) + require.NoError(t, err) + m := result.(map[string]interface{}) + assert.Equal(t, false, m["active"]) +} + +func TestGetTradingHoursStatus_NoGuard(t *testing.T) { + s := newPaperServiceNoSafety(t) + result, err := s.GetTradingHoursStatus(context.Background(), nil) + require.NoError(t, err) + m := result.(map[string]interface{}) + // TradingHours.Enabled is false in zero config + assert.Equal(t, false, m["enabled"]) + // IsWithinTradingHours returns true when disabled + assert.Equal(t, true, m["within_hours"]) +} + +func TestSetMarketPrice_Service(t *testing.T) { + s := newPaperServiceNoSafety(t) + // SetMarketPrice delegates to the mock exchange — should not panic + s.SetMarketPrice("BTCUSDT", 45000.0) + price := s.exchange.GetMarketPrice("BTCUSDT") + assert.Equal(t, 45000.0, price) +} diff --git a/internal/market/pure_unit_test.go b/internal/market/pure_unit_test.go new file mode 100644 index 00000000..a533b0f1 --- /dev/null +++ b/internal/market/pure_unit_test.go @@ -0,0 +1,90 @@ +package market + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCoinGeckoIDToBinanceSymbol(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"bitcoin", "BTCUSDT"}, + {"ethereum", "ETHUSDT"}, + {"solana", "SOLUSDT"}, + {"cardano", "ADAUSDT"}, + {"ripple", "XRPUSDT"}, + {"dogecoin", "DOGEUSDT"}, + {"polkadot", "DOTUSDT"}, + {"avalanche", "AVAXUSDT"}, + {"chainlink", "LINKUSDT"}, + {"polygon", "MATICUSDT"}, + // unknown → uppercase + USDT + {"shiba", "SHIBAUSDT"}, + {"unknowncoin", "UNKNOWNCOINUSDT"}, + } + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + assert.Equal(t, tc.want, CoinGeckoIDToBinanceSymbol(tc.input)) + }) + } +} + +func TestEstimateOHLCV(t *testing.T) { + t.Run("nil / empty input", func(t *testing.T) { + result := EstimateOHLCV(nil) + require.NotNil(t, result) + assert.Empty(t, result.Close) + }) + + t.Run("single price uses min range", func(t *testing.T) { + result := EstimateOHLCV([]float64{100.0}) + require.Len(t, result.High, 1) + assert.Greater(t, result.High[0], 100.0, "high should be above close") + assert.Less(t, result.Low[0], 100.0, "low should be below close") + assert.Equal(t, 100.0, result.Close[0]) + }) + + t.Run("two prices", func(t *testing.T) { + closes := []float64{100.0, 110.0} + result := EstimateOHLCV(closes) + require.Len(t, result.High, 2) + for i := range closes { + assert.GreaterOrEqual(t, result.High[i], closes[i]) + assert.LessOrEqual(t, result.Low[i], closes[i]) + assert.Equal(t, closes[i], result.Close[i]) + } + }) + + t.Run("flat prices use min range of 0.1 pct", func(t *testing.T) { + closes := []float64{1000.0, 1000.0, 1000.0} + result := EstimateOHLCV(closes) + // halfRange = max(0, 0) → falls back to minRange = 1.0 → halfRange/2 = 0.5 + expectedHalfRange := 1000.0 * 0.001 / 2.0 + assert.InDelta(t, 1000.0+expectedHalfRange, result.High[1], 1e-9) + assert.InDelta(t, 1000.0-expectedHalfRange, result.Low[1], 1e-9) + }) + + t.Run("high is always >= close and low <= close", func(t *testing.T) { + closes := []float64{50000, 51000, 49000, 52000, 48000} + result := EstimateOHLCV(closes) + for i, c := range closes { + assert.GreaterOrEqual(t, result.High[i], c) + assert.LessOrEqual(t, result.Low[i], c) + assert.False(t, math.IsNaN(result.High[i])) + assert.False(t, math.IsNaN(result.Low[i])) + } + }) +} + +func TestParseStringFloat(t *testing.T) { + assert.Equal(t, 3.14, ParseStringFloat(3.14)) + assert.Equal(t, 42.0, ParseStringFloat("42")) + assert.Equal(t, 0.0, ParseStringFloat("notanumber")) + assert.Equal(t, 0.0, ParseStringFloat(nil)) + assert.Equal(t, 0.0, ParseStringFloat(42)) // int is not handled +} diff --git a/internal/models/convert_test.go b/internal/models/convert_test.go new file mode 100644 index 00000000..68d6c0fb --- /dev/null +++ b/internal/models/convert_test.go @@ -0,0 +1,243 @@ +package models + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ajitpratap0/cryptofunk/internal/db" +) + +// TestBinancePositionToUnified_OpenPosition tests conversion of an open Binance position +// with nil optional fields. +func TestBinancePositionToUnified_OpenPosition(t *testing.T) { + posID := uuid.New() + entryTime := time.Now().Add(-1 * time.Hour) + + p := &db.Position{ + ID: posID, + Symbol: "BTC/USDT", + Side: db.PositionSideLong, + EntryPrice: 50000.0, + Quantity: 0.5, + EntryTime: entryTime, + // nil optionals: SessionID, UnrealizedPnL, RealizedPnL, ExitTime, ExitPrice, EntryReason + } + + u := BinancePositionToUnified(p) + + assert.Equal(t, posID, u.ID) + assert.Equal(t, PlatformBinance, u.Platform) + assert.Equal(t, "BTC/USDT", u.Symbol) + assert.Equal(t, "long", u.Side, "Side should be lowercased") + assert.InDelta(t, 50000.0, u.EntryPrice, 1e-9) + assert.InDelta(t, 0.5, u.Quantity, 1e-9) + assert.InDelta(t, 50000.0*0.5, u.CostBasis, 1e-9, "CostBasis = EntryPrice * Quantity") + assert.Equal(t, PositionStatusOpen, u.Status) + assert.Nil(t, u.ClosedAt) + assert.InDelta(t, 0.0, u.UnrealizedPnL, 1e-9, "nil UnrealizedPnL should default to 0") + assert.InDelta(t, 0.0, u.RealizedPnL, 1e-9, "nil RealizedPnL should default to 0") + assert.Equal(t, uuid.Nil, u.SessionID, "nil SessionID should default to uuid.Nil") +} + +// TestBinancePositionToUnified_ClosedPosition tests conversion of a closed position with all +// optional fields populated. +func TestBinancePositionToUnified_ClosedPosition(t *testing.T) { + posID := uuid.New() + sessionID := uuid.New() + entryTime := time.Now().Add(-2 * time.Hour) + exitTime := time.Now().Add(-1 * time.Hour) + exitPrice := 52000.0 + unrealizedPnL := 0.0 + realizedPnL := 1000.0 + entryReason := "technical-agent" + + p := &db.Position{ + ID: posID, + SessionID: &sessionID, + Symbol: "ETH/USDT", + Side: db.PositionSideShort, + EntryPrice: 3000.0, + Quantity: 2.0, + EntryTime: entryTime, + ExitTime: &exitTime, + ExitPrice: &exitPrice, + UnrealizedPnL: &unrealizedPnL, + RealizedPnL: &realizedPnL, + EntryReason: &entryReason, + } + + u := BinancePositionToUnified(p) + + assert.Equal(t, posID, u.ID) + assert.Equal(t, PlatformBinance, u.Platform) + assert.Equal(t, "short", u.Side, "Side should be lowercased") + assert.Equal(t, PositionStatusClosed, u.Status) + assert.NotNil(t, u.ClosedAt) + assert.Equal(t, exitTime, *u.ClosedAt) + assert.InDelta(t, 52000.0, u.CurrentPrice, 1e-9) + assert.InDelta(t, unrealizedPnL, u.UnrealizedPnL, 1e-9) + assert.InDelta(t, realizedPnL, u.RealizedPnL, 1e-9) + assert.Equal(t, sessionID, u.SessionID) + assert.Equal(t, "technical-agent", u.Agent) + assert.InDelta(t, 3000.0*2.0, u.CostBasis, 1e-9) +} + +// TestBinanceTradeToUnified tests conversion of a Binance trade+order pair. +func TestBinanceTradeToUnified(t *testing.T) { + tradeID := uuid.New() + orderID := uuid.New() + positionID := uuid.New() + executedAt := time.Now() + + trade := &db.Trade{ + ID: tradeID, + OrderID: orderID, + Symbol: "BTC/USDT", + Side: db.OrderSideBuy, + Price: 48000.0, + Quantity: 0.1, + QuoteQuantity: 4800.0, + ExecutedAt: executedAt, + } + + order := &db.Order{ + ID: orderID, + PositionID: &positionID, + } + + u := BinanceTradeToUnified(trade, order) + + assert.Equal(t, tradeID, u.ID) + assert.Equal(t, PlatformBinance, u.Platform) + assert.Equal(t, "BTC/USDT", u.Symbol) + assert.InDelta(t, 48000.0, u.Price, 1e-9) + assert.InDelta(t, 0.1, u.Quantity, 1e-9) + assert.InDelta(t, 4800.0, u.Amount, 1e-9) + assert.Equal(t, positionID, u.PositionID) + assert.Equal(t, "BUY", u.Action) + assert.Equal(t, executedAt, u.Timestamp) +} + +// TestBinanceTradeToUnified_NilOrder verifies that a nil order doesn't panic and PositionID +// is left as zero value. +func TestBinanceTradeToUnified_NilOrder(t *testing.T) { + trade := &db.Trade{ + ID: uuid.New(), + Symbol: "SOL/USDT", + Side: db.OrderSideSell, + Price: 100.0, + } + + u := BinanceTradeToUnified(trade, nil) + + assert.Equal(t, PlatformBinance, u.Platform) + assert.Equal(t, "SOL/USDT", u.Symbol) + assert.Equal(t, uuid.Nil, u.PositionID, "PositionID should be zero when order is nil") + assert.Equal(t, "SELL", u.Action) +} + +// TestNewEmptyPortfolio verifies the portfolio is non-nil with an empty positions slice. +func TestNewEmptyPortfolio(t *testing.T) { + p := NewEmptyPortfolio() + + require.NotNil(t, p) + require.NotNil(t, p.Positions) + assert.Empty(t, p.Positions) + require.NotNil(t, p.ByPlatform) + assert.Contains(t, p.ByPlatform, PlatformBinance) + assert.Contains(t, p.ByPlatform, PlatformPolymarket) + assert.Equal(t, 0, p.OpenPositions) + assert.InDelta(t, 0.0, p.TotalValue, 1e-9) +} + +// TestUnifiedPortfolio_AddPosition verifies that adding a position updates the portfolio. +func TestUnifiedPortfolio_AddPosition(t *testing.T) { + p := NewEmptyPortfolio() + + pos := UnifiedPosition{ + ID: uuid.New(), + Platform: PlatformBinance, + Symbol: "BTC/USDT", + EntryPrice: 50000.0, + CurrentPrice: 51000.0, + Quantity: 1.0, + UnrealizedPnL: 1000.0, + RealizedPnL: 500.0, + Status: PositionStatusOpen, + } + + p.AddPosition(pos) + + require.Len(t, p.Positions, 1) + assert.Equal(t, pos, p.Positions[0]) + assert.Equal(t, 1, p.OpenPositions) + assert.InDelta(t, 1000.0, p.UnrealizedPnL, 1e-9) + assert.InDelta(t, 500.0, p.RealizedPnL, 1e-9) + assert.InDelta(t, 1500.0, p.TotalPnL, 1e-9) + assert.InDelta(t, 51000.0, p.TotalValue, 1e-9, "TotalValue = CurrentPrice * Quantity") + + binanceSummary := p.ByPlatform[PlatformBinance] + assert.Equal(t, 1, binanceSummary.PositionCount) + assert.InDelta(t, 51000.0, binanceSummary.TotalValue, 1e-9) +} + +// TestUnifiedPortfolio_AddPosition_Closed verifies that adding a closed position does not +// increment OpenPositions or UnrealizedPnL. +func TestUnifiedPortfolio_AddPosition_Closed(t *testing.T) { + p := NewEmptyPortfolio() + + closedAt := time.Now() + pos := UnifiedPosition{ + ID: uuid.New(), + Platform: PlatformBinance, + Symbol: "ETH/USDT", + CurrentPrice: 3000.0, + Quantity: 1.0, + RealizedPnL: 200.0, + Status: PositionStatusClosed, + ClosedAt: &closedAt, + } + + p.AddPosition(pos) + + assert.Equal(t, 0, p.OpenPositions) + assert.InDelta(t, 0.0, p.UnrealizedPnL, 1e-9) + assert.InDelta(t, 200.0, p.RealizedPnL, 1e-9) +} + +// TestUnifiedPortfolio_SetPlatformTradeCount verifies that trade counts are stored and +// TotalTrades is updated correctly. +func TestUnifiedPortfolio_SetPlatformTradeCount(t *testing.T) { + p := NewEmptyPortfolio() + + p.SetPlatformTradeCount(PlatformBinance, 10) + p.SetPlatformTradeCount(PlatformPolymarket, 5) + + assert.Equal(t, 10, p.ByPlatform[PlatformBinance].TradeCount) + assert.Equal(t, 5, p.ByPlatform[PlatformPolymarket].TradeCount) + assert.Equal(t, 15, p.TotalTrades) + + // Updating one platform recalculates total correctly. + p.SetPlatformTradeCount(PlatformBinance, 20) + assert.Equal(t, 20, p.ByPlatform[PlatformBinance].TradeCount) + assert.Equal(t, 25, p.TotalTrades) +} + +// TestUnifiedPosition_IsOpen tests the IsOpen method for both open and closed positions. +func TestUnifiedPosition_IsOpen(t *testing.T) { + open := &UnifiedPosition{ + Status: PositionStatusOpen, + } + assert.True(t, open.IsOpen(), "position with status OPEN should return true from IsOpen") + + closedAt := time.Now() + closed := &UnifiedPosition{ + Status: PositionStatusClosed, + ClosedAt: &closedAt, + } + assert.False(t, closed.IsOpen(), "position with status CLOSED should return false from IsOpen") +} diff --git a/internal/polymarket/analyzer/analyzer_unit_test.go b/internal/polymarket/analyzer/analyzer_unit_test.go new file mode 100644 index 00000000..a1b9720d --- /dev/null +++ b/internal/polymarket/analyzer/analyzer_unit_test.go @@ -0,0 +1,78 @@ +package analyzer + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ajitpratap0/cryptofunk/internal/polymarket/gamma" +) + +func TestParseResponse(t *testing.T) { + t.Run("valid JSON response", func(t *testing.T) { + resp := `{"predicted_probability":0.75,"confidence":0.8,"reasoning":"strong evidence","action":"BUY YES"}` + market := &gamma.Market{ + ConditionID: "123", + Question: "Test?", + OutcomeYesPrice: 0.5, + OutcomeNoPrice: 0.5, + } + + analysis := parseResponse(resp, market) + + require.NotNil(t, analysis) + assert.InDelta(t, 0.75, analysis.PredictedProb, 1e-9) + assert.InDelta(t, 0.8, analysis.Confidence, 1e-9) + assert.Equal(t, "BUY YES", analysis.Action) + assert.InDelta(t, 0.25, analysis.Edge, 1e-9) // 0.75 - 0.5 + assert.Equal(t, "123", analysis.MarketID) + }) + + t.Run("JSON embedded in surrounding text", func(t *testing.T) { + resp := `Some preamble text {"predicted_probability":0.3,"confidence":0.6,"reasoning":"weak","action":"BUY NO"} trailing text` + market := &gamma.Market{ + ConditionID: "456", + Question: "Will X happen?", + OutcomeYesPrice: 0.5, + OutcomeNoPrice: 0.5, + } + + analysis := parseResponse(resp, market) + + require.NotNil(t, analysis) + assert.Equal(t, "BUY NO", analysis.Action) + assert.InDelta(t, 0.3, analysis.PredictedProb, 1e-9) + }) + + t.Run("invalid response with no valid JSON", func(t *testing.T) { + resp := "I cannot analyze this market" + market := &gamma.Market{ + ConditionID: "789", + Question: "Another question?", + OutcomeYesPrice: 0.6, + OutcomeNoPrice: 0.4, + } + + analysis := parseResponse(resp, market) + + require.NotNil(t, analysis) + assert.Equal(t, "SKIP", analysis.Action) + assert.InDelta(t, 0.0, analysis.PredictedProb, 1e-9) + }) + + t.Run("empty string response", func(t *testing.T) { + resp := "" + market := &gamma.Market{ + ConditionID: "000", + Question: "Empty test?", + OutcomeYesPrice: 0.5, + OutcomeNoPrice: 0.5, + } + + analysis := parseResponse(resp, market) + + require.NotNil(t, analysis) + assert.Equal(t, "SKIP", analysis.Action) + }) +} diff --git a/internal/polymarket/news/news_unit_test.go b/internal/polymarket/news/news_unit_test.go new file mode 100644 index 00000000..d6ed1203 --- /dev/null +++ b/internal/polymarket/news/news_unit_test.go @@ -0,0 +1,30 @@ +package news + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStripHTML(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"plain text unchanged", "hello world", "hello world"}, + {"strips simple tag", "bold", "bold"}, + {"strips multiple tags", "
hello world
", "hello world"}, + {"empty string", "", ""}, + {"only tags", "