709 lines
25 KiB
Python
709 lines
25 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import functools
|
|
import logging
|
|
import random
|
|
import urllib.parse
|
|
import warnings
|
|
from types import TracebackType
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Callable,
|
|
Generator,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
cast,
|
|
)
|
|
|
|
from ..datastructures import Headers, HeadersLike
|
|
from ..exceptions import (
|
|
InvalidHandshake,
|
|
InvalidHeader,
|
|
InvalidMessage,
|
|
InvalidStatusCode,
|
|
NegotiationError,
|
|
RedirectHandshake,
|
|
SecurityError,
|
|
)
|
|
from ..extensions import ClientExtensionFactory, Extension
|
|
from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
|
from ..headers import (
|
|
build_authorization_basic,
|
|
build_extension,
|
|
build_host,
|
|
build_subprotocol,
|
|
parse_extension,
|
|
parse_subprotocol,
|
|
validate_subprotocols,
|
|
)
|
|
from ..http import USER_AGENT
|
|
from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
|
|
from ..uri import WebSocketURI, parse_uri
|
|
from .handshake import build_request, check_response
|
|
from .http import read_response
|
|
from .protocol import WebSocketCommonProtocol
|
|
|
|
|
|
__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
|
|
|
|
|
|
class WebSocketClientProtocol(WebSocketCommonProtocol):
|
|
"""
|
|
WebSocket client connection.
|
|
|
|
:class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send`
|
|
coroutines for receiving and sending messages.
|
|
|
|
It supports asynchronous iteration to receive incoming messages::
|
|
|
|
async for message in websocket:
|
|
await process(message)
|
|
|
|
The iterator exits normally when the connection is closed with close code
|
|
1000 (OK) or 1001 (going away). It raises
|
|
a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
|
|
is closed with any other code.
|
|
|
|
See :func:`connect` for the documentation of ``logger``, ``origin``,
|
|
``extensions``, ``subprotocols``, and ``extra_headers``.
|
|
|
|
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
|
|
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
|
|
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
|
|
|
|
"""
|
|
|
|
is_client = True
|
|
side = "client"
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
logger: Optional[LoggerLike] = None,
|
|
origin: Optional[Origin] = None,
|
|
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
|
|
subprotocols: Optional[Sequence[Subprotocol]] = None,
|
|
extra_headers: Optional[HeadersLike] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
if logger is None:
|
|
logger = logging.getLogger("websockets.client")
|
|
super().__init__(logger=logger, **kwargs)
|
|
self.origin = origin
|
|
self.available_extensions = extensions
|
|
self.available_subprotocols = subprotocols
|
|
self.extra_headers = extra_headers
|
|
|
|
def write_http_request(self, path: str, headers: Headers) -> None:
|
|
"""
|
|
Write request line and headers to the HTTP request.
|
|
|
|
"""
|
|
self.path = path
|
|
self.request_headers = headers
|
|
|
|
if self.debug:
|
|
self.logger.debug("> GET %s HTTP/1.1", path)
|
|
for key, value in headers.raw_items():
|
|
self.logger.debug("> %s: %s", key, value)
|
|
|
|
# Since the path and headers only contain ASCII characters,
|
|
# we can keep this simple.
|
|
request = f"GET {path} HTTP/1.1\r\n"
|
|
request += str(headers)
|
|
|
|
self.transport.write(request.encode())
|
|
|
|
async def read_http_response(self) -> Tuple[int, Headers]:
|
|
"""
|
|
Read status line and headers from the HTTP response.
|
|
|
|
If the response contains a body, it may be read from ``self.reader``
|
|
after this coroutine returns.
|
|
|
|
Raises:
|
|
InvalidMessage: if the HTTP message is malformed or isn't an
|
|
HTTP/1.1 GET response.
|
|
|
|
"""
|
|
try:
|
|
status_code, reason, headers = await read_response(self.reader)
|
|
# Remove this branch when dropping support for Python < 3.8
|
|
# because CancelledError no longer inherits Exception.
|
|
except asyncio.CancelledError: # pragma: no cover
|
|
raise
|
|
except Exception as exc:
|
|
raise InvalidMessage("did not receive a valid HTTP response") from exc
|
|
|
|
if self.debug:
|
|
self.logger.debug("< HTTP/1.1 %d %s", status_code, reason)
|
|
for key, value in headers.raw_items():
|
|
self.logger.debug("< %s: %s", key, value)
|
|
|
|
self.response_headers = headers
|
|
|
|
return status_code, self.response_headers
|
|
|
|
@staticmethod
|
|
def process_extensions(
|
|
headers: Headers,
|
|
available_extensions: Optional[Sequence[ClientExtensionFactory]],
|
|
) -> List[Extension]:
|
|
"""
|
|
Handle the Sec-WebSocket-Extensions HTTP response header.
|
|
|
|
Check that each extension is supported, as well as its parameters.
|
|
|
|
Return the list of accepted extensions.
|
|
|
|
Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
|
|
connection.
|
|
|
|
:rfc:`6455` leaves the rules up to the specification of each
|
|
:extension.
|
|
|
|
To provide this level of flexibility, for each extension accepted by
|
|
the server, we check for a match with each extension available in the
|
|
client configuration. If no match is found, an exception is raised.
|
|
|
|
If several variants of the same extension are accepted by the server,
|
|
it may be configured several times, which won't make sense in general.
|
|
Extensions must implement their own requirements. For this purpose,
|
|
the list of previously accepted extensions is provided.
|
|
|
|
Other requirements, for example related to mandatory extensions or the
|
|
order of extensions, may be implemented by overriding this method.
|
|
|
|
"""
|
|
accepted_extensions: List[Extension] = []
|
|
|
|
header_values = headers.get_all("Sec-WebSocket-Extensions")
|
|
|
|
if header_values:
|
|
|
|
if available_extensions is None:
|
|
raise InvalidHandshake("no extensions supported")
|
|
|
|
parsed_header_values: List[ExtensionHeader] = sum(
|
|
[parse_extension(header_value) for header_value in header_values], []
|
|
)
|
|
|
|
for name, response_params in parsed_header_values:
|
|
|
|
for extension_factory in available_extensions:
|
|
|
|
# Skip non-matching extensions based on their name.
|
|
if extension_factory.name != name:
|
|
continue
|
|
|
|
# Skip non-matching extensions based on their params.
|
|
try:
|
|
extension = extension_factory.process_response_params(
|
|
response_params, accepted_extensions
|
|
)
|
|
except NegotiationError:
|
|
continue
|
|
|
|
# Add matching extension to the final list.
|
|
accepted_extensions.append(extension)
|
|
|
|
# Break out of the loop once we have a match.
|
|
break
|
|
|
|
# If we didn't break from the loop, no extension in our list
|
|
# matched what the server sent. Fail the connection.
|
|
else:
|
|
raise NegotiationError(
|
|
f"Unsupported extension: "
|
|
f"name = {name}, params = {response_params}"
|
|
)
|
|
|
|
return accepted_extensions
|
|
|
|
@staticmethod
|
|
def process_subprotocol(
|
|
headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
|
|
) -> Optional[Subprotocol]:
|
|
"""
|
|
Handle the Sec-WebSocket-Protocol HTTP response header.
|
|
|
|
Check that it contains exactly one supported subprotocol.
|
|
|
|
Return the selected subprotocol.
|
|
|
|
"""
|
|
subprotocol: Optional[Subprotocol] = None
|
|
|
|
header_values = headers.get_all("Sec-WebSocket-Protocol")
|
|
|
|
if header_values:
|
|
|
|
if available_subprotocols is None:
|
|
raise InvalidHandshake("no subprotocols supported")
|
|
|
|
parsed_header_values: Sequence[Subprotocol] = sum(
|
|
[parse_subprotocol(header_value) for header_value in header_values], []
|
|
)
|
|
|
|
if len(parsed_header_values) > 1:
|
|
subprotocols = ", ".join(parsed_header_values)
|
|
raise InvalidHandshake(f"multiple subprotocols: {subprotocols}")
|
|
|
|
subprotocol = parsed_header_values[0]
|
|
|
|
if subprotocol not in available_subprotocols:
|
|
raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
|
|
|
|
return subprotocol
|
|
|
|
async def handshake(
|
|
self,
|
|
wsuri: WebSocketURI,
|
|
origin: Optional[Origin] = None,
|
|
available_extensions: Optional[Sequence[ClientExtensionFactory]] = None,
|
|
available_subprotocols: Optional[Sequence[Subprotocol]] = None,
|
|
extra_headers: Optional[HeadersLike] = None,
|
|
) -> None:
|
|
"""
|
|
Perform the client side of the opening handshake.
|
|
|
|
Args:
|
|
wsuri: URI of the WebSocket server.
|
|
origin: value of the ``Origin`` header.
|
|
available_extensions: list of supported extensions, in order in
|
|
which they should be tried.
|
|
available_subprotocols: list of supported subprotocols, in order
|
|
of decreasing preference.
|
|
extra_headers: arbitrary HTTP headers to add to the request.
|
|
|
|
Raises:
|
|
InvalidHandshake: if the handshake fails.
|
|
|
|
"""
|
|
request_headers = Headers()
|
|
|
|
request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure)
|
|
|
|
if wsuri.user_info:
|
|
request_headers["Authorization"] = build_authorization_basic(
|
|
*wsuri.user_info
|
|
)
|
|
|
|
if origin is not None:
|
|
request_headers["Origin"] = origin
|
|
|
|
key = build_request(request_headers)
|
|
|
|
if available_extensions is not None:
|
|
extensions_header = build_extension(
|
|
[
|
|
(extension_factory.name, extension_factory.get_request_params())
|
|
for extension_factory in available_extensions
|
|
]
|
|
)
|
|
request_headers["Sec-WebSocket-Extensions"] = extensions_header
|
|
|
|
if available_subprotocols is not None:
|
|
protocol_header = build_subprotocol(available_subprotocols)
|
|
request_headers["Sec-WebSocket-Protocol"] = protocol_header
|
|
|
|
if self.extra_headers is not None:
|
|
request_headers.update(self.extra_headers)
|
|
|
|
request_headers.setdefault("User-Agent", USER_AGENT)
|
|
|
|
self.write_http_request(wsuri.resource_name, request_headers)
|
|
|
|
status_code, response_headers = await self.read_http_response()
|
|
if status_code in (301, 302, 303, 307, 308):
|
|
if "Location" not in response_headers:
|
|
raise InvalidHeader("Location")
|
|
raise RedirectHandshake(response_headers["Location"])
|
|
elif status_code != 101:
|
|
raise InvalidStatusCode(status_code, response_headers)
|
|
|
|
check_response(response_headers, key)
|
|
|
|
self.extensions = self.process_extensions(
|
|
response_headers, available_extensions
|
|
)
|
|
|
|
self.subprotocol = self.process_subprotocol(
|
|
response_headers, available_subprotocols
|
|
)
|
|
|
|
self.connection_open()
|
|
|
|
|
|
class Connect:
|
|
"""
|
|
Connect to the WebSocket server at ``uri``.
|
|
|
|
Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
|
|
can then be used to send and receive messages.
|
|
|
|
:func:`connect` can be used as a asynchronous context manager::
|
|
|
|
async with websockets.connect(...) as websocket:
|
|
...
|
|
|
|
The connection is closed automatically when exiting the context.
|
|
|
|
:func:`connect` can be used as an infinite asynchronous iterator to
|
|
reconnect automatically on errors::
|
|
|
|
async for websocket in websockets.connect(...):
|
|
try:
|
|
...
|
|
except websockets.ConnectionClosed:
|
|
continue
|
|
|
|
The connection is closed automatically after each iteration of the loop.
|
|
|
|
If an error occurs while establishing the connection, :func:`connect`
|
|
retries with exponential backoff. The backoff delay starts at three
|
|
seconds and increases up to one minute.
|
|
|
|
If an error occurs in the body of the loop, you can handle the exception
|
|
and :func:`connect` will reconnect with the next iteration; or you can
|
|
let the exception bubble up and break out of the loop. This lets you
|
|
decide which errors trigger a reconnection and which errors are fatal.
|
|
|
|
Args:
|
|
uri: URI of the WebSocket server.
|
|
create_protocol: factory for the :class:`asyncio.Protocol` managing
|
|
the connection; defaults to :class:`WebSocketClientProtocol`; may
|
|
be set to a wrapper or a subclass to customize connection handling.
|
|
logger: logger for this connection;
|
|
defaults to ``logging.getLogger("websockets.client")``;
|
|
see the :doc:`logging guide <../topics/logging>` for details.
|
|
compression: shortcut that enables the "permessage-deflate" extension
|
|
by default; may be set to :obj:`None` to disable compression;
|
|
see the :doc:`compression guide <../topics/compression>` for details.
|
|
origin: value of the ``Origin`` header. This is useful when connecting
|
|
to a server that validates the ``Origin`` header to defend against
|
|
Cross-Site WebSocket Hijacking attacks.
|
|
extensions: list of supported extensions, in order in which they
|
|
should be tried.
|
|
subprotocols: list of supported subprotocols, in order of decreasing
|
|
preference.
|
|
extra_headers: arbitrary HTTP headers to add to the request.
|
|
open_timeout: timeout for opening the connection in seconds;
|
|
:obj:`None` to disable the timeout
|
|
|
|
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
|
|
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
|
|
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
|
|
|
|
Any other keyword arguments are passed the event loop's
|
|
:meth:`~asyncio.loop.create_connection` method.
|
|
|
|
For example:
|
|
|
|
* You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS
|
|
settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't
|
|
provided, a TLS context is created
|
|
with :func:`~ssl.create_default_context`.
|
|
|
|
* You can set ``host`` and ``port`` to connect to a different host and
|
|
port from those found in ``uri``. This only changes the destination of
|
|
the TCP connection. The host name from ``uri`` is still used in the TLS
|
|
handshake for secure connections and in the ``Host`` header.
|
|
|
|
Returns:
|
|
WebSocketClientProtocol: WebSocket connection.
|
|
|
|
Raises:
|
|
InvalidURI: if ``uri`` isn't a valid WebSocket URI.
|
|
InvalidHandshake: if the opening handshake fails.
|
|
~asyncio.TimeoutError: if the opening handshake times out.
|
|
|
|
"""
|
|
|
|
MAX_REDIRECTS_ALLOWED = 10
|
|
|
|
def __init__(
|
|
self,
|
|
uri: str,
|
|
*,
|
|
create_protocol: Optional[Callable[[Any], WebSocketClientProtocol]] = None,
|
|
logger: Optional[LoggerLike] = None,
|
|
compression: Optional[str] = "deflate",
|
|
origin: Optional[Origin] = None,
|
|
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
|
|
subprotocols: Optional[Sequence[Subprotocol]] = None,
|
|
extra_headers: Optional[HeadersLike] = None,
|
|
open_timeout: Optional[float] = 10,
|
|
ping_interval: Optional[float] = 20,
|
|
ping_timeout: Optional[float] = 20,
|
|
close_timeout: Optional[float] = None,
|
|
max_size: Optional[int] = 2**20,
|
|
max_queue: Optional[int] = 2**5,
|
|
read_limit: int = 2**16,
|
|
write_limit: int = 2**16,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
# Backwards compatibility: close_timeout used to be called timeout.
|
|
timeout: Optional[float] = kwargs.pop("timeout", None)
|
|
if timeout is None:
|
|
timeout = 10
|
|
else:
|
|
warnings.warn("rename timeout to close_timeout", DeprecationWarning)
|
|
# If both are specified, timeout is ignored.
|
|
if close_timeout is None:
|
|
close_timeout = timeout
|
|
|
|
# Backwards compatibility: create_protocol used to be called klass.
|
|
klass: Optional[Type[WebSocketClientProtocol]] = kwargs.pop("klass", None)
|
|
if klass is None:
|
|
klass = WebSocketClientProtocol
|
|
else:
|
|
warnings.warn("rename klass to create_protocol", DeprecationWarning)
|
|
# If both are specified, klass is ignored.
|
|
if create_protocol is None:
|
|
create_protocol = klass
|
|
|
|
# Backwards compatibility: recv() used to return None on closed connections
|
|
legacy_recv: bool = kwargs.pop("legacy_recv", False)
|
|
|
|
# Backwards compatibility: the loop parameter used to be supported.
|
|
_loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None)
|
|
if _loop is None:
|
|
loop = asyncio.get_event_loop()
|
|
else:
|
|
loop = _loop
|
|
warnings.warn("remove loop argument", DeprecationWarning)
|
|
|
|
wsuri = parse_uri(uri)
|
|
if wsuri.secure:
|
|
kwargs.setdefault("ssl", True)
|
|
elif kwargs.get("ssl") is not None:
|
|
raise ValueError(
|
|
"connect() received a ssl argument for a ws:// URI, "
|
|
"use a wss:// URI to enable TLS"
|
|
)
|
|
|
|
if compression == "deflate":
|
|
extensions = enable_client_permessage_deflate(extensions)
|
|
elif compression is not None:
|
|
raise ValueError(f"unsupported compression: {compression}")
|
|
|
|
if subprotocols is not None:
|
|
validate_subprotocols(subprotocols)
|
|
|
|
factory = functools.partial(
|
|
create_protocol,
|
|
logger=logger,
|
|
origin=origin,
|
|
extensions=extensions,
|
|
subprotocols=subprotocols,
|
|
extra_headers=extra_headers,
|
|
ping_interval=ping_interval,
|
|
ping_timeout=ping_timeout,
|
|
close_timeout=close_timeout,
|
|
max_size=max_size,
|
|
max_queue=max_queue,
|
|
read_limit=read_limit,
|
|
write_limit=write_limit,
|
|
host=wsuri.host,
|
|
port=wsuri.port,
|
|
secure=wsuri.secure,
|
|
legacy_recv=legacy_recv,
|
|
loop=_loop,
|
|
)
|
|
|
|
if kwargs.pop("unix", False):
|
|
path: Optional[str] = kwargs.pop("path", None)
|
|
create_connection = functools.partial(
|
|
loop.create_unix_connection, factory, path, **kwargs
|
|
)
|
|
else:
|
|
host: Optional[str]
|
|
port: Optional[int]
|
|
if kwargs.get("sock") is None:
|
|
host, port = wsuri.host, wsuri.port
|
|
else:
|
|
# If sock is given, host and port shouldn't be specified.
|
|
host, port = None, None
|
|
# If host and port are given, override values from the URI.
|
|
host = kwargs.pop("host", host)
|
|
port = kwargs.pop("port", port)
|
|
create_connection = functools.partial(
|
|
loop.create_connection, factory, host, port, **kwargs
|
|
)
|
|
|
|
self.open_timeout = open_timeout
|
|
if logger is None:
|
|
logger = logging.getLogger("websockets.client")
|
|
self.logger = logger
|
|
|
|
# This is a coroutine function.
|
|
self._create_connection = create_connection
|
|
self._uri = uri
|
|
self._wsuri = wsuri
|
|
|
|
def handle_redirect(self, uri: str) -> None:
|
|
# Update the state of this instance to connect to a new URI.
|
|
old_uri = self._uri
|
|
old_wsuri = self._wsuri
|
|
new_uri = urllib.parse.urljoin(old_uri, uri)
|
|
new_wsuri = parse_uri(new_uri)
|
|
|
|
# Forbid TLS downgrade.
|
|
if old_wsuri.secure and not new_wsuri.secure:
|
|
raise SecurityError("redirect from WSS to WS")
|
|
|
|
same_origin = (
|
|
old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port
|
|
)
|
|
|
|
# Rewrite the host and port arguments for cross-origin redirects.
|
|
# This preserves connection overrides with the host and port
|
|
# arguments if the redirect points to the same host and port.
|
|
if not same_origin:
|
|
# Replace the host and port argument passed to the protocol factory.
|
|
factory = self._create_connection.args[0]
|
|
factory = functools.partial(
|
|
factory.func,
|
|
*factory.args,
|
|
**dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
|
|
)
|
|
# Replace the host and port argument passed to create_connection.
|
|
self._create_connection = functools.partial(
|
|
self._create_connection.func,
|
|
*(factory, new_wsuri.host, new_wsuri.port),
|
|
**self._create_connection.keywords,
|
|
)
|
|
|
|
# Set the new WebSocket URI. This suffices for same-origin redirects.
|
|
self._uri = new_uri
|
|
self._wsuri = new_wsuri
|
|
|
|
# async for ... in connect(...):
|
|
|
|
BACKOFF_MIN = 1.92
|
|
BACKOFF_MAX = 60.0
|
|
BACKOFF_FACTOR = 1.618
|
|
BACKOFF_INITIAL = 5
|
|
|
|
async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
|
|
backoff_delay = self.BACKOFF_MIN
|
|
while True:
|
|
try:
|
|
async with self as protocol:
|
|
yield protocol
|
|
# Remove this branch when dropping support for Python < 3.8
|
|
# because CancelledError no longer inherits Exception.
|
|
except asyncio.CancelledError: # pragma: no cover
|
|
raise
|
|
except Exception:
|
|
# Add a random initial delay between 0 and 5 seconds.
|
|
# See 7.2.3. Recovering from Abnormal Closure in RFC 6544.
|
|
if backoff_delay == self.BACKOFF_MIN:
|
|
initial_delay = random.random() * self.BACKOFF_INITIAL
|
|
self.logger.info(
|
|
"! connect failed; reconnecting in %.1f seconds",
|
|
initial_delay,
|
|
exc_info=True,
|
|
)
|
|
await asyncio.sleep(initial_delay)
|
|
else:
|
|
self.logger.info(
|
|
"! connect failed again; retrying in %d seconds",
|
|
int(backoff_delay),
|
|
exc_info=True,
|
|
)
|
|
await asyncio.sleep(int(backoff_delay))
|
|
# Increase delay with truncated exponential backoff.
|
|
backoff_delay = backoff_delay * self.BACKOFF_FACTOR
|
|
backoff_delay = min(backoff_delay, self.BACKOFF_MAX)
|
|
continue
|
|
else:
|
|
# Connection succeeded - reset backoff delay
|
|
backoff_delay = self.BACKOFF_MIN
|
|
|
|
# async with connect(...) as ...:
|
|
|
|
async def __aenter__(self) -> WebSocketClientProtocol:
|
|
return await self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: Optional[Type[BaseException]],
|
|
exc_value: Optional[BaseException],
|
|
traceback: Optional[TracebackType],
|
|
) -> None:
|
|
await self.protocol.close()
|
|
|
|
# ... = await connect(...)
|
|
|
|
def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
|
|
# Create a suitable iterator by calling __await__ on a coroutine.
|
|
return self.__await_impl_timeout__().__await__()
|
|
|
|
async def __await_impl_timeout__(self) -> WebSocketClientProtocol:
|
|
return await asyncio.wait_for(self.__await_impl__(), self.open_timeout)
|
|
|
|
async def __await_impl__(self) -> WebSocketClientProtocol:
|
|
for redirects in range(self.MAX_REDIRECTS_ALLOWED):
|
|
transport, protocol = await self._create_connection()
|
|
protocol = cast(WebSocketClientProtocol, protocol)
|
|
|
|
try:
|
|
await protocol.handshake(
|
|
self._wsuri,
|
|
origin=protocol.origin,
|
|
available_extensions=protocol.available_extensions,
|
|
available_subprotocols=protocol.available_subprotocols,
|
|
extra_headers=protocol.extra_headers,
|
|
)
|
|
except RedirectHandshake as exc:
|
|
protocol.fail_connection()
|
|
await protocol.wait_closed()
|
|
self.handle_redirect(exc.uri)
|
|
# Avoid leaking a connected socket when the handshake fails.
|
|
except (Exception, asyncio.CancelledError):
|
|
protocol.fail_connection()
|
|
await protocol.wait_closed()
|
|
raise
|
|
else:
|
|
self.protocol = protocol
|
|
return protocol
|
|
else:
|
|
raise SecurityError("too many redirects")
|
|
|
|
# ... = yield from connect(...) - remove when dropping Python < 3.10
|
|
|
|
__iter__ = __await__
|
|
|
|
|
|
connect = Connect
|
|
|
|
|
|
def unix_connect(
|
|
path: Optional[str] = None,
|
|
uri: str = "ws://localhost/",
|
|
**kwargs: Any,
|
|
) -> Connect:
|
|
"""
|
|
Similar to :func:`connect`, but for connecting to a Unix socket.
|
|
|
|
This function builds upon the event loop's
|
|
:meth:`~asyncio.loop.create_unix_connection` method.
|
|
|
|
It is only available on Unix.
|
|
|
|
It's mainly useful for debugging servers listening on Unix sockets.
|
|
|
|
Args:
|
|
path: file system path to the Unix socket.
|
|
uri: URI of the WebSocket server; the host is used in the TLS
|
|
handshake for secure connections and in the ``Host`` header.
|
|
|
|
"""
|
|
return connect(uri=uri, path=path, unix=True, **kwargs)
|