Use getClientAddress instead of getClientIP. (#12599)

getClientIP was deprecated in Twisted 18.4.0, which also added
getClientAddress. The Synapse minimum version for Twisted is
currently 18.9.0, so all supported versions have the new API.
This commit is contained in:
Patrick Cloke 2022-05-04 14:11:21 -04:00 committed by GitHub
parent 116a4c8340
commit 7fbf42499d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 62 additions and 46 deletions

View file

@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
request.getClientAddress.return_value.host = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.getClientAddress.return_value.host = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
f = self.get_failure(
@ -190,7 +190,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@ -209,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@ -236,7 +236,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_device = simple_async_mock({"hidden": False})
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_device = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
@ -288,7 +288,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.store.insert_client_ip = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_success(self.auth.get_user_by_req(request))
@ -305,7 +305,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.store.insert_client_ip = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_success(self.auth.get_user_by_req(request))

View file

@ -204,7 +204,7 @@ def _mock_request():
mock = Mock(
spec=[
"finish",
"getClientIP",
"getClientAddress",
"getHeader",
"setHeader",
"setResponseCode",

View file

@ -1300,7 +1300,7 @@ def _build_callback_request(
"getCookie",
"cookies",
"requestHeaders",
"getClientIP",
"getClientAddress",
"getHeader",
]
)
@ -1310,5 +1310,5 @@ def _build_callback_request(
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address
request.getClientAddress.return_value.host = ip_address
return request

View file

@ -352,7 +352,7 @@ def _mock_request():
mock = Mock(
spec=[
"finish",
"getClientIP",
"getClientAddress",
"getHeader",
"setHeader",
"setResponseCode",

View file

@ -154,10 +154,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(port, 8765)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
# Set up the server side protocol
channel = self.site.buildProtocol(None)
server_address = IPv4Address("TCP", host, port)
channel = self.site.buildProtocol((host, port))
# hook into the channel's request factory so that we can keep a record
# of the requests
@ -173,12 +175,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
channel, self.reactor, client_protocol, server_address, client_address
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
client_protocol, self.reactor, channel, client_address, server_address
)
channel.makeConnection(server_to_client_transport)
@ -406,19 +408,21 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(port, repl_port)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
# Set up the server side protocol
channel = self._hs_to_site[hs].buildProtocol(None)
server_address = IPv4Address("TCP", host, port)
channel = self._hs_to_site[hs].buildProtocol((host, port))
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
channel, self.reactor, client_protocol, server_address, client_address
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
client_protocol, self.reactor, channel, client_address, server_address
)
channel.makeConnection(server_to_client_transport)

View file

@ -181,7 +181,7 @@ class FakeChannel:
self.resource_usage = _self.logcontext.get_resource_usage()
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# We give an address so that getClientAddress/getClientIP returns a non null entry,
# causing us to record the MAU
return address.IPv4Address("TCP", self._ip, 3423)
@ -562,7 +562,10 @@ class FakeTransport:
"""
_peer_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returend by getPeer"""
"""The value to be returned by getPeer"""
_host_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returned by getHost"""
disconnecting = False
disconnected = False
@ -571,11 +574,11 @@ class FakeTransport:
producer = attr.ib(default=None)
autoflush = attr.ib(default=True)
def getPeer(self):
def getPeer(self) -> Optional[IAddress]:
return self._peer_address
def getHost(self):
return None
def getHost(self) -> Optional[IAddress]:
return self._host_address
def loseConnection(self, reason=None):
if not self.disconnecting: