-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds documentation and tests for middleware code
- Loading branch information
Showing
4 changed files
with
202 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package http | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"sync" | ||
"testing" | ||
) | ||
|
||
func TestNewStashMiddleware(t *testing.T) { | ||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
// Test if the stash value is set in the request context | ||
stash := r.Context().Value(Stash) | ||
if stash == nil { | ||
t.Error("Expected stash value to be set in the request context") | ||
} | ||
|
||
// Test if the stash value is of type *sync.Map | ||
_, ok := stash.(*sync.Map) | ||
if !ok { | ||
t.Error("Expected stash value to be of type *sync.Map") | ||
} | ||
|
||
// Test if the next handler is called | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
|
||
middleware := NewStashMiddleware() | ||
testServer := httptest.NewServer(middleware(handler)) | ||
|
||
defer testServer.Close() | ||
|
||
// Send a test request to the server | ||
resp, err := http.Get(testServer.URL) | ||
if err != nil { | ||
t.Fatalf("Failed to send request to test server: %v", err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
// Test if the response status code is 200 OK | ||
if resp.StatusCode != http.StatusOK { | ||
t.Errorf("Expected status code %d, but got %d", http.StatusOK, resp.StatusCode) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package request | ||
|
||
import ( | ||
"errors" | ||
"testing" | ||
|
||
"github.com/ksysoev/wasabi" | ||
"github.com/ksysoev/wasabi/dispatch" | ||
"github.com/ksysoev/wasabi/mocks" | ||
) | ||
|
||
func TestNewErrorHandlingMiddleware(t *testing.T) { | ||
// Define a mock request handler | ||
mockHandler := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { | ||
// Simulate an error | ||
return errors.New("mock error") | ||
}) | ||
|
||
// Define a mock error handler | ||
mockErrorHandler := func(conn wasabi.Connection, req wasabi.Request, err error) error { | ||
// Verify that the error handler is called with the correct parameters | ||
if conn == nil || req == nil || err == nil { | ||
t.Error("Error handler called with nil parameters") | ||
} | ||
|
||
// Return a custom error | ||
return errors.New("custom error") | ||
} | ||
|
||
// Create the error handling middleware | ||
middleware := NewErrorHandlingMiddleware(mockErrorHandler) | ||
|
||
// Create a mock connection and request | ||
mockConn := mocks.NewMockConnection(t) | ||
mockReq := mocks.NewMockRequest(t) | ||
|
||
// Call the middleware with the mock handler | ||
err := middleware(mockHandler).Handle(mockConn, mockReq) | ||
|
||
// Verify that the error returned by the middleware is the custom error | ||
if err == nil || err.Error() != "custom error" { | ||
t.Errorf("Expected error to be 'custom error', but got '%v'", err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
package request | ||
|
||
import ( | ||
"testing" | ||
"time" | ||
|
||
"github.com/ksysoev/wasabi" | ||
"github.com/ksysoev/wasabi/dispatch" | ||
"github.com/ksysoev/wasabi/mocks" | ||
) | ||
|
||
func TestNewRateLimiterMiddleware(t *testing.T) { | ||
// Mock requestLimit function | ||
requestLimit := func(req wasabi.Request) (string, time.Duration, uint64) { | ||
return "test_key", time.Second, 10 | ||
} | ||
|
||
// Mock next RequestHandler | ||
next := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { | ||
// Mock implementation of next handler | ||
return nil | ||
}) | ||
|
||
// Create rate limiter middleware | ||
middleware := NewRateLimiterMiddleware(requestLimit) | ||
|
||
// Create a mock connection and request | ||
conn := mocks.NewMockConnection(t) | ||
req := mocks.NewMockRequest(t) | ||
|
||
// Test rate limiter middleware | ||
err := middleware(next).Handle(conn, req) | ||
|
||
if err != nil { | ||
t.Errorf("Expected no error, but got %v", err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package request | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sync" | ||
"testing" | ||
"time" | ||
|
||
"github.com/ksysoev/wasabi" | ||
"github.com/ksysoev/wasabi/dispatch" | ||
"github.com/ksysoev/wasabi/mocks" | ||
) | ||
|
||
func TestNewTrottlerMiddleware(t *testing.T) { | ||
limit := uint(3) | ||
middleware := NewTrottlerMiddleware(limit) | ||
|
||
// Create a mock request handler | ||
mockHandler := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { | ||
select { | ||
case <-req.Context().Done(): | ||
return nil | ||
case <-time.After(time.Second): | ||
return fmt.Errorf("request timed out") | ||
} | ||
}) | ||
ctx1, cancel := context.WithCancel(context.Background()) | ||
// Create a mock connection and request | ||
mockConn := mocks.NewMockConnection(t) | ||
mockReq := mocks.NewMockRequest(t) | ||
|
||
mockReq.EXPECT().Context().Return(ctx1) | ||
// Test with multiple concurrent requests | ||
wg := sync.WaitGroup{} | ||
readyChan := make(chan struct{}, limit) | ||
|
||
for i := 0; i < int(limit); i++ { | ||
wg.Add(1) | ||
|
||
go func() { | ||
readyChan <- struct{}{} | ||
|
||
err := middleware(mockHandler).Handle(mockConn, mockReq) | ||
if err != nil && err != context.Canceled { | ||
t.Errorf("Expected no error, but got %v", err) | ||
} | ||
|
||
wg.Done() | ||
}() | ||
} | ||
|
||
for i := 0; i < int(limit); i++ { | ||
<-readyChan | ||
} | ||
|
||
time.Sleep(10 * time.Millisecond) | ||
|
||
mockHandlerInstant := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { | ||
return nil | ||
}) | ||
|
||
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Millisecond) | ||
defer cancel2() | ||
|
||
mockReq1 := mocks.NewMockRequest(t) | ||
mockReq1.EXPECT().Context().Return(ctx2) | ||
|
||
// Test with additional requests that should be throttled | ||
err := middleware(mockHandlerInstant).Handle(mockConn, mockReq1) | ||
if err == nil { | ||
t.Error("Expected error due to throttling, but got nil") | ||
} | ||
|
||
cancel() | ||
wg.Wait() | ||
} |