508 lines
16 KiB
Python
508 lines
16 KiB
Python
import codecs
|
|
import logging
|
|
import mimetypes
|
|
import netrc
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import typing
|
|
from pathlib import Path
|
|
from urllib.request import getproxies
|
|
|
|
import sniffio
|
|
|
|
from ._types import PrimitiveData
|
|
|
|
if typing.TYPE_CHECKING: # pragma: no cover
|
|
from ._models import URL
|
|
|
|
|
|
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
|
|
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
|
|
{chr(c): "%{:02X}".format(c) for c in range(0x00, 0x1F + 1) if c != 0x1B}
|
|
)
|
|
_HTML5_FORM_ENCODING_RE = re.compile(
|
|
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
|
|
)
|
|
|
|
|
|
def normalize_header_key(
|
|
value: typing.Union[str, bytes],
|
|
lower: bool,
|
|
encoding: str = None,
|
|
) -> bytes:
|
|
"""
|
|
Coerce str/bytes into a strictly byte-wise HTTP header key.
|
|
"""
|
|
if isinstance(value, bytes):
|
|
bytes_value = value
|
|
else:
|
|
bytes_value = value.encode(encoding or "ascii")
|
|
|
|
return bytes_value.lower() if lower else bytes_value
|
|
|
|
|
|
def normalize_header_value(
|
|
value: typing.Union[str, bytes], encoding: str = None
|
|
) -> bytes:
|
|
"""
|
|
Coerce str/bytes into a strictly byte-wise HTTP header value.
|
|
"""
|
|
if isinstance(value, bytes):
|
|
return value
|
|
return value.encode(encoding or "ascii")
|
|
|
|
|
|
def primitive_value_to_str(value: "PrimitiveData") -> str:
|
|
"""
|
|
Coerce a primitive data type into a string value.
|
|
|
|
Note that we prefer JSON-style 'true'/'false' for boolean values here.
|
|
"""
|
|
if value is True:
|
|
return "true"
|
|
elif value is False:
|
|
return "false"
|
|
elif value is None:
|
|
return ""
|
|
return str(value)
|
|
|
|
|
|
def is_known_encoding(encoding: str) -> bool:
|
|
"""
|
|
Return `True` if `encoding` is a known codec.
|
|
"""
|
|
try:
|
|
codecs.lookup(encoding)
|
|
except LookupError:
|
|
return False
|
|
return True
|
|
|
|
|
|
def format_form_param(name: str, value: typing.Union[str, bytes]) -> bytes:
|
|
"""
|
|
Encode a name/value pair within a multipart form.
|
|
"""
|
|
if isinstance(value, bytes):
|
|
value = value.decode()
|
|
|
|
def replacer(match: typing.Match[str]) -> str:
|
|
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
|
|
|
|
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
|
|
return f'{name}="{value}"'.encode()
|
|
|
|
|
|
# Null bytes; no need to recreate these on each call to guess_json_utf
|
|
_null = b"\x00"
|
|
_null2 = _null * 2
|
|
_null3 = _null * 3
|
|
|
|
|
|
def guess_json_utf(data: bytes) -> typing.Optional[str]:
|
|
# JSON always starts with two ASCII characters, so detection is as
|
|
# easy as counting the nulls and from their location and count
|
|
# determine the encoding. Also detect a BOM, if present.
|
|
sample = data[:4]
|
|
if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE):
|
|
return "utf-32" # BOM included
|
|
if sample[:3] == codecs.BOM_UTF8:
|
|
return "utf-8-sig" # BOM included, MS style (discouraged)
|
|
if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
|
|
return "utf-16" # BOM included
|
|
nullcount = sample.count(_null)
|
|
if nullcount == 0:
|
|
return "utf-8"
|
|
if nullcount == 2:
|
|
if sample[::2] == _null2: # 1st and 3rd are null
|
|
return "utf-16-be"
|
|
if sample[1::2] == _null2: # 2nd and 4th are null
|
|
return "utf-16-le"
|
|
# Did not detect 2 valid UTF-16 ascii-range characters
|
|
if nullcount == 3:
|
|
if sample[:3] == _null3:
|
|
return "utf-32-be"
|
|
if sample[1:] == _null3:
|
|
return "utf-32-le"
|
|
# Did not detect a valid UTF-32 ascii-range character
|
|
return None
|
|
|
|
|
|
class NetRCInfo:
|
|
def __init__(self, files: typing.Optional[typing.List[str]] = None) -> None:
|
|
if files is None:
|
|
files = [os.getenv("NETRC", ""), "~/.netrc", "~/_netrc"]
|
|
self.netrc_files = files
|
|
|
|
@property
|
|
def netrc_info(self) -> typing.Optional[netrc.netrc]:
|
|
if not hasattr(self, "_netrc_info"):
|
|
self._netrc_info = None
|
|
for file_path in self.netrc_files:
|
|
expanded_path = Path(file_path).expanduser()
|
|
try:
|
|
if expanded_path.is_file():
|
|
self._netrc_info = netrc.netrc(str(expanded_path))
|
|
break
|
|
except (netrc.NetrcParseError, IOError): # pragma: nocover
|
|
# Issue while reading the netrc file, ignore...
|
|
pass
|
|
return self._netrc_info
|
|
|
|
def get_credentials(self, host: str) -> typing.Optional[typing.Tuple[str, str]]:
|
|
if self.netrc_info is None:
|
|
return None
|
|
|
|
auth_info = self.netrc_info.authenticators(host)
|
|
if auth_info is None or auth_info[2] is None:
|
|
return None
|
|
return (auth_info[0], auth_info[2])
|
|
|
|
|
|
def get_ca_bundle_from_env() -> typing.Optional[str]:
|
|
if "SSL_CERT_FILE" in os.environ:
|
|
ssl_file = Path(os.environ["SSL_CERT_FILE"])
|
|
if ssl_file.is_file():
|
|
return str(ssl_file)
|
|
if "SSL_CERT_DIR" in os.environ:
|
|
ssl_path = Path(os.environ["SSL_CERT_DIR"])
|
|
if ssl_path.is_dir():
|
|
return str(ssl_path)
|
|
return None
|
|
|
|
|
|
def parse_header_links(value: str) -> typing.List[typing.Dict[str, str]]:
|
|
"""
|
|
Returns a list of parsed link headers, for more info see:
|
|
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
|
|
The generic syntax of those is:
|
|
Link: < uri-reference >; param1=value1; param2="value2"
|
|
So for instance:
|
|
Link; '<http:/.../front.jpeg>; type="image/jpeg",<http://.../back.jpeg>;'
|
|
would return
|
|
[
|
|
{"url": "http:/.../front.jpeg", "type": "image/jpeg"},
|
|
{"url": "http://.../back.jpeg"},
|
|
]
|
|
:param value: HTTP Link entity-header field
|
|
:return: list of parsed link headers
|
|
"""
|
|
links: typing.List[typing.Dict[str, str]] = []
|
|
replace_chars = " '\""
|
|
value = value.strip(replace_chars)
|
|
if not value:
|
|
return links
|
|
for val in re.split(", *<", value):
|
|
try:
|
|
url, params = val.split(";", 1)
|
|
except ValueError:
|
|
url, params = val, ""
|
|
link = {"url": url.strip("<> '\"")}
|
|
for param in params.split(";"):
|
|
try:
|
|
key, value = param.split("=")
|
|
except ValueError:
|
|
break
|
|
link[key.strip(replace_chars)] = value.strip(replace_chars)
|
|
links.append(link)
|
|
return links
|
|
|
|
|
|
SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}
|
|
|
|
|
|
def obfuscate_sensitive_headers(
|
|
items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]]
|
|
) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]:
|
|
for k, v in items:
|
|
if to_str(k.lower()) in SENSITIVE_HEADERS:
|
|
v = to_bytes_or_str("[secure]", match_type_of=v)
|
|
yield k, v
|
|
|
|
|
|
_LOGGER_INITIALIZED = False
|
|
TRACE_LOG_LEVEL = 5
|
|
|
|
|
|
class Logger(logging.Logger):
|
|
# Stub for type checkers.
|
|
def trace(self, message: str, *args: typing.Any, **kwargs: typing.Any) -> None:
|
|
... # pragma: nocover
|
|
|
|
|
|
def get_logger(name: str) -> Logger:
|
|
"""
|
|
Get a `logging.Logger` instance, and optionally
|
|
set up debug logging based on the HTTPX_LOG_LEVEL environment variable.
|
|
"""
|
|
global _LOGGER_INITIALIZED
|
|
|
|
if not _LOGGER_INITIALIZED:
|
|
_LOGGER_INITIALIZED = True
|
|
logging.addLevelName(TRACE_LOG_LEVEL, "TRACE")
|
|
|
|
log_level = os.environ.get("HTTPX_LOG_LEVEL", "").upper()
|
|
if log_level in ("DEBUG", "TRACE"):
|
|
logger = logging.getLogger("httpx")
|
|
logger.setLevel(logging.DEBUG if log_level == "DEBUG" else TRACE_LOG_LEVEL)
|
|
handler = logging.StreamHandler(sys.stderr)
|
|
handler.setFormatter(
|
|
logging.Formatter(
|
|
fmt="%(levelname)s [%(asctime)s] %(name)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
)
|
|
logger.addHandler(handler)
|
|
|
|
logger = logging.getLogger(name)
|
|
|
|
def trace(message: str, *args: typing.Any, **kwargs: typing.Any) -> None:
|
|
logger.log(TRACE_LOG_LEVEL, message, *args, **kwargs)
|
|
|
|
logger.trace = trace # type: ignore
|
|
|
|
return typing.cast(Logger, logger)
|
|
|
|
|
|
def port_or_default(url: "URL") -> typing.Optional[int]:
|
|
if url.port is not None:
|
|
return url.port
|
|
return {"http": 80, "https": 443}.get(url.scheme)
|
|
|
|
|
|
def same_origin(url: "URL", other: "URL") -> bool:
|
|
"""
|
|
Return 'True' if the given URLs share the same origin.
|
|
"""
|
|
return (
|
|
url.scheme == other.scheme
|
|
and url.host == other.host
|
|
and port_or_default(url) == port_or_default(other)
|
|
)
|
|
|
|
|
|
def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]:
|
|
"""Gets proxy information from the environment"""
|
|
|
|
# urllib.request.getproxies() falls back on System
|
|
# Registry and Config for proxies on Windows and macOS.
|
|
# We don't want to propagate non-HTTP proxies into
|
|
# our configuration such as 'TRAVIS_APT_PROXY'.
|
|
proxy_info = getproxies()
|
|
mounts: typing.Dict[str, typing.Optional[str]] = {}
|
|
|
|
for scheme in ("http", "https", "all"):
|
|
if proxy_info.get(scheme):
|
|
hostname = proxy_info[scheme]
|
|
mounts[f"{scheme}://"] = (
|
|
hostname if "://" in hostname else f"http://{hostname}"
|
|
)
|
|
|
|
no_proxy_hosts = [host.strip() for host in proxy_info.get("no", "").split(",")]
|
|
for hostname in no_proxy_hosts:
|
|
# See https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html for details
|
|
# on how names in `NO_PROXY` are handled.
|
|
if hostname == "*":
|
|
# If NO_PROXY=* is used or if "*" occurs as any one of the comma
|
|
# seperated hostnames, then we should just bypass any information
|
|
# from HTTP_PROXY, HTTPS_PROXY, ALL_PROXY, and always ignore
|
|
# proxies.
|
|
return {}
|
|
elif hostname:
|
|
# NO_PROXY=.google.com is marked as "all://*.google.com,
|
|
# which disables "www.google.com" but not "google.com"
|
|
# NO_PROXY=google.com is marked as "all://*google.com,
|
|
# which disables "www.google.com" and "google.com".
|
|
# (But not "wwwgoogle.com")
|
|
mounts[f"all://*{hostname}"] = None
|
|
|
|
return mounts
|
|
|
|
|
|
def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes:
|
|
return value.encode(encoding) if isinstance(value, str) else value
|
|
|
|
|
|
def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str:
|
|
return value if isinstance(value, str) else value.decode(encoding)
|
|
|
|
|
|
def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr:
|
|
return value if isinstance(match_type_of, str) else value.encode()
|
|
|
|
|
|
def unquote(value: str) -> str:
|
|
return value[1:-1] if value[0] == value[-1] == '"' else value
|
|
|
|
|
|
def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]:
|
|
if filename:
|
|
return mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
|
return None
|
|
|
|
|
|
def peek_filelike_length(stream: typing.Any) -> typing.Optional[int]:
|
|
"""
|
|
Given a file-like stream object, return its length in number of bytes
|
|
without reading it into memory.
|
|
"""
|
|
try:
|
|
# Is it an actual file?
|
|
fd = stream.fileno()
|
|
# Yup, seems to be an actual file.
|
|
length = os.fstat(fd).st_size
|
|
except (AttributeError, OSError):
|
|
# No... Maybe it's something that supports random access, like `io.BytesIO`?
|
|
try:
|
|
# Assuming so, go to end of stream to figure out its length,
|
|
# then put it back in place.
|
|
offset = stream.tell()
|
|
length = stream.seek(0, os.SEEK_END)
|
|
stream.seek(offset)
|
|
except (AttributeError, OSError):
|
|
# Not even that? Sorry, we're doomed...
|
|
return None
|
|
|
|
return length
|
|
|
|
|
|
class Timer:
|
|
async def _get_time(self) -> float:
|
|
library = sniffio.current_async_library()
|
|
if library == "trio":
|
|
import trio
|
|
|
|
return trio.current_time()
|
|
elif library == "curio": # pragma: nocover
|
|
import curio
|
|
|
|
return await curio.clock()
|
|
|
|
import asyncio
|
|
|
|
return asyncio.get_event_loop().time()
|
|
|
|
def sync_start(self) -> None:
|
|
self.started = time.perf_counter()
|
|
|
|
async def async_start(self) -> None:
|
|
self.started = await self._get_time()
|
|
|
|
def sync_elapsed(self) -> float:
|
|
now = time.perf_counter()
|
|
return now - self.started
|
|
|
|
async def async_elapsed(self) -> float:
|
|
now = await self._get_time()
|
|
return now - self.started
|
|
|
|
|
|
class URLPattern:
|
|
"""
|
|
A utility class currently used for making lookups against proxy keys...
|
|
|
|
# Wildcard matching...
|
|
>>> pattern = URLPattern("all")
|
|
>>> pattern.matches(httpx.URL("http://example.com"))
|
|
True
|
|
|
|
# Witch scheme matching...
|
|
>>> pattern = URLPattern("https")
|
|
>>> pattern.matches(httpx.URL("https://example.com"))
|
|
True
|
|
>>> pattern.matches(httpx.URL("http://example.com"))
|
|
False
|
|
|
|
# With domain matching...
|
|
>>> pattern = URLPattern("https://example.com")
|
|
>>> pattern.matches(httpx.URL("https://example.com"))
|
|
True
|
|
>>> pattern.matches(httpx.URL("http://example.com"))
|
|
False
|
|
>>> pattern.matches(httpx.URL("https://other.com"))
|
|
False
|
|
|
|
# Wildcard scheme, with domain matching...
|
|
>>> pattern = URLPattern("all://example.com")
|
|
>>> pattern.matches(httpx.URL("https://example.com"))
|
|
True
|
|
>>> pattern.matches(httpx.URL("http://example.com"))
|
|
True
|
|
>>> pattern.matches(httpx.URL("https://other.com"))
|
|
False
|
|
|
|
# With port matching...
|
|
>>> pattern = URLPattern("https://example.com:1234")
|
|
>>> pattern.matches(httpx.URL("https://example.com:1234"))
|
|
True
|
|
>>> pattern.matches(httpx.URL("https://example.com"))
|
|
False
|
|
"""
|
|
|
|
def __init__(self, pattern: str) -> None:
|
|
from ._models import URL
|
|
|
|
if pattern and ":" not in pattern:
|
|
raise ValueError(
|
|
f"Proxy keys should use proper URL forms rather "
|
|
f"than plain scheme strings. "
|
|
f'Instead of "{pattern}", use "{pattern}://"'
|
|
)
|
|
|
|
url = URL(pattern)
|
|
self.pattern = pattern
|
|
self.scheme = "" if url.scheme == "all" else url.scheme
|
|
self.host = "" if url.host == "*" else url.host
|
|
self.port = url.port
|
|
if not url.host or url.host == "*":
|
|
self.host_regex: typing.Optional[typing.Pattern[str]] = None
|
|
else:
|
|
if url.host.startswith("*."):
|
|
# *.example.com should match "www.example.com", but not "example.com"
|
|
domain = re.escape(url.host[2:])
|
|
self.host_regex = re.compile(f"^.+\\.{domain}$")
|
|
elif url.host.startswith("*"):
|
|
# *example.com should match "www.example.com" and "example.com"
|
|
domain = re.escape(url.host[1:])
|
|
self.host_regex = re.compile(f"^(.+\\.)?{domain}$")
|
|
else:
|
|
# example.com should match "example.com" but not "www.example.com"
|
|
domain = re.escape(url.host)
|
|
self.host_regex = re.compile(f"^{domain}$")
|
|
|
|
def matches(self, other: "URL") -> bool:
|
|
if self.scheme and self.scheme != other.scheme:
|
|
return False
|
|
if (
|
|
self.host
|
|
and self.host_regex is not None
|
|
and not self.host_regex.match(other.host)
|
|
):
|
|
return False
|
|
if self.port is not None and self.port != other.port:
|
|
return False
|
|
return True
|
|
|
|
@property
|
|
def priority(self) -> tuple:
|
|
"""
|
|
The priority allows URLPattern instances to be sortable, so that
|
|
we can match from most specific to least specific.
|
|
"""
|
|
# URLs with a port should take priority over URLs without a port.
|
|
port_priority = 0 if self.port is not None else 1
|
|
# Longer hostnames should match first.
|
|
host_priority = -len(self.host)
|
|
# Longer schemes should match first.
|
|
scheme_priority = -len(self.scheme)
|
|
return (port_priority, host_priority, scheme_priority)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.pattern)
|
|
|
|
def __lt__(self, other: "URLPattern") -> bool:
|
|
return self.priority < other.priority
|
|
|
|
def __eq__(self, other: typing.Any) -> bool:
|
|
return isinstance(other, URLPattern) and self.pattern == other.pattern
|