diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 5c2a1468..2f66a8b0 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -383,7 +383,7 @@ async def listen(self, finish_event: asyncio.Event) -> None: # pragma: no cover if self.on_exit is not None: self.on_exit(self) - async def prefetcher( + async def prefetcher( # noqa: C901 self, queue: "asyncio.Queue[bytes | AckableMessage]", finish_event: asyncio.Event, @@ -396,48 +396,59 @@ async def prefetcher( """ fetched_tasks: int = 0 iterator = self.broker.listen() - current_message: asyncio.Task[bytes | AckableMessage] = asyncio.create_task( - iterator.__anext__(), # type: ignore - ) + current_message: asyncio.Task[bytes | AckableMessage] | None = None - while True: - if finish_event.is_set(): - break - try: - await self.sem_prefetch.acquire() - if ( - self.max_tasks_to_execute - and fetched_tasks >= self.max_tasks_to_execute - ): - logger.info("Max number of tasks executed.") - break - # Here we wait for the message to be fetched, - # but we make it with timeout so it can be interrupted - done, _ = await asyncio.wait({current_message}, timeout=0.3) - # If the message is not fetched, we release the semaphore - # and continue the loop. So it will check if finished event was set. - if not done: - self.sem_prefetch.release() - continue - # We're done, so now we need to check - # whether task has returned an error. - message = current_message.result() - current_message = asyncio.create_task(iterator.__anext__()) # type: ignore - fetched_tasks += 1 - await queue.put(message) - # Custom hooks for OTel and any future instrumentations - for middleware in reversed(self.broker.middlewares): - if hasattr(middleware, "on_prefetch_queue_add"): - await maybe_awaitable( - middleware.on_prefetch_queue_add(), # type: ignore + try: + while not finish_event.is_set(): + try: + await self.sem_prefetch.acquire() + if ( + self.max_tasks_to_execute + and fetched_tasks >= self.max_tasks_to_execute + ): + logger.info("Max number of tasks executed.") + break + if current_message is None: + current_message = asyncio.create_task( + iterator.__anext__(), # type: ignore ) - except (asyncio.CancelledError, StopAsyncIteration): - break - # We don't want to fetch new messages if we are shutting down. - logger.info("Stopping prefetching messages...") - current_message.cancel() - await queue.put(QUEUE_DONE) - self.sem_prefetch.release() + # Here we wait for the message to be fetched, + # but we make it with timeout so it can be interrupted + done, _ = await asyncio.wait({current_message}, timeout=0.3) + # If the message is not fetched, we release the semaphore + # and continue the loop. So it will check if finished event was set. + if not done: + self.sem_prefetch.release() + continue + # We're done, so now we need to check + # whether task has returned an error. + message = current_message.result() + current_message = None + fetched_tasks += 1 + await queue.put(message) + # Custom hooks for OTel and any future instrumentations + for middleware in reversed(self.broker.middlewares): + if hasattr(middleware, "on_prefetch_queue_add"): + await maybe_awaitable( + middleware.on_prefetch_queue_add(), # type: ignore + ) + except (asyncio.CancelledError, StopAsyncIteration): + break + finally: + # We don't want to fetch new messages if we are shutting down. + logger.info("Stopping prefetching messages...") + # Short window to deliver, then forward or cancel. + if current_message is not None: + await asyncio.wait({current_message}, timeout=0.3) + if not current_message.done(): + current_message.cancel() + elif ( + not current_message.cancelled() + and current_message.exception() is None + ): + await queue.put(current_message.result()) + await queue.put(QUEUE_DONE) + self.sem_prefetch.release() async def runner( self, diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index eeb29c11..f3189b91 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -600,3 +600,28 @@ async def test_no_semaphore_without_max_async_tasks() -> None: """Test that semaphore is None when max_async_tasks is not set.""" receiver = get_receiver(max_async_tasks=None) assert receiver.sem is None + + +async def test_prefetcher_does_not_pop_message_past_max_tasks() -> None: + """Test not pulling a message without the intention of running it.""" + broker = AsyncQueueBroker() + + @broker.task + async def noop() -> None: + return None + + for _ in range(6): + await noop.kiq() + + assert broker.queue.qsize() == 6 + + receiver = Receiver( + broker, + executor=ThreadPoolExecutor(max_workers=1), + max_async_tasks=1, + max_tasks_to_execute=5, + ) + + await receiver.listen(asyncio.Event()) + + assert broker.queue.qsize() == 1