821 lines
31 KiB
Python
821 lines
31 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import socket
|
||
import ssl as ssl_module
|
||
import traceback
|
||
import urllib.parse
|
||
from collections.abc import AsyncIterator, Generator, Sequence
|
||
from types import TracebackType
|
||
from typing import Any, Callable, Literal, cast
|
||
|
||
from ..client import ClientProtocol, backoff
|
||
from ..datastructures import Headers, HeadersLike
|
||
from ..exceptions import (
|
||
InvalidMessage,
|
||
InvalidProxyMessage,
|
||
InvalidProxyStatus,
|
||
InvalidStatus,
|
||
ProxyError,
|
||
SecurityError,
|
||
)
|
||
from ..extensions.base import ClientExtensionFactory
|
||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
||
from ..headers import build_authorization_basic, build_host, validate_subprotocols
|
||
from ..http11 import USER_AGENT, Response
|
||
from ..protocol import CONNECTING, Event
|
||
from ..streams import StreamReader
|
||
from ..typing import LoggerLike, Origin, Subprotocol
|
||
from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri
|
||
from .compatibility import TimeoutError, asyncio_timeout
|
||
from .connection import Connection
|
||
|
||
|
||
__all__ = ["connect", "unix_connect", "ClientConnection"]
|
||
|
||
MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
|
||
|
||
|
||
class ClientConnection(Connection):
|
||
"""
|
||
:mod:`asyncio` implementation of a WebSocket client connection.
|
||
|
||
:class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines
|
||
for receiving and sending messages.
|
||
|
||
It supports asynchronous iteration to receive messages::
|
||
|
||
async for message in websocket:
|
||
await process(message)
|
||
|
||
The iterator exits normally when the connection is closed with close code
|
||
1000 (OK) or 1001 (going away) or without a close code. It raises a
|
||
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
|
||
closed with any other code.
|
||
|
||
The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
|
||
and ``write_limit`` arguments have the same meaning as in :func:`connect`.
|
||
|
||
Args:
|
||
protocol: Sans-I/O connection.
|
||
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
protocol: ClientProtocol,
|
||
*,
|
||
ping_interval: float | None = 20,
|
||
ping_timeout: float | None = 20,
|
||
close_timeout: float | None = 10,
|
||
max_queue: int | None | tuple[int | None, int | None] = 16,
|
||
write_limit: int | tuple[int, int | None] = 2**15,
|
||
) -> None:
|
||
self.protocol: ClientProtocol
|
||
super().__init__(
|
||
protocol,
|
||
ping_interval=ping_interval,
|
||
ping_timeout=ping_timeout,
|
||
close_timeout=close_timeout,
|
||
max_queue=max_queue,
|
||
write_limit=write_limit,
|
||
)
|
||
self.response_rcvd: asyncio.Future[None] = self.loop.create_future()
|
||
|
||
async def handshake(
|
||
self,
|
||
additional_headers: HeadersLike | None = None,
|
||
user_agent_header: str | None = USER_AGENT,
|
||
) -> None:
|
||
"""
|
||
Perform the opening handshake.
|
||
|
||
"""
|
||
async with self.send_context(expected_state=CONNECTING):
|
||
self.request = self.protocol.connect()
|
||
if additional_headers is not None:
|
||
self.request.headers.update(additional_headers)
|
||
if user_agent_header is not None:
|
||
self.request.headers.setdefault("User-Agent", user_agent_header)
|
||
self.protocol.send_request(self.request)
|
||
|
||
await asyncio.wait(
|
||
[self.response_rcvd, self.connection_lost_waiter],
|
||
return_when=asyncio.FIRST_COMPLETED,
|
||
)
|
||
|
||
# self.protocol.handshake_exc is set when the connection is lost before
|
||
# receiving a response, when the response cannot be parsed, or when the
|
||
# response fails the handshake.
|
||
|
||
if self.protocol.handshake_exc is not None:
|
||
raise self.protocol.handshake_exc
|
||
|
||
def process_event(self, event: Event) -> None:
|
||
"""
|
||
Process one incoming event.
|
||
|
||
"""
|
||
# First event - handshake response.
|
||
if self.response is None:
|
||
assert isinstance(event, Response)
|
||
self.response = event
|
||
self.response_rcvd.set_result(None)
|
||
# Later events - frames.
|
||
else:
|
||
super().process_event(event)
|
||
|
||
|
||
def process_exception(exc: Exception) -> Exception | None:
|
||
"""
|
||
Determine whether a connection error is retryable or fatal.
|
||
|
||
When reconnecting automatically with ``async for ... in connect(...)``, if a
|
||
connection attempt fails, :func:`process_exception` is called to determine
|
||
whether to retry connecting or to raise the exception.
|
||
|
||
This function defines the default behavior, which is to retry on:
|
||
|
||
* :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network
|
||
errors;
|
||
* :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
|
||
502, 503, or 504: server or proxy errors.
|
||
|
||
All other exceptions are considered fatal.
|
||
|
||
You can change this behavior with the ``process_exception`` argument of
|
||
:func:`connect`.
|
||
|
||
Return :obj:`None` if the exception is retryable i.e. when the error could
|
||
be transient and trying to reconnect with the same parameters could succeed.
|
||
The exception will be logged at the ``INFO`` level.
|
||
|
||
Return an exception, either ``exc`` or a new exception, if the exception is
|
||
fatal i.e. when trying to reconnect will most likely produce the same error.
|
||
That exception will be raised, breaking out of the retry loop.
|
||
|
||
"""
|
||
# This catches python-socks' ProxyConnectionError and ProxyTimeoutError.
|
||
# Remove asyncio.TimeoutError when dropping Python < 3.11.
|
||
if isinstance(exc, (OSError, TimeoutError, asyncio.TimeoutError)):
|
||
return None
|
||
if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError):
|
||
return None
|
||
if isinstance(exc, InvalidStatus) and exc.response.status_code in [
|
||
500, # Internal Server Error
|
||
502, # Bad Gateway
|
||
503, # Service Unavailable
|
||
504, # Gateway Timeout
|
||
]:
|
||
return None
|
||
return exc
|
||
|
||
|
||
# This is spelled in lower case because it's exposed as a callable in the API.
|
||
class connect:
|
||
"""
|
||
Connect to the WebSocket server at ``uri``.
|
||
|
||
This coroutine returns a :class:`ClientConnection` instance, which you can
|
||
use to send and receive messages.
|
||
|
||
:func:`connect` may be used as an asynchronous context manager::
|
||
|
||
from websockets.asyncio.client import connect
|
||
|
||
async with connect(...) as websocket:
|
||
...
|
||
|
||
The connection is closed automatically when exiting the context.
|
||
|
||
:func:`connect` can be used as an infinite asynchronous iterator to
|
||
reconnect automatically on errors::
|
||
|
||
async for websocket in connect(...):
|
||
try:
|
||
...
|
||
except websockets.exceptions.ConnectionClosed:
|
||
continue
|
||
|
||
If the connection fails with a transient error, it is retried with
|
||
exponential backoff. If it fails with a fatal error, the exception is
|
||
raised, breaking out of the loop.
|
||
|
||
The connection is closed automatically after each iteration of the loop.
|
||
|
||
Args:
|
||
uri: URI of the WebSocket server.
|
||
origin: Value of the ``Origin`` header, for servers that require it.
|
||
extensions: List of supported extensions, in order in which they
|
||
should be negotiated and run.
|
||
subprotocols: List of supported subprotocols, in order of decreasing
|
||
preference.
|
||
compression: The "permessage-deflate" extension is enabled by default.
|
||
Set ``compression`` to :obj:`None` to disable it. See the
|
||
:doc:`compression guide <../../topics/compression>` for details.
|
||
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
|
||
to the handshake request.
|
||
user_agent_header: Value of the ``User-Agent`` request header.
|
||
It defaults to ``"Python/x.y.z websockets/X.Y"``.
|
||
Setting it to :obj:`None` removes the header.
|
||
proxy: If a proxy is configured, it is used by default. Set ``proxy``
|
||
to :obj:`None` to disable the proxy or to the address of a proxy
|
||
to override the system configuration. See the :doc:`proxy docs
|
||
<../../topics/proxies>` for details.
|
||
process_exception: When reconnecting automatically, tell whether an
|
||
error is transient or fatal. The default behavior is defined by
|
||
:func:`process_exception`. Refer to its documentation for details.
|
||
open_timeout: Timeout for opening the connection in seconds.
|
||
:obj:`None` disables the timeout.
|
||
ping_interval: Interval between keepalive pings in seconds.
|
||
:obj:`None` disables keepalive.
|
||
ping_timeout: Timeout for keepalive pings in seconds.
|
||
:obj:`None` disables timeouts.
|
||
close_timeout: Timeout for closing the connection in seconds.
|
||
:obj:`None` disables the timeout.
|
||
max_size: Maximum size of incoming messages in bytes.
|
||
:obj:`None` disables the limit.
|
||
max_queue: High-water mark of the buffer where frames are received.
|
||
It defaults to 16 frames. The low-water mark defaults to ``max_queue
|
||
// 4``. You may pass a ``(high, low)`` tuple to set the high-water
|
||
and low-water marks. If you want to disable flow control entirely,
|
||
you may set it to ``None``, although that's a bad idea.
|
||
write_limit: High-water mark of write buffer in bytes. It is passed to
|
||
:meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults
|
||
to 32 KiB. You may pass a ``(high, low)`` tuple to set the
|
||
high-water and low-water marks.
|
||
logger: Logger for this client.
|
||
It defaults to ``logging.getLogger("websockets.client")``.
|
||
See the :doc:`logging guide <../../topics/logging>` for details.
|
||
create_connection: Factory for the :class:`ClientConnection` managing
|
||
the connection. Set it to a wrapper or a subclass to customize
|
||
connection handling.
|
||
|
||
Any other keyword arguments are passed to the event loop's
|
||
:meth:`~asyncio.loop.create_connection` method.
|
||
|
||
For example:
|
||
|
||
* You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings.
|
||
When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS
|
||
context is created with :func:`~ssl.create_default_context`.
|
||
|
||
* You can set ``server_hostname`` to override the host name from ``uri`` in
|
||
the TLS handshake.
|
||
|
||
* You can set ``host`` and ``port`` to connect to a different host and port
|
||
from those found in ``uri``. This only changes the destination of the TCP
|
||
connection. The host name from ``uri`` is still used in the TLS handshake
|
||
for secure connections and in the ``Host`` header.
|
||
|
||
* You can set ``sock`` to provide a preexisting TCP socket. You may call
|
||
:func:`socket.create_connection` (not to be confused with the event loop's
|
||
:meth:`~asyncio.loop.create_connection` method) to create a suitable
|
||
client socket and customize it.
|
||
|
||
When using a proxy:
|
||
|
||
* Prefix keyword arguments with ``proxy_`` for configuring TLS between the
|
||
client and an HTTPS proxy: ``proxy_ssl``, ``proxy_server_hostname``,
|
||
``proxy_ssl_handshake_timeout``, and ``proxy_ssl_shutdown_timeout``.
|
||
* Use the standard keyword arguments for configuring TLS between the proxy
|
||
and the WebSocket server: ``ssl``, ``server_hostname``,
|
||
``ssl_handshake_timeout``, and ``ssl_shutdown_timeout``.
|
||
* Other keyword arguments are used only for connecting to the proxy.
|
||
|
||
Raises:
|
||
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
|
||
InvalidProxy: If ``proxy`` isn't a valid proxy.
|
||
OSError: If the TCP connection fails.
|
||
InvalidHandshake: If the opening handshake fails.
|
||
TimeoutError: If the opening handshake times out.
|
||
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
uri: str,
|
||
*,
|
||
# WebSocket
|
||
origin: Origin | None = None,
|
||
extensions: Sequence[ClientExtensionFactory] | None = None,
|
||
subprotocols: Sequence[Subprotocol] | None = None,
|
||
compression: str | None = "deflate",
|
||
# HTTP
|
||
additional_headers: HeadersLike | None = None,
|
||
user_agent_header: str | None = USER_AGENT,
|
||
proxy: str | Literal[True] | None = True,
|
||
process_exception: Callable[[Exception], Exception | None] = process_exception,
|
||
# Timeouts
|
||
open_timeout: float | None = 10,
|
||
ping_interval: float | None = 20,
|
||
ping_timeout: float | None = 20,
|
||
close_timeout: float | None = 10,
|
||
# Limits
|
||
max_size: int | None = 2**20,
|
||
max_queue: int | None | tuple[int | None, int | None] = 16,
|
||
write_limit: int | tuple[int, int | None] = 2**15,
|
||
# Logging
|
||
logger: LoggerLike | None = None,
|
||
# Escape hatch for advanced customization
|
||
create_connection: type[ClientConnection] | None = None,
|
||
# Other keyword arguments are passed to loop.create_connection
|
||
**kwargs: Any,
|
||
) -> None:
|
||
self.uri = uri
|
||
|
||
if subprotocols is not None:
|
||
validate_subprotocols(subprotocols)
|
||
|
||
if compression == "deflate":
|
||
extensions = enable_client_permessage_deflate(extensions)
|
||
elif compression is not None:
|
||
raise ValueError(f"unsupported compression: {compression}")
|
||
|
||
if logger is None:
|
||
logger = logging.getLogger("websockets.client")
|
||
|
||
if create_connection is None:
|
||
create_connection = ClientConnection
|
||
|
||
def protocol_factory(uri: WebSocketURI) -> ClientConnection:
|
||
# This is a protocol in the Sans-I/O implementation of websockets.
|
||
protocol = ClientProtocol(
|
||
uri,
|
||
origin=origin,
|
||
extensions=extensions,
|
||
subprotocols=subprotocols,
|
||
max_size=max_size,
|
||
logger=logger,
|
||
)
|
||
# This is a connection in websockets and a protocol in asyncio.
|
||
connection = create_connection(
|
||
protocol,
|
||
ping_interval=ping_interval,
|
||
ping_timeout=ping_timeout,
|
||
close_timeout=close_timeout,
|
||
max_queue=max_queue,
|
||
write_limit=write_limit,
|
||
)
|
||
return connection
|
||
|
||
self.proxy = proxy
|
||
self.protocol_factory = protocol_factory
|
||
self.additional_headers = additional_headers
|
||
self.user_agent_header = user_agent_header
|
||
self.process_exception = process_exception
|
||
self.open_timeout = open_timeout
|
||
self.logger = logger
|
||
self.connection_kwargs = kwargs
|
||
|
||
async def create_connection(self) -> ClientConnection:
|
||
"""Create TCP or Unix connection."""
|
||
loop = asyncio.get_running_loop()
|
||
kwargs = self.connection_kwargs.copy()
|
||
|
||
ws_uri = parse_uri(self.uri)
|
||
|
||
proxy = self.proxy
|
||
if kwargs.get("unix", False):
|
||
proxy = None
|
||
if kwargs.get("sock") is not None:
|
||
proxy = None
|
||
if proxy is True:
|
||
proxy = get_proxy(ws_uri)
|
||
|
||
def factory() -> ClientConnection:
|
||
return self.protocol_factory(ws_uri)
|
||
|
||
if ws_uri.secure:
|
||
kwargs.setdefault("ssl", True)
|
||
kwargs.setdefault("server_hostname", ws_uri.host)
|
||
if kwargs.get("ssl") is None:
|
||
raise ValueError("ssl=None is incompatible with a wss:// URI")
|
||
else:
|
||
if kwargs.get("ssl") is not None:
|
||
raise ValueError("ssl argument is incompatible with a ws:// URI")
|
||
|
||
if kwargs.pop("unix", False):
|
||
_, connection = await loop.create_unix_connection(factory, **kwargs)
|
||
elif proxy is not None:
|
||
proxy_parsed = parse_proxy(proxy)
|
||
if proxy_parsed.scheme[:5] == "socks":
|
||
# Connect to the server through the proxy.
|
||
sock = await connect_socks_proxy(
|
||
proxy_parsed,
|
||
ws_uri,
|
||
local_addr=kwargs.pop("local_addr", None),
|
||
)
|
||
# Initialize WebSocket connection via the proxy.
|
||
_, connection = await loop.create_connection(
|
||
factory,
|
||
sock=sock,
|
||
**kwargs,
|
||
)
|
||
elif proxy_parsed.scheme[:4] == "http":
|
||
# Split keyword arguments between the proxy and the server.
|
||
all_kwargs, proxy_kwargs, kwargs = kwargs, {}, {}
|
||
for key, value in all_kwargs.items():
|
||
if key.startswith("ssl") or key == "server_hostname":
|
||
kwargs[key] = value
|
||
elif key.startswith("proxy_"):
|
||
proxy_kwargs[key[6:]] = value
|
||
else:
|
||
proxy_kwargs[key] = value
|
||
# Validate the proxy_ssl argument.
|
||
if proxy_parsed.scheme == "https":
|
||
proxy_kwargs.setdefault("ssl", True)
|
||
if proxy_kwargs.get("ssl") is None:
|
||
raise ValueError(
|
||
"proxy_ssl=None is incompatible with an https:// proxy"
|
||
)
|
||
else:
|
||
if proxy_kwargs.get("ssl") is not None:
|
||
raise ValueError(
|
||
"proxy_ssl argument is incompatible with an http:// proxy"
|
||
)
|
||
# Connect to the server through the proxy.
|
||
transport = await connect_http_proxy(
|
||
proxy_parsed,
|
||
ws_uri,
|
||
user_agent_header=self.user_agent_header,
|
||
**proxy_kwargs,
|
||
)
|
||
# Initialize WebSocket connection via the proxy.
|
||
connection = factory()
|
||
transport.set_protocol(connection)
|
||
ssl = kwargs.pop("ssl", None)
|
||
if ssl is True:
|
||
ssl = ssl_module.create_default_context()
|
||
if ssl is not None:
|
||
new_transport = await loop.start_tls(
|
||
transport, connection, ssl, **kwargs
|
||
)
|
||
assert new_transport is not None # help mypy
|
||
transport = new_transport
|
||
connection.connection_made(transport)
|
||
else:
|
||
raise AssertionError("unsupported proxy")
|
||
else:
|
||
# Connect to the server directly.
|
||
if kwargs.get("sock") is None:
|
||
kwargs.setdefault("host", ws_uri.host)
|
||
kwargs.setdefault("port", ws_uri.port)
|
||
# Initialize WebSocket connection.
|
||
_, connection = await loop.create_connection(factory, **kwargs)
|
||
return connection
|
||
|
||
def process_redirect(self, exc: Exception) -> Exception | str:
|
||
"""
|
||
Determine whether a connection error is a redirect that can be followed.
|
||
|
||
Return the new URI if it's a valid redirect. Else, return an exception.
|
||
|
||
"""
|
||
if not (
|
||
isinstance(exc, InvalidStatus)
|
||
and exc.response.status_code
|
||
in [
|
||
300, # Multiple Choices
|
||
301, # Moved Permanently
|
||
302, # Found
|
||
303, # See Other
|
||
307, # Temporary Redirect
|
||
308, # Permanent Redirect
|
||
]
|
||
and "Location" in exc.response.headers
|
||
):
|
||
return exc
|
||
|
||
old_ws_uri = parse_uri(self.uri)
|
||
new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
|
||
new_ws_uri = parse_uri(new_uri)
|
||
|
||
# If connect() received a socket, it is closed and cannot be reused.
|
||
if self.connection_kwargs.get("sock") is not None:
|
||
return ValueError(
|
||
f"cannot follow redirect to {new_uri} with a preexisting socket"
|
||
)
|
||
|
||
# TLS downgrade is forbidden.
|
||
if old_ws_uri.secure and not new_ws_uri.secure:
|
||
return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}")
|
||
|
||
# Apply restrictions to cross-origin redirects.
|
||
if (
|
||
old_ws_uri.secure != new_ws_uri.secure
|
||
or old_ws_uri.host != new_ws_uri.host
|
||
or old_ws_uri.port != new_ws_uri.port
|
||
):
|
||
# Cross-origin redirects on Unix sockets don't quite make sense.
|
||
if self.connection_kwargs.get("unix", False):
|
||
return ValueError(
|
||
f"cannot follow cross-origin redirect to {new_uri} "
|
||
f"with a Unix socket"
|
||
)
|
||
|
||
# Cross-origin redirects when host and port are overridden are ill-defined.
|
||
if (
|
||
self.connection_kwargs.get("host") is not None
|
||
or self.connection_kwargs.get("port") is not None
|
||
):
|
||
return ValueError(
|
||
f"cannot follow cross-origin redirect to {new_uri} "
|
||
f"with an explicit host or port"
|
||
)
|
||
|
||
return new_uri
|
||
|
||
# ... = await connect(...)
|
||
|
||
def __await__(self) -> Generator[Any, None, ClientConnection]:
|
||
# Create a suitable iterator by calling __await__ on a coroutine.
|
||
return self.__await_impl__().__await__()
|
||
|
||
async def __await_impl__(self) -> ClientConnection:
|
||
try:
|
||
async with asyncio_timeout(self.open_timeout):
|
||
for _ in range(MAX_REDIRECTS):
|
||
self.connection = await self.create_connection()
|
||
try:
|
||
await self.connection.handshake(
|
||
self.additional_headers,
|
||
self.user_agent_header,
|
||
)
|
||
except asyncio.CancelledError:
|
||
self.connection.transport.abort()
|
||
raise
|
||
except Exception as exc:
|
||
# Always close the connection even though keep-alive is
|
||
# the default in HTTP/1.1 because create_connection ties
|
||
# opening the network connection with initializing the
|
||
# protocol. In the current design of connect(), there is
|
||
# no easy way to reuse the network connection that works
|
||
# in every case nor to reinitialize the protocol.
|
||
self.connection.transport.abort()
|
||
|
||
uri_or_exc = self.process_redirect(exc)
|
||
# Response is a valid redirect; follow it.
|
||
if isinstance(uri_or_exc, str):
|
||
self.uri = uri_or_exc
|
||
continue
|
||
# Response isn't a valid redirect; raise the exception.
|
||
if uri_or_exc is exc:
|
||
raise
|
||
else:
|
||
raise uri_or_exc from exc
|
||
|
||
else:
|
||
self.connection.start_keepalive()
|
||
return self.connection
|
||
else:
|
||
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")
|
||
|
||
except TimeoutError as exc:
|
||
# Re-raise exception with an informative error message.
|
||
raise TimeoutError("timed out during opening handshake") from exc
|
||
|
||
# ... = yield from connect(...) - remove when dropping Python < 3.10
|
||
|
||
__iter__ = __await__
|
||
|
||
# async with connect(...) as ...: ...
|
||
|
||
async def __aenter__(self) -> ClientConnection:
|
||
return await self
|
||
|
||
async def __aexit__(
|
||
self,
|
||
exc_type: type[BaseException] | None,
|
||
exc_value: BaseException | None,
|
||
traceback: TracebackType | None,
|
||
) -> None:
|
||
await self.connection.close()
|
||
|
||
# async for ... in connect(...):
|
||
|
||
async def __aiter__(self) -> AsyncIterator[ClientConnection]:
|
||
delays: Generator[float] | None = None
|
||
while True:
|
||
try:
|
||
async with self as protocol:
|
||
yield protocol
|
||
except Exception as exc:
|
||
# Determine whether the exception is retryable or fatal.
|
||
# The API of process_exception is "return an exception or None";
|
||
# "raise an exception" is also supported because it's a frequent
|
||
# mistake. It isn't documented in order to keep the API simple.
|
||
try:
|
||
new_exc = self.process_exception(exc)
|
||
except Exception as raised_exc:
|
||
new_exc = raised_exc
|
||
|
||
# The connection failed with a fatal error.
|
||
# Raise the exception and exit the loop.
|
||
if new_exc is exc:
|
||
raise
|
||
if new_exc is not None:
|
||
raise new_exc from exc
|
||
|
||
# The connection failed with a retryable error.
|
||
# Start or continue backoff and reconnect.
|
||
if delays is None:
|
||
delays = backoff()
|
||
delay = next(delays)
|
||
self.logger.info(
|
||
"connect failed; reconnecting in %.1f seconds: %s",
|
||
delay,
|
||
# Remove first argument when dropping Python 3.9.
|
||
traceback.format_exception_only(type(exc), exc)[0].strip(),
|
||
)
|
||
await asyncio.sleep(delay)
|
||
continue
|
||
|
||
else:
|
||
# The connection succeeded. Reset backoff.
|
||
delays = None
|
||
|
||
|
||
def unix_connect(
|
||
path: str | None = None,
|
||
uri: str | None = None,
|
||
**kwargs: Any,
|
||
) -> connect:
|
||
"""
|
||
Connect to a WebSocket server listening on a Unix socket.
|
||
|
||
This function accepts the same keyword arguments as :func:`connect`.
|
||
|
||
It's only available on Unix.
|
||
|
||
It's mainly useful for debugging servers listening on Unix sockets.
|
||
|
||
Args:
|
||
path: File system path to the Unix socket.
|
||
uri: URI of the WebSocket server. ``uri`` defaults to
|
||
``ws://localhost/`` or, when a ``ssl`` argument is provided, to
|
||
``wss://localhost/``.
|
||
|
||
"""
|
||
if uri is None:
|
||
if kwargs.get("ssl") is None:
|
||
uri = "ws://localhost/"
|
||
else:
|
||
uri = "wss://localhost/"
|
||
return connect(uri=uri, unix=True, path=path, **kwargs)
|
||
|
||
|
||
try:
|
||
from python_socks import ProxyType
|
||
from python_socks.async_.asyncio import Proxy as SocksProxy
|
||
|
||
SOCKS_PROXY_TYPES = {
|
||
"socks5h": ProxyType.SOCKS5,
|
||
"socks5": ProxyType.SOCKS5,
|
||
"socks4a": ProxyType.SOCKS4,
|
||
"socks4": ProxyType.SOCKS4,
|
||
}
|
||
|
||
SOCKS_PROXY_RDNS = {
|
||
"socks5h": True,
|
||
"socks5": False,
|
||
"socks4a": True,
|
||
"socks4": False,
|
||
}
|
||
|
||
async def connect_socks_proxy(
|
||
proxy: Proxy,
|
||
ws_uri: WebSocketURI,
|
||
**kwargs: Any,
|
||
) -> socket.socket:
|
||
"""Connect via a SOCKS proxy and return the socket."""
|
||
socks_proxy = SocksProxy(
|
||
SOCKS_PROXY_TYPES[proxy.scheme],
|
||
proxy.host,
|
||
proxy.port,
|
||
proxy.username,
|
||
proxy.password,
|
||
SOCKS_PROXY_RDNS[proxy.scheme],
|
||
)
|
||
# connect() is documented to raise OSError.
|
||
# socks_proxy.connect() doesn't raise TimeoutError; it gets canceled.
|
||
# Wrap other exceptions in ProxyError, a subclass of InvalidHandshake.
|
||
try:
|
||
return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
|
||
except OSError:
|
||
raise
|
||
except Exception as exc:
|
||
raise ProxyError("failed to connect to SOCKS proxy") from exc
|
||
|
||
except ImportError:
|
||
|
||
async def connect_socks_proxy(
|
||
proxy: Proxy,
|
||
ws_uri: WebSocketURI,
|
||
**kwargs: Any,
|
||
) -> socket.socket:
|
||
raise ImportError("python-socks is required to use a SOCKS proxy")
|
||
|
||
|
||
def prepare_connect_request(
|
||
proxy: Proxy,
|
||
ws_uri: WebSocketURI,
|
||
user_agent_header: str | None = None,
|
||
) -> bytes:
|
||
host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True)
|
||
headers = Headers()
|
||
headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure)
|
||
if user_agent_header is not None:
|
||
headers["User-Agent"] = user_agent_header
|
||
if proxy.username is not None:
|
||
assert proxy.password is not None # enforced by parse_proxy()
|
||
headers["Proxy-Authorization"] = build_authorization_basic(
|
||
proxy.username, proxy.password
|
||
)
|
||
# We cannot use the Request class because it supports only GET requests.
|
||
return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize()
|
||
|
||
|
||
class HTTPProxyConnection(asyncio.Protocol):
|
||
def __init__(
|
||
self,
|
||
ws_uri: WebSocketURI,
|
||
proxy: Proxy,
|
||
user_agent_header: str | None = None,
|
||
):
|
||
self.ws_uri = ws_uri
|
||
self.proxy = proxy
|
||
self.user_agent_header = user_agent_header
|
||
|
||
self.reader = StreamReader()
|
||
self.parser = Response.parse(
|
||
self.reader.read_line,
|
||
self.reader.read_exact,
|
||
self.reader.read_to_eof,
|
||
include_body=False,
|
||
)
|
||
|
||
loop = asyncio.get_running_loop()
|
||
self.response: asyncio.Future[Response] = loop.create_future()
|
||
|
||
def run_parser(self) -> None:
|
||
try:
|
||
next(self.parser)
|
||
except StopIteration as exc:
|
||
response = exc.value
|
||
if 200 <= response.status_code < 300:
|
||
self.response.set_result(response)
|
||
else:
|
||
self.response.set_exception(InvalidProxyStatus(response))
|
||
except Exception as exc:
|
||
proxy_exc = InvalidProxyMessage(
|
||
"did not receive a valid HTTP response from proxy"
|
||
)
|
||
proxy_exc.__cause__ = exc
|
||
self.response.set_exception(proxy_exc)
|
||
|
||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||
transport = cast(asyncio.Transport, transport)
|
||
self.transport = transport
|
||
self.transport.write(
|
||
prepare_connect_request(self.proxy, self.ws_uri, self.user_agent_header)
|
||
)
|
||
|
||
def data_received(self, data: bytes) -> None:
|
||
self.reader.feed_data(data)
|
||
self.run_parser()
|
||
|
||
def eof_received(self) -> None:
|
||
self.reader.feed_eof()
|
||
self.run_parser()
|
||
|
||
def connection_lost(self, exc: Exception | None) -> None:
|
||
self.reader.feed_eof()
|
||
if exc is not None:
|
||
self.response.set_exception(exc)
|
||
|
||
|
||
async def connect_http_proxy(
|
||
proxy: Proxy,
|
||
ws_uri: WebSocketURI,
|
||
user_agent_header: str | None = None,
|
||
**kwargs: Any,
|
||
) -> asyncio.Transport:
|
||
transport, protocol = await asyncio.get_running_loop().create_connection(
|
||
lambda: HTTPProxyConnection(ws_uri, proxy, user_agent_header),
|
||
proxy.host,
|
||
proxy.port,
|
||
**kwargs,
|
||
)
|
||
|
||
try:
|
||
# This raises exceptions if the connection to the proxy fails.
|
||
await protocol.response
|
||
except Exception:
|
||
transport.close()
|
||
raise
|
||
|
||
return transport
|