diff --git a/longpoll-bot/longpoll.go b/longpoll-bot/longpoll.go index 1f43e379..b230c4c1 100644 --- a/longpoll-bot/longpoll.go +++ b/longpoll-bot/longpoll.go @@ -10,13 +10,11 @@ import ( "encoding/json" "fmt" "net/http" - "sync/atomic" "github.com/SevereCloud/vksdk/v2" - "github.com/SevereCloud/vksdk/v2/internal" - "github.com/SevereCloud/vksdk/v2/api" "github.com/SevereCloud/vksdk/v2/events" + "github.com/SevereCloud/vksdk/v2/internal" ) // Response struct. @@ -35,9 +33,9 @@ type LongPoll struct { Wait int VK *api.VK Client *http.Client + cancel context.CancelFunc funcFullResponseList []func(Response) - inShutdown int32 events.FuncList } @@ -105,12 +103,15 @@ func (lp *LongPoll) updateServer(updateTs bool) error { return nil } -func (lp *LongPoll) check() (Response, error) { - var response Response - +func (lp *LongPoll) check(ctx context.Context) (response Response, err error) { u := fmt.Sprintf("%s?act=a_check&key=%s&ts=%s&wait=%d", lp.Server, lp.Key, lp.Ts, lp.Wait) - resp, err := lp.Client.Get(u) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return response, err + } + + resp, err := lp.Client.Do(req) if err != nil { return response, err } @@ -143,12 +144,12 @@ func (lp *LongPoll) checkResponse(response Response) (err error) { return } -func (lp *LongPoll) autoSetting() error { +func (lp *LongPoll) autoSetting(ctx context.Context) error { params := api.Params{ "group_id": lp.GroupID, "enabled": true, "api_version": vksdk.API, - } + }.WithContext(ctx) for _, event := range lp.ListEvents() { params[string(event)] = true } @@ -161,39 +162,55 @@ func (lp *LongPoll) autoSetting() error { // Run handler. func (lp *LongPoll) Run() error { - atomic.StoreInt32(&lp.inShutdown, 0) + return lp.RunWithContext(context.Background()) +} + +// RunWithContext handler. +func (lp *LongPoll) RunWithContext(ctx context.Context) error { + return lp.run(ctx) +} + +func (lp *LongPoll) run(ctx context.Context) error { + ctx, lp.cancel = context.WithCancel(ctx) - err := lp.autoSetting() + err := lp.autoSetting(ctx) if err != nil { return err } - for atomic.LoadInt32(&lp.inShutdown) == 0 { - resp, err := lp.check() - if err != nil { - return err - } - - ctx := context.WithValue(context.Background(), internal.LongPollTsKey, resp.Ts) - - for _, event := range resp.Updates { - err = lp.Handler(ctx, event) + for { + select { + case _, ok := <-ctx.Done(): + if !ok { + return nil + } + default: + resp, err := lp.check(ctx) if err != nil { return err } - } - for _, f := range lp.funcFullResponseList { - f(resp) + ctx = context.WithValue(ctx, internal.LongPollTsKey, resp.Ts) + + for _, event := range resp.Updates { + err = lp.Handler(ctx, event) + if err != nil { + return err + } + } + + for _, f := range lp.funcFullResponseList { + f(resp) + } } } - - return nil } // Shutdown gracefully shuts down the longpoll without interrupting any active connections. func (lp *LongPoll) Shutdown() { - atomic.StoreInt32(&lp.inShutdown, 1) + if lp.cancel != nil { + lp.cancel() + } } // FullResponse handler. diff --git a/longpoll-bot/longpoll_test.go b/longpoll-bot/longpoll_test.go index ffb673fe..117923fd 100644 --- a/longpoll-bot/longpoll_test.go +++ b/longpoll-bot/longpoll_test.go @@ -1,22 +1,66 @@ package longpoll import ( + "context" + "errors" "os" "strconv" "testing" + "time" "github.com/SevereCloud/vksdk/v2/api" + "github.com/SevereCloud/vksdk/v2/events" ) func TestLongPoll_Shutdown(t *testing.T) { t.Parallel() - lp := &LongPoll{} + groupToken := os.Getenv("GROUP_TOKEN") + if groupToken == "" { + t.Skip("GROUP_TOKEN empty") + } - lp.Shutdown() + userToken := os.Getenv("USER_TOKEN") + if userToken == "" { + t.Skip("USER_TOKEN empty") + } - if lp.inShutdown != 1 { - t.Error("inShutdown != 1") + vk := api.NewVK(groupToken) + lp, _ := NewLongPollCommunity(vk) + lp.MessageNew(func(ctx context.Context, obj events.MessageNewObject) { + lp.Shutdown() + }) + + c1 := make(chan string) + + go func() { + err := lp.Run() + if err != nil && !errors.Is(err, context.Canceled) { + t.Error(err) + } + + c1 <- "one" + }() + + time.Sleep(time.Millisecond * 300) + + vkUser := api.NewVK(userToken) + + _, err := vkUser.MessagesSend(api.Params{ + "peer_id": -lp.GroupID, + "random_id": 0, + "message": "lp.Shutdown()", + }) + if err != nil { + t.Fatal(err) + } + + // time.Sleep(time.Millisecond * 300) + select { + case <-time.After(time.Second * 3): + lp.Shutdown() + t.Fatal("timeout") + case <-c1: } } @@ -151,6 +195,10 @@ func TestLongPoll_RunError(t *testing.T) { t.Error(err) } + if err := lp.RunWithContext(context.Background()); err == nil { + t.Error(err) + } + lp.Server = "http://example.com" if err := lp.Run(); err == nil {