Fix a Bug in Initialization Logic of Bot (#5030)

Co-authored-by: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com>
This commit is contained in:
Gritty_dev
2025-11-11 18:09:15 -05:00
committed by GitHub
parent 1fd084cb9c
commit 98d9908937
5 changed files with 435 additions and 23 deletions
@@ -0,0 +1,5 @@
bugfixes = "Fix a Bug in Initialization Logic of ``Bot``"
[[pull_requests]]
uid = "5030"
author_uids = ["codomposer"]
closes_threads = ["5021"]
+16 -10
View File
@@ -302,11 +302,12 @@ class Bot(TelegramObject, contextlib.AbstractAsyncContextManager["Bot"]):
__slots__ = (
"_base_file_url",
"_base_url",
"_bot_initialized",
"_bot_user",
"_initialized",
"_local_mode",
"_private_key",
"_request",
"_requests_initialized",
"_token",
)
@@ -334,7 +335,8 @@ class Bot(TelegramObject, contextlib.AbstractAsyncContextManager["Bot"]):
self._local_mode: bool = local_mode
self._bot_user: User | None = None
self._private_key: bytes | None = None
self._initialized: bool = False
self._requests_initialized: bool = False
self._bot_initialized: bool = False
self._request: tuple[BaseRequest, BaseRequest] = (
(
@@ -837,18 +839,21 @@ class Bot(TelegramObject, contextlib.AbstractAsyncContextManager["Bot"]):
.. versionadded:: 20.0
"""
if self._initialized:
if self._requests_initialized and self._bot_initialized:
self._LOGGER.debug("This Bot is already initialized.")
return
await asyncio.gather(self._request[0].initialize(), self._request[1].initialize())
# this needs to be set before we call get_me, since this can trigger an error in the
# request backend, which would then NOT lead to a proper shutdown if this flag isn't set
self._initialized = True
# Initialize request objects if not already done
if not self._requests_initialized:
await asyncio.gather(self._request[0].initialize(), self._request[1].initialize())
self._requests_initialized = True
# Initialize bot user
# Since the bot is to be initialized only once, we can also use it for
# verifying the token passed and raising an exception if it's invalid.
try:
await self.get_me()
self._bot_initialized = True
except InvalidToken as exc:
raise InvalidToken(f"The token `{self._token}` was rejected by the server.") from exc
@@ -860,12 +865,13 @@ class Bot(TelegramObject, contextlib.AbstractAsyncContextManager["Bot"]):
.. versionadded:: 20.0
"""
if not self._initialized:
if not self._requests_initialized:
self._LOGGER.debug("This Bot is already shut down. Returning.")
return
await asyncio.gather(self._request[0].shutdown(), self._request[1].shutdown())
self._initialized = False
self._requests_initialized = False
self._bot_initialized = False
async def do_api_request(
self,
@@ -943,7 +949,7 @@ class Bot(TelegramObject, contextlib.AbstractAsyncContextManager["Bot"]):
# 2) correct tokens but non-existing method, e.g. api.tg.org/botTOKEN/unkonwnMethod
# 2) is relevant only for Bot.do_api_request, that's why we have special handling for
# that here rather than in BaseRequest._request_wrapper
if self._initialized:
if self._bot_initialized:
raise EndPointNotFound(
f"Endpoint '{camel_case_endpoint}' not found in Bot API"
) from exc
+39 -12
View File
@@ -103,6 +103,34 @@ async def network_retry_loop(
log_prefix = f"Network Retry Loop ({description}):"
effective_is_running = is_running or (lambda: True)
def check_max_retries_and_log(current_retries: int, exception_info: str = "") -> bool:
"""Check if max retries reached and log accordingly.
Args:
current_retries: The current retry count.
exception_info: Additional context about the exception (e.g., "Timed out: ...").
Returns:
bool: True if max retries reached (should abort), False otherwise (should retry).
"""
prefix_with_info = f"{log_prefix} {exception_info}" if exception_info else log_prefix
if max_retries < 0 or current_retries < max_retries:
_LOGGER.debug(
"%s Failed run number %s of %s. Retrying.",
prefix_with_info,
current_retries,
max_retries,
)
return False
_LOGGER.exception(
"%s Failed run number %s of %s. Aborting.",
prefix_with_info,
current_retries,
max_retries,
)
return True
async def do_action() -> None:
if not stop_event:
await action_cb()
@@ -136,15 +164,21 @@ async def network_retry_loop(
break
except RetryAfter as exc:
slack_time = 0.5
_LOGGER.info(
"%s %s. Adding %s seconds to the specified time.", log_prefix, exc, slack_time
)
# pylint: disable=protected-access
cur_interval = slack_time + exc._retry_after.total_seconds()
exception_info = f"{exc}. Adding {slack_time} seconds to the specified time."
# Check max_retries for RetryAfter as well
if check_max_retries_and_log(retries, exception_info):
raise
except TimedOut as toe:
_LOGGER.debug("%s Timed out: %s. Retrying immediately.", log_prefix, toe)
# If failure is due to timeout, we should retry asap.
cur_interval = 0
exception_info = f"Timed out: {toe}."
# Check max_retries for TimedOut as well
if check_max_retries_and_log(retries, exception_info):
raise
except InvalidToken:
_LOGGER.exception("%s Invalid token. Aborting retry loop.", log_prefix)
raise
@@ -152,14 +186,7 @@ async def network_retry_loop(
if on_err_cb:
on_err_cb(telegram_exc)
if max_retries < 0 or retries < max_retries:
_LOGGER.debug(
"%s Failed run number %s of %s. Retrying.", log_prefix, retries, max_retries
)
else:
_LOGGER.exception(
"%s Failed run number %s of %s. Aborting.", log_prefix, retries, max_retries
)
if check_max_retries_and_log(retries):
raise
# increase waiting times on subsequent errors up to 30secs
+245
View File
@@ -0,0 +1,245 @@
#!/usr/bin/env python
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2025
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains tests for the network_retry_loop function.
Note:
Most of the retry loop functionality is already covered in test_updater and test_application.
These tests focus specifically on the max_retries behavior for different exception types
and the error callback handling, which were added as part of the bug fix in #5030.
"""
import pytest
from telegram.error import InvalidToken, RetryAfter, TelegramError, TimedOut
from telegram.ext._utils.networkloop import network_retry_loop
class TestNetworkRetryLoop:
"""Tests for the network_retry_loop function.
Note:
The general retry loop functionality is extensively tested in test_updater and
test_application. These tests focus on the specific max_retries behavior for
different exception types.
"""
@pytest.mark.parametrize(
("exception_class", "exception_args"),
[
(RetryAfter, (1,)),
(TimedOut, ("Test timeout",)),
],
ids=["RetryAfter", "TimedOut"],
)
async def test_exception_respects_max_retries(self, exception_class, exception_args):
"""Test that RetryAfter and TimedOut exceptions respect max_retries limit."""
call_count = 0
async def action_with_exception():
nonlocal call_count
call_count += 1
raise exception_class(*exception_args)
with pytest.raises(exception_class):
await network_retry_loop(
action_cb=action_with_exception,
description=f"Test {exception_class.__name__}",
interval=0,
max_retries=2,
)
# Should be called 3 times: initial call + 2 retries
assert call_count == 3
@pytest.mark.parametrize(
("exception_class", "exception_args"),
[
(RetryAfter, (1,)),
(TimedOut, ("Test timeout",)),
],
ids=["RetryAfter", "TimedOut"],
)
async def test_exception_with_zero_max_retries(self, exception_class, exception_args):
"""Test that RetryAfter and TimedOut with max_retries=0 don't retry."""
call_count = 0
async def action_with_exception():
nonlocal call_count
call_count += 1
raise exception_class(*exception_args)
with pytest.raises(exception_class):
await network_retry_loop(
action_cb=action_with_exception,
description=f"Test {exception_class.__name__} no retries",
interval=0,
max_retries=0,
)
# Should be called only once with max_retries=0
assert call_count == 1
async def test_invalid_token_aborts_immediately(self):
"""Test that InvalidToken exceptions abort immediately without retries."""
call_count = 0
async def action_with_invalid_token():
nonlocal call_count
call_count += 1
raise InvalidToken("Invalid token")
with pytest.raises(InvalidToken):
await network_retry_loop(
action_cb=action_with_invalid_token,
description="Test InvalidToken",
interval=0,
max_retries=5,
)
# Should be called only once, no retries for invalid token
assert call_count == 1
async def test_telegram_error_respects_max_retries(self):
"""Test that general TelegramError exceptions respect max_retries limit."""
call_count = 0
async def action_with_telegram_error():
nonlocal call_count
call_count += 1
raise TelegramError("Test error")
with pytest.raises(TelegramError):
await network_retry_loop(
action_cb=action_with_telegram_error,
description="Test TelegramError",
interval=0,
max_retries=3,
)
# Should be called 4 times: initial call + 3 retries
assert call_count == 4
@pytest.mark.parametrize(
("exception_class", "exception_args"),
[
(RetryAfter, (1,)),
(TimedOut, ("Test timeout",)),
(InvalidToken, ("Invalid token",)),
],
ids=["RetryAfter", "TimedOut", "InvalidToken"],
)
async def test_error_callback_not_called_for_specific_exceptions(
self, exception_class, exception_args
):
"""Test that error callback is not called for RetryAfter, TimedOut, or InvalidToken."""
error_callback_called = False
def error_callback(exc):
nonlocal error_callback_called
error_callback_called = True
async def action_with_exception():
raise exception_class(*exception_args)
with pytest.raises(exception_class):
await network_retry_loop(
action_cb=action_with_exception,
on_err_cb=error_callback,
description=f"Test {exception_class.__name__} callback",
interval=0,
max_retries=1,
)
assert not error_callback_called
async def test_error_callback_called_for_telegram_error(self):
"""Test that error callback is called for general TelegramError exceptions."""
error_callback_count = 0
caught_exception = None
def error_callback(exc):
nonlocal error_callback_count, caught_exception
error_callback_count += 1
caught_exception = exc
async def action_with_telegram_error():
raise TelegramError("Test error")
with pytest.raises(TelegramError):
await network_retry_loop(
action_cb=action_with_telegram_error,
on_err_cb=error_callback,
description="Test TelegramError callback",
interval=0,
max_retries=2,
)
# Should be called 3 times (initial + 2 retries)
assert error_callback_count == 3
assert isinstance(caught_exception, TelegramError)
async def test_success_after_retries(self):
"""Test that action succeeds after some retries."""
call_count = 0
async def action_succeeds_on_third_try():
nonlocal call_count
call_count += 1
if call_count < 3:
raise TimedOut("Test timeout")
# Success on third try
await network_retry_loop(
action_cb=action_succeeds_on_third_try,
description="Test success after retries",
interval=0,
max_retries=5,
)
assert call_count == 3
@pytest.mark.parametrize(
("exception_class", "exception_args", "success_after"),
[
(RetryAfter, (0.01,), 5),
(TimedOut, ("Test timeout",), 4),
],
ids=["RetryAfter", "TimedOut"],
)
async def test_exception_with_negative_max_retries(
self, exception_class, exception_args, success_after
):
"""Test that exceptions with max_retries=-1 retry indefinitely until success."""
call_count = 0
async def action_succeeds_after_few_tries():
nonlocal call_count
call_count += 1
if call_count < success_after:
raise exception_class(*exception_args)
# Success after specified tries
await network_retry_loop(
action_cb=action_succeeds_after_few_tries,
description=f"Test {exception_class.__name__} infinite retries",
interval=0,
max_retries=-1,
)
assert call_count == success_after
+130 -1
View File
@@ -93,7 +93,7 @@ from telegram.constants import (
ParseMode,
ReactionEmoji,
)
from telegram.error import BadRequest, EndPointNotFound, InvalidToken
from telegram.error import BadRequest, EndPointNotFound, InvalidToken, TimedOut
from telegram.ext import ExtBot, InvalidCallbackData
from telegram.helpers import escape_markdown
from telegram.request import BaseRequest, HTTPXRequest, RequestData
@@ -373,6 +373,59 @@ class TestBotWithoutRequest:
assert self.received["init"] == 2
assert self.received["shutdown"] == 2
async def test_initialize_with_get_me_failure_then_success(self, offline_bot, monkeypatch):
"""Test that bot can recover from get_me failure during initialization."""
get_me_call_count = 0
request_init_count = 0
test_bot = PytestBot(token=offline_bot.token, request=OfflineRequest())
original_get_me = test_bot.get_me
original_request_init = test_bot.request.initialize
async def failing_then_succeeding_get_me(*args, **kwargs):
nonlocal get_me_call_count
get_me_call_count += 1
if get_me_call_count == 1:
# First call fails
raise TimedOut("Test timeout")
# Subsequent calls succeed
return await original_get_me(*args, **kwargs)
async def counting_request_init(*args, **kwargs):
nonlocal request_init_count
request_init_count += 1
await original_request_init(*args, **kwargs)
monkeypatch.setattr(test_bot, "get_me", failing_then_succeeding_get_me)
monkeypatch.setattr(test_bot.request, "initialize", counting_request_init)
try:
# First initialize attempt should fail due to get_me timeout
with pytest.raises(TimedOut):
await test_bot.initialize()
# Request initialization should have been called (once per initialize call)
assert request_init_count == 1
# get_me should have been called once and failed
assert get_me_call_count == 1
# Second initialize attempt should succeed
await test_bot.initialize()
# Request initialization should not be called again (still 1)
assert request_init_count == 1
# get_me should have been called a second time and succeeded
assert get_me_call_count == 2
# Verify bot is now accessible
assert test_bot.bot.id == offline_bot.id
# Third initialize attempt should be a no-op (both flags already True)
await test_bot.initialize()
# Neither should be called again
assert request_init_count == 1
assert get_me_call_count == 2
finally:
await test_bot.shutdown()
async def test_context_manager(self, monkeypatch, offline_bot):
async def initialize():
self.test_flag = ["initialize"]
@@ -4633,3 +4686,79 @@ class TestBotWithRequest:
balance = await bot.get_my_star_balance()
assert isinstance(balance, StarAmount)
assert balance.amount == 0
async def test_initialize_tracks_requests_and_bot_separately(self, offline_bot, monkeypatch):
"""Test that requests and bot user are initialized separately and only once."""
request_init_count = 0
get_me_call_count = 0
async def counting_request_init(*args, **kwargs):
nonlocal request_init_count
request_init_count += 1
original_get_me = offline_bot.get_me
async def counting_get_me(*args, **kwargs):
nonlocal get_me_call_count
get_me_call_count += 1
return await original_get_me(*args, **kwargs)
test_bot = PytestBot(token=offline_bot.token, request=OfflineRequest())
monkeypatch.setattr(test_bot.request, "initialize", counting_request_init)
monkeypatch.setattr(test_bot, "get_me", counting_get_me)
try:
# First initialization
await test_bot.initialize()
assert request_init_count == 1
assert get_me_call_count == 1
# Second initialization should not call either again
await test_bot.initialize()
assert request_init_count == 1
assert get_me_call_count == 1
finally:
await test_bot.shutdown()
async def test_shutdown_allows_reinitialization(self, offline_bot, monkeypatch):
"""Test that after shutdown, bot can be reinitialized."""
request_init_count = 0
request_shutdown_count = 0
get_me_call_count = 0
async def counting_request_init(*args, **kwargs):
nonlocal request_init_count
request_init_count += 1
async def counting_request_shutdown(*args, **kwargs):
nonlocal request_shutdown_count
request_shutdown_count += 1
original_get_me = offline_bot.get_me
async def counting_get_me(*args, **kwargs):
nonlocal get_me_call_count
get_me_call_count += 1
return await original_get_me(*args, **kwargs)
test_bot = PytestBot(token=offline_bot.token, request=OfflineRequest())
monkeypatch.setattr(test_bot.request, "initialize", counting_request_init)
monkeypatch.setattr(test_bot.request, "shutdown", counting_request_shutdown)
monkeypatch.setattr(test_bot, "get_me", counting_get_me)
try:
# First initialization
await test_bot.initialize()
assert request_init_count == 1
assert get_me_call_count == 1
# Shutdown
await test_bot.shutdown()
assert request_shutdown_count == 1
# Re-initialize should call everything again
await test_bot.initialize()
assert request_init_count == 2
assert get_me_call_count == 2
finally:
await test_bot.shutdown()