mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-26 13:59:22 -05:00
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/split_purge_history
This commit is contained in:
commit
6a0092d371
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -5,3 +5,4 @@
|
||||
* [ ] Pull request is based on the develop branch
|
||||
* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog)
|
||||
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off)
|
||||
* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#code-style))
|
||||
|
@ -58,10 +58,29 @@ All Matrix projects have a well-defined code-style - and sometimes we've even
|
||||
got as far as documenting it... For instance, synapse's code style doc lives
|
||||
at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md.
|
||||
|
||||
To facilitate meeting these criteria you can run ``scripts-dev/lint.sh``
|
||||
locally. Since this runs the tools listed in the above document, you'll need
|
||||
python 3.6 and to install each tool. **Note that the script does not just
|
||||
test/check, but also reformats code, so you may wish to ensure any new code is
|
||||
committed first**. By default this script checks all files and can take some
|
||||
time; if you alter only certain files, you might wish to specify paths as
|
||||
arguments to reduce the run-time.
|
||||
|
||||
Please ensure your changes match the cosmetic style of the existing project,
|
||||
and **never** mix cosmetic and functional changes in the same commit, as it
|
||||
makes it horribly hard to review otherwise.
|
||||
|
||||
Before doing a commit, ensure the changes you've made don't produce
|
||||
linting errors. You can do this by running the linters as follows. Ensure to
|
||||
commit any files that were corrected.
|
||||
|
||||
::
|
||||
# Install the dependencies
|
||||
pip install -U black flake8 isort
|
||||
|
||||
# Run the linter script
|
||||
./scripts-dev/lint.sh
|
||||
|
||||
Changelog
|
||||
~~~~~~~~~
|
||||
|
||||
|
1
changelog.d/5727.feature
Normal file
1
changelog.d/5727.feature
Normal file
@ -0,0 +1 @@
|
||||
Add federation support for cross-signing.
|
1
changelog.d/6164.doc
Normal file
1
changelog.d/6164.doc
Normal file
@ -0,0 +1 @@
|
||||
Contributor documentation now mentions script to run linters.
|
1
changelog.d/6232.bugfix
Normal file
1
changelog.d/6232.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Remove a room from a server's public rooms list on room upgrade.
|
1
changelog.d/6238.feature
Normal file
1
changelog.d/6238.feature
Normal file
@ -0,0 +1 @@
|
||||
Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars.
|
1
changelog.d/6254.bugfix
Normal file
1
changelog.d/6254.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Make notification of cross-signing signatures work with workers.
|
1
changelog.d/6298.misc
Normal file
1
changelog.d/6298.misc
Normal file
@ -0,0 +1 @@
|
||||
Refactor EventContext for clarity.
|
1
changelog.d/6301.feature
Normal file
1
changelog.d/6301.feature
Normal file
@ -0,0 +1 @@
|
||||
Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)).
|
1
changelog.d/6304.misc
Normal file
1
changelog.d/6304.misc
Normal file
@ -0,0 +1 @@
|
||||
Update the version of black used to 19.10b0.
|
1
changelog.d/6305.misc
Normal file
1
changelog.d/6305.misc
Normal file
@ -0,0 +1 @@
|
||||
Add some documentation about worker replication.
|
1
changelog.d/6306.bugfix
Normal file
1
changelog.d/6306.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Appservice requests will no longer contain a double slash prefix when the appservice url provided ends in a slash.
|
1
changelog.d/6312.misc
Normal file
1
changelog.d/6312.misc
Normal file
@ -0,0 +1 @@
|
||||
Document the use of `lint.sh` for code style enforcement & extend it to run on specified paths only.
|
1
changelog.d/6313.bugfix
Normal file
1
changelog.d/6313.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix the `hidden` field in the `devices` table for SQLite versions prior to 3.23.0.
|
1
changelog.d/6314.misc
Normal file
1
changelog.d/6314.misc
Normal file
@ -0,0 +1 @@
|
||||
Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated.
|
@ -78,7 +78,7 @@ class InputOutput(object):
|
||||
m = re.match("^join (\S+)$", line)
|
||||
if m:
|
||||
# The `sender` wants to join a room.
|
||||
room_name, = m.groups()
|
||||
(room_name,) = m.groups()
|
||||
self.print_line("%s joining %s" % (self.user, room_name))
|
||||
self.server.join_room(room_name, self.user, self.user)
|
||||
# self.print_line("OK.")
|
||||
@ -105,7 +105,7 @@ class InputOutput(object):
|
||||
m = re.match("^backfill (\S+)$", line)
|
||||
if m:
|
||||
# we want to backfill a room
|
||||
room_name, = m.groups()
|
||||
(room_name,) = m.groups()
|
||||
self.print_line("backfill %s" % room_name)
|
||||
self.server.backfill(room_name)
|
||||
return
|
||||
|
@ -199,7 +199,20 @@ client (C):
|
||||
|
||||
#### REPLICATE (C)
|
||||
|
||||
Asks the server to replicate a given stream
|
||||
Asks the server to replicate a given stream. The syntax is:
|
||||
|
||||
```
|
||||
REPLICATE <stream_name> <token>
|
||||
```
|
||||
|
||||
Where `<token>` may be either:
|
||||
* a numeric stream_id to stream updates since (exclusive)
|
||||
* `NOW` to stream all subsequent updates.
|
||||
|
||||
The `<stream_name>` is the name of a replication stream to subscribe
|
||||
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
|
||||
of streams). It can also be `ALL` to subscribe to all known streams,
|
||||
in which case the `<token>` must be set to `NOW`.
|
||||
|
||||
#### USER_SYNC (C)
|
||||
|
||||
|
5
mypy.ini
5
mypy.ini
@ -1,7 +1,10 @@
|
||||
[mypy]
|
||||
namespace_packages = True
|
||||
plugins = mypy_zope:plugin
|
||||
follow_imports=skip
|
||||
follow_imports = normal
|
||||
check_untyped_defs = True
|
||||
show_error_codes = True
|
||||
show_traceback = True
|
||||
mypy_path = stubs
|
||||
|
||||
[mypy-zope]
|
||||
|
@ -7,7 +7,15 @@
|
||||
|
||||
set -e
|
||||
|
||||
isort -y -rc synapse tests scripts-dev scripts
|
||||
flake8 synapse tests
|
||||
python3 -m black synapse tests scripts-dev scripts
|
||||
if [ $# -ge 1 ]
|
||||
then
|
||||
files=$*
|
||||
else
|
||||
files="synapse tests scripts-dev scripts"
|
||||
fi
|
||||
|
||||
echo "Linting these locations: $files"
|
||||
isort -y -rc $files
|
||||
flake8 $files
|
||||
python3 -m black $files
|
||||
./scripts-dev/config-lint.sh
|
||||
|
@ -138,3 +138,10 @@ class LimitBlockingTypes(object):
|
||||
|
||||
MONTHLY_ACTIVE_USER = "monthly_active_user"
|
||||
HS_DISABLED = "hs_disabled"
|
||||
|
||||
|
||||
class EventContentFields(object):
|
||||
"""Fields found in events' content, regardless of type."""
|
||||
|
||||
# Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
|
||||
LABELS = "org.matrix.labels"
|
||||
|
@ -20,6 +20,7 @@ from jsonschema import FormatChecker
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import RoomID, UserID
|
||||
@ -66,6 +67,10 @@ ROOM_EVENT_FILTER_SCHEMA = {
|
||||
"contains_url": {"type": "boolean"},
|
||||
"lazy_load_members": {"type": "boolean"},
|
||||
"include_redundant_members": {"type": "boolean"},
|
||||
# Include or exclude events with the provided labels.
|
||||
# cf https://github.com/matrix-org/matrix-doc/pull/2326
|
||||
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
|
||||
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
|
||||
@ -259,6 +264,9 @@ class Filter(object):
|
||||
|
||||
self.contains_url = self.filter_json.get("contains_url", None)
|
||||
|
||||
self.labels = self.filter_json.get("org.matrix.labels", None)
|
||||
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
|
||||
|
||||
def filters_all_types(self):
|
||||
return "*" in self.not_types
|
||||
|
||||
@ -282,6 +290,7 @@ class Filter(object):
|
||||
room_id = None
|
||||
ev_type = "m.presence"
|
||||
contains_url = False
|
||||
labels = []
|
||||
else:
|
||||
sender = event.get("sender", None)
|
||||
if not sender:
|
||||
@ -300,10 +309,11 @@ class Filter(object):
|
||||
content = event.get("content", {})
|
||||
# check if there is a string url field in the content for filtering purposes
|
||||
contains_url = isinstance(content.get("url"), text_type)
|
||||
labels = content.get(EventContentFields.LABELS, [])
|
||||
|
||||
return self.check_fields(room_id, sender, ev_type, contains_url)
|
||||
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
|
||||
|
||||
def check_fields(self, room_id, sender, event_type, contains_url):
|
||||
def check_fields(self, room_id, sender, event_type, labels, contains_url):
|
||||
"""Checks whether the filter matches the given event fields.
|
||||
|
||||
Returns:
|
||||
@ -313,6 +323,7 @@ class Filter(object):
|
||||
"rooms": lambda v: room_id == v,
|
||||
"senders": lambda v: sender == v,
|
||||
"types": lambda v: _matches_wildcard(event_type, v),
|
||||
"labels": lambda v: v in labels,
|
||||
}
|
||||
|
||||
for name, match_func in literal_keys.items():
|
||||
|
@ -565,7 +565,7 @@ def run(hs):
|
||||
"Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
|
||||
)
|
||||
try:
|
||||
yield hs.get_simple_http_client().put_json(
|
||||
yield hs.get_proxied_http_client().put_json(
|
||||
hs.config.report_stats_endpoint, stats
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -94,7 +94,9 @@ class ApplicationService(object):
|
||||
ip_range_whitelist=None,
|
||||
):
|
||||
self.token = token
|
||||
self.url = url
|
||||
self.url = (
|
||||
url.rstrip("/") if isinstance(url, str) else None
|
||||
) # url must not end with a slash
|
||||
self.hs_token = hs_token
|
||||
self.sender = sender
|
||||
self.server_name = hostname
|
||||
|
@ -234,8 +234,8 @@ def setup_logging(
|
||||
|
||||
# make sure that the first thing we log is a thing we can grep backwards
|
||||
# for
|
||||
logging.warn("***** STARTING SERVER *****")
|
||||
logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
|
||||
logging.warning("***** STARTING SERVER *****")
|
||||
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
|
||||
logging.info("Server hostname: %s", config.server_name)
|
||||
|
||||
return logger
|
||||
|
@ -37,9 +37,6 @@ class EventContext:
|
||||
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
|
||||
(type, state_key) -> event_id. ``None`` for an outlier.
|
||||
|
||||
prev_state_events (?): XXX: is this ever set to anything other than
|
||||
the empty list?
|
||||
|
||||
app_service: FIXME
|
||||
|
||||
_current_state_ids (dict[(str, str), str]|None):
|
||||
@ -51,36 +48,16 @@ class EventContext:
|
||||
The current state map excluding the current event. None if outlier
|
||||
or we haven't fetched the state from DB yet.
|
||||
(type, state_key) -> event_id
|
||||
|
||||
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
||||
been calculated. None if we haven't started calculating yet
|
||||
|
||||
_event_type (str): The type of the event the context is associated with.
|
||||
Only set when state has not been fetched yet.
|
||||
|
||||
_event_state_key (str|None): The state_key of the event the context is
|
||||
associated with. Only set when state has not been fetched yet.
|
||||
|
||||
_prev_state_id (str|None): If the event associated with the context is
|
||||
a state event, then `_prev_state_id` is the event_id of the state
|
||||
that was replaced.
|
||||
Only set when state has not been fetched yet.
|
||||
"""
|
||||
|
||||
state_group = attr.ib(default=None)
|
||||
rejected = attr.ib(default=False)
|
||||
prev_group = attr.ib(default=None)
|
||||
delta_ids = attr.ib(default=None)
|
||||
prev_state_events = attr.ib(default=attr.Factory(list))
|
||||
app_service = attr.ib(default=None)
|
||||
|
||||
_current_state_ids = attr.ib(default=None)
|
||||
_prev_state_ids = attr.ib(default=None)
|
||||
_prev_state_id = attr.ib(default=None)
|
||||
|
||||
_event_type = attr.ib(default=None)
|
||||
_event_state_key = attr.ib(default=None)
|
||||
_fetching_state_deferred = attr.ib(default=None)
|
||||
_current_state_ids = attr.ib(default=None)
|
||||
|
||||
@staticmethod
|
||||
def with_state(
|
||||
@ -90,7 +67,6 @@ class EventContext:
|
||||
current_state_ids=current_state_ids,
|
||||
prev_state_ids=prev_state_ids,
|
||||
state_group=state_group,
|
||||
fetching_state_deferred=defer.succeed(None),
|
||||
prev_group=prev_group,
|
||||
delta_ids=delta_ids,
|
||||
)
|
||||
@ -125,7 +101,6 @@ class EventContext:
|
||||
"rejected": self.rejected,
|
||||
"prev_group": self.prev_group,
|
||||
"delta_ids": _encode_state_dict(self.delta_ids),
|
||||
"prev_state_events": self.prev_state_events,
|
||||
"app_service_id": self.app_service.id if self.app_service else None,
|
||||
}
|
||||
|
||||
@ -141,7 +116,7 @@ class EventContext:
|
||||
Returns:
|
||||
EventContext
|
||||
"""
|
||||
context = EventContext(
|
||||
context = _AsyncEventContextImpl(
|
||||
# We use the state_group and prev_state_id stuff to pull the
|
||||
# current_state_ids out of the DB and construct prev_state_ids.
|
||||
prev_state_id=input["prev_state_id"],
|
||||
@ -151,7 +126,6 @@ class EventContext:
|
||||
prev_group=input["prev_group"],
|
||||
delta_ids=_decode_state_dict(input["delta_ids"]),
|
||||
rejected=input["rejected"],
|
||||
prev_state_events=input["prev_state_events"],
|
||||
)
|
||||
|
||||
app_service_id = input["app_service_id"]
|
||||
@ -170,14 +144,7 @@ class EventContext:
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
"""
|
||||
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(
|
||||
self._fill_out_state, store
|
||||
)
|
||||
|
||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
yield self._ensure_fetched(store)
|
||||
return self._current_state_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -190,14 +157,7 @@ class EventContext:
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
"""
|
||||
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(
|
||||
self._fill_out_state, store
|
||||
)
|
||||
|
||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
yield self._ensure_fetched(store)
|
||||
return self._prev_state_ids
|
||||
|
||||
def get_cached_current_state_ids(self):
|
||||
@ -211,6 +171,44 @@ class EventContext:
|
||||
|
||||
return self._current_state_ids
|
||||
|
||||
def _ensure_fetched(self, store):
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _AsyncEventContextImpl(EventContext):
|
||||
"""
|
||||
An implementation of EventContext which fetches _current_state_ids and
|
||||
_prev_state_ids from the database on demand.
|
||||
|
||||
Attributes:
|
||||
|
||||
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
||||
been calculated. None if we haven't started calculating yet
|
||||
|
||||
_event_type (str): The type of the event the context is associated with.
|
||||
|
||||
_event_state_key (str): The state_key of the event the context is
|
||||
associated with.
|
||||
|
||||
_prev_state_id (str|None): If the event associated with the context is
|
||||
a state event, then `_prev_state_id` is the event_id of the state
|
||||
that was replaced.
|
||||
"""
|
||||
|
||||
_prev_state_id = attr.ib(default=None)
|
||||
_event_type = attr.ib(default=None)
|
||||
_event_state_key = attr.ib(default=None)
|
||||
_fetching_state_deferred = attr.ib(default=None)
|
||||
|
||||
def _ensure_fetched(self, store):
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(
|
||||
self._fill_out_state, store
|
||||
)
|
||||
|
||||
return make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _fill_out_state(self, store):
|
||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||
@ -228,27 +226,6 @@ class EventContext:
|
||||
else:
|
||||
self._prev_state_ids = self._current_state_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_state(
|
||||
self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
|
||||
):
|
||||
"""Replace the state in the context
|
||||
"""
|
||||
|
||||
# We need to make sure we wait for any ongoing fetching of state
|
||||
# to complete so that the updated state doesn't get clobbered
|
||||
if self._fetching_state_deferred:
|
||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
self.state_group = state_group
|
||||
self._prev_state_ids = prev_state_ids
|
||||
self.prev_group = prev_group
|
||||
self._current_state_ids = current_state_ids
|
||||
self.delta_ids = delta_ids
|
||||
|
||||
# We need to ensure that that we've marked as having fetched the state
|
||||
self._fetching_state_deferred = defer.succeed(None)
|
||||
|
||||
|
||||
def _encode_state_dict(state_dict):
|
||||
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|
||||
|
@ -555,7 +555,7 @@ class FederationClient(FederationBase):
|
||||
Note that this does not append any events to any graphs.
|
||||
|
||||
Args:
|
||||
destinations (str): Candidate homeservers which are probably
|
||||
destinations (Iterable[str]): Candidate homeservers which are probably
|
||||
participating in the room.
|
||||
room_id (str): The room in which the event will happen.
|
||||
user_id (str): The user whose membership is being evented.
|
||||
|
@ -192,15 +192,16 @@ class PerDestinationQueue(object):
|
||||
# We have to keep 2 free slots for presence and rr_edus
|
||||
limit = MAX_EDUS_PER_TRANSACTION - 2
|
||||
|
||||
device_update_edus, dev_list_id = (
|
||||
yield self._get_device_update_edus(limit)
|
||||
device_update_edus, dev_list_id = yield self._get_device_update_edus(
|
||||
limit
|
||||
)
|
||||
|
||||
limit -= len(device_update_edus)
|
||||
|
||||
to_device_edus, device_stream_id = (
|
||||
yield self._get_to_device_message_edus(limit)
|
||||
)
|
||||
(
|
||||
to_device_edus,
|
||||
device_stream_id,
|
||||
) = yield self._get_to_device_message_edus(limit)
|
||||
|
||||
pending_edus = device_update_edus + to_device_edus
|
||||
|
||||
@ -359,20 +360,20 @@ class PerDestinationQueue(object):
|
||||
last_device_list = self._last_device_list_stream_id
|
||||
|
||||
# Retrieve list of new device updates to send to the destination
|
||||
now_stream_id, results = yield self._store.get_devices_by_remote(
|
||||
now_stream_id, results = yield self._store.get_device_updates_by_remote(
|
||||
self._destination, last_device_list, limit=limit
|
||||
)
|
||||
edus = [
|
||||
Edu(
|
||||
origin=self._server_name,
|
||||
destination=self._destination,
|
||||
edu_type="m.device_list_update",
|
||||
edu_type=edu_type,
|
||||
content=content,
|
||||
)
|
||||
for content in results
|
||||
for (edu_type, content) in results
|
||||
]
|
||||
|
||||
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
|
||||
assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
|
||||
|
||||
return (edus, now_stream_id)
|
||||
|
||||
|
@ -38,9 +38,10 @@ class AccountDataEventSource(object):
|
||||
{"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
|
||||
)
|
||||
|
||||
account_data, room_account_data = (
|
||||
yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
||||
)
|
||||
(
|
||||
account_data,
|
||||
room_account_data,
|
||||
) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
||||
|
||||
for account_data_type, content in account_data.items():
|
||||
results.append({"type": account_data_type, "content": content})
|
||||
|
@ -73,7 +73,10 @@ class ApplicationServicesHandler(object):
|
||||
try:
|
||||
limit = 100
|
||||
while True:
|
||||
upper_bound, events = yield self.store.get_new_events_for_appservice(
|
||||
(
|
||||
upper_bound,
|
||||
events,
|
||||
) = yield self.store.get_new_events_for_appservice(
|
||||
self.current_max, limit
|
||||
)
|
||||
|
||||
|
@ -459,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
@defer.inlineCallbacks
|
||||
def on_federation_query_user_devices(self, user_id):
|
||||
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
|
||||
return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
|
||||
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
|
||||
self_signing_key = yield self.store.get_e2e_cross_signing_key(
|
||||
user_id, "self_signing"
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"stream_id": stream_id,
|
||||
"devices": devices,
|
||||
"master_key": master_key,
|
||||
"self_signing_key": self_signing_key,
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_left_room(self, user, room_id):
|
||||
|
@ -250,7 +250,7 @@ class DirectoryHandler(BaseHandler):
|
||||
ignore_backoff=True,
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
logging.warn("Error retrieving alias")
|
||||
logging.warning("Error retrieving alias")
|
||||
if e.code == 404:
|
||||
result = None
|
||||
else:
|
||||
|
@ -36,6 +36,8 @@ from synapse.types import (
|
||||
get_verify_key_from_cross_signing_key,
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -49,10 +51,19 @@ class E2eKeysHandler(object):
|
||||
self.is_mine = hs.is_mine
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self._edu_updater = SigningKeyEduUpdater(hs, self)
|
||||
|
||||
federation_registry = hs.get_federation_registry()
|
||||
|
||||
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
|
||||
federation_registry.register_edu_handler(
|
||||
"org.matrix.signing_key_update",
|
||||
self._edu_updater.incoming_signing_key_update,
|
||||
)
|
||||
# doesn't really work as part of the generic query API, because the
|
||||
# query request requires an object POST, but we abuse the
|
||||
# "query handler" interface.
|
||||
hs.get_federation_registry().register_query_handler(
|
||||
federation_registry.register_query_handler(
|
||||
"client_keys", self.on_federation_query_client_keys
|
||||
)
|
||||
|
||||
@ -119,9 +130,10 @@ class E2eKeysHandler(object):
|
||||
else:
|
||||
query_list.append((user_id, None))
|
||||
|
||||
user_ids_not_in_cache, remote_results = (
|
||||
yield self.store.get_user_devices_from_cache(query_list)
|
||||
)
|
||||
(
|
||||
user_ids_not_in_cache,
|
||||
remote_results,
|
||||
) = yield self.store.get_user_devices_from_cache(query_list)
|
||||
for user_id, devices in iteritems(remote_results):
|
||||
user_devices = results.setdefault(user_id, {})
|
||||
for device_id, device in iteritems(devices):
|
||||
@ -207,10 +219,12 @@ class E2eKeysHandler(object):
|
||||
if user_id in destination_query:
|
||||
results[user_id] = keys
|
||||
|
||||
if "master_keys" in remote_result:
|
||||
for user_id, key in remote_result["master_keys"].items():
|
||||
if user_id in destination_query:
|
||||
cross_signing_keys["master_keys"][user_id] = key
|
||||
|
||||
if "self_signing_keys" in remote_result:
|
||||
for user_id, key in remote_result["self_signing_keys"].items():
|
||||
if user_id in destination_query:
|
||||
cross_signing_keys["self_signing_keys"][user_id] = key
|
||||
@ -251,7 +265,7 @@ class E2eKeysHandler(object):
|
||||
|
||||
Returns:
|
||||
defer.Deferred[dict[str, dict[str, dict]]]: map from
|
||||
(master|self_signing|user_signing) -> user_id -> key
|
||||
(master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
|
||||
"""
|
||||
master_keys = {}
|
||||
self_signing_keys = {}
|
||||
@ -343,7 +357,16 @@ class E2eKeysHandler(object):
|
||||
"""
|
||||
device_keys_query = query_body.get("device_keys", {})
|
||||
res = yield self.query_local_devices(device_keys_query)
|
||||
return {"device_keys": res}
|
||||
ret = {"device_keys": res}
|
||||
|
||||
# add in the cross-signing keys
|
||||
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
|
||||
device_keys_query, None
|
||||
)
|
||||
|
||||
ret.update(cross_signing_keys)
|
||||
|
||||
return ret
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
@ -688,17 +711,21 @@ class E2eKeysHandler(object):
|
||||
|
||||
try:
|
||||
# get our self-signing key to verify the signatures
|
||||
_, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
|
||||
user_id, "self_signing"
|
||||
)
|
||||
(
|
||||
_,
|
||||
self_signing_key_id,
|
||||
self_signing_verify_key,
|
||||
) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
|
||||
|
||||
# get our master key, since we may have received a signature of it.
|
||||
# We need to fetch it here so that we know what its key ID is, so
|
||||
# that we can check if a signature that was sent is a signature of
|
||||
# the master key or of a device
|
||||
master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key(
|
||||
user_id, "master"
|
||||
)
|
||||
(
|
||||
master_key,
|
||||
_,
|
||||
master_verify_key,
|
||||
) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
|
||||
|
||||
# fetch our stored devices. This is used to 1. verify
|
||||
# signatures on the master key, and 2. to compare with what
|
||||
@ -838,9 +865,11 @@ class E2eKeysHandler(object):
|
||||
|
||||
try:
|
||||
# get our user-signing key to verify the signatures
|
||||
user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
|
||||
user_id, "user_signing"
|
||||
)
|
||||
(
|
||||
user_signing_key,
|
||||
user_signing_key_id,
|
||||
user_signing_verify_key,
|
||||
) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
|
||||
except SynapseError as e:
|
||||
failure = _exception_to_failure(e)
|
||||
for user, devicemap in signatures.items():
|
||||
@ -859,7 +888,11 @@ class E2eKeysHandler(object):
|
||||
try:
|
||||
# get the target user's master key, to make sure it matches
|
||||
# what was sent
|
||||
master_key, master_key_id, _ = yield self._get_e2e_cross_signing_verify_key(
|
||||
(
|
||||
master_key,
|
||||
master_key_id,
|
||||
_,
|
||||
) = yield self._get_e2e_cross_signing_verify_key(
|
||||
target_user, "master", user_id
|
||||
)
|
||||
|
||||
@ -1047,3 +1080,100 @@ class SignatureListItem:
|
||||
target_user_id = attr.ib()
|
||||
target_device_id = attr.ib()
|
||||
signature = attr.ib()
|
||||
|
||||
|
||||
class SigningKeyEduUpdater(object):
|
||||
"""Handles incoming signing key updates from federation and updates the DB"""
|
||||
|
||||
def __init__(self, hs, e2e_keys_handler):
|
||||
self.store = hs.get_datastore()
|
||||
self.federation = hs.get_federation_client()
|
||||
self.clock = hs.get_clock()
|
||||
self.e2e_keys_handler = e2e_keys_handler
|
||||
|
||||
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
||||
|
||||
# user_id -> list of updates waiting to be handled.
|
||||
self._pending_updates = {}
|
||||
|
||||
# Recently seen stream ids. We don't bother keeping these in the DB,
|
||||
# but they're useful to have them about to reduce the number of spurious
|
||||
# resyncs.
|
||||
self._seen_updates = ExpiringCache(
|
||||
cache_name="signing_key_update_edu",
|
||||
clock=self.clock,
|
||||
max_len=10000,
|
||||
expiry_ms=30 * 60 * 1000,
|
||||
iterable=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def incoming_signing_key_update(self, origin, edu_content):
|
||||
"""Called on incoming signing key update from federation. Responsible for
|
||||
parsing the EDU and adding to pending updates list.
|
||||
|
||||
Args:
|
||||
origin (string): the server that sent the EDU
|
||||
edu_content (dict): the contents of the EDU
|
||||
"""
|
||||
|
||||
user_id = edu_content.pop("user_id")
|
||||
master_key = edu_content.pop("master_key", None)
|
||||
self_signing_key = edu_content.pop("self_signing_key", None)
|
||||
|
||||
if get_domain_from_id(user_id) != origin:
|
||||
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
|
||||
return
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
if not room_ids:
|
||||
# We don't share any rooms with this user. Ignore update, as we
|
||||
# probably won't get any further updates.
|
||||
return
|
||||
|
||||
self._pending_updates.setdefault(user_id, []).append(
|
||||
(master_key, self_signing_key)
|
||||
)
|
||||
|
||||
yield self._handle_signing_key_updates(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_signing_key_updates(self, user_id):
|
||||
"""Actually handle pending updates.
|
||||
|
||||
Args:
|
||||
user_id (string): the user whose updates we are processing
|
||||
"""
|
||||
|
||||
device_handler = self.e2e_keys_handler.device_handler
|
||||
|
||||
with (yield self._remote_edu_linearizer.queue(user_id)):
|
||||
pending_updates = self._pending_updates.pop(user_id, [])
|
||||
if not pending_updates:
|
||||
# This can happen since we batch updates
|
||||
return
|
||||
|
||||
device_ids = []
|
||||
|
||||
logger.info("pending updates: %r", pending_updates)
|
||||
|
||||
for master_key, self_signing_key in pending_updates:
|
||||
if master_key:
|
||||
yield self.store.set_e2e_cross_signing_key(
|
||||
user_id, "master", master_key
|
||||
)
|
||||
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
|
||||
# verify_key is a VerifyKey from signedjson, which uses
|
||||
# .version to denote the portion of the key ID after the
|
||||
# algorithm and colon, which is the device ID
|
||||
device_ids.append(verify_key.version)
|
||||
if self_signing_key:
|
||||
yield self.store.set_e2e_cross_signing_key(
|
||||
user_id, "self_signing", self_signing_key
|
||||
)
|
||||
_, verify_key = get_verify_key_from_cross_signing_key(
|
||||
self_signing_key
|
||||
)
|
||||
device_ids.append(verify_key.version)
|
||||
|
||||
yield device_handler.notify_device_update(user_id, device_ids)
|
||||
|
@ -45,6 +45,7 @@ from synapse.api.errors import (
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
from synapse.event_auth import auth_types_for_event
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.logging.context import (
|
||||
make_deferred_yieldable,
|
||||
@ -352,11 +353,12 @@ class FederationHandler(BaseHandler):
|
||||
# note that if any of the missing prevs share missing state or
|
||||
# auth events, the requests to fetch those events are deduped
|
||||
# by the get_pdu_cache in federation_client.
|
||||
remote_state, got_auth_chain = (
|
||||
yield self.federation_client.get_state_for_room(
|
||||
(
|
||||
remote_state,
|
||||
got_auth_chain,
|
||||
) = yield self.federation_client.get_state_for_room(
|
||||
origin, room_id, p
|
||||
)
|
||||
)
|
||||
|
||||
# we want the state *after* p; get_state_for_room returns the
|
||||
# state *before* p.
|
||||
@ -1105,7 +1107,7 @@ class FederationHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
def do_invite_join(self, target_hosts, room_id, joinee, content):
|
||||
""" Attempts to join the `joinee` to the room `room_id` via the
|
||||
server `target_host`.
|
||||
servers contained in `target_hosts`.
|
||||
|
||||
This first triggers a /make_join/ request that returns a partial
|
||||
event that we can fill out and sign. This is then sent to the
|
||||
@ -1114,6 +1116,15 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
We suspend processing of any received events from this room until we
|
||||
have finished processing the join.
|
||||
|
||||
Args:
|
||||
target_hosts (Iterable[str]): List of servers to attempt to join the room with.
|
||||
|
||||
room_id (str): The ID of the room to join.
|
||||
|
||||
joinee (str): The User ID of the joining user.
|
||||
|
||||
content (dict): The event content to use for the join event.
|
||||
"""
|
||||
logger.debug("Joining %s to %s", joinee, room_id)
|
||||
|
||||
@ -1173,6 +1184,22 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
yield self._persist_auth_tree(origin, auth_chain, state, event)
|
||||
|
||||
# Check whether this room is the result of an upgrade of a room we already know
|
||||
# about. If so, migrate over user information
|
||||
predecessor = yield self.store.get_room_predecessor(room_id)
|
||||
if not predecessor:
|
||||
return
|
||||
old_room_id = predecessor["room_id"]
|
||||
logger.debug(
|
||||
"Found predecessor for %s during remote join: %s", room_id, old_room_id
|
||||
)
|
||||
|
||||
# We retrieve the room member handler here as to not cause a cyclic dependency
|
||||
member_handler = self.hs.get_room_member_handler()
|
||||
yield member_handler.transfer_room_state_on_room_upgrade(
|
||||
old_room_id, room_id
|
||||
)
|
||||
|
||||
logger.debug("Finished joining %s to %s", joinee, room_id)
|
||||
finally:
|
||||
room_queue = self.room_queues[room_id]
|
||||
@ -1845,14 +1872,7 @@ class FederationHandler(BaseHandler):
|
||||
if c and c.type == EventTypes.Create:
|
||||
auth_events[(c.type, c.state_key)] = c
|
||||
|
||||
try:
|
||||
yield self.do_auth(origin, event, context, auth_events=auth_events)
|
||||
except AuthError as e:
|
||||
logger.warning(
|
||||
"[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
|
||||
)
|
||||
|
||||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
context = yield self.do_auth(origin, event, context, auth_events=auth_events)
|
||||
|
||||
if not context.rejected:
|
||||
yield self._check_for_soft_fail(event, state, backfilled)
|
||||
@ -2021,12 +2041,12 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
Also NB that this function adds entries to it.
|
||||
Returns:
|
||||
defer.Deferred[None]
|
||||
defer.Deferred[EventContext]: updated context object
|
||||
"""
|
||||
room_version = yield self.store.get_room_version(event.room_id)
|
||||
|
||||
try:
|
||||
yield self._update_auth_events_and_context_for_auth(
|
||||
context = yield self._update_auth_events_and_context_for_auth(
|
||||
origin, event, context, auth_events
|
||||
)
|
||||
except Exception:
|
||||
@ -2044,7 +2064,9 @@ class FederationHandler(BaseHandler):
|
||||
event_auth.check(room_version, event, auth_events=auth_events)
|
||||
except AuthError as e:
|
||||
logger.warning("Failed auth resolution for %r because %s", event, e)
|
||||
raise e
|
||||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
return context
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_auth_events_and_context_for_auth(
|
||||
@ -2068,7 +2090,7 @@ class FederationHandler(BaseHandler):
|
||||
auth_events (dict[(str, str)->synapse.events.EventBase]):
|
||||
|
||||
Returns:
|
||||
defer.Deferred[None]
|
||||
defer.Deferred[EventContext]: updated context
|
||||
"""
|
||||
event_auth_events = set(event.auth_event_ids())
|
||||
|
||||
@ -2107,7 +2129,7 @@ class FederationHandler(BaseHandler):
|
||||
# The other side isn't around or doesn't implement the
|
||||
# endpoint, so lets just bail out.
|
||||
logger.info("Failed to get event auth from remote: %s", e)
|
||||
return
|
||||
return context
|
||||
|
||||
seen_remotes = yield self.store.have_seen_events(
|
||||
[e.event_id for e in remote_auth_chain]
|
||||
@ -2148,7 +2170,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
if event.internal_metadata.is_outlier():
|
||||
logger.info("Skipping auth_event fetch for outlier")
|
||||
return
|
||||
return context
|
||||
|
||||
# FIXME: Assumes we have and stored all the state for all the
|
||||
# prev_events
|
||||
@ -2157,7 +2179,7 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
|
||||
if not different_auth:
|
||||
return
|
||||
return context
|
||||
|
||||
logger.info(
|
||||
"auth_events refers to events which are not in our calculated auth "
|
||||
@ -2204,10 +2226,12 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
auth_events.update(new_state)
|
||||
|
||||
yield self._update_context_for_auth_events(
|
||||
context = yield self._update_context_for_auth_events(
|
||||
event, context, auth_events, event_key
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
|
||||
"""Update the state_ids in an event context after auth event resolution,
|
||||
@ -2216,14 +2240,16 @@ class FederationHandler(BaseHandler):
|
||||
Args:
|
||||
event (Event): The event we're handling the context for
|
||||
|
||||
context (synapse.events.snapshot.EventContext): event context
|
||||
to be updated
|
||||
context (synapse.events.snapshot.EventContext): initial event context
|
||||
|
||||
auth_events (dict[(str, str)->str]): Events to update in the event
|
||||
context.
|
||||
|
||||
event_key ((str, str)): (type, state_key) for the current event.
|
||||
this will not be included in the current_state in the context.
|
||||
|
||||
Returns:
|
||||
Deferred[EventContext]: new event context
|
||||
"""
|
||||
state_updates = {
|
||||
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
|
||||
@ -2248,7 +2274,7 @@ class FederationHandler(BaseHandler):
|
||||
current_state_ids=current_state_ids,
|
||||
)
|
||||
|
||||
yield context.update_state(
|
||||
return EventContext.with_state(
|
||||
state_group=state_group,
|
||||
current_state_ids=current_state_ids,
|
||||
prev_state_ids=prev_state_ids,
|
||||
@ -2441,6 +2467,8 @@ class FederationHandler(BaseHandler):
|
||||
raise e
|
||||
|
||||
yield self._check_signature(event, context)
|
||||
|
||||
# We retrieve the room member handler here as to not cause a cyclic dependency
|
||||
member_handler = self.hs.get_room_member_handler()
|
||||
yield member_handler.send_membership_event(None, event, context)
|
||||
else:
|
||||
@ -2501,6 +2529,7 @@ class FederationHandler(BaseHandler):
|
||||
# though the sender isn't a local user.
|
||||
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
|
||||
|
||||
# We retrieve the room member handler here as to not cause a cyclic dependency
|
||||
member_handler = self.hs.get_room_member_handler()
|
||||
yield member_handler.send_membership_event(None, event, context)
|
||||
|
||||
|
@ -128,8 +128,8 @@ class InitialSyncHandler(BaseHandler):
|
||||
|
||||
tags_by_room = yield self.store.get_tags_for_user(user_id)
|
||||
|
||||
account_data, account_data_by_room = (
|
||||
yield self.store.get_account_data_for_user(user_id)
|
||||
account_data, account_data_by_room = yield self.store.get_account_data_for_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
public_room_ids = yield self.store.get_public_room_ids()
|
||||
|
@ -76,9 +76,10 @@ class MessageHandler(object):
|
||||
Raises:
|
||||
SynapseError if something went wrong.
|
||||
"""
|
||||
membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
|
||||
room_id, user_id
|
||||
)
|
||||
(
|
||||
membership,
|
||||
membership_event_id,
|
||||
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
data = yield self.state.get_current_state(room_id, event_type, state_key)
|
||||
@ -153,9 +154,10 @@ class MessageHandler(object):
|
||||
% (user_id, room_id, at_token),
|
||||
)
|
||||
else:
|
||||
membership, membership_event_id = (
|
||||
yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
)
|
||||
(
|
||||
membership,
|
||||
membership_event_id,
|
||||
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
state_ids = yield self.store.get_filtered_current_state_ids(
|
||||
|
@ -214,9 +214,10 @@ class PaginationHandler(object):
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
with (yield self.pagination_lock.read(room_id)):
|
||||
membership, member_event_id = yield self.auth.check_in_room_or_world_readable(
|
||||
room_id, user_id
|
||||
)
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
|
||||
if source_config.direction == "b":
|
||||
# if we're going backwards, we might need to backfill. This
|
||||
@ -299,10 +300,8 @@ class PaginationHandler(object):
|
||||
}
|
||||
|
||||
if state:
|
||||
chunk["state"] = (
|
||||
yield self._event_serializer.serialize_events(
|
||||
chunk["state"] = yield self._event_serializer.serialize_events(
|
||||
state, time_now, as_client_event=as_client_event
|
||||
)
|
||||
)
|
||||
|
||||
return chunk
|
||||
|
@ -396,8 +396,8 @@ class RegistrationHandler(BaseHandler):
|
||||
room_id = room_identifier
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, remote_room_hosts = (
|
||||
yield room_member_handler.lookup_room_alias(room_alias)
|
||||
room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
|
||||
room_alias
|
||||
)
|
||||
room_id = room_id.to_string()
|
||||
else:
|
||||
|
@ -129,6 +129,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
old_room_id,
|
||||
new_version, # args for _upgrade_room
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -147,8 +148,10 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
# we create and auth the tombstone event before properly creating the new
|
||||
# room, to check our user has perms in the old room.
|
||||
tombstone_event, tombstone_context = (
|
||||
yield self.event_creation_handler.create_event(
|
||||
(
|
||||
tombstone_event,
|
||||
tombstone_context,
|
||||
) = yield self.event_creation_handler.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Tombstone,
|
||||
@ -162,7 +165,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
},
|
||||
token_id=requester.access_token_id,
|
||||
)
|
||||
)
|
||||
old_room_version = yield self.store.get_room_version(old_room_id)
|
||||
yield self.auth.check_from_context(
|
||||
old_room_version, tombstone_event, tombstone_context
|
||||
@ -188,7 +190,12 @@ class RoomCreationHandler(BaseHandler):
|
||||
requester, old_room_id, new_room_id, old_room_state
|
||||
)
|
||||
|
||||
# and finally, shut down the PLs in the old room, and update them in the new
|
||||
# Copy over user push rules, tags and migrate room directory state
|
||||
yield self.room_member_handler.transfer_room_state_on_room_upgrade(
|
||||
old_room_id, new_room_id
|
||||
)
|
||||
|
||||
# finally, shut down the PLs in the old room, and update them in the new
|
||||
# room.
|
||||
yield self._update_upgraded_room_pls(
|
||||
requester, old_room_id, new_room_id, old_room_state
|
||||
|
@ -203,10 +203,6 @@ class RoomMemberHandler(object):
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
if newly_joined:
|
||||
# Copy over user state if we're joining an upgraded room
|
||||
yield self.copy_user_state_if_room_upgrade(
|
||||
room_id, requester.user.to_string()
|
||||
)
|
||||
yield self._user_joined_room(target, room_id)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
if prev_member_event_id:
|
||||
@ -455,11 +451,6 @@ class RoomMemberHandler(object):
|
||||
requester, remote_room_hosts, room_id, target, content
|
||||
)
|
||||
|
||||
# Copy over user state if this is a join on an remote upgraded room
|
||||
yield self.copy_user_state_if_room_upgrade(
|
||||
room_id, requester.user.to_string()
|
||||
)
|
||||
|
||||
return remote_join_response
|
||||
|
||||
elif effective_membership_state == Membership.LEAVE:
|
||||
@ -498,36 +489,72 @@ class RoomMemberHandler(object):
|
||||
return res
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
|
||||
"""Copy user-specific information when they join a new room if that new room is the
|
||||
result of a room upgrade
|
||||
def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
|
||||
"""Upon our server becoming aware of an upgraded room, either by upgrading a room
|
||||
ourselves or joining one, we can transfer over information from the previous room.
|
||||
|
||||
Copies user state (tags/push rules) for every local user that was in the old room, as
|
||||
well as migrating the room directory state.
|
||||
|
||||
Args:
|
||||
new_room_id (str): The ID of the room the user is joining
|
||||
user_id (str): The ID of the user
|
||||
old_room_id (str): The ID of the old room
|
||||
|
||||
room_id (str): The ID of the new room
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
# Find all local users that were in the old room and copy over each user's state
|
||||
users = yield self.store.get_users_in_room(old_room_id)
|
||||
yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
|
||||
|
||||
# Add new room to the room directory if the old room was there
|
||||
# Remove old room from the room directory
|
||||
old_room = yield self.store.get_room(old_room_id)
|
||||
if old_room and old_room["is_public"]:
|
||||
yield self.store.set_room_is_public(old_room_id, False)
|
||||
yield self.store.set_room_is_public(room_id, True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
|
||||
"""Copy user-specific information when they join a new room when that new room is the
|
||||
result of a room upgrade
|
||||
|
||||
Args:
|
||||
old_room_id (str): The ID of upgraded room
|
||||
new_room_id (str): The ID of the new room
|
||||
user_ids (Iterable[str]): User IDs to copy state for
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
# Check if the new room is an upgraded room
|
||||
predecessor = yield self.store.get_room_predecessor(new_room_id)
|
||||
if not predecessor:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"Found predecessor for %s: %s. Copying over room tags and push " "rules",
|
||||
"Copying over room tags and push rules from %s to %s for users %s",
|
||||
old_room_id,
|
||||
new_room_id,
|
||||
predecessor,
|
||||
user_ids,
|
||||
)
|
||||
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
# It is an upgraded room. Copy over old tags
|
||||
yield self.copy_room_tags_and_direct_to_room(
|
||||
predecessor["room_id"], new_room_id, user_id
|
||||
old_room_id, new_room_id, user_id
|
||||
)
|
||||
# Copy over push rules
|
||||
yield self.store.copy_push_rules_from_room_to_room_for_user(
|
||||
predecessor["room_id"], new_room_id, user_id
|
||||
old_room_id, new_room_id, user_id
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error copying tags and/or push rules from rooms %s to %s for user %s. "
|
||||
"Skipping...",
|
||||
old_room_id,
|
||||
new_room_id,
|
||||
user_id,
|
||||
)
|
||||
continue
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_membership_event(self, requester, event, context, ratelimit=True):
|
||||
@ -759,8 +786,12 @@ class RoomMemberHandler(object):
|
||||
if room_avatar_event:
|
||||
room_avatar_url = room_avatar_event.content.get("url", "")
|
||||
|
||||
token, public_keys, fallback_public_key, display_name = (
|
||||
yield self.identity_handler.ask_id_server_for_third_party_invite(
|
||||
(
|
||||
token,
|
||||
public_keys,
|
||||
fallback_public_key,
|
||||
display_name,
|
||||
) = yield self.identity_handler.ask_id_server_for_third_party_invite(
|
||||
requester=requester,
|
||||
id_server=id_server,
|
||||
medium=medium,
|
||||
@ -775,7 +806,6 @@ class RoomMemberHandler(object):
|
||||
inviter_avatar_url=inviter_avatar_url,
|
||||
id_access_token=id_access_token,
|
||||
)
|
||||
)
|
||||
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
|
@ -396,16 +396,12 @@ class SearchHandler(BaseHandler):
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
for context in contexts.values():
|
||||
context["events_before"] = (
|
||||
yield self._event_serializer.serialize_events(
|
||||
context["events_before"] = yield self._event_serializer.serialize_events(
|
||||
context["events_before"], time_now
|
||||
)
|
||||
)
|
||||
context["events_after"] = (
|
||||
yield self._event_serializer.serialize_events(
|
||||
context["events_after"] = yield self._event_serializer.serialize_events(
|
||||
context["events_after"], time_now
|
||||
)
|
||||
)
|
||||
|
||||
state_results = {}
|
||||
if include_state:
|
||||
|
@ -108,7 +108,10 @@ class StatsHandler(StateDeltasHandler):
|
||||
user_deltas = {}
|
||||
|
||||
# Then count deltas for total_events and total_event_bytes.
|
||||
room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(
|
||||
(
|
||||
room_count,
|
||||
user_count,
|
||||
) = yield self.store.get_changes_room_total_events_and_bytes(
|
||||
self.pos, max_pos
|
||||
)
|
||||
|
||||
|
@ -1206,11 +1206,12 @@ class SyncHandler(object):
|
||||
since_token = sync_result_builder.since_token
|
||||
|
||||
if since_token and not sync_result_builder.full_state:
|
||||
account_data, account_data_by_room = (
|
||||
yield self.store.get_updated_account_data_for_user(
|
||||
(
|
||||
account_data,
|
||||
account_data_by_room,
|
||||
) = yield self.store.get_updated_account_data_for_user(
|
||||
user_id, since_token.account_data_key
|
||||
)
|
||||
)
|
||||
|
||||
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
|
||||
user_id, int(since_token.push_rules_key)
|
||||
@ -1221,9 +1222,10 @@ class SyncHandler(object):
|
||||
sync_config.user
|
||||
)
|
||||
else:
|
||||
account_data, account_data_by_room = (
|
||||
yield self.store.get_account_data_for_user(sync_config.user.to_string())
|
||||
)
|
||||
(
|
||||
account_data,
|
||||
account_data_by_room,
|
||||
) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
|
||||
|
||||
account_data["m.push_rules"] = yield self.push_rules_for_user(
|
||||
sync_config.user
|
||||
|
@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
||||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
self._enabled = bool(hs.config.recaptcha_private_key)
|
||||
self._http_client = hs.get_simple_http_client()
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._url = hs.config.recaptcha_siteverify_api
|
||||
self._secret = hs.config.recaptcha_private_key
|
||||
|
||||
|
@ -45,6 +45,7 @@ from synapse.http import (
|
||||
cancelled_to_request_timed_out_error,
|
||||
redact_uri,
|
||||
)
|
||||
from synapse.http.proxyagent import ProxyAgent
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
@ -183,7 +184,15 @@ class SimpleHttpClient(object):
|
||||
using HTTP in Matrix
|
||||
"""
|
||||
|
||||
def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
|
||||
def __init__(
|
||||
self,
|
||||
hs,
|
||||
treq_args={},
|
||||
ip_whitelist=None,
|
||||
ip_blacklist=None,
|
||||
http_proxy=None,
|
||||
https_proxy=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer)
|
||||
@ -192,6 +201,8 @@ class SimpleHttpClient(object):
|
||||
we may not request.
|
||||
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
|
||||
request if it were otherwise caught in a blacklist.
|
||||
http_proxy (bytes): proxy server to use for http connections. host[:port]
|
||||
https_proxy (bytes): proxy server to use for https connections. host[:port]
|
||||
"""
|
||||
self.hs = hs
|
||||
|
||||
@ -236,11 +247,13 @@ class SimpleHttpClient(object):
|
||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||
# 'like a browser'
|
||||
self.agent = Agent(
|
||||
self.agent = ProxyAgent(
|
||||
self.reactor,
|
||||
connectTimeout=15,
|
||||
contextFactory=self.hs.get_http_client_context_factory(),
|
||||
pool=pool,
|
||||
http_proxy=http_proxy,
|
||||
https_proxy=https_proxy,
|
||||
)
|
||||
|
||||
if self._ip_blacklist:
|
||||
|
195
synapse/http/connectproxyclient.py
Normal file
195
synapse/http/connectproxyclient.py
Normal file
@ -0,0 +1,195 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer, protocol
|
||||
from twisted.internet.error import ConnectError
|
||||
from twisted.internet.interfaces import IStreamClientEndpoint
|
||||
from twisted.internet.protocol import connectionDone
|
||||
from twisted.web import http
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyConnectError(ConnectError):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
class HTTPConnectProxyEndpoint(object):
|
||||
"""An Endpoint implementation which will send a CONNECT request to an http proxy
|
||||
|
||||
Wraps an existing HostnameEndpoint for the proxy.
|
||||
|
||||
When we get the connect() request from the connection pool (via the TLS wrapper),
|
||||
we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
|
||||
CONNECT request. Once that completes, we invoke the protocolFactory which was passed
|
||||
in.
|
||||
|
||||
Args:
|
||||
reactor: the Twisted reactor to use for the connection
|
||||
proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
|
||||
proxy
|
||||
host (bytes): hostname that we want to CONNECT to
|
||||
port (int): port that we want to connect to
|
||||
"""
|
||||
|
||||
def __init__(self, reactor, proxy_endpoint, host, port):
|
||||
self._reactor = reactor
|
||||
self._proxy_endpoint = proxy_endpoint
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
||||
def __repr__(self):
|
||||
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
|
||||
|
||||
def connect(self, protocolFactory):
|
||||
f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
|
||||
d = self._proxy_endpoint.connect(f)
|
||||
# once the tcp socket connects successfully, we need to wait for the
|
||||
# CONNECT to complete.
|
||||
d.addCallback(lambda conn: f.on_connection)
|
||||
return d
|
||||
|
||||
|
||||
class HTTPProxiedClientFactory(protocol.ClientFactory):
|
||||
"""ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
|
||||
|
||||
Once the CONNECT completes, invokes the original ClientFactory to build the
|
||||
HTTP Protocol object and run the rest of the connection.
|
||||
|
||||
Args:
|
||||
dst_host (bytes): hostname that we want to CONNECT to
|
||||
dst_port (int): port that we want to connect to
|
||||
wrapped_factory (protocol.ClientFactory): The original Factory
|
||||
"""
|
||||
|
||||
def __init__(self, dst_host, dst_port, wrapped_factory):
|
||||
self.dst_host = dst_host
|
||||
self.dst_port = dst_port
|
||||
self.wrapped_factory = wrapped_factory
|
||||
self.on_connection = defer.Deferred()
|
||||
|
||||
def startedConnecting(self, connector):
|
||||
return self.wrapped_factory.startedConnecting(connector)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
|
||||
|
||||
return HTTPConnectProtocol(
|
||||
self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
|
||||
)
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
logger.debug("Connection to proxy failed: %s", reason)
|
||||
if not self.on_connection.called:
|
||||
self.on_connection.errback(reason)
|
||||
return self.wrapped_factory.clientConnectionFailed(connector, reason)
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
logger.debug("Connection to proxy lost: %s", reason)
|
||||
if not self.on_connection.called:
|
||||
self.on_connection.errback(reason)
|
||||
return self.wrapped_factory.clientConnectionLost(connector, reason)
|
||||
|
||||
|
||||
class HTTPConnectProtocol(protocol.Protocol):
|
||||
"""Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
|
||||
|
||||
Args:
|
||||
host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
|
||||
to put in the CONNECT request
|
||||
|
||||
port (int): The original HTTP(s) port to put in the CONNECT request
|
||||
|
||||
wrapped_protocol (interfaces.IProtocol): the original protocol (probably
|
||||
HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
|
||||
|
||||
connected_deferred (Deferred): a Deferred which will be callbacked with
|
||||
wrapped_protocol when the CONNECT completes
|
||||
"""
|
||||
|
||||
def __init__(self, host, port, wrapped_protocol, connected_deferred):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.wrapped_protocol = wrapped_protocol
|
||||
self.connected_deferred = connected_deferred
|
||||
self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
|
||||
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
|
||||
|
||||
def connectionMade(self):
|
||||
self.http_setup_client.makeConnection(self.transport)
|
||||
|
||||
def connectionLost(self, reason=connectionDone):
|
||||
if self.wrapped_protocol.connected:
|
||||
self.wrapped_protocol.connectionLost(reason)
|
||||
|
||||
self.http_setup_client.connectionLost(reason)
|
||||
|
||||
if not self.connected_deferred.called:
|
||||
self.connected_deferred.errback(reason)
|
||||
|
||||
def proxyConnected(self, _):
|
||||
self.wrapped_protocol.makeConnection(self.transport)
|
||||
|
||||
self.connected_deferred.callback(self.wrapped_protocol)
|
||||
|
||||
# Get any pending data from the http buf and forward it to the original protocol
|
||||
buf = self.http_setup_client.clearLineBuffer()
|
||||
if buf:
|
||||
self.wrapped_protocol.dataReceived(buf)
|
||||
|
||||
def dataReceived(self, data):
|
||||
# if we've set up the HTTP protocol, we can send the data there
|
||||
if self.wrapped_protocol.connected:
|
||||
return self.wrapped_protocol.dataReceived(data)
|
||||
|
||||
# otherwise, we must still be setting up the connection: send the data to the
|
||||
# setup client
|
||||
return self.http_setup_client.dataReceived(data)
|
||||
|
||||
|
||||
class HTTPConnectSetupClient(http.HTTPClient):
|
||||
"""HTTPClient protocol to send a CONNECT message for proxies and read the response.
|
||||
|
||||
Args:
|
||||
host (bytes): The hostname to send in the CONNECT message
|
||||
port (int): The port to send in the CONNECT message
|
||||
"""
|
||||
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.on_connected = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
logger.debug("Connected to proxy, sending CONNECT")
|
||||
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
|
||||
self.endHeaders()
|
||||
|
||||
def handleStatus(self, version, status, message):
|
||||
logger.debug("Got Status: %s %s %s", status, message, version)
|
||||
if status != b"200":
|
||||
raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
|
||||
|
||||
def handleEndHeaders(self):
|
||||
logger.debug("End Headers")
|
||||
self.on_connected.callback(None)
|
||||
|
||||
def handleResponse(self, body):
|
||||
pass
|
195
synapse/http/proxyagent.py
Normal file
195
synapse/http/proxyagent.py
Normal file
@ -0,0 +1,195 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
|
||||
from twisted.web.error import SchemeNotSupported
|
||||
from twisted.web.iweb import IAgent
|
||||
|
||||
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
|
||||
|
||||
|
||||
@implementer(IAgent)
|
||||
class ProxyAgent(_AgentBase):
|
||||
"""An Agent implementation which will use an HTTP proxy if one was requested
|
||||
|
||||
Args:
|
||||
reactor: twisted reactor to place outgoing
|
||||
connections.
|
||||
|
||||
contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
|
||||
verification parameters of OpenSSL. The default is to use a
|
||||
`BrowserLikePolicyForHTTPS`, so unless you have special
|
||||
requirements you can leave this as-is.
|
||||
|
||||
connectTimeout (float): The amount of time that this Agent will wait
|
||||
for the peer to accept a connection.
|
||||
|
||||
bindAddress (bytes): The local address for client sockets to bind to.
|
||||
|
||||
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
|
||||
non-persistent pool instance will be created.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reactor,
|
||||
contextFactory=BrowserLikePolicyForHTTPS(),
|
||||
connectTimeout=None,
|
||||
bindAddress=None,
|
||||
pool=None,
|
||||
http_proxy=None,
|
||||
https_proxy=None,
|
||||
):
|
||||
_AgentBase.__init__(self, reactor, pool)
|
||||
|
||||
self._endpoint_kwargs = {}
|
||||
if connectTimeout is not None:
|
||||
self._endpoint_kwargs["timeout"] = connectTimeout
|
||||
if bindAddress is not None:
|
||||
self._endpoint_kwargs["bindAddress"] = bindAddress
|
||||
|
||||
self.http_proxy_endpoint = _http_proxy_endpoint(
|
||||
http_proxy, reactor, **self._endpoint_kwargs
|
||||
)
|
||||
|
||||
self.https_proxy_endpoint = _http_proxy_endpoint(
|
||||
https_proxy, reactor, **self._endpoint_kwargs
|
||||
)
|
||||
|
||||
self._policy_for_https = contextFactory
|
||||
self._reactor = reactor
|
||||
|
||||
def request(self, method, uri, headers=None, bodyProducer=None):
|
||||
"""
|
||||
Issue a request to the server indicated by the given uri.
|
||||
|
||||
Supports `http` and `https` schemes.
|
||||
|
||||
An existing connection from the connection pool may be used or a new one may be
|
||||
created.
|
||||
|
||||
See also: twisted.web.iweb.IAgent.request
|
||||
|
||||
Args:
|
||||
method (bytes): The request method to use, such as `GET`, `POST`, etc
|
||||
|
||||
uri (bytes): The location of the resource to request.
|
||||
|
||||
headers (Headers|None): Extra headers to send with the request
|
||||
|
||||
bodyProducer (IBodyProducer|None): An object which can generate bytes to
|
||||
make up the body of this request (for example, the properly encoded
|
||||
contents of a file for a file upload). Or, None if the request is to
|
||||
have no body.
|
||||
|
||||
Returns:
|
||||
Deferred[IResponse]: completes when the header of the response has
|
||||
been received (regardless of the response status code).
|
||||
"""
|
||||
uri = uri.strip()
|
||||
if not _VALID_URI.match(uri):
|
||||
raise ValueError("Invalid URI {!r}".format(uri))
|
||||
|
||||
parsed_uri = URI.fromBytes(uri)
|
||||
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
|
||||
request_path = parsed_uri.originForm
|
||||
|
||||
if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
|
||||
# Cache *all* connections under the same key, since we are only
|
||||
# connecting to a single destination, the proxy:
|
||||
pool_key = ("http-proxy", self.http_proxy_endpoint)
|
||||
endpoint = self.http_proxy_endpoint
|
||||
request_path = uri
|
||||
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
|
||||
endpoint = HTTPConnectProxyEndpoint(
|
||||
self._reactor,
|
||||
self.https_proxy_endpoint,
|
||||
parsed_uri.host,
|
||||
parsed_uri.port,
|
||||
)
|
||||
else:
|
||||
# not using a proxy
|
||||
endpoint = HostnameEndpoint(
|
||||
self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
|
||||
)
|
||||
|
||||
logger.debug("Requesting %s via %s", uri, endpoint)
|
||||
|
||||
if parsed_uri.scheme == b"https":
|
||||
tls_connection_creator = self._policy_for_https.creatorForNetloc(
|
||||
parsed_uri.host, parsed_uri.port
|
||||
)
|
||||
endpoint = wrapClientTLS(tls_connection_creator, endpoint)
|
||||
elif parsed_uri.scheme == b"http":
|
||||
pass
|
||||
else:
|
||||
return defer.fail(
|
||||
Failure(
|
||||
SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
|
||||
)
|
||||
)
|
||||
|
||||
return self._requestWithEndpoint(
|
||||
pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
|
||||
)
|
||||
|
||||
|
||||
def _http_proxy_endpoint(proxy, reactor, **kwargs):
|
||||
"""Parses an http proxy setting and returns an endpoint for the proxy
|
||||
|
||||
Args:
|
||||
proxy (bytes|None): the proxy setting
|
||||
reactor: reactor to be used to connect to the proxy
|
||||
kwargs: other args to be passed to HostnameEndpoint
|
||||
|
||||
Returns:
|
||||
interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
|
||||
or None
|
||||
"""
|
||||
if proxy is None:
|
||||
return None
|
||||
|
||||
# currently we only support hostname:port. Some apps also support
|
||||
# protocol://<host>[:port], which allows a way of requiring a TLS connection to the
|
||||
# proxy.
|
||||
|
||||
host, port = parse_host_port(proxy, default_port=1080)
|
||||
return HostnameEndpoint(reactor, host, port, **kwargs)
|
||||
|
||||
|
||||
def parse_host_port(hostport, default_port=None):
|
||||
# could have sworn we had one of these somewhere else...
|
||||
if b":" in hostport:
|
||||
host, port = hostport.rsplit(b":", 1)
|
||||
try:
|
||||
port = int(port)
|
||||
return host, port
|
||||
except ValueError:
|
||||
# the thing after the : wasn't a valid port; presumably this is an
|
||||
# IPv6 address.
|
||||
pass
|
||||
|
||||
return hostport, default_port
|
@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
|
||||
|
||||
|
||||
def parse_drain_configs(
|
||||
drains: dict
|
||||
drains: dict,
|
||||
) -> typing.Generator[DrainConfiguration, None, None]:
|
||||
"""
|
||||
Parse the drain configurations.
|
||||
|
@ -149,9 +149,10 @@ class BulkPushRuleEvaluator(object):
|
||||
|
||||
room_members = yield self.store.get_joined_users_from_context(event, context)
|
||||
|
||||
(power_levels, sender_power_level) = (
|
||||
yield self._get_power_levels_and_sender_level(event, context)
|
||||
)
|
||||
(
|
||||
power_levels,
|
||||
sender_power_level,
|
||||
) = yield self._get_power_levels_and_sender_level(event, context)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(
|
||||
event, len(room_members), sender_power_level, power_levels
|
||||
|
@ -234,15 +234,13 @@ class EmailPusher(object):
|
||||
return
|
||||
|
||||
self.last_stream_ordering = last_stream_ordering
|
||||
pusher_still_exists = (
|
||||
yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||
self.app_id,
|
||||
self.email,
|
||||
self.user_id,
|
||||
last_stream_ordering,
|
||||
self.clock.time_msec(),
|
||||
)
|
||||
)
|
||||
if not pusher_still_exists:
|
||||
# The pusher has been deleted while we were processing, so
|
||||
# lets just stop and return.
|
||||
|
@ -103,7 +103,7 @@ class HttpPusher(object):
|
||||
if "url" not in self.data:
|
||||
raise PusherConfigException("'url' required in data for HTTP pusher")
|
||||
self.url = self.data["url"]
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.http_client = hs.get_proxied_http_client()
|
||||
self.data_minus_url = {}
|
||||
self.data_minus_url.update(self.data)
|
||||
del self.data_minus_url["url"]
|
||||
@ -211,15 +211,13 @@ class HttpPusher(object):
|
||||
http_push_processed_counter.inc()
|
||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||
self.last_stream_ordering = push_action["stream_ordering"]
|
||||
pusher_still_exists = (
|
||||
yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_id,
|
||||
self.last_stream_ordering,
|
||||
self.clock.time_msec(),
|
||||
)
|
||||
)
|
||||
if not pusher_still_exists:
|
||||
# The pusher has been deleted while we were processing, so
|
||||
# lets just stop and return.
|
||||
|
@ -103,9 +103,7 @@ class PusherPool:
|
||||
# create the pusher setting last_stream_ordering to the current maximum
|
||||
# stream ordering in event_push_actions, so it will process
|
||||
# pushes from this point onwards.
|
||||
last_stream_ordering = (
|
||||
yield self.store.get_latest_push_action_stream_ordering()
|
||||
)
|
||||
last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering()
|
||||
|
||||
yield self.store.add_pusher(
|
||||
user_id=user_id,
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import six
|
||||
|
||||
@ -44,7 +45,14 @@ class BaseSlavedStore(SQLBaseStore):
|
||||
|
||||
self.hs = hs
|
||||
|
||||
def stream_positions(self):
|
||||
def stream_positions(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the current positions of all the streams this store wants to subscribe to
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
pos = {}
|
||||
if self._cache_id_gen:
|
||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
|
||||
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
|
||||
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedDeviceStore, self).stream_positions()
|
||||
result["device_lists"] = self._device_list_id_gen.get_current_token()
|
||||
# The user signature stream uses the same stream ID generator as the
|
||||
# device list stream, so set them both to the device list ID
|
||||
# generator's current token.
|
||||
current_token = self._device_list_id_gen.get_current_token()
|
||||
result[DeviceListsStream.NAME] = current_token
|
||||
result[UserSignatureStream.NAME] = current_token
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "device_lists":
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
|
||||
elif stream_name == UserSignatureStream.NAME:
|
||||
for row in rows:
|
||||
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||
return super(SlavedDeviceStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
|
@ -16,10 +16,17 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.tcp.protocol import (
|
||||
AbstractReplicationClientHandler,
|
||||
ClientReplicationStreamProtocol,
|
||||
)
|
||||
|
||||
from .commands import (
|
||||
FederationAckCommand,
|
||||
InvalidateCacheCommand,
|
||||
@ -27,7 +34,6 @@ from .commands import (
|
||||
UserIpCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from .protocol import ClientReplicationStreamProtocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
|
||||
maxDelay = 30 # Try at least once every N seconds
|
||||
|
||||
def __init__(self, hs, client_name, handler):
|
||||
def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
|
||||
self.client_name = client_name
|
||||
self.handler = handler
|
||||
self.server_name = hs.config.server_name
|
||||
@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
|
||||
|
||||
|
||||
class ReplicationClientHandler(object):
|
||||
class ReplicationClientHandler(AbstractReplicationClientHandler):
|
||||
"""A base handler that can be passed to the ReplicationClientFactory.
|
||||
|
||||
By default proxies incoming replication data to the SlaveStore.
|
||||
"""
|
||||
|
||||
def __init__(self, store):
|
||||
def __init__(self, store: BaseSlavedStore):
|
||||
self.store = store
|
||||
|
||||
# The current connection. None if we are currently (re)connecting
|
||||
@ -138,11 +144,13 @@ class ReplicationClientHandler(object):
|
||||
if d:
|
||||
d.callback(data)
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns a dictionary of stream name to token.
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
args = self.store.stream_positions()
|
||||
user_account_data = args.pop("user_account_data", None)
|
||||
|
@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire::
|
||||
> ERROR server stopping
|
||||
* connection closed by server *
|
||||
"""
|
||||
|
||||
import abc
|
||||
import fcntl
|
||||
import logging
|
||||
import struct
|
||||
@ -65,6 +65,7 @@ from twisted.python.failure import Failure
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from .commands import (
|
||||
@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
self.streamer.lost_connection(self)
|
||||
|
||||
|
||||
class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
The interface for the handler that should be passed to
|
||||
ClientReplicationStreamProtocol
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_rdata(self, stream_name, token, rows):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
|
||||
Returns:
|
||||
Deferred|None
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_position(self, stream_name, token):
|
||||
"""Called when we get new position data."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_sync(self, data):
|
||||
"""Called when get a new SYNC command."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_streams_to_replicate(self):
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_connection(self, connection):
|
||||
"""Called when a connection has been established (or lost with None).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def finished_connecting(self):
|
||||
"""Called when we have successfully subscribed and caught up to all
|
||||
streams we're interested in.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
||||
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
||||
|
||||
def __init__(self, client_name, server_name, clock, handler):
|
||||
def __init__(
|
||||
self,
|
||||
client_name: str,
|
||||
server_name: str,
|
||||
clock: Clock,
|
||||
handler: AbstractReplicationClientHandler,
|
||||
):
|
||||
BaseReplicationStreamProtocol.__init__(self, clock)
|
||||
|
||||
self.client_name = client_name
|
||||
|
@ -45,5 +45,6 @@ STREAMS_MAP = {
|
||||
_base.TagAccountDataStream,
|
||||
_base.AccountDataStream,
|
||||
_base.GroupServerStream,
|
||||
_base.UserSignatureStream,
|
||||
)
|
||||
}
|
||||
|
@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple(
|
||||
"GroupsStreamRow",
|
||||
("group_id", "user_id", "type", "content"), # str # str # str # dict
|
||||
)
|
||||
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
|
||||
|
||||
|
||||
class Stream(object):
|
||||
@ -438,3 +439,20 @@ class GroupServerStream(Stream):
|
||||
self.update_function = store.get_all_groups_changes
|
||||
|
||||
super(GroupServerStream, self).__init__(hs)
|
||||
|
||||
|
||||
class UserSignatureStream(Stream):
|
||||
"""A user has signed their own device with their user-signing key
|
||||
"""
|
||||
|
||||
NAME = "user_signature"
|
||||
_LIMITED = False
|
||||
ROW_TYPE = UserSignatureStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token
|
||||
self.update_function = store.get_all_user_signature_changes_for_remotes
|
||||
|
||||
super(UserSignatureStream, self).__init__(hs)
|
||||
|
@ -203,11 +203,12 @@ class LoginRestServlet(RestServlet):
|
||||
address = address.lower()
|
||||
|
||||
# Check for login providers that support 3pid login types
|
||||
canonical_user_id, callback_3pid = (
|
||||
yield self.auth_handler.check_password_provider_3pid(
|
||||
(
|
||||
canonical_user_id,
|
||||
callback_3pid,
|
||||
) = yield self.auth_handler.check_password_provider_3pid(
|
||||
medium, address, login_submission["password"]
|
||||
)
|
||||
)
|
||||
if canonical_user_id:
|
||||
# Authentication through password provider and 3pid succeeded
|
||||
result = yield self._register_device_with_callback(
|
||||
@ -280,8 +281,8 @@ class LoginRestServlet(RestServlet):
|
||||
def do_token_login(self, login_submission):
|
||||
token = login_submission["token"]
|
||||
auth_handler = self.auth_handler
|
||||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
token
|
||||
)
|
||||
|
||||
result = yield self._register_device_with_callback(user_id, login_submission)
|
||||
@ -380,7 +381,7 @@ class CasTicketServlet(RestServlet):
|
||||
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||
self._http_client = hs.get_simple_http_client()
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
@ -148,7 +148,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
self.failure_email_template, = load_jinja2_templates(
|
||||
(self.failure_email_template,) = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[self.config.email_password_reset_template_failure_html],
|
||||
)
|
||||
@ -479,7 +479,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
self.failure_email_template, = load_jinja2_templates(
|
||||
(self.failure_email_template,) = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[self.config.email_add_threepid_template_failure_html],
|
||||
)
|
||||
|
@ -247,13 +247,13 @@ class RegistrationSubmitTokenServlet(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
self.failure_email_template, = load_jinja2_templates(
|
||||
(self.failure_email_template,) = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[self.config.email_registration_template_failure_html],
|
||||
)
|
||||
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
self.failure_email_template, = load_jinja2_templates(
|
||||
(self.failure_email_template,) = load_jinja2_templates(
|
||||
self.config.email_template_dir,
|
||||
[self.config.email_registration_template_failure_html],
|
||||
)
|
||||
|
@ -65,6 +65,9 @@ class VersionsRestServlet(RestServlet):
|
||||
"m.require_identity_server": False,
|
||||
# as per MSC2290
|
||||
"m.separate_add_and_bind": True,
|
||||
# Implements support for label-based filtering as described in
|
||||
# MSC2326.
|
||||
"org.matrix.label_based_filtering": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
@ -102,7 +102,7 @@ class RemoteKey(DirectServeResource):
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
if len(request.postpath) == 1:
|
||||
server, = request.postpath
|
||||
(server,) = request.postpath
|
||||
query = {server.decode("ascii"): {}}
|
||||
elif len(request.postpath) == 2:
|
||||
server, key_id = request.postpath
|
||||
|
@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource):
|
||||
treq_args={"browser_like_redirects": True},
|
||||
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
|
||||
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
|
||||
http_proxy=os.getenv("http_proxy"),
|
||||
https_proxy=os.getenv("HTTPS_PROXY"),
|
||||
)
|
||||
self.media_repo = media_repo
|
||||
self.primary_base_path = media_repo.primary_base_path
|
||||
|
@ -23,6 +23,7 @@
|
||||
# Imports required for the default HomeServer() implementation
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.mail.smtp import sendmail
|
||||
@ -168,6 +169,7 @@ class HomeServer(object):
|
||||
"filtering",
|
||||
"http_client_context_factory",
|
||||
"simple_http_client",
|
||||
"proxied_http_client",
|
||||
"media_repository",
|
||||
"media_repository_resource",
|
||||
"federation_transport_client",
|
||||
@ -311,6 +313,13 @@ class HomeServer(object):
|
||||
def build_simple_http_client(self):
|
||||
return SimpleHttpClient(self)
|
||||
|
||||
def build_proxied_http_client(self):
|
||||
return SimpleHttpClient(
|
||||
self,
|
||||
http_proxy=os.getenv("http_proxy"),
|
||||
https_proxy=os.getenv("HTTPS_PROXY"),
|
||||
)
|
||||
|
||||
def build_room_creation_handler(self):
|
||||
return RoomCreationHandler(self)
|
||||
|
||||
|
@ -12,6 +12,7 @@ import synapse.handlers.message
|
||||
import synapse.handlers.room
|
||||
import synapse.handlers.room_member
|
||||
import synapse.handlers.set_password
|
||||
import synapse.http.client
|
||||
import synapse.rest.media.v1.media_repository
|
||||
import synapse.server_notices.server_notices_manager
|
||||
import synapse.server_notices.server_notices_sender
|
||||
@ -38,8 +39,16 @@ class HomeServer(object):
|
||||
pass
|
||||
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
|
||||
pass
|
||||
def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
|
||||
"""Fetch an HTTP client implementation which doesn't do any blacklisting
|
||||
or support any HTTP_PROXY settings"""
|
||||
pass
|
||||
def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
|
||||
"""Fetch an HTTP client implementation which doesn't do any blacklisting
|
||||
but does support HTTP_PROXY settings"""
|
||||
pass
|
||||
def get_deactivate_account_handler(
|
||||
self
|
||||
self,
|
||||
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
|
||||
pass
|
||||
def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
|
||||
@ -47,32 +56,32 @@ class HomeServer(object):
|
||||
def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
|
||||
pass
|
||||
def get_event_creation_handler(
|
||||
self
|
||||
self,
|
||||
) -> synapse.handlers.message.EventCreationHandler:
|
||||
pass
|
||||
def get_set_password_handler(
|
||||
self
|
||||
self,
|
||||
) -> synapse.handlers.set_password.SetPasswordHandler:
|
||||
pass
|
||||
def get_federation_sender(self) -> synapse.federation.sender.FederationSender:
|
||||
pass
|
||||
def get_federation_transport_client(
|
||||
self
|
||||
self,
|
||||
) -> synapse.federation.transport.client.TransportLayerClient:
|
||||
pass
|
||||
def get_media_repository_resource(
|
||||
self
|
||||
self,
|
||||
) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
|
||||
pass
|
||||
def get_media_repository(
|
||||
self
|
||||
self,
|
||||
) -> synapse.rest.media.v1.media_repository.MediaRepository:
|
||||
pass
|
||||
def get_server_notices_manager(
|
||||
self
|
||||
self,
|
||||
) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
|
||||
pass
|
||||
def get_server_notices_sender(
|
||||
self
|
||||
self,
|
||||
) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
|
||||
pass
|
||||
|
@ -139,7 +139,10 @@ class DataStore(
|
||||
db_conn, "public_room_list_stream", "stream_id"
|
||||
)
|
||||
self._device_list_id_gen = StreamIdGenerator(
|
||||
db_conn, "device_lists_stream", "stream_id"
|
||||
db_conn,
|
||||
"device_lists_stream",
|
||||
"stream_id",
|
||||
extra_tables=[("user_signature_stream", "stream_id")],
|
||||
)
|
||||
self._cross_signing_id_gen = StreamIdGenerator(
|
||||
db_conn, "e2e_cross_signing_keys", "stream_id"
|
||||
@ -317,7 +320,7 @@ class DataStore(
|
||||
) u
|
||||
"""
|
||||
txn.execute(sql, (time_from,))
|
||||
count, = txn.fetchone()
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
def count_r30_users(self):
|
||||
@ -396,7 +399,7 @@ class DataStore(
|
||||
|
||||
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
|
||||
|
||||
count, = txn.fetchone()
|
||||
(count,) = txn.fetchone()
|
||||
results["all"] = count
|
||||
|
||||
return results
|
||||
|
@ -37,6 +37,7 @@ from synapse.storage._base import (
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||
from synapse.types import get_verify_key_from_cross_signing_key
|
||||
from synapse.util import batch_iter
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||
|
||||
@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_remote(self, destination, from_stream_id, limit):
|
||||
"""Get stream of updates to send to remote servers
|
||||
def get_device_updates_by_remote(self, destination, from_stream_id, limit):
|
||||
"""Get a stream of device updates to send to the given remote server.
|
||||
|
||||
Args:
|
||||
destination (str): The host the device updates are intended for
|
||||
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||
limit (int): Maximum number of device updates to return
|
||||
Returns:
|
||||
Deferred[tuple[int, list[dict]]]:
|
||||
Deferred[tuple[int, list[tuple[string,dict]]]]:
|
||||
current stream id (ie, the stream id of the last update included in the
|
||||
response), and the list of updates
|
||||
response), and the list of updates, where each update is a pair of EDU
|
||||
type and EDU contents
|
||||
"""
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
# stream_id; the rationale being that such a large device list update
|
||||
# is likely an error.
|
||||
updates = yield self.runInteraction(
|
||||
"get_devices_by_remote",
|
||||
self._get_devices_by_remote_txn,
|
||||
"get_device_updates_by_remote",
|
||||
self._get_device_updates_by_remote_txn,
|
||||
destination,
|
||||
from_stream_id,
|
||||
now_stream_id,
|
||||
@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
if not updates:
|
||||
return now_stream_id, []
|
||||
|
||||
# get the cross-signing keys of the users in the list, so that we can
|
||||
# determine which of the device changes were cross-signing keys
|
||||
users = set(r[0] for r in updates)
|
||||
master_key_by_user = {}
|
||||
self_signing_key_by_user = {}
|
||||
for user in users:
|
||||
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
|
||||
if cross_signing_key:
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||
cross_signing_key
|
||||
)
|
||||
# verify_key is a VerifyKey from signedjson, which uses
|
||||
# .version to denote the portion of the key ID after the
|
||||
# algorithm and colon, which is the device ID
|
||||
master_key_by_user[user] = {
|
||||
"key_info": cross_signing_key,
|
||||
"device_id": verify_key.version,
|
||||
}
|
||||
|
||||
cross_signing_key = yield self.get_e2e_cross_signing_key(
|
||||
user, "self_signing"
|
||||
)
|
||||
if cross_signing_key:
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||
cross_signing_key
|
||||
)
|
||||
self_signing_key_by_user[user] = {
|
||||
"key_info": cross_signing_key,
|
||||
"device_id": verify_key.version,
|
||||
}
|
||||
|
||||
# if we have exceeded the limit, we need to exclude any results with the
|
||||
# same stream_id as the last row.
|
||||
if len(updates) > limit:
|
||||
@ -153,15 +190,28 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
# context which created the Edu.
|
||||
|
||||
query_map = {}
|
||||
for update in updates:
|
||||
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
|
||||
cross_signing_keys_by_user = {}
|
||||
for user_id, device_id, update_stream_id, update_context in updates:
|
||||
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
|
||||
# Stop processing updates
|
||||
break
|
||||
|
||||
key = (update[0], update[1])
|
||||
|
||||
update_context = update[3]
|
||||
update_stream_id = update[2]
|
||||
if (
|
||||
user_id in master_key_by_user
|
||||
and device_id == master_key_by_user[user_id]["device_id"]
|
||||
):
|
||||
result = cross_signing_keys_by_user.setdefault(user_id, {})
|
||||
result["master_key"] = master_key_by_user[user_id]["key_info"]
|
||||
elif (
|
||||
user_id in self_signing_key_by_user
|
||||
and device_id == self_signing_key_by_user[user_id]["device_id"]
|
||||
):
|
||||
result = cross_signing_keys_by_user.setdefault(user_id, {})
|
||||
result["self_signing_key"] = self_signing_key_by_user[user_id][
|
||||
"key_info"
|
||||
]
|
||||
else:
|
||||
key = (user_id, device_id)
|
||||
|
||||
previous_update_stream_id, _ = query_map.get(key, (0, None))
|
||||
|
||||
@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
# devices, in which case E2E isn't going to work well anyway. We'll just
|
||||
# skip that stream_id and return an empty list, and continue with the next
|
||||
# stream_id next time.
|
||||
if not query_map:
|
||||
if not query_map and not cross_signing_keys_by_user:
|
||||
return stream_id_cutoff, []
|
||||
|
||||
results = yield self._get_device_update_edus_by_remote(
|
||||
destination, from_stream_id, query_map
|
||||
)
|
||||
|
||||
# add the updated cross-signing keys to the results list
|
||||
for user_id, result in iteritems(cross_signing_keys_by_user):
|
||||
result["user_id"] = user_id
|
||||
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
|
||||
results.append(("org.matrix.signing_key_update", result))
|
||||
|
||||
return now_stream_id, results
|
||||
|
||||
def _get_devices_by_remote_txn(
|
||||
def _get_device_updates_by_remote_txn(
|
||||
self, txn, destination, from_stream_id, now_stream_id, limit
|
||||
):
|
||||
"""Return device update information for a given remote destination
|
||||
@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
Returns:
|
||||
List: List of device updates
|
||||
"""
|
||||
# get the list of device updates that need to be sent
|
||||
sql = """
|
||||
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
|
||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||
@ -225,13 +282,17 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
List[Dict]: List of objects representing an device update EDU
|
||||
|
||||
"""
|
||||
devices = yield self.runInteraction(
|
||||
devices = (
|
||||
yield self.runInteraction(
|
||||
"_get_e2e_device_keys_txn",
|
||||
self._get_e2e_device_keys_txn,
|
||||
query_map.keys(),
|
||||
include_all_devices=True,
|
||||
include_deleted_devices=True,
|
||||
)
|
||||
if query_map
|
||||
else {}
|
||||
)
|
||||
|
||||
results = []
|
||||
for user_id, user_devices in iteritems(devices):
|
||||
@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
else:
|
||||
result["deleted"] = True
|
||||
|
||||
results.append(result)
|
||||
results.append(("m.device_list_update", result))
|
||||
|
||||
return results
|
||||
|
||||
|
@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
from_user_id,
|
||||
)
|
||||
|
||||
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
|
||||
"""Return a list of changes from the user signature stream to notify remotes.
|
||||
Note that the user signature stream represents when a user signs their
|
||||
device with their user-signing key, which is not published to other
|
||||
users or servers, so no `destination` is needed in the returned
|
||||
list. However, this is needed to poke workers.
|
||||
|
||||
Args:
|
||||
from_key (int): the stream ID to start at (exclusive)
|
||||
to_key (int): the stream ID to end at (inclusive)
|
||||
|
||||
Returns:
|
||||
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
|
||||
"""
|
||||
sql = """
|
||||
SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
|
||||
FROM user_signature_stream
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
GROUP BY user_id
|
||||
"""
|
||||
return self._execute(
|
||||
"get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
|
||||
)
|
||||
|
||||
|
||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
||||
|
@ -863,7 +863,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
)
|
||||
stream_row = txn.fetchone()
|
||||
if stream_row:
|
||||
offset_stream_ordering, = stream_row
|
||||
(offset_stream_ordering,) = stream_row
|
||||
rotate_to_stream_ordering = min(
|
||||
self.stream_ordering_day_ago, offset_stream_ordering
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ from prometheus_client import Counter
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.metrics
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import EventContentFields, EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase # noqa: F401
|
||||
from synapse.events.snapshot import EventContext # noqa: F401
|
||||
@ -933,6 +933,13 @@ class EventsStore(
|
||||
|
||||
self._handle_event_relations(txn, event)
|
||||
|
||||
# Store the labels for this event.
|
||||
labels = event.content.get(EventContentFields.LABELS)
|
||||
if labels:
|
||||
self.insert_labels_for_event_txn(
|
||||
txn, event.event_id, labels, event.room_id, event.depth
|
||||
)
|
||||
|
||||
# Insert into the room_memberships table.
|
||||
self._store_room_members_txn(
|
||||
txn,
|
||||
@ -1126,7 +1133,7 @@ class EventsStore(
|
||||
AND stream_ordering > ?
|
||||
"""
|
||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||
count, = txn.fetchone()
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
ret = yield self.runInteraction("count_messages", _count_messages)
|
||||
@ -1147,7 +1154,7 @@ class EventsStore(
|
||||
"""
|
||||
|
||||
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
|
||||
count, = txn.fetchone()
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
|
||||
@ -1162,7 +1169,7 @@ class EventsStore(
|
||||
AND stream_ordering > ?
|
||||
"""
|
||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||
count, = txn.fetchone()
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
ret = yield self.runInteraction("count_daily_active_rooms", _count)
|
||||
@ -1596,7 +1603,7 @@ class EventsStore(
|
||||
""",
|
||||
(room_id,),
|
||||
)
|
||||
min_depth, = txn.fetchone()
|
||||
(min_depth,) = txn.fetchone()
|
||||
|
||||
logger.info("[purge] updating room_depth to %d", min_depth)
|
||||
|
||||
@ -1905,6 +1912,33 @@ class EventsStore(
|
||||
get_all_updated_current_state_deltas_txn,
|
||||
)
|
||||
|
||||
def insert_labels_for_event_txn(
|
||||
self, txn, event_id, labels, room_id, topological_ordering
|
||||
):
|
||||
"""Store the mapping between an event's ID and its labels, with one row per
|
||||
(event_id, label) tuple.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The transaction to execute.
|
||||
event_id (str): The event's ID.
|
||||
labels (list[str]): A list of text labels.
|
||||
room_id (str): The ID of the room the event was sent to.
|
||||
topological_ordering (int): The position of the event in the room's topology.
|
||||
"""
|
||||
return self._simple_insert_many_txn(
|
||||
txn=txn,
|
||||
table="event_labels",
|
||||
values=[
|
||||
{
|
||||
"event_id": event_id,
|
||||
"label": label,
|
||||
"room_id": room_id,
|
||||
"topological_ordering": topological_ordering,
|
||||
}
|
||||
for label in labels
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
AllNewEventsResult = namedtuple(
|
||||
"AllNewEventsResult",
|
||||
|
@ -438,7 +438,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
upper_event_id, = rows[-1]
|
||||
(upper_event_id,) = rows[-1]
|
||||
|
||||
# Update the redactions with the received_ts.
|
||||
#
|
||||
|
@ -249,7 +249,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
WHERE group_id = ? AND category_id = ?
|
||||
"""
|
||||
txn.execute(sql, (group_id, category_id))
|
||||
order, = txn.fetchone()
|
||||
(order,) = txn.fetchone()
|
||||
|
||||
if existing:
|
||||
to_update = {}
|
||||
@ -509,7 +509,7 @@ class GroupServerStore(SQLBaseStore):
|
||||
WHERE group_id = ? AND role_id = ?
|
||||
"""
|
||||
txn.execute(sql, (group_id, role_id))
|
||||
order, = txn.fetchone()
|
||||
(order,) = txn.fetchone()
|
||||
|
||||
if existing:
|
||||
to_update = {}
|
||||
|
@ -171,7 +171,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
|
||||
|
||||
txn.execute(sql)
|
||||
count, = txn.fetchone()
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
return self.runInteraction("count_users", _count_users)
|
||||
|
@ -143,7 +143,7 @@ class PushRulesWorkerStore(
|
||||
" WHERE user_id = ? AND ? < stream_id"
|
||||
)
|
||||
txn.execute(sql, (user_id, last_id))
|
||||
count, = txn.fetchone()
|
||||
(count,) = txn.fetchone()
|
||||
return bool(count)
|
||||
|
||||
return self.runInteraction(
|
||||
|
@ -459,7 +459,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
WHERE appservice_id IS NULL
|
||||
"""
|
||||
)
|
||||
count, = txn.fetchone()
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
ret = yield self.runInteraction("count_users", _count_users)
|
||||
|
@ -927,7 +927,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
|
||||
if not row or not row[0]:
|
||||
return processed, True
|
||||
|
||||
next_room, = row
|
||||
(next_room,) = row
|
||||
|
||||
sql = """
|
||||
UPDATE current_state_events
|
||||
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- room_id and topoligical_ordering are denormalised from the events table in order to
|
||||
-- make the index work.
|
||||
CREATE TABLE IF NOT EXISTS event_labels (
|
||||
event_id TEXT,
|
||||
label TEXT,
|
||||
room_id TEXT NOT NULL,
|
||||
topological_ordering BIGINT NOT NULL,
|
||||
PRIMARY KEY(event_id, label)
|
||||
);
|
||||
|
||||
|
||||
-- This index enables an event pagination looking for a particular label to index the
|
||||
-- event_labels table first, which is much quicker than scanning the events table and then
|
||||
-- filtering by label, if the label is rarely used relative to the size of the room.
|
||||
CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering);
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/* Change the hidden column from a default value of FALSE to a default value of
|
||||
* 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the
|
||||
* string 'FALSE', which is truthy.
|
||||
*
|
||||
* Since sqlite doesn't allow us to just change the default value, we have to
|
||||
* recreate the table, copy the data, fix the rows that have incorrect data, and
|
||||
* replace the old table with the new table.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS devices2 (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
last_seen BIGINT,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
hidden BOOLEAN DEFAULT 0,
|
||||
CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
|
||||
);
|
||||
|
||||
INSERT INTO devices2 SELECT * FROM devices;
|
||||
|
||||
UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE';
|
||||
|
||||
DROP TABLE devices;
|
||||
|
||||
ALTER TABLE devices2 RENAME TO devices;
|
@ -672,7 +672,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
)
|
||||
)
|
||||
txn.execute(query, (value, search_query))
|
||||
headline, = txn.fetchall()[0]
|
||||
(headline,) = txn.fetchall()[0]
|
||||
|
||||
# Now we need to pick the possible highlights out of the haedline
|
||||
# result.
|
||||
|
@ -725,17 +725,19 @@ class StateGroupWorkerStore(
|
||||
member_filter, non_member_filter = state_filter.get_member_split()
|
||||
|
||||
# Now we look them up in the member and non-member caches
|
||||
non_member_state, incomplete_groups_nm, = (
|
||||
yield self._get_state_for_groups_using_cache(
|
||||
(
|
||||
non_member_state,
|
||||
incomplete_groups_nm,
|
||||
) = yield self._get_state_for_groups_using_cache(
|
||||
groups, self._state_group_cache, state_filter=non_member_filter
|
||||
)
|
||||
)
|
||||
|
||||
member_state, incomplete_groups_m, = (
|
||||
yield self._get_state_for_groups_using_cache(
|
||||
(
|
||||
member_state,
|
||||
incomplete_groups_m,
|
||||
) = yield self._get_state_for_groups_using_cache(
|
||||
groups, self._state_group_members_cache, state_filter=member_filter
|
||||
)
|
||||
)
|
||||
|
||||
state = dict(non_member_state)
|
||||
for group in groups:
|
||||
@ -1106,7 +1108,7 @@ class StateBackgroundUpdateStore(
|
||||
" WHERE id < ? AND room_id = ?",
|
||||
(state_group, room_id),
|
||||
)
|
||||
prev_group, = txn.fetchone()
|
||||
(prev_group,) = txn.fetchone()
|
||||
new_last_state_group = state_group
|
||||
|
||||
if prev_group:
|
||||
|
@ -773,7 +773,7 @@ class StatsStore(StateDeltasStore):
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
current_state_events_count, = txn.fetchone()
|
||||
(current_state_events_count,) = txn.fetchone()
|
||||
|
||||
users_in_room = self.get_users_in_room_txn(txn, room_id)
|
||||
|
||||
@ -863,7 +863,7 @@ class StatsStore(StateDeltasStore):
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
count, = txn.fetchone()
|
||||
(count,) = txn.fetchone()
|
||||
return count, pos
|
||||
|
||||
joined_rooms, pos = yield self.runInteraction(
|
||||
|
@ -229,6 +229,14 @@ def filter_to_clause(event_filter):
|
||||
clauses.append("contains_url = ?")
|
||||
args.append(event_filter.contains_url)
|
||||
|
||||
# We're only applying the "labels" filter on the database query, because applying the
|
||||
# "not_labels" filter via a SQL query is non-trivial. Instead, we let
|
||||
# event_filter.check_fields apply it, which is not as efficient but makes the
|
||||
# implementation simpler.
|
||||
if event_filter.labels:
|
||||
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
|
||||
args.extend(event_filter.labels)
|
||||
|
||||
return " AND ".join(clauses), args
|
||||
|
||||
|
||||
@ -864,8 +872,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
args.append(int(limit))
|
||||
|
||||
sql = (
|
||||
"SELECT event_id, topological_ordering, stream_ordering"
|
||||
"SELECT DISTINCT event_id, topological_ordering, stream_ordering"
|
||||
" FROM events"
|
||||
" LEFT JOIN event_labels USING (event_id, room_id, topological_ordering)"
|
||||
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
|
||||
" ORDER BY topological_ordering %(order)s,"
|
||||
" stream_ordering %(order)s LIMIT ?"
|
||||
|
@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1):
|
||||
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
|
||||
else:
|
||||
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
|
||||
val, = cur.fetchone()
|
||||
(val,) = cur.fetchone()
|
||||
cur.close()
|
||||
current_id = int(val) if val else step
|
||||
return (max if step > 0 else min)(current_id, step)
|
||||
|
@ -19,6 +19,7 @@ import jsonschema
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.events import FrozenEvent
|
||||
@ -95,6 +96,8 @@ class FilteringTestCase(unittest.TestCase):
|
||||
"types": ["m.room.message"],
|
||||
"not_rooms": ["!726s6s6q:example.com"],
|
||||
"not_senders": ["@spam:example.com"],
|
||||
"org.matrix.labels": ["#fun"],
|
||||
"org.matrix.not_labels": ["#work"],
|
||||
},
|
||||
"ephemeral": {
|
||||
"types": ["m.receipt", "m.typing"],
|
||||
@ -320,6 +323,46 @@ class FilteringTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_filter_labels(self):
|
||||
definition = {"org.matrix.labels": ["#fun"]}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!secretbase:unknown",
|
||||
content={EventContentFields.LABELS: ["#fun"]},
|
||||
)
|
||||
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!secretbase:unknown",
|
||||
content={EventContentFields.LABELS: ["#notfun"]},
|
||||
)
|
||||
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_filter_not_labels(self):
|
||||
definition = {"org.matrix.not_labels": ["#fun"]}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!secretbase:unknown",
|
||||
content={EventContentFields.LABELS: ["#fun"]},
|
||||
)
|
||||
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!secretbase:unknown",
|
||||
content={EventContentFields.LABELS: ["#notfun"]},
|
||||
)
|
||||
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_filter_presence_match(self):
|
||||
user_filter_json = {"presence": {"types": ["m.*"]}}
|
||||
|
@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
"get_received_txn_response",
|
||||
"set_received_txn_response",
|
||||
"get_destination_retry_timings",
|
||||
"get_devices_by_remote",
|
||||
"get_device_updates_by_remote",
|
||||
# Bits that user_directory needs
|
||||
"get_user_directory_stream_pos",
|
||||
"get_current_state_deltas",
|
||||
@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
retry_timings_res
|
||||
)
|
||||
|
||||
self.datastore.get_devices_by_remote.return_value = (0, [])
|
||||
self.datastore.get_device_updates_by_remote.return_value = (0, [])
|
||||
|
||||
def get_received_txn_response(*args):
|
||||
return defer.succeed(None)
|
||||
|
@ -20,6 +20,23 @@ from zope.interface import implementer
|
||||
from OpenSSL import SSL
|
||||
from OpenSSL.SSL import Connection
|
||||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
||||
from twisted.internet.ssl import Certificate, trustRootFromCertificates
|
||||
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
|
||||
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
|
||||
|
||||
|
||||
def get_test_https_policy():
|
||||
"""Get a test IPolicyForHTTPS which trusts the test CA cert
|
||||
|
||||
Returns:
|
||||
IPolicyForHTTPS
|
||||
"""
|
||||
ca_file = get_test_ca_cert_file()
|
||||
with open(ca_file) as stream:
|
||||
content = stream.read()
|
||||
cert = Certificate.loadPEM(content)
|
||||
trust_root = trustRootFromCertificates([cert])
|
||||
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
|
||||
|
||||
|
||||
def get_test_ca_cert_file():
|
||||
|
@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
|
||||
)
|
||||
|
||||
# grab a hold of the TLS connection, in case it gets torn down
|
||||
server_tls_connection = server_tls_protocol._tlsConnection
|
||||
|
||||
# fish the test server back out of the server-side TLS protocol.
|
||||
http_protocol = server_tls_protocol.wrappedProtocol
|
||||
|
||||
# give the reactor a pump to get the TLS juices flowing.
|
||||
self.reactor.pump((0.1,))
|
||||
|
||||
# check the SNI
|
||||
server_name = server_tls_protocol._tlsConnection.get_servername()
|
||||
server_name = server_tls_connection.get_servername()
|
||||
self.assertEqual(
|
||||
server_name,
|
||||
expected_sni,
|
||||
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||
)
|
||||
|
||||
# fish the test server back out of the server-side TLS protocol.
|
||||
return server_tls_protocol.wrappedProtocol
|
||||
return http_protocol
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _make_get_request(self, uri):
|
||||
|
334
tests/http/test_proxyagent.py
Normal file
334
tests/http/test_proxyagent.py
Normal file
@ -0,0 +1,334 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import treq
|
||||
|
||||
from twisted.internet import interfaces # noqa: F401
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.web.http import HTTPChannel
|
||||
|
||||
from synapse.http.proxyagent import ProxyAgent
|
||||
|
||||
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
|
||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HTTPFactory = Factory.forProtocol(HTTPChannel)
|
||||
|
||||
|
||||
class MatrixFederationAgentTests(TestCase):
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
|
||||
def _make_connection(
|
||||
self, client_factory, server_factory, ssl=False, expected_sni=None
|
||||
):
|
||||
"""Builds a test server, and completes the outgoing client connection
|
||||
|
||||
Args:
|
||||
client_factory (interfaces.IProtocolFactory): the the factory that the
|
||||
application is trying to use to make the outbound connection. We will
|
||||
invoke it to build the client Protocol
|
||||
|
||||
server_factory (interfaces.IProtocolFactory): a factory to build the
|
||||
server-side protocol
|
||||
|
||||
ssl (bool): If true, we will expect an ssl connection and wrap
|
||||
server_factory with a TLSMemoryBIOFactory
|
||||
|
||||
expected_sni (bytes|None): the expected SNI value
|
||||
|
||||
Returns:
|
||||
IProtocol: the server Protocol returned by server_factory
|
||||
"""
|
||||
if ssl:
|
||||
server_factory = _wrap_server_factory_for_tls(server_factory)
|
||||
|
||||
server_protocol = server_factory.buildProtocol(None)
|
||||
|
||||
# now, tell the client protocol factory to build the client protocol,
|
||||
# and wire the output of said protocol up to the server via
|
||||
# a FakeTransport.
|
||||
#
|
||||
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||
# stubbing that out here.
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
client_protocol.makeConnection(
|
||||
FakeTransport(server_protocol, self.reactor, client_protocol)
|
||||
)
|
||||
|
||||
# tell the server protocol to send its stuff back to the client, too
|
||||
server_protocol.makeConnection(
|
||||
FakeTransport(client_protocol, self.reactor, server_protocol)
|
||||
)
|
||||
|
||||
if ssl:
|
||||
http_protocol = server_protocol.wrappedProtocol
|
||||
tls_connection = server_protocol._tlsConnection
|
||||
else:
|
||||
http_protocol = server_protocol
|
||||
tls_connection = None
|
||||
|
||||
# give the reactor a pump to get the TLS juices flowing (if needed)
|
||||
self.reactor.advance(0)
|
||||
|
||||
if expected_sni is not None:
|
||||
server_name = tls_connection.get_servername()
|
||||
self.assertEqual(
|
||||
server_name,
|
||||
expected_sni,
|
||||
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||
)
|
||||
|
||||
return http_protocol
|
||||
|
||||
def test_http_request(self):
|
||||
agent = ProxyAgent(self.reactor)
|
||||
|
||||
self.reactor.lookups["test.com"] = "1.2.3.4"
|
||||
d = agent.request(b"GET", b"http://test.com")
|
||||
|
||||
# there should be a pending TCP connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 80)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory, _get_test_protocol_factory()
|
||||
)
|
||||
|
||||
# the FakeTransport is async, so we need to pump the reactor
|
||||
self.reactor.advance(0)
|
||||
|
||||
# now there should be a pending request
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"/")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
resp = self.successResultOf(d)
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
def test_https_request(self):
|
||||
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
|
||||
|
||||
self.reactor.lookups["test.com"] = "1.2.3.4"
|
||||
d = agent.request(b"GET", b"https://test.com/abc")
|
||||
|
||||
# there should be a pending TCP connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 443)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory,
|
||||
_get_test_protocol_factory(),
|
||||
ssl=True,
|
||||
expected_sni=b"test.com",
|
||||
)
|
||||
|
||||
# the FakeTransport is async, so we need to pump the reactor
|
||||
self.reactor.advance(0)
|
||||
|
||||
# now there should be a pending request
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"/abc")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
resp = self.successResultOf(d)
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
def test_http_request_via_proxy(self):
|
||||
agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
|
||||
|
||||
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||
d = agent.request(b"GET", b"http://test.com")
|
||||
|
||||
# there should be a pending TCP connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, "1.2.3.5")
|
||||
self.assertEqual(port, 8888)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory, _get_test_protocol_factory()
|
||||
)
|
||||
|
||||
# the FakeTransport is async, so we need to pump the reactor
|
||||
self.reactor.advance(0)
|
||||
|
||||
# now there should be a pending request
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"http://test.com")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
resp = self.successResultOf(d)
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
def test_https_request_via_proxy(self):
|
||||
agent = ProxyAgent(
|
||||
self.reactor,
|
||||
contextFactory=get_test_https_policy(),
|
||||
https_proxy=b"proxy.com",
|
||||
)
|
||||
|
||||
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||
d = agent.request(b"GET", b"https://test.com/abc")
|
||||
|
||||
# there should be a pending TCP connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, "1.2.3.5")
|
||||
self.assertEqual(port, 1080)
|
||||
|
||||
# make a test HTTP server, and wire up the client
|
||||
proxy_server = self._make_connection(
|
||||
client_factory, _get_test_protocol_factory()
|
||||
)
|
||||
|
||||
# fish the transports back out so that we can do the old switcheroo
|
||||
s2c_transport = proxy_server.transport
|
||||
client_protocol = s2c_transport.other
|
||||
c2s_transport = client_protocol.transport
|
||||
|
||||
# the FakeTransport is async, so we need to pump the reactor
|
||||
self.reactor.advance(0)
|
||||
|
||||
# now there should be a pending CONNECT request
|
||||
self.assertEqual(len(proxy_server.requests), 1)
|
||||
|
||||
request = proxy_server.requests[0]
|
||||
self.assertEqual(request.method, b"CONNECT")
|
||||
self.assertEqual(request.path, b"test.com:443")
|
||||
|
||||
# tell the proxy server not to close the connection
|
||||
proxy_server.persistent = True
|
||||
|
||||
# this just stops the http Request trying to do a chunked response
|
||||
# request.setHeader(b"Content-Length", b"0")
|
||||
request.finish()
|
||||
|
||||
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
|
||||
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
|
||||
ssl_protocol = ssl_factory.buildProtocol(None)
|
||||
http_server = ssl_protocol.wrappedProtocol
|
||||
|
||||
ssl_protocol.makeConnection(
|
||||
FakeTransport(client_protocol, self.reactor, ssl_protocol)
|
||||
)
|
||||
c2s_transport.other = ssl_protocol
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
server_name = ssl_protocol._tlsConnection.get_servername()
|
||||
expected_sni = b"test.com"
|
||||
self.assertEqual(
|
||||
server_name,
|
||||
expected_sni,
|
||||
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||
)
|
||||
|
||||
# now there should be a pending request
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"/abc")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
resp = self.successResultOf(d)
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
|
||||
def _wrap_server_factory_for_tls(factory, sanlist=None):
|
||||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
||||
|
||||
The resultant factory will create a TLS server which presents a certificate
|
||||
signed by our test CA, valid for the domains in `sanlist`
|
||||
|
||||
Args:
|
||||
factory (interfaces.IProtocolFactory): protocol factory to wrap
|
||||
sanlist (iterable[bytes]): list of domains the cert should be valid for
|
||||
|
||||
Returns:
|
||||
interfaces.IProtocolFactory
|
||||
"""
|
||||
if sanlist is None:
|
||||
sanlist = [b"DNS:test.com"]
|
||||
|
||||
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
|
||||
return TLSMemoryBIOFactory(
|
||||
connection_creator, isClient=False, wrappedFactory=factory
|
||||
)
|
||||
|
||||
|
||||
def _get_test_protocol_factory():
|
||||
"""Get a protocol Factory which will build an HTTPChannel
|
||||
|
||||
Returns:
|
||||
interfaces.IProtocolFactory
|
||||
"""
|
||||
server_factory = Factory.forProtocol(HTTPChannel)
|
||||
|
||||
# Request.finish expects the factory to have a 'log' method.
|
||||
server_factory.log = _log_request
|
||||
|
||||
return server_factory
|
||||
|
||||
|
||||
def _log_request(request):
|
||||
"""Implements Factory.log, which is expected by Request.finish"""
|
||||
logger.info("Completed request %s", request)
|
@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
config = self.default_config()
|
||||
config["start_pushers"] = True
|
||||
|
||||
hs = self.setup_test_homeserver(config=config, simple_http_client=m)
|
||||
hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
|
||||
|
||||
return hs
|
||||
|
||||
|
@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.rest.client.v1 import login, profile, room
|
||||
|
||||
from tests import unittest
|
||||
@ -811,6 +811,105 @@ class RoomMessageListTestCase(RoomBase):
|
||||
self.assertTrue("chunk" in channel.json_body)
|
||||
self.assertTrue("end" in channel.json_body)
|
||||
|
||||
def test_filter_labels(self):
|
||||
"""Test that we can filter by a label."""
|
||||
message_filter = json.dumps(
|
||||
{"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]}
|
||||
)
|
||||
|
||||
events = self._test_filter_labels(message_filter)
|
||||
|
||||
self.assertEqual(len(events), 2, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
||||
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
||||
|
||||
def test_filter_not_labels(self):
|
||||
"""Test that we can filter by the absence of a label."""
|
||||
message_filter = json.dumps(
|
||||
{"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]}
|
||||
)
|
||||
|
||||
events = self._test_filter_labels(message_filter)
|
||||
|
||||
self.assertEqual(len(events), 3, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "without label", events[0])
|
||||
self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
|
||||
self.assertEqual(
|
||||
events[2]["content"]["body"], "with two wrong labels", events[2]
|
||||
)
|
||||
|
||||
def test_filter_labels_not_labels(self):
|
||||
"""Test that we can filter by both a label and the absence of another label."""
|
||||
sync_filter = json.dumps(
|
||||
{
|
||||
"types": [EventTypes.Message],
|
||||
"org.matrix.labels": ["#work"],
|
||||
"org.matrix.not_labels": ["#notfun"],
|
||||
}
|
||||
)
|
||||
|
||||
events = self._test_filter_labels(sync_filter)
|
||||
|
||||
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
||||
|
||||
def _test_filter_labels(self, message_filter):
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with right label",
|
||||
EventContentFields.LABELS: ["#fun"],
|
||||
},
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={"msgtype": "m.text", "body": "without label"},
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with wrong label",
|
||||
EventContentFields.LABELS: ["#work"],
|
||||
},
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with two wrong labels",
|
||||
EventContentFields.LABELS: ["#work", "#notfun"],
|
||||
},
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with right label",
|
||||
EventContentFields.LABELS: ["#fun"],
|
||||
},
|
||||
)
|
||||
|
||||
token = "s0_0_0_0_0_0_0_0_0"
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
"/rooms/%s/messages?access_token=x&from=%s&filter=%s"
|
||||
% (self.room_id, token, message_filter),
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
return channel.json_body["chunk"]
|
||||
|
||||
|
||||
class RoomSearchTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
|
@ -106,13 +106,22 @@ class RestHelper(object):
|
||||
self.auth_user_id = temp_id
|
||||
|
||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
if body is None:
|
||||
body = "body_text_here"
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
|
||||
content = {"msgtype": "m.text", "body": body}
|
||||
|
||||
return self.send_event(
|
||||
room_id, "m.room.message", content, txn_id, tok, expect_code
|
||||
)
|
||||
|
||||
def send_event(
|
||||
self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
|
||||
):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id)
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
|
@ -12,10 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventContentFields, EventTypes
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.rest.client.v2_alpha import sync
|
||||
|
||||
@ -26,7 +28,12 @@ from tests.server import TimedOutException
|
||||
class FilterTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
user_id = "@apple:test"
|
||||
servlets = [sync.register_servlets]
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
@ -70,6 +77,140 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
|
||||
class SyncFilterTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
]
|
||||
|
||||
def test_sync_filter_labels(self):
|
||||
"""Test that we can filter by a label."""
|
||||
sync_filter = json.dumps(
|
||||
{
|
||||
"room": {
|
||||
"timeline": {
|
||||
"types": [EventTypes.Message],
|
||||
"org.matrix.labels": ["#fun"],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
events = self._test_sync_filter_labels(sync_filter)
|
||||
|
||||
self.assertEqual(len(events), 2, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
||||
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
||||
|
||||
def test_sync_filter_not_labels(self):
|
||||
"""Test that we can filter by the absence of a label."""
|
||||
sync_filter = json.dumps(
|
||||
{
|
||||
"room": {
|
||||
"timeline": {
|
||||
"types": [EventTypes.Message],
|
||||
"org.matrix.not_labels": ["#fun"],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
events = self._test_sync_filter_labels(sync_filter)
|
||||
|
||||
self.assertEqual(len(events), 3, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "without label", events[0])
|
||||
self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
|
||||
self.assertEqual(
|
||||
events[2]["content"]["body"], "with two wrong labels", events[2]
|
||||
)
|
||||
|
||||
def test_sync_filter_labels_not_labels(self):
|
||||
"""Test that we can filter by both a label and the absence of another label."""
|
||||
sync_filter = json.dumps(
|
||||
{
|
||||
"room": {
|
||||
"timeline": {
|
||||
"types": [EventTypes.Message],
|
||||
"org.matrix.labels": ["#work"],
|
||||
"org.matrix.not_labels": ["#notfun"],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
events = self._test_sync_filter_labels(sync_filter)
|
||||
|
||||
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
||||
|
||||
def _test_sync_filter_labels(self, sync_filter):
|
||||
user_id = self.register_user("kermit", "test")
|
||||
tok = self.login("kermit", "test")
|
||||
|
||||
room_id = self.helper.create_room_as(user_id, tok=tok)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with right label",
|
||||
EventContentFields.LABELS: ["#fun"],
|
||||
},
|
||||
tok=tok,
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={"msgtype": "m.text", "body": "without label"},
|
||||
tok=tok,
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with wrong label",
|
||||
EventContentFields.LABELS: ["#work"],
|
||||
},
|
||||
tok=tok,
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with two wrong labels",
|
||||
EventContentFields.LABELS: ["#work", "#notfun"],
|
||||
},
|
||||
tok=tok,
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with right label",
|
||||
EventContentFields.LABELS: ["#fun"],
|
||||
},
|
||||
tok=tok,
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "/sync?filter=%s" % sync_filter, access_token=tok
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
|
||||
|
||||
|
||||
class SyncTypingTests(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
|
@ -395,11 +395,24 @@ class FakeTransport(object):
|
||||
self.disconnecting = True
|
||||
if self._protocol:
|
||||
self._protocol.connectionLost(reason)
|
||||
|
||||
# if we still have data to write, delay until that is done
|
||||
if self.buffer:
|
||||
logger.info(
|
||||
"FakeTransport: Delaying disconnect until buffer is flushed"
|
||||
)
|
||||
else:
|
||||
self.disconnected = True
|
||||
|
||||
def abortConnection(self):
|
||||
logger.info("FakeTransport: abortConnection()")
|
||||
self.loseConnection()
|
||||
|
||||
if not self.disconnecting:
|
||||
self.disconnecting = True
|
||||
if self._protocol:
|
||||
self._protocol.connectionLost(None)
|
||||
|
||||
self.disconnected = True
|
||||
|
||||
def pauseProducing(self):
|
||||
if not self.producer:
|
||||
@ -430,6 +443,9 @@ class FakeTransport(object):
|
||||
self._reactor.callLater(0.0, _produce)
|
||||
|
||||
def write(self, byt):
|
||||
if self.disconnecting:
|
||||
raise Exception("Writing to disconnecting FakeTransport")
|
||||
|
||||
self.buffer = self.buffer + byt
|
||||
|
||||
# always actually do the write asynchronously. Some protocols (notably the
|
||||
@ -474,6 +490,10 @@ class FakeTransport(object):
|
||||
if self.buffer and self.autoflush:
|
||||
self._reactor.callLater(0.0, self.flush)
|
||||
|
||||
if not self.buffer and self.disconnecting:
|
||||
logger.info("FakeTransport: Buffer now empty, completing disconnect")
|
||||
self.disconnected = True
|
||||
|
||||
|
||||
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
|
||||
"""
|
||||
|
@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_devices_by_remote(self):
|
||||
def test_get_device_updates_by_remote(self):
|
||||
device_ids = ["device_id1", "device_id2"]
|
||||
|
||||
# Add two device updates with a single stream_id
|
||||
@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
)
|
||||
|
||||
# Get all device updates ever meant for this remote
|
||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||
"somehost", -1, limit=100
|
||||
)
|
||||
|
||||
@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
self._check_devices_in_updates(device_ids, device_updates)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_devices_by_remote_limited(self):
|
||||
def test_get_device_updates_by_remote_limited(self):
|
||||
# Test breaking the update limit in 1, 101, and 1 device_id segments
|
||||
|
||||
# first add one device
|
||||
@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
#
|
||||
|
||||
# first we should get a single update
|
||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||
"someotherhost", -1, limit=100
|
||||
)
|
||||
self._check_devices_in_updates(device_ids1, device_updates)
|
||||
|
||||
# Then we should get an empty list back as the 101 devices broke the limit
|
||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||
"someotherhost", now_stream_id, limit=100
|
||||
)
|
||||
self.assertEqual(len(device_updates), 0)
|
||||
|
||||
# The 101 devices should've been cleared, so we should now just get one device
|
||||
# update
|
||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||
"someotherhost", now_stream_id, limit=100
|
||||
)
|
||||
self._check_devices_in_updates(device_ids3, device_updates)
|
||||
@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
"""Check that an specific device ids exist in a list of device update EDUs"""
|
||||
self.assertEqual(len(device_updates), len(expected_device_ids))
|
||||
|
||||
received_device_ids = {update["device_id"] for update in device_updates}
|
||||
received_device_ids = {
|
||||
update["device_id"] for edu_type, update in device_updates
|
||||
}
|
||||
self.assertEqual(received_device_ids, set(expected_device_ids))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.handler = self.homeserver.get_handlers().federation_handler
|
||||
self.handler.do_auth = lambda *a, **b: succeed(True)
|
||||
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
|
||||
context
|
||||
)
|
||||
self.client = self.homeserver.get_federation_client()
|
||||
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
|
||||
pdus
|
||||
|
4
tox.ini
4
tox.ini
@ -114,7 +114,7 @@ skip_install = True
|
||||
basepython = python3.6
|
||||
deps =
|
||||
flake8
|
||||
black==19.3b0 # We pin so that our tests don't start failing on new releases of black.
|
||||
black==19.10b0 # We pin so that our tests don't start failing on new releases of black.
|
||||
commands =
|
||||
python -m black --check --diff .
|
||||
/bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}"
|
||||
@ -167,6 +167,6 @@ deps =
|
||||
env =
|
||||
MYPYPATH = stubs/
|
||||
extras = all
|
||||
commands = mypy --show-traceback --check-untyped-defs --show-error-codes --follow-imports=normal \
|
||||
commands = mypy \
|
||||
synapse/logging/ \
|
||||
synapse/config/
|
||||
|
Loading…
Reference in New Issue
Block a user