mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-27 04:39:24 -05:00
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/split_purge_history
This commit is contained in:
commit
6a0092d371
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -5,3 +5,4 @@
|
|||||||
* [ ] Pull request is based on the develop branch
|
* [ ] Pull request is based on the develop branch
|
||||||
* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog)
|
* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog)
|
||||||
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off)
|
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off)
|
||||||
|
* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#code-style))
|
||||||
|
@ -58,10 +58,29 @@ All Matrix projects have a well-defined code-style - and sometimes we've even
|
|||||||
got as far as documenting it... For instance, synapse's code style doc lives
|
got as far as documenting it... For instance, synapse's code style doc lives
|
||||||
at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md.
|
at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md.
|
||||||
|
|
||||||
|
To facilitate meeting these criteria you can run ``scripts-dev/lint.sh``
|
||||||
|
locally. Since this runs the tools listed in the above document, you'll need
|
||||||
|
python 3.6 and to install each tool. **Note that the script does not just
|
||||||
|
test/check, but also reformats code, so you may wish to ensure any new code is
|
||||||
|
committed first**. By default this script checks all files and can take some
|
||||||
|
time; if you alter only certain files, you might wish to specify paths as
|
||||||
|
arguments to reduce the run-time.
|
||||||
|
|
||||||
Please ensure your changes match the cosmetic style of the existing project,
|
Please ensure your changes match the cosmetic style of the existing project,
|
||||||
and **never** mix cosmetic and functional changes in the same commit, as it
|
and **never** mix cosmetic and functional changes in the same commit, as it
|
||||||
makes it horribly hard to review otherwise.
|
makes it horribly hard to review otherwise.
|
||||||
|
|
||||||
|
Before doing a commit, ensure the changes you've made don't produce
|
||||||
|
linting errors. You can do this by running the linters as follows. Ensure to
|
||||||
|
commit any files that were corrected.
|
||||||
|
|
||||||
|
::
|
||||||
|
# Install the dependencies
|
||||||
|
pip install -U black flake8 isort
|
||||||
|
|
||||||
|
# Run the linter script
|
||||||
|
./scripts-dev/lint.sh
|
||||||
|
|
||||||
Changelog
|
Changelog
|
||||||
~~~~~~~~~
|
~~~~~~~~~
|
||||||
|
|
||||||
|
1
changelog.d/5727.feature
Normal file
1
changelog.d/5727.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add federation support for cross-signing.
|
1
changelog.d/6164.doc
Normal file
1
changelog.d/6164.doc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Contributor documentation now mentions script to run linters.
|
1
changelog.d/6232.bugfix
Normal file
1
changelog.d/6232.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Remove a room from a server's public rooms list on room upgrade.
|
1
changelog.d/6238.feature
Normal file
1
changelog.d/6238.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars.
|
1
changelog.d/6254.bugfix
Normal file
1
changelog.d/6254.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Make notification of cross-signing signatures work with workers.
|
1
changelog.d/6298.misc
Normal file
1
changelog.d/6298.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Refactor EventContext for clarity.
|
1
changelog.d/6301.feature
Normal file
1
changelog.d/6301.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)).
|
1
changelog.d/6304.misc
Normal file
1
changelog.d/6304.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Update the version of black used to 19.10b0.
|
1
changelog.d/6305.misc
Normal file
1
changelog.d/6305.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add some documentation about worker replication.
|
1
changelog.d/6306.bugfix
Normal file
1
changelog.d/6306.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Appservice requests will no longer contain a double slash prefix when the appservice url provided ends in a slash.
|
1
changelog.d/6312.misc
Normal file
1
changelog.d/6312.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Document the use of `lint.sh` for code style enforcement & extend it to run on specified paths only.
|
1
changelog.d/6313.bugfix
Normal file
1
changelog.d/6313.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix the `hidden` field in the `devices` table for SQLite versions prior to 3.23.0.
|
1
changelog.d/6314.misc
Normal file
1
changelog.d/6314.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated.
|
@ -78,7 +78,7 @@ class InputOutput(object):
|
|||||||
m = re.match("^join (\S+)$", line)
|
m = re.match("^join (\S+)$", line)
|
||||||
if m:
|
if m:
|
||||||
# The `sender` wants to join a room.
|
# The `sender` wants to join a room.
|
||||||
room_name, = m.groups()
|
(room_name,) = m.groups()
|
||||||
self.print_line("%s joining %s" % (self.user, room_name))
|
self.print_line("%s joining %s" % (self.user, room_name))
|
||||||
self.server.join_room(room_name, self.user, self.user)
|
self.server.join_room(room_name, self.user, self.user)
|
||||||
# self.print_line("OK.")
|
# self.print_line("OK.")
|
||||||
@ -105,7 +105,7 @@ class InputOutput(object):
|
|||||||
m = re.match("^backfill (\S+)$", line)
|
m = re.match("^backfill (\S+)$", line)
|
||||||
if m:
|
if m:
|
||||||
# we want to backfill a room
|
# we want to backfill a room
|
||||||
room_name, = m.groups()
|
(room_name,) = m.groups()
|
||||||
self.print_line("backfill %s" % room_name)
|
self.print_line("backfill %s" % room_name)
|
||||||
self.server.backfill(room_name)
|
self.server.backfill(room_name)
|
||||||
return
|
return
|
||||||
|
@ -199,7 +199,20 @@ client (C):
|
|||||||
|
|
||||||
#### REPLICATE (C)
|
#### REPLICATE (C)
|
||||||
|
|
||||||
Asks the server to replicate a given stream
|
Asks the server to replicate a given stream. The syntax is:
|
||||||
|
|
||||||
|
```
|
||||||
|
REPLICATE <stream_name> <token>
|
||||||
|
```
|
||||||
|
|
||||||
|
Where `<token>` may be either:
|
||||||
|
* a numeric stream_id to stream updates since (exclusive)
|
||||||
|
* `NOW` to stream all subsequent updates.
|
||||||
|
|
||||||
|
The `<stream_name>` is the name of a replication stream to subscribe
|
||||||
|
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
|
||||||
|
of streams). It can also be `ALL` to subscribe to all known streams,
|
||||||
|
in which case the `<token>` must be set to `NOW`.
|
||||||
|
|
||||||
#### USER_SYNC (C)
|
#### USER_SYNC (C)
|
||||||
|
|
||||||
|
5
mypy.ini
5
mypy.ini
@ -1,7 +1,10 @@
|
|||||||
[mypy]
|
[mypy]
|
||||||
namespace_packages = True
|
namespace_packages = True
|
||||||
plugins = mypy_zope:plugin
|
plugins = mypy_zope:plugin
|
||||||
follow_imports=skip
|
follow_imports = normal
|
||||||
|
check_untyped_defs = True
|
||||||
|
show_error_codes = True
|
||||||
|
show_traceback = True
|
||||||
mypy_path = stubs
|
mypy_path = stubs
|
||||||
|
|
||||||
[mypy-zope]
|
[mypy-zope]
|
||||||
|
@ -7,7 +7,15 @@
|
|||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
isort -y -rc synapse tests scripts-dev scripts
|
if [ $# -ge 1 ]
|
||||||
flake8 synapse tests
|
then
|
||||||
python3 -m black synapse tests scripts-dev scripts
|
files=$*
|
||||||
|
else
|
||||||
|
files="synapse tests scripts-dev scripts"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Linting these locations: $files"
|
||||||
|
isort -y -rc $files
|
||||||
|
flake8 $files
|
||||||
|
python3 -m black $files
|
||||||
./scripts-dev/config-lint.sh
|
./scripts-dev/config-lint.sh
|
||||||
|
@ -138,3 +138,10 @@ class LimitBlockingTypes(object):
|
|||||||
|
|
||||||
MONTHLY_ACTIVE_USER = "monthly_active_user"
|
MONTHLY_ACTIVE_USER = "monthly_active_user"
|
||||||
HS_DISABLED = "hs_disabled"
|
HS_DISABLED = "hs_disabled"
|
||||||
|
|
||||||
|
|
||||||
|
class EventContentFields(object):
|
||||||
|
"""Fields found in events' content, regardless of type."""
|
||||||
|
|
||||||
|
# Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
|
||||||
|
LABELS = "org.matrix.labels"
|
||||||
|
@ -20,6 +20,7 @@ from jsonschema import FormatChecker
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.constants import EventContentFields
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.storage.presence import UserPresenceState
|
from synapse.storage.presence import UserPresenceState
|
||||||
from synapse.types import RoomID, UserID
|
from synapse.types import RoomID, UserID
|
||||||
@ -66,6 +67,10 @@ ROOM_EVENT_FILTER_SCHEMA = {
|
|||||||
"contains_url": {"type": "boolean"},
|
"contains_url": {"type": "boolean"},
|
||||||
"lazy_load_members": {"type": "boolean"},
|
"lazy_load_members": {"type": "boolean"},
|
||||||
"include_redundant_members": {"type": "boolean"},
|
"include_redundant_members": {"type": "boolean"},
|
||||||
|
# Include or exclude events with the provided labels.
|
||||||
|
# cf https://github.com/matrix-org/matrix-doc/pull/2326
|
||||||
|
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
|
||||||
|
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -259,6 +264,9 @@ class Filter(object):
|
|||||||
|
|
||||||
self.contains_url = self.filter_json.get("contains_url", None)
|
self.contains_url = self.filter_json.get("contains_url", None)
|
||||||
|
|
||||||
|
self.labels = self.filter_json.get("org.matrix.labels", None)
|
||||||
|
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
|
||||||
|
|
||||||
def filters_all_types(self):
|
def filters_all_types(self):
|
||||||
return "*" in self.not_types
|
return "*" in self.not_types
|
||||||
|
|
||||||
@ -282,6 +290,7 @@ class Filter(object):
|
|||||||
room_id = None
|
room_id = None
|
||||||
ev_type = "m.presence"
|
ev_type = "m.presence"
|
||||||
contains_url = False
|
contains_url = False
|
||||||
|
labels = []
|
||||||
else:
|
else:
|
||||||
sender = event.get("sender", None)
|
sender = event.get("sender", None)
|
||||||
if not sender:
|
if not sender:
|
||||||
@ -300,10 +309,11 @@ class Filter(object):
|
|||||||
content = event.get("content", {})
|
content = event.get("content", {})
|
||||||
# check if there is a string url field in the content for filtering purposes
|
# check if there is a string url field in the content for filtering purposes
|
||||||
contains_url = isinstance(content.get("url"), text_type)
|
contains_url = isinstance(content.get("url"), text_type)
|
||||||
|
labels = content.get(EventContentFields.LABELS, [])
|
||||||
|
|
||||||
return self.check_fields(room_id, sender, ev_type, contains_url)
|
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
|
||||||
|
|
||||||
def check_fields(self, room_id, sender, event_type, contains_url):
|
def check_fields(self, room_id, sender, event_type, labels, contains_url):
|
||||||
"""Checks whether the filter matches the given event fields.
|
"""Checks whether the filter matches the given event fields.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -313,6 +323,7 @@ class Filter(object):
|
|||||||
"rooms": lambda v: room_id == v,
|
"rooms": lambda v: room_id == v,
|
||||||
"senders": lambda v: sender == v,
|
"senders": lambda v: sender == v,
|
||||||
"types": lambda v: _matches_wildcard(event_type, v),
|
"types": lambda v: _matches_wildcard(event_type, v),
|
||||||
|
"labels": lambda v: v in labels,
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, match_func in literal_keys.items():
|
for name, match_func in literal_keys.items():
|
||||||
|
@ -565,7 +565,7 @@ def run(hs):
|
|||||||
"Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
|
"Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
yield hs.get_simple_http_client().put_json(
|
yield hs.get_proxied_http_client().put_json(
|
||||||
hs.config.report_stats_endpoint, stats
|
hs.config.report_stats_endpoint, stats
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -94,7 +94,9 @@ class ApplicationService(object):
|
|||||||
ip_range_whitelist=None,
|
ip_range_whitelist=None,
|
||||||
):
|
):
|
||||||
self.token = token
|
self.token = token
|
||||||
self.url = url
|
self.url = (
|
||||||
|
url.rstrip("/") if isinstance(url, str) else None
|
||||||
|
) # url must not end with a slash
|
||||||
self.hs_token = hs_token
|
self.hs_token = hs_token
|
||||||
self.sender = sender
|
self.sender = sender
|
||||||
self.server_name = hostname
|
self.server_name = hostname
|
||||||
|
@ -234,8 +234,8 @@ def setup_logging(
|
|||||||
|
|
||||||
# make sure that the first thing we log is a thing we can grep backwards
|
# make sure that the first thing we log is a thing we can grep backwards
|
||||||
# for
|
# for
|
||||||
logging.warn("***** STARTING SERVER *****")
|
logging.warning("***** STARTING SERVER *****")
|
||||||
logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
|
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
|
||||||
logging.info("Server hostname: %s", config.server_name)
|
logging.info("Server hostname: %s", config.server_name)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
@ -37,9 +37,6 @@ class EventContext:
|
|||||||
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
|
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
|
||||||
(type, state_key) -> event_id. ``None`` for an outlier.
|
(type, state_key) -> event_id. ``None`` for an outlier.
|
||||||
|
|
||||||
prev_state_events (?): XXX: is this ever set to anything other than
|
|
||||||
the empty list?
|
|
||||||
|
|
||||||
app_service: FIXME
|
app_service: FIXME
|
||||||
|
|
||||||
_current_state_ids (dict[(str, str), str]|None):
|
_current_state_ids (dict[(str, str), str]|None):
|
||||||
@ -51,36 +48,16 @@ class EventContext:
|
|||||||
The current state map excluding the current event. None if outlier
|
The current state map excluding the current event. None if outlier
|
||||||
or we haven't fetched the state from DB yet.
|
or we haven't fetched the state from DB yet.
|
||||||
(type, state_key) -> event_id
|
(type, state_key) -> event_id
|
||||||
|
|
||||||
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
|
||||||
been calculated. None if we haven't started calculating yet
|
|
||||||
|
|
||||||
_event_type (str): The type of the event the context is associated with.
|
|
||||||
Only set when state has not been fetched yet.
|
|
||||||
|
|
||||||
_event_state_key (str|None): The state_key of the event the context is
|
|
||||||
associated with. Only set when state has not been fetched yet.
|
|
||||||
|
|
||||||
_prev_state_id (str|None): If the event associated with the context is
|
|
||||||
a state event, then `_prev_state_id` is the event_id of the state
|
|
||||||
that was replaced.
|
|
||||||
Only set when state has not been fetched yet.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
state_group = attr.ib(default=None)
|
state_group = attr.ib(default=None)
|
||||||
rejected = attr.ib(default=False)
|
rejected = attr.ib(default=False)
|
||||||
prev_group = attr.ib(default=None)
|
prev_group = attr.ib(default=None)
|
||||||
delta_ids = attr.ib(default=None)
|
delta_ids = attr.ib(default=None)
|
||||||
prev_state_events = attr.ib(default=attr.Factory(list))
|
|
||||||
app_service = attr.ib(default=None)
|
app_service = attr.ib(default=None)
|
||||||
|
|
||||||
_current_state_ids = attr.ib(default=None)
|
|
||||||
_prev_state_ids = attr.ib(default=None)
|
_prev_state_ids = attr.ib(default=None)
|
||||||
_prev_state_id = attr.ib(default=None)
|
_current_state_ids = attr.ib(default=None)
|
||||||
|
|
||||||
_event_type = attr.ib(default=None)
|
|
||||||
_event_state_key = attr.ib(default=None)
|
|
||||||
_fetching_state_deferred = attr.ib(default=None)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def with_state(
|
def with_state(
|
||||||
@ -90,7 +67,6 @@ class EventContext:
|
|||||||
current_state_ids=current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
prev_state_ids=prev_state_ids,
|
prev_state_ids=prev_state_ids,
|
||||||
state_group=state_group,
|
state_group=state_group,
|
||||||
fetching_state_deferred=defer.succeed(None),
|
|
||||||
prev_group=prev_group,
|
prev_group=prev_group,
|
||||||
delta_ids=delta_ids,
|
delta_ids=delta_ids,
|
||||||
)
|
)
|
||||||
@ -125,7 +101,6 @@ class EventContext:
|
|||||||
"rejected": self.rejected,
|
"rejected": self.rejected,
|
||||||
"prev_group": self.prev_group,
|
"prev_group": self.prev_group,
|
||||||
"delta_ids": _encode_state_dict(self.delta_ids),
|
"delta_ids": _encode_state_dict(self.delta_ids),
|
||||||
"prev_state_events": self.prev_state_events,
|
|
||||||
"app_service_id": self.app_service.id if self.app_service else None,
|
"app_service_id": self.app_service.id if self.app_service else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,7 +116,7 @@ class EventContext:
|
|||||||
Returns:
|
Returns:
|
||||||
EventContext
|
EventContext
|
||||||
"""
|
"""
|
||||||
context = EventContext(
|
context = _AsyncEventContextImpl(
|
||||||
# We use the state_group and prev_state_id stuff to pull the
|
# We use the state_group and prev_state_id stuff to pull the
|
||||||
# current_state_ids out of the DB and construct prev_state_ids.
|
# current_state_ids out of the DB and construct prev_state_ids.
|
||||||
prev_state_id=input["prev_state_id"],
|
prev_state_id=input["prev_state_id"],
|
||||||
@ -151,7 +126,6 @@ class EventContext:
|
|||||||
prev_group=input["prev_group"],
|
prev_group=input["prev_group"],
|
||||||
delta_ids=_decode_state_dict(input["delta_ids"]),
|
delta_ids=_decode_state_dict(input["delta_ids"]),
|
||||||
rejected=input["rejected"],
|
rejected=input["rejected"],
|
||||||
prev_state_events=input["prev_state_events"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
app_service_id = input["app_service_id"]
|
app_service_id = input["app_service_id"]
|
||||||
@ -170,14 +144,7 @@ class EventContext:
|
|||||||
Maps a (type, state_key) to the event ID of the state event matching
|
Maps a (type, state_key) to the event ID of the state event matching
|
||||||
this tuple.
|
this tuple.
|
||||||
"""
|
"""
|
||||||
|
yield self._ensure_fetched(store)
|
||||||
if not self._fetching_state_deferred:
|
|
||||||
self._fetching_state_deferred = run_in_background(
|
|
||||||
self._fill_out_state, store
|
|
||||||
)
|
|
||||||
|
|
||||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
|
||||||
|
|
||||||
return self._current_state_ids
|
return self._current_state_ids
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -190,14 +157,7 @@ class EventContext:
|
|||||||
Maps a (type, state_key) to the event ID of the state event matching
|
Maps a (type, state_key) to the event ID of the state event matching
|
||||||
this tuple.
|
this tuple.
|
||||||
"""
|
"""
|
||||||
|
yield self._ensure_fetched(store)
|
||||||
if not self._fetching_state_deferred:
|
|
||||||
self._fetching_state_deferred = run_in_background(
|
|
||||||
self._fill_out_state, store
|
|
||||||
)
|
|
||||||
|
|
||||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
|
||||||
|
|
||||||
return self._prev_state_ids
|
return self._prev_state_ids
|
||||||
|
|
||||||
def get_cached_current_state_ids(self):
|
def get_cached_current_state_ids(self):
|
||||||
@ -211,6 +171,44 @@ class EventContext:
|
|||||||
|
|
||||||
return self._current_state_ids
|
return self._current_state_ids
|
||||||
|
|
||||||
|
def _ensure_fetched(self, store):
|
||||||
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class _AsyncEventContextImpl(EventContext):
|
||||||
|
"""
|
||||||
|
An implementation of EventContext which fetches _current_state_ids and
|
||||||
|
_prev_state_ids from the database on demand.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
|
||||||
|
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
||||||
|
been calculated. None if we haven't started calculating yet
|
||||||
|
|
||||||
|
_event_type (str): The type of the event the context is associated with.
|
||||||
|
|
||||||
|
_event_state_key (str): The state_key of the event the context is
|
||||||
|
associated with.
|
||||||
|
|
||||||
|
_prev_state_id (str|None): If the event associated with the context is
|
||||||
|
a state event, then `_prev_state_id` is the event_id of the state
|
||||||
|
that was replaced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_prev_state_id = attr.ib(default=None)
|
||||||
|
_event_type = attr.ib(default=None)
|
||||||
|
_event_state_key = attr.ib(default=None)
|
||||||
|
_fetching_state_deferred = attr.ib(default=None)
|
||||||
|
|
||||||
|
def _ensure_fetched(self, store):
|
||||||
|
if not self._fetching_state_deferred:
|
||||||
|
self._fetching_state_deferred = run_in_background(
|
||||||
|
self._fill_out_state, store
|
||||||
|
)
|
||||||
|
|
||||||
|
return make_deferred_yieldable(self._fetching_state_deferred)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _fill_out_state(self, store):
|
def _fill_out_state(self, store):
|
||||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||||
@ -228,27 +226,6 @@ class EventContext:
|
|||||||
else:
|
else:
|
||||||
self._prev_state_ids = self._current_state_ids
|
self._prev_state_ids = self._current_state_ids
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def update_state(
|
|
||||||
self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
|
|
||||||
):
|
|
||||||
"""Replace the state in the context
|
|
||||||
"""
|
|
||||||
|
|
||||||
# We need to make sure we wait for any ongoing fetching of state
|
|
||||||
# to complete so that the updated state doesn't get clobbered
|
|
||||||
if self._fetching_state_deferred:
|
|
||||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
|
||||||
|
|
||||||
self.state_group = state_group
|
|
||||||
self._prev_state_ids = prev_state_ids
|
|
||||||
self.prev_group = prev_group
|
|
||||||
self._current_state_ids = current_state_ids
|
|
||||||
self.delta_ids = delta_ids
|
|
||||||
|
|
||||||
# We need to ensure that that we've marked as having fetched the state
|
|
||||||
self._fetching_state_deferred = defer.succeed(None)
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_state_dict(state_dict):
|
def _encode_state_dict(state_dict):
|
||||||
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|
||||||
|
@ -555,7 +555,7 @@ class FederationClient(FederationBase):
|
|||||||
Note that this does not append any events to any graphs.
|
Note that this does not append any events to any graphs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destinations (str): Candidate homeservers which are probably
|
destinations (Iterable[str]): Candidate homeservers which are probably
|
||||||
participating in the room.
|
participating in the room.
|
||||||
room_id (str): The room in which the event will happen.
|
room_id (str): The room in which the event will happen.
|
||||||
user_id (str): The user whose membership is being evented.
|
user_id (str): The user whose membership is being evented.
|
||||||
|
@ -192,15 +192,16 @@ class PerDestinationQueue(object):
|
|||||||
# We have to keep 2 free slots for presence and rr_edus
|
# We have to keep 2 free slots for presence and rr_edus
|
||||||
limit = MAX_EDUS_PER_TRANSACTION - 2
|
limit = MAX_EDUS_PER_TRANSACTION - 2
|
||||||
|
|
||||||
device_update_edus, dev_list_id = (
|
device_update_edus, dev_list_id = yield self._get_device_update_edus(
|
||||||
yield self._get_device_update_edus(limit)
|
limit
|
||||||
)
|
)
|
||||||
|
|
||||||
limit -= len(device_update_edus)
|
limit -= len(device_update_edus)
|
||||||
|
|
||||||
to_device_edus, device_stream_id = (
|
(
|
||||||
yield self._get_to_device_message_edus(limit)
|
to_device_edus,
|
||||||
)
|
device_stream_id,
|
||||||
|
) = yield self._get_to_device_message_edus(limit)
|
||||||
|
|
||||||
pending_edus = device_update_edus + to_device_edus
|
pending_edus = device_update_edus + to_device_edus
|
||||||
|
|
||||||
@ -359,20 +360,20 @@ class PerDestinationQueue(object):
|
|||||||
last_device_list = self._last_device_list_stream_id
|
last_device_list = self._last_device_list_stream_id
|
||||||
|
|
||||||
# Retrieve list of new device updates to send to the destination
|
# Retrieve list of new device updates to send to the destination
|
||||||
now_stream_id, results = yield self._store.get_devices_by_remote(
|
now_stream_id, results = yield self._store.get_device_updates_by_remote(
|
||||||
self._destination, last_device_list, limit=limit
|
self._destination, last_device_list, limit=limit
|
||||||
)
|
)
|
||||||
edus = [
|
edus = [
|
||||||
Edu(
|
Edu(
|
||||||
origin=self._server_name,
|
origin=self._server_name,
|
||||||
destination=self._destination,
|
destination=self._destination,
|
||||||
edu_type="m.device_list_update",
|
edu_type=edu_type,
|
||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
for content in results
|
for (edu_type, content) in results
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
|
assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
|
||||||
|
|
||||||
return (edus, now_stream_id)
|
return (edus, now_stream_id)
|
||||||
|
|
||||||
|
@ -38,9 +38,10 @@ class AccountDataEventSource(object):
|
|||||||
{"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
|
{"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
account_data, room_account_data = (
|
(
|
||||||
yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
account_data,
|
||||||
)
|
room_account_data,
|
||||||
|
) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
||||||
|
|
||||||
for account_data_type, content in account_data.items():
|
for account_data_type, content in account_data.items():
|
||||||
results.append({"type": account_data_type, "content": content})
|
results.append({"type": account_data_type, "content": content})
|
||||||
|
@ -73,7 +73,10 @@ class ApplicationServicesHandler(object):
|
|||||||
try:
|
try:
|
||||||
limit = 100
|
limit = 100
|
||||||
while True:
|
while True:
|
||||||
upper_bound, events = yield self.store.get_new_events_for_appservice(
|
(
|
||||||
|
upper_bound,
|
||||||
|
events,
|
||||||
|
) = yield self.store.get_new_events_for_appservice(
|
||||||
self.current_max, limit
|
self.current_max, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -459,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_federation_query_user_devices(self, user_id):
|
def on_federation_query_user_devices(self, user_id):
|
||||||
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
|
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
|
||||||
return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
|
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
|
||||||
|
self_signing_key = yield self.store.get_e2e_cross_signing_key(
|
||||||
|
user_id, "self_signing"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"devices": devices,
|
||||||
|
"master_key": master_key,
|
||||||
|
"self_signing_key": self_signing_key,
|
||||||
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_left_room(self, user, room_id):
|
def user_left_room(self, user, room_id):
|
||||||
|
@ -250,7 +250,7 @@ class DirectoryHandler(BaseHandler):
|
|||||||
ignore_backoff=True,
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
logging.warn("Error retrieving alias")
|
logging.warning("Error retrieving alias")
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
result = None
|
result = None
|
||||||
else:
|
else:
|
||||||
|
@ -36,6 +36,8 @@ from synapse.types import (
|
|||||||
get_verify_key_from_cross_signing_key,
|
get_verify_key_from_cross_signing_key,
|
||||||
)
|
)
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.async_helpers import Linearizer
|
||||||
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -49,10 +51,19 @@ class E2eKeysHandler(object):
|
|||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
self._edu_updater = SigningKeyEduUpdater(hs, self)
|
||||||
|
|
||||||
|
federation_registry = hs.get_federation_registry()
|
||||||
|
|
||||||
|
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
|
||||||
|
federation_registry.register_edu_handler(
|
||||||
|
"org.matrix.signing_key_update",
|
||||||
|
self._edu_updater.incoming_signing_key_update,
|
||||||
|
)
|
||||||
# doesn't really work as part of the generic query API, because the
|
# doesn't really work as part of the generic query API, because the
|
||||||
# query request requires an object POST, but we abuse the
|
# query request requires an object POST, but we abuse the
|
||||||
# "query handler" interface.
|
# "query handler" interface.
|
||||||
hs.get_federation_registry().register_query_handler(
|
federation_registry.register_query_handler(
|
||||||
"client_keys", self.on_federation_query_client_keys
|
"client_keys", self.on_federation_query_client_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -119,9 +130,10 @@ class E2eKeysHandler(object):
|
|||||||
else:
|
else:
|
||||||
query_list.append((user_id, None))
|
query_list.append((user_id, None))
|
||||||
|
|
||||||
user_ids_not_in_cache, remote_results = (
|
(
|
||||||
yield self.store.get_user_devices_from_cache(query_list)
|
user_ids_not_in_cache,
|
||||||
)
|
remote_results,
|
||||||
|
) = yield self.store.get_user_devices_from_cache(query_list)
|
||||||
for user_id, devices in iteritems(remote_results):
|
for user_id, devices in iteritems(remote_results):
|
||||||
user_devices = results.setdefault(user_id, {})
|
user_devices = results.setdefault(user_id, {})
|
||||||
for device_id, device in iteritems(devices):
|
for device_id, device in iteritems(devices):
|
||||||
@ -207,10 +219,12 @@ class E2eKeysHandler(object):
|
|||||||
if user_id in destination_query:
|
if user_id in destination_query:
|
||||||
results[user_id] = keys
|
results[user_id] = keys
|
||||||
|
|
||||||
|
if "master_keys" in remote_result:
|
||||||
for user_id, key in remote_result["master_keys"].items():
|
for user_id, key in remote_result["master_keys"].items():
|
||||||
if user_id in destination_query:
|
if user_id in destination_query:
|
||||||
cross_signing_keys["master_keys"][user_id] = key
|
cross_signing_keys["master_keys"][user_id] = key
|
||||||
|
|
||||||
|
if "self_signing_keys" in remote_result:
|
||||||
for user_id, key in remote_result["self_signing_keys"].items():
|
for user_id, key in remote_result["self_signing_keys"].items():
|
||||||
if user_id in destination_query:
|
if user_id in destination_query:
|
||||||
cross_signing_keys["self_signing_keys"][user_id] = key
|
cross_signing_keys["self_signing_keys"][user_id] = key
|
||||||
@ -251,7 +265,7 @@ class E2eKeysHandler(object):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[dict[str, dict[str, dict]]]: map from
|
defer.Deferred[dict[str, dict[str, dict]]]: map from
|
||||||
(master|self_signing|user_signing) -> user_id -> key
|
(master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
|
||||||
"""
|
"""
|
||||||
master_keys = {}
|
master_keys = {}
|
||||||
self_signing_keys = {}
|
self_signing_keys = {}
|
||||||
@ -343,7 +357,16 @@ class E2eKeysHandler(object):
|
|||||||
"""
|
"""
|
||||||
device_keys_query = query_body.get("device_keys", {})
|
device_keys_query = query_body.get("device_keys", {})
|
||||||
res = yield self.query_local_devices(device_keys_query)
|
res = yield self.query_local_devices(device_keys_query)
|
||||||
return {"device_keys": res}
|
ret = {"device_keys": res}
|
||||||
|
|
||||||
|
# add in the cross-signing keys
|
||||||
|
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
|
||||||
|
device_keys_query, None
|
||||||
|
)
|
||||||
|
|
||||||
|
ret.update(cross_signing_keys)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -688,17 +711,21 @@ class E2eKeysHandler(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# get our self-signing key to verify the signatures
|
# get our self-signing key to verify the signatures
|
||||||
_, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
|
(
|
||||||
user_id, "self_signing"
|
_,
|
||||||
)
|
self_signing_key_id,
|
||||||
|
self_signing_verify_key,
|
||||||
|
) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
|
||||||
|
|
||||||
# get our master key, since we may have received a signature of it.
|
# get our master key, since we may have received a signature of it.
|
||||||
# We need to fetch it here so that we know what its key ID is, so
|
# We need to fetch it here so that we know what its key ID is, so
|
||||||
# that we can check if a signature that was sent is a signature of
|
# that we can check if a signature that was sent is a signature of
|
||||||
# the master key or of a device
|
# the master key or of a device
|
||||||
master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key(
|
(
|
||||||
user_id, "master"
|
master_key,
|
||||||
)
|
_,
|
||||||
|
master_verify_key,
|
||||||
|
) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
|
||||||
|
|
||||||
# fetch our stored devices. This is used to 1. verify
|
# fetch our stored devices. This is used to 1. verify
|
||||||
# signatures on the master key, and 2. to compare with what
|
# signatures on the master key, and 2. to compare with what
|
||||||
@ -838,9 +865,11 @@ class E2eKeysHandler(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# get our user-signing key to verify the signatures
|
# get our user-signing key to verify the signatures
|
||||||
user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key(
|
(
|
||||||
user_id, "user_signing"
|
user_signing_key,
|
||||||
)
|
user_signing_key_id,
|
||||||
|
user_signing_verify_key,
|
||||||
|
) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
failure = _exception_to_failure(e)
|
failure = _exception_to_failure(e)
|
||||||
for user, devicemap in signatures.items():
|
for user, devicemap in signatures.items():
|
||||||
@ -859,7 +888,11 @@ class E2eKeysHandler(object):
|
|||||||
try:
|
try:
|
||||||
# get the target user's master key, to make sure it matches
|
# get the target user's master key, to make sure it matches
|
||||||
# what was sent
|
# what was sent
|
||||||
master_key, master_key_id, _ = yield self._get_e2e_cross_signing_verify_key(
|
(
|
||||||
|
master_key,
|
||||||
|
master_key_id,
|
||||||
|
_,
|
||||||
|
) = yield self._get_e2e_cross_signing_verify_key(
|
||||||
target_user, "master", user_id
|
target_user, "master", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1047,3 +1080,100 @@ class SignatureListItem:
|
|||||||
target_user_id = attr.ib()
|
target_user_id = attr.ib()
|
||||||
target_device_id = attr.ib()
|
target_device_id = attr.ib()
|
||||||
signature = attr.ib()
|
signature = attr.ib()
|
||||||
|
|
||||||
|
|
||||||
|
class SigningKeyEduUpdater(object):
|
||||||
|
"""Handles incoming signing key updates from federation and updates the DB"""
|
||||||
|
|
||||||
|
def __init__(self, hs, e2e_keys_handler):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.federation = hs.get_federation_client()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.e2e_keys_handler = e2e_keys_handler
|
||||||
|
|
||||||
|
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
||||||
|
|
||||||
|
# user_id -> list of updates waiting to be handled.
|
||||||
|
self._pending_updates = {}
|
||||||
|
|
||||||
|
# Recently seen stream ids. We don't bother keeping these in the DB,
|
||||||
|
# but they're useful to have them about to reduce the number of spurious
|
||||||
|
# resyncs.
|
||||||
|
self._seen_updates = ExpiringCache(
|
||||||
|
cache_name="signing_key_update_edu",
|
||||||
|
clock=self.clock,
|
||||||
|
max_len=10000,
|
||||||
|
expiry_ms=30 * 60 * 1000,
|
||||||
|
iterable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def incoming_signing_key_update(self, origin, edu_content):
|
||||||
|
"""Called on incoming signing key update from federation. Responsible for
|
||||||
|
parsing the EDU and adding to pending updates list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin (string): the server that sent the EDU
|
||||||
|
edu_content (dict): the contents of the EDU
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = edu_content.pop("user_id")
|
||||||
|
master_key = edu_content.pop("master_key", None)
|
||||||
|
self_signing_key = edu_content.pop("self_signing_key", None)
|
||||||
|
|
||||||
|
if get_domain_from_id(user_id) != origin:
|
||||||
|
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
|
||||||
|
return
|
||||||
|
|
||||||
|
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||||
|
if not room_ids:
|
||||||
|
# We don't share any rooms with this user. Ignore update, as we
|
||||||
|
# probably won't get any further updates.
|
||||||
|
return
|
||||||
|
|
||||||
|
self._pending_updates.setdefault(user_id, []).append(
|
||||||
|
(master_key, self_signing_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self._handle_signing_key_updates(user_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_signing_key_updates(self, user_id):
|
||||||
|
"""Actually handle pending updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (string): the user whose updates we are processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
device_handler = self.e2e_keys_handler.device_handler
|
||||||
|
|
||||||
|
with (yield self._remote_edu_linearizer.queue(user_id)):
|
||||||
|
pending_updates = self._pending_updates.pop(user_id, [])
|
||||||
|
if not pending_updates:
|
||||||
|
# This can happen since we batch updates
|
||||||
|
return
|
||||||
|
|
||||||
|
device_ids = []
|
||||||
|
|
||||||
|
logger.info("pending updates: %r", pending_updates)
|
||||||
|
|
||||||
|
for master_key, self_signing_key in pending_updates:
|
||||||
|
if master_key:
|
||||||
|
yield self.store.set_e2e_cross_signing_key(
|
||||||
|
user_id, "master", master_key
|
||||||
|
)
|
||||||
|
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
|
||||||
|
# verify_key is a VerifyKey from signedjson, which uses
|
||||||
|
# .version to denote the portion of the key ID after the
|
||||||
|
# algorithm and colon, which is the device ID
|
||||||
|
device_ids.append(verify_key.version)
|
||||||
|
if self_signing_key:
|
||||||
|
yield self.store.set_e2e_cross_signing_key(
|
||||||
|
user_id, "self_signing", self_signing_key
|
||||||
|
)
|
||||||
|
_, verify_key = get_verify_key_from_cross_signing_key(
|
||||||
|
self_signing_key
|
||||||
|
)
|
||||||
|
device_ids.append(verify_key.version)
|
||||||
|
|
||||||
|
yield device_handler.notify_device_update(user_id, device_ids)
|
||||||
|
@ -45,6 +45,7 @@ from synapse.api.errors import (
|
|||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||||
from synapse.crypto.event_signing import compute_event_signature
|
from synapse.crypto.event_signing import compute_event_signature
|
||||||
from synapse.event_auth import auth_types_for_event
|
from synapse.event_auth import auth_types_for_event
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
@ -352,11 +353,12 @@ class FederationHandler(BaseHandler):
|
|||||||
# note that if any of the missing prevs share missing state or
|
# note that if any of the missing prevs share missing state or
|
||||||
# auth events, the requests to fetch those events are deduped
|
# auth events, the requests to fetch those events are deduped
|
||||||
# by the get_pdu_cache in federation_client.
|
# by the get_pdu_cache in federation_client.
|
||||||
remote_state, got_auth_chain = (
|
(
|
||||||
yield self.federation_client.get_state_for_room(
|
remote_state,
|
||||||
|
got_auth_chain,
|
||||||
|
) = yield self.federation_client.get_state_for_room(
|
||||||
origin, room_id, p
|
origin, room_id, p
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# we want the state *after* p; get_state_for_room returns the
|
# we want the state *after* p; get_state_for_room returns the
|
||||||
# state *before* p.
|
# state *before* p.
|
||||||
@ -1105,7 +1107,7 @@ class FederationHandler(BaseHandler):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_invite_join(self, target_hosts, room_id, joinee, content):
|
def do_invite_join(self, target_hosts, room_id, joinee, content):
|
||||||
""" Attempts to join the `joinee` to the room `room_id` via the
|
""" Attempts to join the `joinee` to the room `room_id` via the
|
||||||
server `target_host`.
|
servers contained in `target_hosts`.
|
||||||
|
|
||||||
This first triggers a /make_join/ request that returns a partial
|
This first triggers a /make_join/ request that returns a partial
|
||||||
event that we can fill out and sign. This is then sent to the
|
event that we can fill out and sign. This is then sent to the
|
||||||
@ -1114,6 +1116,15 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
We suspend processing of any received events from this room until we
|
We suspend processing of any received events from this room until we
|
||||||
have finished processing the join.
|
have finished processing the join.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_hosts (Iterable[str]): List of servers to attempt to join the room with.
|
||||||
|
|
||||||
|
room_id (str): The ID of the room to join.
|
||||||
|
|
||||||
|
joinee (str): The User ID of the joining user.
|
||||||
|
|
||||||
|
content (dict): The event content to use for the join event.
|
||||||
"""
|
"""
|
||||||
logger.debug("Joining %s to %s", joinee, room_id)
|
logger.debug("Joining %s to %s", joinee, room_id)
|
||||||
|
|
||||||
@ -1173,6 +1184,22 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
yield self._persist_auth_tree(origin, auth_chain, state, event)
|
yield self._persist_auth_tree(origin, auth_chain, state, event)
|
||||||
|
|
||||||
|
# Check whether this room is the result of an upgrade of a room we already know
|
||||||
|
# about. If so, migrate over user information
|
||||||
|
predecessor = yield self.store.get_room_predecessor(room_id)
|
||||||
|
if not predecessor:
|
||||||
|
return
|
||||||
|
old_room_id = predecessor["room_id"]
|
||||||
|
logger.debug(
|
||||||
|
"Found predecessor for %s during remote join: %s", room_id, old_room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# We retrieve the room member handler here as to not cause a cyclic dependency
|
||||||
|
member_handler = self.hs.get_room_member_handler()
|
||||||
|
yield member_handler.transfer_room_state_on_room_upgrade(
|
||||||
|
old_room_id, room_id
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("Finished joining %s to %s", joinee, room_id)
|
logger.debug("Finished joining %s to %s", joinee, room_id)
|
||||||
finally:
|
finally:
|
||||||
room_queue = self.room_queues[room_id]
|
room_queue = self.room_queues[room_id]
|
||||||
@ -1845,14 +1872,7 @@ class FederationHandler(BaseHandler):
|
|||||||
if c and c.type == EventTypes.Create:
|
if c and c.type == EventTypes.Create:
|
||||||
auth_events[(c.type, c.state_key)] = c
|
auth_events[(c.type, c.state_key)] = c
|
||||||
|
|
||||||
try:
|
context = yield self.do_auth(origin, event, context, auth_events=auth_events)
|
||||||
yield self.do_auth(origin, event, context, auth_events=auth_events)
|
|
||||||
except AuthError as e:
|
|
||||||
logger.warning(
|
|
||||||
"[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
|
|
||||||
)
|
|
||||||
|
|
||||||
context.rejected = RejectedReason.AUTH_ERROR
|
|
||||||
|
|
||||||
if not context.rejected:
|
if not context.rejected:
|
||||||
yield self._check_for_soft_fail(event, state, backfilled)
|
yield self._check_for_soft_fail(event, state, backfilled)
|
||||||
@ -2021,12 +2041,12 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
Also NB that this function adds entries to it.
|
Also NB that this function adds entries to it.
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[None]
|
defer.Deferred[EventContext]: updated context object
|
||||||
"""
|
"""
|
||||||
room_version = yield self.store.get_room_version(event.room_id)
|
room_version = yield self.store.get_room_version(event.room_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self._update_auth_events_and_context_for_auth(
|
context = yield self._update_auth_events_and_context_for_auth(
|
||||||
origin, event, context, auth_events
|
origin, event, context, auth_events
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -2044,7 +2064,9 @@ class FederationHandler(BaseHandler):
|
|||||||
event_auth.check(room_version, event, auth_events=auth_events)
|
event_auth.check(room_version, event, auth_events=auth_events)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warning("Failed auth resolution for %r because %s", event, e)
|
logger.warning("Failed auth resolution for %r because %s", event, e)
|
||||||
raise e
|
context.rejected = RejectedReason.AUTH_ERROR
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _update_auth_events_and_context_for_auth(
|
def _update_auth_events_and_context_for_auth(
|
||||||
@ -2068,7 +2090,7 @@ class FederationHandler(BaseHandler):
|
|||||||
auth_events (dict[(str, str)->synapse.events.EventBase]):
|
auth_events (dict[(str, str)->synapse.events.EventBase]):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[None]
|
defer.Deferred[EventContext]: updated context
|
||||||
"""
|
"""
|
||||||
event_auth_events = set(event.auth_event_ids())
|
event_auth_events = set(event.auth_event_ids())
|
||||||
|
|
||||||
@ -2107,7 +2129,7 @@ class FederationHandler(BaseHandler):
|
|||||||
# The other side isn't around or doesn't implement the
|
# The other side isn't around or doesn't implement the
|
||||||
# endpoint, so lets just bail out.
|
# endpoint, so lets just bail out.
|
||||||
logger.info("Failed to get event auth from remote: %s", e)
|
logger.info("Failed to get event auth from remote: %s", e)
|
||||||
return
|
return context
|
||||||
|
|
||||||
seen_remotes = yield self.store.have_seen_events(
|
seen_remotes = yield self.store.have_seen_events(
|
||||||
[e.event_id for e in remote_auth_chain]
|
[e.event_id for e in remote_auth_chain]
|
||||||
@ -2148,7 +2170,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
if event.internal_metadata.is_outlier():
|
if event.internal_metadata.is_outlier():
|
||||||
logger.info("Skipping auth_event fetch for outlier")
|
logger.info("Skipping auth_event fetch for outlier")
|
||||||
return
|
return context
|
||||||
|
|
||||||
# FIXME: Assumes we have and stored all the state for all the
|
# FIXME: Assumes we have and stored all the state for all the
|
||||||
# prev_events
|
# prev_events
|
||||||
@ -2157,7 +2179,7 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not different_auth:
|
if not different_auth:
|
||||||
return
|
return context
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"auth_events refers to events which are not in our calculated auth "
|
"auth_events refers to events which are not in our calculated auth "
|
||||||
@ -2204,10 +2226,12 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
auth_events.update(new_state)
|
auth_events.update(new_state)
|
||||||
|
|
||||||
yield self._update_context_for_auth_events(
|
context = yield self._update_context_for_auth_events(
|
||||||
event, context, auth_events, event_key
|
event, context, auth_events, event_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
|
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
|
||||||
"""Update the state_ids in an event context after auth event resolution,
|
"""Update the state_ids in an event context after auth event resolution,
|
||||||
@ -2216,14 +2240,16 @@ class FederationHandler(BaseHandler):
|
|||||||
Args:
|
Args:
|
||||||
event (Event): The event we're handling the context for
|
event (Event): The event we're handling the context for
|
||||||
|
|
||||||
context (synapse.events.snapshot.EventContext): event context
|
context (synapse.events.snapshot.EventContext): initial event context
|
||||||
to be updated
|
|
||||||
|
|
||||||
auth_events (dict[(str, str)->str]): Events to update in the event
|
auth_events (dict[(str, str)->str]): Events to update in the event
|
||||||
context.
|
context.
|
||||||
|
|
||||||
event_key ((str, str)): (type, state_key) for the current event.
|
event_key ((str, str)): (type, state_key) for the current event.
|
||||||
this will not be included in the current_state in the context.
|
this will not be included in the current_state in the context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[EventContext]: new event context
|
||||||
"""
|
"""
|
||||||
state_updates = {
|
state_updates = {
|
||||||
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
|
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
|
||||||
@ -2248,7 +2274,7 @@ class FederationHandler(BaseHandler):
|
|||||||
current_state_ids=current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield context.update_state(
|
return EventContext.with_state(
|
||||||
state_group=state_group,
|
state_group=state_group,
|
||||||
current_state_ids=current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
prev_state_ids=prev_state_ids,
|
prev_state_ids=prev_state_ids,
|
||||||
@ -2441,6 +2467,8 @@ class FederationHandler(BaseHandler):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
yield self._check_signature(event, context)
|
yield self._check_signature(event, context)
|
||||||
|
|
||||||
|
# We retrieve the room member handler here as to not cause a cyclic dependency
|
||||||
member_handler = self.hs.get_room_member_handler()
|
member_handler = self.hs.get_room_member_handler()
|
||||||
yield member_handler.send_membership_event(None, event, context)
|
yield member_handler.send_membership_event(None, event, context)
|
||||||
else:
|
else:
|
||||||
@ -2501,6 +2529,7 @@ class FederationHandler(BaseHandler):
|
|||||||
# though the sender isn't a local user.
|
# though the sender isn't a local user.
|
||||||
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
|
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
|
||||||
|
|
||||||
|
# We retrieve the room member handler here as to not cause a cyclic dependency
|
||||||
member_handler = self.hs.get_room_member_handler()
|
member_handler = self.hs.get_room_member_handler()
|
||||||
yield member_handler.send_membership_event(None, event, context)
|
yield member_handler.send_membership_event(None, event, context)
|
||||||
|
|
||||||
|
@ -128,8 +128,8 @@ class InitialSyncHandler(BaseHandler):
|
|||||||
|
|
||||||
tags_by_room = yield self.store.get_tags_for_user(user_id)
|
tags_by_room = yield self.store.get_tags_for_user(user_id)
|
||||||
|
|
||||||
account_data, account_data_by_room = (
|
account_data, account_data_by_room = yield self.store.get_account_data_for_user(
|
||||||
yield self.store.get_account_data_for_user(user_id)
|
user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
public_room_ids = yield self.store.get_public_room_ids()
|
public_room_ids = yield self.store.get_public_room_ids()
|
||||||
|
@ -76,9 +76,10 @@ class MessageHandler(object):
|
|||||||
Raises:
|
Raises:
|
||||||
SynapseError if something went wrong.
|
SynapseError if something went wrong.
|
||||||
"""
|
"""
|
||||||
membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
|
(
|
||||||
room_id, user_id
|
membership,
|
||||||
)
|
membership_event_id,
|
||||||
|
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
data = yield self.state.get_current_state(room_id, event_type, state_key)
|
data = yield self.state.get_current_state(room_id, event_type, state_key)
|
||||||
@ -153,9 +154,10 @@ class MessageHandler(object):
|
|||||||
% (user_id, room_id, at_token),
|
% (user_id, room_id, at_token),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
membership, membership_event_id = (
|
(
|
||||||
yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
membership,
|
||||||
)
|
membership_event_id,
|
||||||
|
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
state_ids = yield self.store.get_filtered_current_state_ids(
|
state_ids = yield self.store.get_filtered_current_state_ids(
|
||||||
|
@ -214,9 +214,10 @@ class PaginationHandler(object):
|
|||||||
source_config = pagin_config.get_source_config("room")
|
source_config = pagin_config.get_source_config("room")
|
||||||
|
|
||||||
with (yield self.pagination_lock.read(room_id)):
|
with (yield self.pagination_lock.read(room_id)):
|
||||||
membership, member_event_id = yield self.auth.check_in_room_or_world_readable(
|
(
|
||||||
room_id, user_id
|
membership,
|
||||||
)
|
member_event_id,
|
||||||
|
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||||
|
|
||||||
if source_config.direction == "b":
|
if source_config.direction == "b":
|
||||||
# if we're going backwards, we might need to backfill. This
|
# if we're going backwards, we might need to backfill. This
|
||||||
@ -299,10 +300,8 @@ class PaginationHandler(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if state:
|
if state:
|
||||||
chunk["state"] = (
|
chunk["state"] = yield self._event_serializer.serialize_events(
|
||||||
yield self._event_serializer.serialize_events(
|
|
||||||
state, time_now, as_client_event=as_client_event
|
state, time_now, as_client_event=as_client_event
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return chunk
|
return chunk
|
||||||
|
@ -396,8 +396,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
room_id = room_identifier
|
room_id = room_identifier
|
||||||
elif RoomAlias.is_valid(room_identifier):
|
elif RoomAlias.is_valid(room_identifier):
|
||||||
room_alias = RoomAlias.from_string(room_identifier)
|
room_alias = RoomAlias.from_string(room_identifier)
|
||||||
room_id, remote_room_hosts = (
|
room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
|
||||||
yield room_member_handler.lookup_room_alias(room_alias)
|
room_alias
|
||||||
)
|
)
|
||||||
room_id = room_id.to_string()
|
room_id = room_id.to_string()
|
||||||
else:
|
else:
|
||||||
|
@ -129,6 +129,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
old_room_id,
|
old_room_id,
|
||||||
new_version, # args for _upgrade_room
|
new_version, # args for _upgrade_room
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -147,8 +148,10 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
|
|
||||||
# we create and auth the tombstone event before properly creating the new
|
# we create and auth the tombstone event before properly creating the new
|
||||||
# room, to check our user has perms in the old room.
|
# room, to check our user has perms in the old room.
|
||||||
tombstone_event, tombstone_context = (
|
(
|
||||||
yield self.event_creation_handler.create_event(
|
tombstone_event,
|
||||||
|
tombstone_context,
|
||||||
|
) = yield self.event_creation_handler.create_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": EventTypes.Tombstone,
|
"type": EventTypes.Tombstone,
|
||||||
@ -162,7 +165,6 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
},
|
},
|
||||||
token_id=requester.access_token_id,
|
token_id=requester.access_token_id,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
old_room_version = yield self.store.get_room_version(old_room_id)
|
old_room_version = yield self.store.get_room_version(old_room_id)
|
||||||
yield self.auth.check_from_context(
|
yield self.auth.check_from_context(
|
||||||
old_room_version, tombstone_event, tombstone_context
|
old_room_version, tombstone_event, tombstone_context
|
||||||
@ -188,7 +190,12 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
requester, old_room_id, new_room_id, old_room_state
|
requester, old_room_id, new_room_id, old_room_state
|
||||||
)
|
)
|
||||||
|
|
||||||
# and finally, shut down the PLs in the old room, and update them in the new
|
# Copy over user push rules, tags and migrate room directory state
|
||||||
|
yield self.room_member_handler.transfer_room_state_on_room_upgrade(
|
||||||
|
old_room_id, new_room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# finally, shut down the PLs in the old room, and update them in the new
|
||||||
# room.
|
# room.
|
||||||
yield self._update_upgraded_room_pls(
|
yield self._update_upgraded_room_pls(
|
||||||
requester, old_room_id, new_room_id, old_room_state
|
requester, old_room_id, new_room_id, old_room_state
|
||||||
|
@ -203,10 +203,6 @@ class RoomMemberHandler(object):
|
|||||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
if newly_joined:
|
if newly_joined:
|
||||||
# Copy over user state if we're joining an upgraded room
|
|
||||||
yield self.copy_user_state_if_room_upgrade(
|
|
||||||
room_id, requester.user.to_string()
|
|
||||||
)
|
|
||||||
yield self._user_joined_room(target, room_id)
|
yield self._user_joined_room(target, room_id)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
if prev_member_event_id:
|
if prev_member_event_id:
|
||||||
@ -455,11 +451,6 @@ class RoomMemberHandler(object):
|
|||||||
requester, remote_room_hosts, room_id, target, content
|
requester, remote_room_hosts, room_id, target, content
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copy over user state if this is a join on an remote upgraded room
|
|
||||||
yield self.copy_user_state_if_room_upgrade(
|
|
||||||
room_id, requester.user.to_string()
|
|
||||||
)
|
|
||||||
|
|
||||||
return remote_join_response
|
return remote_join_response
|
||||||
|
|
||||||
elif effective_membership_state == Membership.LEAVE:
|
elif effective_membership_state == Membership.LEAVE:
|
||||||
@ -498,36 +489,72 @@ class RoomMemberHandler(object):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
|
def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
|
||||||
"""Copy user-specific information when they join a new room if that new room is the
|
"""Upon our server becoming aware of an upgraded room, either by upgrading a room
|
||||||
result of a room upgrade
|
ourselves or joining one, we can transfer over information from the previous room.
|
||||||
|
|
||||||
|
Copies user state (tags/push rules) for every local user that was in the old room, as
|
||||||
|
well as migrating the room directory state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
new_room_id (str): The ID of the room the user is joining
|
old_room_id (str): The ID of the old room
|
||||||
user_id (str): The ID of the user
|
|
||||||
|
room_id (str): The ID of the new room
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
# Find all local users that were in the old room and copy over each user's state
|
||||||
|
users = yield self.store.get_users_in_room(old_room_id)
|
||||||
|
yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
|
||||||
|
|
||||||
|
# Add new room to the room directory if the old room was there
|
||||||
|
# Remove old room from the room directory
|
||||||
|
old_room = yield self.store.get_room(old_room_id)
|
||||||
|
if old_room and old_room["is_public"]:
|
||||||
|
yield self.store.set_room_is_public(old_room_id, False)
|
||||||
|
yield self.store.set_room_is_public(room_id, True)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
|
||||||
|
"""Copy user-specific information when they join a new room when that new room is the
|
||||||
|
result of a room upgrade
|
||||||
|
|
||||||
|
Args:
|
||||||
|
old_room_id (str): The ID of upgraded room
|
||||||
|
new_room_id (str): The ID of the new room
|
||||||
|
user_ids (Iterable[str]): User IDs to copy state for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred
|
Deferred
|
||||||
"""
|
"""
|
||||||
# Check if the new room is an upgraded room
|
|
||||||
predecessor = yield self.store.get_room_predecessor(new_room_id)
|
|
||||||
if not predecessor:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Found predecessor for %s: %s. Copying over room tags and push " "rules",
|
"Copying over room tags and push rules from %s to %s for users %s",
|
||||||
|
old_room_id,
|
||||||
new_room_id,
|
new_room_id,
|
||||||
predecessor,
|
user_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for user_id in user_ids:
|
||||||
|
try:
|
||||||
# It is an upgraded room. Copy over old tags
|
# It is an upgraded room. Copy over old tags
|
||||||
yield self.copy_room_tags_and_direct_to_room(
|
yield self.copy_room_tags_and_direct_to_room(
|
||||||
predecessor["room_id"], new_room_id, user_id
|
old_room_id, new_room_id, user_id
|
||||||
)
|
)
|
||||||
# Copy over push rules
|
# Copy over push rules
|
||||||
yield self.store.copy_push_rules_from_room_to_room_for_user(
|
yield self.store.copy_push_rules_from_room_to_room_for_user(
|
||||||
predecessor["room_id"], new_room_id, user_id
|
old_room_id, new_room_id, user_id
|
||||||
)
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Error copying tags and/or push rules from rooms %s to %s for user %s. "
|
||||||
|
"Skipping...",
|
||||||
|
old_room_id,
|
||||||
|
new_room_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_membership_event(self, requester, event, context, ratelimit=True):
|
def send_membership_event(self, requester, event, context, ratelimit=True):
|
||||||
@ -759,8 +786,12 @@ class RoomMemberHandler(object):
|
|||||||
if room_avatar_event:
|
if room_avatar_event:
|
||||||
room_avatar_url = room_avatar_event.content.get("url", "")
|
room_avatar_url = room_avatar_event.content.get("url", "")
|
||||||
|
|
||||||
token, public_keys, fallback_public_key, display_name = (
|
(
|
||||||
yield self.identity_handler.ask_id_server_for_third_party_invite(
|
token,
|
||||||
|
public_keys,
|
||||||
|
fallback_public_key,
|
||||||
|
display_name,
|
||||||
|
) = yield self.identity_handler.ask_id_server_for_third_party_invite(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
id_server=id_server,
|
id_server=id_server,
|
||||||
medium=medium,
|
medium=medium,
|
||||||
@ -775,7 +806,6 @@ class RoomMemberHandler(object):
|
|||||||
inviter_avatar_url=inviter_avatar_url,
|
inviter_avatar_url=inviter_avatar_url,
|
||||||
id_access_token=id_access_token,
|
id_access_token=id_access_token,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||||
requester,
|
requester,
|
||||||
|
@ -396,16 +396,12 @@ class SearchHandler(BaseHandler):
|
|||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
for context in contexts.values():
|
for context in contexts.values():
|
||||||
context["events_before"] = (
|
context["events_before"] = yield self._event_serializer.serialize_events(
|
||||||
yield self._event_serializer.serialize_events(
|
|
||||||
context["events_before"], time_now
|
context["events_before"], time_now
|
||||||
)
|
)
|
||||||
)
|
context["events_after"] = yield self._event_serializer.serialize_events(
|
||||||
context["events_after"] = (
|
|
||||||
yield self._event_serializer.serialize_events(
|
|
||||||
context["events_after"], time_now
|
context["events_after"], time_now
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
state_results = {}
|
state_results = {}
|
||||||
if include_state:
|
if include_state:
|
||||||
|
@ -108,7 +108,10 @@ class StatsHandler(StateDeltasHandler):
|
|||||||
user_deltas = {}
|
user_deltas = {}
|
||||||
|
|
||||||
# Then count deltas for total_events and total_event_bytes.
|
# Then count deltas for total_events and total_event_bytes.
|
||||||
room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(
|
(
|
||||||
|
room_count,
|
||||||
|
user_count,
|
||||||
|
) = yield self.store.get_changes_room_total_events_and_bytes(
|
||||||
self.pos, max_pos
|
self.pos, max_pos
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1206,11 +1206,12 @@ class SyncHandler(object):
|
|||||||
since_token = sync_result_builder.since_token
|
since_token = sync_result_builder.since_token
|
||||||
|
|
||||||
if since_token and not sync_result_builder.full_state:
|
if since_token and not sync_result_builder.full_state:
|
||||||
account_data, account_data_by_room = (
|
(
|
||||||
yield self.store.get_updated_account_data_for_user(
|
account_data,
|
||||||
|
account_data_by_room,
|
||||||
|
) = yield self.store.get_updated_account_data_for_user(
|
||||||
user_id, since_token.account_data_key
|
user_id, since_token.account_data_key
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
|
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
|
||||||
user_id, int(since_token.push_rules_key)
|
user_id, int(since_token.push_rules_key)
|
||||||
@ -1221,9 +1222,10 @@ class SyncHandler(object):
|
|||||||
sync_config.user
|
sync_config.user
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
account_data, account_data_by_room = (
|
(
|
||||||
yield self.store.get_account_data_for_user(sync_config.user.to_string())
|
account_data,
|
||||||
)
|
account_data_by_room,
|
||||||
|
) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
|
||||||
|
|
||||||
account_data["m.push_rules"] = yield self.push_rules_for_user(
|
account_data["m.push_rules"] = yield self.push_rules_for_user(
|
||||||
sync_config.user
|
sync_config.user
|
||||||
|
@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self._enabled = bool(hs.config.recaptcha_private_key)
|
self._enabled = bool(hs.config.recaptcha_private_key)
|
||||||
self._http_client = hs.get_simple_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
self._url = hs.config.recaptcha_siteverify_api
|
self._url = hs.config.recaptcha_siteverify_api
|
||||||
self._secret = hs.config.recaptcha_private_key
|
self._secret = hs.config.recaptcha_private_key
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ from synapse.http import (
|
|||||||
cancelled_to_request_timed_out_error,
|
cancelled_to_request_timed_out_error,
|
||||||
redact_uri,
|
redact_uri,
|
||||||
)
|
)
|
||||||
|
from synapse.http.proxyagent import ProxyAgent
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||||
from synapse.util.async_helpers import timeout_deferred
|
from synapse.util.async_helpers import timeout_deferred
|
||||||
@ -183,7 +184,15 @@ class SimpleHttpClient(object):
|
|||||||
using HTTP in Matrix
|
using HTTP in Matrix
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs,
|
||||||
|
treq_args={},
|
||||||
|
ip_whitelist=None,
|
||||||
|
ip_blacklist=None,
|
||||||
|
http_proxy=None,
|
||||||
|
https_proxy=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hs (synapse.server.HomeServer)
|
hs (synapse.server.HomeServer)
|
||||||
@ -192,6 +201,8 @@ class SimpleHttpClient(object):
|
|||||||
we may not request.
|
we may not request.
|
||||||
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
|
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
|
||||||
request if it were otherwise caught in a blacklist.
|
request if it were otherwise caught in a blacklist.
|
||||||
|
http_proxy (bytes): proxy server to use for http connections. host[:port]
|
||||||
|
https_proxy (bytes): proxy server to use for https connections. host[:port]
|
||||||
"""
|
"""
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
@ -236,11 +247,13 @@ class SimpleHttpClient(object):
|
|||||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||||
# 'like a browser'
|
# 'like a browser'
|
||||||
self.agent = Agent(
|
self.agent = ProxyAgent(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
connectTimeout=15,
|
connectTimeout=15,
|
||||||
contextFactory=self.hs.get_http_client_context_factory(),
|
contextFactory=self.hs.get_http_client_context_factory(),
|
||||||
pool=pool,
|
pool=pool,
|
||||||
|
http_proxy=http_proxy,
|
||||||
|
https_proxy=https_proxy,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._ip_blacklist:
|
if self._ip_blacklist:
|
||||||
|
195
synapse/http/connectproxyclient.py
Normal file
195
synapse/http/connectproxyclient.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet import defer, protocol
|
||||||
|
from twisted.internet.error import ConnectError
|
||||||
|
from twisted.internet.interfaces import IStreamClientEndpoint
|
||||||
|
from twisted.internet.protocol import connectionDone
|
||||||
|
from twisted.web import http
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyConnectError(ConnectError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IStreamClientEndpoint)
|
||||||
|
class HTTPConnectProxyEndpoint(object):
|
||||||
|
"""An Endpoint implementation which will send a CONNECT request to an http proxy
|
||||||
|
|
||||||
|
Wraps an existing HostnameEndpoint for the proxy.
|
||||||
|
|
||||||
|
When we get the connect() request from the connection pool (via the TLS wrapper),
|
||||||
|
we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
|
||||||
|
CONNECT request. Once that completes, we invoke the protocolFactory which was passed
|
||||||
|
in.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reactor: the Twisted reactor to use for the connection
|
||||||
|
proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
|
||||||
|
proxy
|
||||||
|
host (bytes): hostname that we want to CONNECT to
|
||||||
|
port (int): port that we want to connect to
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reactor, proxy_endpoint, host, port):
|
||||||
|
self._reactor = reactor
|
||||||
|
self._proxy_endpoint = proxy_endpoint
|
||||||
|
self._host = host
|
||||||
|
self._port = port
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
|
||||||
|
|
||||||
|
def connect(self, protocolFactory):
|
||||||
|
f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
|
||||||
|
d = self._proxy_endpoint.connect(f)
|
||||||
|
# once the tcp socket connects successfully, we need to wait for the
|
||||||
|
# CONNECT to complete.
|
||||||
|
d.addCallback(lambda conn: f.on_connection)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPProxiedClientFactory(protocol.ClientFactory):
|
||||||
|
"""ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
|
||||||
|
|
||||||
|
Once the CONNECT completes, invokes the original ClientFactory to build the
|
||||||
|
HTTP Protocol object and run the rest of the connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dst_host (bytes): hostname that we want to CONNECT to
|
||||||
|
dst_port (int): port that we want to connect to
|
||||||
|
wrapped_factory (protocol.ClientFactory): The original Factory
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dst_host, dst_port, wrapped_factory):
|
||||||
|
self.dst_host = dst_host
|
||||||
|
self.dst_port = dst_port
|
||||||
|
self.wrapped_factory = wrapped_factory
|
||||||
|
self.on_connection = defer.Deferred()
|
||||||
|
|
||||||
|
def startedConnecting(self, connector):
|
||||||
|
return self.wrapped_factory.startedConnecting(connector)
|
||||||
|
|
||||||
|
def buildProtocol(self, addr):
|
||||||
|
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
|
||||||
|
|
||||||
|
return HTTPConnectProtocol(
|
||||||
|
self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
|
||||||
|
)
|
||||||
|
|
||||||
|
def clientConnectionFailed(self, connector, reason):
|
||||||
|
logger.debug("Connection to proxy failed: %s", reason)
|
||||||
|
if not self.on_connection.called:
|
||||||
|
self.on_connection.errback(reason)
|
||||||
|
return self.wrapped_factory.clientConnectionFailed(connector, reason)
|
||||||
|
|
||||||
|
def clientConnectionLost(self, connector, reason):
|
||||||
|
logger.debug("Connection to proxy lost: %s", reason)
|
||||||
|
if not self.on_connection.called:
|
||||||
|
self.on_connection.errback(reason)
|
||||||
|
return self.wrapped_factory.clientConnectionLost(connector, reason)
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPConnectProtocol(protocol.Protocol):
|
||||||
|
"""Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
|
||||||
|
to put in the CONNECT request
|
||||||
|
|
||||||
|
port (int): The original HTTP(s) port to put in the CONNECT request
|
||||||
|
|
||||||
|
wrapped_protocol (interfaces.IProtocol): the original protocol (probably
|
||||||
|
HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
|
||||||
|
|
||||||
|
connected_deferred (Deferred): a Deferred which will be callbacked with
|
||||||
|
wrapped_protocol when the CONNECT completes
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, host, port, wrapped_protocol, connected_deferred):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.wrapped_protocol = wrapped_protocol
|
||||||
|
self.connected_deferred = connected_deferred
|
||||||
|
self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
|
||||||
|
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
|
||||||
|
|
||||||
|
def connectionMade(self):
|
||||||
|
self.http_setup_client.makeConnection(self.transport)
|
||||||
|
|
||||||
|
def connectionLost(self, reason=connectionDone):
|
||||||
|
if self.wrapped_protocol.connected:
|
||||||
|
self.wrapped_protocol.connectionLost(reason)
|
||||||
|
|
||||||
|
self.http_setup_client.connectionLost(reason)
|
||||||
|
|
||||||
|
if not self.connected_deferred.called:
|
||||||
|
self.connected_deferred.errback(reason)
|
||||||
|
|
||||||
|
def proxyConnected(self, _):
|
||||||
|
self.wrapped_protocol.makeConnection(self.transport)
|
||||||
|
|
||||||
|
self.connected_deferred.callback(self.wrapped_protocol)
|
||||||
|
|
||||||
|
# Get any pending data from the http buf and forward it to the original protocol
|
||||||
|
buf = self.http_setup_client.clearLineBuffer()
|
||||||
|
if buf:
|
||||||
|
self.wrapped_protocol.dataReceived(buf)
|
||||||
|
|
||||||
|
def dataReceived(self, data):
|
||||||
|
# if we've set up the HTTP protocol, we can send the data there
|
||||||
|
if self.wrapped_protocol.connected:
|
||||||
|
return self.wrapped_protocol.dataReceived(data)
|
||||||
|
|
||||||
|
# otherwise, we must still be setting up the connection: send the data to the
|
||||||
|
# setup client
|
||||||
|
return self.http_setup_client.dataReceived(data)
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPConnectSetupClient(http.HTTPClient):
|
||||||
|
"""HTTPClient protocol to send a CONNECT message for proxies and read the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host (bytes): The hostname to send in the CONNECT message
|
||||||
|
port (int): The port to send in the CONNECT message
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, host, port):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.on_connected = defer.Deferred()
|
||||||
|
|
||||||
|
def connectionMade(self):
|
||||||
|
logger.debug("Connected to proxy, sending CONNECT")
|
||||||
|
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
|
||||||
|
self.endHeaders()
|
||||||
|
|
||||||
|
def handleStatus(self, version, status, message):
|
||||||
|
logger.debug("Got Status: %s %s %s", status, message, version)
|
||||||
|
if status != b"200":
|
||||||
|
raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
|
||||||
|
|
||||||
|
def handleEndHeaders(self):
|
||||||
|
logger.debug("End Headers")
|
||||||
|
self.on_connected.callback(None)
|
||||||
|
|
||||||
|
def handleResponse(self, body):
|
||||||
|
pass
|
195
synapse/http/proxyagent.py
Normal file
195
synapse/http/proxyagent.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
|
||||||
|
from twisted.web.error import SchemeNotSupported
|
||||||
|
from twisted.web.iweb import IAgent
|
||||||
|
|
||||||
|
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IAgent)
|
||||||
|
class ProxyAgent(_AgentBase):
|
||||||
|
"""An Agent implementation which will use an HTTP proxy if one was requested
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reactor: twisted reactor to place outgoing
|
||||||
|
connections.
|
||||||
|
|
||||||
|
contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
|
||||||
|
verification parameters of OpenSSL. The default is to use a
|
||||||
|
`BrowserLikePolicyForHTTPS`, so unless you have special
|
||||||
|
requirements you can leave this as-is.
|
||||||
|
|
||||||
|
connectTimeout (float): The amount of time that this Agent will wait
|
||||||
|
for the peer to accept a connection.
|
||||||
|
|
||||||
|
bindAddress (bytes): The local address for client sockets to bind to.
|
||||||
|
|
||||||
|
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
|
||||||
|
non-persistent pool instance will be created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
reactor,
|
||||||
|
contextFactory=BrowserLikePolicyForHTTPS(),
|
||||||
|
connectTimeout=None,
|
||||||
|
bindAddress=None,
|
||||||
|
pool=None,
|
||||||
|
http_proxy=None,
|
||||||
|
https_proxy=None,
|
||||||
|
):
|
||||||
|
_AgentBase.__init__(self, reactor, pool)
|
||||||
|
|
||||||
|
self._endpoint_kwargs = {}
|
||||||
|
if connectTimeout is not None:
|
||||||
|
self._endpoint_kwargs["timeout"] = connectTimeout
|
||||||
|
if bindAddress is not None:
|
||||||
|
self._endpoint_kwargs["bindAddress"] = bindAddress
|
||||||
|
|
||||||
|
self.http_proxy_endpoint = _http_proxy_endpoint(
|
||||||
|
http_proxy, reactor, **self._endpoint_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.https_proxy_endpoint = _http_proxy_endpoint(
|
||||||
|
https_proxy, reactor, **self._endpoint_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self._policy_for_https = contextFactory
|
||||||
|
self._reactor = reactor
|
||||||
|
|
||||||
|
def request(self, method, uri, headers=None, bodyProducer=None):
|
||||||
|
"""
|
||||||
|
Issue a request to the server indicated by the given uri.
|
||||||
|
|
||||||
|
Supports `http` and `https` schemes.
|
||||||
|
|
||||||
|
An existing connection from the connection pool may be used or a new one may be
|
||||||
|
created.
|
||||||
|
|
||||||
|
See also: twisted.web.iweb.IAgent.request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method (bytes): The request method to use, such as `GET`, `POST`, etc
|
||||||
|
|
||||||
|
uri (bytes): The location of the resource to request.
|
||||||
|
|
||||||
|
headers (Headers|None): Extra headers to send with the request
|
||||||
|
|
||||||
|
bodyProducer (IBodyProducer|None): An object which can generate bytes to
|
||||||
|
make up the body of this request (for example, the properly encoded
|
||||||
|
contents of a file for a file upload). Or, None if the request is to
|
||||||
|
have no body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[IResponse]: completes when the header of the response has
|
||||||
|
been received (regardless of the response status code).
|
||||||
|
"""
|
||||||
|
uri = uri.strip()
|
||||||
|
if not _VALID_URI.match(uri):
|
||||||
|
raise ValueError("Invalid URI {!r}".format(uri))
|
||||||
|
|
||||||
|
parsed_uri = URI.fromBytes(uri)
|
||||||
|
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
|
||||||
|
request_path = parsed_uri.originForm
|
||||||
|
|
||||||
|
if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
|
||||||
|
# Cache *all* connections under the same key, since we are only
|
||||||
|
# connecting to a single destination, the proxy:
|
||||||
|
pool_key = ("http-proxy", self.http_proxy_endpoint)
|
||||||
|
endpoint = self.http_proxy_endpoint
|
||||||
|
request_path = uri
|
||||||
|
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
|
||||||
|
endpoint = HTTPConnectProxyEndpoint(
|
||||||
|
self._reactor,
|
||||||
|
self.https_proxy_endpoint,
|
||||||
|
parsed_uri.host,
|
||||||
|
parsed_uri.port,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# not using a proxy
|
||||||
|
endpoint = HostnameEndpoint(
|
||||||
|
self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Requesting %s via %s", uri, endpoint)
|
||||||
|
|
||||||
|
if parsed_uri.scheme == b"https":
|
||||||
|
tls_connection_creator = self._policy_for_https.creatorForNetloc(
|
||||||
|
parsed_uri.host, parsed_uri.port
|
||||||
|
)
|
||||||
|
endpoint = wrapClientTLS(tls_connection_creator, endpoint)
|
||||||
|
elif parsed_uri.scheme == b"http":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
return defer.fail(
|
||||||
|
Failure(
|
||||||
|
SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._requestWithEndpoint(
|
||||||
|
pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _http_proxy_endpoint(proxy, reactor, **kwargs):
|
||||||
|
"""Parses an http proxy setting and returns an endpoint for the proxy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxy (bytes|None): the proxy setting
|
||||||
|
reactor: reactor to be used to connect to the proxy
|
||||||
|
kwargs: other args to be passed to HostnameEndpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
|
||||||
|
or None
|
||||||
|
"""
|
||||||
|
if proxy is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# currently we only support hostname:port. Some apps also support
|
||||||
|
# protocol://<host>[:port], which allows a way of requiring a TLS connection to the
|
||||||
|
# proxy.
|
||||||
|
|
||||||
|
host, port = parse_host_port(proxy, default_port=1080)
|
||||||
|
return HostnameEndpoint(reactor, host, port, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_host_port(hostport, default_port=None):
|
||||||
|
# could have sworn we had one of these somewhere else...
|
||||||
|
if b":" in hostport:
|
||||||
|
host, port = hostport.rsplit(b":", 1)
|
||||||
|
try:
|
||||||
|
port = int(port)
|
||||||
|
return host, port
|
||||||
|
except ValueError:
|
||||||
|
# the thing after the : wasn't a valid port; presumably this is an
|
||||||
|
# IPv6 address.
|
||||||
|
pass
|
||||||
|
|
||||||
|
return hostport, default_port
|
@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
|
|||||||
|
|
||||||
|
|
||||||
def parse_drain_configs(
|
def parse_drain_configs(
|
||||||
drains: dict
|
drains: dict,
|
||||||
) -> typing.Generator[DrainConfiguration, None, None]:
|
) -> typing.Generator[DrainConfiguration, None, None]:
|
||||||
"""
|
"""
|
||||||
Parse the drain configurations.
|
Parse the drain configurations.
|
||||||
|
@ -149,9 +149,10 @@ class BulkPushRuleEvaluator(object):
|
|||||||
|
|
||||||
room_members = yield self.store.get_joined_users_from_context(event, context)
|
room_members = yield self.store.get_joined_users_from_context(event, context)
|
||||||
|
|
||||||
(power_levels, sender_power_level) = (
|
(
|
||||||
yield self._get_power_levels_and_sender_level(event, context)
|
power_levels,
|
||||||
)
|
sender_power_level,
|
||||||
|
) = yield self._get_power_levels_and_sender_level(event, context)
|
||||||
|
|
||||||
evaluator = PushRuleEvaluatorForEvent(
|
evaluator = PushRuleEvaluatorForEvent(
|
||||||
event, len(room_members), sender_power_level, power_levels
|
event, len(room_members), sender_power_level, power_levels
|
||||||
|
@ -234,15 +234,13 @@ class EmailPusher(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.last_stream_ordering = last_stream_ordering
|
self.last_stream_ordering = last_stream_ordering
|
||||||
pusher_still_exists = (
|
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||||
yield self.store.update_pusher_last_stream_ordering_and_success(
|
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.email,
|
self.email,
|
||||||
self.user_id,
|
self.user_id,
|
||||||
last_stream_ordering,
|
last_stream_ordering,
|
||||||
self.clock.time_msec(),
|
self.clock.time_msec(),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if not pusher_still_exists:
|
if not pusher_still_exists:
|
||||||
# The pusher has been deleted while we were processing, so
|
# The pusher has been deleted while we were processing, so
|
||||||
# lets just stop and return.
|
# lets just stop and return.
|
||||||
|
@ -103,7 +103,7 @@ class HttpPusher(object):
|
|||||||
if "url" not in self.data:
|
if "url" not in self.data:
|
||||||
raise PusherConfigException("'url' required in data for HTTP pusher")
|
raise PusherConfigException("'url' required in data for HTTP pusher")
|
||||||
self.url = self.data["url"]
|
self.url = self.data["url"]
|
||||||
self.http_client = hs.get_simple_http_client()
|
self.http_client = hs.get_proxied_http_client()
|
||||||
self.data_minus_url = {}
|
self.data_minus_url = {}
|
||||||
self.data_minus_url.update(self.data)
|
self.data_minus_url.update(self.data)
|
||||||
del self.data_minus_url["url"]
|
del self.data_minus_url["url"]
|
||||||
@ -211,15 +211,13 @@ class HttpPusher(object):
|
|||||||
http_push_processed_counter.inc()
|
http_push_processed_counter.inc()
|
||||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||||
self.last_stream_ordering = push_action["stream_ordering"]
|
self.last_stream_ordering = push_action["stream_ordering"]
|
||||||
pusher_still_exists = (
|
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||||
yield self.store.update_pusher_last_stream_ordering_and_success(
|
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.pushkey,
|
self.pushkey,
|
||||||
self.user_id,
|
self.user_id,
|
||||||
self.last_stream_ordering,
|
self.last_stream_ordering,
|
||||||
self.clock.time_msec(),
|
self.clock.time_msec(),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if not pusher_still_exists:
|
if not pusher_still_exists:
|
||||||
# The pusher has been deleted while we were processing, so
|
# The pusher has been deleted while we were processing, so
|
||||||
# lets just stop and return.
|
# lets just stop and return.
|
||||||
|
@ -103,9 +103,7 @@ class PusherPool:
|
|||||||
# create the pusher setting last_stream_ordering to the current maximum
|
# create the pusher setting last_stream_ordering to the current maximum
|
||||||
# stream ordering in event_push_actions, so it will process
|
# stream ordering in event_push_actions, so it will process
|
||||||
# pushes from this point onwards.
|
# pushes from this point onwards.
|
||||||
last_stream_ordering = (
|
last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering()
|
||||||
yield self.store.get_latest_push_action_stream_ordering()
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.store.add_pusher(
|
yield self.store.add_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@ -44,7 +45,14 @@ class BaseSlavedStore(SQLBaseStore):
|
|||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self) -> Dict[str, int]:
|
||||||
|
"""
|
||||||
|
Get the current positions of all the streams this store wants to subscribe to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
map from stream name to the most recent update we have for
|
||||||
|
that stream (ie, the point we want to start replicating from)
|
||||||
|
"""
|
||||||
pos = {}
|
pos = {}
|
||||||
if self._cache_id_gen:
|
if self._cache_id_gen:
|
||||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
|
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
|
||||||
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
|
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
|
||||||
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
|
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
|||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedDeviceStore, self).stream_positions()
|
result = super(SlavedDeviceStore, self).stream_positions()
|
||||||
result["device_lists"] = self._device_list_id_gen.get_current_token()
|
# The user signature stream uses the same stream ID generator as the
|
||||||
|
# device list stream, so set them both to the device list ID
|
||||||
|
# generator's current token.
|
||||||
|
current_token = self._device_list_id_gen.get_current_token()
|
||||||
|
result[DeviceListsStream.NAME] = current_token
|
||||||
|
result[UserSignatureStream.NAME] = current_token
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "device_lists":
|
if stream_name == DeviceListsStream.NAME:
|
||||||
self._device_list_id_gen.advance(token)
|
self._device_list_id_gen.advance(token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
|
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
|
||||||
|
elif stream_name == UserSignatureStream.NAME:
|
||||||
|
for row in rows:
|
||||||
|
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||||
return super(SlavedDeviceStore, self).process_replication_rows(
|
return super(SlavedDeviceStore, self).process_replication_rows(
|
||||||
stream_name, token, rows
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
|
@ -16,10 +16,17 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.protocol import ReconnectingClientFactory
|
from twisted.internet.protocol import ReconnectingClientFactory
|
||||||
|
|
||||||
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
|
from synapse.replication.tcp.protocol import (
|
||||||
|
AbstractReplicationClientHandler,
|
||||||
|
ClientReplicationStreamProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
from .commands import (
|
from .commands import (
|
||||||
FederationAckCommand,
|
FederationAckCommand,
|
||||||
InvalidateCacheCommand,
|
InvalidateCacheCommand,
|
||||||
@ -27,7 +34,6 @@ from .commands import (
|
|||||||
UserIpCommand,
|
UserIpCommand,
|
||||||
UserSyncCommand,
|
UserSyncCommand,
|
||||||
)
|
)
|
||||||
from .protocol import ClientReplicationStreamProtocol
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
|||||||
|
|
||||||
maxDelay = 30 # Try at least once every N seconds
|
maxDelay = 30 # Try at least once every N seconds
|
||||||
|
|
||||||
def __init__(self, hs, client_name, handler):
|
def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
|
||||||
self.client_name = client_name
|
self.client_name = client_name
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
self.server_name = hs.config.server_name
|
self.server_name = hs.config.server_name
|
||||||
@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
|||||||
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
|
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
|
||||||
|
|
||||||
|
|
||||||
class ReplicationClientHandler(object):
|
class ReplicationClientHandler(AbstractReplicationClientHandler):
|
||||||
"""A base handler that can be passed to the ReplicationClientFactory.
|
"""A base handler that can be passed to the ReplicationClientFactory.
|
||||||
|
|
||||||
By default proxies incoming replication data to the SlaveStore.
|
By default proxies incoming replication data to the SlaveStore.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, store):
|
def __init__(self, store: BaseSlavedStore):
|
||||||
self.store = store
|
self.store = store
|
||||||
|
|
||||||
# The current connection. None if we are currently (re)connecting
|
# The current connection. None if we are currently (re)connecting
|
||||||
@ -138,11 +144,13 @@ class ReplicationClientHandler(object):
|
|||||||
if d:
|
if d:
|
||||||
d.callback(data)
|
d.callback(data)
|
||||||
|
|
||||||
def get_streams_to_replicate(self):
|
def get_streams_to_replicate(self) -> Dict[str, int]:
|
||||||
"""Called when a new connection has been established and we need to
|
"""Called when a new connection has been established and we need to
|
||||||
subscribe to streams.
|
subscribe to streams.
|
||||||
|
|
||||||
Returns a dictionary of stream name to token.
|
Returns:
|
||||||
|
map from stream name to the most recent update we have for
|
||||||
|
that stream (ie, the point we want to start replicating from)
|
||||||
"""
|
"""
|
||||||
args = self.store.stream_positions()
|
args = self.store.stream_positions()
|
||||||
user_account_data = args.pop("user_account_data", None)
|
user_account_data = args.pop("user_account_data", None)
|
||||||
|
@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire::
|
|||||||
> ERROR server stopping
|
> ERROR server stopping
|
||||||
* connection closed by server *
|
* connection closed by server *
|
||||||
"""
|
"""
|
||||||
|
import abc
|
||||||
import fcntl
|
import fcntl
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
@ -65,6 +65,7 @@ from twisted.python.failure import Failure
|
|||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
from .commands import (
|
from .commands import (
|
||||||
@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
self.streamer.lost_connection(self)
|
self.streamer.lost_connection(self)
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
The interface for the handler that should be passed to
|
||||||
|
ClientReplicationStreamProtocol
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def on_rdata(self, stream_name, token, rows):
|
||||||
|
"""Called to handle a batch of replication data with a given stream token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_name (str): name of the replication stream for this batch of rows
|
||||||
|
token (int): stream token for this batch of rows
|
||||||
|
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||||
|
Stream.parse_row.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred|None
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def on_position(self, stream_name, token):
|
||||||
|
"""Called when we get new position data."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def on_sync(self, data):
|
||||||
|
"""Called when get a new SYNC command."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_streams_to_replicate(self):
|
||||||
|
"""Called when a new connection has been established and we need to
|
||||||
|
subscribe to streams.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
map from stream name to the most recent update we have for
|
||||||
|
that stream (ie, the point we want to start replicating from)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_currently_syncing_users(self):
|
||||||
|
"""Get the list of currently syncing users (if any). This is called
|
||||||
|
when a connection has been established and we need to send the
|
||||||
|
currently syncing users."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def update_connection(self, connection):
|
||||||
|
"""Called when a connection has been established (or lost with None).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def finished_connecting(self):
|
||||||
|
"""Called when we have successfully subscribed and caught up to all
|
||||||
|
streams we're interested in.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
||||||
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
||||||
|
|
||||||
def __init__(self, client_name, server_name, clock, handler):
|
def __init__(
|
||||||
|
self,
|
||||||
|
client_name: str,
|
||||||
|
server_name: str,
|
||||||
|
clock: Clock,
|
||||||
|
handler: AbstractReplicationClientHandler,
|
||||||
|
):
|
||||||
BaseReplicationStreamProtocol.__init__(self, clock)
|
BaseReplicationStreamProtocol.__init__(self, clock)
|
||||||
|
|
||||||
self.client_name = client_name
|
self.client_name = client_name
|
||||||
|
@ -45,5 +45,6 @@ STREAMS_MAP = {
|
|||||||
_base.TagAccountDataStream,
|
_base.TagAccountDataStream,
|
||||||
_base.AccountDataStream,
|
_base.AccountDataStream,
|
||||||
_base.GroupServerStream,
|
_base.GroupServerStream,
|
||||||
|
_base.UserSignatureStream,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple(
|
|||||||
"GroupsStreamRow",
|
"GroupsStreamRow",
|
||||||
("group_id", "user_id", "type", "content"), # str # str # str # dict
|
("group_id", "user_id", "type", "content"), # str # str # str # dict
|
||||||
)
|
)
|
||||||
|
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
|
||||||
|
|
||||||
|
|
||||||
class Stream(object):
|
class Stream(object):
|
||||||
@ -438,3 +439,20 @@ class GroupServerStream(Stream):
|
|||||||
self.update_function = store.get_all_groups_changes
|
self.update_function = store.get_all_groups_changes
|
||||||
|
|
||||||
super(GroupServerStream, self).__init__(hs)
|
super(GroupServerStream, self).__init__(hs)
|
||||||
|
|
||||||
|
|
||||||
|
class UserSignatureStream(Stream):
|
||||||
|
"""A user has signed their own device with their user-signing key
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "user_signature"
|
||||||
|
_LIMITED = False
|
||||||
|
ROW_TYPE = UserSignatureStreamRow
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
store = hs.get_datastore()
|
||||||
|
|
||||||
|
self.current_token = store.get_device_stream_token
|
||||||
|
self.update_function = store.get_all_user_signature_changes_for_remotes
|
||||||
|
|
||||||
|
super(UserSignatureStream, self).__init__(hs)
|
||||||
|
@ -203,11 +203,12 @@ class LoginRestServlet(RestServlet):
|
|||||||
address = address.lower()
|
address = address.lower()
|
||||||
|
|
||||||
# Check for login providers that support 3pid login types
|
# Check for login providers that support 3pid login types
|
||||||
canonical_user_id, callback_3pid = (
|
(
|
||||||
yield self.auth_handler.check_password_provider_3pid(
|
canonical_user_id,
|
||||||
|
callback_3pid,
|
||||||
|
) = yield self.auth_handler.check_password_provider_3pid(
|
||||||
medium, address, login_submission["password"]
|
medium, address, login_submission["password"]
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if canonical_user_id:
|
if canonical_user_id:
|
||||||
# Authentication through password provider and 3pid succeeded
|
# Authentication through password provider and 3pid succeeded
|
||||||
result = yield self._register_device_with_callback(
|
result = yield self._register_device_with_callback(
|
||||||
@ -280,8 +281,8 @@ class LoginRestServlet(RestServlet):
|
|||||||
def do_token_login(self, login_submission):
|
def do_token_login(self, login_submission):
|
||||||
token = login_submission["token"]
|
token = login_submission["token"]
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id = (
|
user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
token
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield self._register_device_with_callback(user_id, login_submission)
|
result = yield self._register_device_with_callback(user_id, login_submission)
|
||||||
@ -380,7 +381,7 @@ class CasTicketServlet(RestServlet):
|
|||||||
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
|
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||||
self._http_client = hs.get_simple_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
|
@ -148,7 +148,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||||
self.failure_email_template, = load_jinja2_templates(
|
(self.failure_email_template,) = load_jinja2_templates(
|
||||||
self.config.email_template_dir,
|
self.config.email_template_dir,
|
||||||
[self.config.email_password_reset_template_failure_html],
|
[self.config.email_password_reset_template_failure_html],
|
||||||
)
|
)
|
||||||
@ -479,7 +479,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||||
self.failure_email_template, = load_jinja2_templates(
|
(self.failure_email_template,) = load_jinja2_templates(
|
||||||
self.config.email_template_dir,
|
self.config.email_template_dir,
|
||||||
[self.config.email_add_threepid_template_failure_html],
|
[self.config.email_add_threepid_template_failure_html],
|
||||||
)
|
)
|
||||||
|
@ -247,13 +247,13 @@ class RegistrationSubmitTokenServlet(RestServlet):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||||
self.failure_email_template, = load_jinja2_templates(
|
(self.failure_email_template,) = load_jinja2_templates(
|
||||||
self.config.email_template_dir,
|
self.config.email_template_dir,
|
||||||
[self.config.email_registration_template_failure_html],
|
[self.config.email_registration_template_failure_html],
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||||
self.failure_email_template, = load_jinja2_templates(
|
(self.failure_email_template,) = load_jinja2_templates(
|
||||||
self.config.email_template_dir,
|
self.config.email_template_dir,
|
||||||
[self.config.email_registration_template_failure_html],
|
[self.config.email_registration_template_failure_html],
|
||||||
)
|
)
|
||||||
|
@ -65,6 +65,9 @@ class VersionsRestServlet(RestServlet):
|
|||||||
"m.require_identity_server": False,
|
"m.require_identity_server": False,
|
||||||
# as per MSC2290
|
# as per MSC2290
|
||||||
"m.separate_add_and_bind": True,
|
"m.separate_add_and_bind": True,
|
||||||
|
# Implements support for label-based filtering as described in
|
||||||
|
# MSC2326.
|
||||||
|
"org.matrix.label_based_filtering": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -102,7 +102,7 @@ class RemoteKey(DirectServeResource):
|
|||||||
@wrap_json_request_handler
|
@wrap_json_request_handler
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
if len(request.postpath) == 1:
|
if len(request.postpath) == 1:
|
||||||
server, = request.postpath
|
(server,) = request.postpath
|
||||||
query = {server.decode("ascii"): {}}
|
query = {server.decode("ascii"): {}}
|
||||||
elif len(request.postpath) == 2:
|
elif len(request.postpath) == 2:
|
||||||
server, key_id = request.postpath
|
server, key_id = request.postpath
|
||||||
|
@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource):
|
|||||||
treq_args={"browser_like_redirects": True},
|
treq_args={"browser_like_redirects": True},
|
||||||
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
|
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
|
||||||
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
|
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
|
||||||
|
http_proxy=os.getenv("http_proxy"),
|
||||||
|
https_proxy=os.getenv("HTTPS_PROXY"),
|
||||||
)
|
)
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.primary_base_path = media_repo.primary_base_path
|
self.primary_base_path = media_repo.primary_base_path
|
||||||
|
@ -23,6 +23,7 @@
|
|||||||
# Imports required for the default HomeServer() implementation
|
# Imports required for the default HomeServer() implementation
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from twisted.enterprise import adbapi
|
from twisted.enterprise import adbapi
|
||||||
from twisted.mail.smtp import sendmail
|
from twisted.mail.smtp import sendmail
|
||||||
@ -168,6 +169,7 @@ class HomeServer(object):
|
|||||||
"filtering",
|
"filtering",
|
||||||
"http_client_context_factory",
|
"http_client_context_factory",
|
||||||
"simple_http_client",
|
"simple_http_client",
|
||||||
|
"proxied_http_client",
|
||||||
"media_repository",
|
"media_repository",
|
||||||
"media_repository_resource",
|
"media_repository_resource",
|
||||||
"federation_transport_client",
|
"federation_transport_client",
|
||||||
@ -311,6 +313,13 @@ class HomeServer(object):
|
|||||||
def build_simple_http_client(self):
|
def build_simple_http_client(self):
|
||||||
return SimpleHttpClient(self)
|
return SimpleHttpClient(self)
|
||||||
|
|
||||||
|
def build_proxied_http_client(self):
|
||||||
|
return SimpleHttpClient(
|
||||||
|
self,
|
||||||
|
http_proxy=os.getenv("http_proxy"),
|
||||||
|
https_proxy=os.getenv("HTTPS_PROXY"),
|
||||||
|
)
|
||||||
|
|
||||||
def build_room_creation_handler(self):
|
def build_room_creation_handler(self):
|
||||||
return RoomCreationHandler(self)
|
return RoomCreationHandler(self)
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ import synapse.handlers.message
|
|||||||
import synapse.handlers.room
|
import synapse.handlers.room
|
||||||
import synapse.handlers.room_member
|
import synapse.handlers.room_member
|
||||||
import synapse.handlers.set_password
|
import synapse.handlers.set_password
|
||||||
|
import synapse.http.client
|
||||||
import synapse.rest.media.v1.media_repository
|
import synapse.rest.media.v1.media_repository
|
||||||
import synapse.server_notices.server_notices_manager
|
import synapse.server_notices.server_notices_manager
|
||||||
import synapse.server_notices.server_notices_sender
|
import synapse.server_notices.server_notices_sender
|
||||||
@ -38,8 +39,16 @@ class HomeServer(object):
|
|||||||
pass
|
pass
|
||||||
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
|
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
|
||||||
pass
|
pass
|
||||||
|
def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
|
||||||
|
"""Fetch an HTTP client implementation which doesn't do any blacklisting
|
||||||
|
or support any HTTP_PROXY settings"""
|
||||||
|
pass
|
||||||
|
def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
|
||||||
|
"""Fetch an HTTP client implementation which doesn't do any blacklisting
|
||||||
|
but does support HTTP_PROXY settings"""
|
||||||
|
pass
|
||||||
def get_deactivate_account_handler(
|
def get_deactivate_account_handler(
|
||||||
self
|
self,
|
||||||
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
|
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
|
||||||
pass
|
pass
|
||||||
def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
|
def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
|
||||||
@ -47,32 +56,32 @@ class HomeServer(object):
|
|||||||
def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
|
def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
|
||||||
pass
|
pass
|
||||||
def get_event_creation_handler(
|
def get_event_creation_handler(
|
||||||
self
|
self,
|
||||||
) -> synapse.handlers.message.EventCreationHandler:
|
) -> synapse.handlers.message.EventCreationHandler:
|
||||||
pass
|
pass
|
||||||
def get_set_password_handler(
|
def get_set_password_handler(
|
||||||
self
|
self,
|
||||||
) -> synapse.handlers.set_password.SetPasswordHandler:
|
) -> synapse.handlers.set_password.SetPasswordHandler:
|
||||||
pass
|
pass
|
||||||
def get_federation_sender(self) -> synapse.federation.sender.FederationSender:
|
def get_federation_sender(self) -> synapse.federation.sender.FederationSender:
|
||||||
pass
|
pass
|
||||||
def get_federation_transport_client(
|
def get_federation_transport_client(
|
||||||
self
|
self,
|
||||||
) -> synapse.federation.transport.client.TransportLayerClient:
|
) -> synapse.federation.transport.client.TransportLayerClient:
|
||||||
pass
|
pass
|
||||||
def get_media_repository_resource(
|
def get_media_repository_resource(
|
||||||
self
|
self,
|
||||||
) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
|
) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
|
||||||
pass
|
pass
|
||||||
def get_media_repository(
|
def get_media_repository(
|
||||||
self
|
self,
|
||||||
) -> synapse.rest.media.v1.media_repository.MediaRepository:
|
) -> synapse.rest.media.v1.media_repository.MediaRepository:
|
||||||
pass
|
pass
|
||||||
def get_server_notices_manager(
|
def get_server_notices_manager(
|
||||||
self
|
self,
|
||||||
) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
|
) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
|
||||||
pass
|
pass
|
||||||
def get_server_notices_sender(
|
def get_server_notices_sender(
|
||||||
self
|
self,
|
||||||
) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
|
) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
|
||||||
pass
|
pass
|
||||||
|
@ -139,7 +139,10 @@ class DataStore(
|
|||||||
db_conn, "public_room_list_stream", "stream_id"
|
db_conn, "public_room_list_stream", "stream_id"
|
||||||
)
|
)
|
||||||
self._device_list_id_gen = StreamIdGenerator(
|
self._device_list_id_gen = StreamIdGenerator(
|
||||||
db_conn, "device_lists_stream", "stream_id"
|
db_conn,
|
||||||
|
"device_lists_stream",
|
||||||
|
"stream_id",
|
||||||
|
extra_tables=[("user_signature_stream", "stream_id")],
|
||||||
)
|
)
|
||||||
self._cross_signing_id_gen = StreamIdGenerator(
|
self._cross_signing_id_gen = StreamIdGenerator(
|
||||||
db_conn, "e2e_cross_signing_keys", "stream_id"
|
db_conn, "e2e_cross_signing_keys", "stream_id"
|
||||||
@ -317,7 +320,7 @@ class DataStore(
|
|||||||
) u
|
) u
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (time_from,))
|
txn.execute(sql, (time_from,))
|
||||||
count, = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def count_r30_users(self):
|
def count_r30_users(self):
|
||||||
@ -396,7 +399,7 @@ class DataStore(
|
|||||||
|
|
||||||
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
|
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
|
||||||
|
|
||||||
count, = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
results["all"] = count
|
results["all"] = count
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -37,6 +37,7 @@ from synapse.storage._base import (
|
|||||||
make_in_list_sql_clause,
|
make_in_list_sql_clause,
|
||||||
)
|
)
|
||||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||||
|
from synapse.types import get_verify_key_from_cross_signing_key
|
||||||
from synapse.util import batch_iter
|
from synapse.util import batch_iter
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||||
|
|
||||||
@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_devices_by_remote(self, destination, from_stream_id, limit):
|
def get_device_updates_by_remote(self, destination, from_stream_id, limit):
|
||||||
"""Get stream of updates to send to remote servers
|
"""Get a stream of device updates to send to the given remote server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): The host the device updates are intended for
|
||||||
|
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||||
|
limit (int): Maximum number of device updates to return
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[tuple[int, list[dict]]]:
|
Deferred[tuple[int, list[tuple[string,dict]]]]:
|
||||||
current stream id (ie, the stream id of the last update included in the
|
current stream id (ie, the stream id of the last update included in the
|
||||||
response), and the list of updates
|
response), and the list of updates, where each update is a pair of EDU
|
||||||
|
type and EDU contents
|
||||||
"""
|
"""
|
||||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
# stream_id; the rationale being that such a large device list update
|
# stream_id; the rationale being that such a large device list update
|
||||||
# is likely an error.
|
# is likely an error.
|
||||||
updates = yield self.runInteraction(
|
updates = yield self.runInteraction(
|
||||||
"get_devices_by_remote",
|
"get_device_updates_by_remote",
|
||||||
self._get_devices_by_remote_txn,
|
self._get_device_updates_by_remote_txn,
|
||||||
destination,
|
destination,
|
||||||
from_stream_id,
|
from_stream_id,
|
||||||
now_stream_id,
|
now_stream_id,
|
||||||
@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
if not updates:
|
if not updates:
|
||||||
return now_stream_id, []
|
return now_stream_id, []
|
||||||
|
|
||||||
|
# get the cross-signing keys of the users in the list, so that we can
|
||||||
|
# determine which of the device changes were cross-signing keys
|
||||||
|
users = set(r[0] for r in updates)
|
||||||
|
master_key_by_user = {}
|
||||||
|
self_signing_key_by_user = {}
|
||||||
|
for user in users:
|
||||||
|
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
|
||||||
|
if cross_signing_key:
|
||||||
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||||
|
cross_signing_key
|
||||||
|
)
|
||||||
|
# verify_key is a VerifyKey from signedjson, which uses
|
||||||
|
# .version to denote the portion of the key ID after the
|
||||||
|
# algorithm and colon, which is the device ID
|
||||||
|
master_key_by_user[user] = {
|
||||||
|
"key_info": cross_signing_key,
|
||||||
|
"device_id": verify_key.version,
|
||||||
|
}
|
||||||
|
|
||||||
|
cross_signing_key = yield self.get_e2e_cross_signing_key(
|
||||||
|
user, "self_signing"
|
||||||
|
)
|
||||||
|
if cross_signing_key:
|
||||||
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||||
|
cross_signing_key
|
||||||
|
)
|
||||||
|
self_signing_key_by_user[user] = {
|
||||||
|
"key_info": cross_signing_key,
|
||||||
|
"device_id": verify_key.version,
|
||||||
|
}
|
||||||
|
|
||||||
# if we have exceeded the limit, we need to exclude any results with the
|
# if we have exceeded the limit, we need to exclude any results with the
|
||||||
# same stream_id as the last row.
|
# same stream_id as the last row.
|
||||||
if len(updates) > limit:
|
if len(updates) > limit:
|
||||||
@ -153,15 +190,28 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
# context which created the Edu.
|
# context which created the Edu.
|
||||||
|
|
||||||
query_map = {}
|
query_map = {}
|
||||||
for update in updates:
|
cross_signing_keys_by_user = {}
|
||||||
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
|
for user_id, device_id, update_stream_id, update_context in updates:
|
||||||
|
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
|
||||||
# Stop processing updates
|
# Stop processing updates
|
||||||
break
|
break
|
||||||
|
|
||||||
key = (update[0], update[1])
|
if (
|
||||||
|
user_id in master_key_by_user
|
||||||
update_context = update[3]
|
and device_id == master_key_by_user[user_id]["device_id"]
|
||||||
update_stream_id = update[2]
|
):
|
||||||
|
result = cross_signing_keys_by_user.setdefault(user_id, {})
|
||||||
|
result["master_key"] = master_key_by_user[user_id]["key_info"]
|
||||||
|
elif (
|
||||||
|
user_id in self_signing_key_by_user
|
||||||
|
and device_id == self_signing_key_by_user[user_id]["device_id"]
|
||||||
|
):
|
||||||
|
result = cross_signing_keys_by_user.setdefault(user_id, {})
|
||||||
|
result["self_signing_key"] = self_signing_key_by_user[user_id][
|
||||||
|
"key_info"
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
key = (user_id, device_id)
|
||||||
|
|
||||||
previous_update_stream_id, _ = query_map.get(key, (0, None))
|
previous_update_stream_id, _ = query_map.get(key, (0, None))
|
||||||
|
|
||||||
@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
# devices, in which case E2E isn't going to work well anyway. We'll just
|
# devices, in which case E2E isn't going to work well anyway. We'll just
|
||||||
# skip that stream_id and return an empty list, and continue with the next
|
# skip that stream_id and return an empty list, and continue with the next
|
||||||
# stream_id next time.
|
# stream_id next time.
|
||||||
if not query_map:
|
if not query_map and not cross_signing_keys_by_user:
|
||||||
return stream_id_cutoff, []
|
return stream_id_cutoff, []
|
||||||
|
|
||||||
results = yield self._get_device_update_edus_by_remote(
|
results = yield self._get_device_update_edus_by_remote(
|
||||||
destination, from_stream_id, query_map
|
destination, from_stream_id, query_map
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add the updated cross-signing keys to the results list
|
||||||
|
for user_id, result in iteritems(cross_signing_keys_by_user):
|
||||||
|
result["user_id"] = user_id
|
||||||
|
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
|
||||||
|
results.append(("org.matrix.signing_key_update", result))
|
||||||
|
|
||||||
return now_stream_id, results
|
return now_stream_id, results
|
||||||
|
|
||||||
def _get_devices_by_remote_txn(
|
def _get_device_updates_by_remote_txn(
|
||||||
self, txn, destination, from_stream_id, now_stream_id, limit
|
self, txn, destination, from_stream_id, now_stream_id, limit
|
||||||
):
|
):
|
||||||
"""Return device update information for a given remote destination
|
"""Return device update information for a given remote destination
|
||||||
@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List: List of device updates
|
List: List of device updates
|
||||||
"""
|
"""
|
||||||
|
# get the list of device updates that need to be sent
|
||||||
sql = """
|
sql = """
|
||||||
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
|
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
|
||||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||||
@ -225,13 +282,17 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
List[Dict]: List of objects representing an device update EDU
|
List[Dict]: List of objects representing an device update EDU
|
||||||
|
|
||||||
"""
|
"""
|
||||||
devices = yield self.runInteraction(
|
devices = (
|
||||||
|
yield self.runInteraction(
|
||||||
"_get_e2e_device_keys_txn",
|
"_get_e2e_device_keys_txn",
|
||||||
self._get_e2e_device_keys_txn,
|
self._get_e2e_device_keys_txn,
|
||||||
query_map.keys(),
|
query_map.keys(),
|
||||||
include_all_devices=True,
|
include_all_devices=True,
|
||||||
include_deleted_devices=True,
|
include_deleted_devices=True,
|
||||||
)
|
)
|
||||||
|
if query_map
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for user_id, user_devices in iteritems(devices):
|
for user_id, user_devices in iteritems(devices):
|
||||||
@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
else:
|
else:
|
||||||
result["deleted"] = True
|
result["deleted"] = True
|
||||||
|
|
||||||
results.append(result)
|
results.append(("m.device_list_update", result))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
from_user_id,
|
from_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
|
||||||
|
"""Return a list of changes from the user signature stream to notify remotes.
|
||||||
|
Note that the user signature stream represents when a user signs their
|
||||||
|
device with their user-signing key, which is not published to other
|
||||||
|
users or servers, so no `destination` is needed in the returned
|
||||||
|
list. However, this is needed to poke workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_key (int): the stream ID to start at (exclusive)
|
||||||
|
to_key (int): the stream ID to end at (inclusive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
|
||||||
|
"""
|
||||||
|
sql = """
|
||||||
|
SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
|
||||||
|
FROM user_signature_stream
|
||||||
|
WHERE ? < stream_id AND stream_id <= ?
|
||||||
|
GROUP BY user_id
|
||||||
|
"""
|
||||||
|
return self._execute(
|
||||||
|
"get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
||||||
|
@ -863,7 +863,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||||||
)
|
)
|
||||||
stream_row = txn.fetchone()
|
stream_row = txn.fetchone()
|
||||||
if stream_row:
|
if stream_row:
|
||||||
offset_stream_ordering, = stream_row
|
(offset_stream_ordering,) = stream_row
|
||||||
rotate_to_stream_ordering = min(
|
rotate_to_stream_ordering = min(
|
||||||
self.stream_ordering_day_ago, offset_stream_ordering
|
self.stream_ordering_day_ago, offset_stream_ordering
|
||||||
)
|
)
|
||||||
|
@ -30,7 +30,7 @@ from prometheus_client import Counter
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventContentFields, EventTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.events import EventBase # noqa: F401
|
from synapse.events import EventBase # noqa: F401
|
||||||
from synapse.events.snapshot import EventContext # noqa: F401
|
from synapse.events.snapshot import EventContext # noqa: F401
|
||||||
@ -933,6 +933,13 @@ class EventsStore(
|
|||||||
|
|
||||||
self._handle_event_relations(txn, event)
|
self._handle_event_relations(txn, event)
|
||||||
|
|
||||||
|
# Store the labels for this event.
|
||||||
|
labels = event.content.get(EventContentFields.LABELS)
|
||||||
|
if labels:
|
||||||
|
self.insert_labels_for_event_txn(
|
||||||
|
txn, event.event_id, labels, event.room_id, event.depth
|
||||||
|
)
|
||||||
|
|
||||||
# Insert into the room_memberships table.
|
# Insert into the room_memberships table.
|
||||||
self._store_room_members_txn(
|
self._store_room_members_txn(
|
||||||
txn,
|
txn,
|
||||||
@ -1126,7 +1133,7 @@ class EventsStore(
|
|||||||
AND stream_ordering > ?
|
AND stream_ordering > ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||||
count, = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
ret = yield self.runInteraction("count_messages", _count_messages)
|
ret = yield self.runInteraction("count_messages", _count_messages)
|
||||||
@ -1147,7 +1154,7 @@ class EventsStore(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
|
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
|
||||||
count, = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
|
ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
|
||||||
@ -1162,7 +1169,7 @@ class EventsStore(
|
|||||||
AND stream_ordering > ?
|
AND stream_ordering > ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||||
count, = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
ret = yield self.runInteraction("count_daily_active_rooms", _count)
|
ret = yield self.runInteraction("count_daily_active_rooms", _count)
|
||||||
@ -1596,7 +1603,7 @@ class EventsStore(
|
|||||||
""",
|
""",
|
||||||
(room_id,),
|
(room_id,),
|
||||||
)
|
)
|
||||||
min_depth, = txn.fetchone()
|
(min_depth,) = txn.fetchone()
|
||||||
|
|
||||||
logger.info("[purge] updating room_depth to %d", min_depth)
|
logger.info("[purge] updating room_depth to %d", min_depth)
|
||||||
|
|
||||||
@ -1905,6 +1912,33 @@ class EventsStore(
|
|||||||
get_all_updated_current_state_deltas_txn,
|
get_all_updated_current_state_deltas_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def insert_labels_for_event_txn(
|
||||||
|
self, txn, event_id, labels, room_id, topological_ordering
|
||||||
|
):
|
||||||
|
"""Store the mapping between an event's ID and its labels, with one row per
|
||||||
|
(event_id, label) tuple.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn (LoggingTransaction): The transaction to execute.
|
||||||
|
event_id (str): The event's ID.
|
||||||
|
labels (list[str]): A list of text labels.
|
||||||
|
room_id (str): The ID of the room the event was sent to.
|
||||||
|
topological_ordering (int): The position of the event in the room's topology.
|
||||||
|
"""
|
||||||
|
return self._simple_insert_many_txn(
|
||||||
|
txn=txn,
|
||||||
|
table="event_labels",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"event_id": event_id,
|
||||||
|
"label": label,
|
||||||
|
"room_id": room_id,
|
||||||
|
"topological_ordering": topological_ordering,
|
||||||
|
}
|
||||||
|
for label in labels
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AllNewEventsResult = namedtuple(
|
AllNewEventsResult = namedtuple(
|
||||||
"AllNewEventsResult",
|
"AllNewEventsResult",
|
||||||
|
@ -438,7 +438,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
|
|||||||
if not rows:
|
if not rows:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
upper_event_id, = rows[-1]
|
(upper_event_id,) = rows[-1]
|
||||||
|
|
||||||
# Update the redactions with the received_ts.
|
# Update the redactions with the received_ts.
|
||||||
#
|
#
|
||||||
|
@ -249,7 +249,7 @@ class GroupServerStore(SQLBaseStore):
|
|||||||
WHERE group_id = ? AND category_id = ?
|
WHERE group_id = ? AND category_id = ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (group_id, category_id))
|
txn.execute(sql, (group_id, category_id))
|
||||||
order, = txn.fetchone()
|
(order,) = txn.fetchone()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
to_update = {}
|
to_update = {}
|
||||||
@ -509,7 +509,7 @@ class GroupServerStore(SQLBaseStore):
|
|||||||
WHERE group_id = ? AND role_id = ?
|
WHERE group_id = ? AND role_id = ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (group_id, role_id))
|
txn.execute(sql, (group_id, role_id))
|
||||||
order, = txn.fetchone()
|
(order,) = txn.fetchone()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
to_update = {}
|
to_update = {}
|
||||||
|
@ -171,7 +171,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
|||||||
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
|
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
|
||||||
|
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
count, = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
return self.runInteraction("count_users", _count_users)
|
return self.runInteraction("count_users", _count_users)
|
||||||
|
@ -143,7 +143,7 @@ class PushRulesWorkerStore(
|
|||||||
" WHERE user_id = ? AND ? < stream_id"
|
" WHERE user_id = ? AND ? < stream_id"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id, last_id))
|
txn.execute(sql, (user_id, last_id))
|
||||||
count, = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return bool(count)
|
return bool(count)
|
||||||
|
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
|
@ -459,7 +459,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
WHERE appservice_id IS NULL
|
WHERE appservice_id IS NULL
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
count, = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
ret = yield self.runInteraction("count_users", _count_users)
|
ret = yield self.runInteraction("count_users", _count_users)
|
||||||
|
@ -927,7 +927,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
|
|||||||
if not row or not row[0]:
|
if not row or not row[0]:
|
||||||
return processed, True
|
return processed, True
|
||||||
|
|
||||||
next_room, = row
|
(next_room,) = row
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
UPDATE current_state_events
|
UPDATE current_state_events
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- room_id and topoligical_ordering are denormalised from the events table in order to
|
||||||
|
-- make the index work.
|
||||||
|
CREATE TABLE IF NOT EXISTS event_labels (
|
||||||
|
event_id TEXT,
|
||||||
|
label TEXT,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
topological_ordering BIGINT NOT NULL,
|
||||||
|
PRIMARY KEY(event_id, label)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
-- This index enables an event pagination looking for a particular label to index the
|
||||||
|
-- event_labels table first, which is much quicker than scanning the events table and then
|
||||||
|
-- filtering by label, if the label is rarely used relative to the size of the room.
|
||||||
|
CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering);
|
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* Change the hidden column from a default value of FALSE to a default value of
|
||||||
|
* 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the
|
||||||
|
* string 'FALSE', which is truthy.
|
||||||
|
*
|
||||||
|
* Since sqlite doesn't allow us to just change the default value, we have to
|
||||||
|
* recreate the table, copy the data, fix the rows that have incorrect data, and
|
||||||
|
* replace the old table with the new table.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS devices2 (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
|
last_seen BIGINT,
|
||||||
|
ip TEXT,
|
||||||
|
user_agent TEXT,
|
||||||
|
hidden BOOLEAN DEFAULT 0,
|
||||||
|
CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO devices2 SELECT * FROM devices;
|
||||||
|
|
||||||
|
UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE';
|
||||||
|
|
||||||
|
DROP TABLE devices;
|
||||||
|
|
||||||
|
ALTER TABLE devices2 RENAME TO devices;
|
@ -672,7 +672,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
txn.execute(query, (value, search_query))
|
txn.execute(query, (value, search_query))
|
||||||
headline, = txn.fetchall()[0]
|
(headline,) = txn.fetchall()[0]
|
||||||
|
|
||||||
# Now we need to pick the possible highlights out of the haedline
|
# Now we need to pick the possible highlights out of the haedline
|
||||||
# result.
|
# result.
|
||||||
|
@ -725,17 +725,19 @@ class StateGroupWorkerStore(
|
|||||||
member_filter, non_member_filter = state_filter.get_member_split()
|
member_filter, non_member_filter = state_filter.get_member_split()
|
||||||
|
|
||||||
# Now we look them up in the member and non-member caches
|
# Now we look them up in the member and non-member caches
|
||||||
non_member_state, incomplete_groups_nm, = (
|
(
|
||||||
yield self._get_state_for_groups_using_cache(
|
non_member_state,
|
||||||
|
incomplete_groups_nm,
|
||||||
|
) = yield self._get_state_for_groups_using_cache(
|
||||||
groups, self._state_group_cache, state_filter=non_member_filter
|
groups, self._state_group_cache, state_filter=non_member_filter
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
member_state, incomplete_groups_m, = (
|
(
|
||||||
yield self._get_state_for_groups_using_cache(
|
member_state,
|
||||||
|
incomplete_groups_m,
|
||||||
|
) = yield self._get_state_for_groups_using_cache(
|
||||||
groups, self._state_group_members_cache, state_filter=member_filter
|
groups, self._state_group_members_cache, state_filter=member_filter
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
state = dict(non_member_state)
|
state = dict(non_member_state)
|
||||||
for group in groups:
|
for group in groups:
|
||||||
@ -1106,7 +1108,7 @@ class StateBackgroundUpdateStore(
|
|||||||
" WHERE id < ? AND room_id = ?",
|
" WHERE id < ? AND room_id = ?",
|
||||||
(state_group, room_id),
|
(state_group, room_id),
|
||||||
)
|
)
|
||||||
prev_group, = txn.fetchone()
|
(prev_group,) = txn.fetchone()
|
||||||
new_last_state_group = state_group
|
new_last_state_group = state_group
|
||||||
|
|
||||||
if prev_group:
|
if prev_group:
|
||||||
|
@ -773,7 +773,7 @@ class StatsStore(StateDeltasStore):
|
|||||||
(room_id,),
|
(room_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
current_state_events_count, = txn.fetchone()
|
(current_state_events_count,) = txn.fetchone()
|
||||||
|
|
||||||
users_in_room = self.get_users_in_room_txn(txn, room_id)
|
users_in_room = self.get_users_in_room_txn(txn, room_id)
|
||||||
|
|
||||||
@ -863,7 +863,7 @@ class StatsStore(StateDeltasStore):
|
|||||||
""",
|
""",
|
||||||
(user_id,),
|
(user_id,),
|
||||||
)
|
)
|
||||||
count, = txn.fetchone()
|
(count,) = txn.fetchone()
|
||||||
return count, pos
|
return count, pos
|
||||||
|
|
||||||
joined_rooms, pos = yield self.runInteraction(
|
joined_rooms, pos = yield self.runInteraction(
|
||||||
|
@ -229,6 +229,14 @@ def filter_to_clause(event_filter):
|
|||||||
clauses.append("contains_url = ?")
|
clauses.append("contains_url = ?")
|
||||||
args.append(event_filter.contains_url)
|
args.append(event_filter.contains_url)
|
||||||
|
|
||||||
|
# We're only applying the "labels" filter on the database query, because applying the
|
||||||
|
# "not_labels" filter via a SQL query is non-trivial. Instead, we let
|
||||||
|
# event_filter.check_fields apply it, which is not as efficient but makes the
|
||||||
|
# implementation simpler.
|
||||||
|
if event_filter.labels:
|
||||||
|
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
|
||||||
|
args.extend(event_filter.labels)
|
||||||
|
|
||||||
return " AND ".join(clauses), args
|
return " AND ".join(clauses), args
|
||||||
|
|
||||||
|
|
||||||
@ -864,8 +872,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||||||
args.append(int(limit))
|
args.append(int(limit))
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT event_id, topological_ordering, stream_ordering"
|
"SELECT DISTINCT event_id, topological_ordering, stream_ordering"
|
||||||
" FROM events"
|
" FROM events"
|
||||||
|
" LEFT JOIN event_labels USING (event_id, room_id, topological_ordering)"
|
||||||
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
|
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
|
||||||
" ORDER BY topological_ordering %(order)s,"
|
" ORDER BY topological_ordering %(order)s,"
|
||||||
" stream_ordering %(order)s LIMIT ?"
|
" stream_ordering %(order)s LIMIT ?"
|
||||||
|
@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1):
|
|||||||
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
|
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
|
||||||
else:
|
else:
|
||||||
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
|
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
|
||||||
val, = cur.fetchone()
|
(val,) = cur.fetchone()
|
||||||
cur.close()
|
cur.close()
|
||||||
current_id = int(val) if val else step
|
current_id = int(val) if val else step
|
||||||
return (max if step > 0 else min)(current_id, step)
|
return (max if step > 0 else min)(current_id, step)
|
||||||
|
@ -19,6 +19,7 @@ import jsonschema
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.constants import EventContentFields
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
@ -95,6 +96,8 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
"types": ["m.room.message"],
|
"types": ["m.room.message"],
|
||||||
"not_rooms": ["!726s6s6q:example.com"],
|
"not_rooms": ["!726s6s6q:example.com"],
|
||||||
"not_senders": ["@spam:example.com"],
|
"not_senders": ["@spam:example.com"],
|
||||||
|
"org.matrix.labels": ["#fun"],
|
||||||
|
"org.matrix.not_labels": ["#work"],
|
||||||
},
|
},
|
||||||
"ephemeral": {
|
"ephemeral": {
|
||||||
"types": ["m.receipt", "m.typing"],
|
"types": ["m.receipt", "m.typing"],
|
||||||
@ -320,6 +323,46 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
|
def test_filter_labels(self):
|
||||||
|
definition = {"org.matrix.labels": ["#fun"]}
|
||||||
|
event = MockEvent(
|
||||||
|
sender="@foo:bar",
|
||||||
|
type="m.room.message",
|
||||||
|
room_id="!secretbase:unknown",
|
||||||
|
content={EventContentFields.LABELS: ["#fun"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(Filter(definition).check(event))
|
||||||
|
|
||||||
|
event = MockEvent(
|
||||||
|
sender="@foo:bar",
|
||||||
|
type="m.room.message",
|
||||||
|
room_id="!secretbase:unknown",
|
||||||
|
content={EventContentFields.LABELS: ["#notfun"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
|
def test_filter_not_labels(self):
|
||||||
|
definition = {"org.matrix.not_labels": ["#fun"]}
|
||||||
|
event = MockEvent(
|
||||||
|
sender="@foo:bar",
|
||||||
|
type="m.room.message",
|
||||||
|
room_id="!secretbase:unknown",
|
||||||
|
content={EventContentFields.LABELS: ["#fun"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(Filter(definition).check(event))
|
||||||
|
|
||||||
|
event = MockEvent(
|
||||||
|
sender="@foo:bar",
|
||||||
|
type="m.room.message",
|
||||||
|
room_id="!secretbase:unknown",
|
||||||
|
content={EventContentFields.LABELS: ["#notfun"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(Filter(definition).check(event))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_filter_presence_match(self):
|
def test_filter_presence_match(self):
|
||||||
user_filter_json = {"presence": {"types": ["m.*"]}}
|
user_filter_json = {"presence": {"types": ["m.*"]}}
|
||||||
|
@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
"get_received_txn_response",
|
"get_received_txn_response",
|
||||||
"set_received_txn_response",
|
"set_received_txn_response",
|
||||||
"get_destination_retry_timings",
|
"get_destination_retry_timings",
|
||||||
"get_devices_by_remote",
|
"get_device_updates_by_remote",
|
||||||
# Bits that user_directory needs
|
# Bits that user_directory needs
|
||||||
"get_user_directory_stream_pos",
|
"get_user_directory_stream_pos",
|
||||||
"get_current_state_deltas",
|
"get_current_state_deltas",
|
||||||
@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
retry_timings_res
|
retry_timings_res
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore.get_devices_by_remote.return_value = (0, [])
|
self.datastore.get_device_updates_by_remote.return_value = (0, [])
|
||||||
|
|
||||||
def get_received_txn_response(*args):
|
def get_received_txn_response(*args):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
@ -20,6 +20,23 @@ from zope.interface import implementer
|
|||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
from OpenSSL.SSL import Connection
|
from OpenSSL.SSL import Connection
|
||||||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
||||||
|
from twisted.internet.ssl import Certificate, trustRootFromCertificates
|
||||||
|
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
|
||||||
|
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_https_policy():
|
||||||
|
"""Get a test IPolicyForHTTPS which trusts the test CA cert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IPolicyForHTTPS
|
||||||
|
"""
|
||||||
|
ca_file = get_test_ca_cert_file()
|
||||||
|
with open(ca_file) as stream:
|
||||||
|
content = stream.read()
|
||||||
|
cert = Certificate.loadPEM(content)
|
||||||
|
trust_root = trustRootFromCertificates([cert])
|
||||||
|
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
|
||||||
|
|
||||||
|
|
||||||
def get_test_ca_cert_file():
|
def get_test_ca_cert_file():
|
||||||
|
@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||||||
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
|
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# grab a hold of the TLS connection, in case it gets torn down
|
||||||
|
server_tls_connection = server_tls_protocol._tlsConnection
|
||||||
|
|
||||||
|
# fish the test server back out of the server-side TLS protocol.
|
||||||
|
http_protocol = server_tls_protocol.wrappedProtocol
|
||||||
|
|
||||||
# give the reactor a pump to get the TLS juices flowing.
|
# give the reactor a pump to get the TLS juices flowing.
|
||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
|
|
||||||
# check the SNI
|
# check the SNI
|
||||||
server_name = server_tls_protocol._tlsConnection.get_servername()
|
server_name = server_tls_connection.get_servername()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
server_name,
|
server_name,
|
||||||
expected_sni,
|
expected_sni,
|
||||||
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
# fish the test server back out of the server-side TLS protocol.
|
return http_protocol
|
||||||
return server_tls_protocol.wrappedProtocol
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _make_get_request(self, uri):
|
def _make_get_request(self, uri):
|
||||||
|
334
tests/http/test_proxyagent.py
Normal file
334
tests/http/test_proxyagent.py
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import treq
|
||||||
|
|
||||||
|
from twisted.internet import interfaces # noqa: F401
|
||||||
|
from twisted.internet.protocol import Factory
|
||||||
|
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||||
|
from twisted.web.http import HTTPChannel
|
||||||
|
|
||||||
|
from synapse.http.proxyagent import ProxyAgent
|
||||||
|
|
||||||
|
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
|
||||||
|
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||||
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HTTPFactory = Factory.forProtocol(HTTPChannel)
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixFederationAgentTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.reactor = ThreadedMemoryReactorClock()
|
||||||
|
|
||||||
|
def _make_connection(
|
||||||
|
self, client_factory, server_factory, ssl=False, expected_sni=None
|
||||||
|
):
|
||||||
|
"""Builds a test server, and completes the outgoing client connection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_factory (interfaces.IProtocolFactory): the the factory that the
|
||||||
|
application is trying to use to make the outbound connection. We will
|
||||||
|
invoke it to build the client Protocol
|
||||||
|
|
||||||
|
server_factory (interfaces.IProtocolFactory): a factory to build the
|
||||||
|
server-side protocol
|
||||||
|
|
||||||
|
ssl (bool): If true, we will expect an ssl connection and wrap
|
||||||
|
server_factory with a TLSMemoryBIOFactory
|
||||||
|
|
||||||
|
expected_sni (bytes|None): the expected SNI value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IProtocol: the server Protocol returned by server_factory
|
||||||
|
"""
|
||||||
|
if ssl:
|
||||||
|
server_factory = _wrap_server_factory_for_tls(server_factory)
|
||||||
|
|
||||||
|
server_protocol = server_factory.buildProtocol(None)
|
||||||
|
|
||||||
|
# now, tell the client protocol factory to build the client protocol,
|
||||||
|
# and wire the output of said protocol up to the server via
|
||||||
|
# a FakeTransport.
|
||||||
|
#
|
||||||
|
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||||
|
# stubbing that out here.
|
||||||
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
client_protocol.makeConnection(
|
||||||
|
FakeTransport(server_protocol, self.reactor, client_protocol)
|
||||||
|
)
|
||||||
|
|
||||||
|
# tell the server protocol to send its stuff back to the client, too
|
||||||
|
server_protocol.makeConnection(
|
||||||
|
FakeTransport(client_protocol, self.reactor, server_protocol)
|
||||||
|
)
|
||||||
|
|
||||||
|
if ssl:
|
||||||
|
http_protocol = server_protocol.wrappedProtocol
|
||||||
|
tls_connection = server_protocol._tlsConnection
|
||||||
|
else:
|
||||||
|
http_protocol = server_protocol
|
||||||
|
tls_connection = None
|
||||||
|
|
||||||
|
# give the reactor a pump to get the TLS juices flowing (if needed)
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
if expected_sni is not None:
|
||||||
|
server_name = tls_connection.get_servername()
|
||||||
|
self.assertEqual(
|
||||||
|
server_name,
|
||||||
|
expected_sni,
|
||||||
|
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
return http_protocol
|
||||||
|
|
||||||
|
def test_http_request(self):
|
||||||
|
agent = ProxyAgent(self.reactor)
|
||||||
|
|
||||||
|
self.reactor.lookups["test.com"] = "1.2.3.4"
|
||||||
|
d = agent.request(b"GET", b"http://test.com")
|
||||||
|
|
||||||
|
# there should be a pending TCP connection
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, "1.2.3.4")
|
||||||
|
self.assertEqual(port, 80)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory, _get_test_protocol_factory()
|
||||||
|
)
|
||||||
|
|
||||||
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
# now there should be a pending request
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"GET")
|
||||||
|
self.assertEqual(request.path, b"/")
|
||||||
|
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||||
|
request.write(b"result")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
resp = self.successResultOf(d)
|
||||||
|
body = self.successResultOf(treq.content(resp))
|
||||||
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
|
def test_https_request(self):
|
||||||
|
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
|
||||||
|
|
||||||
|
self.reactor.lookups["test.com"] = "1.2.3.4"
|
||||||
|
d = agent.request(b"GET", b"https://test.com/abc")
|
||||||
|
|
||||||
|
# there should be a pending TCP connection
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, "1.2.3.4")
|
||||||
|
self.assertEqual(port, 443)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory,
|
||||||
|
_get_test_protocol_factory(),
|
||||||
|
ssl=True,
|
||||||
|
expected_sni=b"test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
# now there should be a pending request
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"GET")
|
||||||
|
self.assertEqual(request.path, b"/abc")
|
||||||
|
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||||
|
request.write(b"result")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
resp = self.successResultOf(d)
|
||||||
|
body = self.successResultOf(treq.content(resp))
|
||||||
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
|
def test_http_request_via_proxy(self):
|
||||||
|
agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
|
||||||
|
|
||||||
|
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||||
|
d = agent.request(b"GET", b"http://test.com")
|
||||||
|
|
||||||
|
# there should be a pending TCP connection
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, "1.2.3.5")
|
||||||
|
self.assertEqual(port, 8888)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory, _get_test_protocol_factory()
|
||||||
|
)
|
||||||
|
|
||||||
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
# now there should be a pending request
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"GET")
|
||||||
|
self.assertEqual(request.path, b"http://test.com")
|
||||||
|
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||||
|
request.write(b"result")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
resp = self.successResultOf(d)
|
||||||
|
body = self.successResultOf(treq.content(resp))
|
||||||
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
|
def test_https_request_via_proxy(self):
|
||||||
|
agent = ProxyAgent(
|
||||||
|
self.reactor,
|
||||||
|
contextFactory=get_test_https_policy(),
|
||||||
|
https_proxy=b"proxy.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||||
|
d = agent.request(b"GET", b"https://test.com/abc")
|
||||||
|
|
||||||
|
# there should be a pending TCP connection
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, "1.2.3.5")
|
||||||
|
self.assertEqual(port, 1080)
|
||||||
|
|
||||||
|
# make a test HTTP server, and wire up the client
|
||||||
|
proxy_server = self._make_connection(
|
||||||
|
client_factory, _get_test_protocol_factory()
|
||||||
|
)
|
||||||
|
|
||||||
|
# fish the transports back out so that we can do the old switcheroo
|
||||||
|
s2c_transport = proxy_server.transport
|
||||||
|
client_protocol = s2c_transport.other
|
||||||
|
c2s_transport = client_protocol.transport
|
||||||
|
|
||||||
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
# now there should be a pending CONNECT request
|
||||||
|
self.assertEqual(len(proxy_server.requests), 1)
|
||||||
|
|
||||||
|
request = proxy_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"CONNECT")
|
||||||
|
self.assertEqual(request.path, b"test.com:443")
|
||||||
|
|
||||||
|
# tell the proxy server not to close the connection
|
||||||
|
proxy_server.persistent = True
|
||||||
|
|
||||||
|
# this just stops the http Request trying to do a chunked response
|
||||||
|
# request.setHeader(b"Content-Length", b"0")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
|
||||||
|
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
|
||||||
|
ssl_protocol = ssl_factory.buildProtocol(None)
|
||||||
|
http_server = ssl_protocol.wrappedProtocol
|
||||||
|
|
||||||
|
ssl_protocol.makeConnection(
|
||||||
|
FakeTransport(client_protocol, self.reactor, ssl_protocol)
|
||||||
|
)
|
||||||
|
c2s_transport.other = ssl_protocol
|
||||||
|
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
server_name = ssl_protocol._tlsConnection.get_servername()
|
||||||
|
expected_sni = b"test.com"
|
||||||
|
self.assertEqual(
|
||||||
|
server_name,
|
||||||
|
expected_sni,
|
||||||
|
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
# now there should be a pending request
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"GET")
|
||||||
|
self.assertEqual(request.path, b"/abc")
|
||||||
|
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||||
|
request.write(b"result")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
resp = self.successResultOf(d)
|
||||||
|
body = self.successResultOf(treq.content(resp))
|
||||||
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_server_factory_for_tls(factory, sanlist=None):
|
||||||
|
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
||||||
|
|
||||||
|
The resultant factory will create a TLS server which presents a certificate
|
||||||
|
signed by our test CA, valid for the domains in `sanlist`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factory (interfaces.IProtocolFactory): protocol factory to wrap
|
||||||
|
sanlist (iterable[bytes]): list of domains the cert should be valid for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
interfaces.IProtocolFactory
|
||||||
|
"""
|
||||||
|
if sanlist is None:
|
||||||
|
sanlist = [b"DNS:test.com"]
|
||||||
|
|
||||||
|
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
|
||||||
|
return TLSMemoryBIOFactory(
|
||||||
|
connection_creator, isClient=False, wrappedFactory=factory
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_test_protocol_factory():
|
||||||
|
"""Get a protocol Factory which will build an HTTPChannel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
interfaces.IProtocolFactory
|
||||||
|
"""
|
||||||
|
server_factory = Factory.forProtocol(HTTPChannel)
|
||||||
|
|
||||||
|
# Request.finish expects the factory to have a 'log' method.
|
||||||
|
server_factory.log = _log_request
|
||||||
|
|
||||||
|
return server_factory
|
||||||
|
|
||||||
|
|
||||||
|
def _log_request(request):
|
||||||
|
"""Implements Factory.log, which is expected by Request.finish"""
|
||||||
|
logger.info("Completed request %s", request)
|
@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
config["start_pushers"] = True
|
config["start_pushers"] = True
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(config=config, simple_http_client=m)
|
hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||||
from synapse.rest.client.v1 import login, profile, room
|
from synapse.rest.client.v1 import login, profile, room
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@ -811,6 +811,105 @@ class RoomMessageListTestCase(RoomBase):
|
|||||||
self.assertTrue("chunk" in channel.json_body)
|
self.assertTrue("chunk" in channel.json_body)
|
||||||
self.assertTrue("end" in channel.json_body)
|
self.assertTrue("end" in channel.json_body)
|
||||||
|
|
||||||
|
def test_filter_labels(self):
|
||||||
|
"""Test that we can filter by a label."""
|
||||||
|
message_filter = json.dumps(
|
||||||
|
{"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
events = self._test_filter_labels(message_filter)
|
||||||
|
|
||||||
|
self.assertEqual(len(events), 2, [event["content"] for event in events])
|
||||||
|
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
||||||
|
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
||||||
|
|
||||||
|
def test_filter_not_labels(self):
|
||||||
|
"""Test that we can filter by the absence of a label."""
|
||||||
|
message_filter = json.dumps(
|
||||||
|
{"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
events = self._test_filter_labels(message_filter)
|
||||||
|
|
||||||
|
self.assertEqual(len(events), 3, [event["content"] for event in events])
|
||||||
|
self.assertEqual(events[0]["content"]["body"], "without label", events[0])
|
||||||
|
self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
|
||||||
|
self.assertEqual(
|
||||||
|
events[2]["content"]["body"], "with two wrong labels", events[2]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_labels_not_labels(self):
|
||||||
|
"""Test that we can filter by both a label and the absence of another label."""
|
||||||
|
sync_filter = json.dumps(
|
||||||
|
{
|
||||||
|
"types": [EventTypes.Message],
|
||||||
|
"org.matrix.labels": ["#work"],
|
||||||
|
"org.matrix.not_labels": ["#notfun"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
events = self._test_filter_labels(sync_filter)
|
||||||
|
|
||||||
|
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
||||||
|
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
||||||
|
|
||||||
|
def _test_filter_labels(self, message_filter):
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"body": "with right label",
|
||||||
|
EventContentFields.LABELS: ["#fun"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={"msgtype": "m.text", "body": "without label"},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"body": "with wrong label",
|
||||||
|
EventContentFields.LABELS: ["#work"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"body": "with two wrong labels",
|
||||||
|
EventContentFields.LABELS: ["#work", "#notfun"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"body": "with right label",
|
||||||
|
EventContentFields.LABELS: ["#fun"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
token = "s0_0_0_0_0_0_0_0_0"
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/rooms/%s/messages?access_token=x&from=%s&filter=%s"
|
||||||
|
% (self.room_id, token, message_filter),
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
return channel.json_body["chunk"]
|
||||||
|
|
||||||
|
|
||||||
class RoomSearchTestCase(unittest.HomeserverTestCase):
|
class RoomSearchTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
|
@ -106,13 +106,22 @@ class RestHelper(object):
|
|||||||
self.auth_user_id = temp_id
|
self.auth_user_id = temp_id
|
||||||
|
|
||||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||||
if txn_id is None:
|
|
||||||
txn_id = "m%s" % (str(time.time()))
|
|
||||||
if body is None:
|
if body is None:
|
||||||
body = "body_text_here"
|
body = "body_text_here"
|
||||||
|
|
||||||
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
|
|
||||||
content = {"msgtype": "m.text", "body": body}
|
content = {"msgtype": "m.text", "body": body}
|
||||||
|
|
||||||
|
return self.send_event(
|
||||||
|
room_id, "m.room.message", content, txn_id, tok, expect_code
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_event(
|
||||||
|
self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
|
||||||
|
):
|
||||||
|
if txn_id is None:
|
||||||
|
txn_id = "m%s" % (str(time.time()))
|
||||||
|
|
||||||
|
path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id)
|
||||||
if tok:
|
if tok:
|
||||||
path = path + "?access_token=%s" % tok
|
path = path + "?access_token=%s" % tok
|
||||||
|
|
||||||
|
@ -12,10 +12,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
|
from synapse.api.constants import EventContentFields, EventTypes
|
||||||
from synapse.rest.client.v1 import login, room
|
from synapse.rest.client.v1 import login, room
|
||||||
from synapse.rest.client.v2_alpha import sync
|
from synapse.rest.client.v2_alpha import sync
|
||||||
|
|
||||||
@ -26,7 +28,12 @@ from tests.server import TimedOutException
|
|||||||
class FilterTestCase(unittest.HomeserverTestCase):
|
class FilterTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
user_id = "@apple:test"
|
user_id = "@apple:test"
|
||||||
servlets = [sync.register_servlets]
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
|
room.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
sync.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
@ -70,6 +77,140 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SyncFilterTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
|
room.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
sync.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_sync_filter_labels(self):
|
||||||
|
"""Test that we can filter by a label."""
|
||||||
|
sync_filter = json.dumps(
|
||||||
|
{
|
||||||
|
"room": {
|
||||||
|
"timeline": {
|
||||||
|
"types": [EventTypes.Message],
|
||||||
|
"org.matrix.labels": ["#fun"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
events = self._test_sync_filter_labels(sync_filter)
|
||||||
|
|
||||||
|
self.assertEqual(len(events), 2, [event["content"] for event in events])
|
||||||
|
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
||||||
|
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
||||||
|
|
||||||
|
def test_sync_filter_not_labels(self):
|
||||||
|
"""Test that we can filter by the absence of a label."""
|
||||||
|
sync_filter = json.dumps(
|
||||||
|
{
|
||||||
|
"room": {
|
||||||
|
"timeline": {
|
||||||
|
"types": [EventTypes.Message],
|
||||||
|
"org.matrix.not_labels": ["#fun"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
events = self._test_sync_filter_labels(sync_filter)
|
||||||
|
|
||||||
|
self.assertEqual(len(events), 3, [event["content"] for event in events])
|
||||||
|
self.assertEqual(events[0]["content"]["body"], "without label", events[0])
|
||||||
|
self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
|
||||||
|
self.assertEqual(
|
||||||
|
events[2]["content"]["body"], "with two wrong labels", events[2]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sync_filter_labels_not_labels(self):
|
||||||
|
"""Test that we can filter by both a label and the absence of another label."""
|
||||||
|
sync_filter = json.dumps(
|
||||||
|
{
|
||||||
|
"room": {
|
||||||
|
"timeline": {
|
||||||
|
"types": [EventTypes.Message],
|
||||||
|
"org.matrix.labels": ["#work"],
|
||||||
|
"org.matrix.not_labels": ["#notfun"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
events = self._test_sync_filter_labels(sync_filter)
|
||||||
|
|
||||||
|
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
||||||
|
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
||||||
|
|
||||||
|
def _test_sync_filter_labels(self, sync_filter):
|
||||||
|
user_id = self.register_user("kermit", "test")
|
||||||
|
tok = self.login("kermit", "test")
|
||||||
|
|
||||||
|
room_id = self.helper.create_room_as(user_id, tok=tok)
|
||||||
|
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"body": "with right label",
|
||||||
|
EventContentFields.LABELS: ["#fun"],
|
||||||
|
},
|
||||||
|
tok=tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={"msgtype": "m.text", "body": "without label"},
|
||||||
|
tok=tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"body": "with wrong label",
|
||||||
|
EventContentFields.LABELS: ["#work"],
|
||||||
|
},
|
||||||
|
tok=tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"body": "with two wrong labels",
|
||||||
|
EventContentFields.LABELS: ["#work", "#notfun"],
|
||||||
|
},
|
||||||
|
tok=tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"body": "with right label",
|
||||||
|
EventContentFields.LABELS: ["#fun"],
|
||||||
|
},
|
||||||
|
tok=tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "/sync?filter=%s" % sync_filter, access_token=tok
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
|
return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
|
||||||
|
|
||||||
|
|
||||||
class SyncTypingTests(unittest.HomeserverTestCase):
|
class SyncTypingTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [
|
servlets = [
|
||||||
|
@ -395,11 +395,24 @@ class FakeTransport(object):
|
|||||||
self.disconnecting = True
|
self.disconnecting = True
|
||||||
if self._protocol:
|
if self._protocol:
|
||||||
self._protocol.connectionLost(reason)
|
self._protocol.connectionLost(reason)
|
||||||
|
|
||||||
|
# if we still have data to write, delay until that is done
|
||||||
|
if self.buffer:
|
||||||
|
logger.info(
|
||||||
|
"FakeTransport: Delaying disconnect until buffer is flushed"
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.disconnected = True
|
self.disconnected = True
|
||||||
|
|
||||||
def abortConnection(self):
|
def abortConnection(self):
|
||||||
logger.info("FakeTransport: abortConnection()")
|
logger.info("FakeTransport: abortConnection()")
|
||||||
self.loseConnection()
|
|
||||||
|
if not self.disconnecting:
|
||||||
|
self.disconnecting = True
|
||||||
|
if self._protocol:
|
||||||
|
self._protocol.connectionLost(None)
|
||||||
|
|
||||||
|
self.disconnected = True
|
||||||
|
|
||||||
def pauseProducing(self):
|
def pauseProducing(self):
|
||||||
if not self.producer:
|
if not self.producer:
|
||||||
@ -430,6 +443,9 @@ class FakeTransport(object):
|
|||||||
self._reactor.callLater(0.0, _produce)
|
self._reactor.callLater(0.0, _produce)
|
||||||
|
|
||||||
def write(self, byt):
|
def write(self, byt):
|
||||||
|
if self.disconnecting:
|
||||||
|
raise Exception("Writing to disconnecting FakeTransport")
|
||||||
|
|
||||||
self.buffer = self.buffer + byt
|
self.buffer = self.buffer + byt
|
||||||
|
|
||||||
# always actually do the write asynchronously. Some protocols (notably the
|
# always actually do the write asynchronously. Some protocols (notably the
|
||||||
@ -474,6 +490,10 @@ class FakeTransport(object):
|
|||||||
if self.buffer and self.autoflush:
|
if self.buffer and self.autoflush:
|
||||||
self._reactor.callLater(0.0, self.flush)
|
self._reactor.callLater(0.0, self.flush)
|
||||||
|
|
||||||
|
if not self.buffer and self.disconnecting:
|
||||||
|
logger.info("FakeTransport: Buffer now empty, completing disconnect")
|
||||||
|
self.disconnected = True
|
||||||
|
|
||||||
|
|
||||||
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
|
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
|
||||||
"""
|
"""
|
||||||
|
@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_devices_by_remote(self):
|
def test_get_device_updates_by_remote(self):
|
||||||
device_ids = ["device_id1", "device_id2"]
|
device_ids = ["device_id1", "device_id2"]
|
||||||
|
|
||||||
# Add two device updates with a single stream_id
|
# Add two device updates with a single stream_id
|
||||||
@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get all device updates ever meant for this remote
|
# Get all device updates ever meant for this remote
|
||||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||||
"somehost", -1, limit=100
|
"somehost", -1, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||||||
self._check_devices_in_updates(device_ids, device_updates)
|
self._check_devices_in_updates(device_ids, device_updates)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_devices_by_remote_limited(self):
|
def test_get_device_updates_by_remote_limited(self):
|
||||||
# Test breaking the update limit in 1, 101, and 1 device_id segments
|
# Test breaking the update limit in 1, 101, and 1 device_id segments
|
||||||
|
|
||||||
# first add one device
|
# first add one device
|
||||||
@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||||||
#
|
#
|
||||||
|
|
||||||
# first we should get a single update
|
# first we should get a single update
|
||||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||||
"someotherhost", -1, limit=100
|
"someotherhost", -1, limit=100
|
||||||
)
|
)
|
||||||
self._check_devices_in_updates(device_ids1, device_updates)
|
self._check_devices_in_updates(device_ids1, device_updates)
|
||||||
|
|
||||||
# Then we should get an empty list back as the 101 devices broke the limit
|
# Then we should get an empty list back as the 101 devices broke the limit
|
||||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||||
"someotherhost", now_stream_id, limit=100
|
"someotherhost", now_stream_id, limit=100
|
||||||
)
|
)
|
||||||
self.assertEqual(len(device_updates), 0)
|
self.assertEqual(len(device_updates), 0)
|
||||||
|
|
||||||
# The 101 devices should've been cleared, so we should now just get one device
|
# The 101 devices should've been cleared, so we should now just get one device
|
||||||
# update
|
# update
|
||||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||||
"someotherhost", now_stream_id, limit=100
|
"someotherhost", now_stream_id, limit=100
|
||||||
)
|
)
|
||||||
self._check_devices_in_updates(device_ids3, device_updates)
|
self._check_devices_in_updates(device_ids3, device_updates)
|
||||||
@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||||||
"""Check that an specific device ids exist in a list of device update EDUs"""
|
"""Check that an specific device ids exist in a list of device update EDUs"""
|
||||||
self.assertEqual(len(device_updates), len(expected_device_ids))
|
self.assertEqual(len(device_updates), len(expected_device_ids))
|
||||||
|
|
||||||
received_device_ids = {update["device_id"] for update in device_updates}
|
received_device_ids = {
|
||||||
|
update["device_id"] for edu_type, update in device_updates
|
||||||
|
}
|
||||||
self.assertEqual(received_device_ids, set(expected_device_ids))
|
self.assertEqual(received_device_ids, set(expected_device_ids))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.handler = self.homeserver.get_handlers().federation_handler
|
self.handler = self.homeserver.get_handlers().federation_handler
|
||||||
self.handler.do_auth = lambda *a, **b: succeed(True)
|
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
|
||||||
|
context
|
||||||
|
)
|
||||||
self.client = self.homeserver.get_federation_client()
|
self.client = self.homeserver.get_federation_client()
|
||||||
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
|
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
|
||||||
pdus
|
pdus
|
||||||
|
4
tox.ini
4
tox.ini
@ -114,7 +114,7 @@ skip_install = True
|
|||||||
basepython = python3.6
|
basepython = python3.6
|
||||||
deps =
|
deps =
|
||||||
flake8
|
flake8
|
||||||
black==19.3b0 # We pin so that our tests don't start failing on new releases of black.
|
black==19.10b0 # We pin so that our tests don't start failing on new releases of black.
|
||||||
commands =
|
commands =
|
||||||
python -m black --check --diff .
|
python -m black --check --diff .
|
||||||
/bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}"
|
/bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}"
|
||||||
@ -167,6 +167,6 @@ deps =
|
|||||||
env =
|
env =
|
||||||
MYPYPATH = stubs/
|
MYPYPATH = stubs/
|
||||||
extras = all
|
extras = all
|
||||||
commands = mypy --show-traceback --check-untyped-defs --show-error-codes --follow-imports=normal \
|
commands = mypy \
|
||||||
synapse/logging/ \
|
synapse/logging/ \
|
||||||
synapse/config/
|
synapse/config/
|
||||||
|
Loading…
Reference in New Issue
Block a user