diff --git a/goupnp.go b/goupnp.go index 93c588b..9fa8963 100644 --- a/goupnp.go +++ b/goupnp.go @@ -85,7 +85,10 @@ func DiscoverDevicesCtx(ctx context.Context, searchTarget string) ([]MaybeRootDe return nil, err } defer hcCleanup() - responses, err := ssdp.SSDPRawSearchCtx(ctx, hc, string(searchTarget), 2, 3) + + searchCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + responses, err := ssdp.RawSearch(searchCtx, hc, string(searchTarget), 3) if err != nil { return nil, err } diff --git a/httpu/httpu.go b/httpu/httpu.go index 808d600..5bb8d67 100644 --- a/httpu/httpu.go +++ b/httpu/httpu.go @@ -3,6 +3,7 @@ package httpu import ( "bufio" "bytes" + "context" "errors" "fmt" "log" @@ -26,6 +27,27 @@ type ClientInterface interface { ) ([]*http.Response, error) } +// ClientInterfaceCtx is the equivalent of ClientInterface, except with methods +// taking a context.Context parameter. +type ClientInterfaceCtx interface { + // DoWithContext performs a request. If the input request has a + // deadline, then that value will be used as the timeout for how long + // to wait before returning the responses that were received. If the + // request's context is canceled, this method will return immediately. + // + // If the request's context is never canceled, and does not have a + // deadline, then this function WILL NEVER RETURN. You MUST set an + // appropriate deadline on the context, or otherwise cancel it when you + // want to finish an operation. + // + // An error is only returned for failing to send the request. Failures + // in receipt simply do not add to the resulting responses. + DoWithContext( + req *http.Request, + numSends int, + ) ([]*http.Response, error) +} + // HTTPUClient is a client for dealing with HTTPU (HTTP over UDP). Its typical // function is for HTTPMU, and particularly SSDP. type HTTPUClient struct { @@ -34,6 +56,7 @@ type HTTPUClient struct { } var _ ClientInterface = &HTTPUClient{} +var _ ClientInterfaceCtx = &HTTPUClient{} // NewHTTPUClient creates a new HTTPUClient, opening up a new UDP socket for the // purpose. @@ -75,6 +98,25 @@ func (httpu *HTTPUClient) Do( req *http.Request, timeout time.Duration, numSends int, +) ([]*http.Response, error) { + ctx := req.Context() + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + req = req.WithContext(ctx) + } + + return httpu.DoWithContext(req, numSends) +} + +// DoWithContext implements ClientInterfaceCtx.DoWithContext. +// +// Make sure to read the documentation on the ClientInterfaceCtx interface +// regarding cancellation! +func (httpu *HTTPUClient) DoWithContext( + req *http.Request, + numSends int, ) ([]*http.Response, error) { httpu.connLock.Lock() defer httpu.connLock.Unlock() @@ -101,10 +143,28 @@ func (httpu *HTTPUClient) Do( if err != nil { return nil, err } - if err = httpu.conn.SetDeadline(time.Now().Add(timeout)); err != nil { - return nil, err + + // Handle context deadline/timeout + ctx := req.Context() + deadline, ok := ctx.Deadline() + if ok { + if err = httpu.conn.SetDeadline(deadline); err != nil { + return nil, err + } } + // Handle context cancelation + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + // if context is cancelled, stop any connections by setting time in the past. + httpu.conn.SetDeadline(time.Now().Add(-time.Second)) + case <-done: + } + }() + // Send request. for i := 0; i < numSends; i++ { if n, err := httpu.conn.WriteTo(requestBuf.Bytes(), destAddr); err != nil { diff --git a/httpu/multiclient.go b/httpu/multiclient.go index 463ab7a..5cc65e9 100644 --- a/httpu/multiclient.go +++ b/httpu/multiclient.go @@ -49,14 +49,14 @@ func (mc *MultiClient) Do( } func (mc *MultiClient) sendRequests( - results chan<-[]*http.Response, + results chan<- []*http.Response, req *http.Request, timeout time.Duration, numSends int, ) error { tasks := &errgroup.Group{} for _, d := range mc.delegates { - d := d // copy for closure + d := d // copy for closure tasks.Go(func() error { responses, err := d.Do(req, timeout, numSends) if err != nil { @@ -68,3 +68,65 @@ func (mc *MultiClient) sendRequests( } return tasks.Wait() } + +// MultiClientCtx dispatches requests out to all the delegated clients. +type MultiClientCtx struct { + // The HTTPU clients to delegate to. + delegates []ClientInterfaceCtx +} + +var _ ClientInterfaceCtx = &MultiClientCtx{} + +// NewMultiClient creates a new MultiClient that delegates to all the given +// clients. +func NewMultiClientCtx(delegates []ClientInterfaceCtx) *MultiClientCtx { + return &MultiClientCtx{ + delegates: delegates, + } +} + +// DoWithContext implements ClientInterfaceCtx.DoWithContext. +func (mc *MultiClientCtx) DoWithContext( + req *http.Request, + numSends int, +) ([]*http.Response, error) { + tasks, ctx := errgroup.WithContext(req.Context()) + req = req.WithContext(ctx) // so we cancel if the errgroup errors + results := make(chan []*http.Response) + + // For each client, send the request to it and collect results. + tasks.Go(func() error { + defer close(results) + return mc.sendRequestsCtx(results, req, numSends) + }) + + var responses []*http.Response + tasks.Go(func() error { + for rs := range results { + responses = append(responses, rs...) + } + return nil + }) + + return responses, tasks.Wait() +} + +func (mc *MultiClientCtx) sendRequestsCtx( + results chan<- []*http.Response, + req *http.Request, + numSends int, +) error { + tasks := &errgroup.Group{} + for _, d := range mc.delegates { + d := d // copy for closure + tasks.Go(func() error { + responses, err := d.DoWithContext(req, numSends) + if err != nil { + return err + } + results <- responses + return nil + }) + } + return tasks.Wait() +} diff --git a/network.go b/network.go index e93763a..a2c3a45 100644 --- a/network.go +++ b/network.go @@ -10,14 +10,14 @@ import ( // httpuClient creates a HTTPU client that multiplexes to all multicast-capable // IPv4 addresses on the host. Returns a function to clean up once the client is // no longer required. -func httpuClient() (httpu.ClientInterface, func(), error) { +func httpuClient() (httpu.ClientInterfaceCtx, func(), error) { addrs, err := localIPv4MCastAddrs() if err != nil { return nil, nil, ctxError(err, "requesting host IPv4 addresses") } closers := make([]io.Closer, 0, len(addrs)) - delegates := make([]httpu.ClientInterface, 0, len(addrs)) + delegates := make([]httpu.ClientInterfaceCtx, 0, len(addrs)) for _, addr := range addrs { c, err := httpu.NewHTTPUClientAddr(addr) if err != nil { @@ -34,7 +34,7 @@ func httpuClient() (httpu.ClientInterface, func(), error) { } } - return httpu.NewMultiClient(delegates), closer, nil + return httpu.NewMultiClientCtx(delegates), closer, nil } // localIPv2MCastAddrs returns the set of IPv4 addresses on multicast-able diff --git a/ssdp/ssdp.go b/ssdp/ssdp.go index 240dfa7..2f318f3 100644 --- a/ssdp/ssdp.go +++ b/ssdp/ssdp.go @@ -35,6 +35,15 @@ type HTTPUClient interface { ) ([]*http.Response, error) } +// HTTPUClientCtx is an optional interface that will be used to perform +// HTTP-over-UDP requests if the client implements it. +type HTTPUClientCtx interface { + DoWithContext( + req *http.Request, + numSends int, + ) ([]*http.Response, error) +} + // SSDPRawSearchCtx performs a fairly raw SSDP search request, and returns the // unique response(s) that it receives. Each response has the requested // searchTarget, a USN, and a valid location. maxWaitSeconds states how long to @@ -49,8 +58,64 @@ func SSDPRawSearchCtx( maxWaitSeconds int, numSends int, ) ([]*http.Response, error) { + req, err := prepareRequest(ctx, searchTarget, maxWaitSeconds) + if err != nil { + return nil, err + } + + allResponses, err := httpu.Do(req, time.Duration(maxWaitSeconds)*time.Second+100*time.Millisecond, numSends) + if err != nil { + return nil, err + } + return processSSDPResponses(searchTarget, allResponses) +} + +// RawSearch performs a fairly raw SSDP search request, and returns the +// unique response(s) that it receives. Each response has the requested +// searchTarget, a USN, and a valid location. If the provided context times out +// or is canceled, the search will be aborted. numSends is the number of +// requests to send - 3 is a reasonable value for this. +// +// The provided context should have a deadline, since the SSDP protocol +// requires the max wait time be included in search requests. If the context +// has no deadline, then a default deadline of 3 seconds will be applied. +func RawSearch( + ctx context.Context, + httpu HTTPUClientCtx, + searchTarget string, + numSends int, +) ([]*http.Response, error) { + // We need a timeout value to include in the SSDP request; get it by + // checking the deadline on the context. + var maxWaitSeconds int + if deadline, ok := ctx.Deadline(); ok { + maxWaitSeconds = int(deadline.Sub(time.Now()) / time.Second) + } else { + // Pick a default timeout of 3 seconds if none was provided. + maxWaitSeconds = 3 + + var cancel func() + ctx, cancel = context.WithTimeout(ctx, time.Duration(maxWaitSeconds)*time.Second) + defer cancel() + } + + req, err := prepareRequest(ctx, searchTarget, maxWaitSeconds) + if err != nil { + return nil, err + } + + allResponses, err := httpu.DoWithContext(req, numSends) + if err != nil { + return nil, err + } + return processSSDPResponses(searchTarget, allResponses) +} + +// prepareRequest checks the provided parameters and constructs a SSDP search +// request to be sent. +func prepareRequest(ctx context.Context, searchTarget string, maxWaitSeconds int) (*http.Request, error) { if maxWaitSeconds < 1 { - return nil, errors.New("ssdp: maxWaitSeconds must be >= 1") + return nil, errors.New("ssdp: request timeout must be at least 1s") } req := (&http.Request{ @@ -67,11 +132,13 @@ func SSDPRawSearchCtx( "ST": []string{searchTarget}, }, }).WithContext(ctx) - allResponses, err := httpu.Do(req, time.Duration(maxWaitSeconds)*time.Second+100*time.Millisecond, numSends) - if err != nil { - return nil, err - } + return req, nil +} +func processSSDPResponses( + searchTarget string, + allResponses []*http.Response, +) ([]*http.Response, error) { isExactSearch := searchTarget != SSDPAll && searchTarget != UPNPRootDevice seenIDs := make(map[string]bool)