152 lines
5.4 KiB
Python
152 lines
5.4 KiB
Python
from contextlib import contextmanager
|
|
from inspect import isasyncgenfunction, iscoroutinefunction
|
|
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, cast
|
|
|
|
import pytest
|
|
import sniffio
|
|
|
|
from ._core._eventloop import get_all_backends, get_asynclib
|
|
from .abc import TestRunner
|
|
|
|
if TYPE_CHECKING:
|
|
from _pytest.config import Config
|
|
|
|
_current_runner: Optional[TestRunner] = None
|
|
|
|
|
|
def extract_backend_and_options(backend: object) -> Tuple[str, Dict[str, Any]]:
|
|
if isinstance(backend, str):
|
|
return backend, {}
|
|
elif isinstance(backend, tuple) and len(backend) == 2:
|
|
if isinstance(backend[0], str) and isinstance(backend[1], dict):
|
|
return cast(Tuple[str, Dict[str, Any]], backend)
|
|
|
|
raise TypeError('anyio_backend must be either a string or tuple of (string, dict)')
|
|
|
|
|
|
@contextmanager
|
|
def get_runner(backend_name: str, backend_options: Dict[str, Any]) -> Iterator[TestRunner]:
|
|
global _current_runner
|
|
if _current_runner:
|
|
yield _current_runner
|
|
return
|
|
|
|
asynclib = get_asynclib(backend_name)
|
|
token = None
|
|
if sniffio.current_async_library_cvar.get(None) is None:
|
|
# Since we're in control of the event loop, we can cache the name of the async library
|
|
token = sniffio.current_async_library_cvar.set(backend_name)
|
|
|
|
try:
|
|
backend_options = backend_options or {}
|
|
with asynclib.TestRunner(**backend_options) as runner:
|
|
_current_runner = runner
|
|
yield runner
|
|
finally:
|
|
_current_runner = None
|
|
if token:
|
|
sniffio.current_async_library_cvar.reset(token)
|
|
|
|
|
|
def pytest_configure(config: "Config") -> None:
|
|
config.addinivalue_line('markers', 'anyio: mark the (coroutine function) test to be run '
|
|
'asynchronously via anyio.')
|
|
|
|
|
|
def pytest_fixture_setup(fixturedef: Any, request: Any) -> None:
|
|
def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def]
|
|
backend_name, backend_options = extract_backend_and_options(anyio_backend)
|
|
if has_backend_arg:
|
|
kwargs['anyio_backend'] = anyio_backend
|
|
|
|
with get_runner(backend_name, backend_options) as runner:
|
|
if isasyncgenfunction(func):
|
|
gen = func(*args, **kwargs)
|
|
try:
|
|
value = runner.call(gen.asend, None)
|
|
except StopAsyncIteration:
|
|
raise RuntimeError('Async generator did not yield')
|
|
|
|
yield value
|
|
|
|
try:
|
|
runner.call(gen.asend, None)
|
|
except StopAsyncIteration:
|
|
pass
|
|
else:
|
|
runner.call(gen.aclose)
|
|
raise RuntimeError('Async generator fixture did not stop')
|
|
else:
|
|
yield runner.call(func, *args, **kwargs)
|
|
|
|
# Only apply this to coroutine functions and async generator functions in requests that involve
|
|
# the anyio_backend fixture
|
|
func = fixturedef.func
|
|
if isasyncgenfunction(func) or iscoroutinefunction(func):
|
|
if 'anyio_backend' in request.fixturenames:
|
|
has_backend_arg = 'anyio_backend' in fixturedef.argnames
|
|
fixturedef.func = wrapper
|
|
if not has_backend_arg:
|
|
fixturedef.argnames += ('anyio_backend',)
|
|
|
|
|
|
@pytest.hookimpl(tryfirst=True)
|
|
def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None:
|
|
if collector.istestfunction(obj, name):
|
|
inner_func = obj.hypothesis.inner_test if hasattr(obj, 'hypothesis') else obj
|
|
if iscoroutinefunction(inner_func):
|
|
marker = collector.get_closest_marker('anyio')
|
|
own_markers = getattr(obj, 'pytestmark', ())
|
|
if marker or any(marker.name == 'anyio' for marker in own_markers):
|
|
pytest.mark.usefixtures('anyio_backend')(obj)
|
|
|
|
|
|
@pytest.hookimpl(tryfirst=True)
|
|
def pytest_pyfunc_call(pyfuncitem: Any) -> Optional[bool]:
|
|
def run_with_hypothesis(**kwargs: Any) -> None:
|
|
with get_runner(backend_name, backend_options) as runner:
|
|
runner.call(original_func, **kwargs)
|
|
|
|
backend = pyfuncitem.funcargs.get('anyio_backend')
|
|
if backend:
|
|
backend_name, backend_options = extract_backend_and_options(backend)
|
|
|
|
if hasattr(pyfuncitem.obj, 'hypothesis'):
|
|
# Wrap the inner test function unless it's already wrapped
|
|
original_func = pyfuncitem.obj.hypothesis.inner_test
|
|
if original_func.__qualname__ != run_with_hypothesis.__qualname__:
|
|
if iscoroutinefunction(original_func):
|
|
pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis
|
|
|
|
return None
|
|
|
|
if iscoroutinefunction(pyfuncitem.obj):
|
|
funcargs = pyfuncitem.funcargs
|
|
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
|
|
with get_runner(backend_name, backend_options) as runner:
|
|
runner.call(pyfuncitem.obj, **testargs)
|
|
|
|
return True
|
|
|
|
return None
|
|
|
|
|
|
@pytest.fixture(params=get_all_backends())
|
|
def anyio_backend(request: Any) -> Any:
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture
|
|
def anyio_backend_name(anyio_backend: Any) -> str:
|
|
if isinstance(anyio_backend, str):
|
|
return anyio_backend
|
|
else:
|
|
return anyio_backend[0]
|
|
|
|
|
|
@pytest.fixture
|
|
def anyio_backend_options(anyio_backend: Any) -> Dict[str, Any]:
|
|
if isinstance(anyio_backend, str):
|
|
return {}
|
|
else:
|
|
return anyio_backend[1]
|