Merge branch 'develop' into send_sni_for_federation_requests

# Conflicts:
#	synapse/http/endpoint.py
This commit is contained in:
Jeroen 2018-07-09 08:51:11 +02:00
commit b5e157d895
88 changed files with 1672 additions and 204 deletions

View File

@ -4,7 +4,12 @@ language: python
# tell travis to cache ~/.cache/pip # tell travis to cache ~/.cache/pip
cache: pip cache: pip
before_script:
- git remote set-branches --add origin develop
- git fetch origin develop
matrix: matrix:
fast_finish: true
include: include:
- python: 2.7 - python: 2.7
env: TOX_ENV=packaging env: TOX_ENV=packaging
@ -18,6 +23,9 @@ matrix:
- python: 3.6 - python: 3.6
env: TOX_ENV=py36 env: TOX_ENV=py36
- python: 3.6
env: TOX_ENV=check-newsfragment
install: install:
- pip install tox - pip install tox

View File

@ -1,3 +1,72 @@
Synapse 0.32.2 (2018-07-07)
===========================
Bugfixes
--------
- Amend the Python dependencies to depend on attrs from PyPI, not attr (`#3492 <https://github.com/matrix-org/synapse/issues/3492>`_)
Synapse 0.32.1 (2018-07-06)
===========================
Bugfixes
--------
- Add explicit dependency on netaddr (`#3488 <https://github.com/matrix-org/synapse/issues/3488>`_)
Changes in synapse v0.32.0 (2018-07-06)
===========================================
No changes since 0.32.0rc1
Synapse 0.32.0rc1 (2018-07-05)
==============================
Features
--------
- Add blacklist & whitelist of servers allowed to send events to a room via ``m.room.server_acl`` event.
- Cache factor override system for specific caches (`#3334 <https://github.com/matrix-org/synapse/issues/3334>`_)
- Add metrics to track appservice transactions (`#3344 <https://github.com/matrix-org/synapse/issues/3344>`_)
- Try to log more helpful info when a sig verification fails (`#3372 <https://github.com/matrix-org/synapse/issues/3372>`_)
- Synapse now uses the best performing JSON encoder/decoder according to your runtime (simplejson on CPython, stdlib json on PyPy). (`#3462 <https://github.com/matrix-org/synapse/issues/3462>`_)
- Add optional ip_range_whitelist param to AS registration files to lock AS IP access (`#3465 <https://github.com/matrix-org/synapse/issues/3465>`_)
- Reject invalid server names in federation requests (`#3480 <https://github.com/matrix-org/synapse/issues/3480>`_)
- Reject invalid server names in homeserver.yaml (`#3483 <https://github.com/matrix-org/synapse/issues/3483>`_)
Bugfixes
--------
- Strip access_token from outgoing requests (`#3327 <https://github.com/matrix-org/synapse/issues/3327>`_)
- Redact AS tokens in logs (`#3349 <https://github.com/matrix-org/synapse/issues/3349>`_)
- Fix federation backfill from SQLite servers (`#3355 <https://github.com/matrix-org/synapse/issues/3355>`_)
- Fix event-purge-by-ts admin API (`#3363 <https://github.com/matrix-org/synapse/issues/3363>`_)
- Fix event filtering in get_missing_events handler (`#3371 <https://github.com/matrix-org/synapse/issues/3371>`_)
- Synapse is now stricter regarding accepting events which it cannot retrieve the prev_events for. (`#3456 <https://github.com/matrix-org/synapse/issues/3456>`_)
- Fix bug where synapse would explode when receiving unicode in HTTP User-Agent header (`#3470 <https://github.com/matrix-org/synapse/issues/3470>`_)
- Invalidate cache on correct thread to avoid race (`#3473 <https://github.com/matrix-org/synapse/issues/3473>`_)
Improved Documentation
----------------------
- ``doc/postgres.rst``: fix display of the last command block. Thanks to @ArchangeGabriel! (`#3340 <https://github.com/matrix-org/synapse/issues/3340>`_)
Deprecations and Removals
-------------------------
- Remove was_forgotten_at (`#3324 <https://github.com/matrix-org/synapse/issues/3324>`_)
Misc
----
- `#3332 <https://github.com/matrix-org/synapse/issues/3332>`_, `#3341 <https://github.com/matrix-org/synapse/issues/3341>`_, `#3347 <https://github.com/matrix-org/synapse/issues/3347>`_, `#3348 <https://github.com/matrix-org/synapse/issues/3348>`_, `#3356 <https://github.com/matrix-org/synapse/issues/3356>`_, `#3385 <https://github.com/matrix-org/synapse/issues/3385>`_, `#3446 <https://github.com/matrix-org/synapse/issues/3446>`_, `#3447 <https://github.com/matrix-org/synapse/issues/3447>`_, `#3467 <https://github.com/matrix-org/synapse/issues/3467>`_, `#3474 <https://github.com/matrix-org/synapse/issues/3474>`_
Changes in synapse v0.31.2 (2018-06-14) Changes in synapse v0.31.2 (2018-06-14)
======================================= =======================================

View File

@ -48,6 +48,26 @@ Please ensure your changes match the cosmetic style of the existing project,
and **never** mix cosmetic and functional changes in the same commit, as it and **never** mix cosmetic and functional changes in the same commit, as it
makes it horribly hard to review otherwise. makes it horribly hard to review otherwise.
Changelog
~~~~~~~~~
All changes, even minor ones, need a corresponding changelog
entry. These are managed by Towncrier
(https://github.com/hawkowl/towncrier).
To create a changelog entry, make a new file in the ``changelog.d``
file named in the format of ``issuenumberOrPR.type``. The type can be
one of ``feature``, ``bugfix``, ``removal`` (also used for
deprecations), or ``misc`` (for internal-only changes). The content of
the file is your changelog entry, which can contain RestructuredText
formatting. A note of contributors is welcomed in changelogs for
non-misc changes (the content of misc changes is not displayed).
For example, a fix for a bug reported in #1234 would have its
changelog entry in ``changelog.d/1234.bugfix``, and contain content
like "The security levels of Florbs are now validated when
recieved over federation. Contributed by Jane Matrix".
Attribution Attribution
~~~~~~~~~~~ ~~~~~~~~~~~
@ -111,10 +131,14 @@ include the line in your commit or pull request comment::
Signed-off-by: Your Name <your@email.example.org> Signed-off-by: Your Name <your@email.example.org>
...using your real name; unfortunately pseudonyms and anonymous contributions We accept contributions under a legally identifiable name, such as
can't be accepted. Git makes this trivial - just use the -s flag when you do your name on government documentation or common-law names (names
``git commit``, having first set ``user.name`` and ``user.email`` git configs claimed by legitimate usage or repute). Unfortunately, we cannot
(which you should have done anyway :) accept anonymous contributions at this time.
Git allows you to add this signoff automatically when using the ``-s``
flag to ``git commit``, which uses the name and email set in your
``user.name`` and ``user.email`` git configs.
Conclusion Conclusion
~~~~~~~~~~ ~~~~~~~~~~

View File

@ -29,5 +29,8 @@ exclude Dockerfile
exclude .dockerignore exclude .dockerignore
recursive-exclude jenkins *.sh recursive-exclude jenkins *.sh
include pyproject.toml
recursive-include changelog.d *
prune .github prune .github
prune demo/etc prune demo/etc

1
changelog.d/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
!.gitignore

0
changelog.d/3463.misc Normal file
View File

View File

@ -44,13 +44,26 @@ Deactivate Account
This API deactivates an account. It removes active access tokens, resets the This API deactivates an account. It removes active access tokens, resets the
password, and deletes third-party IDs (to prevent the user requesting a password, and deletes third-party IDs (to prevent the user requesting a
password reset). password reset). It can also mark the user as GDPR-erased (stopping their data
from distributed further, and deleting it entirely if there are no other
references to it).
The api is:: The api is::
POST /_matrix/client/r0/admin/deactivate/<user_id> POST /_matrix/client/r0/admin/deactivate/<user_id>
including an ``access_token`` of a server admin, and an empty request body. with a body of:
.. code:: json
{
"erase": true
}
including an ``access_token`` of a server admin.
The erase parameter is optional and defaults to 'false'.
An empty body may be passed for backwards compatibility.
Reset password Reset password

5
pyproject.toml Normal file
View File

@ -0,0 +1,5 @@
[tool.towncrier]
package = "synapse"
filename = "CHANGES.rst"
directory = "changelog.d"
issue_format = "`#{issue} <https://github.com/matrix-org/synapse/issues/{issue}>`_"

View File

@ -19,3 +19,15 @@ max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. # W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
# E203 is contrary to PEP8. # E203 is contrary to PEP8.
ignore = W503,E203 ignore = W503,E203
[isort]
line_length = 89
not_skip = __init__.py
sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER
default_section=THIRDPARTY
known_first_party = synapse
known_tests=tests
known_compat = mock,six
known_twisted=twisted,OpenSSL
multi_line_output=3
include_trailing_comma=true

View File

@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.31.2" __version__ = "0.32.2"

View File

@ -19,6 +19,7 @@ from six import itervalues
import pymacaroons import pymacaroons
from twisted.internet import defer from twisted.internet import defer
from netaddr import IPAddress
import synapse.types import synapse.types
from synapse import event_auth from synapse import event_auth
@ -244,6 +245,11 @@ class Auth(object):
if app_service is None: if app_service is None:
defer.returnValue((None, None)) defer.returnValue((None, None))
if app_service.ip_range_whitelist:
ip_address = IPAddress(self.hs.get_ip_from_request(request))
if ip_address not in app_service.ip_range_whitelist:
defer.returnValue((None, None))
if "user_id" not in request.args: if "user_id" not in request.args:
defer.returnValue((app_service.sender, app_service)) defer.returnValue((app_service.sender, app_service))
@ -488,7 +494,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token): def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token) ret = yield self.store.get_user_by_access_token(token)
if not ret: if not ret:
logger.warn("Unrecognised access token - not in store: %s" % (token,)) logger.warn("Unrecognised access token - not in store.")
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
@ -511,7 +517,7 @@ class Auth(object):
) )
service = self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,)) logger.warn("Unrecognised appservice access token.")
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.", "Unrecognised access token.",

View File

@ -76,6 +76,8 @@ class EventTypes(object):
Topic = "m.room.topic" Topic = "m.room.topic"
Name = "m.room.name" Name = "m.room.name"
ServerACL = "m.room.server_acl"
class RejectedReason(object): class RejectedReason(object):
AUTH_ERROR = "auth_error" AUTH_ERROR = "auth_error"

View File

@ -17,7 +17,8 @@
import logging import logging
import simplejson as json from canonicaljson import json
from six import iteritems from six import iteritems
from six.moves import http_client from six.moves import http_client

View File

@ -17,7 +17,8 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from twisted.internet import defer from twisted.internet import defer
import simplejson as json from canonicaljson import json
import jsonschema import jsonschema
from jsonschema import FormatChecker from jsonschema import FormatChecker

View File

@ -85,7 +85,8 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None, def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None, rate_limited=True): sender=None, id=None, protocols=None, rate_limited=True,
ip_range_whitelist=None):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
@ -93,6 +94,7 @@ class ApplicationService(object):
self.server_name = hostname self.server_name = hostname
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
self.ip_range_whitelist = ip_range_whitelist
if "|" in self.id: if "|" in self.id:
raise Exception("application service ID cannot contain '|' character") raise Exception("application service ID cannot contain '|' character")

View File

@ -17,6 +17,8 @@ from ._base import Config, ConfigError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import UserID from synapse.types import UserID
from netaddr import IPSet
import yaml import yaml
import logging import logging
@ -154,6 +156,13 @@ def _load_appservice(hostname, as_info, config_filename):
" will not receive events or queries.", " will not receive events or queries.",
config_filename, config_filename,
) )
ip_range_whitelist = None
if as_info.get('ip_range_whitelist'):
ip_range_whitelist = IPSet(
as_info.get('ip_range_whitelist')
)
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
hostname=hostname, hostname=hostname,
@ -163,5 +172,6 @@ def _load_appservice(hostname, as_info, config_filename):
sender=user_id, sender=user_id,
id=as_info["id"], id=as_info["id"],
protocols=protocols, protocols=protocols,
rate_limited=rate_limited rate_limited=rate_limited,
ip_range_whitelist=ip_range_whitelist,
) )

View File

@ -16,6 +16,7 @@
import logging import logging
from synapse.http.endpoint import parse_and_validate_server_name
from ._base import Config, ConfigError from ._base import Config, ConfigError
logger = logging.Logger(__name__) logger = logging.Logger(__name__)
@ -25,6 +26,12 @@ class ServerConfig(Config):
def read_config(self, config): def read_config(self, config):
self.server_name = config["server_name"] self.server_name = config["server_name"]
try:
parse_and_validate_server_name(self.server_name)
except ValueError as e:
raise ConfigError(str(e))
self.pid_file = self.abspath(config.get("pid_file")) self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"] self.web_client = config["web_client"]
self.web_client_location = config.get("web_client_location", None) self.web_client_location = config.get("web_client_location", None)
@ -162,8 +169,8 @@ class ServerConfig(Config):
}) })
def default_config(self, server_name, **kwargs): def default_config(self, server_name, **kwargs):
if ":" in server_name: _, bind_port = parse_and_validate_server_name(server_name)
bind_port = int(server_name.split(":")[1]) if bind_port is not None:
unsecure_port = bind_port - 400 unsecure_port = bind_port - 400
else: else:
bind_port = 8448 bind_port = 8448

View File

@ -18,7 +18,7 @@ from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
import simplejson as json from canonicaljson import json
import logging import logging

View File

@ -76,6 +76,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
return return
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id) room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain: if room_id_domain != sender_domain:
raise AuthError( raise AuthError(
@ -524,7 +525,11 @@ def _check_power_levels(event, auth_events):
"to your own" "to your own"
) )
if old_level > user_level or new_level > user_level: # Check if the old and new levels are greater than the user level
# (if defined)
old_level_too_big = old_level is not None and old_level > user_level
new_level_too_big = new_level is not None and new_level > user_level
if old_level_too_big or new_level_too_big:
raise AuthError( raise AuthError(
403, 403,
"You don't have permission to add ops level greater " "You don't have permission to add ops level greater "

View File

@ -14,10 +14,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import re
import simplejson as json from canonicaljson import json
import six
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import ( from synapse.federation.federation_base import (
@ -27,6 +31,7 @@ from synapse.federation.federation_base import (
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import async from synapse.util import async
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -74,6 +79,9 @@ class FederationServer(FederationBase):
@log_function @log_function
def on_backfill_request(self, origin, room_id, versions, limit): def on_backfill_request(self, origin, room_id, versions, limit):
with (yield self._server_linearizer.queue((origin, room_id))): with (yield self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
pdus = yield self.handler.on_backfill_request( pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit origin, room_id, versions, limit
) )
@ -134,6 +142,8 @@ class FederationServer(FederationBase):
received_pdus_counter.inc(len(transaction.pdus)) received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(transaction.origin)
pdus_by_room = {} pdus_by_room = {}
for p in transaction.pdus: for p in transaction.pdus:
@ -154,9 +164,21 @@ class FederationServer(FederationBase):
# we can process different rooms in parallel (which is useful if they # we can process different rooms in parallel (which is useful if they
# require callouts to other servers to fetch missing events), but # require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu. # impose a limit to avoid going too crazy with ram/cpu.
@defer.inlineCallbacks @defer.inlineCallbacks
def process_pdus_for_room(room_id): def process_pdus_for_room(room_id):
logger.debug("Processing PDUs for %s", room_id) logger.debug("Processing PDUs for %s", room_id)
try:
yield self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
logger.warn(
"Ignoring PDUs for room %s from banned server", room_id,
)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
return
for pdu in pdus_by_room[room_id]: for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id event_id = pdu.event_id
try: try:
@ -211,6 +233,9 @@ class FederationServer(FederationBase):
if not event_id: if not event_id:
raise NotImplementedError("Specify an event") raise NotImplementedError("Specify an event")
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
in_room = yield self.auth.check_host_in_room(room_id, origin) in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
@ -234,6 +259,9 @@ class FederationServer(FederationBase):
if not event_id: if not event_id:
raise NotImplementedError("Specify an event") raise NotImplementedError("Specify an event")
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
in_room = yield self.auth.check_host_in_room(room_id, origin) in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
@ -298,7 +326,9 @@ class FederationServer(FederationBase):
defer.returnValue((200, resp)) defer.returnValue((200, resp))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id): def on_make_join_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_join_request(room_id, user_id) pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)}) defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@ -306,6 +336,8 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_invite_request(self, origin, content): def on_invite_request(self, origin, content):
pdu = event_from_pdu_json(content) pdu = event_from_pdu_json(content)
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu) ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@ -314,6 +346,10 @@ class FederationServer(FederationBase):
def on_send_join_request(self, origin, content): def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content) logger.debug("on_send_join_request: content: %s", content)
pdu = event_from_pdu_json(content) pdu = event_from_pdu_json(content)
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu) res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
@ -325,7 +361,9 @@ class FederationServer(FederationBase):
})) }))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_leave_request(self, room_id, user_id): def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
pdu = yield self.handler.on_make_leave_request(room_id, user_id) pdu = yield self.handler.on_make_leave_request(room_id, user_id)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)}) defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@ -334,6 +372,10 @@ class FederationServer(FederationBase):
def on_send_leave_request(self, origin, content): def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content) logger.debug("on_send_leave_request: content: %s", content)
pdu = event_from_pdu_json(content) pdu = event_from_pdu_json(content)
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu) yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -341,6 +383,9 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id): def on_event_auth(self, origin, room_id, event_id):
with (yield self._server_linearizer.queue((origin, room_id))): with (yield self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id) auth_pdus = yield self.handler.on_event_auth(event_id)
res = { res = {
@ -369,6 +414,9 @@ class FederationServer(FederationBase):
Deferred: Results in `dict` with the same format as `content` Deferred: Results in `dict` with the same format as `content`
""" """
with (yield self._server_linearizer.queue((origin, room_id))): with (yield self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
auth_chain = [ auth_chain = [
event_from_pdu_json(e) event_from_pdu_json(e)
for e in content["auth_chain"] for e in content["auth_chain"]
@ -442,6 +490,9 @@ class FederationServer(FederationBase):
def on_get_missing_events(self, origin, room_id, earliest_events, def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth): latest_events, limit, min_depth):
with (yield self._server_linearizer.queue((origin, room_id))): with (yield self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
logger.info( logger.info(
"on_get_missing_events: earliest_events: %r, latest_events: %r," "on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d, min_depth: %d", " limit: %d, min_depth: %d",
@ -549,7 +600,9 @@ class FederationServer(FederationBase):
affected=pdu.event_id, affected=pdu.event_id,
) )
yield self.handler.on_receive_pdu(origin, pdu, get_missing=True) yield self.handler.on_receive_pdu(
origin, pdu, get_missing=True, sent_to_us_directly=True,
)
def __str__(self): def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
@ -577,6 +630,101 @@ class FederationServer(FederationBase):
) )
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def check_server_matches_acl(self, server_name, room_id):
"""Check if the given server is allowed by the server ACLs in the room
Args:
server_name (str): name of server, *without any port part*
room_id (str): ID of the room to check
Raises:
AuthError if the server does not match the ACL
"""
state_ids = yield self.store.get_current_state_ids(room_id)
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
if not acl_event_id:
return
acl_event = yield self.store.get_event(acl_event_id)
if server_matches_acl_event(server_name, acl_event):
return
raise AuthError(code=403, msg="Server is banned from room")
def server_matches_acl_event(server_name, acl_event):
"""Check if the given server is allowed by the ACL event
Args:
server_name (str): name of server, without any port part
acl_event (EventBase): m.room.server_acl event
Returns:
bool: True if this server is allowed by the ACLs
"""
logger.debug("Checking %s against acl %s", server_name, acl_event.content)
# first of all, check if literal IPs are blocked, and if so, whether the
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
logger.warn("Ignorning non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
if server_name[0] == '[':
return False
# check for ipv4 literals. We can just lift the routine from twisted.
if isIPAddress(server_name):
return False
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
logger.warn("Ignorning non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched deny rule %s", server_name, e)
return False
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
logger.warn("Ignorning non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched allow rule %s", server_name, e)
return True
# everything else should be rejected.
# logger.info("%s fell through", server_name)
return False
def _acl_entry_matches(server_name, acl_entry):
if not isinstance(acl_entry, six.string_types):
logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
return False
regex = _glob_to_regex(acl_entry)
return regex.match(server_name)
def _glob_to_regex(glob):
res = ''
for c in glob:
if c == '*':
res = res + '.*'
elif c == '?':
res = res + '.'
else:
res = res + re.escape(c)
return re.compile(res + "\\Z", re.IGNORECASE)
class FederationHandlerRegistry(object): class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or """Allows classes to register themselves as handlers for a given EDU or

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError, FederationDeniedError from synapse.api.errors import Codes, SynapseError, FederationDeniedError
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@ -99,26 +100,6 @@ class Authenticator(object):
origin = None origin = None
def parse_auth_header(header_str):
try:
params = auth.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith("\""):
return value[1:-1]
else:
return value
origin = strip_quotes(param_dict["origin"])
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
except Exception:
raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers: if not auth_headers:
@ -127,8 +108,8 @@ class Authenticator(object):
) )
for auth in auth_headers: for auth in auth_headers:
if auth.startswith("X-Matrix"): if auth.startswith(b"X-Matrix"):
(origin, key, sig) = parse_auth_header(auth) (origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig json_request["signatures"].setdefault(origin, {})[key] = sig
@ -165,6 +146,48 @@ class Authenticator(object):
logger.exception("Error resetting retry timings on %s", origin) logger.exception("Error resetting retry timings on %s", origin)
def _parse_auth_header(header_bytes):
"""Parse an X-Matrix auth header
Args:
header_bytes (bytes): header value
Returns:
Tuple[str, str, str]: origin, key id, signature.
Raises:
AuthenticationError if the header could not be parsed
"""
try:
header_str = header_bytes.decode('utf-8')
params = header_str.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith(b"\""):
return value[1:-1]
else:
return value
origin = strip_quotes(param_dict["origin"])
# ensure that the origin is a valid server name
parse_and_validate_server_name(origin)
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
return origin, key, sig
except Exception as e:
logger.warn(
"Error parsing auth header '%s': %s",
header_bytes.decode('ascii', 'replace'),
e,
)
raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED,
)
class BaseFederationServlet(object): class BaseFederationServlet(object):
REQUIRE_AUTH = True REQUIRE_AUTH = True
@ -362,7 +385,9 @@ class FederationMakeJoinServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id): def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_join_request(context, user_id) content = yield self.handler.on_make_join_request(
origin, context, user_id,
)
defer.returnValue((200, content)) defer.returnValue((200, content))
@ -371,7 +396,9 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id): def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_leave_request(context, user_id) content = yield self.handler.on_make_leave_request(
origin, context, user_id,
)
defer.returnValue((200, content)) defer.returnValue((200, content))

View File

@ -16,6 +16,8 @@
from twisted.internet import defer, threads from twisted.internet import defer, threads
from canonicaljson import json
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import ( from synapse.api.errors import (
@ -32,7 +34,6 @@ from twisted.web.client import PartialDownloadError
import logging import logging
import bcrypt import bcrypt
import pymacaroons import pymacaroons
import simplejson
import attr import attr
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -403,7 +404,7 @@ class AuthHandler(BaseHandler):
except PartialDownloadError as pde: except PartialDownloadError as pde:
# Twisted is silly # Twisted is silly
data = pde.response data = pde.response
resp_body = simplejson.loads(data) resp_body = json.loads(data)
if 'success' in resp_body: if 'success' in resp_body:
# Note that we do NOT check the hostname here: we explicitly # Note that we do NOT check the hostname here: we explicitly

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer, reactor from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -39,14 +39,15 @@ class DeactivateAccountHandler(BaseHandler):
# Start the user parter loop so it can resume parting users from rooms where # Start the user parter loop so it can resume parting users from rooms where
# it left off (if it has work left to do). # it left off (if it has work left to do).
reactor.callWhenRunning(self._start_user_parting) hs.get_reactor().callWhenRunning(self._start_user_parting)
@defer.inlineCallbacks @defer.inlineCallbacks
def deactivate_account(self, user_id): def deactivate_account(self, user_id, erase_data):
"""Deactivate a user's account """Deactivate a user's account
Args: Args:
user_id (str): ID of user to be deactivated user_id (str): ID of user to be deactivated
erase_data (bool): whether to GDPR-erase the user's data
Returns: Returns:
Deferred Deferred
@ -92,6 +93,11 @@ class DeactivateAccountHandler(BaseHandler):
# delete from user directory # delete from user directory
yield self.user_directory_handler.handle_user_deactivated(user_id) yield self.user_directory_handler.handle_user_deactivated(user_id)
# Mark the user as erased, if they asked for that
if erase_data:
logger.info("Marking %s as erased", user_id)
yield self.store.mark_user_erased(user_id)
# Now start the process that goes through that list and # Now start the process that goes through that list and
# parts users from rooms (if it isn't already running) # parts users from rooms (if it isn't already running)
self._start_user_parting() self._start_user_parting()

View File

@ -14,10 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import simplejson as json
import logging import logging
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json, json
from twisted.internet import defer from twisted.internet import defer
from six import iteritems from six import iteritems
@ -80,7 +79,7 @@ class E2eKeysHandler(object):
else: else:
remote_queries[user_id] = device_ids remote_queries[user_id] = device_ids
# Firt get local devices. # First get local devices.
failures = {} failures = {}
results = {} results = {}
if local_query: if local_query:
@ -357,7 +356,7 @@ def _exception_to_failure(e):
# include ConnectionRefused and other errors # include ConnectionRefused and other errors
# #
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which simplejson then fails to serialize. # give a string for e.message, which json then fails to serialize.
return { return {
"status": 503, "message": str(e.message), "status": 503, "message": str(e.message),
} }

View File

@ -44,6 +44,7 @@ from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures, compute_event_signature, add_hashes_and_signatures,
) )
from synapse.state import resolve_events_with_factory
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
@ -89,7 +90,9 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_receive_pdu(self, origin, pdu, get_missing=True): def on_receive_pdu(
self, origin, pdu, get_missing=True, sent_to_us_directly=False,
):
""" Process a PDU received via a federation /send/ transaction, or """ Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events via backfill of missing prev_events
@ -163,14 +166,11 @@ class FederationHandler(BaseHandler):
"Ignoring PDU %s for room %s from %s as we've left the room!", "Ignoring PDU %s for room %s from %s as we've left the room!",
pdu.event_id, pdu.room_id, origin, pdu.event_id, pdu.room_id, origin,
) )
return defer.returnValue(None)
state = None state = None
auth_chain = [] auth_chain = []
fetch_state = False
# Get missing pdus if necessary. # Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier(): if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth. # We only backfill backwards to the min depth.
@ -225,26 +225,60 @@ class FederationHandler(BaseHandler):
list(prevs - seen)[:5], list(prevs - seen)[:5],
) )
if prevs - seen: if sent_to_us_directly and prevs - seen:
logger.info( # If they have sent it to us directly, and the server
"Still missing %d events for room %r: %r...", # isn't telling us about the auth events that it's
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] # made a message referencing, we explode
raise FederationError(
"ERROR",
403,
(
"Your server isn't divulging details about prev_events "
"referenced in this event."
),
affected=pdu.event_id,
) )
fetch_state = True elif prevs - seen:
# Calculate the state of the previous events, and
# de-conflict them to find the current state.
state_groups = []
auth_chains = set()
try:
# Get the state of the events we know about
ours = yield self.store.get_state_groups(pdu.room_id, list(seen))
state_groups.append(ours)
if fetch_state: # Ask the remote server for the states we don't
# We need to get the state at this event, since we haven't # know about
# processed all the prev events. for p in prevs - seen:
logger.debug( state, got_auth_chain = (
"_handle_new_pdu getting state for %s", yield self.replication_layer.get_state_for_room(
pdu.room_id origin, pdu.room_id, p
) )
try: )
state, auth_chain = yield self.replication_layer.get_state_for_room( auth_chains.update(got_auth_chain)
origin, pdu.room_id, pdu.event_id, state_group = {(x.type, x.state_key): x.event_id for x in state}
) state_groups.append(state_group)
except Exception:
logger.exception("Failed to get state for event: %s", pdu.event_id) # Resolve any conflicting state
def fetch(ev_ids):
return self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False
)
state_map = yield resolve_events_with_factory(
state_groups, {pdu.event_id: pdu}, fetch
)
state = (yield self.store.get_events(state_map.values())).values()
auth_chain = list(auth_chains)
except Exception:
raise FederationError(
"ERROR",
403,
"We can't get valid state history.",
affected=pdu.event_id,
)
yield self._process_received_pdu( yield self._process_received_pdu(
origin, origin,
@ -322,11 +356,17 @@ class FederationHandler(BaseHandler):
for e in missing_events: for e in missing_events:
logger.info("Handling found event %s", e.event_id) logger.info("Handling found event %s", e.event_id)
yield self.on_receive_pdu( try:
origin, yield self.on_receive_pdu(
e, origin,
get_missing=False e,
) get_missing=False
)
except FederationError as e:
if e.code == 403:
logger.warn("Event %s failed history check.")
else:
raise
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
@ -460,6 +500,47 @@ class FederationHandler(BaseHandler):
@measure_func("_filter_events_for_server") @measure_func("_filter_events_for_server")
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):
"""Filter the given events for the given server, redacting those the
server can't see.
Assumes the server is currently in the room.
Returns
list[FrozenEvent]
"""
# First lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
)
)
visibility_ids = set()
for sids in event_to_state_ids.itervalues():
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist:
visibility_ids.add(hist)
# If we failed to find any history visibility events then the default
# is "shared" visiblity.
if not visibility_ids:
defer.returnValue(events)
event_map = yield self.store.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.itervalues()
)
if all_open:
defer.returnValue(events)
# Ok, so we're dealing with events that have non-trivial visibility
# rules, so we need to also get the memberships of the room.
event_to_state_ids = yield self.store.get_state_ids_for_events( event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=( types=(
@ -495,7 +576,20 @@ class FederationHandler(BaseHandler):
for e_id, key_to_eid in event_to_state_ids.iteritems() for e_id, key_to_eid in event_to_state_ids.iteritems()
} }
erased_senders = yield self.store.are_users_erased(
e.sender for e in events,
)
def redact_disallowed(event, state): def redact_disallowed(event, state):
# if the sender has been gdpr17ed, always return a redacted
# copy of the event.
if erased_senders[event.sender]:
logger.info(
"Sender of %s has been erased, redacting",
event.event_id,
)
return prune_event(event)
if not state: if not state:
return event return event

View File

@ -19,7 +19,7 @@
import logging import logging
import simplejson as json from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer

View File

@ -14,13 +14,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import simplejson
import sys import sys
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json, json
import six import six
from six import string_types, itervalues, iteritems from six import string_types, itervalues, iteritems
from twisted.internet import defer, reactor from twisted.internet import defer
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -157,7 +156,7 @@ class MessageHandler(BaseHandler):
# remove the purge from the list 24 hours after it completes # remove the purge from the list 24 hours after it completes
def clear_purge(): def clear_purge():
del self._purges_by_id[purge_id] del self._purges_by_id[purge_id]
reactor.callLater(24 * 3600, clear_purge) self.hs.get_reactor().callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id): def get_purge_status(self, purge_id):
"""Get the current status of an active purge """Get the current status of an active purge
@ -491,7 +490,7 @@ class EventCreationHandler(object):
target, e target, e
) )
is_exempt = yield self._is_exempt_from_privacy_policy(builder) is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
if not is_exempt: if not is_exempt:
yield self.assert_accepted_privacy_policy(requester) yield self.assert_accepted_privacy_policy(requester)
@ -509,12 +508,13 @@ class EventCreationHandler(object):
defer.returnValue((event, context)) defer.returnValue((event, context))
def _is_exempt_from_privacy_policy(self, builder): def _is_exempt_from_privacy_policy(self, builder, requester):
""""Determine if an event to be sent is exempt from having to consent """"Determine if an event to be sent is exempt from having to consent
to the privacy policy to the privacy policy
Args: Args:
builder (synapse.events.builder.EventBuilder): event being created builder (synapse.events.builder.EventBuilder): event being created
requester (Requster): user requesting this event
Returns: Returns:
Deferred[bool]: true if the event can be sent without the user Deferred[bool]: true if the event can be sent without the user
@ -525,6 +525,9 @@ class EventCreationHandler(object):
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
if membership == Membership.JOIN: if membership == Membership.JOIN:
return self._is_server_notices_room(builder.room_id) return self._is_server_notices_room(builder.room_id)
elif membership == Membership.LEAVE:
# the user is always allowed to leave (but not kick people)
return builder.state_key == requester.user.to_string()
return succeed(False) return succeed(False)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -793,7 +796,7 @@ class EventCreationHandler(object):
# Ensure that we can round trip before trying to persist in db # Ensure that we can round trip before trying to persist in db
try: try:
dump = frozendict_json_encoder.encode(event.content) dump = frozendict_json_encoder.encode(event.content)
simplejson.loads(dump) json.loads(dump)
except Exception: except Exception:
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)
raise raise

View File

@ -22,7 +22,7 @@ The methods that define policy are:
- should_notify - should_notify
""" """
from twisted.internet import defer, reactor from twisted.internet import defer
from contextlib import contextmanager from contextlib import contextmanager
from six import itervalues, iteritems from six import itervalues, iteritems
@ -179,7 +179,7 @@ class PresenceHandler(object):
# have not yet been persisted # have not yet been persisted
self.unpersisted_users_changes = set() self.unpersisted_users_changes = set()
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) hs.get_reactor().addSystemEventTrigger("before", "shutdown", self._on_shutdown)
self.serial_to_user = {} self.serial_to_user = {}
self._next_serial = 1 self._next_serial = 1

View File

@ -145,7 +145,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"invited", # InvitedSyncResult for each invited room. "invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room. "archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device. "to_device", # List of direct messages for the device.
"device_lists", # List of user_ids whose devices have chanegd "device_lists", # List of user_ids whose devices have changed
"device_one_time_keys_count", # Dict of algorithm to count for one time keys "device_one_time_keys_count", # Dict of algorithm to count for one time keys
# for this device # for this device
"groups", "groups",

View File

@ -42,7 +42,7 @@ from twisted.web._newclient import ResponseDone
from six import StringIO from six import StringIO
from prometheus_client import Counter from prometheus_client import Counter
import simplejson as json from canonicaljson import json
import logging import logging
import urllib import urllib

View File

@ -12,8 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor from twisted.internet import defer
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
from twisted.names import client, dns from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError from twisted.names.error import DNSNameError, DomainError
@ -37,6 +39,71 @@ _Server = collections.namedtuple(
) )
def parse_server_name(server_name):
"""Split a server name into host/port parts.
Args:
server_name (str): server name to parse
Returns:
Tuple[str, int|None]: host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
try:
if server_name[-1] == ']':
# ipv6 literal, hopefully
return server_name, None
domain_port = server_name.rsplit(":", 1)
domain = domain_port[0]
port = int(domain_port[1]) if domain_port[1:] else None
return domain, port
except Exception:
raise ValueError("Invalid server name '%s'" % server_name)
VALID_HOST_REGEX = re.compile(
"\\A[0-9a-zA-Z.-]+\\Z",
)
def parse_and_validate_server_name(server_name):
"""Split a server name into host/port parts and do some basic validation.
Args:
server_name (str): server name to parse
Returns:
Tuple[str, int|None]: host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
host, port = parse_server_name(server_name)
# these tests don't need to be bulletproof as we'll find out soon enough
# if somebody is giving us invalid data. What we *do* need is to be sure
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
if host[0] == '[':
if host[-1] != ']':
raise ValueError("Mismatched [...] in server name '%s'" % (
server_name,
))
return host, port
# otherwise it should only be alphanumerics.
if not VALID_HOST_REGEX.match(host):
raise ValueError("Server name '%s' contains invalid characters" % (
server_name,
))
return host, port
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None, def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
timeout=None): timeout=None):
"""Construct an endpoint for the given matrix destination. """Construct an endpoint for the given matrix destination.
@ -50,9 +117,7 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=
timeout (int): connection timeout in seconds timeout (int): connection timeout in seconds
""" """
domain_port = destination.split(":") domain, port = parse_server_name(destination)
domain = domain_port[0]
port = int(domain_port[1]) if domain_port[1:] else None
endpoint_kw_args = {} endpoint_kw_args = {}
@ -74,21 +139,22 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=
reactor, "matrix", domain, protocol="tcp", reactor, "matrix", domain, protocol="tcp",
default_port=default_port, endpoint=transport_endpoint, default_port=default_port, endpoint=transport_endpoint,
endpoint_kw_args=endpoint_kw_args endpoint_kw_args=endpoint_kw_args
)) ), reactor)
else: else:
return _WrappingEndpointFac(transport_endpoint( return _WrappingEndpointFac(transport_endpoint(
reactor, domain, port, **endpoint_kw_args reactor, domain, port, **endpoint_kw_args
)) ), reactor)
class _WrappingEndpointFac(object): class _WrappingEndpointFac(object):
def __init__(self, endpoint_fac): def __init__(self, endpoint_fac, reactor):
self.endpoint_fac = endpoint_fac self.endpoint_fac = endpoint_fac
self.reactor = reactor
@defer.inlineCallbacks @defer.inlineCallbacks
def connect(self, protocolFactory): def connect(self, protocolFactory):
conn = yield self.endpoint_fac.connect(protocolFactory) conn = yield self.endpoint_fac.connect(protocolFactory)
conn = _WrappedConnection(conn) conn = _WrappedConnection(conn, self.reactor)
defer.returnValue(conn) defer.returnValue(conn)
@ -98,9 +164,10 @@ class _WrappedConnection(object):
""" """
__slots__ = ["conn", "last_request"] __slots__ = ["conn", "last_request"]
def __init__(self, conn): def __init__(self, conn, reactor):
object.__setattr__(self, "conn", conn) object.__setattr__(self, "conn", conn)
object.__setattr__(self, "last_request", time.time()) object.__setattr__(self, "last_request", time.time())
self._reactor = reactor
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.conn, name) return getattr(self.conn, name)
@ -131,14 +198,14 @@ class _WrappedConnection(object):
# Time this connection out if we haven't send a request in the last # Time this connection out if we haven't send a request in the last
# N minutes # N minutes
# TODO: Cancel the previous callLater? # TODO: Cancel the previous callLater?
reactor.callLater(3 * 60, self._time_things_out_maybe) self._reactor.callLater(3 * 60, self._time_things_out_maybe)
d = self.conn.request(request) d = self.conn.request(request)
def update_request_time(res): def update_request_time(res):
self.last_request = time.time() self.last_request = time.time()
# TODO: Cancel the previous callLater? # TODO: Cancel the previous callLater?
reactor.callLater(3 * 60, self._time_things_out_maybe) self._reactor.callLater(3 * 60, self._time_things_out_maybe)
return res return res
d.addCallback(update_request_time) d.addCallback(update_request_time)

View File

@ -27,7 +27,7 @@ from synapse.util import logcontext
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
import synapse.util.retryutils import synapse.util.retryutils
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json, json
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, HttpResponseException, FederationDeniedError, SynapseError, Codes, HttpResponseException, FederationDeniedError,
@ -36,7 +36,6 @@ from synapse.api.errors import (
from signedjson.sign import sign_json from signedjson.sign import sign_json
import cgi import cgi
import simplejson as json
import logging import logging
import random import random
import sys import sys

View File

@ -29,7 +29,7 @@ import synapse.metrics
import synapse.events import synapse.events
from canonicaljson import ( from canonicaljson import (
encode_canonical_json, encode_pretty_printed_json encode_canonical_json, encode_pretty_printed_json, json
) )
from twisted.internet import defer from twisted.internet import defer
@ -41,7 +41,6 @@ from twisted.web.util import redirectTo
import collections import collections
import logging import logging
import urllib import urllib
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -410,7 +409,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
if canonical_json or synapse.events.USE_FROZEN_DICTS: if canonical_json or synapse.events.USE_FROZEN_DICTS:
json_bytes = encode_canonical_json(json_object) json_bytes = encode_canonical_json(json_object)
else: else:
json_bytes = simplejson.dumps(json_object) json_bytes = json.dumps(json_object)
return respond_with_json_bytes( return respond_with_json_bytes(
request, code, json_bytes, request, code, json_bytes,

View File

@ -18,7 +18,9 @@
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
import logging import logging
import simplejson
from canonicaljson import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -171,7 +173,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
return None return None
try: try:
content = simplejson.loads(content_bytes) content = json.loads(content_bytes)
except Exception as e: except Exception as e:
logger.warn("Unable to parse JSON: %s", e) logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View File

@ -107,13 +107,28 @@ class SynapseRequest(Request):
end_time = time.time() end_time = time.time()
# need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header
authenticated_entity = self.authenticated_entity
if authenticated_entity is not None:
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
# ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically
# with maximum recursion trying to log errors about
# the charset problem.
# c.f. https://github.com/matrix-org/synapse/issues/3471
user_agent = self.get_user_agent()
if user_agent is not None:
user_agent = user_agent.decode("utf-8", "replace")
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
self.authenticated_entity, authenticated_entity,
end_time - self.start_time, end_time - self.start_time,
ru_utime, ru_utime,
ru_stime, ru_stime,
@ -125,7 +140,7 @@ class SynapseRequest(Request):
self.method, self.method,
self.get_redacted_uri(), self.get_redacted_uri(),
self.clientproto, self.clientproto,
self.get_user_agent(), user_agent,
evt_db_fetch_count, evt_db_fetch_count,
) )

View File

@ -147,7 +147,8 @@ class GCCounts(object):
yield cm yield cm
REGISTRY.register(GCCounts()) if not running_on_pypy:
REGISTRY.register(GCCounts())
# #
# Twisted reactor metrics # Twisted reactor metrics

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer, reactor from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging import logging
@ -199,7 +199,7 @@ class EmailPusher(object):
self.timed_call = None self.timed_call = None
if soonest_due_at is not None: if soonest_due_at is not None:
self.timed_call = reactor.callLater( self.timed_call = self.hs.get_reactor().callLater(
self.seconds_until(soonest_due_at), self.on_timer self.seconds_until(soonest_due_at), self.on_timer
) )

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.internet import defer, reactor from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from . import push_rule_evaluator from . import push_rule_evaluator
@ -220,7 +220,9 @@ class HttpPusher(object):
) )
else: else:
logger.info("Push failed: delaying for %ds", self.backoff_delay) logger.info("Push failed: delaying for %ds", self.backoff_delay)
self.timed_call = reactor.callLater(self.backoff_delay, self.on_timer) self.timed_call = self.hs.get_reactor().callLater(
self.backoff_delay, self.on_timer
)
self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC) self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC)
break break

View File

@ -57,16 +57,14 @@ REQUIREMENTS = {
"phonenumbers>=8.2.0": ["phonenumbers"], "phonenumbers>=8.2.0": ["phonenumbers"],
"six": ["six"], "six": ["six"],
"prometheus_client": ["prometheus_client"], "prometheus_client": ["prometheus_client"],
"attr": ["attr"], "attrs": ["attr"],
"netaddr>=0.7.18": ["netaddr"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {
"matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"], "matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"],
}, },
"preview_url": {
"netaddr>=0.7.18": ["netaddr"],
},
"email.enable_notifs": { "email.enable_notifs": {
"Jinja2>=2.8": ["Jinja2>=2.8"], "Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"], "bleach>=1.4.2": ["bleach>=1.4.2"],

View File

@ -23,6 +23,7 @@ from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateGroupWorkerStore from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamWorkerStore from synapse.storage.stream import StreamWorkerStore
from synapse.storage.signatures import SignatureWorkerStore from synapse.storage.signatures import SignatureWorkerStore
from synapse.storage.user_erasure_store import UserErasureWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
@ -45,6 +46,7 @@ class SlavedEventStore(EventFederationWorkerStore,
EventsWorkerStore, EventsWorkerStore,
StateGroupWorkerStore, StateGroupWorkerStore,
SignatureWorkerStore, SignatureWorkerStore,
UserErasureWorkerStore,
BaseSlavedStore): BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):

View File

@ -15,7 +15,7 @@
"""A replication client for use by synapse workers. """A replication client for use by synapse workers.
""" """
from twisted.internet import reactor, defer from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory from twisted.internet.protocol import ReconnectingClientFactory
from .commands import ( from .commands import (
@ -44,7 +44,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
self._clock = hs.get_clock() # As self.clock is defined in super class self._clock = hs.get_clock() # As self.clock is defined in super class
reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying) hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
def startedConnecting(self, connector): def startedConnecting(self, connector):
logger.info("Connecting to replication: %r", connector.getDestination()) logger.info("Connecting to replication: %r", connector.getDestination())
@ -95,7 +95,7 @@ class ReplicationClientHandler(object):
factory = ReplicationClientFactory(hs, client_name, self) factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host host = hs.config.worker_replication_host
port = hs.config.worker_replication_port port = hs.config.worker_replication_port
reactor.connectTCP(host, port, factory) hs.get_reactor().connectTCP(host, port, factory)
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
"""Called when we get new replication data. By default this just pokes """Called when we get new replication data. By default this just pokes

View File

@ -19,13 +19,17 @@ allowed to be sent by which side.
""" """
import logging import logging
import simplejson import platform
if platform.python_implementation() == "PyPy":
import json
_json_encoder = json.JSONEncoder()
else:
import simplejson as json
_json_encoder = json.JSONEncoder(namedtuple_as_object=False)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False)
class Command(object): class Command(object):
"""The base command class. """The base command class.
@ -102,7 +106,7 @@ class RdataCommand(Command):
return cls( return cls(
stream_name, stream_name,
None if token == "batch" else int(token), None if token == "batch" else int(token),
simplejson.loads(row_json) json.loads(row_json)
) )
def to_line(self): def to_line(self):
@ -300,7 +304,7 @@ class InvalidateCacheCommand(Command):
def from_line(cls, line): def from_line(cls, line):
cache_func, keys_json = line.split(" ", 1) cache_func, keys_json = line.split(" ", 1)
return cls(cache_func, simplejson.loads(keys_json)) return cls(cache_func, json.loads(keys_json))
def to_line(self): def to_line(self):
return " ".join(( return " ".join((
@ -329,7 +333,7 @@ class UserIpCommand(Command):
def from_line(cls, line): def from_line(cls, line):
user_id, jsn = line.split(" ", 1) user_id, jsn = line.split(" ", 1)
access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn) access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
return cls( return cls(
user_id, access_token, ip, user_agent, device_id, last_seen user_id, access_token, ip, user_agent, device_id, last_seen

View File

@ -15,7 +15,7 @@
"""The server side of the replication stream. """The server side of the replication stream.
""" """
from twisted.internet import defer, reactor from twisted.internet import defer
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from .streams import STREAMS_MAP, FederationStream from .streams import STREAMS_MAP, FederationStream
@ -109,7 +109,7 @@ class ReplicationStreamer(object):
self.is_looping = False self.is_looping = False
self.pending_updates = False self.pending_updates = False
reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown) hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
def on_shutdown(self): def on_shutdown(self):
# close all connections on shutdown # close all connections on shutdown

View File

@ -16,6 +16,8 @@
from twisted.internet import defer from twisted.internet import defer
from six.moves import http_client
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -247,6 +249,15 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, target_user_id): def on_POST(self, request, target_user_id):
body = parse_json_object_from_request(request, allow_empty_body=True)
erase = body.get("erase", False)
if not isinstance(erase, bool):
raise SynapseError(
http_client.BAD_REQUEST,
"Param 'erase' must be a boolean, if given",
Codes.BAD_JSON,
)
UserID.from_string(target_user_id) UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user) is_admin = yield self.auth.is_server_admin(requester.user)
@ -254,7 +265,9 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
yield self._deactivate_account_handler.deactivate_account(target_user_id) yield self._deactivate_account_handler.deactivate_account(
target_user_id, erase,
)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -23,7 +23,8 @@ from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json from canonicaljson import json
import urllib import urllib
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse

View File

@ -31,7 +31,7 @@ from synapse.http.servlet import (
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
import logging import logging
import simplejson as json from canonicaljson import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,6 +16,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from six.moves import http_client
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import has_access_token from synapse.api.auth import has_access_token
@ -186,13 +188,20 @@ class DeactivateAccountRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
erase = body.get("erase", False)
if not isinstance(erase, bool):
raise SynapseError(
http_client.BAD_REQUEST,
"Param 'erase' must be a boolean, if given",
Codes.BAD_JSON,
)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
# allow ASes to dectivate their own users # allow ASes to dectivate their own users
if requester.app_service: if requester.app_service:
yield self._deactivate_account_handler.deactivate_account( yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string() requester.user.to_string(), erase,
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -200,7 +209,7 @@ class DeactivateAccountRestServlet(RestServlet):
requester, body, self.hs.get_ip_from_request(request), requester, body, self.hs.get_ip_from_request(request),
) )
yield self._deactivate_account_handler.deactivate_account( yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), requester.user.to_string(), erase,
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -33,7 +33,7 @@ from ._base import set_timeline_upper_limit
import itertools import itertools
import logging import logging
import simplejson as json from canonicaljson import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -22,8 +22,9 @@ from synapse.api.errors import (
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from twisted.web import server, resource from twisted.web import server, resource
from canonicaljson import json
import base64 import base64
import simplejson as json
import logging import logging
import os import os
import re import re

View File

@ -23,7 +23,8 @@ import re
import shutil import shutil
import sys import sys
import traceback import traceback
import simplejson as json
from canonicaljson import json
from six.moves import urllib_parse as urlparse from six.moves import urllib_parse as urlparse
from six import string_types from six import string_types

View File

@ -20,6 +20,7 @@ import time
import logging import logging
from synapse.storage.devices import DeviceStore from synapse.storage.devices import DeviceStore
from synapse.storage.user_erasure_store import UserErasureStore
from .appservice import ( from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore ApplicationServiceStore, ApplicationServiceTransactionStore
) )
@ -88,6 +89,7 @@ class DataStore(RoomMemberStore, RoomStore,
DeviceInboxStore, DeviceInboxStore,
UserDirectoryStore, UserDirectoryStore,
GroupServerStore, GroupServerStore,
UserErasureStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):

View File

@ -22,8 +22,9 @@ from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from canonicaljson import json
import abc import abc
import simplejson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from canonicaljson import json
from synapse.appservice import AppServiceTransaction from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices

View File

@ -18,7 +18,8 @@ from . import engines
from twisted.internet import defer from twisted.internet import defer
import simplejson as json from canonicaljson import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
import logging import logging
import simplejson
from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
@ -85,7 +86,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
) )
rows = [] rows = []
for destination, edu in remote_messages_by_destination.items(): for destination, edu in remote_messages_by_destination.items():
edu_json = simplejson.dumps(edu) edu_json = json.dumps(edu)
rows.append((destination, stream_id, now_ms, edu_json)) rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows) txn.executemany(sql, rows)
@ -177,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" WHERE user_id = ?" " WHERE user_id = ?"
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
message_json = simplejson.dumps(messages_by_device["*"]) message_json = json.dumps(messages_by_device["*"])
for row in txn: for row in txn:
# Add the message for all devices for this user on this # Add the message for all devices for this user on this
# server. # server.
@ -199,7 +200,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# Only insert into the local inbox if the device exists on # Only insert into the local inbox if the device exists on
# this server # this server
device = row[0] device = row[0]
message_json = simplejson.dumps(messages_by_device[device]) message_json = json.dumps(messages_by_device[device])
messages_json_for_user[device] = message_json messages_json_for_user[device] = message_json
if messages_json_for_user: if messages_json_for_user:
@ -253,7 +254,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
messages = [] messages = []
for row in txn: for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(simplejson.loads(row[1])) messages.append(json.loads(row[1]))
if len(messages) < limit: if len(messages) < limit:
stream_pos = current_stream_id stream_pos = current_stream_id
return (messages, stream_pos) return (messages, stream_pos)
@ -389,7 +390,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
messages = [] messages = []
for row in txn: for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(simplejson.loads(row[1])) messages.append(json.loads(row[1]))
if len(messages) < limit: if len(messages) < limit:
stream_pos = current_stream_id stream_pos = current_stream_id
return (messages, stream_pos) return (messages, stream_pos)

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import simplejson as json
from twisted.internet import defer from twisted.internet import defer
@ -21,6 +20,8 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore, Cache from ._base import SQLBaseStore, Cache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
from canonicaljson import json
from six import itervalues, iteritems from six import itervalues, iteritems
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -16,8 +16,7 @@ from twisted.internet import defer
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json, json
import simplejson as json
from ._base import SQLBaseStore from ._base import SQLBaseStore

View File

@ -19,7 +19,8 @@ from twisted.internet import defer
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
import logging import logging
import simplejson as json
from canonicaljson import json
from six import iteritems from six import iteritems

View File

@ -19,7 +19,8 @@ from functools import wraps
import itertools import itertools
import logging import logging
import simplejson as json from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
@ -800,7 +801,8 @@ class EventsStore(EventsWorkerStore):
] ]
) )
self._curr_state_delta_stream_cache.entity_has_changed( txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
room_id, max_stream_order, room_id, max_stream_order,
) )

View File

@ -29,7 +29,8 @@ from synapse.api.errors import SynapseError
from collections import namedtuple from collections import namedtuple
import logging import logging
import simplejson as json
from canonicaljson import json
# these are only included to make the type annotations work # these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase # noqa: F401

View File

@ -19,8 +19,7 @@ from ._base import SQLBaseStore
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json, json
import simplejson as json
class FilteringStore(SQLBaseStore): class FilteringStore(SQLBaseStore):

View File

@ -20,7 +20,7 @@ from synapse.api.errors import SynapseError
from ._base import SQLBaseStore from ._base import SQLBaseStore
import simplejson as json from canonicaljson import json
# The category ID for the "default" category. We don't store as null in the # The category ID for the "default" category. We don't store as null in the

View File

@ -25,9 +25,10 @@ from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from twisted.internet import defer from twisted.internet import defer
from canonicaljson import json
import abc import abc
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -17,12 +17,11 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json, json
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
import logging import logging
import simplejson as json
import types import types
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -21,9 +21,10 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer from twisted.internet import defer
from canonicaljson import json
import abc import abc
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -623,7 +623,9 @@ class RegistrationStore(RegistrationWorkerStore,
Removes the given user to the table of users who need to be parted from all the Removes the given user to the table of users who need to be parted from all the
rooms they're in, effectively marking that user as fully deactivated. rooms they're in, effectively marking that user as fully deactivated.
""" """
return self._simple_delete_one( # XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
return self._simple_delete(
"users_pending_deactivation", "users_pending_deactivation",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,

View File

@ -20,9 +20,10 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.search import SearchStore from synapse.storage.search import SearchStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from canonicaljson import json
import collections import collections
import logging import logging
import simplejson as json
import re import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -28,7 +28,7 @@ from synapse.api.constants import Membership, EventTypes
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
import logging import logging
import simplejson as json from canonicaljson import json
from six import itervalues, iteritems from six import itervalues, iteritems

View File

@ -0,0 +1,21 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- a table of users who have requested that their details be erased
CREATE TABLE erased_users (
user_id TEXT NOT NULL
);
CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id);

View File

@ -16,7 +16,7 @@
from collections import namedtuple from collections import namedtuple
import logging import logging
import re import re
import simplejson as json from canonicaljson import json
from six import string_types from six import string_types

View File

@ -19,7 +19,8 @@ from synapse.storage.account_data import AccountDataWorkerStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from twisted.internet import defer from twisted.internet import defer
import simplejson as json from canonicaljson import json
import logging import logging
from six.moves import range from six.moves import range

View File

@ -19,12 +19,11 @@ from synapse.util.caches.descriptors import cached
from twisted.internet import defer from twisted.internet import defer
import six import six
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json, json
from collections import namedtuple from collections import namedtuple
import logging import logging
import simplejson as json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview # despite being deprecated and removed in favor of memoryview

View File

@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedList, cached
class UserErasureWorkerStore(SQLBaseStore):
@cached()
def is_user_erased(self, user_id):
"""
Check if the given user id has requested erasure
Args:
user_id (str): full user id to check
Returns:
Deferred[bool]: True if the user has requested erasure
"""
return self._simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
desc="is_user_erased",
).addCallback(operator.truth)
@cachedList(
cached_method_name="is_user_erased",
list_name="user_ids",
inlineCallbacks=True,
)
def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure
Args:
user_ids (iterable[str]): full user id to check
Returns:
Deferred[dict[str, bool]]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
def _get_erased_users(txn):
txn.execute(
"SELECT user_id FROM erased_users WHERE user_id IN (%s)" % (
",".join("?" * len(user_ids))
),
user_ids,
)
return set(r[0] for r in txn)
erased_users = yield self.runInteraction(
"are_users_erased", _get_erased_users,
)
res = dict((u, u in erased_users) for u in user_ids)
defer.returnValue(res)
class UserErasureStore(UserErasureWorkerStore):
def mark_user_erased(self, user_id):
"""Indicate that user_id wishes their message history to be erased.
Args:
user_id (str): full user_id to be erased
"""
def f(txn):
# first check if they are already in the list
txn.execute(
"SELECT 1 FROM erased_users WHERE user_id = ?",
(user_id, )
)
if txn.fetchone():
return
# they are not already there: do the insert.
txn.execute(
"INSERT INTO erased_users (user_id) VALUES (?)",
(user_id, )
)
self._invalidate_cache_and_stream(
txn, self.is_user_erased, (user_id,)
)
return self.runInteraction("mark_user_erased", f)

View File

@ -34,6 +34,9 @@ def unwrapFirstError(failure):
class Clock(object): class Clock(object):
""" """
A Clock wraps a Twisted reactor and provides utilities on top of it. A Clock wraps a Twisted reactor and provides utilities on top of it.
Args:
reactor: The Twisted reactor to use.
""" """
_reactor = attr.ib() _reactor = attr.ib()

View File

@ -78,7 +78,8 @@ class StreamChangeCache(object):
not_known_entities = set(entities) - set(self._entity_to_key) not_known_entities = set(entities) - set(self._entity_to_key)
result = ( result = (
set(self._cache.values()[self._cache.bisect_right(stream_pos) :]) {self._cache[k] for k in self._cache.islice(
start=self._cache.bisect_right(stream_pos))}
.intersection(entities) .intersection(entities)
.union(not_known_entities) .union(not_known_entities)
) )
@ -113,7 +114,8 @@ class StreamChangeCache(object):
assert type(stream_pos) is int assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos: if stream_pos >= self._earliest_known_stream_pos:
return self._cache.values()[self._cache.bisect_right(stream_pos) :] return [self._cache[k] for k in self._cache.islice(
start=self._cache.bisect_right(stream_pos))]
else: else:
return None return None

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from frozendict import frozendict from frozendict import frozendict
import simplejson as json from canonicaljson import json
from six import string_types from six import string_types

View File

@ -12,15 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools
import logging import logging
import operator
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import (
make_deferred_yieldable, preserve_fn,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -95,16 +97,27 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
if ignore_dict_content else [] if ignore_dict_content else []
) )
erased_senders = yield store.are_users_erased((e.sender for e in events))
def allowed(event): def allowed(event):
""" """
Args: Args:
event (synapse.events.EventBase): event to check event (synapse.events.EventBase): event to check
Returns:
None|EventBase:
None if the user cannot see this event at all
a redacted copy of the event if they can only see a redacted
version
the original event if they can see it as normal.
""" """
if not event.is_state() and event.sender in ignore_list: if not event.is_state() and event.sender in ignore_list:
return False return None
if event.event_id in always_include_ids: if event.event_id in always_include_ids:
return True return event
state = event_id_to_state[event.event_id] state = event_id_to_state[event.event_id]
@ -118,10 +131,6 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
if visibility not in VISIBILITY_PRIORITY: if visibility not in VISIBILITY_PRIORITY:
visibility = "shared" visibility = "shared"
# if it was world_readable, it's easy: everyone can read it
if visibility == "world_readable":
return True
# Always allow history visibility events on boundaries. This is done # Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive # by setting the effective visibility to the least restrictive
# of the old vs new. # of the old vs new.
@ -155,7 +164,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
if membership == "leave" and ( if membership == "leave" and (
prev_membership == "join" or prev_membership == "invite" prev_membership == "join" or prev_membership == "invite"
): ):
return True return event
new_priority = MEMBERSHIP_PRIORITY.index(membership) new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
@ -166,31 +175,55 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
if membership is None: if membership is None:
membership_event = state.get((EventTypes.Member, user_id), None) membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event: if membership_event:
# XXX why do we do this?
# https://github.com/matrix-org/synapse/issues/3350
if membership_event.event_id not in event_id_forgotten: if membership_event.event_id not in event_id_forgotten:
membership = membership_event.membership membership = membership_event.membership
# if the user was a member of the room at the time of the event, # if the user was a member of the room at the time of the event,
# they can see it. # they can see it.
if membership == Membership.JOIN: if membership == Membership.JOIN:
return True return event
# otherwise, it depends on the room visibility.
if visibility == "joined": if visibility == "joined":
# we weren't a member at the time of the event, so we can't # we weren't a member at the time of the event, so we can't
# see this event. # see this event.
return False return None
elif visibility == "invited": elif visibility == "invited":
# user can also see the event if they were *invited* at the time # user can also see the event if they were *invited* at the time
# of the event. # of the event.
return membership == Membership.INVITE return (
event if membership == Membership.INVITE else None
)
else: elif visibility == "shared" and is_peeking:
# visibility is shared: user can also see the event if they have # if the visibility is shared, users cannot see the event unless
# become a member since the event # they have *subequently* joined the room (or were members at the
# time, of course)
# #
# XXX: if the user has subsequently joined and then left again, # XXX: if the user has subsequently joined and then left again,
# ideally we would share history up to the point they left. But # ideally we would share history up to the point they left. But
# we don't know when they left. # we don't know when they left. We just treat it as though they
return not is_peeking # never joined, and restrict access.
return None
defer.returnValue(list(filter(allowed, events))) # the visibility is either shared or world_readable, and the user was
# not a member at the time. We allow it, provided the original sender
# has not requested their data to be erased, in which case, we return
# a redacted version.
if erased_senders[event.sender]:
return prune_event(event)
return event
# check each event: gives an iterable[None|EventBase]
filtered_events = itertools.imap(allowed, events)
# remove the None entries
filtered_events = filter(operator.truth, filtered_events)
# we turn it into a list before returning it.
defer.returnValue(list(filtered_events))

View File

@ -86,16 +86,53 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self): def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(
token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None,
)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_good_ip(self):
from netaddr import IPSet
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
from netaddr import IPSet
app_service = Mock(
token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
@ -119,12 +156,16 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self): def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org" masquerading_user_id = "@doppelganger:matrix.org"
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(
token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None,
)
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@ -133,12 +174,16 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org" masquerading_user_id = "@doppelganger:matrix.org"
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(
token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None,
)
app_service.is_interested_in_user = Mock(return_value=False) app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()

View File

View File

@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.events import FrozenEvent
from synapse.federation.federation_server import server_matches_acl_event
from tests import unittest
@unittest.DEBUG
class ServerACLsTestCase(unittest.TestCase):
def test_blacklisted_server(self):
e = _create_acl_event({
"allow": ["*"],
"deny": ["evil.com"],
})
logging.info("ACL event: %s", e.content)
self.assertFalse(server_matches_acl_event("evil.com", e))
self.assertFalse(server_matches_acl_event("EVIL.COM", e))
self.assertTrue(server_matches_acl_event("evil.com.au", e))
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
def test_block_ip_literals(self):
e = _create_acl_event({
"allow_ip_literals": False,
"allow": ["*"],
})
logging.info("ACL event: %s", e.content)
self.assertFalse(server_matches_acl_event("1.2.3.4", e))
self.assertTrue(server_matches_acl_event("1a.2.3.4", e))
self.assertFalse(server_matches_acl_event("[1:2::]", e))
self.assertTrue(server_matches_acl_event("1:2:3:4", e))
def _create_acl_event(content):
return FrozenEvent({
"room_id": "!a:b",
"event_id": "$a:b",
"type": "m.room.server_acls",
"sender": "@a:b",
"content": content
})

0
tests/http/__init__.py Normal file
View File

View File

@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.endpoint import (
parse_server_name,
parse_and_validate_server_name,
)
from tests import unittest
class ServerNameTestCase(unittest.TestCase):
def test_parse_server_name(self):
test_data = {
'localhost': ('localhost', None),
'my-example.com:1234': ('my-example.com', 1234),
'1.2.3.4': ('1.2.3.4', None),
'[0abc:1def::1234]': ('[0abc:1def::1234]', None),
'1.2.3.4:1': ('1.2.3.4', 1),
'[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080),
}
for i, o in test_data.items():
self.assertEqual(parse_server_name(i), o)
def test_validate_bad_server_names(self):
test_data = [
"", # empty
"localhost:http", # non-numeric port
"1234]", # smells like ipv6 literal but isn't
"[1234",
"underscore_.com",
"percent%65.com",
"1234:5678:80", # too many colons
]
for i in test_data:
try:
parse_and_validate_server_name(i)
self.fail(
"Expected parse_and_validate_server_name('%s') to throw" % (
i,
),
)
except ValueError:
pass

181
tests/server.py Normal file
View File

@ -0,0 +1,181 @@
from io import BytesIO
import attr
import json
from six import text_type
from twisted.python.failure import Failure
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactorClock
from synapse.http.site import SynapseRequest
from twisted.internet import threads
from tests.utils import setup_test_homeserver as _sth
@attr.s
class FakeChannel(object):
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
"""
result = attr.ib(factory=dict)
@property
def json_body(self):
if not self.result:
raise Exception("No result yet.")
return json.loads(self.result["body"])
def writeHeaders(self, version, code, reason, headers):
self.result["version"] = version
self.result["code"] = code
self.result["reason"] = reason
self.result["headers"] = headers
def write(self, content):
if "body" not in self.result:
self.result["body"] = b""
self.result["body"] += content
def requestDone(self, _self):
self.result["done"] = True
def getPeer(self):
return None
def getHost(self):
return None
@property
def transport(self):
return self
class FakeSite:
"""
A fake Twisted Web Site, with mocks of the extra things that
Synapse adds.
"""
server_version_string = b"1"
site_tag = "test"
@property
def access_logger(self):
class FakeLogger:
def info(self, *args, **kwargs):
pass
return FakeLogger()
def make_request(method, path, content=b""):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
"""
if isinstance(content, text_type):
content = content.encode('utf8')
site = FakeSite()
channel = FakeChannel()
req = SynapseRequest(site, channel)
req.process = lambda: b""
req.content = BytesIO(content)
req.requestReceived(method, path, b"1.1")
return req, channel
def wait_until_result(clock, channel, timeout=100):
"""
Wait until the channel has a result.
"""
clock.run()
x = 0
while not channel.result:
x += 1
if x > timeout:
raise Exception("Timed out waiting for request to finish.")
clock.advance(0.1)
class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
A MemoryReactorClock that supports callFromThread.
"""
def callFromThread(self, callback, *args, **kwargs):
"""
Make the callback fire in the next reactor iteration.
"""
d = Deferred()
d.addCallback(lambda x: callback(*args, **kwargs))
self.callLater(0, d.callback, True)
return d
def setup_test_homeserver(*args, **kwargs):
"""
Set up a synchronous test server, driven by the reactor used by
the homeserver.
"""
d = _sth(*args, **kwargs).result
# Make the thread pool synchronous.
clock = d.get_clock()
pool = d.get_db_pool()
def runWithConnection(func, *args, **kwargs):
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runWithConnection,
func,
*args,
**kwargs
)
def runInteraction(interaction, *args, **kwargs):
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runInteraction,
interaction,
*args,
**kwargs
)
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
class ThreadPool:
"""
Threadless thread pool.
"""
def start(self):
pass
def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
def _(res):
if isinstance(res, Failure):
onResult(False, res)
else:
onResult(True, res)
d = Deferred()
d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_)
clock._reactor.callLater(0, d.callback, True)
return d
clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool()
return d

243
tests/test_federation.py Normal file
View File

@ -0,0 +1,243 @@
from twisted.internet.defer import succeed, maybeDeferred
from synapse.util import Clock
from synapse.events import FrozenEvent
from synapse.types import Requester, UserID
from tests import unittest
from tests.server import setup_test_homeserver, ThreadedMemoryReactorClock
from mock import Mock
class MessageAcceptTests(unittest.TestCase):
def setUp(self):
self.http_client = Mock()
self.reactor = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor
)
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
room = room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
)
self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"]
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
)
)[0]
join_event = FrozenEvent(
{
"room_id": self.room_id,
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
"event_id": "$join:test.serv",
"depth": 1000,
"origin_server_ts": 1,
"type": "m.room.member",
"origin": "test.servx",
"content": {"membership": "join"},
"auth_events": [],
"prev_state": [(most_recent, {})],
"prev_events": [(most_recent, {})],
}
)
self.handler = self.homeserver.get_handlers().federation_handler
self.handler.do_auth = lambda *a, **b: succeed(True)
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus
)
# Send the join, it should return None (which is not an error)
d = self.handler.on_receive_pdu(
"test.serv", join_event, sent_to_us_directly=True
)
self.reactor.advance(1)
self.assertEqual(self.successResultOf(d), None)
# Make sure we actually joined the room
self.assertEqual(
self.successResultOf(
maybeDeferred(
self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
)
)[0],
"$join:test.serv",
)
def test_cant_hide_direct_ancestors(self):
"""
If you send a message, you must be able to provide the direct
prev_events that said event references.
"""
def post_json(destination, path, data, headers=None, timeout=0):
# If it asks us for new missing events, give them NOTHING
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {"events": []}
self.http_client.post_json = post_json
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
)
)[0]
# Now lie about an event
lying_event = FrozenEvent(
{
"room_id": self.room_id,
"sender": "@baduser:test.serv",
"event_id": "one:test.serv",
"depth": 1000,
"origin_server_ts": 1,
"type": "m.room.message",
"origin": "test.serv",
"content": "hewwo?",
"auth_events": [],
"prev_events": [("two:test.serv", {}), (most_recent, {})],
}
)
d = self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
)
# Step the reactor, so the database fetches come back
self.reactor.advance(1)
# on_receive_pdu should throw an error
failure = self.failureResultOf(d)
self.assertEqual(
failure.value.args[0],
(
"ERROR 403: Your server isn't divulging details about prev_events "
"referenced in this event."
),
)
# Make sure the invalid event isn't there
extrem = maybeDeferred(
self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
@unittest.DEBUG
def test_cant_hide_past_history(self):
"""
If you send a message, you must be able to provide the direct
prev_events that said event references.
"""
def post_json(destination, path, data, headers=None, timeout=0):
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {
"events": [
{
"room_id": self.room_id,
"sender": "@baduser:test.serv",
"event_id": "three:test.serv",
"depth": 1000,
"origin_server_ts": 1,
"type": "m.room.message",
"origin": "test.serv",
"content": "hewwo?",
"auth_events": [],
"prev_events": [("four:test.serv", {})],
}
]
}
self.http_client.post_json = post_json
def get_json(destination, path, args, headers=None):
if path.startswith("/_matrix/federation/v1/state_ids/"):
d = self.successResultOf(
self.homeserver.datastore.get_state_ids_for_event("one:test.serv")
)
return succeed(
{
"pdu_ids": [
y
for x, y in d.items()
if x == ("m.room.member", "@us:test")
],
"auth_chain_ids": d.values(),
}
)
self.http_client.get_json = get_json
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
)
)[0]
# Make a good event
good_event = FrozenEvent(
{
"room_id": self.room_id,
"sender": "@baduser:test.serv",
"event_id": "one:test.serv",
"depth": 1000,
"origin_server_ts": 1,
"type": "m.room.message",
"origin": "test.serv",
"content": "hewwo?",
"auth_events": [],
"prev_events": [(most_recent, {})],
}
)
d = self.handler.on_receive_pdu(
"test.serv", good_event, sent_to_us_directly=True
)
self.reactor.advance(1)
self.assertEqual(self.successResultOf(d), None)
bad_event = FrozenEvent(
{
"room_id": self.room_id,
"sender": "@baduser:test.serv",
"event_id": "two:test.serv",
"depth": 1000,
"origin_server_ts": 1,
"type": "m.room.message",
"origin": "test.serv",
"content": "hewwo?",
"auth_events": [],
"prev_events": [("one:test.serv", {}), ("three:test.serv", {})],
}
)
d = self.handler.on_receive_pdu(
"test.serv", bad_event, sent_to_us_directly=True
)
self.reactor.advance(1)
extrem = maybeDeferred(
self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
)
self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv")
state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id)
self.reactor.advance(1)
self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys())

128
tests/test_server.py Normal file
View File

@ -0,0 +1,128 @@
import json
import re
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactorClock
from synapse.util import Clock
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from tests import unittest
from tests.server import make_request, setup_test_homeserver
class JsonResourceTests(unittest.TestCase):
def setUp(self):
self.reactor = MemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.reactor
)
def test_handler_for_request(self):
"""
JsonResource.handler_for_request gives correctly decoded URL args to
the callback, while Twisted will give the raw bytes of URL query
arguments.
"""
got_kwargs = {}
def _callback(request, **kwargs):
got_kwargs.update(kwargs)
return (200, kwargs)
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback)
request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83")
request.render(res)
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
def test_callback_direct_exception(self):
"""
If the web callback raises an uncaught exception, it will be translated
into a 500.
"""
def _callback(request, **kwargs):
raise Exception("boo")
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback)
request, channel = make_request(b"GET", b"/foo")
request.render(res)
self.assertEqual(channel.result["code"], b'500')
def test_callback_indirect_exception(self):
"""
If the web callback raises an uncaught exception in a Deferred, it will
be translated into a 500.
"""
def _throw(*args):
raise Exception("boo")
def _callback(request, **kwargs):
d = Deferred()
d.addCallback(_throw)
self.reactor.callLater(1, d.callback, True)
return d
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback)
request, channel = make_request(b"GET", b"/foo")
request.render(res)
# No error has been raised yet
self.assertTrue("code" not in channel.result)
# Advance time, now there's an error
self.reactor.advance(1)
self.assertEqual(channel.result["code"], b'500')
def test_callback_synapseerror(self):
"""
If the web callback raises a SynapseError, it returns the appropriate
status code and message set in it.
"""
def _callback(request, **kwargs):
raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback)
request, channel = make_request(b"GET", b"/foo")
request.render(res)
self.assertEqual(channel.result["code"], b'403')
reply_body = json.loads(channel.result["body"])
self.assertEqual(reply_body["error"], "Forbidden!!one!")
self.assertEqual(reply_body["errcode"], "M_FORBIDDEN")
def test_no_handler(self):
"""
If there is no handler to process the request, Synapse will return 400.
"""
def _callback(request, **kwargs):
"""
Not ever actually called!
"""
self.fail("shouldn't ever get here")
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback)
request, channel = make_request(b"GET", b"/foobar")
request.render(res)
self.assertEqual(channel.result["code"], b'400')
reply_body = json.loads(channel.result["body"])
self.assertEqual(reply_body["error"], "Unrecognized request")
self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED")

View File

@ -35,7 +35,10 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record): def emit(self, record):
log_entry = self.format(record) log_entry = self.format(record)
log_level = record.levelname.lower().replace('warning', 'warn') log_level = record.levelname.lower().replace('warning', 'warn')
self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry) self.tx_log.emit(
twisted.logger.LogLevel.levelWithName(log_level),
log_entry.replace("{", r"(").replace("}", r")"),
)
handler = ToTwistedHandler() handler = ToTwistedHandler()

View File

@ -19,13 +19,19 @@ import logging
import mock import mock
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util import logcontext from synapse.util import logcontext
from twisted.internet import defer from twisted.internet import defer, reactor
from synapse.util.caches import descriptors from synapse.util.caches import descriptors
from tests import unittest from tests import unittest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_on_reactor():
d = defer.Deferred()
reactor.callLater(0, d.callback, 0)
return logcontext.make_deferred_yieldable(d)
class CacheTestCase(unittest.TestCase): class CacheTestCase(unittest.TestCase):
def test_invalidate_all(self): def test_invalidate_all(self):
cache = descriptors.Cache("testcache") cache = descriptors.Cache("testcache")
@ -194,6 +200,8 @@ class DescriptorTestCase(unittest.TestCase):
def fn(self, arg1): def fn(self, arg1):
@defer.inlineCallbacks @defer.inlineCallbacks
def inner_fn(): def inner_fn():
# we want this to behave like an asynchronous function
yield run_on_reactor()
raise SynapseError(400, "blah") raise SynapseError(400, "blah")
return inner_fn() return inner_fn()
@ -203,7 +211,12 @@ class DescriptorTestCase(unittest.TestCase):
with logcontext.LoggingContext() as c1: with logcontext.LoggingContext() as c1:
c1.name = "c1" c1.name = "c1"
try: try:
yield obj.fn(1) d = obj.fn(1)
self.assertEqual(
logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel,
)
yield d
self.fail("No exception thrown") self.fail("No exception thrown")
except SynapseError: except SynapseError:
pass pass

View File

@ -102,3 +102,11 @@ basepython = python2.7
deps = deps =
flake8 flake8
commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}" commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}"
[testenv:check-newsfragment]
skip_install = True
deps = towncrier>=18.6.0rc1
commands =
python -m towncrier.check --compare-with=origin/develop
basepython = python3.6