Fix client IPs being broken on Python 3 (#3908)

This commit is contained in:
Amber Brown 2018-09-20 20:14:34 +10:00 committed by GitHub
parent 3fd68d533b
commit 1f3f5fcf52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 238 additions and 58 deletions

View File

@ -31,6 +31,11 @@ matrix:
- python: 3.6 - python: 3.6
env: TOX_ENV=py36 env: TOX_ENV=py36
- python: 3.6
env: TOX_ENV=py36-postgres TRIAL_FLAGS="-j 4"
services:
- postgresql
- python: 3.6 - python: 3.6
env: TOX_ENV=check_isort env: TOX_ENV=check_isort

1
changelog.d/3908.bugfix Normal file
View File

@ -0,0 +1 @@
Fix adding client IPs to the database failing on Python 3.

View File

@ -308,7 +308,7 @@ class XForwardedForRequest(SynapseRequest):
C{b"-"}. C{b"-"}.
""" """
return self.requestHeaders.getRawHeaders( return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip() b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii')
class SynapseRequestFactory(object): class SynapseRequestFactory(object):

View File

@ -119,21 +119,25 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
for entry in iteritems(to_update): for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
self._simple_upsert_txn( try:
txn, self._simple_upsert_txn(
table="user_ips", txn,
keyvalues={ table="user_ips",
"user_id": user_id, keyvalues={
"access_token": access_token, "user_id": user_id,
"ip": ip, "access_token": access_token,
"user_agent": user_agent, "ip": ip,
"device_id": device_id, "user_agent": user_agent,
}, "device_id": device_id,
values={ },
"last_seen": last_seen, values={
}, "last_seen": last_seen,
lock=False, },
) lock=False,
)
except Exception as e:
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_last_client_ip_by_device(self, user_id, device_id): def get_last_client_ip_by_device(self, user_id, device_id):

View File

@ -98,7 +98,7 @@ class FakeSite:
return FakeLogger() return FakeLogger()
def make_request(method, path, content=b"", access_token=None): def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
""" """
Make a web request using the given method and path, feed it the Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath. content, and return the Request and the Channel underneath.
@ -120,14 +120,16 @@ def make_request(method, path, content=b"", access_token=None):
site = FakeSite() site = FakeSite()
channel = FakeChannel() channel = FakeChannel()
req = SynapseRequest(site, channel) req = request(site, channel)
req.process = lambda: b"" req.process = lambda: b""
req.content = BytesIO(content) req.content = BytesIO(content)
if access_token: if access_token:
req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1") if content:
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
req.requestReceived(method, path, b"1.1") req.requestReceived(method, path, b"1.1")
return req, channel return req, channel

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,35 +13,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib
import hmac
import json
from mock import Mock from mock import Mock
from twisted.internet import defer from twisted.internet import defer
import tests.unittest from synapse.http.site import XForwardedForRequest
import tests.utils from synapse.rest.client.v1 import admin, login
from tests import unittest
class ClientIpStoreTestCase(tests.unittest.TestCase): class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def __init__(self, *args, **kwargs): def make_homeserver(self, reactor, clock):
super(ClientIpStoreTestCase, self).__init__(*args, **kwargs) hs = self.setup_test_homeserver()
self.store = None # type: synapse.storage.DataStore return hs
self.clock = None # type: tests.utils.MockClock
@defer.inlineCallbacks def prepare(self, hs, reactor, clock):
def setUp(self):
self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
@defer.inlineCallbacks
def test_insert_new_client_ip(self): def test_insert_new_client_ip(self):
self.clock.now = 12345678 self.reactor.advance(12345678)
user_id = "@user:id" user_id = "@user:id"
yield self.store.insert_client_ip( self.get_success(
user_id, "access_token", "ip", "user_agent", "device_id" self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
) )
result = yield self.store.get_last_client_ip_by_device(user_id, "device_id") # Trigger the storage loop
self.reactor.advance(10)
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, "device_id")
)
r = result[(user_id, "device_id")] r = result[(user_id, "device_id")]
self.assertDictContainsSubset( self.assertDictContainsSubset(
@ -55,18 +66,18 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
r, r,
) )
@defer.inlineCallbacks
def test_disabled_monthly_active_user(self): def test_disabled_monthly_active_user(self):
self.hs.config.limit_usage_by_mau = False self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50 self.hs.config.max_mau_value = 50
user_id = "@user:server" user_id = "@user:server"
yield self.store.insert_client_ip( self.get_success(
user_id, "access_token", "ip", "user_agent", "device_id" self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
) )
active = yield self.store.user_last_seen_monthly_active(user_id) active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active) self.assertFalse(active)
@defer.inlineCallbacks
def test_adding_monthly_active_user_when_full(self): def test_adding_monthly_active_user_when_full(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50 self.hs.config.max_mau_value = 50
@ -76,38 +87,159 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users) return_value=defer.succeed(lots_of_users)
) )
yield self.store.insert_client_ip( self.get_success(
user_id, "access_token", "ip", "user_agent", "device_id" self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
) )
active = yield self.store.user_last_seen_monthly_active(user_id) active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active) self.assertFalse(active)
@defer.inlineCallbacks
def test_adding_monthly_active_user_when_space(self): def test_adding_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50 self.hs.config.max_mau_value = 50
user_id = "@user:server" user_id = "@user:server"
active = yield self.store.user_last_seen_monthly_active(user_id) active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active) self.assertFalse(active)
yield self.store.insert_client_ip( # Trigger the saving loop
user_id, "access_token", "ip", "user_agent", "device_id" self.reactor.advance(10)
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
) )
active = yield self.store.user_last_seen_monthly_active(user_id) active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active) self.assertTrue(active)
@defer.inlineCallbacks
def test_updating_monthly_active_user_when_space(self): def test_updating_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50 self.hs.config.max_mau_value = 50
user_id = "@user:server" user_id = "@user:server"
yield self.store.register(user_id=user_id, token="123", password_hash=None) self.get_success(
self.store.register(user_id=user_id, token="123", password_hash=None)
)
active = yield self.store.user_last_seen_monthly_active(user_id) active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active) self.assertFalse(active)
yield self.store.insert_client_ip( # Trigger the saving loop
user_id, "access_token", "ip", "user_agent", "device_id" self.reactor.advance(10)
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
) )
active = yield self.store.user_last_seen_monthly_active(user_id) active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active) self.assertTrue(active)
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets, login.register_servlets]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
return hs
def prepare(self, hs, reactor, clock):
self.hs.config.registration_shared_secret = u"shared"
self.store = self.hs.get_datastore()
# Create the user
request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
)
request, channel = self.make_request(
"POST", "/_matrix/client/r0/admin/register", body.encode('utf8')
)
self.render(request)
self.assertEqual(channel.code, 200)
self.user_id = channel.json_body["user_id"]
def test_request_with_xforwarded(self):
"""
The IP in X-Forwarded-For is entered into the client IPs table.
"""
self._runtest(
{b"X-Forwarded-For": b"127.9.0.1"},
"127.9.0.1",
{"request": XForwardedForRequest},
)
def test_request_from_getPeer(self):
"""
The IP returned by getPeer is entered into the client IPs table, if
there's no X-Forwarded-For header.
"""
self._runtest({}, "127.0.0.1", {})
def _runtest(self, headers, expected_ip, make_request_args):
device_id = "bleb"
body = json.dumps(
{
"type": "m.login.password",
"user": "bob",
"password": "abc123",
"device_id": device_id,
}
)
request, channel = self.make_request(
"POST", "/_matrix/client/r0/login", body.encode('utf8'), **make_request_args
)
self.render(request)
self.assertEqual(channel.code, 200)
access_token = channel.json_body["access_token"].encode('ascii')
# Advance to a known time
self.reactor.advance(123456 - self.reactor.seconds())
request, channel = self.make_request(
"GET",
"/_matrix/client/r0/admin/users/" + self.user_id,
body.encode('utf8'),
access_token=access_token,
**make_request_args
)
request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
# Add the optional headers
for h, v in headers.items():
request.requestHeaders.addRawHeader(h, v)
self.render(request)
# Advance so the save loop occurs
self.reactor.advance(100)
result = self.get_success(
self.store.get_last_client_ip_by_device(self.user_id, device_id)
)
r = result[(self.user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": self.user_id,
"device_id": device_id,
"ip": expected_ip,
"user_agent": "Mozzila pizza",
"last_seen": 123456100,
},
r,
)

View File

@ -26,6 +26,7 @@ from twisted.internet.defer import Deferred
from twisted.trial import unittest from twisted.trial import unittest
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.logcontext import LoggingContextFilter from synapse.util.logcontext import LoggingContextFilter
@ -237,7 +238,9 @@ class HomeserverTestCase(TestCase):
Function to optionally be overridden in subclasses. Function to optionally be overridden in subclasses.
""" """
def make_request(self, method, path, content=b""): def make_request(
self, method, path, content=b"", access_token=None, request=SynapseRequest
):
""" """
Create a SynapseRequest at the path using the method and containing the Create a SynapseRequest at the path using the method and containing the
given content. given content.
@ -255,7 +258,7 @@ class HomeserverTestCase(TestCase):
if isinstance(content, dict): if isinstance(content, dict):
content = json.dumps(content).encode('utf8') content = json.dumps(content).encode('utf8')
return make_request(method, path, content) return make_request(method, path, content, access_token, request)
def render(self, request): def render(self, request):
""" """

View File

@ -16,7 +16,9 @@
import atexit import atexit
import hashlib import hashlib
import os import os
import time
import uuid import uuid
import warnings
from inspect import getcallargs from inspect import getcallargs
from mock import Mock, patch from mock import Mock, patch
@ -237,20 +239,41 @@ def setup_test_homeserver(
else: else:
# We need to do cleanup on PostgreSQL # We need to do cleanup on PostgreSQL
def cleanup(): def cleanup():
import psycopg2
# Close all the db pools # Close all the db pools
hs.get_db_pool().close() hs.get_db_pool().close()
dropped = False
# Drop the test database # Drop the test database
db_conn = db_engine.module.connect( db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB, user=POSTGRES_USER database=POSTGRES_BASE_DB, user=POSTGRES_USER
) )
db_conn.autocommit = True db_conn.autocommit = True
cur = db_conn.cursor() cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
db_conn.commit() # Try a few times to drop the DB. Some things may hold on to the
# database for a few more seconds due to flakiness, preventing
# us from dropping it when the test is over. If we can't drop
# it, warn and move on.
for x in range(5):
try:
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
db_conn.commit()
dropped = True
except psycopg2.OperationalError as e:
warnings.warn(
"Couldn't drop old db: " + str(e), category=UserWarning
)
time.sleep(0.5)
cur.close() cur.close()
db_conn.close() db_conn.close()
if not dropped:
warnings.warn("Failed to drop old DB.", category=UserWarning)
if not LEAVE_DB: if not LEAVE_DB:
# Register the cleanup hook # Register the cleanup hook
cleanup_func(cleanup) cleanup_func(cleanup)

10
tox.ini
View File

@ -70,6 +70,16 @@ usedevelop=true
[testenv:py36] [testenv:py36]
usedevelop=true usedevelop=true
[testenv:py36-postgres]
usedevelop=true
deps =
{[base]deps}
psycopg2
setenv =
{[base]setenv}
SYNAPSE_POSTGRES = 1
[testenv:packaging] [testenv:packaging]
deps = deps =
check-manifest check-manifest