Merge pull request #33 from npazosmendez/njpm/context feat: some support for context propagation
OvyFlash 46941696+OvyFlash@users.noreply.github.com
Sun, 08 Sep 2024 19:53:42 +0300
1 files changed,
58 insertions(+),
20 deletions(-)
jump to
M
bot.go
→
bot.go
@@ -3,6 +3,7 @@ // the Telegram Bot API.
package tgbotapi import ( + "context" "encoding/json" "errors" "fmt"@@ -11,6 +12,7 @@ "mime/multipart"
"net/http" "net/url" "strings" + "sync" "time" )@@ -25,11 +27,13 @@ Token string `json:"token"`
Debug bool `json:"debug"` Buffer int `json:"buffer"` - Self User `json:"-"` - Client HTTPClient `json:"-"` - shutdownChannel chan interface{} + Self User `json:"-"` + Client HTTPClient `json:"-"` apiEndpoint string + + stoppers []context.CancelFunc + mu sync.RWMutex } // NewBotAPI creates a new BotAPI instance.@@ -53,10 +57,9 @@ //
// It requires a token, provided by @BotFather on Telegram and API endpoint. func NewBotAPIWithClient(token, apiEndpoint string, client HTTPClient) (*BotAPI, error) { bot := &BotAPI{ - Token: token, - Client: client, - Buffer: 100, - shutdownChannel: make(chan interface{}), + Token: token, + Client: client, + Buffer: 100, apiEndpoint: apiEndpoint, }@@ -92,6 +95,10 @@ }
// MakeRequest makes a request to a specific endpoint with our token. func (bot *BotAPI) MakeRequest(endpoint string, params Params) (*APIResponse, error) { + return bot.MakeRequestWithContext(context.Background(), endpoint, params) +} + +func (bot *BotAPI) MakeRequestWithContext(ctx context.Context, endpoint string, params Params) (*APIResponse, error) { if bot.Debug { log.Printf("Endpoint: %s, params: %v\n", endpoint, params) }@@ -100,7 +107,7 @@ method := fmt.Sprintf(bot.apiEndpoint, bot.Token, endpoint)
values := buildParams(params) - req, err := http.NewRequest("POST", method, strings.NewReader(values.Encode())) + req, err := http.NewRequestWithContext(ctx, "POST", method, strings.NewReader(values.Encode())) if err != nil { return &APIResponse{}, err }@@ -165,6 +172,10 @@ }
// UploadFiles makes a request to the API with files. func (bot *BotAPI) UploadFiles(endpoint string, params Params, files []RequestFile) (*APIResponse, error) { + return bot.UploadFilesWithContext(context.Background(), endpoint, params, files) +} + +func (bot *BotAPI) UploadFilesWithContext(ctx context.Context, endpoint string, params Params, files []RequestFile) (*APIResponse, error) { r, w := io.Pipe() m := multipart.NewWriter(w)@@ -223,7 +234,7 @@ }
method := fmt.Sprintf(bot.apiEndpoint, bot.Token, endpoint) - req, err := http.NewRequest("POST", method, r) + req, err := http.NewRequestWithContext(ctx, "POST", method, r) if err != nil { return nil, err }@@ -281,7 +292,11 @@ // This method is called upon creation to validate the token,
// and so you may get this data from BotAPI.Self without the need for // another request. func (bot *BotAPI) GetMe() (User, error) { - resp, err := bot.MakeRequest("getMe", nil) + return bot.GetMeWithContext(context.Background()) +} + +func (bot *BotAPI) GetMeWithContext(ctx context.Context) (User, error) { + resp, err := bot.MakeRequestWithContext(ctx, "getMe", nil) if err != nil { return User{}, err }@@ -311,6 +326,10 @@ }
// Request sends a Chattable to Telegram, and returns the APIResponse. func (bot *BotAPI) Request(c Chattable) (*APIResponse, error) { + return bot.RequestWithContext(context.Background(), c) +} + +func (bot *BotAPI) RequestWithContext(ctx context.Context, c Chattable) (*APIResponse, error) { params, err := c.params() if err != nil { return nil, err@@ -332,7 +351,7 @@ params[file.Name] = file.Data.SendData()
} } - return bot.MakeRequest(c.method(), params) + return bot.MakeRequestWithContext(ctx, c.method(), params) } // Send will send a Chattable item to Telegram and provides the@@ -401,7 +420,11 @@ // To avoid stale items, set Offset to one higher than the previous item.
// Set Timeout to a large number to reduce requests, so you can get updates // instantly instead of having to wait between requests. func (bot *BotAPI) GetUpdates(config UpdateConfig) ([]Update, error) { - resp, err := bot.Request(config) + return bot.GetUpdatesWithContext(context.Background(), config) +} + +func (bot *BotAPI) GetUpdatesWithContext(ctx context.Context, config UpdateConfig) ([]Update, error) { + resp, err := bot.RequestWithContext(ctx, config) if err != nil { return []Update{}, err }@@ -415,7 +438,11 @@
// GetWebhookInfo allows you to fetch information about a webhook and if // one currently is set, along with pending update count and error messages. func (bot *BotAPI) GetWebhookInfo() (WebhookInfo, error) { - resp, err := bot.MakeRequest("getWebhookInfo", nil) + return bot.GetWebhookInfoWithContext(context.Background()) +} + +func (bot *BotAPI) GetWebhookInfoWithContext(ctx context.Context) (WebhookInfo, error) { + resp, err := bot.MakeRequestWithContext(ctx, "getWebhookInfo", nil) if err != nil { return WebhookInfo{}, err }@@ -430,21 +457,27 @@ // GetUpdatesChan starts and returns a channel for getting updates.
func (bot *BotAPI) GetUpdatesChan(config UpdateConfig) UpdatesChannel { ch := make(chan Update, bot.Buffer) + ctx, cancel := context.WithCancel(context.Background()) + bot.mu.Lock() + bot.stoppers = append(bot.stoppers, cancel) + bot.mu.Unlock() + go func() { for { select { - case <-bot.shutdownChannel: + case <-ctx.Done(): close(ch) return default: } - updates, err := bot.GetUpdates(config) + updates, err := bot.GetUpdatesWithContext(ctx, config) if err != nil { - log.Println(err) - log.Println("Failed to get updates, retrying in 3 seconds...") - time.Sleep(time.Second * 3) - + if ctx.Err() == nil { + log.Println(err) + log.Println("Failed to get updates, retrying in 3 seconds...") + time.Sleep(time.Second * 3) + } continue }@@ -462,10 +495,15 @@ }
// StopReceivingUpdates stops the go routine which receives updates func (bot *BotAPI) StopReceivingUpdates() { + bot.mu.Lock() + defer bot.mu.Unlock() + if bot.Debug { log.Println("Stopping the update receiver routine...") } - close(bot.shutdownChannel) + for _, stopper := range bot.stoppers { + stopper() + } } // ListenForWebhook registers a http handler for a webhook.