Compare commits

...

9 Commits

Author SHA1 Message Date
Hinrich Mahler b1fff6d90a Use Lock instead of semaphore 2025-02-06 12:05:20 +01:00
Hinrich Mahler 31af1a9db8 Add an example on concurrency in FSM 2025-02-06 11:57:18 +01:00
Hinrich Mahler 4441543043 Try setting up infrastructure for optimistic locking. Example will follow 2025-02-05 23:40:15 +01:00
Hinrich Mahler 646ba37391 Move internal state storage to FiniteStateMachine and add state history 2025-02-05 13:13:21 +01:00
Hinrich Mahler 817b71d914 Disable tests harder … 2025-02-05 12:26:02 +01:00
Hinrich Mahler 434cbfade8 Add Some Abstractions for Timeout Jobs 2025-02-05 12:22:07 +01:00
Hinrich Mahler 34832d9db9 Temporarily Disable Tests on this branch 2025-02-05 11:17:47 +01:00
Hinrich Mahler 07225b9a02 Add State.ANY for fallbacks and allow handling multiple states for one update 2025-02-05 10:52:10 +01:00
Hinrich Mahler 0c06ba0a90 Initial FSM PoC 2025-02-04 21:36:40 +01:00
14 changed files with 774 additions and 283 deletions
-34
View File
@@ -1,34 +0,0 @@
name: Check Links in Documentation
on:
schedule:
# First day of month at 05:46 in every 2nd month
- cron: '46 5 1 */2 *'
pull_request:
paths:
- .github/workflows/docs-linkcheck.yml
permissions: {}
jobs:
test-sphinx-build:
name: test-sphinx-linkcheck
runs-on: ${{matrix.os}}
strategy:
matrix:
python-version: ['3.10']
os: [ubuntu-latest]
fail-fast: False
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -W ignore -m pip install --upgrade pip
python -W ignore -m pip install -r requirements-dev-all.txt
- name: Check Links
run: sphinx-build docs/source docs/build/html -W --keep-going -j auto -b linkcheck
-53
View File
@@ -1,53 +0,0 @@
name: Test Documentation Build
on:
pull_request:
paths:
- telegram/**
- docs/**
push:
branches:
- master
permissions: {}
jobs:
test-sphinx-build:
name: test-sphinx-build
runs-on: ${{matrix.os}}
permissions:
# for uploading artifacts
actions: write
strategy:
matrix:
python-version: ['3.10']
os: [ubuntu-latest]
fail-fast: False
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements*.txt'
- name: Install dependencies
run: |
python -W ignore -m pip install --upgrade pip
python -W ignore -m pip install -r requirements-dev-all.txt
- name: Test autogeneration of admonitions
run: pytest -v --tb=short tests/docs/admonition_inserter.py
- name: Build docs
run: sphinx-build docs/source docs/build/html -W --keep-going -j auto
- name: Upload docs
uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
with:
name: HTML Docs
retention-days: 7
path: |
# Exclude the .doctrees folder and .buildinfo file from the artifact
# since they are not needed and add to the size
docs/build/html/*
!docs/build/html/.doctrees
!docs/build/html/.buildinfo
-51
View File
@@ -1,51 +0,0 @@
name: Bot API Tests
on:
pull_request:
paths:
- telegram/**
- tests/**
push:
branches:
- master
schedule:
# Run monday and friday morning at 03:07 - odd time to spread load on GitHub Actions
- cron: '7 3 * * 1,5'
permissions: {}
jobs:
check-conformity:
name: check-conformity
runs-on: ${{matrix.os}}
strategy:
matrix:
python-version: [3.11]
os: [ubuntu-latest]
fail-fast: False
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -W ignore -m pip install --upgrade pip
python -W ignore -m pip install .[all]
python -W ignore -m pip install -r requirements-unit-tests.txt
- name: Compare to official api
run: |
pytest -v tests/test_official/test_official.py --junit-xml=.test_report_official.xml
exit $?
env:
TEST_OFFICIAL: "true"
shell: bash --noprofile --norc {0}
- name: Test Summary
id: test_summary
uses: test-summary/action@31493c76ec9e7aa675f1585d3ed6f1da69269a86 # v2.4
if: always() # always run, even if tests fail
with:
paths: .test_report_official.xml
-23
View File
@@ -1,23 +0,0 @@
name: Check Type Completeness
on:
pull_request:
paths:
- telegram/**
- pyproject.toml
- .github/workflows/type_completeness.yml
push:
branches:
- master
permissions: {}
jobs:
test-type-completeness:
name: test-type-completeness
runs-on: ubuntu-latest
steps:
- uses: Bibo-Joshi/pyright-type-completeness@c85a67ff3c66f51dcbb2d06bfcf4fe83a57d69cc # v1.0.1
with:
package-name: telegram
python-version: 3.12
pyright-version: ~=1.1.367
File diff suppressed because one or more lines are too long
+203
View File
@@ -0,0 +1,203 @@
#!/usr/bin/env python
# pylint: disable=unused-argument
# This program is dedicated to the public domain under the CC0 license.
"""Simple state machine to handle user support.
One admin is supported. The admin can have one active conversation at a time. Other users
are put on hold until the admin finishes the current conversation.
In each conversation, the admin and the user take turns to send messages.
"""
import logging
from typing import Optional
from telegram import Update
from telegram.ext import (
Application,
CommandHandler,
ContextTypes,
FiniteStateMachine,
MessageHandler,
State,
StateInfo,
filters,
)
# Enable logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.DEBUG
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("telegram").setLevel(logging.WARNING)
logging.getLogger("telegram.ext.Application").setLevel(logging.DEBUG)
logger = logging.getLogger(__name__)
class UserSupportMachine(FiniteStateMachine[Optional[int]]):
HOLD = State("HOLD")
WELCOMING = State("WELCOMING")
WAITING_FOR_REPLY = State("WAITING_FOR_REPLY")
WRITING = State("WRITING")
def __init__(self, admin_id: int):
self.admin_id = admin_id
super().__init__()
def _get_admin_state(self) -> tuple[State, int]:
return self._states[self.admin_id]
def get_state_info(self, update: object) -> StateInfo[Optional[int]]:
if not isinstance(update, Update) or not (user := update.effective_user):
key = None
state, version = self.states[key]
return StateInfo(key=key, state=state, version=version)
# Admin is easy - just return the state
admin_state, admin_version = self._get_admin_state()
if user.id == self.admin_id:
logging.debug("Returning admin state: %s", admin_state)
return StateInfo(self.admin_id, admin_state, admin_version)
# If the user state is active in the conversation, we can just return that state
user_state, user_version = self._states[user.id]
if user_state.matches(self.WELCOMING | self.WRITING | self.WAITING_FOR_REPLY):
logging.debug("Returning user state: %s", user_state)
return StateInfo(user.id, user_state, user_version)
# On first interaction, we need to determine what to do with the user
# if the admin is not idle, we put the user on hold. Otherwise, they may send the first
# message, and we put the admin in waiting for reply to avoid another user occupying the
# admin first
effective_user_state = self.HOLD if admin_state != State.IDLE else self.WELCOMING
self._do_set_state(user.id, effective_user_state, user_version)
if effective_user_state == self.WELCOMING:
self._do_set_state(self.admin_id, self.WAITING_FOR_REPLY)
logging.debug("Returning user state: %s", effective_user_state)
return StateInfo(user.id, effective_user_state, user_version)
async def welcome_user(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await update.effective_message.forward(context.bot_data["admin_id"])
suffix = ""
if UserSupportMachine.HOLD in context.fsm.get_state_history(context.fsm_state_info.key)[:-1]:
suffix = " Thank you for patiently waiting. We hope you enjoyed the music."
await update.effective_message.reply_text(
"Welcome! Your message has been forwarded to the admin. "
f"They will get back to you soon.{suffix}"
)
await context.set_state(UserSupportMachine.WAITING_FOR_REPLY)
await context.fsm.set_state(context.bot_data["admin_id"], UserSupportMachine.WRITING)
context.bot_data["active_user"] = update.effective_user.id
async def conversation_timeout(context: ContextTypes.DEFAULT_TYPE) -> None:
active_user = context.bot_data.get("active_user")
admin_id = context.bot_data["admin_id"]
async def handle(user_id: int) -> None:
await context.bot.send_message(
user_id, "The conversation has been stopped due to inactivity."
)
await context.fsm.set_state(user_id, State.IDLE)
if active_user:
await handle(active_user)
await handle(admin_id)
async def handle_reply(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if not (active_user := context.bot_data.get("active_user")):
logger.warning("No active user found, ignoring message")
target = (
active_user
if update.effective_user.id == (admin_id := context.bot_data["admin_id"])
else admin_id
)
await context.bot.send_message(target, update.effective_message.text)
logging.debug("Forwarded message to %s", target)
await context.set_state(UserSupportMachine.WAITING_FOR_REPLY)
logging.debug("Done setting state to WAITING_FOR_REPLY for %s", target)
await context.fsm.set_state(target, UserSupportMachine.WRITING)
logging.debug("Done setting state to WRITING for %s, context.fsm_key")
context.fsm.schedule_timeout(
when=30,
callback=conversation_timeout,
cancel_keys=[active_user, admin_id],
)
async def stop_conversation(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
text = "The conversation has been stopped."
admin_id = context.bot_data["admin_id"]
active_user = context.bot_data.get("active_user")
await context.bot.send_message(admin_id, text)
await context.fsm.set_state(admin_id, State.IDLE)
if active_user:
await context.bot.send_message(active_user, text)
await context.fsm.set_state(active_user, State.IDLE)
async def hold_melody(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await update.effective_message.reply_text(
"You have been put on hold. The admin will get back to you soon. Please hear some music "
"while you wait: https://www.youtube.com/watch?v=dQw4w9WgXcQ"
)
async def not_your_turn(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await update.effective_message.reply_text(
"It's not your turn yet. Please wait for the other party to reply to your message."
)
async def unsupported_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await update.effective_message.reply_text("This message is not supported.")
def main() -> None:
application = Application.builder().token("TOKEN").build()
application.fsm = UserSupportMachine(admin_id=123456)
application.fsm.set_job_queue(application.job_queue)
application.bot_data["admin_id"] = application.fsm.admin_id
# Users are welcomed only if they are in the corresponding state
application.add_handler(
MessageHandler(~filters.User(application.fsm.admin_id) & filters.TEXT, welcome_user),
state=UserSupportMachine.WELCOMING,
)
# Conversation logic:
# * forward messages between user and admin
# * stop the conversation at any time (admin or user)
# * point out that the other party is currently writing
# Important: Order matters!
application.add_handler(
CommandHandler("stop", stop_conversation),
state=UserSupportMachine.WAITING_FOR_REPLY | UserSupportMachine.WRITING,
)
application.add_handler(
MessageHandler(filters.TEXT, handle_reply), state=UserSupportMachine.WRITING
)
application.add_handler(
MessageHandler(filters.TEXT, not_your_turn), state=UserSupportMachine.WAITING_FOR_REPLY
)
# If the admin is busy, put the user on hold
application.add_handler(
MessageHandler(filters.TEXT, hold_melody), state=UserSupportMachine.HOLD
)
# Fallback
application.add_handler(MessageHandler(filters.ALL, unsupported_message), state=State.ANY)
application.run_polling(allowed_updates=Update.ALL_TYPES)
if __name__ == "__main__":
main()
+172
View File
@@ -0,0 +1,172 @@
#!/usr/bin/env python
# pylint: disable=unused-argument
# This program is dedicated to the public domain under the CC0 license.
"""State machine bot showcasing how concurrency can be handled with FSM.
How to use:
* Use Case 1: Concurrent balance updates
- /unsafe_update <balance_update>: Unsafe update of the wallet balance. Send the command
multiple times in quick succession (less than 1 second) to see the effect
- /safe_update <balance_update>: Safe update of the wallet balance. Send the command
multiple times in quick succession (less than 1 second) to see the effect
* Use Case 2: Declare a winner - who is the fastest?
- /unsafe_declare_winner: Unsafe declaration of the user as winner. Send the command
multiple times in quick succession (less than 1 second) to see the effect. Needs restart
after the winner is declared.
- /safe_declare_winner: Safe declaration of the user as winner. Send the command
multiple times in quick succession (less than 1 second) to see the effect. Needs restart
after the winner is declared.
"""
import asyncio
import logging
from telegram import Update
from telegram.constants import ChatAction
from telegram.ext import (
Application,
CommandHandler,
ContextTypes,
FiniteStateMachine,
MessageHandler,
State,
StateInfo,
filters,
)
# Enable logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.DEBUG
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("telegram").setLevel(logging.WARNING)
logging.getLogger("telegram.ext.Application").setLevel(logging.DEBUG)
logger = logging.getLogger(__name__)
class ConcurrentMachine(FiniteStateMachine[None]):
"""This FSM only knows a global state for the whole bot"""
UPDATING_BALANCE = State("UPDATING_BALANCE")
WINNER_DECLARED = State("WINNER_DECLARED")
def get_state_info(self, update: object) -> StateInfo[None]:
state, version = self.states[None]
return StateInfo(key=None, state=state, version=version)
########################################
# Use case 1: Concurrent balance updates
########################################
async def update_balance(context: ContextTypes.DEFAULT_TYPE, update: Update) -> None:
initial_balance = context.bot_data.get("balance", 0)
balance_update = int(context.args[0])
# Simulate heavy computation
await update.effective_message.reply_text(
f"Initiating balance update: {initial_balance}. Updating ..."
)
await update.effective_chat.send_action(ChatAction.TYPING)
await asyncio.sleep(4.5)
new_balance = context.bot_data["balance"] = initial_balance + balance_update
await update.effective_message.reply_text(f"Balance updated. New balance: {new_balance}")
async def unsafe_update(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Unsafe update of the wallet balance"""
# Simulate heavy computation *before* the update is processed
await asyncio.sleep(1)
await context.fsm.set_state(context.fsm_state_info.key, ConcurrentMachine.UPDATING_BALANCE)
# At this point, the lock is released such that multiple updates can update
# the balance concurrently. This can lead to race conditions.
await update_balance(context, update)
await context.fsm.set_state(context.fsm_state_info.key, State.IDLE)
async def safe_update(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Safe update of the wallet balance"""
# Simulate heavy computation *before* the update is processed
await asyncio.sleep(1)
async with context.as_fsm_state(ConcurrentMachine.UPDATING_BALANCE):
# At this point, the lock is acquired such that only one update can update
# the balance at a time. This prevents race conditions.
await update_balance(context, update)
async def busy(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Busy state"""
await update.effective_message.reply_text("I'm busy, try again later.")
####################################################
# Use case 2: Declare a winner - who is the fastest?
####################################################
async def declare_winner_unsafe(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Declare the user as winner"""
# Simulate heavy computation *before* the update is processed
await asyncio.sleep(1)
# Unsafe state update: No version check, so the state might have already changed
await context.fsm.set_state(context.fsm_state_info.key, ConcurrentMachine.WINNER_DECLARED)
await update.effective_message.reply_text("You are the winner!")
async def declare_winner_safe(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Declare the user as winner"""
# Simulate heavy computation *before* the update is processed
await asyncio.sleep(1)
try:
await context.set_state(ConcurrentMachine.WINNER_DECLARED)
await update.effective_message.reply_text("You are the winner!")
except ValueError:
await update.effective_message.reply_text(
"Sorry, you are too late. Someone else was faster."
)
def main() -> None:
application = Application.builder().token("TOKEN").concurrent_updates(True).build()
application.fsm = ConcurrentMachine()
# Note: OR-combination of states is used here to allow both use cases to be handled
# in parallel. Not really necessary for the showcasing, just a nice touch :)
# Use case 2: Declare a winner - who is the fastest?
application.add_handler(
CommandHandler("unsafe_declare_winner", declare_winner_unsafe),
state=State.IDLE | ConcurrentMachine.UPDATING_BALANCE,
)
application.add_handler(
CommandHandler("safe_declare_winner", declare_winner_safe),
state=State.IDLE | ConcurrentMachine.UPDATING_BALANCE,
)
# Use case 1: Concurrent balance updates
application.add_handler(
CommandHandler("unsafe_update", unsafe_update, has_args=1),
state=State.IDLE | ConcurrentMachine.WINNER_DECLARED,
)
application.add_handler(
CommandHandler("safe_update", safe_update, has_args=1),
state=State.IDLE | ConcurrentMachine.WINNER_DECLARED,
)
# Order matters, so this needs to be added last
application.add_handler(
MessageHandler(filters.ALL, busy), state=ConcurrentMachine.UPDATING_BALANCE
)
application.run_polling(allowed_updates=Update.ALL_TYPES)
if __name__ == "__main__":
main()
+5
View File
@@ -42,6 +42,7 @@ __all__ = (
"Defaults",
"DictPersistence",
"ExtBot",
"FiniteStateMachine",
"InlineQueryHandler",
"InvalidCallbackData",
"Job",
@@ -57,6 +58,9 @@ __all__ = (
"PrefixHandler",
"ShippingQueryHandler",
"SimpleUpdateProcessor",
"SingleStateMachine",
"State",
"StateInfo",
"StringCommandHandler",
"StringRegexHandler",
"TypeHandler",
@@ -77,6 +81,7 @@ from ._contexttypes import ContextTypes
from ._defaults import Defaults
from ._dictpersistence import DictPersistence
from ._extbot import ExtBot
from ._fsm import FiniteStateMachine, SingleStateMachine, State, StateInfo
from ._handlers.basehandler import BaseHandler
from ._handlers.businessconnectionhandler import BusinessConnectionHandler
from ._handlers.businessmessagesdeletedhandler import BusinessMessagesDeletedHandler
+51 -13
View File
@@ -48,6 +48,7 @@ from telegram.error import TelegramError
from telegram.ext._basepersistence import BasePersistence
from telegram.ext._contexttypes import ContextTypes
from telegram.ext._extbot import ExtBot
from telegram.ext._fsm import SingleStateMachine, State, StateInfo
from telegram.ext._handlers.basehandler import BaseHandler
from telegram.ext._updater import Updater
from telegram.ext._utils.stack import was_called_by
@@ -59,7 +60,7 @@ if TYPE_CHECKING:
from socket import socket
from telegram import Message
from telegram.ext import ConversationHandler, JobQueue
from telegram.ext import ConversationHandler, FiniteStateMachine, JobQueue
from telegram.ext._applicationbuilder import InitApplicationBuilder
from telegram.ext._baseupdateprocessor import BaseUpdateProcessor
from telegram.ext._jobqueue import Job
@@ -266,6 +267,7 @@ class Application(
"update_queue",
"updater",
"user_data",
"fsm",
)
# Allowing '__weakref__' creation here since we need it for the JobQueue
# Currently the __weakref__ slot is already created
@@ -301,11 +303,12 @@ class Application(
stacklevel=2,
)
self.fsm: FiniteStateMachine = SingleStateMachine()
self.bot: BT = bot
self.update_queue: asyncio.Queue[object] = update_queue
self.context_types: ContextTypes[CCT, UD, CD, BD] = context_types
self.updater: Optional[Updater] = updater
self.handlers: dict[int, list[BaseHandler[Any, CCT, Any]]] = {}
self.handlers: dict[State, dict[int, list[BaseHandler[Any, CCT, Any]]]] = {}
self.error_handlers: dict[
HandlerCallback[object, CCT, None], Union[bool, DefaultValue[bool]]
] = {}
@@ -1278,19 +1281,46 @@ class Application(
# Processing updates before initialize() is a problem e.g. if persistence is used
self._check_initialized()
fsm_state_info = self.fsm.get_state_info(update)
for state, state_handlers in self.handlers.items():
if state.matches(fsm_state_info.state):
_LOGGER.debug("Processing in state %s", state)
was_handled = await self.__process_update_groups(
update, state_handlers, fsm_state_info
)
if was_handled:
_LOGGER.debug(
"Update was handled in state %s. Stopping further processing", state
)
return
_LOGGER.debug(
"No handlers found for key %s in state %s", fsm_state_info.key, fsm_state_info.state
)
return
async def __process_update_groups(
self,
update: object,
state_handlers: dict[int, list[BaseHandler]],
fsm_state_info: StateInfo,
) -> bool:
context = None
was_handled = False
any_blocking = False # Flag which is set to True if any handler specifies block=True
for handlers in self.handlers.values():
for handlers in state_handlers.values():
try:
for handler in handlers:
check = handler.check_update(update) # Should the handler handle this update?
if check is None or check is False:
continue
was_handled = True
if not context: # build a context if not already built
try:
context = self.context_types.context.from_update(update, self)
context.fsm_state_info = fsm_state_info
except Exception as exc:
_LOGGER.critical(
(
@@ -1300,7 +1330,7 @@ class Application(
update,
exc_info=exc,
)
return
return True
await context.refresh_data()
coroutine: Coroutine = handler.handle_update(update, self, check, context)
@@ -1340,7 +1370,14 @@ class Application(
# (in __create_task_callback)
self._mark_for_persistence_update(update=update)
def add_handler(self, handler: BaseHandler[Any, CCT, Any], group: int = DEFAULT_GROUP) -> None:
return was_handled
def add_handler(
self,
handler: BaseHandler[Any, CCT, Any],
group: int = DEFAULT_GROUP,
state: State = State.IDLE,
) -> None:
"""Register a handler.
TL;DR: Order and priority counts. 0 or 1 handlers per group will be used. End handling of
@@ -1399,11 +1436,11 @@ class Application(
stacklevel=2,
)
if group not in self.handlers:
self.handlers[group] = []
self.handlers = dict(sorted(self.handlers.items())) # lower -> higher groups
state_handlers = self.handlers.setdefault(state, {})
if group not in state_handlers:
state_handlers[group] = []
self.handlers[group].append(handler)
state_handlers[group].append(handler)
def add_handlers(
self,
@@ -1475,10 +1512,11 @@ class Application(
group (:obj:`object`, optional): The group identifier. Default is ``0``.
"""
if handler in self.handlers[group]:
self.handlers[group].remove(handler)
if not self.handlers[group]:
del self.handlers[group]
for state_handlers in self.handlers.values():
if handler in state_handlers[group]:
state_handlers[group].remove(handler)
if not state_handlers[group]:
del state_handlers[group]
def drop_chat_data(self, chat_id: int) -> None:
"""Drops the corresponding entry from the :attr:`chat_data`. Will also be deleted from
+22 -2
View File
@@ -17,8 +17,9 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains the CallbackContext class."""
import asyncio
from collections.abc import Awaitable, Generator
from contextlib import AbstractAsyncContextManager
from re import Match
from typing import TYPE_CHECKING, Any, Generic, NoReturn, Optional, TypeVar, Union
@@ -26,12 +27,13 @@ from telegram._callbackquery import CallbackQuery
from telegram._update import Update
from telegram._utils.warnings import warn
from telegram.ext._extbot import ExtBot
from telegram.ext._fsm import FiniteStateMachine, State
from telegram.ext._utils.types import BD, BT, CD, UD
if TYPE_CHECKING:
from asyncio import Future, Queue
from telegram.ext import Application, Job, JobQueue
from telegram.ext import Application, Job, JobQueue, StateInfo
from telegram.ext._utils.types import CCT
_STORING_DATA_WIKI = (
@@ -121,6 +123,7 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
"args",
"coroutine",
"error",
"fsm_state_info",
"job",
"matches",
)
@@ -141,6 +144,7 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
self.coroutine: Optional[
Union[Generator[Optional[Future[object]], None, Any], Awaitable[Any]]
] = None
self.fsm_state_info: StateInfo = None # type: ignore[assignment]
@property
def application(self) -> "Application[BT, ST, UD, CD, BD, Any]":
@@ -269,6 +273,22 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
"telegram.Bot does not allow for arbitrary callback data."
)
@property
def fsm(self) -> FiniteStateMachine:
return self.application.fsm
def fsm_semaphore(self) -> asyncio.Lock:
return self.fsm.get_lock(self.fsm_state_info.key)
async def set_state(self, state: State) -> None:
await self.fsm.set_state(self.fsm_state_info.key, state, self.fsm_state_info.version)
def set_state_nowait(self, state: State) -> None:
self.fsm.set_state_nowait(self.fsm_state_info.key, state, self.fsm_state_info.version)
def as_fsm_state(self, state: State) -> AbstractAsyncContextManager[None]:
return self.fsm.as_state(self.fsm_state_info.key, state)
@classmethod
def from_error(
cls: type["CCT"],
+6
View File
@@ -0,0 +1,6 @@
"""Private Submbodule for finite state machine implementation."""
__all__ = ["FiniteStateMachine", "SingleStateMachine", "State", "StateInfo"]
from .machine import FiniteStateMachine, SingleStateMachine, StateInfo
from .states import State
+200
View File
@@ -0,0 +1,200 @@
"""This Module contains the FiniteStateMachine class and the built-in subclass SingleStateMachine.
"""
import abc
import asyncio
import contextlib
import datetime as dtm
import logging
import time
import weakref
from collections import defaultdict, deque
from collections.abc import AsyncIterator, Hashable, Mapping, MutableSequence, Sequence
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, overload
from telegram.ext._fsm.states import State
from telegram.ext._utils.types import JobCallback
if TYPE_CHECKING:
from collections.abc import MutableMapping
from telegram.ext import JobQueue
_KT = TypeVar("_KT", bound=Hashable)
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.DEBUG)
class StateInfo(Generic[_KT]):
def __init__(self: "StateInfo[_KT]", key: _KT, state: State, version: int) -> None:
self.key: _KT = key
self.state: State = state
self.version: int = version
class FiniteStateMachine(abc.ABC, Generic[_KT]):
def __init__(self) -> None:
self._locks: MutableMapping[_KT, asyncio.Lock] = weakref.WeakValueDictionary()
# There is likely litte benefit for a user to customize how exactly the states are stored
# and accessed. So we make this private and only provide a read-only view.
self.__states: dict[_KT, tuple[State, int]] = defaultdict(
lambda: (State.IDLE, time.perf_counter_ns())
)
self._states = MappingProxyType(self.__states)
self.__job_queue: Optional[weakref.ReferenceType[JobQueue]] = None
self.__history: Mapping[_KT, MutableSequence[State]] = defaultdict(
lambda: deque(maxlen=10)
)
@property
def states(self) -> Mapping[_KT, tuple[State, int]]:
return self._states
def store_state_history(self, key: _KT, state: State) -> None:
# Making this public so that users can override if they want to customize the history
# E.g., they could want to store more/fewer states, also depending on the key
self.__history[key].append(state)
def get_state_history(self, key: _KT) -> Sequence[State]:
return list(self.__history[key])
def get_lock(self, key: _KT) -> asyncio.Lock:
"""Returns a lock that is unique for this key at runtime.
It can be used to prevent concurrent access to resources associated to this key.
"""
return self._locks.setdefault(key, asyncio.Lock())
@abc.abstractmethod
def get_state_info(self, update: object) -> StateInfo[_KT]:
"""Returns exactly one active state for the update.
If more than one stored key applies to the update, one must be chosen.
It's recommended to select the most specific one.
Example:
The state of a chat, a user or a user in a specific chat could be tracked.
For a message in that chat, the state of the user in that chat should be returned if
available. Otherwise, the state of the chat should be returned.
Important:
This must be an atomic operation and not e.g. wait for a lock.
Instead, if necessary, return a special state indicating that the key is currently
busy.
"""
def _do_set_state(
self, key: _KT, state: State, version: Optional[int] = None
) -> StateInfo[_KT]:
"""Protected method to set the state for the specified key.
The version can be optionally used for optimistic locking. If the version does not match
the current version, the state should not be updated.
Important:
This should be used exclusively by methods of this class and subclasses.
It should *not* be called directly by users of this class!
"""
_LOGGER.debug("Setting %s state to %s", key, state)
if state is State.ANY:
raise ValueError("State.ANY is not supported in set_state")
if version and version != self._states.get(key, (None, None))[1]:
raise ValueError("Optimistic locking failed. Not updating state.")
if jq := self._get_job_queue(raise_exception=False):
# This is a rather tight coupling between FSM and JobQueue
# Not sure if we like that. Makes it even harder to replace JobQueue
# (or the JQ implementation) with something else.
# The upside is that we don't need to maintain any additional internal state
# for the jobs and persistence is handled by the JobQueue.
cancel_jobs = jq.jobs(pattern=str(hash(key)))
for job in cancel_jobs:
_LOGGER.debug("Cancelling timeout job %s", job)
job.schedule_removal()
# important to use time.perf_counter_ns() here, as time_ns() is not monotonic
self.__states[key] = (state, time.perf_counter_ns())
# Doing this *after* do_set_state so that any exceptions are raised before the history
# is updated
self.store_state_history(key, state)
return StateInfo(key, state, self._states[key][1])
async def set_state(self, key: _KT, state: State, version: Optional[int] = None) -> None:
"""Store the state for the specified key."""
async with self.get_lock(key):
self._do_set_state(key, state, version)
def set_state_nowait(self, key: _KT, state: State, version: Optional[int] = None) -> None:
"""Store the state for the specified key without waiting for a lock."""
if self.get_lock(key).locked():
raise asyncio.InvalidStateError("Lock is locked")
self._do_set_state(key, state, version)
@contextlib.asynccontextmanager
async def as_state(self, key: _KT, state: State) -> AsyncIterator[None]:
"""Context manager to set the state for the specified key and reset it afterwards."""
async with self.get_lock(key):
current_state, current_version = self.states[key]
new_version = self._do_set_state(key, state, current_version).version
try:
yield
finally:
self._do_set_state(key, current_state, new_version)
@staticmethod
def _build_job_name(keys: Sequence[_KT]) -> str:
return f"FSM_Job_{'_'.join(str(hash(k)) for k in keys)}"
def set_job_queue(self, job_queue: "JobQueue") -> None:
self.__job_queue = weakref.ref(job_queue)
@overload
def _get_job_queue(self, raise_exception: Literal[False]) -> Optional["JobQueue"]: ...
@overload
def _get_job_queue(self) -> "JobQueue": ...
def _get_job_queue(self, raise_exception: bool = True) -> Optional["JobQueue"]:
if self.__job_queue is None:
if raise_exception:
raise RuntimeError("JobQueue not set")
return None
job_queue = self.__job_queue()
if job_queue is None:
if raise_exception:
raise RuntimeError("JobQueue was garbage collected")
return None
return job_queue
def schedule_timeout(
self,
callback: JobCallback,
when: Union[float, dtm.timedelta, dtm.datetime, dtm.time],
cancel_keys: Optional[Sequence[_KT]] = None,
job_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""Schedule a timeout job. This is a thin wrapper around JobQueue.run_once.
The callback will have to take care of resetting any state if necessary.
Pass cancel_keys to automatically cancel the job when a new state is set for any of the
keys.
"""
job_kwargs = job_kwargs or {}
if cancel_keys:
if "name" in job_kwargs:
raise ValueError("job_kwargs must not contain a 'name' key")
job_kwargs["name"] = self._build_job_name(cancel_keys)
self._get_job_queue().run_once(callback, when, **job_kwargs)
_LOGGER.debug(
"Scheduled timeout. Will be cancelled when a new set state is for either of: %s",
cancel_keys or [],
)
class SingleStateMachine(FiniteStateMachine[None]):
def get_state_info(self, update: object) -> StateInfo[None]: # noqa: ARG002
return StateInfo(None, State.IDLE, 0)
def do_set_state(self, key: None, state: State) -> None:
pass
+114
View File
@@ -0,0 +1,114 @@
"""This Module contains implementations of State classes for Finite State Machines"""
import abc
import contextlib
from typing import ClassVar, Optional
from uuid import uuid4
class State(abc.ABC):
__knows_uids: ClassVar[set[str]] = set()
__not_cache: ClassVar[dict[str, "_NOTState"]] = {}
__or_cache: ClassVar[dict[tuple[str, str], "_ORState"]] = {}
__and_cache: ClassVar[dict[tuple[str, str], "_ANDState"]] = {}
__xor_cache: ClassVar[dict[tuple[str, str], "_XORState"]] = {}
IDLE: "State"
"""Default State for all Finite State Machines"""
ANY: "State"
"""Special State that matches any other State. Useful to define fallback behavior.
*Not* supported in ``set_state`` method of FSMs.
"""
def __init__(self, uid: Optional[str] = None):
effective_uid = uid or uuid4().hex
if effective_uid in self.__knows_uids:
raise ValueError(f"Duplicate UID: {effective_uid} already registered")
self._uid = effective_uid
self.__knows_uids.add(effective_uid)
def __invert__(self) -> "_NOTState":
with contextlib.suppress(KeyError):
return self.__not_cache[self.uid]
return self.__not_cache.setdefault(self.uid, _NOTState(self))
def __or__(self, other: "State") -> "_ORState":
key = (self.uid, other.uid)
with contextlib.suppress(KeyError):
return self.__or_cache[key]
return self.__or_cache.setdefault(key, _ORState(self, other))
def __and__(self, other: "State") -> "_ANDState":
key = (self.uid, other.uid)
with contextlib.suppress(KeyError):
return self.__and_cache[key]
return self.__and_cache.setdefault(key, _ANDState(self, other))
def __xor__(self, other: "State") -> "_XORState":
key = (self.uid, other.uid)
with contextlib.suppress(KeyError):
return self.__xor_cache[key]
return self.__xor_cache.setdefault(key, _XORState(self, other))
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.uid}>"
def __str__(self) -> str:
return self.uid
@property
def uid(self) -> str:
return self._uid
def matches(self, state: "State") -> bool:
if isinstance(state, (_NOTState, _ANDState, _ORState, _XORState)):
return state.matches(self)
return self.uid == state.uid
class _AnyState(State):
def matches(self, state: "State") -> bool: # noqa: ARG002
return True
State.IDLE = State("IDLE")
State.ANY = _AnyState("ANY")
class _XORState(State):
def __init__(self, state_one: State, state_two: State):
super().__init__(uid=f"({state_one.uid})^({state_two.uid})")
self._state_one = state_one
self._state_two = state_two
def matches(self, state: "State") -> bool:
return self._state_one.matches(state) ^ self._state_two.matches(state)
class _ORState(State):
def __init__(self, state_one: State, state_two: State):
super().__init__(uid=f"({state_one.uid})|({state_two.uid})")
self._state_one = state_one
self._state_two = state_two
def matches(self, state: "State") -> bool:
return self._state_one.matches(state) or self._state_two.matches(state)
class _ANDState(State):
def __init__(self, state_one: State, state_two: State):
super().__init__(uid=f"({state_one.uid})&({state_two.uid})")
self._state_one = state_one
self._state_two = state_two
def matches(self, state: "State") -> bool:
return self._state_one.matches(state) and self._state_two.matches(state)
class _NOTState(State):
def __init__(self, state: State):
super().__init__(uid=f"!({state.uid})")
self._state = state
def matches(self, state: "State") -> bool:
return not self._state.matches(state)
+1 -1
View File
@@ -97,7 +97,7 @@ class JobQueue(Generic[CCT]):
"""
__slots__ = ("_application", "_executor", "scheduler")
__slots__ = ("__weakref__", "_application", "_executor", "scheduler")
_CRON_MAPPING = ("sun", "mon", "tue", "wed", "thu", "fri", "sat")
def __init__(self) -> None: