mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Allow providing credentials to HTTPS_PROXY (#9657)
Addresses https://github.com/matrix-org/synapse-dinsic/issues/70 This PR causes `ProxyAgent` to attempt to extract credentials from an `HTTPS_PROXY` env var. If credentials are found, a `Proxy-Authorization` header ([details](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Proxy-Authorization)) is sent to the proxy server to authenticate against it. The headers are *not* passed to the remote server. Also added some type hints.
This commit is contained in:
parent
4612302399
commit
5b268997bd
1
changelog.d/9657.feature
Normal file
1
changelog.d/9657.feature
Normal file
@ -0,0 +1 @@
|
||||
Add support for credentials for proxy authentication in the `HTTPS_PROXY` environment variable.
|
@ -19,9 +19,10 @@ from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer, protocol
|
||||
from twisted.internet.error import ConnectError
|
||||
from twisted.internet.interfaces import IStreamClientEndpoint
|
||||
from twisted.internet.protocol import connectionDone
|
||||
from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
|
||||
from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
|
||||
from twisted.web import http
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -43,23 +44,33 @@ class HTTPConnectProxyEndpoint:
|
||||
|
||||
Args:
|
||||
reactor: the Twisted reactor to use for the connection
|
||||
proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
|
||||
proxy
|
||||
host (bytes): hostname that we want to CONNECT to
|
||||
port (int): port that we want to connect to
|
||||
proxy_endpoint: the endpoint to use to connect to the proxy
|
||||
host: hostname that we want to CONNECT to
|
||||
port: port that we want to connect to
|
||||
headers: Extra HTTP headers to include in the CONNECT request
|
||||
"""
|
||||
|
||||
def __init__(self, reactor, proxy_endpoint, host, port):
|
||||
def __init__(
|
||||
self,
|
||||
reactor: IReactorCore,
|
||||
proxy_endpoint: IStreamClientEndpoint,
|
||||
host: bytes,
|
||||
port: int,
|
||||
headers: Headers,
|
||||
):
|
||||
self._reactor = reactor
|
||||
self._proxy_endpoint = proxy_endpoint
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._headers = headers
|
||||
|
||||
def __repr__(self):
|
||||
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
|
||||
|
||||
def connect(self, protocolFactory):
|
||||
f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
|
||||
def connect(self, protocolFactory: ClientFactory):
|
||||
f = HTTPProxiedClientFactory(
|
||||
self._host, self._port, protocolFactory, self._headers
|
||||
)
|
||||
d = self._proxy_endpoint.connect(f)
|
||||
# once the tcp socket connects successfully, we need to wait for the
|
||||
# CONNECT to complete.
|
||||
@ -74,15 +85,23 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
|
||||
HTTP Protocol object and run the rest of the connection.
|
||||
|
||||
Args:
|
||||
dst_host (bytes): hostname that we want to CONNECT to
|
||||
dst_port (int): port that we want to connect to
|
||||
wrapped_factory (protocol.ClientFactory): The original Factory
|
||||
dst_host: hostname that we want to CONNECT to
|
||||
dst_port: port that we want to connect to
|
||||
wrapped_factory: The original Factory
|
||||
headers: Extra HTTP headers to include in the CONNECT request
|
||||
"""
|
||||
|
||||
def __init__(self, dst_host, dst_port, wrapped_factory):
|
||||
def __init__(
|
||||
self,
|
||||
dst_host: bytes,
|
||||
dst_port: int,
|
||||
wrapped_factory: ClientFactory,
|
||||
headers: Headers,
|
||||
):
|
||||
self.dst_host = dst_host
|
||||
self.dst_port = dst_port
|
||||
self.wrapped_factory = wrapped_factory
|
||||
self.headers = headers
|
||||
self.on_connection = defer.Deferred()
|
||||
|
||||
def startedConnecting(self, connector):
|
||||
@ -92,7 +111,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
|
||||
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
|
||||
|
||||
return HTTPConnectProtocol(
|
||||
self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
|
||||
self.dst_host,
|
||||
self.dst_port,
|
||||
wrapped_protocol,
|
||||
self.on_connection,
|
||||
self.headers,
|
||||
)
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
@ -112,24 +135,37 @@ class HTTPConnectProtocol(protocol.Protocol):
|
||||
"""Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
|
||||
|
||||
Args:
|
||||
host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
|
||||
host: The original HTTP(s) hostname or IPv4 or IPv6 address literal
|
||||
to put in the CONNECT request
|
||||
|
||||
port (int): The original HTTP(s) port to put in the CONNECT request
|
||||
port: The original HTTP(s) port to put in the CONNECT request
|
||||
|
||||
wrapped_protocol (interfaces.IProtocol): the original protocol (probably
|
||||
HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
|
||||
wrapped_protocol: the original protocol (probably HTTPChannel or
|
||||
TLSMemoryBIOProtocol, but could be anything really)
|
||||
|
||||
connected_deferred (Deferred): a Deferred which will be callbacked with
|
||||
connected_deferred: a Deferred which will be callbacked with
|
||||
wrapped_protocol when the CONNECT completes
|
||||
|
||||
headers: Extra HTTP headers to include in the CONNECT request
|
||||
"""
|
||||
|
||||
def __init__(self, host, port, wrapped_protocol, connected_deferred):
|
||||
def __init__(
|
||||
self,
|
||||
host: bytes,
|
||||
port: int,
|
||||
wrapped_protocol: Protocol,
|
||||
connected_deferred: defer.Deferred,
|
||||
headers: Headers,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.wrapped_protocol = wrapped_protocol
|
||||
self.connected_deferred = connected_deferred
|
||||
self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
|
||||
self.headers = headers
|
||||
|
||||
self.http_setup_client = HTTPConnectSetupClient(
|
||||
self.host, self.port, self.headers
|
||||
)
|
||||
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
|
||||
|
||||
def connectionMade(self):
|
||||
@ -154,7 +190,7 @@ class HTTPConnectProtocol(protocol.Protocol):
|
||||
if buf:
|
||||
self.wrapped_protocol.dataReceived(buf)
|
||||
|
||||
def dataReceived(self, data):
|
||||
def dataReceived(self, data: bytes):
|
||||
# if we've set up the HTTP protocol, we can send the data there
|
||||
if self.wrapped_protocol.connected:
|
||||
return self.wrapped_protocol.dataReceived(data)
|
||||
@ -168,21 +204,29 @@ class HTTPConnectSetupClient(http.HTTPClient):
|
||||
"""HTTPClient protocol to send a CONNECT message for proxies and read the response.
|
||||
|
||||
Args:
|
||||
host (bytes): The hostname to send in the CONNECT message
|
||||
port (int): The port to send in the CONNECT message
|
||||
host: The hostname to send in the CONNECT message
|
||||
port: The port to send in the CONNECT message
|
||||
headers: Extra headers to send with the CONNECT message
|
||||
"""
|
||||
|
||||
def __init__(self, host, port):
|
||||
def __init__(self, host: bytes, port: int, headers: Headers):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.headers = headers
|
||||
self.on_connected = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
logger.debug("Connected to proxy, sending CONNECT")
|
||||
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
|
||||
|
||||
# Send any additional specified headers
|
||||
for name, values in self.headers.getAllRawHeaders():
|
||||
for value in values:
|
||||
self.sendHeader(name, value)
|
||||
|
||||
self.endHeaders()
|
||||
|
||||
def handleStatus(self, version, status, message):
|
||||
def handleStatus(self, version: bytes, status: bytes, message: bytes):
|
||||
logger.debug("Got Status: %s %s %s", status, message, version)
|
||||
if status != b"200":
|
||||
raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
|
||||
|
@ -12,10 +12,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from urllib.request import getproxies_environment, proxy_bypass_environment
|
||||
|
||||
import attr
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
@ -23,6 +26,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
|
||||
from twisted.web.error import SchemeNotSupported
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IAgent
|
||||
|
||||
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
|
||||
@ -32,6 +36,22 @@ logger = logging.getLogger(__name__)
|
||||
_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
|
||||
|
||||
|
||||
@attr.s
|
||||
class ProxyCredentials:
|
||||
username_password = attr.ib(type=bytes)
|
||||
|
||||
def as_proxy_authorization_value(self) -> bytes:
|
||||
"""
|
||||
Return the value for a Proxy-Authorization header (i.e. 'Basic abdef==').
|
||||
|
||||
Returns:
|
||||
A transformation of the authentication string the encoded value for
|
||||
a Proxy-Authorization header.
|
||||
"""
|
||||
# Encode as base64 and prepend the authorization type
|
||||
return b"Basic " + base64.encodebytes(self.username_password)
|
||||
|
||||
|
||||
@implementer(IAgent)
|
||||
class ProxyAgent(_AgentBase):
|
||||
"""An Agent implementation which will use an HTTP proxy if one was requested
|
||||
@ -96,6 +116,9 @@ class ProxyAgent(_AgentBase):
|
||||
https_proxy = proxies["https"].encode() if "https" in proxies else None
|
||||
no_proxy = proxies["no"] if "no" in proxies else None
|
||||
|
||||
# Parse credentials from https proxy connection string if present
|
||||
self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
|
||||
|
||||
self.http_proxy_endpoint = _http_proxy_endpoint(
|
||||
http_proxy, self.proxy_reactor, **self._endpoint_kwargs
|
||||
)
|
||||
@ -175,11 +198,22 @@ class ProxyAgent(_AgentBase):
|
||||
and self.https_proxy_endpoint
|
||||
and not should_skip_proxy
|
||||
):
|
||||
connect_headers = Headers()
|
||||
|
||||
# Determine whether we need to set Proxy-Authorization headers
|
||||
if self.https_proxy_creds:
|
||||
# Set a Proxy-Authorization header
|
||||
connect_headers.addRawHeader(
|
||||
b"Proxy-Authorization",
|
||||
self.https_proxy_creds.as_proxy_authorization_value(),
|
||||
)
|
||||
|
||||
endpoint = HTTPConnectProxyEndpoint(
|
||||
self.proxy_reactor,
|
||||
self.https_proxy_endpoint,
|
||||
parsed_uri.host,
|
||||
parsed_uri.port,
|
||||
headers=connect_headers,
|
||||
)
|
||||
else:
|
||||
# not using a proxy
|
||||
@ -208,12 +242,16 @@ class ProxyAgent(_AgentBase):
|
||||
)
|
||||
|
||||
|
||||
def _http_proxy_endpoint(proxy, reactor, **kwargs):
|
||||
def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs):
|
||||
"""Parses an http proxy setting and returns an endpoint for the proxy
|
||||
|
||||
Args:
|
||||
proxy (bytes|None): the proxy setting
|
||||
proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>]
|
||||
Note that compared to other apps, this function currently lacks support
|
||||
for specifying a protocol schema (i.e. protocol://...).
|
||||
|
||||
reactor: reactor to be used to connect to the proxy
|
||||
|
||||
kwargs: other args to be passed to HostnameEndpoint
|
||||
|
||||
Returns:
|
||||
@ -223,16 +261,43 @@ def _http_proxy_endpoint(proxy, reactor, **kwargs):
|
||||
if proxy is None:
|
||||
return None
|
||||
|
||||
# currently we only support hostname:port. Some apps also support
|
||||
# protocol://<host>[:port], which allows a way of requiring a TLS connection to the
|
||||
# proxy.
|
||||
|
||||
# Parse the connection string
|
||||
host, port = parse_host_port(proxy, default_port=1080)
|
||||
return HostnameEndpoint(reactor, host, port, **kwargs)
|
||||
|
||||
|
||||
def parse_host_port(hostport, default_port=None):
|
||||
# could have sworn we had one of these somewhere else...
|
||||
def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]:
|
||||
"""
|
||||
Parses the username and password from a proxy declaration e.g
|
||||
username:password@hostname:port.
|
||||
|
||||
Args:
|
||||
proxy: The proxy connection string.
|
||||
|
||||
Returns
|
||||
An instance of ProxyCredentials and the proxy connection string with any credentials
|
||||
stripped, i.e u:p@host:port -> host:port. If no credentials were found, the
|
||||
ProxyCredentials instance is replaced with None.
|
||||
"""
|
||||
if proxy and b"@" in proxy:
|
||||
# We use rsplit here as the password could contain an @ character
|
||||
credentials, proxy_without_credentials = proxy.rsplit(b"@", 1)
|
||||
return ProxyCredentials(credentials), proxy_without_credentials
|
||||
|
||||
return None, proxy
|
||||
|
||||
|
||||
def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]:
|
||||
"""
|
||||
Parse the hostname and port from a proxy connection byte string.
|
||||
|
||||
Args:
|
||||
hostport: The proxy connection string. Must be in the form 'host[:port]'.
|
||||
default_port: The default port to return if one is not found in `hostport`.
|
||||
|
||||
Returns:
|
||||
A tuple containing the hostname and port. Uses `default_port` if one was not found.
|
||||
"""
|
||||
if b":" in hostport:
|
||||
host, port = hostport.rsplit(b":", 1)
|
||||
try:
|
||||
|
@ -12,8 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import treq
|
||||
@ -242,6 +244,21 @@ class MatrixFederationAgentTests(TestCase):
|
||||
|
||||
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
|
||||
def test_https_request_via_proxy(self):
|
||||
"""Tests that TLS-encrypted requests can be made through a proxy"""
|
||||
self._do_https_request_via_proxy(auth_credentials=None)
|
||||
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
|
||||
)
|
||||
def test_https_request_via_proxy_with_auth(self):
|
||||
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
|
||||
self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
|
||||
|
||||
def _do_https_request_via_proxy(
|
||||
self,
|
||||
auth_credentials: Optional[str] = None,
|
||||
):
|
||||
agent = ProxyAgent(
|
||||
self.reactor,
|
||||
contextFactory=get_test_https_policy(),
|
||||
@ -278,6 +295,22 @@ class MatrixFederationAgentTests(TestCase):
|
||||
self.assertEqual(request.method, b"CONNECT")
|
||||
self.assertEqual(request.path, b"test.com:443")
|
||||
|
||||
# Check whether auth credentials have been supplied to the proxy
|
||||
proxy_auth_header_values = request.requestHeaders.getRawHeaders(
|
||||
b"Proxy-Authorization"
|
||||
)
|
||||
|
||||
if auth_credentials is not None:
|
||||
# Compute the correct header value for Proxy-Authorization
|
||||
encoded_credentials = base64.b64encode(b"bob:pinkponies")
|
||||
expected_header_value = b"Basic " + encoded_credentials
|
||||
|
||||
# Validate the header's value
|
||||
self.assertIn(expected_header_value, proxy_auth_header_values)
|
||||
else:
|
||||
# Check that the Proxy-Authorization header has not been supplied to the proxy
|
||||
self.assertIsNone(proxy_auth_header_values)
|
||||
|
||||
# tell the proxy server not to close the connection
|
||||
proxy_server.persistent = True
|
||||
|
||||
@ -312,6 +345,13 @@ class MatrixFederationAgentTests(TestCase):
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"/abc")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
|
||||
# Check that the destination server DID NOT receive proxy credentials
|
||||
proxy_auth_header_values = request.requestHeaders.getRawHeaders(
|
||||
b"Proxy-Authorization"
|
||||
)
|
||||
self.assertIsNone(proxy_auth_header_values)
|
||||
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user