@@ -63,6 +63,9 @@ func createApp(opts ...OptionFunc) *HTTPBin {
63
63
DripDelay : 0 ,
64
64
DripDuration : 100 * time .Millisecond ,
65
65
DripNumBytes : 10 ,
66
+ SSECount : 10 ,
67
+ SSEDelay : 0 ,
68
+ SSEDuration : 100 * time .Millisecond ,
66
69
}),
67
70
WithMaxBodySize (maxBodySize ),
68
71
WithMaxDuration (maxDuration ),
@@ -2957,6 +2960,246 @@ func TestHostname(t *testing.T) {
2957
2960
})
2958
2961
}
2959
2962
2963
+ func TestSSE (t * testing.T ) {
2964
+ t .Parallel ()
2965
+
2966
+ parseServerSentEvent := func (t * testing.T , buf * bufio.Reader ) (serverSentEvent , error ) {
2967
+ t .Helper ()
2968
+
2969
+ // match "event: ping" line
2970
+ eventLine , err := buf .ReadBytes ('\n' )
2971
+ if err != nil {
2972
+ return serverSentEvent {}, err
2973
+ }
2974
+ _ , eventType , _ := bytes .Cut (eventLine , []byte (":" ))
2975
+ assert .Equal (t , string (bytes .TrimSpace (eventType )), "ping" , "unexpected event type" )
2976
+
2977
+ // match "data: {...}" line
2978
+ dataLine , err := buf .ReadBytes ('\n' )
2979
+ if err != nil {
2980
+ return serverSentEvent {}, err
2981
+ }
2982
+ _ , data , _ := bytes .Cut (dataLine , []byte (":" ))
2983
+ var event serverSentEvent
2984
+ assert .NilError (t , json .Unmarshal (data , & event ))
2985
+
2986
+ // match newline after event data
2987
+ b , err := buf .ReadByte ()
2988
+ if err != nil && err != io .EOF {
2989
+ assert .NilError (t , err )
2990
+ }
2991
+ if b != '\n' {
2992
+ t .Fatalf ("expected newline after event data, got %q" , b )
2993
+ }
2994
+
2995
+ return event , nil
2996
+ }
2997
+
2998
+ parseServerSentEventStream := func (t * testing.T , resp * http.Response ) []serverSentEvent {
2999
+ t .Helper ()
3000
+ buf := bufio .NewReader (resp .Body )
3001
+ var events []serverSentEvent
3002
+ for {
3003
+ event , err := parseServerSentEvent (t , buf )
3004
+ if err == io .EOF {
3005
+ break
3006
+ }
3007
+ assert .NilError (t , err )
3008
+ events = append (events , event )
3009
+ }
3010
+ return events
3011
+ }
3012
+
3013
+ okTests := []struct {
3014
+ params * url.Values
3015
+ duration time.Duration
3016
+ count int
3017
+ }{
3018
+ // there are useful defaults for all values
3019
+ {& url.Values {}, 0 , 10 },
3020
+
3021
+ // go-style durations are accepted
3022
+ {& url.Values {"duration" : {"5ms" }}, 5 * time .Millisecond , 10 },
3023
+ {& url.Values {"duration" : {"10ns" }}, 0 , 10 },
3024
+ {& url.Values {"delay" : {"5ms" }}, 5 * time .Millisecond , 10 },
3025
+ {& url.Values {"delay" : {"0h" }}, 0 , 10 },
3026
+
3027
+ // or floating point seconds
3028
+ {& url.Values {"duration" : {"0.25" }}, 250 * time .Millisecond , 10 },
3029
+ {& url.Values {"duration" : {"1" }}, 1 * time .Second , 10 },
3030
+ {& url.Values {"delay" : {"0.25" }}, 250 * time .Millisecond , 10 },
3031
+ {& url.Values {"delay" : {"0" }}, 0 , 10 },
3032
+
3033
+ {& url.Values {"count" : {"1" }}, 0 , 1 },
3034
+ {& url.Values {"count" : {"011" }}, 0 , 11 },
3035
+ {& url.Values {"count" : {fmt .Sprintf ("%d" , app .maxSSECount )}}, 0 , int (app .maxSSECount )},
3036
+
3037
+ {& url.Values {"duration" : {"250ms" }, "delay" : {"250ms" }}, 500 * time .Millisecond , 10 },
3038
+ {& url.Values {"duration" : {"250ms" }, "delay" : {"0.25s" }}, 500 * time .Millisecond , 10 },
3039
+ }
3040
+ for _ , test := range okTests {
3041
+ test := test
3042
+ t .Run (fmt .Sprintf ("ok/%s" , test .params .Encode ()), func (t * testing.T ) {
3043
+ t .Parallel ()
3044
+
3045
+ url := "/sse?" + test .params .Encode ()
3046
+
3047
+ start := time .Now ()
3048
+ req := newTestRequest (t , "GET" , url )
3049
+ resp := must .DoReq (t , client , req )
3050
+ assert .StatusCode (t , resp , http .StatusOK )
3051
+ events := parseServerSentEventStream (t , resp )
3052
+
3053
+ if elapsed := time .Since (start ); elapsed < test .duration {
3054
+ t .Fatalf ("expected minimum duration of %s, request took %s" , test .duration , elapsed )
3055
+ }
3056
+ assert .ContentType (t , resp , sseContentType )
3057
+ assert .DeepEqual (t , resp .TransferEncoding , []string {"chunked" }, "unexpected Transfer-Encoding header" )
3058
+ assert .Equal (t , len (events ), test .count , "unexpected number of events" )
3059
+ })
3060
+ }
3061
+
3062
+ badTests := []struct {
3063
+ params * url.Values
3064
+ code int
3065
+ }{
3066
+ {& url.Values {"duration" : {"0" }}, http .StatusBadRequest },
3067
+ {& url.Values {"duration" : {"0s" }}, http .StatusBadRequest },
3068
+ {& url.Values {"duration" : {"1m" }}, http .StatusBadRequest },
3069
+ {& url.Values {"duration" : {"-1ms" }}, http .StatusBadRequest },
3070
+ {& url.Values {"duration" : {"1001" }}, http .StatusBadRequest },
3071
+ {& url.Values {"duration" : {"-1" }}, http .StatusBadRequest },
3072
+ {& url.Values {"duration" : {"foo" }}, http .StatusBadRequest },
3073
+
3074
+ {& url.Values {"delay" : {"1m" }}, http .StatusBadRequest },
3075
+ {& url.Values {"delay" : {"-1ms" }}, http .StatusBadRequest },
3076
+ {& url.Values {"delay" : {"1001" }}, http .StatusBadRequest },
3077
+ {& url.Values {"delay" : {"-1" }}, http .StatusBadRequest },
3078
+ {& url.Values {"delay" : {"foo" }}, http .StatusBadRequest },
3079
+
3080
+ {& url.Values {"count" : {"foo" }}, http .StatusBadRequest },
3081
+ {& url.Values {"count" : {"0" }}, http .StatusBadRequest },
3082
+ {& url.Values {"count" : {"-1" }}, http .StatusBadRequest },
3083
+ {& url.Values {"count" : {"0xff" }}, http .StatusBadRequest },
3084
+ {& url.Values {"count" : {fmt .Sprintf ("%d" , app .maxSSECount + 1 )}}, http .StatusBadRequest },
3085
+
3086
+ // request would take too long
3087
+ {& url.Values {"duration" : {"750ms" }, "delay" : {"500ms" }}, http .StatusBadRequest },
3088
+ }
3089
+ for _ , test := range badTests {
3090
+ test := test
3091
+ t .Run (fmt .Sprintf ("bad/%s" , test .params .Encode ()), func (t * testing.T ) {
3092
+ t .Parallel ()
3093
+ url := "/sse?" + test .params .Encode ()
3094
+ req := newTestRequest (t , "GET" , url )
3095
+ resp := must .DoReq (t , client , req )
3096
+ defer consumeAndCloseBody (resp )
3097
+ assert .StatusCode (t , resp , test .code )
3098
+ })
3099
+ }
3100
+
3101
+ t .Run ("writes are actually incremmental" , func (t * testing.T ) {
3102
+ t .Parallel ()
3103
+
3104
+ var (
3105
+ duration = 100 * time .Millisecond
3106
+ count = 3
3107
+ endpoint = fmt .Sprintf ("/sse?duration=%s&count=%d" , duration , count )
3108
+
3109
+ // Match server logic for calculating the delay between writes
3110
+ wantPauseBetweenWrites = duration / time .Duration (count - 1 )
3111
+ )
3112
+
3113
+ req := newTestRequest (t , "GET" , endpoint )
3114
+ resp := must .DoReq (t , client , req )
3115
+ buf := bufio .NewReader (resp .Body )
3116
+ eventCount := 0
3117
+
3118
+ // Here we read from the response one byte at a time, and ensure that
3119
+ // at least the expected delay occurs for each read.
3120
+ //
3121
+ // The request above includes an initial delay equal to the expected
3122
+ // wait between writes so that even the first iteration of this loop
3123
+ // expects to wait the same amount of time for a read.
3124
+ for i := 0 ; ; i ++ {
3125
+ start := time .Now ()
3126
+ event , err := parseServerSentEvent (t , buf )
3127
+ if err == io .EOF {
3128
+ break
3129
+ }
3130
+ assert .NilError (t , err )
3131
+ gotPause := time .Since (start )
3132
+
3133
+ // We expect to read exactly one byte on each iteration. On the
3134
+ // last iteration, we expct to hit EOF after reading the final
3135
+ // byte, because the server does not pause after the last write.
3136
+ assert .Equal (t , event .ID , i , "unexpected SSE event ID" )
3137
+
3138
+ // only ensure that we pause for the expected time between writes
3139
+ // (allowing for minor mismatch in local timers and server timers)
3140
+ // after the first byte.
3141
+ if i > 0 {
3142
+ assert .RoughDuration (t , gotPause , wantPauseBetweenWrites , 3 * time .Millisecond )
3143
+ }
3144
+
3145
+ eventCount ++
3146
+ }
3147
+
3148
+ assert .Equal (t , eventCount , count , "unexpected number of events" )
3149
+ })
3150
+
3151
+ t .Run ("handle cancelation during initial delay" , func (t * testing.T ) {
3152
+ t .Parallel ()
3153
+
3154
+ // For this test, we expect the client to time out and cancel the
3155
+ // request after 10ms. The handler should still be in its intitial
3156
+ // delay period, so this will result in a request error since no status
3157
+ // code will be written before the cancelation.
3158
+ ctx , cancel := context .WithTimeout (context .Background (), 25 * time .Millisecond )
3159
+ defer cancel ()
3160
+
3161
+ req := newTestRequest (t , "GET" , "/sse?duration=500ms&delay=500ms" ).WithContext (ctx )
3162
+ if _ , err := client .Do (req ); ! os .IsTimeout (err ) {
3163
+ t .Fatalf ("expected timeout error, got %s" , err )
3164
+ }
3165
+ })
3166
+
3167
+ t .Run ("handle cancelation during stream" , func (t * testing.T ) {
3168
+ t .Parallel ()
3169
+
3170
+ ctx , cancel := context .WithTimeout (context .Background (), 100 * time .Millisecond )
3171
+ defer cancel ()
3172
+
3173
+ req := newTestRequest (t , "GET" , "/sse?duration=900ms&delay=0&count=2" ).WithContext (ctx )
3174
+ resp := must .DoReq (t , client , req )
3175
+ defer consumeAndCloseBody (resp )
3176
+
3177
+ // In this test, the server should have started an OK response before
3178
+ // our client timeout cancels the request, so we should get an OK here.
3179
+ assert .StatusCode (t , resp , http .StatusOK )
3180
+
3181
+ // But, we should time out while trying to read the whole response
3182
+ // body.
3183
+ body , err := io .ReadAll (resp .Body )
3184
+ if ! os .IsTimeout (err ) {
3185
+ t .Fatalf ("expected timeout reading body, got %s" , err )
3186
+ }
3187
+
3188
+ // partial read should include the first whole event
3189
+ event , err := parseServerSentEvent (t , bufio .NewReader (bytes .NewReader (body )))
3190
+ assert .NilError (t , err )
3191
+ assert .Equal (t , event .ID , 0 , "unexpected SSE event ID" )
3192
+ })
3193
+
3194
+ t .Run ("ensure HEAD request works with streaming responses" , func (t * testing.T ) {
3195
+ t .Parallel ()
3196
+ req := newTestRequest (t , "HEAD" , "/sse?duration=900ms&delay=100ms" )
3197
+ resp := must .DoReq (t , client , req )
3198
+ assert .StatusCode (t , resp , http .StatusOK )
3199
+ assert .BodySize (t , resp , 0 )
3200
+ })
3201
+ }
3202
+
2960
3203
func TestWebSocketEcho (t * testing.T ) {
2961
3204
// ========================================================================
2962
3205
// Note: Here we only test input validation for the websocket endpoint.
@@ -3028,6 +3271,7 @@ func TestWebSocketEcho(t *testing.T) {
3028
3271
})
3029
3272
}
3030
3273
}
3274
+
3031
3275
func newTestServer (handler http.Handler ) (* httptest.Server , * http.Client ) {
3032
3276
srv := httptest .NewServer (handler )
3033
3277
client := srv .Client ()
0 commit comments