diff --git a/server/server.go b/server/server.go index 1fec328..38f1bc2 100644 --- a/server/server.go +++ b/server/server.go @@ -25,6 +25,7 @@ type Server struct { handler *http.Server addr string channels []wasabi.Channel + listener net.Listener } type Option func(*Server) @@ -33,6 +34,10 @@ type Option func(*Server) // port - port to listen on // returns new instance of Server func NewServer(addr string, opts ...Option) *Server { + if addr == "" { + addr = ":http" + } + server := &Server{ addr: addr, channels: make([]wasabi.Channel, 0, 1), @@ -64,7 +69,7 @@ func (s *Server) AddChannel(channel wasabi.Channel) { // Run starts the server // returns error if server is already running // or if server fails to start -func (s *Server) Run() error { +func (s *Server) Run() (err error) { if !s.mutex.TryLock() { return ErrServerAlreadyRunning } @@ -82,9 +87,14 @@ func (s *Server) Run() error { s.handler.Handler = mux - slog.Info("Starting app server on " + s.addr) + s.listener, err = net.Listen("tcp", s.addr) + if err != nil { + return err + } + + slog.Info("Starting app server on " + s.listener.Addr().String()) - err := s.handler.ListenAndServe() + err = s.handler.Serve(s.listener) if err != nil && err != http.ErrServerClosed { return err @@ -132,6 +142,16 @@ func (s *Server) Close(ctx ...context.Context) error { return <-done } +// Addr returns the server's network address. +// If the server is not running, it returns nil. +func (s *Server) Addr() net.Addr { + if s.listener == nil { + return nil + } + + return s.listener.Addr() +} + // BaseContext optionally specifies based context that will be used for all connections. // If not specified, context.Background() will be used. func WithBaseContext(ctx context.Context) Option { diff --git a/server/server_test.go b/server/server_test.go index 05fe6e0..cdd4c8c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "fmt" "net/http" "testing" "time" @@ -218,3 +219,48 @@ func TestServer_Close_NoContext(t *testing.T) { t.Error("Expected server to stop") } } + +func TestServer_Addr(t *testing.T) { + // Create a new Server instance + server := NewServer(":0") + defer server.Close() + + // Create a mock channel + channel := mocks.NewMockChannel(t) + channel.EXPECT().Path().Return("/test") + channel.EXPECT().Handler().Return(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + channel.EXPECT().Close().Return(nil) + + server.AddChannel(channel) + + // Start the server in a separate goroutine + done := make(chan struct{}) + + // Run the server + for i := 0; i < 2; i++ { + go func() { + err := server.Run() + switch err { + case nil: + case ErrServerAlreadyRunning: + close(done) + default: + t.Errorf("Got unexpected error: %v", err) + } + }() + } + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Error("Expected server to start") + } + + fmt.Println(server.listener) + + addr := server.Addr() + + if addr == nil { + t.Error("Expected non-empty address") + } +}