Merge branch 'develop' of github.com:matrix-org/synapse into erikj/sqlite_native_upsert

This commit is contained in:
Erik Johnston 2019-01-25 14:11:17 +00:00
commit 431e485914
13 changed files with 120 additions and 7 deletions

1
changelog.d/4415.feature Normal file
View File

@ -0,0 +1 @@
Search now includes results from predecessor rooms after a room upgrade.

1
changelog.d/4466.misc Normal file
View File

@ -0,0 +1 @@
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.

1
changelog.d/4468.misc Normal file
View File

@ -0,0 +1 @@
Move SRV logic into the Agent layer

1
changelog.d/4476.misc Normal file
View File

@ -0,0 +1 @@
Fix quoting for allowed_local_3pids example config

View File

@ -444,6 +444,20 @@ class Filter(object):
def include_redundant_members(self): def include_redundant_members(self):
return self.filter_json.get("include_redundant_members", False) return self.filter_json.get("include_redundant_members", False)
def with_room_ids(self, room_ids):
"""Returns a new filter with the given room IDs appended.
Args:
room_ids (iterable[unicode]): The room_ids to add
Returns:
filter: A new filter including the given rooms and the old
filter's rooms.
"""
newFilter = Filter(self.filter_json)
newFilter.rooms += room_ids
return newFilter
def _matches_wildcard(actual_value, filter_value): def _matches_wildcard(actual_value, filter_value):
if filter_value.endswith("*"): if filter_value.endswith("*"):

View File

@ -84,11 +84,11 @@ class RegistrationConfig(Config):
# #
# allowed_local_3pids: # allowed_local_3pids:
# - medium: email # - medium: email
# pattern: ".*@matrix\\.org" # pattern: '.*@matrix\\.org'
# - medium: email # - medium: email
# pattern: ".*@vector\\.im" # pattern: '.*@vector\\.im'
# - medium: msisdn # - medium: msisdn
# pattern: "\\+44" # pattern: '\\+44'
# If set, allows registration by anyone who also has the shared # If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.

View File

@ -37,6 +37,41 @@ class SearchHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(SearchHandler, self).__init__(hs) super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
"""Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a
`predecessor` key. If it exists, we add the room ID to our return
list and then check that room for a m.room.create event and so on
until we can no longer find any more previous rooms.
The full list of all found rooms in then returned.
Args:
room_id (str): id of the room to search through.
Returns:
Deferred[iterable[unicode]]: predecessor room ids
"""
historical_room_ids = []
while True:
predecessor = yield self.store.get_room_predecessor(room_id)
# If no predecessor, assume we've hit a dead end
if not predecessor:
break
# Add predecessor's room ID
historical_room_ids.append(predecessor["room_id"])
# Scan through the old room for further predecessors
room_id = predecessor["room_id"]
defer.returnValue(historical_room_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def search(self, user, content, batch=None): def search(self, user, content, batch=None):
"""Performs a full text search for a user. """Performs a full text search for a user.
@ -137,6 +172,18 @@ class SearchHandler(BaseHandler):
) )
room_ids = set(r.room_id for r in rooms) room_ids = set(r.room_id for r in rooms)
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
historical_room_ids = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = yield self.get_old_rooms_from_upgraded_room(room_id)
historical_room_ids += ids
# Prevent any historical events from being filtered
search_filter = search_filter.with_room_ids(historical_room_ids)
room_ids = search_filter.filter_rooms(room_ids) room_ids = search_filter.filter_rooms(room_ids)
if batch_group == "room_id": if batch_group == "room_id":

View File

@ -19,6 +19,7 @@ from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent from twisted.web.iweb import IAgent
from synapse.http.endpoint import parse_server_name from synapse.http.endpoint import parse_server_name
@ -109,6 +110,15 @@ class MatrixFederationAgent(object):
else: else:
target = pick_server_from_list(server_list) target = pick_server_from_list(server_list)
# make sure that the Host header is set correctly
if headers is None:
headers = Headers()
else:
headers = headers.copy()
if not headers.hasHeader(b'host'):
headers.addRawHeader(b'host', server_name_bytes)
class EndpointFactory(object): class EndpointFactory(object):
@staticmethod @staticmethod
def endpointForURI(_uri): def endpointForURI(_uri):

View File

@ -255,7 +255,6 @@ class MatrixFederationHttpClient(object):
headers_dict = { headers_dict = {
b"User-Agent": [self.version_string_bytes], b"User-Agent": [self.version_string_bytes],
b"Host": [destination_bytes],
} }
with limiter: with limiter:

View File

@ -15,7 +15,6 @@
import struct import struct
import threading import threading
from sqlite3 import sqlite_version_info
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
@ -40,7 +39,7 @@ class Sqlite3Engine(object):
# when its enabled. # when its enabled.
# FIXME: Figure out what is wrong so we can re-enable native upserts # FIXME: Figure out what is wrong so we can re-enable native upserts
# return sqlite_version_info >= (3, 24, 0) # return self.module.sqlite_version_info >= (3, 24, 0)
return False return False
def check_database(self, txn): def check_database(self, txn):

View File

@ -437,6 +437,30 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = yield self.get_event(create_id) create_event = yield self.get_event(create_id)
defer.returnValue(create_event.content.get("room_version", "1")) defer.returnValue(create_event.content.get("room_version", "1"))
@defer.inlineCallbacks
def get_room_predecessor(self, room_id):
"""Get the predecessor room of an upgraded room if one exists.
Otherwise return None.
Args:
room_id (str)
Returns:
Deferred[unicode|None]: predecessor room id
"""
state_ids = yield self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, ""))
# If we can't find the create event, assume we've hit a dead end
if not create_id:
defer.returnValue(None)
# Retrieve the room's create event
create_event = yield self.get_event(create_id)
# Return predecessor if present
defer.returnValue(create_event.content.get("predecessor", None))
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id): def get_current_state_ids(self, room_id):
"""Get the current state event ids for a room based on the """Get the current state event ids for a room based on the

View File

@ -131,6 +131,10 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0] request = http_server.requests[0]
self.assertEqual(request.method, b'GET') self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'testserv:8448']
)
content = request.content.read() content = request.content.read()
self.assertEqual(content, b'') self.assertEqual(content, b'')
@ -195,6 +199,10 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0] request = http_server.requests[0]
self.assertEqual(request.method, b'GET') self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'1.2.3.4'],
)
# finish the request # finish the request
request.finish() request.finish()
@ -235,6 +243,10 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0] request = http_server.requests[0]
self.assertEqual(request.method, b'GET') self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'testserv'],
)
# finish the request # finish the request
request.finish() request.finish()
@ -276,6 +288,10 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0] request = http_server.requests[0]
self.assertEqual(request.method, b'GET') self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'testserv'],
)
# finish the request # finish the request
request.finish() request.finish()

View File

@ -49,7 +49,6 @@ class FederationClientTests(HomeserverTestCase):
return hs return hs
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
self.cl = MatrixFederationHttpClient(self.hs) self.cl = MatrixFederationHttpClient(self.hs)
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
@ -95,6 +94,7 @@ class FederationClientTests(HomeserverTestCase):
# that should have made it send the request to the transport # that should have made it send the request to the transport
self.assertRegex(transport.value(), b"^GET /foo/bar") self.assertRegex(transport.value(), b"^GET /foo/bar")
self.assertRegex(transport.value(), b"Host: testserv:8008")
# Deferred is still without a result # Deferred is still without a result
self.assertNoResult(test_d) self.assertNoResult(test_d)