diff --git a/client.go b/client.go index dd997c2..9188bcd 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package botgolang import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -23,6 +24,10 @@ type Client struct { } func (c *Client) Do(path string, params url.Values, file *os.File) ([]byte, error) { + return c.DoWithContext(context.Background(), path, params, file) +} + +func (c *Client) DoWithContext(ctx context.Context, path string, params url.Values, file *os.File) ([]byte, error) { apiURL, err := url.Parse(c.baseURL + path) params.Set("token", c.token) @@ -31,7 +36,7 @@ func (c *Client) Do(path string, params url.Values, file *os.File) ([]byte, erro } apiURL.RawQuery = params.Encode() - req, err := http.NewRequest(http.MethodGet, apiURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL.String(), nil) if err != nil || req == nil { return nil, fmt.Errorf("cannot init http request: %s", err) } @@ -558,13 +563,17 @@ func (c *Client) UploadVoice(message *Message) error { } func (c *Client) GetEvents(lastEventID int, pollTime int) ([]*Event, error) { + return c.GetEventsWithContext(context.Background(), lastEventID, pollTime) +} + +func (c *Client) GetEventsWithContext(ctx context.Context, lastEventID int, pollTime int) ([]*Event, error) { params := url.Values{ "lastEventId": {strconv.Itoa(lastEventID)}, "pollTime": {strconv.Itoa(pollTime)}, } events := &eventsResponse{} - response, err := c.Do("/events/get", params, nil) + response, err := c.DoWithContext(ctx, "/events/get", params, nil) if err != nil { return events.Events, fmt.Errorf("error while making request: %s", err) } diff --git a/updates.go b/updates.go index 5abb25b..efc8fc0 100644 --- a/updates.go +++ b/updates.go @@ -36,7 +36,7 @@ func (u *Updater) NewMessageFromPayload(message EventPayload) *Message { } func (u *Updater) RunUpdatesCheck(ctx context.Context, ch chan<- Event) { - _, err := u.GetLastEvents(0) + _, err := u.GetLastEventsWithContext(ctx, 0) if err != nil { u.logger.WithFields(logrus.Fields{ "err": err, @@ -49,7 +49,7 @@ func (u *Updater) RunUpdatesCheck(ctx context.Context, ch chan<- Event) { close(ch) return default: - events, err := u.GetLastEvents(u.PollTime) + events, err := u.GetLastEventsWithContext(ctx, u.PollTime) if err != nil { u.logger.WithFields(logrus.Fields{ "err": err, @@ -71,7 +71,11 @@ func (u *Updater) RunUpdatesCheck(ctx context.Context, ch chan<- Event) { } func (u *Updater) GetLastEvents(pollTime int) ([]*Event, error) { - events, err := u.client.GetEvents(u.lastEventID, pollTime) + return u.GetLastEventsWithContext(context.Background(), pollTime) +} + +func (u *Updater) GetLastEventsWithContext(ctx context.Context, pollTime int) ([]*Event, error) { + events, err := u.client.GetEventsWithContext(ctx, u.lastEventID, pollTime) if err != nil { u.logger.WithFields(logrus.Fields{ "err": err,