@@ -77,12 +77,12 @@ import (
77
77
"encoding/json"
78
78
"errors"
79
79
"fmt"
80
+ "net/http"
80
81
"sync"
81
82
"time"
82
83
83
84
"golang.org/x/time/rate"
84
85
85
- "github.com/google/uuid"
86
86
"github.com/gorilla/websocket"
87
87
"github.com/rs/zerolog"
88
88
"golang.org/x/sync/errgroup"
@@ -129,7 +129,7 @@ type Controller struct {
129
129
// issues such as sending on a closed channel while maintaining proper cleanup.
130
130
multiplexedStream chan interface {}
131
131
132
- dataProviders * concurrentmap.Map [uuid. UUID , dp.DataProvider ]
132
+ dataProviders * concurrentmap.Map [SubscriptionID , dp.DataProvider ]
133
133
dataProviderFactory dp.DataProviderFactory
134
134
dataProvidersGroup * sync.WaitGroup
135
135
limiter * rate.Limiter
@@ -146,7 +146,7 @@ func NewWebSocketController(
146
146
config : config ,
147
147
conn : conn ,
148
148
multiplexedStream : make (chan interface {}),
149
- dataProviders : concurrentmap .New [uuid. UUID , dp.DataProvider ](),
149
+ dataProviders : concurrentmap .New [SubscriptionID , dp.DataProvider ](),
150
150
dataProviderFactory : dataProviderFactory ,
151
151
dataProvidersGroup : & sync.WaitGroup {},
152
152
limiter : rate .NewLimiter (rate .Limit (config .MaxResponsesPerSecond ), 1 ),
@@ -246,7 +246,7 @@ func (c *Controller) keepalive(ctx context.Context) error {
246
246
// If no messages are sent within InactivityTimeout and no active data providers exist,
247
247
// the connection will be closed.
248
248
func (c * Controller ) writeMessages (ctx context.Context ) error {
249
- inactivityTicker := time .NewTicker (c .config . InactivityTimeout / 10 )
249
+ inactivityTicker := time .NewTicker (c .inactivityTickerPeriod () )
250
250
defer inactivityTicker .Stop ()
251
251
252
252
lastMessageSentAt := time .Now ()
@@ -301,6 +301,10 @@ func (c *Controller) writeMessages(ctx context.Context) error {
301
301
}
302
302
}
303
303
304
+ func (c * Controller ) inactivityTickerPeriod () time.Duration {
305
+ return c .config .InactivityTimeout / 10
306
+ }
307
+
304
308
// readMessages continuously reads messages from a client WebSocket connection,
305
309
// validates each message, and processes it based on the message type.
306
310
func (c * Controller ) readMessages (ctx context.Context ) error {
@@ -314,7 +318,8 @@ func (c *Controller) readMessages(ctx context.Context) error {
314
318
c .writeErrorResponse (
315
319
ctx ,
316
320
err ,
317
- wrapErrorMessage (InvalidMessage , "error reading message" , "" , "" , "" ))
321
+ wrapErrorMessage (http .StatusBadRequest , "error reading message" , "" , "" ),
322
+ )
318
323
continue
319
324
}
320
325
@@ -323,7 +328,8 @@ func (c *Controller) readMessages(ctx context.Context) error {
323
328
c .writeErrorResponse (
324
329
ctx ,
325
330
err ,
326
- wrapErrorMessage (InvalidMessage , "error parsing message" , "" , "" , "" ))
331
+ wrapErrorMessage (http .StatusBadRequest , "error parsing message" , "" , "" ),
332
+ )
327
333
continue
328
334
}
329
335
}
@@ -366,24 +372,34 @@ func (c *Controller) handleMessage(ctx context.Context, message json.RawMessage)
366
372
}
367
373
368
374
func (c * Controller ) handleSubscribe (ctx context.Context , msg models.SubscribeMessageRequest ) {
375
+ subscriptionID , err := c .parseOrCreateSubscriptionID (msg .SubscriptionID )
376
+ if err != nil {
377
+ c .writeErrorResponse (
378
+ ctx ,
379
+ err ,
380
+ wrapErrorMessage (http .StatusBadRequest , "error parsing subscription id" ,
381
+ models .SubscribeAction , msg .SubscriptionID ),
382
+ )
383
+ return
384
+ }
385
+
369
386
// register new provider
370
- provider , err := c .dataProviderFactory .NewDataProvider (ctx , msg .Topic , msg .Arguments , c .multiplexedStream )
387
+ provider , err := c .dataProviderFactory .NewDataProvider (ctx , subscriptionID . String (), msg .Topic , msg .Arguments , c .multiplexedStream )
371
388
if err != nil {
372
389
c .writeErrorResponse (
373
390
ctx ,
374
391
err ,
375
- wrapErrorMessage (InvalidArgument , "error creating data provider" , msg .ClientMessageID , models .SubscribeAction , "" ),
392
+ wrapErrorMessage (http .StatusBadRequest , "error creating data provider" ,
393
+ models .SubscribeAction , subscriptionID .String ()),
376
394
)
377
395
return
378
396
}
379
- c .dataProviders .Add (provider . ID () , provider )
397
+ c .dataProviders .Add (subscriptionID , provider )
380
398
381
399
// write OK response to client
382
400
responseOk := models.SubscribeMessageResponse {
383
401
BaseMessageResponse : models.BaseMessageResponse {
384
- ClientMessageID : msg .ClientMessageID ,
385
- Success : true ,
386
- SubscriptionID : provider .ID ().String (),
402
+ SubscriptionID : subscriptionID .String (),
387
403
},
388
404
}
389
405
c .writeResponse (ctx , responseOk )
@@ -396,72 +412,63 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe
396
412
c .writeErrorResponse (
397
413
ctx ,
398
414
err ,
399
- wrapErrorMessage (SubscriptionError , "subscription finished with error" , "" , "" , "" ),
415
+ wrapErrorMessage (http .StatusInternalServerError , "internal error" ,
416
+ models .SubscribeAction , subscriptionID .String ()),
400
417
)
401
418
}
402
419
403
420
c .dataProvidersGroup .Done ()
404
- c .dataProviders .Remove (provider . ID () )
421
+ c .dataProviders .Remove (subscriptionID )
405
422
}()
406
423
}
407
424
408
425
func (c * Controller ) handleUnsubscribe (ctx context.Context , msg models.UnsubscribeMessageRequest ) {
409
- id , err := uuid . Parse (msg .SubscriptionID )
426
+ subscriptionID , err := ParseClientSubscriptionID (msg .SubscriptionID )
410
427
if err != nil {
411
428
c .writeErrorResponse (
412
429
ctx ,
413
430
err ,
414
- wrapErrorMessage (InvalidArgument , "error parsing subscription ID" , msg .ClientMessageID , models .UnsubscribeAction , msg .SubscriptionID ),
431
+ wrapErrorMessage (http .StatusBadRequest , "error parsing subscription id" ,
432
+ models .UnsubscribeAction , msg .SubscriptionID ),
415
433
)
416
434
return
417
435
}
418
436
419
- provider , ok := c .dataProviders .Get (id )
437
+ provider , ok := c .dataProviders .Get (subscriptionID )
420
438
if ! ok {
421
439
c .writeErrorResponse (
422
440
ctx ,
423
441
err ,
424
- wrapErrorMessage (NotFound , "subscription not found" , msg .ClientMessageID , models .UnsubscribeAction , msg .SubscriptionID ),
442
+ wrapErrorMessage (http .StatusNotFound , "subscription not found" ,
443
+ models .UnsubscribeAction , subscriptionID .String ()),
425
444
)
426
445
return
427
446
}
428
447
429
448
provider .Close ()
430
- c .dataProviders .Remove (id )
449
+ c .dataProviders .Remove (subscriptionID )
431
450
432
451
responseOk := models.UnsubscribeMessageResponse {
433
452
BaseMessageResponse : models.BaseMessageResponse {
434
- ClientMessageID : msg .ClientMessageID ,
435
- Success : true ,
436
- SubscriptionID : msg .SubscriptionID ,
453
+ SubscriptionID : subscriptionID .String (),
437
454
},
438
455
}
439
456
c .writeResponse (ctx , responseOk )
440
457
}
441
458
442
- func (c * Controller ) handleListSubscriptions (ctx context.Context , msg models.ListSubscriptionsMessageRequest ) {
459
+ func (c * Controller ) handleListSubscriptions (ctx context.Context , _ models.ListSubscriptionsMessageRequest ) {
443
460
var subs []* models.SubscriptionEntry
444
- err : = c .dataProviders .ForEach (func (id uuid. UUID , provider dp.DataProvider ) error {
461
+ _ = c .dataProviders .ForEach (func (id SubscriptionID , provider dp.DataProvider ) error {
445
462
subs = append (subs , & models.SubscriptionEntry {
446
- ID : id .String (),
447
- Topic : provider .Topic (),
463
+ SubscriptionID : id .String (),
464
+ Topic : provider .Topic (),
448
465
})
449
466
return nil
450
467
})
451
468
452
- if err != nil {
453
- c .writeErrorResponse (
454
- ctx ,
455
- err ,
456
- wrapErrorMessage (NotFound , "error listing subscriptions" , msg .ClientMessageID , models .ListSubscriptionsAction , "" ),
457
- )
458
- return
459
- }
460
-
461
469
responseOk := models.ListSubscriptionsMessageResponse {
462
- Success : true ,
463
- ClientMessageID : msg .ClientMessageID ,
464
- Subscriptions : subs ,
470
+ Subscriptions : subs ,
471
+ Action : models .ListSubscriptionsAction ,
465
472
}
466
473
c .writeResponse (ctx , responseOk )
467
474
}
@@ -472,13 +479,10 @@ func (c *Controller) shutdownConnection() {
472
479
c .logger .Debug ().Err (err ).Msg ("error closing connection" )
473
480
}
474
481
475
- err = c .dataProviders .ForEach (func (_ uuid. UUID , provider dp.DataProvider ) error {
482
+ _ = c .dataProviders .ForEach (func (_ SubscriptionID , provider dp.DataProvider ) error {
476
483
provider .Close ()
477
484
return nil
478
485
})
479
- if err != nil {
480
- c .logger .Debug ().Err (err ).Msg ("error closing data provider" )
481
- }
482
486
483
487
c .dataProviders .Clear ()
484
488
c .dataProvidersGroup .Wait ()
@@ -498,15 +502,26 @@ func (c *Controller) writeResponse(ctx context.Context, response interface{}) {
498
502
}
499
503
}
500
504
501
- func wrapErrorMessage (code Code , message string , msgId string , action string , subscriptionID string ) models.BaseMessageResponse {
505
+ func wrapErrorMessage (code int , message string , action string , subscriptionID string ) models.BaseMessageResponse {
502
506
return models.BaseMessageResponse {
503
- ClientMessageID : msgId ,
504
- Success : false ,
505
- SubscriptionID : subscriptionID ,
507
+ SubscriptionID : subscriptionID ,
506
508
Error : models.ErrorMessage {
507
- Code : int ( code ) ,
509
+ Code : code ,
508
510
Message : message ,
509
- Action : action ,
510
511
},
512
+ Action : action ,
511
513
}
512
514
}
515
+
516
+ func (c * Controller ) parseOrCreateSubscriptionID (id string ) (SubscriptionID , error ) {
517
+ newId , err := NewSubscriptionID (id )
518
+ if err != nil {
519
+ return SubscriptionID {}, err
520
+ }
521
+
522
+ if c .dataProviders .Has (newId ) {
523
+ return SubscriptionID {}, fmt .Errorf ("subscription ID is already in use: %s" , newId )
524
+ }
525
+
526
+ return newId , nil
527
+ }
0 commit comments