Skip to content

Commit 21c68b8

Browse files
authored
feat: add /sse endpoint to test Server-Sent Events (#160)
Each event is a "ping" that includes an incrementing integer ID and an integer Unix timestamp with millisecond resolution: event: ping data: {"id":9,"timestamp":1702417925258} Fixes #150.
1 parent 6ad2943 commit 21c68b8

File tree

6 files changed

+439
-80
lines changed

6 files changed

+439
-80
lines changed

httpbin/handlers.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/json"
88
"errors"
99
"fmt"
10+
"io"
1011
"net/http"
1112
"net/http/httputil"
1213
"net/url"
@@ -1108,6 +1109,115 @@ func (h *HTTPBin) Hostname(w http.ResponseWriter, _ *http.Request) {
11081109
})
11091110
}
11101111

1112+
// SSE writes a stream of events over a duration after an optional
1113+
// initial delay.
1114+
func (h *HTTPBin) SSE(w http.ResponseWriter, r *http.Request) {
1115+
q := r.URL.Query()
1116+
var (
1117+
count = h.DefaultParams.SSECount
1118+
duration = h.DefaultParams.SSEDuration
1119+
delay = h.DefaultParams.SSEDelay
1120+
err error
1121+
)
1122+
1123+
if userCount := q.Get("count"); userCount != "" {
1124+
count, err = strconv.Atoi(userCount)
1125+
if err != nil {
1126+
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %w", err))
1127+
return
1128+
}
1129+
if count < 1 || int64(count) > h.maxSSECount {
1130+
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: must in range [1, %d]", h.maxSSECount))
1131+
return
1132+
}
1133+
}
1134+
1135+
if userDuration := q.Get("duration"); userDuration != "" {
1136+
duration, err = parseBoundedDuration(userDuration, 1, h.MaxDuration)
1137+
if err != nil {
1138+
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid duration: %w", err))
1139+
return
1140+
}
1141+
}
1142+
1143+
if userDelay := q.Get("delay"); userDelay != "" {
1144+
delay, err = parseBoundedDuration(userDelay, 0, h.MaxDuration)
1145+
if err != nil {
1146+
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid delay: %w", err))
1147+
return
1148+
}
1149+
}
1150+
1151+
if duration+delay > h.MaxDuration {
1152+
http.Error(w, "Too much time", http.StatusBadRequest)
1153+
return
1154+
}
1155+
1156+
pause := duration
1157+
if count > 1 {
1158+
// compensate for lack of pause after final write (i.e. if we're
1159+
// writing 10 events, we will only pause 9 times)
1160+
pause = duration / time.Duration(count-1)
1161+
}
1162+
1163+
// Initial delay before we send any response data
1164+
if delay > 0 {
1165+
select {
1166+
case <-time.After(delay):
1167+
// ok
1168+
case <-r.Context().Done():
1169+
w.WriteHeader(499) // "Client Closed Request" https://httpstatuses.com/499
1170+
return
1171+
}
1172+
}
1173+
1174+
w.Header().Set("Content-Type", sseContentType)
1175+
w.WriteHeader(http.StatusOK)
1176+
1177+
flusher := w.(http.Flusher)
1178+
1179+
// special case when we only have one event to write
1180+
if count == 1 {
1181+
writeServerSentEvent(w, 0, time.Now())
1182+
flusher.Flush()
1183+
return
1184+
}
1185+
1186+
ticker := time.NewTicker(pause)
1187+
defer ticker.Stop()
1188+
1189+
for i := 0; i < count; i++ {
1190+
writeServerSentEvent(w, i, time.Now())
1191+
flusher.Flush()
1192+
1193+
// don't pause after last byte
1194+
if i == count-1 {
1195+
return
1196+
}
1197+
1198+
select {
1199+
case <-ticker.C:
1200+
// ok
1201+
case <-r.Context().Done():
1202+
return
1203+
}
1204+
}
1205+
}
1206+
1207+
// writeServerSentEvent writes the bytes that constitute a single server-sent
1208+
// event message, including both the event type and data.
1209+
func writeServerSentEvent(dst io.Writer, id int, ts time.Time) {
1210+
dst.Write([]byte("event: ping\n"))
1211+
dst.Write([]byte("data: "))
1212+
json.NewEncoder(dst).Encode(serverSentEvent{
1213+
ID: id,
1214+
Timestamp: ts.UnixMilli(),
1215+
})
1216+
// each SSE ends with two newlines (\n\n), the first of which is written
1217+
// automatically by json.NewEncoder().Encode()
1218+
dst.Write([]byte("\n"))
1219+
}
1220+
11111221
// WebSocketEcho - simple websocket echo server, where the max fragment size
11121222
// and max message size can be controlled by clients.
11131223
func (h *HTTPBin) WebSocketEcho(w http.ResponseWriter, r *http.Request) {

httpbin/handlers_test.go

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ func createApp(opts ...OptionFunc) *HTTPBin {
6363
DripDelay: 0,
6464
DripDuration: 100 * time.Millisecond,
6565
DripNumBytes: 10,
66+
SSECount: 10,
67+
SSEDelay: 0,
68+
SSEDuration: 100 * time.Millisecond,
6669
}),
6770
WithMaxBodySize(maxBodySize),
6871
WithMaxDuration(maxDuration),
@@ -2957,6 +2960,246 @@ func TestHostname(t *testing.T) {
29572960
})
29582961
}
29592962

2963+
func TestSSE(t *testing.T) {
2964+
t.Parallel()
2965+
2966+
parseServerSentEvent := func(t *testing.T, buf *bufio.Reader) (serverSentEvent, error) {
2967+
t.Helper()
2968+
2969+
// match "event: ping" line
2970+
eventLine, err := buf.ReadBytes('\n')
2971+
if err != nil {
2972+
return serverSentEvent{}, err
2973+
}
2974+
_, eventType, _ := bytes.Cut(eventLine, []byte(":"))
2975+
assert.Equal(t, string(bytes.TrimSpace(eventType)), "ping", "unexpected event type")
2976+
2977+
// match "data: {...}" line
2978+
dataLine, err := buf.ReadBytes('\n')
2979+
if err != nil {
2980+
return serverSentEvent{}, err
2981+
}
2982+
_, data, _ := bytes.Cut(dataLine, []byte(":"))
2983+
var event serverSentEvent
2984+
assert.NilError(t, json.Unmarshal(data, &event))
2985+
2986+
// match newline after event data
2987+
b, err := buf.ReadByte()
2988+
if err != nil && err != io.EOF {
2989+
assert.NilError(t, err)
2990+
}
2991+
if b != '\n' {
2992+
t.Fatalf("expected newline after event data, got %q", b)
2993+
}
2994+
2995+
return event, nil
2996+
}
2997+
2998+
parseServerSentEventStream := func(t *testing.T, resp *http.Response) []serverSentEvent {
2999+
t.Helper()
3000+
buf := bufio.NewReader(resp.Body)
3001+
var events []serverSentEvent
3002+
for {
3003+
event, err := parseServerSentEvent(t, buf)
3004+
if err == io.EOF {
3005+
break
3006+
}
3007+
assert.NilError(t, err)
3008+
events = append(events, event)
3009+
}
3010+
return events
3011+
}
3012+
3013+
okTests := []struct {
3014+
params *url.Values
3015+
duration time.Duration
3016+
count int
3017+
}{
3018+
// there are useful defaults for all values
3019+
{&url.Values{}, 0, 10},
3020+
3021+
// go-style durations are accepted
3022+
{&url.Values{"duration": {"5ms"}}, 5 * time.Millisecond, 10},
3023+
{&url.Values{"duration": {"10ns"}}, 0, 10},
3024+
{&url.Values{"delay": {"5ms"}}, 5 * time.Millisecond, 10},
3025+
{&url.Values{"delay": {"0h"}}, 0, 10},
3026+
3027+
// or floating point seconds
3028+
{&url.Values{"duration": {"0.25"}}, 250 * time.Millisecond, 10},
3029+
{&url.Values{"duration": {"1"}}, 1 * time.Second, 10},
3030+
{&url.Values{"delay": {"0.25"}}, 250 * time.Millisecond, 10},
3031+
{&url.Values{"delay": {"0"}}, 0, 10},
3032+
3033+
{&url.Values{"count": {"1"}}, 0, 1},
3034+
{&url.Values{"count": {"011"}}, 0, 11},
3035+
{&url.Values{"count": {fmt.Sprintf("%d", app.maxSSECount)}}, 0, int(app.maxSSECount)},
3036+
3037+
{&url.Values{"duration": {"250ms"}, "delay": {"250ms"}}, 500 * time.Millisecond, 10},
3038+
{&url.Values{"duration": {"250ms"}, "delay": {"0.25s"}}, 500 * time.Millisecond, 10},
3039+
}
3040+
for _, test := range okTests {
3041+
test := test
3042+
t.Run(fmt.Sprintf("ok/%s", test.params.Encode()), func(t *testing.T) {
3043+
t.Parallel()
3044+
3045+
url := "/sse?" + test.params.Encode()
3046+
3047+
start := time.Now()
3048+
req := newTestRequest(t, "GET", url)
3049+
resp := must.DoReq(t, client, req)
3050+
assert.StatusCode(t, resp, http.StatusOK)
3051+
events := parseServerSentEventStream(t, resp)
3052+
3053+
if elapsed := time.Since(start); elapsed < test.duration {
3054+
t.Fatalf("expected minimum duration of %s, request took %s", test.duration, elapsed)
3055+
}
3056+
assert.ContentType(t, resp, sseContentType)
3057+
assert.DeepEqual(t, resp.TransferEncoding, []string{"chunked"}, "unexpected Transfer-Encoding header")
3058+
assert.Equal(t, len(events), test.count, "unexpected number of events")
3059+
})
3060+
}
3061+
3062+
badTests := []struct {
3063+
params *url.Values
3064+
code int
3065+
}{
3066+
{&url.Values{"duration": {"0"}}, http.StatusBadRequest},
3067+
{&url.Values{"duration": {"0s"}}, http.StatusBadRequest},
3068+
{&url.Values{"duration": {"1m"}}, http.StatusBadRequest},
3069+
{&url.Values{"duration": {"-1ms"}}, http.StatusBadRequest},
3070+
{&url.Values{"duration": {"1001"}}, http.StatusBadRequest},
3071+
{&url.Values{"duration": {"-1"}}, http.StatusBadRequest},
3072+
{&url.Values{"duration": {"foo"}}, http.StatusBadRequest},
3073+
3074+
{&url.Values{"delay": {"1m"}}, http.StatusBadRequest},
3075+
{&url.Values{"delay": {"-1ms"}}, http.StatusBadRequest},
3076+
{&url.Values{"delay": {"1001"}}, http.StatusBadRequest},
3077+
{&url.Values{"delay": {"-1"}}, http.StatusBadRequest},
3078+
{&url.Values{"delay": {"foo"}}, http.StatusBadRequest},
3079+
3080+
{&url.Values{"count": {"foo"}}, http.StatusBadRequest},
3081+
{&url.Values{"count": {"0"}}, http.StatusBadRequest},
3082+
{&url.Values{"count": {"-1"}}, http.StatusBadRequest},
3083+
{&url.Values{"count": {"0xff"}}, http.StatusBadRequest},
3084+
{&url.Values{"count": {fmt.Sprintf("%d", app.maxSSECount+1)}}, http.StatusBadRequest},
3085+
3086+
// request would take too long
3087+
{&url.Values{"duration": {"750ms"}, "delay": {"500ms"}}, http.StatusBadRequest},
3088+
}
3089+
for _, test := range badTests {
3090+
test := test
3091+
t.Run(fmt.Sprintf("bad/%s", test.params.Encode()), func(t *testing.T) {
3092+
t.Parallel()
3093+
url := "/sse?" + test.params.Encode()
3094+
req := newTestRequest(t, "GET", url)
3095+
resp := must.DoReq(t, client, req)
3096+
defer consumeAndCloseBody(resp)
3097+
assert.StatusCode(t, resp, test.code)
3098+
})
3099+
}
3100+
3101+
t.Run("writes are actually incremmental", func(t *testing.T) {
3102+
t.Parallel()
3103+
3104+
var (
3105+
duration = 100 * time.Millisecond
3106+
count = 3
3107+
endpoint = fmt.Sprintf("/sse?duration=%s&count=%d", duration, count)
3108+
3109+
// Match server logic for calculating the delay between writes
3110+
wantPauseBetweenWrites = duration / time.Duration(count-1)
3111+
)
3112+
3113+
req := newTestRequest(t, "GET", endpoint)
3114+
resp := must.DoReq(t, client, req)
3115+
buf := bufio.NewReader(resp.Body)
3116+
eventCount := 0
3117+
3118+
// Here we read from the response one byte at a time, and ensure that
3119+
// at least the expected delay occurs for each read.
3120+
//
3121+
// The request above includes an initial delay equal to the expected
3122+
// wait between writes so that even the first iteration of this loop
3123+
// expects to wait the same amount of time for a read.
3124+
for i := 0; ; i++ {
3125+
start := time.Now()
3126+
event, err := parseServerSentEvent(t, buf)
3127+
if err == io.EOF {
3128+
break
3129+
}
3130+
assert.NilError(t, err)
3131+
gotPause := time.Since(start)
3132+
3133+
// We expect to read exactly one byte on each iteration. On the
3134+
// last iteration, we expct to hit EOF after reading the final
3135+
// byte, because the server does not pause after the last write.
3136+
assert.Equal(t, event.ID, i, "unexpected SSE event ID")
3137+
3138+
// only ensure that we pause for the expected time between writes
3139+
// (allowing for minor mismatch in local timers and server timers)
3140+
// after the first byte.
3141+
if i > 0 {
3142+
assert.RoughDuration(t, gotPause, wantPauseBetweenWrites, 3*time.Millisecond)
3143+
}
3144+
3145+
eventCount++
3146+
}
3147+
3148+
assert.Equal(t, eventCount, count, "unexpected number of events")
3149+
})
3150+
3151+
t.Run("handle cancelation during initial delay", func(t *testing.T) {
3152+
t.Parallel()
3153+
3154+
// For this test, we expect the client to time out and cancel the
3155+
// request after 10ms. The handler should still be in its intitial
3156+
// delay period, so this will result in a request error since no status
3157+
// code will be written before the cancelation.
3158+
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond)
3159+
defer cancel()
3160+
3161+
req := newTestRequest(t, "GET", "/sse?duration=500ms&delay=500ms").WithContext(ctx)
3162+
if _, err := client.Do(req); !os.IsTimeout(err) {
3163+
t.Fatalf("expected timeout error, got %s", err)
3164+
}
3165+
})
3166+
3167+
t.Run("handle cancelation during stream", func(t *testing.T) {
3168+
t.Parallel()
3169+
3170+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
3171+
defer cancel()
3172+
3173+
req := newTestRequest(t, "GET", "/sse?duration=900ms&delay=0&count=2").WithContext(ctx)
3174+
resp := must.DoReq(t, client, req)
3175+
defer consumeAndCloseBody(resp)
3176+
3177+
// In this test, the server should have started an OK response before
3178+
// our client timeout cancels the request, so we should get an OK here.
3179+
assert.StatusCode(t, resp, http.StatusOK)
3180+
3181+
// But, we should time out while trying to read the whole response
3182+
// body.
3183+
body, err := io.ReadAll(resp.Body)
3184+
if !os.IsTimeout(err) {
3185+
t.Fatalf("expected timeout reading body, got %s", err)
3186+
}
3187+
3188+
// partial read should include the first whole event
3189+
event, err := parseServerSentEvent(t, bufio.NewReader(bytes.NewReader(body)))
3190+
assert.NilError(t, err)
3191+
assert.Equal(t, event.ID, 0, "unexpected SSE event ID")
3192+
})
3193+
3194+
t.Run("ensure HEAD request works with streaming responses", func(t *testing.T) {
3195+
t.Parallel()
3196+
req := newTestRequest(t, "HEAD", "/sse?duration=900ms&delay=100ms")
3197+
resp := must.DoReq(t, client, req)
3198+
assert.StatusCode(t, resp, http.StatusOK)
3199+
assert.BodySize(t, resp, 0)
3200+
})
3201+
}
3202+
29603203
func TestWebSocketEcho(t *testing.T) {
29613204
// ========================================================================
29623205
// Note: Here we only test input validation for the websocket endpoint.
@@ -3028,6 +3271,7 @@ func TestWebSocketEcho(t *testing.T) {
30283271
})
30293272
}
30303273
}
3274+
30313275
func newTestServer(handler http.Handler) (*httptest.Server, *http.Client) {
30323276
srv := httptest.NewServer(handler)
30333277
client := srv.Client()

0 commit comments

Comments
 (0)