waitgroup.py, waitgroup_test.py

a simple wait group written before discovering asyncio.TaskGroup

waitgroup.py

import asyncio
import logging
from typing import Any, Callable, Coroutine, List, Optional

logger = logging.getLogger(__name__)


class Waitgroup:
    def __init__(self, timeout: Optional[float] = None):
        self.timeout = timeout
        self.futures: List[asyncio.Future] = []

    def create(self) -> asyncio.Future:
        future = asyncio.Future()
        self.futures.append(future)
        return future

    def cancel(self):
        """Cancels all futures managed by this Waitgroup."""
        for fut in self.futures:
            if not fut.done():
                fut.cancel()

    async def wait(self, timeout: Optional[float] = None):
        """Waits for all saved futures to resolve in parallel, with an optional timeout."""
        if not self.futures:
            return

        # wait for all futures to complete
        done, pending = await asyncio.wait(
            self.futures, timeout=timeout or self.timeout, return_when=asyncio.ALL_COMPLETED
        )

        # log exceptions on completed futures
        for fut in done:
            exception = fut.exception()
            if exception is not None:
                logger.error(f"Future ended with exception: {exception}")
                self.cancel()  # cancel all remaining futures
                raise exception

        # cancel any pending futures if timedout
        for fut in pending:
            fut.cancel()

    def schedule(self, coro: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs):
        """Schedules the execution of a coroutine function which resolves a future."""
        future = self.create()

        async def task_wrapper():
            try:
                result = await coro(*args, **kwargs)
                future.set_result(result)
            except Exception as e:
                future.set_exception(e)

        asyncio.create_task(task_wrapper())

waitgroup_test.py

import asyncio
from unittest.mock import patch

import pytest
import Waitgroup, logger


# Helper coroutine for testing
def schedule_run(wg: Waitgroup, delay, result=None, exception=None):
    async def run(result=None, exception=None):
        await asyncio.sleep(delay)
        if exception:
            raise exception
        return result

    wg.schedule(run, result=result, exception=exception)


@pytest.mark.asyncio
async def test_waitgroup_initialization():
    wg = Waitgroup()
    assert wg.timeout is None
    assert len(wg.futures) == 0

    wg_with_timeout = Waitgroup(timeout=10)
    assert wg_with_timeout.timeout == 10


@pytest.mark.asyncio
async def test_future_creation():
    wg = Waitgroup()
    future = wg.create()
    assert isinstance(future, asyncio.Future)
    assert len(wg.futures) == 1


@pytest.mark.asyncio
async def test_cancel_all_futures():
    wg = Waitgroup()
    future1 = wg.create()
    future2 = wg.create()
    wg.cancel()
    assert future1.cancelled()
    assert future2.cancelled()


@pytest.mark.asyncio
async def test_wait_no_futures():
    wg = Waitgroup()
    await wg.wait()  # Should not raise any exceptions


@pytest.mark.asyncio
async def test_wait_futures_complete_before_timeout():
    wg = Waitgroup()
    for _ in range(3):
        schedule_run(wg, 0.1, result=True)
    await wg.wait(timeout=1)
    assert all(f.done() and not f.cancelled() for f in wg.futures)


@pytest.mark.asyncio
async def test_wait_futures_timeout():
    wg = Waitgroup(timeout=0.1)
    for _ in range(3):
        wg.create()  # Futures that won't be set
    await wg.wait()
    # Check if futures are cancelled due to timeout
    assert all(f.cancelled() for f in wg.futures)


@pytest.mark.asyncio
async def test_wait_exceptions_are_logged_and_raised():
    wg = Waitgroup()
    with patch.object(logger, "error") as mock_logger:
        schedule_run(wg, 0.1, exception=Exception("Test Exception"))
        with pytest.raises(Exception) as exc_info:
            await wg.wait()
        assert str(exc_info.value) == "Test Exception"
    mock_logger.assert_called_once()


@pytest.mark.asyncio
async def test_schedule_coroutine():
    wg = Waitgroup()
    schedule_run(wg, 0.1, result="success")
    await wg.wait()
    assert all(f.done() and f.result() == "success" for f in wg.futures)

    # Testing with an exception
    wg = Waitgroup()
    schedule_run(wg, 0.1, exception=Exception("Failure"))
    with pytest.raises(Exception) as exc_info:
        await wg.wait()
    assert "Failure" in str(exc_info.value)