Skip to content

Commit

Permalink
Merge pull request #51 from ksysoev/get_address_and_port
Browse files Browse the repository at this point in the history
Adds possibility to get address and port that used by the server
  • Loading branch information
ksysoev committed May 19, 2024
2 parents dc190b1 + 619ff6b commit bdf051b
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 12 deletions.
52 changes: 40 additions & 12 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ const (
var ErrServerAlreadyRunning = fmt.Errorf("server is already running")

type Server struct {
baseCtx context.Context
mutex *sync.Mutex
handler *http.Server
addr string
channels []wasabi.Channel
baseCtx context.Context
listener net.Listener
listenerLock *sync.Mutex
mutex *sync.Mutex
handler *http.Server
addr string
channels []wasabi.Channel
}

type Option func(*Server)
Expand All @@ -33,11 +35,16 @@ type Option func(*Server)
// port - port to listen on
// returns new instance of Server
func NewServer(addr string, opts ...Option) *Server {
if addr == "" {
addr = ":http"
}

server := &Server{
addr: addr,
channels: make([]wasabi.Channel, 0, 1),
mutex: &sync.Mutex{},
baseCtx: context.Background(),
addr: addr,
channels: make([]wasabi.Channel, 0, 1),
mutex: &sync.Mutex{},
listenerLock: &sync.Mutex{},
baseCtx: context.Background(),
}

for _, opt := range opts {
Expand All @@ -64,7 +71,7 @@ func (s *Server) AddChannel(channel wasabi.Channel) {
// Run starts the server
// returns error if server is already running
// or if server fails to start
func (s *Server) Run() error {
func (s *Server) Run() (err error) {
if !s.mutex.TryLock() {
return ErrServerAlreadyRunning
}
Expand All @@ -82,9 +89,17 @@ func (s *Server) Run() error {

s.handler.Handler = mux

slog.Info("Starting app server on " + s.addr)
s.listenerLock.Lock()
s.listener, err = net.Listen("tcp", s.addr)
s.listenerLock.Unlock()

if err != nil {
return err
}

slog.Info("Starting app server on " + s.listener.Addr().String())

err := s.handler.ListenAndServe()
err = s.handler.Serve(s.listener)

if err != nil && err != http.ErrServerClosed {
return err
Expand Down Expand Up @@ -132,6 +147,19 @@ func (s *Server) Close(ctx ...context.Context) error {
return <-done
}

// Addr returns the server's network address.
// If the server is not running, it returns nil.
func (s *Server) Addr() net.Addr {
s.listenerLock.Lock()
defer s.listenerLock.Unlock()

if s.listener == nil {
return nil
}

return s.listener.Addr()
}

// BaseContext optionally specifies based context that will be used for all connections.
// If not specified, context.Background() will be used.
func WithBaseContext(ctx context.Context) Option {
Expand Down
56 changes: 56 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ func TestNewServer(t *testing.T) {
if server.mutex == nil {
t.Error("Expected non-nil mutex")
}

server = NewServer("")

if server.addr != ":http" {
t.Errorf("Expected default port :http, but got %s", server.addr)
}
}
func TestServer_AddChannel(t *testing.T) {
// Create a new Server instance
Expand Down Expand Up @@ -218,3 +224,53 @@ func TestServer_Close_NoContext(t *testing.T) {
t.Error("Expected server to stop")
}
}

func TestServer_Addr(t *testing.T) {
// Create a new Server instance
server := NewServer(":0")
defer server.Close()

// Create a mock channel
channel := mocks.NewMockChannel(t)
channel.EXPECT().Path().Return("/test")
channel.EXPECT().Handler().Return(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
channel.EXPECT().Close().Return(nil)

server.AddChannel(channel)

if server.Addr() != nil {
t.Error("Expected nil address for server that is not running")
}

// Start the server in a separate goroutine
done := make(chan struct{})

// Run the server
for i := 0; i < 2; i++ {
go func() {
err := server.Run()
switch err {
case nil:
case ErrServerAlreadyRunning:
close(done)
default:
t.Errorf("Got unexpected error: %v", err)
}
}()
}

select {
case <-done:
case <-time.After(1 * time.Second):
t.Error("Expected server to start")
}

// Wait for the server to fully start
time.Sleep(1 * time.Millisecond)

addr := server.Addr()

if addr == nil {
t.Error("Expected non-empty address")
}
}

0 comments on commit bdf051b

Please sign in to comment.