diff --git a/server/server.go b/server/server.go index 1fec328..bf4d583 100644 --- a/server/server.go +++ b/server/server.go @@ -20,11 +20,13 @@ const ( var ErrServerAlreadyRunning = fmt.Errorf("server is already running") type Server struct { - baseCtx context.Context - mutex *sync.Mutex - handler *http.Server - addr string - channels []wasabi.Channel + baseCtx context.Context + listener net.Listener + listenerLock *sync.Mutex + mutex *sync.Mutex + handler *http.Server + addr string + channels []wasabi.Channel } type Option func(*Server) @@ -33,11 +35,16 @@ type Option func(*Server) // port - port to listen on // returns new instance of Server func NewServer(addr string, opts ...Option) *Server { + if addr == "" { + addr = ":http" + } + server := &Server{ - addr: addr, - channels: make([]wasabi.Channel, 0, 1), - mutex: &sync.Mutex{}, - baseCtx: context.Background(), + addr: addr, + channels: make([]wasabi.Channel, 0, 1), + mutex: &sync.Mutex{}, + listenerLock: &sync.Mutex{}, + baseCtx: context.Background(), } for _, opt := range opts { @@ -64,7 +71,7 @@ func (s *Server) AddChannel(channel wasabi.Channel) { // Run starts the server // returns error if server is already running // or if server fails to start -func (s *Server) Run() error { +func (s *Server) Run() (err error) { if !s.mutex.TryLock() { return ErrServerAlreadyRunning } @@ -82,9 +89,17 @@ func (s *Server) Run() error { s.handler.Handler = mux - slog.Info("Starting app server on " + s.addr) + s.listenerLock.Lock() + s.listener, err = net.Listen("tcp", s.addr) + s.listenerLock.Unlock() + + if err != nil { + return err + } + + slog.Info("Starting app server on " + s.listener.Addr().String()) - err := s.handler.ListenAndServe() + err = s.handler.Serve(s.listener) if err != nil && err != http.ErrServerClosed { return err @@ -132,6 +147,19 @@ func (s *Server) Close(ctx ...context.Context) error { return <-done } +// Addr returns the server's network address. +// If the server is not running, it returns nil. +func (s *Server) Addr() net.Addr { + s.listenerLock.Lock() + defer s.listenerLock.Unlock() + + if s.listener == nil { + return nil + } + + return s.listener.Addr() +} + // BaseContext optionally specifies based context that will be used for all connections. // If not specified, context.Background() will be used. func WithBaseContext(ctx context.Context) Option { diff --git a/server/server_test.go b/server/server_test.go index 05fe6e0..bddaf8d 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -26,6 +26,12 @@ func TestNewServer(t *testing.T) { if server.mutex == nil { t.Error("Expected non-nil mutex") } + + server = NewServer("") + + if server.addr != ":http" { + t.Errorf("Expected default port :http, but got %s", server.addr) + } } func TestServer_AddChannel(t *testing.T) { // Create a new Server instance @@ -218,3 +224,53 @@ func TestServer_Close_NoContext(t *testing.T) { t.Error("Expected server to stop") } } + +func TestServer_Addr(t *testing.T) { + // Create a new Server instance + server := NewServer(":0") + defer server.Close() + + // Create a mock channel + channel := mocks.NewMockChannel(t) + channel.EXPECT().Path().Return("/test") + channel.EXPECT().Handler().Return(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + channel.EXPECT().Close().Return(nil) + + server.AddChannel(channel) + + if server.Addr() != nil { + t.Error("Expected nil address for server that is not running") + } + + // Start the server in a separate goroutine + done := make(chan struct{}) + + // Run the server + for i := 0; i < 2; i++ { + go func() { + err := server.Run() + switch err { + case nil: + case ErrServerAlreadyRunning: + close(done) + default: + t.Errorf("Got unexpected error: %v", err) + } + }() + } + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Error("Expected server to start") + } + + // Wait for the server to fully start + time.Sleep(1 * time.Millisecond) + + addr := server.Addr() + + if addr == nil { + t.Error("Expected non-empty address") + } +}