mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Use getClientAddress
instead of getClientIP
. (#12599)
getClientIP was deprecated in Twisted 18.4.0, which also added getClientAddress. The Synapse minimum version for Twisted is currently 18.9.0, so all supported versions have the new API.
This commit is contained in:
parent
116a4c8340
commit
7fbf42499d
1
changelog.d/12599.misc
Normal file
1
changelog.d/12599.misc
Normal file
@ -0,0 +1 @@
|
||||
Use `getClientAddress` instead of the deprecated `getClientIP`.
|
@ -187,7 +187,7 @@ class Auth:
|
||||
Once get_user_by_req has set up the opentracing span, this does the actual work.
|
||||
"""
|
||||
try:
|
||||
ip_addr = request.getClientIP()
|
||||
ip_addr = request.getClientAddress().host
|
||||
user_agent = get_request_user_agent(request)
|
||||
|
||||
access_token = self.get_access_token_from_request(request)
|
||||
@ -356,7 +356,7 @@ class Auth:
|
||||
return None, None, None
|
||||
|
||||
if app_service.ip_range_whitelist:
|
||||
ip_address = IPAddress(request.getClientIP())
|
||||
ip_address = IPAddress(request.getClientAddress().host)
|
||||
if ip_address not in app_service.ip_range_whitelist:
|
||||
return None, None, None
|
||||
|
||||
|
@ -551,7 +551,7 @@ class AuthHandler:
|
||||
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
||||
|
||||
user_agent = get_request_user_agent(request)
|
||||
clientip = request.getClientIP()
|
||||
clientip = request.getClientAddress().host
|
||||
|
||||
await self.store.add_user_agent_ip_to_ui_auth_session(
|
||||
session.session_id, user_agent, clientip
|
||||
|
@ -92,7 +92,7 @@ class IdentityHandler:
|
||||
"""
|
||||
|
||||
await self._3pid_validation_ratelimiter_ip.ratelimit(
|
||||
None, (medium, request.getClientIP())
|
||||
None, (medium, request.getClientAddress().host)
|
||||
)
|
||||
await self._3pid_validation_ratelimiter_address.ratelimit(
|
||||
None, (medium, address)
|
||||
|
@ -468,7 +468,7 @@ class SsoHandler:
|
||||
auth_provider_id,
|
||||
remote_user_id,
|
||||
get_request_user_agent(request),
|
||||
request.getClientIP(),
|
||||
request.getClientAddress().host,
|
||||
)
|
||||
new_user = True
|
||||
elif self._sso_update_profile_information:
|
||||
@ -928,7 +928,7 @@ class SsoHandler:
|
||||
session.auth_provider_id,
|
||||
session.remote_user_id,
|
||||
get_request_user_agent(request),
|
||||
request.getClientIP(),
|
||||
request.getClientAddress().host,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
@ -238,7 +238,7 @@ class SynapseRequest(Request):
|
||||
request_id,
|
||||
request=ContextRequest(
|
||||
request_id=request_id,
|
||||
ip_address=self.getClientIP(),
|
||||
ip_address=self.getClientAddress().host,
|
||||
site_tag=self.synapse_site.site_tag,
|
||||
# The requester is going to be unknown at this point.
|
||||
requester=None,
|
||||
@ -381,7 +381,7 @@ class SynapseRequest(Request):
|
||||
|
||||
self.synapse_site.access_logger.debug(
|
||||
"%s - %s - Received request: %s %s",
|
||||
self.getClientIP(),
|
||||
self.getClientAddress().host,
|
||||
self.synapse_site.site_tag,
|
||||
self.get_method(),
|
||||
self.get_redacted_uri(),
|
||||
@ -429,7 +429,7 @@ class SynapseRequest(Request):
|
||||
"%s - %s - {%s}"
|
||||
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
||||
' %sB %s "%s %s %s" "%s" [%d dbevts]',
|
||||
self.getClientIP(),
|
||||
self.getClientAddress().host,
|
||||
self.synapse_site.site_tag,
|
||||
requester,
|
||||
processing_time,
|
||||
|
@ -884,7 +884,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
|
||||
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
||||
tags.HTTP_METHOD: request.get_method(),
|
||||
tags.HTTP_URL: request.get_redacted_uri(),
|
||||
tags.PEER_HOST_IPV6: request.getClientIP(),
|
||||
tags.PEER_HOST_IPV6: request.getClientAddress().host,
|
||||
}
|
||||
|
||||
request_name = request.request_metrics.name
|
||||
|
@ -112,7 +112,7 @@ class AuthRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.RECAPTCHA, authdict, request.getClientIP()
|
||||
LoginType.RECAPTCHA, authdict, request.getClientAddress().host
|
||||
)
|
||||
except LoginError as e:
|
||||
# Authentication failed, let user try again
|
||||
@ -132,7 +132,7 @@ class AuthRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.TERMS, authdict, request.getClientIP()
|
||||
LoginType.TERMS, authdict, request.getClientAddress().host
|
||||
)
|
||||
except LoginError as e:
|
||||
# Authentication failed, let user try again
|
||||
@ -161,7 +161,9 @@ class AuthRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
|
||||
LoginType.REGISTRATION_TOKEN,
|
||||
authdict,
|
||||
request.getClientAddress().host,
|
||||
)
|
||||
except LoginError as e:
|
||||
html = self.registration_token_template.render(
|
||||
|
@ -176,7 +176,7 @@ class LoginRestServlet(RestServlet):
|
||||
|
||||
if appservice.is_rate_limited():
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientIP()
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
|
||||
result = await self._do_appservice_login(
|
||||
@ -188,19 +188,25 @@ class LoginRestServlet(RestServlet):
|
||||
self.jwt_enabled
|
||||
and login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
):
|
||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
result = await self._do_jwt_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
result = await self._do_token_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
)
|
||||
else:
|
||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
result = await self._do_other_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
|
@ -352,7 +352,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
||||
if self.inhibit_user_in_use_error:
|
||||
return 200, {"available": True}
|
||||
|
||||
ip = request.getClientIP()
|
||||
ip = request.getClientAddress().host
|
||||
with self.ratelimiter.ratelimit(ip) as wait_deferred:
|
||||
await wait_deferred
|
||||
|
||||
@ -394,7 +394,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
|
||||
await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))
|
||||
|
||||
if not self.hs.config.registration.enable_registration:
|
||||
raise SynapseError(
|
||||
@ -441,7 +441,7 @@ class RegisterRestServlet(RestServlet):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
client_addr = request.getClientIP()
|
||||
client_addr = request.getClientAddress().host
|
||||
|
||||
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
||||
|
||||
|
@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "192.168.10.10"
|
||||
request.getClientAddress.return_value.host = "192.168.10.10"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "131.111.8.42"
|
||||
request.getClientAddress.return_value.host = "131.111.8.42"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
f = self.get_failure(
|
||||
@ -190,7 +190,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
@ -209,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
@ -236,7 +236,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_device = simple_async_mock({"hidden": False})
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
||||
@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_device = simple_async_mock(None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
||||
@ -288,7 +288,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.store.insert_client_ip = simple_async_mock(None)
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_success(self.auth.get_user_by_req(request))
|
||||
@ -305,7 +305,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.store.insert_client_ip = simple_async_mock(None)
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_success(self.auth.get_user_by_req(request))
|
||||
|
@ -204,7 +204,7 @@ def _mock_request():
|
||||
mock = Mock(
|
||||
spec=[
|
||||
"finish",
|
||||
"getClientIP",
|
||||
"getClientAddress",
|
||||
"getHeader",
|
||||
"setHeader",
|
||||
"setResponseCode",
|
||||
|
@ -1300,7 +1300,7 @@ def _build_callback_request(
|
||||
"getCookie",
|
||||
"cookies",
|
||||
"requestHeaders",
|
||||
"getClientIP",
|
||||
"getClientAddress",
|
||||
"getHeader",
|
||||
]
|
||||
)
|
||||
@ -1310,5 +1310,5 @@ def _build_callback_request(
|
||||
request.args = {}
|
||||
request.args[b"code"] = [code.encode("utf-8")]
|
||||
request.args[b"state"] = [state.encode("utf-8")]
|
||||
request.getClientIP.return_value = ip_address
|
||||
request.getClientAddress.return_value.host = ip_address
|
||||
return request
|
||||
|
@ -352,7 +352,7 @@ def _mock_request():
|
||||
mock = Mock(
|
||||
spec=[
|
||||
"finish",
|
||||
"getClientIP",
|
||||
"getClientAddress",
|
||||
"getHeader",
|
||||
"setHeader",
|
||||
"setResponseCode",
|
||||
|
@ -154,10 +154,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(port, 8765)
|
||||
|
||||
# Set up client side protocol
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
|
||||
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = self.site.buildProtocol(None)
|
||||
server_address = IPv4Address("TCP", host, port)
|
||||
channel = self.site.buildProtocol((host, port))
|
||||
|
||||
# hook into the channel's request factory so that we can keep a record
|
||||
# of the requests
|
||||
@ -173,12 +175,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
channel, self.reactor, client_protocol
|
||||
channel, self.reactor, client_protocol, server_address, client_address
|
||||
)
|
||||
client_protocol.makeConnection(client_to_server_transport)
|
||||
|
||||
server_to_client_transport = FakeTransport(
|
||||
client_protocol, self.reactor, channel
|
||||
client_protocol, self.reactor, channel, client_address, server_address
|
||||
)
|
||||
channel.makeConnection(server_to_client_transport)
|
||||
|
||||
@ -406,19 +408,21 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(port, repl_port)
|
||||
|
||||
# Set up client side protocol
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
|
||||
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = self._hs_to_site[hs].buildProtocol(None)
|
||||
server_address = IPv4Address("TCP", host, port)
|
||||
channel = self._hs_to_site[hs].buildProtocol((host, port))
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
channel, self.reactor, client_protocol
|
||||
channel, self.reactor, client_protocol, server_address, client_address
|
||||
)
|
||||
client_protocol.makeConnection(client_to_server_transport)
|
||||
|
||||
server_to_client_transport = FakeTransport(
|
||||
client_protocol, self.reactor, channel
|
||||
client_protocol, self.reactor, channel, client_address, server_address
|
||||
)
|
||||
channel.makeConnection(server_to_client_transport)
|
||||
|
||||
|
@ -181,7 +181,7 @@ class FakeChannel:
|
||||
self.resource_usage = _self.logcontext.get_resource_usage()
|
||||
|
||||
def getPeer(self):
|
||||
# We give an address so that getClientIP returns a non null entry,
|
||||
# We give an address so that getClientAddress/getClientIP returns a non null entry,
|
||||
# causing us to record the MAU
|
||||
return address.IPv4Address("TCP", self._ip, 3423)
|
||||
|
||||
@ -562,7 +562,10 @@ class FakeTransport:
|
||||
"""
|
||||
|
||||
_peer_address: Optional[IAddress] = attr.ib(default=None)
|
||||
"""The value to be returend by getPeer"""
|
||||
"""The value to be returned by getPeer"""
|
||||
|
||||
_host_address: Optional[IAddress] = attr.ib(default=None)
|
||||
"""The value to be returned by getHost"""
|
||||
|
||||
disconnecting = False
|
||||
disconnected = False
|
||||
@ -571,11 +574,11 @@ class FakeTransport:
|
||||
producer = attr.ib(default=None)
|
||||
autoflush = attr.ib(default=True)
|
||||
|
||||
def getPeer(self):
|
||||
def getPeer(self) -> Optional[IAddress]:
|
||||
return self._peer_address
|
||||
|
||||
def getHost(self):
|
||||
return None
|
||||
def getHost(self) -> Optional[IAddress]:
|
||||
return self._host_address
|
||||
|
||||
def loseConnection(self, reason=None):
|
||||
if not self.disconnecting:
|
||||
|
Loading…
Reference in New Issue
Block a user