diff --git a/cmd/bridge/main.go b/cmd/bridge/main.go index d17bbef..3562558 100644 --- a/cmd/bridge/main.go +++ b/cmd/bridge/main.go @@ -170,6 +170,7 @@ func main() { e.GET("/bridge/events", h.EventRegistrationHandler) e.POST("/bridge/message", h.SendMessageHandler) + e.POST("/bridge/verify", h.ConnectVerifyHandler) // Health and ready endpoints e.GET("/health", func(c echo.Context) error { diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 15fe5e2..6da2317 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -59,6 +59,24 @@ var ( }) ) +type connectClient struct { + clientId string + ip string + referrer string // normalized origin + time time.Time +} + +type verifyRequest struct { + Type string `json:"type"` + ClientID string `json:"client_id"` + URL string `json:"url"` + Message string `json:"message,omitempty"` +} + +type verifyResponse struct { + Status string `json:"status"` +} + type stream struct { Sessions []*Session mux sync.RWMutex @@ -69,6 +87,8 @@ type handler struct { storage storage.Storage eventIDGen *EventIDGenerator heartbeatInterval time.Duration + datamap map[string][]connectClient // todo - use lru maps, add ttl 5 minutes + } func NewHandler(s storage.Storage, heartbeatInterval time.Duration) *handler { @@ -78,6 +98,7 @@ func NewHandler(s storage.Storage, heartbeatInterval time.Duration) *handler { storage: s, eventIDGen: NewEventIDGenerator(), heartbeatInterval: heartbeatInterval, + datamap: make(map[string][]connectClient), } return &h } @@ -148,6 +169,18 @@ func (h *handler) EventRegistrationHandler(c echo.Context) error { clientIdsPerConnectionMetric.Observe(float64(len(clientIds))) session := h.CreateSession(clientIds, lastEventId) + ip := c.RealIP() + origin := utils.ExtractOrigin(c.Request().Header.Get("Origin")) + connect_client := connectClient{ + clientId: clientId[0], + ip: ip, + referrer: origin, + time: time.Now(), + } + h.Mux.Lock() + h.datamap[clientId[0]] = append(h.datamap[clientId[0]], connect_client) + h.Mux.Unlock() + ctx := c.Request().Context() notify := ctx.Done() go func() { @@ -210,6 +243,72 @@ loop: return nil } +func (h *handler) ConnectVerifyHandler(c echo.Context) error { + ip := c.RealIP() // Todo - move all ip extraction to single function + + // Support new JSON POST format; fallback to legacy query params for backward compatibility + var req verifyRequest + if c.Request().Method == http.MethodPost { + decoder := json.NewDecoder(c.Request().Body) + if err := decoder.Decode(&req); err != nil { + badRequestMetric.Inc() + return c.JSON(utils.HttpResError("invalid JSON body", http.StatusBadRequest)) + } + } else { + params := c.QueryParams() + clientId, ok := params["client_id"] + if ok && len(clientId) > 0 { + req.ClientID = clientId[0] + } + urls, ok := params["url"] + if ok && len(urls) > 0 { + req.URL = urls[0] + } + types, ok := params["type"] + if ok && len(types) > 0 { + req.Type = types[0] + } else { + req.Type = "connect" + } + } + + if req.ClientID == "" { + badRequestMetric.Inc() + return c.JSON(utils.HttpResError("param \"client_id\" not present", http.StatusBadRequest)) + } + if req.URL == "" { + badRequestMetric.Inc() + return c.JSON(utils.HttpResError("param \"url\" not present", http.StatusBadRequest)) + } + req.URL = utils.ExtractOrigin(req.URL) + if req.Type == "" { + badRequestMetric.Inc() + return c.JSON(utils.HttpResError("param \"type\" not present", http.StatusBadRequest)) + } + + // Default status + status := "unknown" + now := time.Now() + + switch strings.ToLower(req.Type) { + case "connect": + h.Mux.RLock() + existingConnects := h.datamap[req.ClientID] + h.Mux.RUnlock() + for _, connect := range existingConnects { + if connect.ip == ip && connect.referrer == req.URL && now.Sub(connect.time) < 5*time.Minute { + status = "ok" + break + } + } + default: + badRequestMetric.Inc() + return c.JSON(utils.HttpResError("param \"type\" must be one of: connect, message", http.StatusBadRequest)) + } + + return c.JSON(http.StatusOK, verifyResponse{Status: status}) +} + func (h *handler) SendMessageHandler(c echo.Context) error { ctx := c.Request().Context() log := logrus.WithContext(ctx).WithField("prefix", "SendMessageHandler") @@ -299,10 +398,35 @@ func (h *handler) SendMessageHandler(c echo.Context) error { } } + origin := utils.ExtractOrigin(c.Request().Header.Get("Origin")) + ip := c.RealIP() + userAgent := c.Request().Header.Get("User-Agent") + + // Create request source metadata + requestSource := models.BridgeRequestSource{ + Origin: origin, + IP: ip, + Time: time.Now().UTC().Format(time.RFC3339), + ClientID: clientId[0], + UserAgent: userAgent, + } + + // Encrypt the request source metadata using the wallet's public key + encryptedRequestSource, err := utils.EncryptRequestSourceWithWalletID( + requestSource, + toId[0], // todo - check to id properly + ) + if err != nil { + badRequestMetric.Inc() + log.Error(err) + return c.JSON(utils.HttpResError(fmt.Sprintf("failed to encrypt request source: %v", err), http.StatusBadRequest)) + } + mes, err := json.Marshal(models.BridgeMessage{ - From: clientId[0], - Message: string(message), - TraceId: traceId, + From: clientId[0], + Message: string(message), + BridgeRequestSource: encryptedRequestSource, + TraceId: traceId, }) if err != nil { badRequestMetric.Inc() diff --git a/internal/models/models.go b/internal/models/models.go index 60c26b2..d43bd78 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -7,7 +7,16 @@ type SseMessage struct { } type BridgeMessage struct { - From string `json:"from"` - Message string `json:"message"` - TraceId string `json:"trace_id"` + From string `json:"from"` + Message string `json:"message"` + TraceId string `json:"trace_id"` + BridgeRequestSource string `json:"request_source"` +} + +type BridgeRequestSource struct { + Origin string `json:"origin"` + IP string `json:"ip"` + Time string `json:"time"` + ClientID string `json:"client_id"` + UserAgent string `json:"user_agent"` } diff --git a/internal/utils/http.go b/internal/utils/http.go index 5bd1918..222e47a 100644 --- a/internal/utils/http.go +++ b/internal/utils/http.go @@ -1,6 +1,9 @@ package utils -import "net/http" +import ( + "net/http" + "net/url" +) type HttpRes struct { Message string `json:"message,omitempty" example:"status ok"` @@ -20,3 +23,17 @@ func HttpResError(errMsg string, statusCode int) (int, HttpRes) { StatusCode: statusCode, } } + +func ExtractOrigin(rawURL string) string { + if rawURL == "" { + return "" + } + u, err := url.Parse(rawURL) + if err != nil { + return rawURL + } + if u.Scheme == "" || u.Host == "" { + return rawURL + } + return u.Scheme + "://" + u.Host +} \ No newline at end of file diff --git a/internal/utils/tls.go b/internal/utils/tls.go index 79f60cb..e197e0f 100644 --- a/internal/utils/tls.go +++ b/internal/utils/tls.go @@ -6,11 +6,19 @@ import ( "crypto/rand" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" + "encoding/hex" + "encoding/json" "encoding/pem" + "fmt" "math/big" "time" + + "github.com/ton-connect/bridge3/internal/models" + "golang.org/x/crypto/nacl/box" ) +// GenerateSelfSignedCertificate generates a self-signed X.509 certificate and private key func GenerateSelfSignedCertificate() ([]byte, []byte, error) { privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -50,3 +58,31 @@ func GenerateSelfSignedCertificate() ([]byte, []byte, error) { return certPEM, keyPEM, nil } + +// EncryptRequestSourceWithWalletID encrypts the request source metadata using the wallet's Curve25519 public key +func EncryptRequestSourceWithWalletID(requestSource models.BridgeRequestSource, walletID string) (string, error) { + data, err := json.Marshal(requestSource) + if err != nil { + return "", fmt.Errorf("failed to marshal request source: %w", err) + } + + publicKeyBytes, err := hex.DecodeString(walletID) + if err != nil { + return "", fmt.Errorf("failed to decode wallet ID: %w", err) + } + + if len(publicKeyBytes) != 32 { + return "", fmt.Errorf("invalid public key length: expected 32 bytes, got %d", len(publicKeyBytes)) + } + + // Convert to Curve25519 public key format + var recipientPublicKey [32]byte + copy(recipientPublicKey[:], publicKeyBytes) + + encrypted, err := box.SealAnonymous(nil, data, &recipientPublicKey, rand.Reader) + if err != nil { + return "", fmt.Errorf("failed to encrypt data: %w", err) + } + + return base64.StdEncoding.EncodeToString(encrypted), nil +}