Use Lock instead of semaphore

This commit is contained in:
Hinrich Mahler
2025-02-06 12:05:20 +01:00
parent 31af1a9db8
commit b1fff6d90a
3 changed files with 15 additions and 17 deletions
+2 -2
View File
@@ -82,7 +82,7 @@ async def unsafe_update(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N
await context.fsm.set_state(context.fsm_state_info.key, ConcurrentMachine.UPDATING_BALANCE)
# At this point, the semaphore is released such that multiple updates can update
# At this point, the lock is released such that multiple updates can update
# the balance concurrently. This can lead to race conditions.
await update_balance(context, update)
@@ -95,7 +95,7 @@ async def safe_update(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non
await asyncio.sleep(1)
async with context.as_fsm_state(ConcurrentMachine.UPDATING_BALANCE):
# At this point, the semaphore is acquired such that only one update can update
# At this point, the lock is acquired such that only one update can update
# the balance at a time. This prevents race conditions.
await update_balance(context, update)
+2 -2
View File
@@ -277,8 +277,8 @@ class CallbackContext(Generic[BT, UD, CD, BD]):
def fsm(self) -> FiniteStateMachine:
return self.application.fsm
def fsm_semaphore(self) -> asyncio.BoundedSemaphore:
return self.fsm.get_semaphore(self.fsm_state_info.key)
def fsm_semaphore(self) -> asyncio.Lock:
return self.fsm.get_lock(self.fsm_state_info.key)
async def set_state(self, state: State) -> None:
await self.fsm.set_state(self.fsm_state_info.key, state, self.fsm_state_info.version)
+11 -13
View File
@@ -35,9 +35,7 @@ class StateInfo(Generic[_KT]):
class FiniteStateMachine(abc.ABC, Generic[_KT]):
def __init__(self) -> None:
self._semaphores: MutableMapping[_KT, asyncio.BoundedSemaphore] = (
weakref.WeakValueDictionary()
)
self._locks: MutableMapping[_KT, asyncio.Lock] = weakref.WeakValueDictionary()
# There is likely litte benefit for a user to customize how exactly the states are stored
# and accessed. So we make this private and only provide a read-only view.
@@ -63,16 +61,16 @@ class FiniteStateMachine(abc.ABC, Generic[_KT]):
def get_state_history(self, key: _KT) -> Sequence[State]:
return list(self.__history[key])
def get_semaphore(self, key: _KT) -> asyncio.BoundedSemaphore:
"""Returns a semaphore that is unique for this key at runtime.
def get_lock(self, key: _KT) -> asyncio.Lock:
"""Returns a lock that is unique for this key at runtime.
It can be used to prevent concurrent access to resources associated to this key.
"""
return self._semaphores.setdefault(key, asyncio.BoundedSemaphore(1))
return self._locks.setdefault(key, asyncio.Lock())
@abc.abstractmethod
def get_state_info(self, update: object) -> StateInfo[_KT]:
"""Returns exactly one active state for the update.
If more than one stored key applies to the update, one must chosen.
If more than one stored key applies to the update, one must be chosen.
It's recommended to select the most specific one.
Example:
@@ -81,7 +79,7 @@ class FiniteStateMachine(abc.ABC, Generic[_KT]):
available. Otherwise, the state of the chat should be returned.
Important:
This must be an atomic operation and not e.g. wait for a semaphore.
This must be an atomic operation and not e.g. wait for a lock.
Instead, if necessary, return a special state indicating that the key is currently
busy.
"""
@@ -125,19 +123,19 @@ class FiniteStateMachine(abc.ABC, Generic[_KT]):
async def set_state(self, key: _KT, state: State, version: Optional[int] = None) -> None:
"""Store the state for the specified key."""
async with self.get_semaphore(key):
async with self.get_lock(key):
self._do_set_state(key, state, version)
def set_state_nowait(self, key: _KT, state: State, version: Optional[int] = None) -> None:
"""Store the state for the specified key without waiting for a semaphore."""
if self.get_semaphore(key).locked():
raise asyncio.InvalidStateError("Semaphore is locked")
"""Store the state for the specified key without waiting for a lock."""
if self.get_lock(key).locked():
raise asyncio.InvalidStateError("Lock is locked")
self._do_set_state(key, state, version)
@contextlib.asynccontextmanager
async def as_state(self, key: _KT, state: State) -> AsyncIterator[None]:
"""Context manager to set the state for the specified key and reset it afterwards."""
async with self.get_semaphore(key):
async with self.get_lock(key):
current_state, current_version = self.states[key]
new_version = self._do_set_state(key, state, current_version).version
try: