Skip to content

Commit

Permalink
Merge pull request #3 from synapsecns/ctx
Browse files Browse the repository at this point in the history
Add Context
  • Loading branch information
antelman107 authored Jun 23, 2021
2 parents a000cb2 + ff968b1 commit cf684ae
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
12 changes: 12 additions & 0 deletions wait/wait.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package wait

import (
"context"
"log"
"net"
"sync"
Expand All @@ -14,6 +15,7 @@ type Executor struct {
Deadline time.Duration
Debug bool
UDPPacket []byte
Context context.Context
}

type Option func(*Executor)
Expand All @@ -35,6 +37,7 @@ func New(opts ...Option) *Executor {
Deadline: defaultDeadline,
Debug: defaultDebug,
UDPPacket: []byte(defaultUDPPacket),
Context: context.Background(),
}

for _, opt := range opts {
Expand All @@ -50,6 +53,11 @@ func WithProto(proto string) Option {
}
}

func WithContext(ctx context.Context) Option {
return func(h *Executor) {
h.Context = ctx
}
}
func WithWait(wait time.Duration) Option {
return func(h *Executor) {
h.Wait = wait
Expand Down Expand Up @@ -96,6 +104,8 @@ func (e *Executor) Do(addrs []string) bool {
select {
case <-deadlineCh:
return
case <-e.Context.Done():
return
default:
if e.Proto == "udp" {
if !e.doUDP(addr) {
Expand All @@ -120,6 +130,8 @@ func (e *Executor) Do(addrs []string) bool {
}()

select {
case <-e.Context.Done():
return false
case <-deadlineCh:
return false
case <-successCh:
Expand Down
34 changes: 28 additions & 6 deletions wait/wait_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package wait

import (
"context"
"io"
"net"
"strings"
Expand Down Expand Up @@ -67,13 +68,16 @@ func getUDPServer(proto, addr string, t *testing.T) io.Closer {
}

func TestDO(t *testing.T) {
ctx := context.Background()

type data struct {
name string
addr string
reqAddr string
proto string
packet string
result bool
name string
addr string
reqAddr string
proto string
packet string
result bool
contextCancel bool
}

for _, row := range []data{
Expand Down Expand Up @@ -107,6 +111,15 @@ func TestDO(t *testing.T) {
packet: "1",
result: false,
},
{
name: "context cancel",
addr: "localhost:6433",
reqAddr: "localhost:6433",
proto: "udp",
packet: "1",
result: false,
contextCancel: true,
},
} {
r := row
t.Run(row.name, func(t *testing.T) {
Expand All @@ -118,12 +131,21 @@ func TestDO(t *testing.T) {
}
defer srv.Close()

ctx, cancel := context.WithCancel(ctx)
defer cancel()

if row.contextCancel {
cancel()
}

e := New(
WithProto(r.proto),
WithUDPPacket([]byte(r.packet)),
WithDebug(false),
WithDeadline(time.Second*2),
WithContext(ctx),
)

if e.Do([]string{r.reqAddr}) != r.result {
t.Errorf("%s result is not %#v", r.name, r.result)
}
Expand Down

0 comments on commit cf684ae

Please sign in to comment.