Merge branch 'develop' into rav/saml2_client

This commit is contained in:
Richard van der Hoff 2019-06-26 22:34:41 +01:00
commit a4daa899ec
478 changed files with 18927 additions and 11500 deletions

View file

@ -17,14 +17,22 @@
""" This is a reference implementation of a Matrix home server.
"""
import sys
# Check that we're not running on an unsupported Python version.
if sys.version_info < (3, 5):
print("Synapse requires Python 3.5 or above.")
sys.exit(1)
try:
from twisted.internet import protocol
from twisted.internet.protocol import Factory
from twisted.names.dns import DNSDatagramProtocol
protocol.Factory.noisy = False
Factory.noisy = False
DNSDatagramProtocol.noisy = False
except ImportError:
pass
__version__ = "1.0.0rc3"
__version__ = "1.0.0"

View file

@ -57,18 +57,18 @@ def request_registration(
nonce = r.json()["nonce"]
mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1)
mac = hmac.new(key=shared_secret.encode("utf8"), digestmod=hashlib.sha1)
mac.update(nonce.encode('utf8'))
mac.update(nonce.encode("utf8"))
mac.update(b"\x00")
mac.update(user.encode('utf8'))
mac.update(user.encode("utf8"))
mac.update(b"\x00")
mac.update(password.encode('utf8'))
mac.update(password.encode("utf8"))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
if user_type:
mac.update(b"\x00")
mac.update(user_type.encode('utf8'))
mac.update(user_type.encode("utf8"))
mac = mac.hexdigest()
@ -134,8 +134,9 @@ def register_new_user(user, password, server_location, shared_secret, admin, use
else:
admin = False
request_registration(user, password, server_location, shared_secret,
bool(admin), user_type)
request_registration(
user, password, server_location, shared_secret, bool(admin), user_type
)
def main():
@ -189,7 +190,7 @@ def main():
group.add_argument(
"-c",
"--config",
type=argparse.FileType('r'),
type=argparse.FileType("r"),
help="Path to server config file. Used to read in shared secret.",
)
@ -200,7 +201,7 @@ def main():
parser.add_argument(
"server_url",
default="https://localhost:8448",
nargs='?',
nargs="?",
help="URL to use to talk to the home server. Defaults to "
" 'https://localhost:8448'.",
)
@ -220,8 +221,9 @@ def main():
if args.admin or args.no_admin:
admin = args.admin
register_new_user(args.user, args.password, args.server_url, secret,
admin, args.user_type)
register_new_user(
args.user, args.password, args.server_url, secret, admin, args.user_type
)
if __name__ == "__main__":

View file

@ -36,8 +36,11 @@ logger = logging.getLogger(__name__)
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
EventTypes.Create,
EventTypes.Member,
EventTypes.PowerLevels,
EventTypes.JoinRules,
EventTypes.RoomHistoryVisibility,
EventTypes.ThirdPartyInvite,
)
@ -54,6 +57,7 @@ class Auth(object):
FIXME: This class contains a mix of functions for authenticating users
of our client-server API and authenticating events added to room graphs.
"""
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
@ -70,15 +74,12 @@ class Auth(object):
def check_from_context(self, room_version, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True,
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in itervalues(auth_events)
}
auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
self.check(
room_version, event,
auth_events=auth_events, do_sig_check=do_sig_check,
room_version, event, auth_events=auth_events, do_sig_check=do_sig_check
)
def check(self, room_version, event, auth_events, do_sig_check=True):
@ -115,15 +116,10 @@ class Auth(object):
the room.
"""
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
None
)
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
self._check_joined_room(member, user_id, room_id)
@ -143,23 +139,17 @@ class Auth(object):
the room. This will be the leave event if they have left the room.
"""
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
if membership not in (Membership.JOIN, Membership.LEAVE):
raise AuthError(403, "User %s not in room %s" % (
user_id, room_id
))
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
if membership == Membership.LEAVE:
forgot = yield self.store.did_forget(user_id, room_id)
if forgot:
raise AuthError(403, "User %s not in room %s" % (
user_id, room_id
))
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
defer.returnValue(member)
@ -171,9 +161,9 @@ class Auth(object):
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
raise AuthError(
403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
)
def can_federate(self, event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
@ -185,11 +175,7 @@ class Auth(object):
@defer.inlineCallbacks
def get_user_by_req(
self,
request,
allow_guest=False,
rights="access",
allow_expired=False,
self, request, allow_guest=False, rights="access", allow_expired=False
):
""" Get a registered user's ID.
@ -209,9 +195,8 @@ class Auth(object):
try:
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent",
default=[b""]
)[0].decode('ascii', 'surrogateescape')
b"User-Agent", default=[b""]
)[0].decode("ascii", "surrogateescape")
access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
@ -243,11 +228,12 @@ class Auth(object):
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
if expiration_ts is not None and self.clock.time_msec() >= expiration_ts:
if (
expiration_ts is not None
and self.clock.time_msec() >= expiration_ts
):
raise AuthError(
403,
"User account has expired",
errcode=Codes.EXPIRED_ACCOUNT,
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
# device_id may not be present if get_user_by_access_token has been
@ -265,18 +251,23 @@ class Auth(object):
if is_guest and not allow_guest:
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
403,
"Guest access not allowed",
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
request.authenticated_entity = user.to_string()
defer.returnValue(synapse.types.create_requester(
user, token_id, is_guest, device_id, app_service=app_service)
defer.returnValue(
synapse.types.create_requester(
user, token_id, is_guest, device_id, app_service=app_service
)
)
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Missing access token.",
errcode=Codes.MISSING_TOKEN,
)
@defer.inlineCallbacks
@ -297,20 +288,14 @@ class Auth(object):
if b"user_id" not in request.args:
defer.returnValue((app_service.sender, app_service))
user_id = request.args[b"user_id"][0].decode('utf8')
user_id = request.args[b"user_id"][0].decode("utf8")
if app_service.sender == user_id:
defer.returnValue((app_service.sender, app_service))
if not app_service.is_interested_in_user(user_id):
raise AuthError(
403,
"Application service cannot masquerade as this user."
)
raise AuthError(403, "Application service cannot masquerade as this user.")
if not (yield self.store.get_user_by_id(user_id)):
raise AuthError(
403,
"Application service has not registered this user"
)
raise AuthError(403, "Application service has not registered this user")
defer.returnValue((user_id, app_service))
@defer.inlineCallbacks
@ -368,13 +353,13 @@ class Auth(object):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unknown user_id %s" % user_id,
errcode=Codes.UNKNOWN_TOKEN
errcode=Codes.UNKNOWN_TOKEN,
)
if not stored_user["is_guest"]:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Guest access token used for regular user",
errcode=Codes.UNKNOWN_TOKEN
errcode=Codes.UNKNOWN_TOKEN,
)
ret = {
"user": user,
@ -402,8 +387,9 @@ class Auth(object):
) as e:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN,
)
def _parse_and_validate_macaroon(self, token, rights="access"):
@ -441,13 +427,13 @@ class Auth(object):
guest = True
self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id,
macaroon, rights, self.hs.config.expire_access_token, user_id=user_id
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN,
)
if not has_expiry and rights == "access":
@ -472,10 +458,11 @@ class Auth(object):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
return caveat.caveat_id[len(user_prefix) :]
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN,
)
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
@ -522,7 +509,7 @@ class Auth(object):
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix):])
expiry = int(caveat[len(prefix) :])
now = self.hs.get_clock().time_msec()
return now < expiry
@ -554,14 +541,12 @@ class Auth(object):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
errcode=Codes.UNKNOWN_TOKEN,
)
request.authenticated_entity = service.sender
return defer.succeed(service)
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
)
raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.")
def is_server_admin(self, user):
""" Check if the given user is a local server admin.
@ -581,19 +566,19 @@ class Auth(object):
auth_ids = []
key = (EventTypes.PowerLevels, "", )
key = (EventTypes.PowerLevels, "")
power_level_event_id = current_state_ids.get(key)
if power_level_event_id:
auth_ids.append(power_level_event_id)
key = (EventTypes.JoinRules, "", )
key = (EventTypes.JoinRules, "")
join_rule_event_id = current_state_ids.get(key)
key = (EventTypes.Member, event.sender, )
key = (EventTypes.Member, event.sender)
member_event_id = current_state_ids.get(key)
key = (EventTypes.Create, "", )
key = (EventTypes.Create, "")
create_event_id = current_state_ids.get(key)
if create_event_id:
auth_ids.append(create_event_id)
@ -619,7 +604,7 @@ class Auth(object):
auth_ids.append(member_event_id)
if for_verification:
key = (EventTypes.Member, event.state_key, )
key = (EventTypes.Member, event.state_key)
existing_event_id = current_state_ids.get(key)
if existing_event_id:
auth_ids.append(existing_event_id)
@ -628,7 +613,7 @@ class Auth(object):
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
event.content["third_party_invite"]["signed"]["token"],
)
third_party_invite_id = current_state_ids.get(key)
if third_party_invite_id:
@ -684,7 +669,7 @@ class Auth(object):
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = event_auth.get_send_level(
EventTypes.Aliases, "", power_level_event,
EventTypes.Aliases, "", power_level_event
)
user_level = event_auth.get_user_power_level(user_id, auth_events)
@ -692,7 +677,7 @@ class Auth(object):
raise AuthError(
403,
"This server requires you to be a moderator in the room to"
" edit its room list entry"
" edit its room list entry",
)
@staticmethod
@ -742,7 +727,7 @@ class Auth(object):
)
parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
return parts[1].decode('ascii')
return parts[1].decode("ascii")
else:
raise AuthError(
token_not_found_http_status,
@ -755,10 +740,10 @@ class Auth(object):
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
errcode=Codes.MISSING_TOKEN,
)
return query_params[0].decode('ascii')
return query_params[0].decode("ascii")
@defer.inlineCallbacks
def check_in_room_or_world_readable(self, room_id, user_id):
@ -785,8 +770,8 @@ class Auth(object):
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility and
visibility.content["history_visibility"] == "world_readable"
visibility
and visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return
@ -820,10 +805,11 @@ class Auth(object):
if self.hs.config.hs_disabled:
raise ResourceLimitError(
403, self.hs.config.hs_disabled_message,
403,
self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=self.hs.config.admin_contact,
limit_type=self.hs.config.hs_disabled_limit_type
limit_type=self.hs.config.hs_disabled_limit_type,
)
if self.hs.config.limit_usage_by_mau is True:
assert not (user_id and threepid)
@ -848,8 +834,9 @@ class Auth(object):
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError(
403, "Monthly Active User Limit Exceeded",
403,
"Monthly Active User Limit Exceeded",
admin_contact=self.hs.config.admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type="monthly_active_user"
limit_type="monthly_active_user",
)

View file

@ -18,7 +18,7 @@
"""Contains constants from the specification."""
# the "depth" field on events is limited to 2**63 - 1
MAX_DEPTH = 2**63 - 1
MAX_DEPTH = 2 ** 63 - 1
# the maximum length for a room alias is 255 characters
MAX_ALIAS_LENGTH = 255
@ -30,39 +30,41 @@ MAX_USERID_LENGTH = 255
class Membership(object):
"""Represents the membership states of a user in a room."""
INVITE = u"invite"
JOIN = u"join"
KNOCK = u"knock"
LEAVE = u"leave"
BAN = u"ban"
INVITE = "invite"
JOIN = "join"
KNOCK = "knock"
LEAVE = "leave"
BAN = "ban"
LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
class PresenceState(object):
"""Represents the presence state of a user."""
OFFLINE = u"offline"
UNAVAILABLE = u"unavailable"
ONLINE = u"online"
OFFLINE = "offline"
UNAVAILABLE = "unavailable"
ONLINE = "online"
class JoinRules(object):
PUBLIC = u"public"
KNOCK = u"knock"
INVITE = u"invite"
PRIVATE = u"private"
PUBLIC = "public"
KNOCK = "knock"
INVITE = "invite"
PRIVATE = "private"
class LoginType(object):
PASSWORD = u"m.login.password"
EMAIL_IDENTITY = u"m.login.email.identity"
MSISDN = u"m.login.msisdn"
RECAPTCHA = u"m.login.recaptcha"
TERMS = u"m.login.terms"
DUMMY = u"m.login.dummy"
PASSWORD = "m.login.password"
EMAIL_IDENTITY = "m.login.email.identity"
MSISDN = "m.login.msisdn"
RECAPTCHA = "m.login.recaptcha"
TERMS = "m.login.terms"
DUMMY = "m.login.dummy"
# Only for C/S API v1
APPLICATION_SERVICE = u"m.login.application_service"
SHARED_SECRET = u"org.matrix.login.shared_secret"
APPLICATION_SERVICE = "m.login.application_service"
SHARED_SECRET = "org.matrix.login.shared_secret"
class EventTypes(object):
@ -118,6 +120,7 @@ class UserTypes(object):
"""Allows for user type specific behaviour. With the benefit of hindsight
'admin' and 'guest' users should also be UserTypes. Normal users are type None
"""
SUPPORT = "support"
ALL_USER_TYPES = (SUPPORT,)
@ -125,6 +128,7 @@ class UserTypes(object):
class RelationTypes(object):
"""The types of relations known to this server.
"""
ANNOTATION = "m.annotation"
REPLACE = "m.replace"
REFERENCE = "m.reference"

View file

@ -70,6 +70,7 @@ class CodeMessageException(RuntimeError):
code (int): HTTP error code
msg (str): string describing the error
"""
def __init__(self, code, msg):
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code
@ -83,6 +84,7 @@ class SynapseError(CodeMessageException):
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
"""Constructs a synapse error.
@ -95,10 +97,7 @@ class SynapseError(CodeMessageException):
self.errcode = errcode
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
)
return cs_error(self.msg, self.errcode)
class ProxiedRequestError(SynapseError):
@ -107,27 +106,23 @@ class ProxiedRequestError(SynapseError):
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
super(ProxiedRequestError, self).__init__(
code, msg, errcode
)
super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None:
self._additional_fields = {}
else:
self._additional_fields = dict(additional_fields)
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
**self._additional_fields
)
return cs_error(self.msg, self.errcode, **self._additional_fields)
class ConsentNotGivenError(SynapseError):
"""The error returned to the client when the user has not consented to the
privacy policy.
"""
def __init__(self, msg, consent_uri):
"""Constructs a ConsentNotGivenError
@ -136,22 +131,17 @@ class ConsentNotGivenError(SynapseError):
consent_url (str): The URL where the user can give their consent
"""
super(ConsentNotGivenError, self).__init__(
code=http_client.FORBIDDEN,
msg=msg,
errcode=Codes.CONSENT_NOT_GIVEN
code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN
)
self._consent_uri = consent_uri
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
consent_uri=self._consent_uri
)
return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
pass
@ -190,15 +180,17 @@ class InteractiveAuthIncompleteError(Exception):
result (dict): the server response to the request, which should be
passed back to the client
"""
def __init__(self, result):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete",
"Interactive auth not yet complete"
)
self.result = result
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
@ -207,21 +199,14 @@ class UnrecognizedRequestError(SynapseError):
message = "Unrecognized request"
else:
message = args[0]
super(UnrecognizedRequestError, self).__init__(
400,
message,
**kwargs
)
super(UnrecognizedRequestError, self).__init__(400, message, **kwargs)
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
super(NotFoundError, self).__init__(
404,
msg,
errcode=errcode
)
super(NotFoundError, self).__init__(404, msg, errcode=errcode)
class AuthError(SynapseError):
@ -238,8 +223,11 @@ class ResourceLimitError(SynapseError):
Any error raised when there is a problem with resource usage.
For instance, the monthly active user limit for the server has been exceeded
"""
def __init__(
self, code, msg,
self,
code,
msg,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=None,
limit_type=None,
@ -253,7 +241,7 @@ class ResourceLimitError(SynapseError):
self.msg,
self.errcode,
admin_contact=self.admin_contact,
limit_type=self.limit_type
limit_type=self.limit_type,
)
@ -268,6 +256,7 @@ class EventSizeError(SynapseError):
class EventStreamError(SynapseError):
"""An error raised when there a problem with the event stream."""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.BAD_PAGINATION
@ -276,47 +265,53 @@ class EventStreamError(SynapseError):
class LoginError(SynapseError):
"""An error raised when there was a problem logging in."""
pass
class StoreError(SynapseError):
"""An error raised when there was a problem storing some data."""
pass
class InvalidCaptchaError(SynapseError):
def __init__(self, code=400, msg="Invalid captcha.", error_url=None,
errcode=Codes.CAPTCHA_INVALID):
def __init__(
self,
code=400,
msg="Invalid captcha.",
error_url=None,
errcode=Codes.CAPTCHA_INVALID,
):
super(InvalidCaptchaError, self).__init__(code, msg, errcode)
self.error_url = error_url
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
error_url=self.error_url,
)
return cs_error(self.msg, self.errcode, error_url=self.error_url)
class LimitExceededError(SynapseError):
"""A client has sent too many requests and is being throttled.
"""
def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None,
errcode=Codes.LIMIT_EXCEEDED):
def __init__(
self,
code=429,
msg="Too Many Requests",
retry_after_ms=None,
errcode=Codes.LIMIT_EXCEEDED,
):
super(LimitExceededError, self).__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
retry_after_ms=self.retry_after_ms,
)
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
class RoomKeysVersionError(SynapseError):
"""A client has tried to upload to a non-current version of the room_keys store
"""
def __init__(self, current_version):
"""
Args:
@ -331,6 +326,7 @@ class RoomKeysVersionError(SynapseError):
class UnsupportedRoomVersionError(SynapseError):
"""The client's request to create a room used a room version that the server does
not support."""
def __init__(self):
super(UnsupportedRoomVersionError, self).__init__(
code=400,
@ -354,22 +350,19 @@ class IncompatibleRoomVersionError(SynapseError):
Unlike UnsupportedRoomVersionError, it is specific to the case of the make_join
failing.
"""
def __init__(self, room_version):
super(IncompatibleRoomVersionError, self).__init__(
code=400,
msg="Your homeserver does not support the features required to "
"join this room",
"join this room",
errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
)
self._room_version = room_version
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
room_version=self._room_version,
)
return cs_error(self.msg, self.errcode, room_version=self._room_version)
class RequestSendFailed(RuntimeError):
@ -380,11 +373,11 @@ class RequestSendFailed(RuntimeError):
networking (e.g. DNS failures, connection timeouts etc), versus unexpected
errors (like programming errors).
"""
def __init__(self, inner_exception, can_retry):
super(RequestSendFailed, self).__init__(
"Failed to send request: %s: %s" % (
type(inner_exception).__name__, inner_exception,
)
"Failed to send request: %s: %s"
% (type(inner_exception).__name__, inner_exception)
)
self.inner_exception = inner_exception
self.can_retry = can_retry
@ -428,7 +421,7 @@ class FederationError(RuntimeError):
self.affected = affected
self.source = source
msg = "%s %s: %s" % (level, code, reason,)
msg = "%s %s: %s" % (level, code, reason)
super(FederationError, self).__init__(msg)
def get_dict(self):
@ -448,6 +441,7 @@ class HttpResponseException(CodeMessageException):
Attributes:
response (bytes): body of response
"""
def __init__(self, code, msg, response):
"""
@ -486,7 +480,7 @@ class HttpResponseException(CodeMessageException):
if not isinstance(j, dict):
j = {}
errcode = j.pop('errcode', Codes.UNKNOWN)
errmsg = j.pop('error', self.msg)
errcode = j.pop("errcode", Codes.UNKNOWN)
errmsg = j.pop("error", self.msg)
return ProxiedRequestError(self.code, errmsg, errcode, j)

View file

@ -28,117 +28,55 @@ FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"limit": {
"type": "number"
},
"senders": {
"$ref": "#/definitions/user_id_array"
},
"not_senders": {
"$ref": "#/definitions/user_id_array"
},
"limit": {"type": "number"},
"senders": {"$ref": "#/definitions/user_id_array"},
"not_senders": {"$ref": "#/definitions/user_id_array"},
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
"types": {
"type": "array",
"items": {
"type": "string"
}
},
"not_types": {
"type": "array",
"items": {
"type": "string"
}
}
}
"types": {"type": "array", "items": {"type": "string"}},
"not_types": {"type": "array", "items": {"type": "string"}},
},
}
ROOM_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"not_rooms": {
"$ref": "#/definitions/room_id_array"
},
"rooms": {
"$ref": "#/definitions/room_id_array"
},
"ephemeral": {
"$ref": "#/definitions/room_event_filter"
},
"include_leave": {
"type": "boolean"
},
"state": {
"$ref": "#/definitions/room_event_filter"
},
"timeline": {
"$ref": "#/definitions/room_event_filter"
},
"account_data": {
"$ref": "#/definitions/room_event_filter"
},
}
"not_rooms": {"$ref": "#/definitions/room_id_array"},
"rooms": {"$ref": "#/definitions/room_id_array"},
"ephemeral": {"$ref": "#/definitions/room_event_filter"},
"include_leave": {"type": "boolean"},
"state": {"$ref": "#/definitions/room_event_filter"},
"timeline": {"$ref": "#/definitions/room_event_filter"},
"account_data": {"$ref": "#/definitions/room_event_filter"},
},
}
ROOM_EVENT_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"limit": {
"type": "number"
},
"senders": {
"$ref": "#/definitions/user_id_array"
},
"not_senders": {
"$ref": "#/definitions/user_id_array"
},
"types": {
"type": "array",
"items": {
"type": "string"
}
},
"not_types": {
"type": "array",
"items": {
"type": "string"
}
},
"rooms": {
"$ref": "#/definitions/room_id_array"
},
"not_rooms": {
"$ref": "#/definitions/room_id_array"
},
"contains_url": {
"type": "boolean"
},
"lazy_load_members": {
"type": "boolean"
},
"include_redundant_members": {
"type": "boolean"
},
}
"limit": {"type": "number"},
"senders": {"$ref": "#/definitions/user_id_array"},
"not_senders": {"$ref": "#/definitions/user_id_array"},
"types": {"type": "array", "items": {"type": "string"}},
"not_types": {"type": "array", "items": {"type": "string"}},
"rooms": {"$ref": "#/definitions/room_id_array"},
"not_rooms": {"$ref": "#/definitions/room_id_array"},
"contains_url": {"type": "boolean"},
"lazy_load_members": {"type": "boolean"},
"include_redundant_members": {"type": "boolean"},
},
}
USER_ID_ARRAY_SCHEMA = {
"type": "array",
"items": {
"type": "string",
"format": "matrix_user_id"
}
"items": {"type": "string", "format": "matrix_user_id"},
}
ROOM_ID_ARRAY_SCHEMA = {
"type": "array",
"items": {
"type": "string",
"format": "matrix_room_id"
}
"items": {"type": "string", "format": "matrix_room_id"},
}
USER_FILTER_SCHEMA = {
@ -150,22 +88,13 @@ USER_FILTER_SCHEMA = {
"user_id_array": USER_ID_ARRAY_SCHEMA,
"filter": FILTER_SCHEMA,
"room_filter": ROOM_FILTER_SCHEMA,
"room_event_filter": ROOM_EVENT_FILTER_SCHEMA
"room_event_filter": ROOM_EVENT_FILTER_SCHEMA,
},
"properties": {
"presence": {
"$ref": "#/definitions/filter"
},
"account_data": {
"$ref": "#/definitions/filter"
},
"room": {
"$ref": "#/definitions/room_filter"
},
"event_format": {
"type": "string",
"enum": ["client", "federation"]
},
"presence": {"$ref": "#/definitions/filter"},
"account_data": {"$ref": "#/definitions/filter"},
"room": {"$ref": "#/definitions/room_filter"},
"event_format": {"type": "string", "enum": ["client", "federation"]},
"event_fields": {
"type": "array",
"items": {
@ -177,26 +106,25 @@ USER_FILTER_SCHEMA = {
#
# Note that because this is a regular expression, we have to escape
# each backslash in the pattern.
"pattern": r"^((?!\\\\).)*$"
}
}
"pattern": r"^((?!\\\\).)*$",
},
},
},
"additionalProperties": False
"additionalProperties": False,
}
@FormatChecker.cls_checks('matrix_room_id')
@FormatChecker.cls_checks("matrix_room_id")
def matrix_room_id_validator(room_id_str):
return RoomID.from_string(room_id_str)
@FormatChecker.cls_checks('matrix_user_id')
@FormatChecker.cls_checks("matrix_user_id")
def matrix_user_id_validator(user_id_str):
return UserID.from_string(user_id_str)
class Filtering(object):
def __init__(self, hs):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
@ -228,8 +156,9 @@ class Filtering(object):
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
try:
jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
format_checker=FormatChecker())
jsonschema.validate(
user_filter_json, USER_FILTER_SCHEMA, format_checker=FormatChecker()
)
except jsonschema.ValidationError as e:
raise SynapseError(400, str(e))
@ -240,10 +169,9 @@ class FilterCollection(object):
room_filter_json = self._filter_json.get("room", {})
self._room_filter = Filter({
k: v for k, v in room_filter_json.items()
if k in ("rooms", "not_rooms")
})
self._room_filter = Filter(
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
)
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(room_filter_json.get("state", {}))
@ -252,9 +180,7 @@ class FilterCollection(object):
self._presence_filter = Filter(filter_json.get("presence", {}))
self._account_data = Filter(filter_json.get("account_data", {}))
self.include_leave = filter_json.get("room", {}).get(
"include_leave", False
)
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
@ -299,22 +225,22 @@ class FilterCollection(object):
def blocks_all_presence(self):
return (
self._presence_filter.filters_all_types() or
self._presence_filter.filters_all_senders()
self._presence_filter.filters_all_types()
or self._presence_filter.filters_all_senders()
)
def blocks_all_room_ephemeral(self):
return (
self._room_ephemeral_filter.filters_all_types() or
self._room_ephemeral_filter.filters_all_senders() or
self._room_ephemeral_filter.filters_all_rooms()
self._room_ephemeral_filter.filters_all_types()
or self._room_ephemeral_filter.filters_all_senders()
or self._room_ephemeral_filter.filters_all_rooms()
)
def blocks_all_room_timeline(self):
return (
self._room_timeline_filter.filters_all_types() or
self._room_timeline_filter.filters_all_senders() or
self._room_timeline_filter.filters_all_rooms()
self._room_timeline_filter.filters_all_types()
or self._room_timeline_filter.filters_all_senders()
or self._room_timeline_filter.filters_all_rooms()
)
@ -375,12 +301,7 @@ class Filter(object):
# check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), text_type)
return self.check_fields(
room_id,
sender,
ev_type,
contains_url,
)
return self.check_fields(room_id, sender, ev_type, contains_url)
def check_fields(self, room_id, sender, event_type, contains_url):
"""Checks whether the filter matches the given event fields.
@ -391,7 +312,7 @@ class Filter(object):
literal_keys = {
"rooms": lambda v: room_id == v,
"senders": lambda v: sender == v,
"types": lambda v: _matches_wildcard(event_type, v)
"types": lambda v: _matches_wildcard(event_type, v),
}
for name, match_func in literal_keys.items():

View file

@ -44,29 +44,25 @@ class Ratelimiter(object):
"""
self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.get(
key, (0., time_now_s, None),
key, (0.0, time_now_s, None)
)
time_delta = time_now_s - time_start
sent_count = message_count - time_delta * rate_hz
if sent_count < 0:
allowed = True
time_start = time_now_s
message_count = 1.
elif sent_count > burst_count - 1.:
message_count = 1.0
elif sent_count > burst_count - 1.0:
allowed = False
else:
allowed = True
message_count += 1
if update:
self.message_counts[key] = (
message_count, time_start, rate_hz
)
self.message_counts[key] = (message_count, time_start, rate_hz)
if rate_hz > 0:
time_allowed = (
time_start + (message_count - burst_count + 1) / rate_hz
)
time_allowed = time_start + (message_count - burst_count + 1) / rate_hz
if time_allowed < time_now_s:
time_allowed = time_now_s
else:
@ -76,9 +72,7 @@ class Ratelimiter(object):
def prune_message_counts(self, time_now_s):
for key in list(self.message_counts.keys()):
message_count, time_start, rate_hz = (
self.message_counts[key]
)
message_count, time_start, rate_hz = self.message_counts[key]
time_delta = time_now_s - time_start
if message_count - time_delta * rate_hz > 0:
break
@ -92,5 +86,5 @@ class Ratelimiter(object):
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s)),
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)

View file

@ -19,9 +19,10 @@ class EventFormatVersions(object):
"""This is an internal enum for tracking the version of the event format,
independently from the room version.
"""
V1 = 1 # $id:server event id format
V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
V3 = 3 # MSC1884-style $hash format: introduced for room v4
V1 = 1 # $id:server event id format
V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
V3 = 3 # MSC1884-style $hash format: introduced for room v4
KNOWN_EVENT_FORMAT_VERSIONS = {
@ -33,8 +34,9 @@ KNOWN_EVENT_FORMAT_VERSIONS = {
class StateResolutionVersions(object):
"""Enum to identify the state resolution algorithms"""
V1 = 1 # room v1 state res
V2 = 2 # MSC1442 state res: room v2 and later
V1 = 1 # room v1 state res
V2 = 2 # MSC1442 state res: room v2 and later
class RoomDisposition(object):
@ -46,10 +48,10 @@ class RoomDisposition(object):
class RoomVersion(object):
"""An object which describes the unique attributes of a room version."""
identifier = attr.ib() # str; the identifier for this version
disposition = attr.ib() # str; one of the RoomDispositions
event_format = attr.ib() # int; one of the EventFormatVersions
state_res = attr.ib() # int; one of the StateResolutionVersions
identifier = attr.ib() # str; the identifier for this version
disposition = attr.ib() # str; one of the RoomDispositions
event_format = attr.ib() # int; one of the EventFormatVersions
state_res = attr.ib() # int; one of the StateResolutionVersions
enforce_key_validity = attr.ib() # bool
@ -92,11 +94,12 @@ class RoomVersions(object):
KNOWN_ROOM_VERSIONS = {
v.identifier: v for v in (
v.identifier: v
for v in (
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.V3,
RoomVersions.V4,
RoomVersions.V5,
)
} # type: dict[str, RoomVersion]
} # type: dict[str, RoomVersion]

View file

@ -42,13 +42,9 @@ class ConsentURIBuilder(object):
hs_config (synapse.config.homeserver.HomeServerConfig):
"""
if hs_config.form_secret is None:
raise ConfigError(
"form_secret not set in config",
)
raise ConfigError("form_secret not set in config")
if hs_config.public_baseurl is None:
raise ConfigError(
"public_baseurl not set in config",
)
raise ConfigError("public_baseurl not set in config")
self._hmac_secret = hs_config.form_secret.encode("utf-8")
self._public_baseurl = hs_config.public_baseurl
@ -64,15 +60,10 @@ class ConsentURIBuilder(object):
(str) the URI where the user can do consent
"""
mac = hmac.new(
key=self._hmac_secret,
msg=user_id.encode('ascii'),
digestmod=sha256,
key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256
).hexdigest()
consent_uri = "%s_matrix/consent?%s" % (
self._public_baseurl,
urlencode({
"u": user_id,
"h": mac
}),
urlencode({"u": user_id, "h": mac}),
)
return consent_uri

View file

@ -43,7 +43,7 @@ def check_bind_error(e, address, bind_addresses):
address (str): Address on which binding was attempted.
bind_addresses (list): Addresses on which the service listens.
"""
if address == '0.0.0.0' and '::' in bind_addresses:
logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
if address == "0.0.0.0" and "::" in bind_addresses:
logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]")
else:
raise e

View file

@ -19,7 +19,6 @@ import signal
import sys
import traceback
import psutil
from daemonize import Daemonize
from twisted.internet import defer, error, reactor
@ -68,21 +67,13 @@ def start_worker_reactor(appname, config):
gc_thresholds=config.gc_thresholds,
pid_file=config.worker_pid_file,
daemonize=config.worker_daemonize,
cpu_affinity=config.worker_cpu_affinity,
print_pidfile=config.print_pidfile,
logger=logger,
)
def start_reactor(
appname,
soft_file_limit,
gc_thresholds,
pid_file,
daemonize,
cpu_affinity,
print_pidfile,
logger,
appname, soft_file_limit, gc_thresholds, pid_file, daemonize, print_pidfile, logger
):
""" Run the reactor in the main process
@ -95,7 +86,6 @@ def start_reactor(
gc_thresholds:
pid_file (str): name of pid file to write to if daemonize is True
daemonize (bool): true to run the reactor in a background process
cpu_affinity (int|None): cpu affinity mask
print_pidfile (bool): whether to print the pid file, if daemonize is True
logger (logging.Logger): logger instance to pass to Daemonize
"""
@ -109,20 +99,6 @@ def start_reactor(
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
logger.info("Running")
if cpu_affinity is not None:
# Turn the bitmask into bits, reverse it so we go from 0 up
mask_to_bits = bin(cpu_affinity)[2:][::-1]
cpus = []
cpu_num = 0
for i in mask_to_bits:
if i == "1":
cpus.append(cpu_num)
cpu_num += 1
p = psutil.Process()
p.cpu_affinity(cpus)
change_resource_limit(soft_file_limit)
if gc_thresholds:
@ -149,10 +125,10 @@ def start_reactor(
def quit_with_error(error_string):
message_lines = error_string.split("\n")
line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
sys.stderr.write("*" * line_length + '\n')
sys.stderr.write("*" * line_length + "\n")
for line in message_lines:
sys.stderr.write(" %s\n" % (line.rstrip(),))
sys.stderr.write("*" * line_length + '\n')
sys.stderr.write("*" * line_length + "\n")
sys.exit(1)
@ -178,14 +154,7 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
r = []
for address in bind_addresses:
try:
r.append(
reactor.listenTCP(
port,
factory,
backlog,
address
)
)
r.append(reactor.listenTCP(port, factory, backlog, address))
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
@ -205,13 +174,7 @@ def listen_ssl(
for address in bind_addresses:
try:
r.append(
reactor.listenSSL(
port,
factory,
context_factory,
backlog,
address
)
reactor.listenSSL(port, factory, context_factory, backlog, address)
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
@ -243,15 +206,13 @@ def refresh_certificate(hs):
if isinstance(i.factory, TLSMemoryBIOFactory):
addr = i.getHost()
logger.info(
"Replacing TLS context factory on [%s]:%i", addr.host, addr.port,
"Replacing TLS context factory on [%s]:%i", addr.host, addr.port
)
# We want to replace TLS factories with a new one, with the new
# TLS configuration. We do this by reaching in and pulling out
# the wrappedFactory, and then re-wrapping it.
i.factory = TLSMemoryBIOFactory(
hs.tls_server_context_factory,
False,
i.factory.wrappedFactory
hs.tls_server_context_factory, False, i.factory.wrappedFactory
)
logger.info("Context factories updated.")
@ -267,6 +228,7 @@ def start(hs, listeners=None):
try:
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
def handle_sighup(*args, **kwargs):
for i in _sighup_callbacks:
i(hs)
@ -302,10 +264,8 @@ def setup_sentry(hs):
return
import sentry_sdk
sentry_sdk.init(
dsn=hs.config.sentry_dsn,
release=get_version_string(synapse),
)
sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse))
# We set some default tags that give some context to this instance
with sentry_sdk.configure_scope() as scope:
@ -326,7 +286,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
many DNS queries at once
"""
new_resolver = _LimitedHostnameResolver(
reactor.nameResolver, max_dns_requests_in_flight,
reactor.nameResolver, max_dns_requests_in_flight
)
reactor.installNameResolver(new_resolver)
@ -339,11 +299,17 @@ class _LimitedHostnameResolver(object):
def __init__(self, resolver, max_dns_requests_in_flight):
self._resolver = resolver
self._limiter = Linearizer(
name="dns_client_limiter", max_count=max_dns_requests_in_flight,
name="dns_client_limiter", max_count=max_dns_requests_in_flight
)
def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
addressTypes=None, transportSemantics='TCP'):
def resolveHostName(
self,
resolutionReceiver,
hostName,
portNumber=0,
addressTypes=None,
transportSemantics="TCP",
):
# We need this function to return `resolutionReceiver` so we do all the
# actual logic involving deferreds in a separate function.
@ -363,8 +329,14 @@ class _LimitedHostnameResolver(object):
return resolutionReceiver
@defer.inlineCallbacks
def _resolve(self, resolutionReceiver, hostName, portNumber=0,
addressTypes=None, transportSemantics='TCP'):
def _resolve(
self,
resolutionReceiver,
hostName,
portNumber=0,
addressTypes=None,
transportSemantics="TCP",
):
with (yield self._limiter.queue(())):
# resolveHostName doesn't return a Deferred, so we need to hook into
@ -374,8 +346,7 @@ class _LimitedHostnameResolver(object):
receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred)
self._resolver.resolveHostName(
receiver, hostName, portNumber,
addressTypes, transportSemantics,
receiver, hostName, portNumber, addressTypes, transportSemantics
)
yield deferred

View file

@ -44,7 +44,9 @@ logger = logging.getLogger("synapse.app.appservice")
class AppserviceSlaveStore(
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
DirectoryStore,
SlavedEventStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
):
pass
@ -74,7 +76,7 @@ class AppserviceServer(HomeServer):
listener_config,
root_resource,
self.version_string,
)
),
)
logger.info("Synapse appservice now listening on port %d", port)
@ -88,18 +90,19 @@ class AppserviceServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -132,9 +135,7 @@ class ASReplicationHandler(ReplicationClientHandler):
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse appservice", config_options
)
config = HomeServerConfig.load_config("Synapse appservice", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@ -173,6 +174,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-appservice", config)
if __name__ == '__main__':
if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -37,6 +37,7 @@ from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
@ -52,6 +53,7 @@ from synapse.rest.client.v1.room import (
PublicRoomListRestServlet,
RoomEventContextServlet,
RoomMemberListRestServlet,
RoomMessageListRestServlet,
RoomStateRestServlet,
)
from synapse.rest.client.v1.voip import VoipRestServlet
@ -74,6 +76,7 @@ class ClientReaderSlavedStore(
SlavedDeviceStore,
SlavedReceiptsStore,
SlavedPushRuleStore,
SlavedGroupServerStore,
SlavedAccountDataStore,
SlavedEventStore,
SlavedKeyStore,
@ -109,6 +112,7 @@ class ClientReaderServer(HomeServer):
JoinedRoomMemberListRestServlet(self).register(resource)
RoomStateRestServlet(self).register(resource)
RoomEventContextServlet(self).register(resource)
RoomMessageListRestServlet(self).register(resource)
RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
ThreepidRestServlet(self).register(resource)
@ -118,9 +122,7 @@ class ClientReaderServer(HomeServer):
PushRuleRestServlet(self).register(resource)
VersionsRestServlet().register(resource)
resources.update({
"/_matrix/client": resource,
})
resources.update({"/_matrix/client": resource})
root_resource = create_resource_tree(resources, NoResource())
@ -133,7 +135,7 @@ class ClientReaderServer(HomeServer):
listener_config,
root_resource,
self.version_string,
)
),
)
logger.info("Synapse client reader now listening on port %d", port)
@ -147,18 +149,19 @@ class ClientReaderServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -170,9 +173,7 @@ class ClientReaderServer(HomeServer):
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse client reader", config_options
)
config = HomeServerConfig.load_config("Synapse client reader", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@ -199,6 +200,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-client-reader", config)
if __name__ == '__main__':
if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -109,12 +109,14 @@ class EventCreatorServer(HomeServer):
ProfileAvatarURLRestServlet(self).register(resource)
ProfileDisplaynameRestServlet(self).register(resource)
ProfileRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
resources.update(
{
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
}
)
root_resource = create_resource_tree(resources, NoResource())
@ -127,7 +129,7 @@ class EventCreatorServer(HomeServer):
listener_config,
root_resource,
self.version_string,
)
),
)
logger.info("Synapse event creator now listening on port %d", port)
@ -141,18 +143,19 @@ class EventCreatorServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -164,9 +167,7 @@ class EventCreatorServer(HomeServer):
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse event creator", config_options
)
config = HomeServerConfig.load_config("Synapse event creator", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@ -198,6 +199,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-event-creator", config)
if __name__ == '__main__':
if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -86,19 +86,18 @@ class FederationReaderServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
if name == "openid" and "federation" not in res["names"]:
# Only load the openid resource separately if federation resource
# is not specified since federation resource includes openid
# resource.
resources.update({
FEDERATION_PREFIX: TransportLayerServer(
self,
servlet_groups=["openid"],
),
})
resources.update(
{
FEDERATION_PREFIX: TransportLayerServer(
self, servlet_groups=["openid"]
)
}
)
if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
@ -115,7 +114,7 @@ class FederationReaderServer(HomeServer):
root_resource,
self.version_string,
),
reactor=self.get_reactor()
reactor=self.get_reactor(),
)
logger.info("Synapse federation reader now listening on port %d", port)
@ -129,18 +128,19 @@ class FederationReaderServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -181,6 +181,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-federation-reader", config)
if __name__ == '__main__':
if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -52,8 +52,13 @@ logger = logging.getLogger("synapse.app.federation_sender")
class FederationSenderSlaveStore(
SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
SlavedDeviceInboxStore,
SlavedTransactionStore,
SlavedReceiptsStore,
SlavedEventStore,
SlavedRegistrationStore,
SlavedDeviceStore,
SlavedPresenceStore,
):
def __init__(self, db_conn, hs):
super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
@ -65,10 +70,7 @@ class FederationSenderSlaveStore(
self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
def _get_federation_out_pos(self, db_conn):
sql = (
"SELECT stream_id FROM federation_stream_position"
" WHERE type = ?"
)
sql = "SELECT stream_id FROM federation_stream_position" " WHERE type = ?"
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
@ -103,7 +105,7 @@ class FederationSenderServer(HomeServer):
listener_config,
root_resource,
self.version_string,
)
),
)
logger.info("Synapse federation_sender now listening on port %d", port)
@ -117,18 +119,19 @@ class FederationSenderServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -151,7 +154,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
self.send_handler.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self):
args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
args = super(
FederationSenderReplicationHandler, self
).get_streams_to_replicate()
args.update(self.send_handler.stream_positions())
return args
@ -203,6 +208,7 @@ class FederationSenderHandler(object):
"""Processes the replication stream and forwards the appropriate entries
to the federation sender.
"""
def __init__(self, hs, replication_client):
self.store = hs.get_datastore()
self._is_mine_id = hs.is_mine_id
@ -241,7 +247,7 @@ class FederationSenderHandler(object):
# ... and when new receipts happen
elif stream_name == ReceiptsStream.NAME:
run_as_background_process(
"process_receipts_for_federation", self._on_new_receipts, rows,
"process_receipts_for_federation", self._on_new_receipts, rows
)
@defer.inlineCallbacks
@ -278,12 +284,14 @@ class FederationSenderHandler(object):
# We ACK this token over replication so that the master can drop
# its in memory queues
self.replication_client.send_federation_ack(self.federation_position)
self.replication_client.send_federation_ack(
self.federation_position
)
self._last_ack = self.federation_position
except Exception:
logger.exception("Error updating federation stream position")
if __name__ == '__main__':
if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -62,14 +62,11 @@ class PresenceStatusStubServlet(RestServlet):
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
headers = {
"Authorization": auth_headers,
}
headers = {"Authorization": auth_headers}
try:
result = yield self.http_client.get_json(
self.main_uri + request.uri.decode('ascii'),
headers=headers,
self.main_uri + request.uri.decode("ascii"), headers=headers
)
except HttpResponseException as e:
raise e.to_synapse_error()
@ -105,18 +102,19 @@ class KeyUploadServlet(RestServlet):
if device_id is not None:
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
if (requester.device_id is not None and
device_id != requester.device_id):
logger.warning("Client uploading keys for a different device "
"(logged in as %s, uploading for %s)",
requester.device_id, device_id)
if requester.device_id is not None and device_id != requester.device_id:
logger.warning(
"Client uploading keys for a different device "
"(logged in as %s, uploading for %s)",
requester.device_id,
device_id,
)
else:
device_id = requester.device_id
if device_id is None:
raise SynapseError(
400,
"To upload keys, you must pass device_id when authenticating"
400, "To upload keys, you must pass device_id when authenticating"
)
if body:
@ -124,13 +122,9 @@ class KeyUploadServlet(RestServlet):
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
headers = {
"Authorization": auth_headers,
}
headers = {"Authorization": auth_headers}
result = yield self.http_client.post_json_get_json(
self.main_uri + request.uri.decode('ascii'),
body,
headers=headers,
self.main_uri + request.uri.decode("ascii"), body, headers=headers
)
defer.returnValue((200, result))
@ -171,12 +165,14 @@ class FrontendProxyServer(HomeServer):
if not self.config.use_presence:
PresenceStatusStubServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
resources.update(
{
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
}
)
root_resource = create_resource_tree(resources, NoResource())
@ -190,7 +186,7 @@ class FrontendProxyServer(HomeServer):
root_resource,
self.version_string,
),
reactor=self.get_reactor()
reactor=self.get_reactor(),
)
logger.info("Synapse client reader now listening on port %d", port)
@ -204,18 +200,19 @@ class FrontendProxyServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -227,9 +224,7 @@ class FrontendProxyServer(HomeServer):
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse frontend proxy", config_options
)
config = HomeServerConfig.load_config("Synapse frontend proxy", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@ -258,6 +253,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-frontend-proxy", config)
if __name__ == '__main__':
if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -101,13 +101,12 @@ class SynapseHomeServer(HomeServer):
# Skip loading openid resource if federation is defined
# since federation resource will include openid
continue
resources.update(self._configure_named_resource(
name, res.get("compress", False),
))
resources.update(
self._configure_named_resource(name, res.get("compress", False))
)
additional_resources = listener_config.get("additional_resources", {})
logger.debug("Configuring additional resources: %r",
additional_resources)
logger.debug("Configuring additional resources: %r", additional_resources)
module_api = ModuleApi(self, self.get_auth_handler())
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
@ -174,60 +173,67 @@ class SynapseHomeServer(HomeServer):
if compress:
client_resource = gz_wrap(client_resource)
resources.update({
"/_matrix/client/api/v1": client_resource,
"/_synapse/password_reset": client_resource,
"/_matrix/client/r0": client_resource,
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
"/.well-known/matrix/client": WellKnownResource(self),
"/_synapse/admin": AdminRestResource(self),
})
resources.update(
{
"/_matrix/client/api/v1": client_resource,
"/_matrix/client/r0": client_resource,
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
"/.well-known/matrix/client": WellKnownResource(self),
"/_synapse/admin": AdminRestResource(self),
}
)
if self.get_config().saml2_enabled:
from synapse.rest.saml2 import SAML2Resource
resources["/_matrix/saml2"] = SAML2Resource(self)
if name == "consent":
from synapse.rest.consent.consent_resource import ConsentResource
consent_resource = ConsentResource(self)
if compress:
consent_resource = gz_wrap(consent_resource)
resources.update({
"/_matrix/consent": consent_resource,
})
resources.update({"/_matrix/consent": consent_resource})
if name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
if name == "openid":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self, servlet_groups=["openid"]),
})
resources.update(
{
FEDERATION_PREFIX: TransportLayerServer(
self, servlet_groups=["openid"]
)
}
)
if name in ["static", "client"]:
resources.update({
STATIC_PREFIX: File(
os.path.join(os.path.dirname(synapse.__file__), "static")
),
})
resources.update(
{
STATIC_PREFIX: File(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
}
)
if name in ["media", "federation", "client"]:
if self.get_config().enable_media_repo:
media_repo = self.get_media_repository_resource()
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
resources.update(
{
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
}
)
elif name == "media":
raise ConfigError(
"'media' resource conflicts with enable_media_repo=False",
"'media' resource conflicts with enable_media_repo=False"
)
if name in ["keys", "federation"]:
@ -258,18 +264,14 @@ class SynapseHomeServer(HomeServer):
for listener in listeners:
if listener["type"] == "http":
self._listening_services.extend(
self._listener_http(config, listener)
)
self._listening_services.extend(self._listener_http(config, listener))
elif listener["type"] == "manhole":
listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "replication":
services = listen_tcp(
@ -278,16 +280,17 @@ class SynapseHomeServer(HomeServer):
ReplicationStreamProtocolFactory(self),
)
for s in services:
reactor.addSystemEventTrigger(
"before", "shutdown", s.stopListening,
)
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -313,7 +316,7 @@ current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
registered_reserved_users_mau_gauge = Gauge(
"synapse_admin_mau:registered_reserved_users",
"Registered users with reserved threepids"
"Registered users with reserved threepids",
)
@ -328,8 +331,7 @@ def setup(config_options):
"""
try:
config = HomeServerConfig.load_or_generate_config(
"Synapse Homeserver",
config_options,
"Synapse Homeserver", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
@ -340,10 +342,7 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
synapse.config.logger.setup_logging(
config,
use_worker_options=False
)
synapse.config.logger.setup_logging(config, use_worker_options=False)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@ -358,7 +357,7 @@ def setup(config_options):
database_engine=database_engine,
)
logger.info("Preparing database: %s...", config.database_config['name'])
logger.info("Preparing database: %s...", config.database_config["name"])
try:
with hs.get_db_conn(run_new_connection=False) as db_conn:
@ -376,7 +375,7 @@ def setup(config_options):
)
sys.exit(1)
logger.info("Database prepared in %s.", config.database_config['name'])
logger.info("Database prepared in %s.", config.database_config["name"])
hs.setup()
hs.setup_master()
@ -392,9 +391,7 @@ def setup(config_options):
acme = hs.get_acme_handler()
# Check how long the certificate is active for.
cert_days_remaining = hs.config.is_disk_cert_valid(
allow_self_signed=False
)
cert_days_remaining = hs.config.is_disk_cert_valid(allow_self_signed=False)
# We want to reprovision if cert_days_remaining is None (meaning no
# certificate exists), or the days remaining number it returns
@ -402,8 +399,8 @@ def setup(config_options):
provision = False
if (
cert_days_remaining is None or
cert_days_remaining < hs.config.acme_reprovision_threshold
cert_days_remaining is None
or cert_days_remaining < hs.config.acme_reprovision_threshold
):
provision = True
@ -434,10 +431,7 @@ def setup(config_options):
yield do_acme()
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(
reprovision_acme,
24 * 60 * 60 * 1000
)
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
_base.start(hs, config.listeners)
@ -464,6 +458,7 @@ class SynapseService(service.Service):
A twisted Service class that will start synapse. Used to run synapse
via twistd and a .tac.
"""
def __init__(self, config):
self.config = config
@ -480,6 +475,7 @@ class SynapseService(service.Service):
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
def profile(func):
from cProfile import Profile
from threading import current_thread
@ -490,13 +486,14 @@ def run(hs):
func(*args, **kargs)
profile.disable()
ident = current_thread().ident
profile.dump_stats("/tmp/%s.%s.%i.pstat" % (
hs.hostname, func.__name__, ident
))
profile.dump_stats(
"/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident)
)
return profiled
from twisted.python.threadpool import ThreadPool
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
@ -541,7 +538,10 @@ def run(hs):
stats["total_room_count"] = room_count
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
stats[
"daily_active_rooms"
] = yield hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users()
@ -565,8 +565,7 @@ def run(hs):
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
yield hs.get_simple_http_client().put_json(
"https://matrix.org/report-usage-stats/push",
stats
"https://matrix.org/report-usage-stats/push", stats
)
except Exception as e:
logger.warn("Error reporting stats: %s", e)
@ -581,14 +580,11 @@ def run(hs):
logger.info("report_stats can use psutil")
stats_process.append(process)
except (AttributeError):
logger.warning(
"Unable to read memory/cpu stats. Disabling reporting."
)
logger.warning("Unable to read memory/cpu stats. Disabling reporting.")
def generate_user_daily_visit_stats():
return run_as_background_process(
"generate_user_daily_visits",
hs.get_datastore().generate_user_daily_visits,
"generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits
)
# Rather than update on per session basis, batch up the requests.
@ -599,9 +595,9 @@ def run(hs):
# monthly active user limiting functionality
def reap_monthly_active_users():
return run_as_background_process(
"reap_monthly_active_users",
hs.get_datastore().reap_monthly_active_users,
"reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users
)
clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
reap_monthly_active_users()
@ -619,8 +615,7 @@ def run(hs):
def start_generate_monthly_active_users():
return run_as_background_process(
"generate_monthly_active_users",
generate_monthly_active_users,
"generate_monthly_active_users", generate_monthly_active_users
)
start_generate_monthly_active_users()
@ -646,7 +641,6 @@ def run(hs):
gc_thresholds=hs.config.gc_thresholds,
pid_file=hs.config.pid_file,
daemonize=hs.config.daemonize,
cpu_affinity=hs.config.cpu_affinity,
print_pidfile=hs.config.print_pidfile,
logger=logger,
)
@ -660,5 +654,5 @@ def main():
run(hs)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -72,13 +72,15 @@ class MediaRepositoryServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "media":
media_repo = self.get_media_repository_resource()
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
resources.update(
{
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
}
)
root_resource = create_resource_tree(resources, NoResource())
@ -91,7 +93,7 @@ class MediaRepositoryServer(HomeServer):
listener_config,
root_resource,
self.version_string,
)
),
)
logger.info("Synapse media repository now listening on port %d", port)
@ -105,18 +107,19 @@ class MediaRepositoryServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -164,6 +167,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-media-repository", config)
if __name__ == '__main__':
if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -46,36 +46,27 @@ logger = logging.getLogger("synapse.app.pusher")
class PusherSlaveStore(
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
SlavedAccountDataStore
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedAccountDataStore
):
update_pusher_last_stream_ordering_and_success = (
__func__(DataStore.update_pusher_last_stream_ordering_and_success)
update_pusher_last_stream_ordering_and_success = __func__(
DataStore.update_pusher_last_stream_ordering_and_success
)
update_pusher_failing_since = (
__func__(DataStore.update_pusher_failing_since)
update_pusher_failing_since = __func__(DataStore.update_pusher_failing_since)
update_pusher_last_stream_ordering = __func__(
DataStore.update_pusher_last_stream_ordering
)
update_pusher_last_stream_ordering = (
__func__(DataStore.update_pusher_last_stream_ordering)
get_throttle_params_by_room = __func__(DataStore.get_throttle_params_by_room)
set_throttle_params = __func__(DataStore.set_throttle_params)
get_time_of_last_push_action_before = __func__(
DataStore.get_time_of_last_push_action_before
)
get_throttle_params_by_room = (
__func__(DataStore.get_throttle_params_by_room)
)
set_throttle_params = (
__func__(DataStore.set_throttle_params)
)
get_time_of_last_push_action_before = (
__func__(DataStore.get_time_of_last_push_action_before)
)
get_profile_displayname = (
__func__(DataStore.get_profile_displayname)
)
get_profile_displayname = __func__(DataStore.get_profile_displayname)
class PusherServer(HomeServer):
@ -105,7 +96,7 @@ class PusherServer(HomeServer):
listener_config,
root_resource,
self.version_string,
)
),
)
logger.info("Synapse pusher now listening on port %d", port)
@ -119,18 +110,19 @@ class PusherServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -161,9 +153,7 @@ class PusherReplicationHandler(ReplicationClientHandler):
else:
yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == "events":
yield self.pusher_pool.on_new_notifications(
token, token,
)
yield self.pusher_pool.on_new_notifications(token, token)
elif stream_name == "receipts":
yield self.pusher_pool.on_new_receipts(
token, token, set(row.room_id for row in rows)
@ -188,9 +178,7 @@ class PusherReplicationHandler(ReplicationClientHandler):
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse pusher", config_options
)
config = HomeServerConfig.load_config("Synapse pusher", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@ -234,6 +222,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-pusher", config)
if __name__ == '__main__':
if __name__ == "__main__":
with LoggingContext("main"):
ps = start(sys.argv[1:])

View file

@ -98,10 +98,7 @@ class SynchrotronPresence(object):
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {
state.user_id: state
for state in active_presence
}
self.user_to_current_state = {state.user_id: state for state in active_presence}
# user_id -> last_sync_ms. Lists the users that have stopped syncing
# but we haven't notified the master of that yet
@ -196,17 +193,26 @@ class SynchrotronPresence(object):
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=users_to_states.keys()
"presence_key",
stream_id,
rooms=room_ids_to_states.keys(),
users=users_to_states.keys(),
)
@defer.inlineCallbacks
def process_replication_rows(self, token, rows):
states = [UserPresenceState(
row.user_id, row.state, row.last_active_ts,
row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg,
row.currently_active
) for row in rows]
states = [
UserPresenceState(
row.user_id,
row.state,
row.last_active_ts,
row.last_federation_update_ts,
row.last_user_sync_ts,
row.status_msg,
row.currently_active,
)
for row in rows
]
for state in states:
self.user_to_current_state[state.user_id] = state
@ -217,7 +223,8 @@ class SynchrotronPresence(object):
def get_currently_syncing_users(self):
if self.hs.config.use_presence:
return [
user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
user_id
for user_id, count in iteritems(self.user_to_num_current_syncs)
if count > 0
]
else:
@ -281,12 +288,14 @@ class SynchrotronServer(HomeServer):
events.register_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
RoomInitialSyncRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
resources.update(
{
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
}
)
root_resource = create_resource_tree(resources, NoResource())
@ -299,7 +308,7 @@ class SynchrotronServer(HomeServer):
listener_config,
root_resource,
self.version_string,
)
),
)
logger.info("Synapse synchrotron now listening on port %d", port)
@ -313,18 +322,19 @@ class SynchrotronServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -382,40 +392,36 @@ class SyncReplicationHandler(ReplicationClientHandler):
)
elif stream_name == "push_rules":
self.notifier.on_new_event(
"push_rules_key", token, users=[row.user_id for row in rows],
"push_rules_key", token, users=[row.user_id for row in rows]
)
elif stream_name in ("account_data", "tag_account_data",):
elif stream_name in ("account_data", "tag_account_data"):
self.notifier.on_new_event(
"account_data_key", token, users=[row.user_id for row in rows],
"account_data_key", token, users=[row.user_id for row in rows]
)
elif stream_name == "receipts":
self.notifier.on_new_event(
"receipt_key", token, rooms=[row.room_id for row in rows],
"receipt_key", token, rooms=[row.room_id for row in rows]
)
elif stream_name == "typing":
self.typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows],
"typing_key", token, rooms=[row.room_id for row in rows]
)
elif stream_name == "to_device":
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
self.notifier.on_new_event(
"to_device_key", token, users=entities,
)
self.notifier.on_new_event("to_device_key", token, users=entities)
elif stream_name == "device_lists":
all_room_ids = set()
for row in rows:
room_ids = yield self.store.get_rooms_for_user(row.user_id)
all_room_ids.update(room_ids)
self.notifier.on_new_event(
"device_list_key", token, rooms=all_room_ids,
)
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
elif stream_name == "presence":
yield self.presence_handler.process_replication_rows(token, rows)
elif stream_name == "receipts":
self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows],
"groups_key", token, users=[row.user_id for row in rows]
)
except Exception:
logger.exception("Error processing replication")
@ -423,9 +429,7 @@ class SyncReplicationHandler(ReplicationClientHandler):
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse synchrotron", config_options
)
config = HomeServerConfig.load_config("Synapse synchrotron", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@ -453,6 +457,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-synchrotron", config)
if __name__ == '__main__':
if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -66,14 +66,16 @@ class UserDirectorySlaveStore(
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
db_conn, "current_state_delta_stream",
db_conn,
"current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
"_curr_state_delta_stream_cache", min_curr_state_delta_id,
"_curr_state_delta_stream_cache",
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
@ -110,12 +112,14 @@ class UserDirectoryServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
user_directory.register_servlets(self, resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
resources.update(
{
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
}
)
root_resource = create_resource_tree(resources, NoResource())
@ -128,7 +132,7 @@ class UserDirectoryServer(HomeServer):
listener_config,
root_resource,
self.version_string,
)
),
)
logger.info("Synapse user_dir now listening on port %d", port)
@ -142,18 +146,19 @@ class UserDirectoryServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
username="matrix", password="rabbithole", globals={"hs": self}
),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warn(("Metrics listener configured, but "
"enable_metrics is not True!"))
logger.warn(
(
"Metrics listener configured, but "
"enable_metrics is not True!"
)
)
else:
_base.listen_metrics(listener["bind_addresses"],
listener["port"])
_base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -186,9 +191,7 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse user directory", config_options
)
config = HomeServerConfig.load_config("Synapse user directory", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@ -227,6 +230,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-user-dir", config)
if __name__ == '__main__':
if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -48,9 +48,7 @@ class AppServiceTransaction(object):
A Deferred which resolves to True if the transaction was sent.
"""
return as_api.push_bulk(
service=self.service,
events=self.events,
txn_id=self.id
service=self.service, events=self.events, txn_id=self.id
)
def complete(self, store):
@ -64,10 +62,7 @@ class AppServiceTransaction(object):
Returns:
A Deferred which resolves to True if the transaction was completed.
"""
return store.complete_appservice_txn(
service=self.service,
txn_id=self.id
)
return store.complete_appservice_txn(service=self.service, txn_id=self.id)
class ApplicationService(object):
@ -76,6 +71,7 @@ class ApplicationService(object):
Provides methods to check if this service is "interested" in events.
"""
NS_USERS = "users"
NS_ALIASES = "aliases"
NS_ROOMS = "rooms"
@ -84,9 +80,19 @@ class ApplicationService(object):
# values.
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None, rate_limited=True,
ip_range_whitelist=None):
def __init__(
self,
token,
hostname,
url=None,
namespaces=None,
hs_token=None,
sender=None,
id=None,
protocols=None,
rate_limited=True,
ip_range_whitelist=None,
):
self.token = token
self.url = url
self.hs_token = hs_token
@ -128,9 +134,7 @@ class ApplicationService(object):
if not isinstance(regex_obj, dict):
raise ValueError("Expected dict regex for ns '%s'" % ns)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns
)
raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
group_id = regex_obj.get("group_id")
if group_id:
if not isinstance(group_id, str):
@ -153,9 +157,7 @@ class ApplicationService(object):
if isinstance(regex, string_types):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
raise ValueError(
"Expected string for 'regex' in ns '%s'" % ns
)
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
return namespaces
def _matches_regex(self, test_string, namespace_key):
@ -178,8 +180,9 @@ class ApplicationService(object):
if self.is_interested_in_user(event.sender):
defer.returnValue(True)
# also check m.room.member state key
if (event.type == EventTypes.Member and
self.is_interested_in_user(event.state_key)):
if event.type == EventTypes.Member and self.is_interested_in_user(
event.state_key
):
defer.returnValue(True)
if not store:

View file

@ -32,19 +32,17 @@ logger = logging.getLogger(__name__)
sent_transactions_counter = Counter(
"synapse_appservice_api_sent_transactions",
"Number of /transactions/ requests sent",
["service"]
["service"],
)
failed_transactions_counter = Counter(
"synapse_appservice_api_failed_transactions",
"Number of /transactions/ requests that failed to send",
["service"]
["service"],
)
sent_events_counter = Counter(
"synapse_appservice_api_sent_events",
"Number of events sent to the AS",
["service"]
"synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
)
HOUR_IN_MS = 60 * 60 * 1000
@ -92,8 +90,9 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta",
timeout_ms=HOUR_IN_MS)
self.protocol_meta_cache = ResponseCache(
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
)
@defer.inlineCallbacks
def query_user(self, service, user_id):
@ -102,9 +101,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None
try:
response = yield self.get_json(uri, {
"access_token": service.hs_token
})
response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
defer.returnValue(True)
except CodeMessageException as e:
@ -123,9 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None
try:
response = yield self.get_json(uri, {
"access_token": service.hs_token
})
response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
defer.returnValue(True)
except CodeMessageException as e:
@ -144,9 +139,7 @@ class ApplicationServiceApi(SimpleHttpClient):
elif kind == ThirdPartyEntityKind.LOCATION:
required_field = "alias"
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind)
if service.url is None:
defer.returnValue([])
@ -154,14 +147,13 @@ class ApplicationServiceApi(SimpleHttpClient):
service.url,
APP_SERVICE_PREFIX,
kind,
urllib.parse.quote(protocol)
urllib.parse.quote(protocol),
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
uri, response
"query_3pe to %s returned an invalid response %r", uri, response
)
defer.returnValue([])
@ -171,8 +163,7 @@ class ApplicationServiceApi(SimpleHttpClient):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r",
uri, r
"query_3pe to %s returned an invalid result %r", uri, r
)
defer.returnValue(ret)
@ -189,27 +180,27 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.parse.quote(protocol)
urllib.parse.quote(protocol),
)
try:
info = yield self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning("query_3pe_protocol to %s did not return a"
" valid result", uri)
logger.warning(
"query_3pe_protocol to %s did not return a" " valid result", uri
)
defer.returnValue(None)
for instance in info.get("instances", []):
network_id = instance.get("network_id", None)
if network_id is not None:
instance["instance_id"] = ThirdPartyInstanceID(
service.id, network_id,
service.id, network_id
).to_string()
defer.returnValue(info)
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex)
logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex)
defer.returnValue(None)
key = (service.id, protocol)
@ -223,22 +214,19 @@ class ApplicationServiceApi(SimpleHttpClient):
events = self._serialize(events)
if txn_id is None:
logger.warning("push_bulk: Missing txn ID sending events to %s",
service.url)
logger.warning(
"push_bulk: Missing txn ID sending events to %s", service.url
)
txn_id = str(0)
txn_id = str(txn_id)
uri = service.url + ("/transactions/%s" %
urllib.parse.quote(txn_id))
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try:
yield self.put_json(
uri=uri,
json_body={
"events": events
},
args={
"access_token": service.hs_token
})
json_body={"events": events},
args={"access_token": service.hs_token},
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
defer.returnValue(True)
@ -252,6 +240,4 @@ class ApplicationServiceApi(SimpleHttpClient):
def _serialize(self, events):
time_now = self.clock.time_msec()
return [
serialize_event(e, time_now, as_client_event=True) for e in events
]
return [serialize_event(e, time_now, as_client_event=True) for e in events]

View file

@ -112,15 +112,14 @@ class _ServiceQueuer(object):
return
run_as_background_process(
"as-sender-%s" % (service.id, ),
self._send_request, service,
"as-sender-%s" % (service.id,), self._send_request, service
)
@defer.inlineCallbacks
def _send_request(self, service):
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert(service.id not in self.requests_in_flight)
assert service.id not in self.requests_in_flight
self.requests_in_flight.add(service.id)
try:
@ -137,7 +136,6 @@ class _ServiceQueuer(object):
class _TransactionController(object):
def __init__(self, clock, store, as_api, recoverer_fn):
self.clock = clock
self.store = store
@ -149,10 +147,7 @@ class _TransactionController(object):
@defer.inlineCallbacks
def send(self, service, events):
try:
txn = yield self.store.create_appservice_txn(
service=service,
events=events
)
txn = yield self.store.create_appservice_txn(service=service, events=events)
service_is_up = yield self._is_service_up(service)
if service_is_up:
sent = yield txn.send(self.as_api)
@ -167,12 +162,12 @@ class _TransactionController(object):
@defer.inlineCallbacks
def on_recovered(self, recoverer):
self.recoverers.remove(recoverer)
logger.info("Successfully recovered application service AS ID %s",
recoverer.service.id)
logger.info(
"Successfully recovered application service AS ID %s", recoverer.service.id
)
logger.info("Remaining active recoverers: %s", len(self.recoverers))
yield self.store.set_appservice_state(
recoverer.service,
ApplicationServiceState.UP
recoverer.service, ApplicationServiceState.UP
)
def add_recoverers(self, recoverers):
@ -184,13 +179,10 @@ class _TransactionController(object):
@defer.inlineCallbacks
def _start_recoverer(self, service):
try:
yield self.store.set_appservice_state(
service,
ApplicationServiceState.DOWN
)
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
logger.info(
"Application service falling behind. Starting recoverer. AS ID %s",
service.id
service.id,
)
recoverer = self.recoverer_fn(service, self.on_recovered)
self.add_recoverers([recoverer])
@ -205,19 +197,16 @@ class _TransactionController(object):
class _Recoverer(object):
@staticmethod
@defer.inlineCallbacks
def start(clock, store, as_api, callback):
services = yield store.get_appservices_by_state(
ApplicationServiceState.DOWN
)
recoverers = [
_Recoverer(clock, store, as_api, s, callback) for s in services
]
services = yield store.get_appservices_by_state(ApplicationServiceState.DOWN)
recoverers = [_Recoverer(clock, store, as_api, s, callback) for s in services]
for r in recoverers:
logger.info("Starting recoverer for AS ID %s which was marked as "
"DOWN", r.service.id)
logger.info(
"Starting recoverer for AS ID %s which was marked as " "DOWN",
r.service.id,
)
r.recover()
defer.returnValue(recoverers)
@ -232,9 +221,9 @@ class _Recoverer(object):
def recover(self):
def _retry():
run_as_background_process(
"as-recoverer-%s" % (self.service.id,),
self.retry,
"as-recoverer-%s" % (self.service.id,), self.retry
)
self.clock.call_later((2 ** self.backoff_counter), _retry)
def _backoff(self):
@ -248,8 +237,9 @@ class _Recoverer(object):
try:
txn = yield self.store.get_oldest_unsent_txn(self.service)
if txn:
logger.info("Retrying transaction %s for AS ID %s",
txn.id, txn.service.id)
logger.info(
"Retrying transaction %s for AS ID %s", txn.id, txn.service.id
)
sent = yield txn.send(self.as_api)
if sent:
yield txn.complete(self.store)

View file

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -134,11 +136,6 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
@staticmethod
def read_config_file(file_path):
with open(file_path) as file_stream:
return yaml.safe_load(file_stream)
def invoke_all(self, name, *args, **kargs):
results = []
for cls in type(self).mro():
@ -153,12 +150,12 @@ class Config(object):
server_name,
generate_secrets=False,
report_stats=None,
open_private_ports=False,
):
"""Build a default configuration file
This is used both when the user explicitly asks us to generate a config file
(eg with --generate_config), and before loading the config at runtime (to give
a base which the config files override)
This is used when the user explicitly asks us to generate a config file
(eg with --generate_config).
Args:
config_dir_path (str): The path where the config files are kept. Used to
@ -177,25 +174,33 @@ class Config(object):
report_stats (bool|None): Initial setting for the report_stats setting.
If None, report_stats will be left unset.
open_private_ports (bool): True to leave private ports (such as the non-TLS
HTTP listener) open to the internet.
Returns:
str: the yaml config file
"""
default_config = "\n\n".join(
return "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
"default_config",
"generate_config_section",
config_dir_path=config_dir_path,
data_dir_path=data_dir_path,
server_name=server_name,
generate_secrets=generate_secrets,
report_stats=report_stats,
open_private_ports=open_private_ports,
)
)
return default_config
@classmethod
def load_config(cls, description, argv):
"""Parse the commandline and config files
Doesn't support config-file-generation: used by the worker apps.
Returns: Config object.
"""
config_parser = argparse.ArgumentParser(description=description)
config_parser.add_argument(
"-c",
@ -210,7 +215,7 @@ class Config(object):
"--keys-directory",
metavar="DIRECTORY",
help="Where files such as certs and signing keys are stored when"
" their location is given explicitly in the config."
" their location is not given explicitly in the config."
" Defaults to the directory containing the last config file",
)
@ -222,8 +227,19 @@ class Config(object):
config_files = find_config_files(search_paths=config_args.config_path)
obj.read_config_files(
config_files, keys_directory=config_args.keys_directory, generate_keys=False
if not config_files:
config_parser.error("Must supply a config file.")
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
else:
config_dir_path = os.path.dirname(config_files[-1])
config_dir_path = os.path.abspath(config_dir_path)
data_dir_path = os.getcwd()
config_dict = read_config_files(config_files)
obj.parse_config_dict(
config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
)
obj.invoke_all("read_arguments", config_args)
@ -232,6 +248,12 @@ class Config(object):
@classmethod
def load_or_generate_config(cls, description, argv):
"""Parse the commandline and config files
Supports generation of config files, so is used for the main homeserver app.
Returns: Config object, or None if --generate-config or --generate-keys was set
"""
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
"-c",
@ -241,37 +263,74 @@ class Config(object):
help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files.",
)
config_parser.add_argument(
generate_group = config_parser.add_argument_group("Config generation")
generate_group.add_argument(
"--generate-config",
action="store_true",
help="Generate a config file for the server name",
help="Generate a config file, then exit.",
)
config_parser.add_argument(
"--report-stats",
action="store",
help="Whether the generated config reports anonymized usage statistics",
choices=["yes", "no"],
)
config_parser.add_argument(
generate_group.add_argument(
"--generate-missing-configs",
"--generate-keys",
action="store_true",
help="Generate any missing key files then exit",
help="Generate any missing additional config files, then exit.",
)
config_parser.add_argument(
generate_group.add_argument(
"-H", "--server-name", help="The server name to generate a config file for."
)
generate_group.add_argument(
"--report-stats",
action="store",
help="Whether the generated config reports anonymized usage statistics.",
choices=["yes", "no"],
)
generate_group.add_argument(
"--config-directory",
"--keys-directory",
metavar="DIRECTORY",
help="Used with 'generate-*' options to specify where files such as"
" signing keys should be stored, unless explicitly"
" specified in the config.",
help=(
"Specify where additional config files such as signing keys and log"
" config should be stored. Defaults to the same directory as the last"
" config file."
),
)
config_parser.add_argument(
"-H", "--server-name", help="The server name to generate a config file for"
generate_group.add_argument(
"--data-directory",
metavar="DIRECTORY",
help=(
"Specify where data such as the media store and database file should be"
" stored. Defaults to the current working directory."
),
)
generate_group.add_argument(
"--open-private-ports",
action="store_true",
help=(
"Leave private ports (such as the non-TLS HTTP listener) open to the"
" internet. Do not use this unless you know what you are doing."
),
)
config_args, remaining_args = config_parser.parse_known_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
generate_keys = config_args.generate_keys
if not config_files:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
' generated using "--generate-config -H SERVER_NAME'
' -c CONFIG-FILE"'
)
if config_args.config_directory:
config_dir_path = config_args.config_directory
else:
config_dir_path = os.path.dirname(config_files[-1])
config_dir_path = os.path.abspath(config_dir_path)
data_dir_path = os.getcwd()
generate_missing_configs = config_args.generate_missing_configs
obj = cls()
@ -281,19 +340,16 @@ class Config(object):
"Please specify either --report-stats=yes or --report-stats=no\n\n"
+ MISSING_REPORT_STATS_SPIEL
)
if not config_files:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\""
)
(config_path,) = config_files
if not cls.path_exists(config_path):
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
print("Generating config file %s" % (config_path,))
if config_args.data_directory:
data_dir_path = config_args.data_directory
else:
config_dir_path = os.path.dirname(config_path)
config_dir_path = os.path.abspath(config_dir_path)
data_dir_path = os.getcwd()
data_dir_path = os.path.abspath(data_dir_path)
server_name = config_args.server_name
if not server_name:
@ -304,22 +360,21 @@ class Config(object):
config_str = obj.generate_config(
config_dir_path=config_dir_path,
data_dir_path=os.getcwd(),
data_dir_path=data_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
generate_secrets=True,
open_private_ports=config_args.open_private_ports,
)
if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
config_file.write(
"# vim:ft=yaml\n\n"
)
config_file.write("# vim:ft=yaml\n\n")
config_file.write(config_str)
config = yaml.safe_load(config_str)
obj.invoke_all("generate_files", config)
config_dict = yaml.safe_load(config_str)
obj.generate_missing_files(config_dict, config_dir_path)
print(
(
@ -333,12 +388,12 @@ class Config(object):
else:
print(
(
"Config file %r already exists. Generating any missing key"
"Config file %r already exists. Generating any missing config"
" files."
)
% (config_path,)
)
generate_keys = True
generate_missing_configs = True
parser = argparse.ArgumentParser(
parents=[config_parser],
@ -349,66 +404,63 @@ class Config(object):
obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args)
if not config_files:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\""
)
obj.read_config_files(
config_files,
keys_directory=config_args.keys_directory,
generate_keys=generate_keys,
)
if generate_keys:
config_dict = read_config_files(config_files)
if generate_missing_configs:
obj.generate_missing_files(config_dict, config_dir_path)
return None
obj.parse_config_dict(
config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
)
obj.invoke_all("read_arguments", args)
return obj
def read_config_files(self, config_files, keys_directory=None, generate_keys=False):
if not keys_directory:
keys_directory = os.path.dirname(config_files[-1])
def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
"""Read the information from the config dict into this Config object.
self.config_dir_path = os.path.abspath(keys_directory)
Args:
config_dict (dict): Configuration data, as read from the yaml
specified_config = {}
for config_file in config_files:
yaml_config = self.read_config_file(config_file)
specified_config.update(yaml_config)
config_dir_path (str): The path where the config files are kept. Used to
create filenames for things like the log config and the signing key.
if "server_name" not in specified_config:
raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"]
config_string = self.generate_config(
config_dir_path=self.config_dir_path,
data_dir_path=os.getcwd(),
server_name=server_name,
generate_secrets=False,
data_dir_path (str): The path where the data files are kept. Used to create
filenames for things like the database and media store.
"""
self.invoke_all(
"read_config",
config_dict,
config_dir_path=config_dir_path,
data_dir_path=data_dir_path,
)
config = yaml.safe_load(config_string)
config.pop("log_config")
config.update(specified_config)
if "report_stats" not in config:
raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
+ "\n"
+ MISSING_REPORT_STATS_SPIEL
)
def generate_missing_files(self, config_dict, config_dir_path):
self.invoke_all("generate_files", config_dict, config_dir_path)
if generate_keys:
self.invoke_all("generate_files", config)
return
self.parse_config_dict(config)
def read_config_files(config_files):
"""Read the config files into a dict
def parse_config_dict(self, config_dict):
self.invoke_all("read_config", config_dict)
Args:
config_files (iterable[str]): A list of the config files to read
Returns: dict
"""
specified_config = {}
for config_file in config_files:
with open(config_file) as file_stream:
yaml_config = yaml.safe_load(file_stream)
specified_config.update(yaml_config)
if "server_name" not in specified_config:
raise ConfigError(MISSING_SERVER_NAME)
if "report_stats" not in specified_config:
raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_SPIEL
)
return specified_config
def find_config_files(search_paths):

View file

@ -18,17 +18,19 @@ from ._base import Config
class ApiConfig(Config):
def read_config(self, config, **kwargs):
self.room_invite_state_types = config.get(
"room_invite_state_types",
[
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.RoomEncryption,
EventTypes.Name,
],
)
def read_config(self, config):
self.room_invite_state_types = config.get("room_invite_state_types", [
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.RoomEncryption,
EventTypes.Name,
])
def default_config(cls, **kwargs):
def generate_config_section(cls, **kwargs):
return """\
## API Configuration ##
@ -40,4 +42,6 @@ class ApiConfig(Config):
# - "{RoomAvatar}"
# - "{RoomEncryption}"
# - "{Name}"
""".format(**vars(EventTypes))
""".format(
**vars(EventTypes)
)

View file

@ -29,13 +29,12 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
def default_config(cls, **kwargs):
def generate_config_section(cls, **kwargs):
return """\
# A list of application service config files to use
#
@ -53,9 +52,7 @@ class AppServiceConfig(Config):
def load_appservices(hostname, config_files):
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning(
"Expected %s to be a list of AS config files.", config_files
)
logger.warning("Expected %s to be a list of AS config files.", config_files)
return []
# Dicts of value -> filename
@ -66,22 +63,20 @@ def load_appservices(hostname, config_files):
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = _load_appservice(
hostname, yaml.safe_load(f), config_file
)
with open(config_file, "r") as f:
appservice = _load_appservice(hostname, yaml.safe_load(f), config_file)
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
"%s (files: %s, %s)" % (
appservice.id, config_file, seen_ids[appservice.id],
)
"%s (files: %s, %s)"
% (appservice.id, config_file, seen_ids[appservice.id])
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
"%s (files: %s, %s)" % (
"%s (files: %s, %s)"
% (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
@ -98,28 +93,26 @@ def load_appservices(hostname, config_files):
def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [
"id", "as_token", "hs_token", "sender_localpart"
]
required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"]
for field in required_string_fields:
if not isinstance(as_info.get(field), string_types):
raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
raise KeyError(
"Required string field: '%s' (%s)" % (field, config_filename)
)
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
if (not isinstance(as_info.get("url"), string_types) and
as_info.get("url", "") is not None):
if (
not isinstance(as_info.get("url"), string_types)
and as_info.get("url", "") is not None
):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"]
if urlparse.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
raise ValueError("sender_localpart needs characters which are not URL encoded.")
user = UserID(localpart, hostname)
user_id = user.to_string()
@ -138,13 +131,12 @@ def _load_appservice(hostname, as_info, config_filename):
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
"Expected namespace entry in %s to be an object," " but got %s",
ns,
regex_obj,
)
if not isinstance(regex_obj.get("regex"), string_types):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
raise ValueError("Missing/bad type 'regex' key in %s", regex_obj)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
@ -167,10 +159,8 @@ def _load_appservice(hostname, as_info, config_filename):
)
ip_range_whitelist = None
if as_info.get('ip_range_whitelist'):
ip_range_whitelist = IPSet(
as_info.get('ip_range_whitelist')
)
if as_info.get("ip_range_whitelist"):
ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist"))
return ApplicationService(
token=as_info["as_token"],

View file

@ -16,8 +16,7 @@ from ._base import Config
class CaptchaConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.recaptcha_private_key = config.get("recaptcha_private_key")
self.recaptcha_public_key = config.get("recaptcha_public_key")
self.enable_registration_captcha = config.get(
@ -29,7 +28,7 @@ class CaptchaConfig(Config):
"https://www.recaptcha.net/recaptcha/api/siteverify",
)
def default_config(self, **kwargs):
def generate_config_section(self, **kwargs):
return """\
## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.

View file

@ -22,7 +22,7 @@ class CasConfig(Config):
cas_server_url: URL of CAS server
"""
def read_config(self, config):
def read_config(self, config, **kwargs):
cas_config = config.get("cas_config", None)
if cas_config:
self.cas_enabled = cas_config.get("enabled", True)
@ -35,7 +35,7 @@ class CasConfig(Config):
self.cas_service_url = None
self.cas_required_attributes = {}
def default_config(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Enable CAS for registration and login.
#

View file

@ -84,35 +84,32 @@ class ConsentConfig(Config):
self.user_consent_at_registration = False
self.user_consent_policy_name = "Privacy Policy"
def read_config(self, config):
def read_config(self, config, **kwargs):
consent_config = config.get("user_consent")
if consent_config is None:
return
self.user_consent_version = str(consent_config["version"])
self.user_consent_template_dir = self.abspath(
consent_config["template_dir"]
)
self.user_consent_template_dir = self.abspath(consent_config["template_dir"])
if not path.isdir(self.user_consent_template_dir):
raise ConfigError(
"Could not find template directory '%s'" % (
self.user_consent_template_dir,
),
"Could not find template directory '%s'"
% (self.user_consent_template_dir,)
)
self.user_consent_server_notice_content = consent_config.get(
"server_notice_content",
"server_notice_content"
)
self.block_events_without_consent_error = consent_config.get(
"block_events_error",
"block_events_error"
)
self.user_consent_server_notice_to_guests = bool(
consent_config.get("send_server_notice_to_guests", False)
)
self.user_consent_at_registration = bool(
consent_config.get("require_at_registration", False)
)
self.user_consent_server_notice_to_guests = bool(consent_config.get(
"send_server_notice_to_guests", False,
))
self.user_consent_at_registration = bool(consent_config.get(
"require_at_registration", False,
))
self.user_consent_policy_name = consent_config.get(
"policy_name", "Privacy Policy",
"policy_name", "Privacy Policy"
)
def default_config(self, **kwargs):
def generate_config_section(self, **kwargs):
return DEFAULT_CONFIG

View file

@ -18,37 +18,30 @@ from ._base import Config
class DatabaseConfig(Config):
def read_config(self, config):
self.event_cache_size = self.parse_size(
config.get("event_cache_size", "10K")
)
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
self.database_config = config.get("database")
if self.database_config is None:
self.database_config = {
"name": "sqlite3",
"args": {},
}
self.database_config = {"name": "sqlite3", "args": {}}
name = self.database_config.get("name", None)
if name == "psycopg2":
pass
elif name == "sqlite3":
self.database_config.setdefault("args", {}).update({
"cp_min": 1,
"cp_max": 1,
"check_same_thread": False,
})
self.database_config.setdefault("args", {}).update(
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
)
else:
raise RuntimeError("Unsupported database type '%s'" % (name,))
self.set_databasepath(config.get("database_path"))
def default_config(self, data_dir_path, **kwargs):
def generate_config_section(self, data_dir_path, **kwargs):
database_path = os.path.join(data_dir_path, "homeserver.db")
return """\
return (
"""\
## Database ##
database:
@ -62,7 +55,9 @@ class DatabaseConfig(Config):
# Number of events to cache in memory.
#
#event_cache_size: 10K
""" % locals()
"""
% locals()
)
def read_arguments(self, args):
self.set_databasepath(args.database_path)
@ -77,6 +72,8 @@ class DatabaseConfig(Config):
def add_arguments(self, parser):
db_group = parser.add_argument_group("database")
db_group.add_argument(
"-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
help="The path to a sqlite database to use."
"-d",
"--database-path",
metavar="SQLITE_DATABASE_PATH",
help="The path to a sqlite database to use.",
)

View file

@ -19,18 +19,15 @@ from __future__ import print_function
# This file can't be called email.py because if it is, we cannot:
import email.utils
import logging
import os
import pkg_resources
from ._base import Config, ConfigError
logger = logging.getLogger(__name__)
class EmailConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
# TODO: We should separate better the email configuration from the notification
# and account validity config.
@ -59,7 +56,7 @@ class EmailConfig(Config):
if self.email_notif_from is not None:
# make sure it's valid
parsed = email.utils.parseaddr(self.email_notif_from)
if parsed[1] == '':
if parsed[1] == "":
raise RuntimeError("Invalid notif_from address")
template_dir = email_config.get("template_dir")
@ -68,27 +65,27 @@ class EmailConfig(Config):
# (Note that loading as package_resources with jinja.PackageLoader doesn't
# work for the same reason.)
if not template_dir:
template_dir = pkg_resources.resource_filename(
'synapse', 'res/templates'
)
template_dir = pkg_resources.resource_filename("synapse", "res/templates")
self.email_template_dir = os.path.abspath(template_dir)
self.email_enable_notifs = email_config.get("enable_notifs", False)
account_validity_renewal_enabled = config.get(
"account_validity", {},
).get("renew_at")
account_validity_renewal_enabled = config.get("account_validity", {}).get(
"renew_at"
)
email_trust_identity_server_for_password_resets = email_config.get(
"trust_identity_server_for_password_resets", False,
"trust_identity_server_for_password_resets", False
)
self.email_password_reset_behaviour = (
"remote" if email_trust_identity_server_for_password_resets else "local"
)
self.password_resets_were_disabled_due_to_email_config = False
if self.email_password_reset_behaviour == "local" and email_config == {}:
logger.warn(
"User password resets have been disabled due to lack of email config"
)
# We cannot warn the user this has happened here
# Instead do so when a user attempts to reset their password
self.password_resets_were_disabled_due_to_email_config = True
self.email_password_reset_behaviour = "off"
# Get lifetime of a validation token in milliseconds
@ -104,62 +101,59 @@ class EmailConfig(Config):
# make sure we can import the required deps
import jinja2
import bleach
# prevent unused warnings
jinja2
bleach
if self.email_password_reset_behaviour == "local":
required = [
"smtp_host",
"smtp_port",
"notif_from",
]
required = ["smtp_host", "smtp_port", "notif_from"]
missing = []
for k in required:
if k not in email_config:
missing.append(k)
if (len(missing) > 0):
if len(missing) > 0:
raise RuntimeError(
"email.password_reset_behaviour is set to 'local' "
"but required keys are missing: %s" %
(", ".join(["email." + k for k in missing]),)
"but required keys are missing: %s"
% (", ".join(["email." + k for k in missing]),)
)
# Templates for password reset emails
self.email_password_reset_template_html = email_config.get(
"password_reset_template_html", "password_reset.html",
"password_reset_template_html", "password_reset.html"
)
self.email_password_reset_template_text = email_config.get(
"password_reset_template_text", "password_reset.txt",
"password_reset_template_text", "password_reset.txt"
)
self.email_password_reset_failure_template = email_config.get(
"password_reset_failure_template", "password_reset_failure.html",
"password_reset_failure_template", "password_reset_failure.html"
)
# This template does not support any replaceable variables, so we will
# read it from the disk once during setup
email_password_reset_success_template = email_config.get(
"password_reset_success_template", "password_reset_success.html",
"password_reset_success_template", "password_reset_success.html"
)
# Check templates exist
for f in [self.email_password_reset_template_html,
self.email_password_reset_template_text,
self.email_password_reset_failure_template,
email_password_reset_success_template]:
for f in [
self.email_password_reset_template_html,
self.email_password_reset_template_text,
self.email_password_reset_failure_template,
email_password_reset_success_template,
]:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find template file %s" % (p, ))
raise ConfigError("Unable to find template file %s" % (p,))
# Retrieve content of web templates
filepath = os.path.join(
self.email_template_dir,
email_password_reset_success_template,
self.email_template_dir, email_password_reset_success_template
)
self.email_password_reset_success_html_content = self.read_file(
filepath,
"email.password_reset_template_success_html",
filepath, "email.password_reset_template_success_html"
)
if config.get("public_baseurl") is None:
@ -183,10 +177,10 @@ class EmailConfig(Config):
if k not in email_config:
missing.append(k)
if (len(missing) > 0):
if len(missing) > 0:
raise RuntimeError(
"email.enable_notifs is True but required keys are missing: %s" %
(", ".join(["email." + k for k in missing]),)
"email.enable_notifs is True but required keys are missing: %s"
% (", ".join(["email." + k for k in missing]),)
)
if config.get("public_baseurl") is None:
@ -200,29 +194,27 @@ class EmailConfig(Config):
for f in self.email_notif_template_text, self.email_notif_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p, ))
raise ConfigError("Unable to find email template file %s" % (p,))
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
)
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
self.email_riot_base_url = email_config.get("riot_base_url", None)
if account_validity_renewal_enabled:
self.email_expiry_template_html = email_config.get(
"expiry_template_html", "notice_expiry.html",
"expiry_template_html", "notice_expiry.html"
)
self.email_expiry_template_text = email_config.get(
"expiry_template_text", "notice_expiry.txt",
"expiry_template_text", "notice_expiry.txt"
)
for f in self.email_expiry_template_text, self.email_expiry_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p, ))
raise ConfigError("Unable to find email template file %s" % (p,))
def default_config(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Enable sending emails for password resets, notification events or
# account expiry notices

View file

@ -17,11 +17,11 @@ from ._base import Config
class GroupsConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "")
def default_config(self, **kwargs):
def generate_config_section(self, **kwargs):
return """\
# Uncomment to allow non-server-admin users to create groups on this server
#

View file

@ -38,6 +38,7 @@ from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig
from .stats import StatsConfig
from .third_party_event_rules import ThirdPartyRulesConfig
from .tls import TlsConfig
from .user_directory import UserDirectoryConfig
from .voip import VoipConfig
@ -73,5 +74,6 @@ class HomeServerConfig(
StatsConfig,
ServerNoticesConfig,
RoomDirectoryConfig,
ThirdPartyRulesConfig,
):
pass

View file

@ -15,17 +15,15 @@
from ._base import Config, ConfigError
MISSING_JWT = (
"""Missing jwt library. This is required for jwt login.
MISSING_JWT = """Missing jwt library. This is required for jwt login.
Install by running:
pip install pyjwt
"""
)
class JWTConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
jwt_config = config.get("jwt_config", None)
if jwt_config:
self.jwt_enabled = jwt_config.get("enabled", False)
@ -34,6 +32,7 @@ class JWTConfig(Config):
try:
import jwt
jwt # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_JWT)
@ -42,7 +41,7 @@ class JWTConfig(Config):
self.jwt_secret = None
self.jwt_algorithm = None
def default_config(self, **kwargs):
def generate_config_section(self, **kwargs):
return """\
# The JWT needs to contain a globally unique "sub" (subject) claim.
#

View file

@ -65,13 +65,18 @@ class TrustedKeyServer(object):
class KeyConfig(Config):
def read_config(self, config):
def read_config(self, config, config_dir_path, **kwargs):
# the signing key can be specified inline or in a separate file
if "signing_key" in config:
self.signing_key = read_signing_keys([config["signing_key"]])
else:
self.signing_key_path = config["signing_key_path"]
self.signing_key = self.read_signing_key(self.signing_key_path)
signing_key_path = config.get("signing_key_path")
if signing_key_path is None:
signing_key_path = os.path.join(
config_dir_path, config["server_name"] + ".signing.key"
)
self.signing_key = self.read_signing_key(signing_key_path)
self.old_signing_keys = self.read_old_signing_keys(
config.get("old_signing_keys", {})
@ -117,7 +122,7 @@ class KeyConfig(Config):
# falsification of values
self.form_secret = config.get("form_secret", None)
def default_config(
def generate_config_section(
self, config_dir_path, server_name, generate_secrets=False, **kwargs
):
base_key_name = os.path.join(config_dir_path, server_name)
@ -237,10 +242,18 @@ class KeyConfig(Config):
)
return keys
def generate_files(self, config):
signing_key_path = config["signing_key_path"]
def generate_files(self, config, config_dir_path):
if "signing_key" in config:
return
signing_key_path = config.get("signing_key_path")
if signing_key_path is None:
signing_key_path = os.path.join(
config_dir_path, config["server_name"] + ".signing.key"
)
if not self.path_exists(signing_key_path):
print("Generating signing key file %s" % (signing_key_path,))
with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4)
write_signing_keys(signing_key_file, (generate_signing_key(key_id),))
@ -348,9 +361,8 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
result.verify_keys[key_id] = verify_key
if (
not federation_verify_certificates and
not server.get("accept_keys_insecurely")
if not federation_verify_certificates and not server.get(
"accept_keys_insecurely"
):
_assert_keyserver_has_verify_keys(result)

View file

@ -29,7 +29,8 @@ from synapse.util.versionstring import get_version_string
from ._base import Config
DEFAULT_LOG_CONFIG = Template("""
DEFAULT_LOG_CONFIG = Template(
"""
version: 1
formatters:
@ -68,26 +69,29 @@ loggers:
root:
level: INFO
handlers: [file, console]
""")
"""
)
class LoggingConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.verbosity = config.get("verbose", 0)
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
log_config = os.path.join(config_dir_path, server_name + ".log.config")
return """\
return (
"""\
## Logging ##
# A yaml python logging config file
#
log_config: "%(log_config)s"
""" % locals()
"""
% locals()
)
def read_arguments(self, args):
if args.verbose is not None:
@ -102,32 +106,43 @@ class LoggingConfig(Config):
def add_arguments(cls, parser):
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
'-v', '--verbose', dest="verbose", action='count',
"-v",
"--verbose",
dest="verbose",
action="count",
help="The verbosity level. Specify multiple times to increase "
"verbosity. (Ignored if --log-config is specified.)"
"verbosity. (Ignored if --log-config is specified.)",
)
logging_group.add_argument(
'-f', '--log-file', dest="log_file",
help="File to log to. (Ignored if --log-config is specified.)"
"-f",
"--log-file",
dest="log_file",
help="File to log to. (Ignored if --log-config is specified.)",
)
logging_group.add_argument(
'--log-config', dest="log_config", default=None,
help="Python logging config file"
"--log-config",
dest="log_config",
default=None,
help="Python logging config file",
)
logging_group.add_argument(
'-n', '--no-redirect-stdio',
action='store_true', default=None,
help="Do not redirect stdout/stderr to the log"
"-n",
"--no-redirect-stdio",
action="store_true",
default=None,
help="Do not redirect stdout/stderr to the log",
)
def generate_files(self, config):
def generate_files(self, config, config_dir_path):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")
print(
"Generating log config file %s which will log to %s"
% (log_config, log_file)
)
with open(log_config, "w") as log_config_file:
log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
)
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
def setup_logging(config, use_worker_options=False):
@ -143,10 +158,8 @@ def setup_logging(config, use_worker_options=False):
register_sighup (func | None): Function to call to register a
sighup handler.
"""
log_config = (config.worker_log_config if use_worker_options
else config.log_config)
log_file = (config.worker_log_file if use_worker_options
else config.log_file)
log_config = config.worker_log_config if use_worker_options else config.log_config
log_file = config.worker_log_file if use_worker_options else config.log_file
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
@ -164,23 +177,23 @@ def setup_logging(config, use_worker_options=False):
if config.verbosity > 1:
level_for_storage = logging.DEBUG
logger = logging.getLogger('')
logger = logging.getLogger("")
logger.setLevel(level)
logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
logging.getLogger("synapse.storage.SQL").setLevel(level_for_storage)
formatter = logging.Formatter(log_format)
if log_file:
# TODO: Customisable file size / backup count
handler = logging.handlers.RotatingFileHandler(
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
encoding='utf8'
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, encoding="utf8"
)
def sighup(signum, stack):
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
else:
handler = logging.StreamHandler()
@ -193,8 +206,9 @@ def setup_logging(config, use_worker_options=False):
logger.addHandler(handler)
else:
def load_log_config():
with open(log_config, 'r') as f:
with open(log_config, "r") as f:
logging.config.dictConfig(yaml.safe_load(f))
def sighup(*args):
@ -209,10 +223,7 @@ def setup_logging(config, use_worker_options=False):
# make sure that the first thing we log is a thing we can grep backwards
# for
logging.warn("***** STARTING SERVER *****")
logging.warn(
"Server %s version %s",
sys.argv[0], get_version_string(synapse),
)
logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
# It's critical to point twisted's internal logging somewhere, otherwise it
@ -242,8 +253,7 @@ def setup_logging(config, use_worker_options=False):
return observer(event)
globalLogBeginner.beginLoggingTo(
[_log],
redirectStandardIO=not config.no_redirect_stdio,
[_log], redirectStandardIO=not config.no_redirect_stdio
)
if not config.no_redirect_stdio:
print("Redirected stdout/stderr to logs")

View file

@ -15,15 +15,13 @@
from ._base import Config, ConfigError
MISSING_SENTRY = (
"""Missing sentry-sdk library. This is required to enable sentry
MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentry
integration.
"""
)
class MetricsConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.enable_metrics = config.get("enable_metrics", False)
self.report_stats = config.get("report_stats", None)
self.metrics_port = config.get("metrics_port")
@ -39,10 +37,10 @@ class MetricsConfig(Config):
self.sentry_dsn = config["sentry"].get("dsn")
if not self.sentry_dsn:
raise ConfigError(
"sentry.dsn field is required when sentry integration is enabled",
"sentry.dsn field is required when sentry integration is enabled"
)
def default_config(self, report_stats=None, **kwargs):
def generate_config_section(self, report_stats=None, **kwargs):
res = """\
## Metrics ###
@ -66,6 +64,6 @@ class MetricsConfig(Config):
if report_stats is None:
res += "# report_stats: true|false\n"
else:
res += "report_stats: %s\n" % ('true' if report_stats else 'false')
res += "report_stats: %s\n" % ("true" if report_stats else "false")
return res

View file

@ -20,7 +20,7 @@ class PasswordConfig(Config):
"""Password login configuration
"""
def read_config(self, config):
def read_config(self, config, **kwargs):
password_config = config.get("password_config", {})
if password_config is None:
password_config = {}
@ -28,7 +28,7 @@ class PasswordConfig(Config):
self.password_enabled = password_config.get("enabled", True)
self.password_pepper = password_config.get("pepper", "")
def default_config(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
# Uncomment to disable password login

View file

@ -17,11 +17,11 @@ from synapse.util.module_loader import load_module
from ._base import Config
LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.password_providers = []
providers = []
@ -29,28 +29,24 @@ class PasswordAuthProviderConfig(Config):
# param.
ldap_config = config.get("ldap_config", {})
if ldap_config.get("enabled", False):
providers.append({
'module': LDAP_PROVIDER,
'config': ldap_config,
})
providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
providers.extend(config.get("password_providers", []))
for provider in providers:
mod_name = provider['module']
mod_name = provider["module"]
# This is for backwards compat when the ldap auth provider resided
# in this package.
if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
mod_name = LDAP_PROVIDER
(provider_class, provider_config) = load_module({
"module": mod_name,
"config": provider['config'],
})
(provider_class, provider_config) = load_module(
{"module": mod_name, "config": provider["config"]}
)
self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs):
def generate_config_section(self, **kwargs):
return """\
#password_providers:
# - module: "ldap_auth_provider.LdapAuthProvider"

View file

@ -18,7 +18,7 @@ from ._base import Config
class PushConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
@ -42,7 +42,7 @@ class PushConfig(Config):
)
self.push_include_content = not redact_content
def default_config(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Clients requesting push notifications can either have the body of
# the message sent in the notification poke along with other details

View file

@ -36,7 +36,7 @@ class FederationRateLimitConfig(object):
class RatelimitConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
# Load the new-style messages config if it exists. Otherwise fall back
# to the old method.
@ -80,7 +80,7 @@ class RatelimitConfig(Config):
"federation_rr_transactions_per_room_per_second", 50
)
def default_config(self, **kwargs):
def generate_config_section(self, **kwargs):
return """\
## Ratelimiting ##

View file

@ -23,7 +23,7 @@ from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config):
def __init__(self, config, synapse_config):
self.enabled = config.get("enabled", False)
self.renew_by_email_enabled = ("renew_at" in config)
self.renew_by_email_enabled = "renew_at" in config
if self.enabled:
if "period" in config:
@ -39,15 +39,14 @@ class AccountValidityConfig(Config):
else:
self.renew_email_subject = "Renew your %(app)s account"
self.startup_job_max_delta = self.period * 10. / 100.
self.startup_job_max_delta = self.period * 10.0 / 100.0
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
class RegistrationConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.enable_registration = bool(
strtobool(str(config.get("enable_registration", False)))
)
@ -57,7 +56,7 @@ class RegistrationConfig(Config):
)
self.account_validity = AccountValidityConfig(
config.get("account_validity", {}), config,
config.get("account_validity", {}), config
)
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
@ -67,35 +66,37 @@ class RegistrationConfig(Config):
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
"trusted_third_party_id_servers",
["matrix.org", "vector.im"],
"trusted_third_party_id_servers", ["matrix.org", "vector.im"]
)
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
self.invite_3pid_guest = (
self.allow_guest_access and config.get("invite_3pid_guest", False)
self.invite_3pid_guest = self.allow_guest_access and config.get(
"invite_3pid_guest", False
)
self.auto_join_rooms = config.get("auto_join_rooms", [])
for room_alias in self.auto_join_rooms:
if not RoomAlias.is_valid(room_alias):
raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
self.disable_msisdn_registration = (
config.get("disable_msisdn_registration", False)
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
def default_config(self, generate_secrets=False, **kwargs):
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
random_string_with_symbols(50),
)
else:
registration_shared_secret = '# registration_shared_secret: <PRIVATE STRING>'
registration_shared_secret = (
"# registration_shared_secret: <PRIVATE STRING>"
)
return """\
return (
"""\
## Registration ##
#
# Registration can be rate-limited using the parameters in the "Ratelimiting"
@ -217,17 +218,19 @@ class RegistrationConfig(Config):
# users cannot be auto-joined since they do not exist.
#
#autocreate_auto_join_rooms: true
""" % locals()
"""
% locals()
)
def add_arguments(self, parser):
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
"--enable-registration", action="store_true", default=None,
help="Enable registration for new users."
"--enable-registration",
action="store_true",
default=None,
help="Enable registration for new users.",
)
def read_arguments(self, args):
if args.enable_registration is not None:
self.enable_registration = bool(
strtobool(str(args.enable_registration))
)
self.enable_registration = bool(strtobool(str(args.enable_registration)))

View file

@ -20,27 +20,11 @@ from synapse.util.module_loader import load_module
from ._base import Config, ConfigError
DEFAULT_THUMBNAIL_SIZES = [
{
"width": 32,
"height": 32,
"method": "crop",
}, {
"width": 96,
"height": 96,
"method": "crop",
}, {
"width": 320,
"height": 240,
"method": "scale",
}, {
"width": 640,
"height": 480,
"method": "scale",
}, {
"width": 800,
"height": 600,
"method": "scale"
},
{"width": 32, "height": 32, "method": "crop"},
{"width": 96, "height": 96, "method": "crop"},
{"width": 320, "height": 240, "method": "scale"},
{"width": 640, "height": 480, "method": "scale"},
{"width": 800, "height": 600, "method": "scale"},
]
THUMBNAIL_SIZE_YAML = """\
@ -49,19 +33,15 @@ THUMBNAIL_SIZE_YAML = """\
# method: %(method)s
"""
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
)
MISSING_NETADDR = "Missing netaddr library. This is required for URL preview API."
MISSING_LXML = (
"""Missing lxml library. This is required for URL preview API.
MISSING_LXML = """Missing lxml library. This is required for URL preview API.
Install by running:
pip install lxml
Requires libxslt1-dev system package.
"""
)
ThumbnailRequirement = namedtuple(
@ -69,7 +49,8 @@ ThumbnailRequirement = namedtuple(
)
MediaStorageProviderConfig = namedtuple(
"MediaStorageProviderConfig", (
"MediaStorageProviderConfig",
(
"store_local", # Whether to store newly uploaded local files
"store_remote", # Whether to store newly downloaded remote files
"store_synchronous", # Whether to wait for successful storage for local uploads
@ -100,18 +81,19 @@ def parse_thumbnail_requirements(thumbnail_sizes):
requirements.setdefault("image/gif", []).append(png_thumbnail)
requirements.setdefault("image/png", []).append(png_thumbnail)
return {
media_type: tuple(thumbnails)
for media_type, thumbnails in requirements.items()
media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items()
}
class ContentRepositoryConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M"))
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
self.media_store_path = self.ensure_directory(config["media_store_path"])
self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
backup_media_store_path = config.get("backup_media_store_path")
@ -127,15 +109,15 @@ class ContentRepositoryConfig(Config):
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
)
storage_providers = [{
"module": "file_system",
"store_local": True,
"store_synchronous": synchronous_backup_media_store,
"store_remote": True,
"config": {
"directory": backup_media_store_path,
storage_providers = [
{
"module": "file_system",
"store_local": True,
"store_synchronous": synchronous_backup_media_store,
"store_remote": True,
"config": {"directory": backup_media_store_path},
}
}]
]
# This is a list of config that can be used to create the storage
# providers. The entries are tuples of (Class, class_config,
@ -165,18 +147,19 @@ class ContentRepositoryConfig(Config):
)
self.media_storage_providers.append(
(provider_class, parsed_config, wrapper_config,)
(provider_class, parsed_config, wrapper_config)
)
self.uploads_path = self.ensure_directory(config["uploads_path"])
self.uploads_path = self.ensure_directory(config.get("uploads_path", "uploads"))
self.dynamic_thumbnails = config.get("dynamic_thumbnails", False)
self.thumbnail_requirements = parse_thumbnail_requirements(
config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES),
config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES)
)
self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled:
try:
import lxml
lxml # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_LXML)
@ -199,17 +182,15 @@ class ContentRepositoryConfig(Config):
# we always blacklist '0.0.0.0' and '::', which are supposed to be
# unroutable addresses.
self.url_preview_ip_range_blacklist.update(['0.0.0.0', '::'])
self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"])
self.url_preview_ip_range_whitelist = IPSet(
config.get("url_preview_ip_range_whitelist", ())
)
self.url_preview_url_blacklist = config.get(
"url_preview_url_blacklist", ()
)
self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ())
def default_config(self, data_dir_path, **kwargs):
def generate_config_section(self, data_dir_path, **kwargs):
media_store = os.path.join(data_dir_path, "media_store")
uploads_path = os.path.join(data_dir_path, "uploads")
@ -219,7 +200,8 @@ class ContentRepositoryConfig(Config):
# strip final NL
formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1]
return r"""
return (
r"""
# Directory where uploaded images and attachments are stored.
#
media_store_path: "%(media_store)s"
@ -342,4 +324,6 @@ class ContentRepositoryConfig(Config):
# The largest allowed URL preview spidering size in bytes
#
#max_spider_size: 10M
""" % locals()
"""
% locals()
)

View file

@ -19,10 +19,8 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
def read_config(self, config):
self.enable_room_list_search = config.get(
"enable_room_list_search", True,
)
def read_config(self, config, **kwargs):
self.enable_room_list_search = config.get("enable_room_list_search", True)
alias_creation_rules = config.get("alias_creation_rules")
@ -33,11 +31,7 @@ class RoomDirectoryConfig(Config):
]
else:
self._alias_creation_rules = [
_RoomDirectoryRule(
"alias_creation_rules", {
"action": "allow",
}
)
_RoomDirectoryRule("alias_creation_rules", {"action": "allow"})
]
room_list_publication_rules = config.get("room_list_publication_rules")
@ -49,14 +43,10 @@ class RoomDirectoryConfig(Config):
]
else:
self._room_list_publication_rules = [
_RoomDirectoryRule(
"room_list_publication_rules", {
"action": "allow",
}
)
_RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
]
def default_config(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Uncomment to disable searching the public room list. When disabled
# blocks searching local and remote room lists for local and remote
@ -178,8 +168,7 @@ class _RoomDirectoryRule(object):
self.action = action
else:
raise ConfigError(
"%s rules can only have action of 'allow'"
" or 'deny'" % (option_name,)
"%s rules can only have action of 'allow'" " or 'deny'" % (option_name,)
)
self._alias_matches_all = alias == "*"

View file

@ -18,7 +18,7 @@ from ._base import Config, ConfigError
class SAML2Config(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.saml2_enabled = False
saml2_config = config.get("saml2_config")
@ -34,6 +34,7 @@ class SAML2Config(Config):
self.saml2_enabled = True
import saml2.config
self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(self._default_saml_config_dict())
self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
@ -47,29 +48,26 @@ class SAML2Config(Config):
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError(
"saml2_config requires a public_baseurl to be set"
)
raise ConfigError("saml2_config requires a public_baseurl to be set")
metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
response_url = public_baseurl + "_matrix/saml2/authn_response"
return {
"entityid": metadata_url,
"service": {
"sp": {
"endpoints": {
"assertion_consumer_service": [
(response_url, saml2.BINDING_HTTP_POST),
],
(response_url, saml2.BINDING_HTTP_POST)
]
},
"required_attributes": ["uid"],
"optional_attributes": ["mail", "surname", "givenname"],
},
}
}
},
}
def default_config(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
# Enable SAML2 for registration and login. Uses pysaml2.
#
@ -112,4 +110,6 @@ class SAML2Config(Config):
# # separate pysaml2 configuration file:
# #
# config_path: "%(config_dir_path)s/sp_conf.py"
""" % {"config_dir_path": config_dir_path}
""" % {
"config_dir_path": config_dir_path
}

View file

@ -34,14 +34,13 @@ logger = logging.Logger(__name__)
#
# We later check for errors when binding to 0.0.0.0 and ignore them if :: is also in
# in the list.
DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0']
DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
DEFAULT_ROOM_VERSION = "4"
class ServerConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.server_name = config["server_name"]
self.server_context = config.get("server_context", None)
@ -58,7 +57,6 @@ class ServerConfig(Config):
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl")
self.cpu_affinity = config.get("cpu_affinity")
# Whether to send federation traffic out in this process. This only
# applies to some federation traffic, and so shouldn't be used to
@ -81,27 +79,45 @@ class ServerConfig(Config):
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API.
self.require_auth_for_profile_requests = config.get(
"require_auth_for_profile_requests", False,
"require_auth_for_profile_requests", False
)
# If set to 'True', requires authentication to access the server's
# public rooms directory through the client API, and forbids any other
# homeserver to fetch it via federation.
self.restrict_public_rooms_to_local_users = config.get(
"restrict_public_rooms_to_local_users", False,
)
if "restrict_public_rooms_to_local_users" in config and (
"allow_public_rooms_without_auth" in config
or "allow_public_rooms_over_federation" in config
):
raise ConfigError(
"Can't use 'restrict_public_rooms_to_local_users' if"
" 'allow_public_rooms_without_auth' and/or"
" 'allow_public_rooms_over_federation' is set."
)
default_room_version = config.get(
"default_room_version", DEFAULT_ROOM_VERSION,
)
# Check if the legacy "restrict_public_rooms_to_local_users" flag is set. This
# flag is now obsolete but we need to check it for backward-compatibility.
if config.get("restrict_public_rooms_to_local_users", False):
self.allow_public_rooms_without_auth = False
self.allow_public_rooms_over_federation = False
else:
# If set to 'False', requires authentication to access the server's public
# rooms directory through the client API. Defaults to 'True'.
self.allow_public_rooms_without_auth = config.get(
"allow_public_rooms_without_auth", True
)
# If set to 'False', forbids any other homeserver to fetch the server's public
# rooms directory via federation. Defaults to 'True'.
self.allow_public_rooms_over_federation = config.get(
"allow_public_rooms_over_federation", True
)
default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION)
# Ensure room version is a str
default_room_version = str(default_room_version)
if default_room_version not in KNOWN_ROOM_VERSIONS:
raise ConfigError(
"Unknown default_room_version: %s, known room versions: %s" %
(default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
"Unknown default_room_version: %s, known room versions: %s"
% (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
)
# Get the actual room version object rather than just the identifier
@ -116,31 +132,25 @@ class ServerConfig(Config):
# Whether we should block invites sent to users on this server
# (other than those sent by local server admins)
self.block_non_admin_invites = config.get(
"block_non_admin_invites", False,
)
self.block_non_admin_invites = config.get("block_non_admin_invites", False)
# Whether to enable experimental MSC1849 (aka relations) support
self.experimental_msc1849_support_enabled = config.get(
"experimental_msc1849_support_enabled", False,
"experimental_msc1849_support_enabled", False
)
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = 0
if self.limit_usage_by_mau:
self.max_mau_value = config.get(
"max_mau_value", 0,
)
self.max_mau_value = config.get("max_mau_value", 0)
self.mau_stats_only = config.get("mau_stats_only", False)
self.mau_limits_reserved_threepids = config.get(
"mau_limit_reserved_threepids", []
)
self.mau_trial_days = config.get(
"mau_trial_days", 0,
)
self.mau_trial_days = config.get("mau_trial_days", 0)
# Options to disable HS
self.hs_disabled = config.get("hs_disabled", False)
@ -153,9 +163,7 @@ class ServerConfig(Config):
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None
federation_domain_whitelist = config.get(
"federation_domain_whitelist", None,
)
federation_domain_whitelist = config.get("federation_domain_whitelist", None)
if federation_domain_whitelist is not None:
# turn the whitelist into a hash for speed of lookup
@ -165,7 +173,7 @@ class ServerConfig(Config):
self.federation_domain_whitelist[domain] = True
self.federation_ip_range_blacklist = config.get(
"federation_ip_range_blacklist", [],
"federation_ip_range_blacklist", []
)
# Attempt to create an IPSet from the given ranges
@ -178,13 +186,12 @@ class ServerConfig(Config):
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
except Exception as e:
raise ConfigError(
"Invalid range(s) provided in "
"federation_ip_range_blacklist: %s" % e
"Invalid range(s) provided in " "federation_ip_range_blacklist: %s" % e
)
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
self.start_pushers = config.get("start_pushers", True)
# (undocumented) option for torturing the worker-mode replication a bit,
@ -195,7 +202,7 @@ class ServerConfig(Config):
# Whether to require a user to be in the room to add an alias to it.
# Defaults to True.
self.require_membership_for_aliases = config.get(
"require_membership_for_aliases", True,
"require_membership_for_aliases", True
)
# Whether to allow per-room membership profiles through the send of membership
@ -227,9 +234,9 @@ class ServerConfig(Config):
# if we still have an empty list of addresses, use the default list
if not bind_addresses:
if listener['type'] == 'metrics':
if listener["type"] == "metrics":
# the metrics listener doesn't support IPv6
bind_addresses.append('0.0.0.0')
bind_addresses.append("0.0.0.0")
else:
bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
@ -249,78 +256,80 @@ class ServerConfig(Config):
bind_host = config.get("bind_host", "")
gzip_responses = config.get("gzip_responses", True)
self.listeners.append({
"port": bind_port,
"bind_addresses": [bind_host],
"tls": True,
"type": "http",
"resources": [
{
"names": ["client"],
"compress": gzip_responses,
},
{
"names": ["federation"],
"compress": False,
}
]
})
self.listeners.append(
{
"port": bind_port,
"bind_addresses": [bind_host],
"tls": True,
"type": "http",
"resources": [
{"names": ["client"], "compress": gzip_responses},
{"names": ["federation"], "compress": False},
],
}
)
unsecure_port = config.get("unsecure_port", bind_port - 400)
if unsecure_port:
self.listeners.append({
"port": unsecure_port,
"bind_addresses": [bind_host],
"tls": False,
"type": "http",
"resources": [
{
"names": ["client"],
"compress": gzip_responses,
},
{
"names": ["federation"],
"compress": False,
}
]
})
self.listeners.append(
{
"port": unsecure_port,
"bind_addresses": [bind_host],
"tls": False,
"type": "http",
"resources": [
{"names": ["client"], "compress": gzip_responses},
{"names": ["federation"], "compress": False},
],
}
)
manhole = config.get("manhole")
if manhole:
self.listeners.append({
"port": manhole,
"bind_addresses": ["127.0.0.1"],
"type": "manhole",
"tls": False,
})
self.listeners.append(
{
"port": manhole,
"bind_addresses": ["127.0.0.1"],
"type": "manhole",
"tls": False,
}
)
metrics_port = config.get("metrics_port")
if metrics_port:
logger.warn(
("The metrics_port configuration option is deprecated in Synapse 0.31 "
"in favour of a listener. Please see "
"http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
" on how to configure the new listener."))
(
"The metrics_port configuration option is deprecated in Synapse 0.31 "
"in favour of a listener. Please see "
"http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
" on how to configure the new listener."
)
)
self.listeners.append({
"port": metrics_port,
"bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
"tls": False,
"type": "http",
"resources": [
{
"names": ["metrics"],
"compress": False,
},
]
})
self.listeners.append(
{
"port": metrics_port,
"bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
"tls": False,
"type": "http",
"resources": [{"names": ["metrics"], "compress": False}],
}
)
_check_resource_config(self.listeners)
# An experimental option to try and periodically clean up extremities
# by sending dummy events.
self.cleanup_extremities_with_dummy_events = config.get(
"cleanup_extremities_with_dummy_events", False
)
def has_tls_listener(self):
return any(l["tls"] for l in self.listeners)
def default_config(self, server_name, data_dir_path, **kwargs):
def generate_config_section(
self, server_name, data_dir_path, open_private_ports, **kwargs
):
_, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None:
unsecure_port = bind_port - 400
@ -333,7 +342,15 @@ class ServerConfig(Config):
# Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
# default config string
default_room_version = DEFAULT_ROOM_VERSION
return """\
unsecure_http_binding = "port: %i\n tls: false" % (unsecure_port,)
if not open_private_ports:
unsecure_http_binding += (
"\n bind_addresses: ['::1', '127.0.0.1']"
)
return (
"""\
## Server ##
# The domain name of the server, with optional explicit port.
@ -347,29 +364,6 @@ class ServerConfig(Config):
#
pid_file: %(pid_file)s
# CPU affinity mask. Setting this restricts the CPUs on which the
# process will be scheduled. It is represented as a bitmask, with the
# lowest order bit corresponding to the first logical CPU and the
# highest order bit corresponding to the last logical CPU. Not all CPUs
# may exist on a given system but a mask may specify more CPUs than are
# present.
#
# For example:
# 0x00000001 is processor #0,
# 0x00000003 is processors #0 and #1,
# 0xFFFFFFFF is all processors (#0 through #31).
#
# Pinning a Python process to a single CPU is desirable, because Python
# is inherently single-threaded due to the GIL, and can suffer a
# 30-40%% slowdown due to cache blow-out and thread context switching
# if the scheduler happens to schedule the underlying threads across
# different cores. See
# https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
#
# This setting requires the affinity package to be installed!
#
#cpu_affinity: 0xFFFFFFFF
# The path to the web client which will be served at /_matrix/client/
# if 'webclient' is configured under the 'listeners' configuration.
#
@ -401,11 +395,15 @@ class ServerConfig(Config):
#
#require_auth_for_profile_requests: true
# If set to 'true', requires authentication to access the server's
# public rooms directory through the client API, and forbids any other
# homeserver to fetch it via federation. Defaults to 'false'.
# If set to 'false', requires authentication to access the server's public rooms
# directory through the client API. Defaults to 'true'.
#
#restrict_public_rooms_to_local_users: true
#allow_public_rooms_without_auth: false
# If set to 'false', forbids any other homeserver to fetch the server's public
# rooms directory via federation. Defaults to 'true'.
#
#allow_public_rooms_over_federation: false
# The default room version for newly created rooms.
#
@ -546,9 +544,7 @@ class ServerConfig(Config):
# If you plan to use a reverse proxy, please see
# https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst.
#
- port: %(unsecure_port)s
tls: false
bind_addresses: ['::1', '127.0.0.1']
- %(unsecure_http_binding)s
type: http
x_forwarded: true
@ -556,7 +552,7 @@ class ServerConfig(Config):
- names: [client, federation]
compress: false
# example additonal_resources:
# example additional_resources:
#
#additional_resources:
# "/_matrix/my/custom/endpoint":
@ -631,7 +627,9 @@ class ServerConfig(Config):
# Defaults to 'true'.
#
#allow_per_room_profiles: false
""" % locals()
"""
% locals()
)
def read_arguments(self, args):
if args.manhole is not None:
@ -643,17 +641,26 @@ class ServerConfig(Config):
def add_arguments(self, parser):
server_group = parser.add_argument_group("server")
server_group.add_argument("-D", "--daemonize", action='store_true',
default=None,
help="Daemonize the home server")
server_group.add_argument("--print-pidfile", action='store_true',
default=None,
help="Print the path to the pidfile just"
" before daemonizing")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"
" service on the given port.")
server_group.add_argument(
"-D",
"--daemonize",
action="store_true",
default=None,
help="Daemonize the home server",
)
server_group.add_argument(
"--print-pidfile",
action="store_true",
default=None,
help="Print the path to the pidfile just" " before daemonizing",
)
server_group.add_argument(
"--manhole",
metavar="PORT",
dest="manhole",
type=int,
help="Turn on the twisted telnet manhole" " service on the given port.",
)
def is_threepid_reserved(reserved_threepids, threepid):
@ -667,7 +674,7 @@ def is_threepid_reserved(reserved_threepids, threepid):
"""
for tp in reserved_threepids:
if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']):
if threepid["medium"] == tp["medium"] and threepid["address"] == tp["address"]:
return True
return False
@ -680,9 +687,7 @@ def read_gc_thresholds(thresholds):
return None
try:
assert len(thresholds) == 3
return (
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
)
return (int(thresholds[0]), int(thresholds[1]), int(thresholds[2]))
except Exception:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
@ -700,22 +705,22 @@ def _warn_if_webclient_configured(listeners):
for listener in listeners:
for res in listener.get("resources", []):
for name in res.get("names", []):
if name == 'webclient':
if name == "webclient":
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
KNOWN_RESOURCES = (
'client',
'consent',
'federation',
'keys',
'media',
'metrics',
'openid',
'replication',
'static',
'webclient',
"client",
"consent",
"federation",
"keys",
"media",
"metrics",
"openid",
"replication",
"static",
"webclient",
)
@ -729,11 +734,9 @@ def _check_resource_config(listeners):
for resource in resource_names:
if resource not in KNOWN_RESOURCES:
raise ConfigError(
"Unknown listener resource '%s'" % (resource, )
)
raise ConfigError("Unknown listener resource '%s'" % (resource,))
if resource == "consent":
try:
check_requirements('resources.consent')
check_requirements("resources.consent")
except DependencyException as e:
raise ConfigError(e.message)

View file

@ -58,6 +58,7 @@ class ServerNoticesConfig(Config):
The name to use for the server notices room.
None if server notices are not enabled.
"""
def __init__(self):
super(ServerNoticesConfig, self).__init__()
self.server_notices_mxid = None
@ -65,23 +66,17 @@ class ServerNoticesConfig(Config):
self.server_notices_mxid_avatar_url = None
self.server_notices_room_name = None
def read_config(self, config):
def read_config(self, config, **kwargs):
c = config.get("server_notices")
if c is None:
return
mxid_localpart = c['system_mxid_localpart']
self.server_notices_mxid = UserID(
mxid_localpart, self.server_name,
).to_string()
self.server_notices_mxid_display_name = c.get(
'system_mxid_display_name', None,
)
self.server_notices_mxid_avatar_url = c.get(
'system_mxid_avatar_url', None,
)
mxid_localpart = c["system_mxid_localpart"]
self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string()
self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None)
self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None)
# todo: i18n
self.server_notices_room_name = c.get('room_name', "Server Notices")
self.server_notices_room_name = c.get("room_name", "Server Notices")
def default_config(self, **kwargs):
def generate_config_section(self, **kwargs):
return DEFAULT_CONFIG

View file

@ -19,14 +19,14 @@ from ._base import Config
class SpamCheckerConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.spam_checker = None
provider = config.get("spam_checker", None)
if provider is not None:
self.spam_checker = load_module(provider)
def default_config(self, **kwargs):
def generate_config_section(self, **kwargs):
return """\
#spam_checker:
# module: "my_custom_project.SuperSpamChecker"

View file

@ -25,7 +25,7 @@ class StatsConfig(Config):
Configuration for the behaviour of synapse's stats engine
"""
def read_config(self, config):
def read_config(self, config, **kwargs):
self.stats_enabled = True
self.stats_bucket_size = 86400
self.stats_retention = sys.maxsize
@ -42,7 +42,7 @@ class StatsConfig(Config):
/ 1000
)
def default_config(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Local statistics collection. Used in populating the room directory.
#

View file

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.module_loader import load_module
from ._base import Config
class ThirdPartyRulesConfig(Config):
def read_config(self, config, **kwargs):
self.third_party_event_rules = None
provider = config.get("third_party_event_rules", None)
if provider is not None:
self.third_party_event_rules = load_module(provider)
def generate_config_section(self, **kwargs):
return """\
# Server admins can define a Python module that implements extra rules for
# allowing or denying incoming events. In order to work, this module needs to
# override the methods defined in synapse/events/third_party_rules.py.
#
# This feature is designed to be used in closed federations only, where each
# participating server enforces the same rules.
#
#third_party_event_rules:
# module: "my_custom_project.SuperRulesSet"
# config:
# example_option: 'things'
"""

View file

@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
class TlsConfig(Config):
def read_config(self, config):
def read_config(self, config, config_dir_path, **kwargs):
acme_config = config.get("acme", None)
if acme_config is None:
@ -42,14 +42,18 @@ class TlsConfig(Config):
self.acme_enabled = acme_config.get("enabled", False)
# hyperlink complains on py2 if this is not a Unicode
self.acme_url = six.text_type(acme_config.get(
"url", u"https://acme-v01.api.letsencrypt.org/directory"
))
self.acme_url = six.text_type(
acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory")
)
self.acme_port = acme_config.get("port", 80)
self.acme_bind_addresses = acme_config.get("bind_addresses", ['::', '0.0.0.0'])
self.acme_bind_addresses = acme_config.get("bind_addresses", ["::", "0.0.0.0"])
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
self.acme_domain = acme_config.get("domain", config.get("server_name"))
self.acme_account_key_file = self.abspath(
acme_config.get("account_key_file", config_dir_path + "/client.key")
)
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
@ -74,12 +78,12 @@ class TlsConfig(Config):
# Whether to verify certificates on outbound federation traffic
self.federation_verify_certificates = config.get(
"federation_verify_certificates", True,
"federation_verify_certificates", True
)
# Whitelist of domains to not verify certificates for
fed_whitelist_entries = config.get(
"federation_certificate_verification_whitelist", [],
"federation_certificate_verification_whitelist", []
)
# Support globs (*) in whitelist values
@ -90,9 +94,7 @@ class TlsConfig(Config):
self.federation_certificate_verification_whitelist.append(entry_regex)
# List of custom certificate authorities for federation traffic validation
custom_ca_list = config.get(
"federation_custom_ca_list", None,
)
custom_ca_list = config.get("federation_custom_ca_list", None)
# Read in and parse custom CA certificates
self.federation_ca_trust_root = None
@ -101,8 +103,10 @@ class TlsConfig(Config):
# A trustroot cannot be generated without any CA certificates.
# Raise an error if this option has been specified without any
# corresponding certificates.
raise ConfigError("federation_custom_ca_list specified without "
"any certificate files")
raise ConfigError(
"federation_custom_ca_list specified without "
"any certificate files"
)
certs = []
for ca_file in custom_ca_list:
@ -114,8 +118,9 @@ class TlsConfig(Config):
cert_base = Certificate.loadPEM(content)
certs.append(cert_base)
except Exception as e:
raise ConfigError("Error parsing custom CA certificate file %s: %s"
% (ca_file, e))
raise ConfigError(
"Error parsing custom CA certificate file %s: %s" % (ca_file, e)
)
self.federation_ca_trust_root = trustRootFromCertificates(certs)
@ -146,17 +151,21 @@ class TlsConfig(Config):
return None
try:
with open(self.tls_certificate_file, 'rb') as f:
with open(self.tls_certificate_file, "rb") as f:
cert_pem = f.read()
except Exception as e:
raise ConfigError("Failed to read existing certificate file %s: %s"
% (self.tls_certificate_file, e))
raise ConfigError(
"Failed to read existing certificate file %s: %s"
% (self.tls_certificate_file, e)
)
try:
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
except Exception as e:
raise ConfigError("Failed to parse existing certificate file %s: %s"
% (self.tls_certificate_file, e))
raise ConfigError(
"Failed to parse existing certificate file %s: %s"
% (self.tls_certificate_file, e)
)
if not allow_self_signed:
if tls_certificate.get_subject() == tls_certificate.get_issuer():
@ -166,7 +175,7 @@ class TlsConfig(Config):
# YYYYMMDDhhmmssZ -- in UTC
expires_on = datetime.strptime(
tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
tls_certificate.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ"
)
now = datetime.utcnow()
days_remaining = (expires_on - now).days
@ -191,7 +200,8 @@ class TlsConfig(Config):
except Exception as e:
logger.info(
"Unable to read TLS certificate (%s). Ignoring as no "
"tls listeners enabled.", e,
"tls listeners enabled.",
e,
)
self.tls_fingerprints = list(self._original_tls_fingerprints)
@ -205,18 +215,21 @@ class TlsConfig(Config):
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
self.tls_fingerprints.append({"sha256": sha256_fingerprint})
def default_config(self, config_dir_path, server_name, **kwargs):
def generate_config_section(
self, config_dir_path, server_name, data_dir_path, **kwargs
):
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key"
default_acme_account_file = os.path.join(data_dir_path, "acme_account.key")
# this is to avoid the max line length. Sorrynotsorry
proxypassline = (
'ProxyPass /.well-known/acme-challenge '
'http://localhost:8009/.well-known/acme-challenge'
"ProxyPass /.well-known/acme-challenge "
"http://localhost:8009/.well-known/acme-challenge"
)
return (
@ -337,6 +350,13 @@ class TlsConfig(Config):
#
#domain: matrix.example.com
# file to use for the account key. This will be generated if it doesn't
# exist.
#
# If unspecified, we will use CONFDIR/client.key.
#
account_key_file: %(default_acme_account_file)s
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
# make HTTPS requests to this server will check that the TLS

View file

@ -21,19 +21,19 @@ class UserDirectoryConfig(Config):
Configuration for the behaviour of the /user_directory API
"""
def read_config(self, config):
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_enabled = (
user_directory_config.get("enabled", True)
self.user_directory_search_enabled = user_directory_config.get(
"enabled", True
)
self.user_directory_search_all_users = (
user_directory_config.get("search_all_users", False)
self.user_directory_search_all_users = user_directory_config.get(
"search_all_users", False
)
def default_config(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# User Directory configuration
#

View file

@ -16,18 +16,17 @@ from ._base import Config
class VoipConfig(Config):
def read_config(self, config):
def read_config(self, config, **kwargs):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret")
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(
config.get("turn_user_lifetime", "1h"),
config.get("turn_user_lifetime", "1h")
)
self.turn_allow_guests = config.get("turn_allow_guests", True)
def default_config(self, **kwargs):
def generate_config_section(self, **kwargs):
return """\
## TURN ##

View file

@ -21,7 +21,7 @@ class WorkerConfig(Config):
They have their own pid_file and listener configuration. They use the
replication_url to talk to the main synapse process."""
def read_config(self, config):
def read_config(self, config, **kwargs):
self.worker_app = config.get("worker_app")
# Canonicalise worker_app so that master always has None
@ -46,18 +46,19 @@ class WorkerConfig(Config):
self.worker_name = config.get("worker_name", self.worker_app)
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
self.worker_cpu_affinity = config.get("worker_cpu_affinity")
# This option is really only here to support `--manhole` command line
# argument.
manhole = config.get("worker_manhole")
if manhole:
self.worker_listeners.append({
"port": manhole,
"bind_addresses": ["127.0.0.1"],
"type": "manhole",
"tls": False,
})
self.worker_listeners.append(
{
"port": manhole,
"bind_addresses": ["127.0.0.1"],
"type": "manhole",
"tls": False,
}
)
if self.worker_listeners:
for listener in self.worker_listeners:
@ -67,7 +68,7 @@ class WorkerConfig(Config):
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')
bind_addresses.append("")
def read_arguments(self, args):
# We support a bunch of command line arguments that override options in

View file

@ -46,9 +46,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
if name not in hashes:
raise SynapseError(
400,
"Algorithm %s not in hashes %s" % (
name, list(hashes),
),
"Algorithm %s not in hashes %s" % (name, list(hashes)),
Codes.UNAUTHORIZED,
)
message_hash_base64 = hashes[name]
@ -56,9 +54,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
message_hash_bytes = decode_base64(message_hash_base64)
except Exception:
raise SynapseError(
400,
"Invalid base64: %s" % (message_hash_base64,),
Codes.UNAUTHORIZED,
400, "Invalid base64: %s" % (message_hash_base64,), Codes.UNAUTHORIZED
)
return message_hash_bytes == expected_hash
@ -135,8 +131,9 @@ def compute_event_signature(event_dict, signature_name, signing_key):
return redact_json["signatures"]
def add_hashes_and_signatures(event_dict, signature_name, signing_key,
hash_algorithm=hashlib.sha256):
def add_hashes_and_signatures(
event_dict, signature_name, signing_key, hash_algorithm=hashlib.sha256
):
"""Add content hash and sign the event
Args:
@ -153,7 +150,5 @@ def add_hashes_and_signatures(event_dict, signature_name, signing_key,
event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
event_dict["signatures"] = compute_event_signature(
event_dict,
signature_name=signature_name,
signing_key=signing_key,
event_dict, signature_name=signature_name, signing_key=signing_key
)

View file

@ -505,7 +505,7 @@ class BaseV2KeyFetcher(object):
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
ts_valid_until_ms = response_json[u"valid_until_ts"]
ts_valid_until_ms = response_json["valid_until_ts"]
# start by extracting the keys from the response, since they may be required
# to validate the signature on the response.
@ -614,10 +614,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(get_key, server)
for server in self.key_servers
],
[run_in_background(get_key, server) for server in self.key_servers],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
@ -630,9 +627,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(
self, keys_to_fetch, key_server
):
def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
"""
Args:
keys_to_fetch (dict[str, dict[str, int]]):
@ -661,9 +656,9 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
destination=perspective_name,
path="/_matrix/key/v2/query",
data={
u"server_keys": {
"server_keys": {
server_name: {
key_id: {u"minimum_valid_until_ts": min_valid_ts}
key_id: {"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
}
for server_name, server_keys in keys_to_fetch.items()
@ -690,10 +685,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
try:
self._validate_perspectives_response(
key_server,
response,
)
self._validate_perspectives_response(key_server, response)
processed_response = yield self.process_v2_response(
perspective_name, response, time_added_ms=time_now_ms
@ -720,9 +712,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
defer.returnValue(keys)
def _validate_perspectives_response(
self, key_server, response,
):
def _validate_perspectives_response(self, key_server, response):
"""Optionally check the signature on the result of a /key/query request
Args:
@ -739,13 +729,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return
if (
u"signatures" not in response
or perspective_name not in response[u"signatures"]
"signatures" not in response
or perspective_name not in response["signatures"]
):
raise KeyLookupError("Response not signed by the notary server")
verified = False
for key_id in response[u"signatures"][perspective_name]:
for key_id in response["signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(response, perspective_name, perspective_keys[key_id])
verified = True
@ -754,7 +744,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
raise KeyLookupError(
"Response not signed with a known key: signed with: %r, known keys: %r"
% (
list(response[u"signatures"][perspective_name].keys()),
list(response["signatures"][perspective_name].keys()),
list(perspective_keys.keys()),
)
)
@ -826,7 +816,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
ignore_backoff=True,
# we only give the remote server 10s to respond. It should be an
# easy request to handle, so if it doesn't reply within 10s, it's
# probably not going to.

View file

@ -85,17 +85,14 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
403, "Creation event's room_id domain does not match sender's"
)
room_version = event.content.get("room_version", "1")
if room_version not in KNOWN_ROOM_VERSIONS:
raise AuthError(
403,
"room appears to have unsupported version %s" % (
room_version,
))
403, "room appears to have unsupported version %s" % (room_version,)
)
# FIXME
logger.debug("Allowing! %s", event)
return
@ -103,46 +100,30 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise AuthError(
403,
"No create event in auth events",
)
raise AuthError(403, "No create event in auth events")
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
raise AuthError(403, "This room has been marked as unfederatable.")
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
raise AuthError(403, "Alias event must be a state event")
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
raise AuthError(403, "Alias event must have non-empty state_key")
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
403, "Alias event's state_key does not match sender's domain"
)
logger.debug("Allowing! %s", event)
return
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
if event.type == EventTypes.Member:
_is_membership_change_allowed(event, auth_events)
@ -159,9 +140,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You don't have permission to invite users",
)
raise AuthError(403, "You don't have permission to invite users")
else:
logger.debug("Allowing! %s", event)
return
@ -207,7 +186,7 @@ def _is_membership_change_allowed(event, auth_events):
# Check if this is the room creator joining:
if len(event.prev_event_ids()) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
key = (EventTypes.Create, "")
create = auth_events.get(key)
if create and event.prev_event_ids()[0] == create.event_id:
if create.content["creator"] == event.state_key:
@ -219,38 +198,31 @@ def _is_membership_change_allowed(event, auth_events):
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
raise AuthError(403, "This room has been marked as unfederatable.")
# get info about the caller
key = (EventTypes.Member, event.user_id, )
key = (EventTypes.Member, event.user_id)
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
key = (EventTypes.Member, target_user_id)
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
key = (EventTypes.JoinRules, "")
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
join_rule = join_rule_event.content.get("join_rule", JoinRules.INVITE)
else:
join_rule = JoinRules.INVITE
user_level = get_user_power_level(event.user_id, auth_events)
target_level = get_user_power_level(
target_user_id, auth_events
)
target_level = get_user_power_level(target_user_id, auth_events)
# FIXME (erikj): What should we do here as the default?
ban_level = _get_named_level(auth_events, "ban", 50)
@ -266,29 +238,26 @@ def _is_membership_change_allowed(event, auth_events):
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
},
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not _verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
raise AuthError(403, "%s is banned from the room" % (target_user_id,))
return
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
if (
caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id
):
return
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id))
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
@ -296,19 +265,14 @@ def _is_membership_change_allowed(event, auth_events):
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
raise AuthError(403, "%s is banned from the room" % (target_user_id,))
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
raise AuthError(403, "%s is already in the room." % target_user_id)
else:
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You don't have permission to invite users",
)
raise AuthError(403, "You don't have permission to invite users")
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
@ -329,16 +293,12 @@ def _is_membership_change_allowed(event, auth_events):
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user %s." % (target_user_id,)
)
raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
raise AuthError(403, "You cannot kick user %s." % target_user_id)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
@ -347,21 +307,17 @@ def _is_membership_change_allowed(event, auth_events):
def _check_event_sender_in_room(event, auth_events):
key = (EventTypes.Member, event.user_id, )
key = (EventTypes.Member, event.user_id)
member_event = auth_events.get(key)
return _check_joined_room(
member_event,
event.user_id,
event.room_id
)
return _check_joined_room(member_event, event.user_id, event.room_id)
def _check_joined_room(member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
raise AuthError(
403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
)
def get_send_level(etype, state_key, power_levels_event):
@ -402,26 +358,21 @@ def get_send_level(etype, state_key, power_levels_event):
def _can_send_event(event, auth_events):
power_levels_event = _get_power_level_event(auth_events)
send_level = get_send_level(
event.type, event.get("state_key"), power_levels_event,
)
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
"You don't have permission to post that to the room. "
+ "user_level (%d) < send_level (%d)" % (user_level, send_level),
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
raise AuthError(403, "You are not allowed to set others state")
return True
@ -459,10 +410,7 @@ def check_redaction(room_version, event, auth_events):
event.internal_metadata.recheck_redaction = True
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
raise AuthError(403, "You don't have permission to redact events")
def _check_power_levels(event, auth_events):
@ -479,7 +427,7 @@ def _check_power_levels(event, auth_events):
except Exception:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
key = (event.type, event.state_key)
current_state = auth_events.get(key)
if not current_state:
@ -500,16 +448,12 @@ def _check_power_levels(event, auth_events):
old_list = current_state.content.get("users", {})
for user in set(list(old_list) + list(user_list)):
levels_to_check.append(
(user, "users")
)
levels_to_check.append((user, "users"))
old_list = current_state.content.get("events", {})
new_list = event.content.get("events", {})
for ev_id in set(list(old_list) + list(new_list)):
levels_to_check.append(
(ev_id, "events")
)
levels_to_check.append((ev_id, "events"))
old_state = current_state.content
new_state = event.content
@ -540,7 +484,7 @@ def _check_power_levels(event, auth_events):
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
"to your own",
)
# Check if the old and new levels are greater than the user level
@ -550,8 +494,7 @@ def _check_power_levels(event, auth_events):
if old_level_too_big or new_level_too_big:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
"You don't have permission to add ops level greater " "than your own",
)
@ -587,10 +530,9 @@ def get_user_power_level(user_id, auth_events):
# some things which call this don't pass the create event: hack around
# that.
key = (EventTypes.Create, "", )
key = (EventTypes.Create, "")
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
if create_event is not None and create_event.content["creator"] == user_id:
return 100
else:
return 0
@ -636,9 +578,7 @@ def _verify_third_party_invite(event, auth_events):
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
invite_event = auth_events.get((EventTypes.ThirdPartyInvite, token))
if not invite_event:
return False
@ -661,8 +601,7 @@ def _verify_third_party_invite(event, auth_events):
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
key_name, decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
@ -671,7 +610,7 @@ def _verify_third_party_invite(event, auth_events):
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
except (KeyError, SignatureVerifyException):
continue
return False
@ -679,9 +618,7 @@ def _verify_third_party_invite(event, auth_events):
def get_public_keys(invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
o = {"public_key": invite_event.content["public_key"]}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
@ -702,22 +639,22 @@ def auth_types_for_event(event):
auth_types = []
auth_types.append((EventTypes.PowerLevels, "", ))
auth_types.append((EventTypes.Member, event.sender, ))
auth_types.append((EventTypes.Create, "", ))
auth_types.append((EventTypes.PowerLevels, ""))
auth_types.append((EventTypes.Member, event.sender))
auth_types.append((EventTypes.Create, ""))
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
auth_types.append((EventTypes.JoinRules, "", ))
auth_types.append((EventTypes.JoinRules, ""))
auth_types.append((EventTypes.Member, event.state_key, ))
auth_types.append((EventTypes.Member, event.state_key))
if membership == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
event.content["third_party_invite"]["signed"]["token"],
)
auth_types.append(key)

View file

@ -92,6 +92,18 @@ class _EventInternalMetadata(object):
"""
return getattr(self, "soft_failed", False)
def should_proactively_send(self):
"""Whether the event, if ours, should be sent to other clients and
servers.
This is used for sending dummy events internally. Servers and clients
can still explicitly fetch the event.
Returns:
bool
"""
return getattr(self, "proactively_send", True)
def _event_dict_property(key):
# We want to be able to use hasattr with the event dict properties.
@ -115,25 +127,25 @@ def _event_dict_property(key):
except KeyError:
raise AttributeError(key)
return property(
getter,
setter,
delete,
)
return property(getter, setter, delete)
class EventBase(object):
def __init__(self, event_dict, signatures={}, unsigned={},
internal_metadata_dict={}, rejected_reason=None):
def __init__(
self,
event_dict,
signatures={},
unsigned={},
internal_metadata_dict={},
rejected_reason=None,
):
self.signatures = signatures
self.unsigned = unsigned
self.rejected_reason = rejected_reason
self._event_dict = event_dict
self.internal_metadata = _EventInternalMetadata(
internal_metadata_dict
)
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
@ -156,10 +168,7 @@ class EventBase(object):
def get_dict(self):
d = dict(self._event_dict)
d.update({
"signatures": self.signatures,
"unsigned": dict(self.unsigned),
})
d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
return d
@ -346,6 +355,7 @@ class FrozenEventV2(EventBase):
class FrozenEventV3(FrozenEventV2):
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
format_version = EventFormatVersions.V3 # All events of this type are V3
@property
@ -402,6 +412,4 @@ def event_type_from_format_version(format_version):
elif format_version == EventFormatVersions.V3:
return FrozenEventV3
else:
raise Exception(
"No event format %r" % (format_version,)
)
raise Exception("No event format %r" % (format_version,))

View file

@ -78,7 +78,9 @@ class EventBuilder(object):
_redacts = attr.ib(default=None)
_origin_server_ts = attr.ib(default=None)
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
internal_metadata = attr.ib(
default=attr.Factory(lambda: _EventInternalMetadata({}))
)
@property
def state_key(self):
@ -102,11 +104,9 @@ class EventBuilder(object):
"""
state_ids = yield self._state.get_current_state_ids(
self.room_id, prev_event_ids,
)
auth_ids = yield self._auth.compute_auth_events(
self, state_ids,
self.room_id, prev_event_ids
)
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
if self.format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids)
@ -115,9 +115,7 @@ class EventBuilder(object):
auth_events = auth_ids
prev_events = prev_event_ids
old_depth = yield self._store.get_max_depth_of(
prev_event_ids,
)
old_depth = yield self._store.get_max_depth_of(prev_event_ids)
depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not
@ -217,9 +215,14 @@ class EventBuilderFactory(object):
)
def create_local_event_from_event_dict(clock, hostname, signing_key,
format_version, event_dict,
internal_metadata_dict=None):
def create_local_event_from_event_dict(
clock,
hostname,
signing_key,
format_version,
event_dict,
internal_metadata_dict=None,
):
"""Takes a fully formed event dict, ensuring that fields like `origin`
and `origin_server_ts` have correct values for a locally produced event,
then signs and hashes it.
@ -237,9 +240,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
"""
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
raise Exception(
"No event format defined for version %r" % (format_version,)
)
raise Exception("No event format defined for version %r" % (format_version,))
if internal_metadata_dict is None:
internal_metadata_dict = {}
@ -258,13 +259,9 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
event_dict.setdefault("signatures", {})
add_hashes_and_signatures(
event_dict,
hostname,
signing_key,
)
add_hashes_and_signatures(event_dict, hostname, signing_key)
return event_type_from_format_version(format_version)(
event_dict, internal_metadata_dict=internal_metadata_dict,
event_dict, internal_metadata_dict=internal_metadata_dict
)

View file

@ -88,8 +88,9 @@ class EventContext(object):
self.app_service = None
@staticmethod
def with_state(state_group, current_state_ids, prev_state_ids,
prev_group=None, delta_ids=None):
def with_state(
state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None
):
context = EventContext()
# The current state including the current event
@ -132,17 +133,19 @@ class EventContext(object):
else:
prev_state_id = None
defer.returnValue({
"prev_state_id": prev_state_id,
"event_type": event.type,
"event_state_key": event.state_key if event.is_state() else None,
"state_group": self.state_group,
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
"prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None
})
defer.returnValue(
{
"prev_state_id": prev_state_id,
"event_type": event.type,
"event_state_key": event.state_key if event.is_state() else None,
"state_group": self.state_group,
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
"prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None,
}
)
@staticmethod
def deserialize(store, input):
@ -194,7 +197,7 @@ class EventContext(object):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store,
self._fill_out_state, store
)
yield make_deferred_yieldable(self._fetching_state_deferred)
@ -214,7 +217,7 @@ class EventContext(object):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store,
self._fill_out_state, store
)
yield make_deferred_yieldable(self._fetching_state_deferred)
@ -240,9 +243,7 @@ class EventContext(object):
if self.state_group is None:
return
self._current_state_ids = yield store.get_state_ids_for_group(
self.state_group,
)
self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
if self._prev_state_id and self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
@ -252,8 +253,9 @@ class EventContext(object):
self._prev_state_ids = self._current_state_ids
@defer.inlineCallbacks
def update_state(self, state_group, prev_state_ids, current_state_ids,
prev_group, delta_ids):
def update_state(
self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
):
"""Replace the state in the context
"""
@ -279,10 +281,7 @@ def _encode_state_dict(state_dict):
if state_dict is None:
return None
return [
(etype, state_key, v)
for (etype, state_key), v in iteritems(state_dict)
]
return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)]
def _decode_state_dict(input):
@ -291,4 +290,4 @@ def _decode_state_dict(input):
if input is None:
return None
return frozendict({(etype, state_key,): v for etype, state_key, v in input})
return frozendict({(etype, state_key): v for etype, state_key, v in input})

View file

@ -60,7 +60,9 @@ class SpamChecker(object):
if self.spam_checker is None:
return True
return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
return self.spam_checker.user_may_invite(
inviter_userid, invitee_userid, room_id
)
def user_may_create_room(self, userid):
"""Checks if a given user may create a room

View file

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
class ThirdPartyEventRules(object):
"""Allows server admins to provide a Python module implementing an extra
set of rules to apply when processing events.
This is designed to help admins of closed federations with enforcing custom
behaviours.
"""
def __init__(self, hs):
self.third_party_rules = None
self.store = hs.get_datastore()
module = None
config = None
if hs.config.third_party_event_rules:
module, config = hs.config.third_party_event_rules
if module is not None:
self.third_party_rules = module(
config=config, http_client=hs.get_simple_http_client()
)
@defer.inlineCallbacks
def check_event_allowed(self, event, context):
"""Check if a provided event should be allowed in the given context.
Args:
event (synapse.events.EventBase): The event to be checked.
context (synapse.events.snapshot.EventContext): The context of the event.
Returns:
defer.Deferred[bool]: True if the event should be allowed, False if not.
"""
if self.third_party_rules is None:
defer.returnValue(True)
prev_state_ids = yield context.get_prev_state_ids(self.store)
# Retrieve the state events from the database.
state_events = {}
for key, event_id in prev_state_ids.items():
state_events[key] = yield self.store.get_event(event_id, allow_none=True)
ret = yield self.third_party_rules.check_event_allowed(event, state_events)
defer.returnValue(ret)
@defer.inlineCallbacks
def on_create_room(self, requester, config, is_requester_admin):
"""Intercept requests to create room to allow, deny or update the
request config.
Args:
requester (Requester)
config (dict): The creation config from the client.
is_requester_admin (bool): If the requester is an admin
Returns:
defer.Deferred
"""
if self.third_party_rules is None:
return
yield self.third_party_rules.on_create_room(
requester, config, is_requester_admin
)
@defer.inlineCallbacks
def check_threepid_can_be_invited(self, medium, address, room_id):
"""Check if a provided 3PID can be invited in the given room.
Args:
medium (str): The 3PID's medium.
address (str): The 3PID's address.
room_id (str): The room we want to invite the threepid to.
Returns:
defer.Deferred[bool], True if the 3PID can be invited, False if not.
"""
if self.third_party_rules is None:
defer.returnValue(True)
state_ids = yield self.store.get_filtered_current_state_ids(room_id)
room_state_events = yield self.store.get_events(state_ids.values())
state_events = {}
for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id]
ret = yield self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events
)
defer.returnValue(ret)

View file

@ -31,7 +31,7 @@ from . import EventBase
# by a match for 'stuff'.
# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
def prune_event(event):
@ -51,6 +51,7 @@ def prune_event(event):
pruned_event_dict = prune_event_dict(event.get_dict())
from . import event_type_from_format_version
return event_type_from_format_version(event.format_version)(
pruned_event_dict, event.internal_metadata.get_dict()
)
@ -116,11 +117,7 @@ def prune_event_dict(event_dict):
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
allowed_fields = {
k: v
for k, v in event_dict.items()
if k in allowed_keys
}
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
allowed_fields["content"] = new_content
@ -205,7 +202,7 @@ def only_fields(dictionary, fields):
# for each element of the output array of arrays:
# remove escaping so we can use the right key names.
split_fields[:] = [
[f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
]
output = {}
@ -226,7 +223,10 @@ def format_event_for_client_v1(d):
d["user_id"] = sender
copy_keys = (
"age", "redacted_because", "replaces_state", "prev_content",
"age",
"redacted_because",
"replaces_state",
"prev_content",
"invite_room_state",
)
for key in copy_keys:
@ -238,8 +238,13 @@ def format_event_for_client_v1(d):
def format_event_for_client_v2(d):
drop_keys = (
"auth_events", "prev_events", "hashes", "signatures", "depth",
"origin", "prev_state",
"auth_events",
"prev_events",
"hashes",
"signatures",
"depth",
"origin",
"prev_state",
)
for key in drop_keys:
d.pop(key, None)
@ -252,9 +257,15 @@ def format_event_for_client_v2_without_room_id(d):
return d
def serialize_event(e, time_now_ms, as_client_event=True,
event_format=format_event_for_client_v1,
token_id=None, only_event_fields=None, is_invite=False):
def serialize_event(
e,
time_now_ms,
as_client_event=True,
event_format=format_event_for_client_v1,
token_id=None,
only_event_fields=None,
is_invite=False,
):
"""Serialize event for clients
Args:
@ -288,8 +299,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if "redacted_because" in e.unsigned:
d["unsigned"]["redacted_because"] = serialize_event(
e.unsigned["redacted_because"], time_now_ms,
event_format=event_format
e.unsigned["redacted_because"], time_now_ms, event_format=event_format
)
if token_id is not None:
@ -308,8 +318,9 @@ def serialize_event(e, time_now_ms, as_client_event=True,
d = event_format(d)
if only_event_fields:
if (not isinstance(only_event_fields, list) or
not all(isinstance(f, string_types) for f in only_event_fields)):
if not isinstance(only_event_fields, list) or not all(
isinstance(f, string_types) for f in only_event_fields
):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
@ -352,11 +363,9 @@ class EventClientSerializer(object):
# If MSC1849 is enabled then we need to look if thre are any relations
# we need to bundle in with the event
if self.experimental_msc1849_support_enabled and bundle_aggregations:
annotations = yield self.store.get_aggregation_groups_for_event(
event_id,
)
annotations = yield self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f",
event_id, RelationTypes.REFERENCE, direction="f"
)
if annotations.chunk:
@ -383,9 +392,7 @@ class EventClientSerializer(object):
serialized_event["content"].pop("m.relates_to", None)
r = serialized_event["unsigned"].setdefault("m.relations", {})
r[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
}
r[RelationTypes.REPLACE] = {"event_id": edit.event_id}
defer.returnValue(serialized_event)
@ -401,6 +408,5 @@ class EventClientSerializer(object):
Deferred[list[dict]]: The list of serialized events
"""
return yieldable_gather_results(
self.serialize_event, events,
time_now=time_now, **kwargs
self.serialize_event, events, time_now=time_now, **kwargs
)

View file

@ -48,9 +48,7 @@ class EventValidator(object):
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
event_strings = [
"origin",
]
event_strings = ["origin"]
for s in event_strings:
if not isinstance(getattr(event, s), string_types):
@ -62,8 +60,10 @@ class EventValidator(object):
if len(alias) > MAX_ALIAS_LENGTH:
raise SynapseError(
400,
("Can't create aliases longer than"
" %d characters" % (MAX_ALIAS_LENGTH,)),
(
"Can't create aliases longer than"
" %d characters" % (MAX_ALIAS_LENGTH,)
),
Codes.INVALID_PARAM,
)
@ -76,11 +76,7 @@ class EventValidator(object):
event (EventBuilder|FrozenEvent)
"""
strings = [
"room_id",
"sender",
"type",
]
strings = ["room_id", "sender", "type"]
if hasattr(event, "state_key"):
strings.append("state_key")
@ -93,10 +89,7 @@ class EventValidator(object):
UserID.from_string(event.sender)
if event.type == EventTypes.Message:
strings = [
"body",
"msgtype",
]
strings = ["body", "msgtype"]
self._ensure_strings(event.content, strings)

View file

@ -44,8 +44,9 @@ class FederationBase(object):
self._clock = hs.get_clock()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
outlier=False, include_none=False):
def _check_sigs_and_hash_and_fetch(
self, origin, pdus, room_version, outlier=False, include_none=False
):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@ -79,9 +80,7 @@ class FederationBase(object):
if not res:
# Check local db.
res = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
pdu.event_id, allow_rejected=True, allow_none=True
)
if not res and pdu.origin != origin:
@ -98,23 +97,16 @@ class FederationBase(object):
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
"Failed to find copy of %s with valid signature", pdu.event_id
)
defer.returnValue(res)
handle = logcontext.preserve_fn(handle_check_result)
deferreds2 = [
handle(pdu, deferred)
for pdu, deferred in zip(pdus, deferreds)
]
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
valid_pdus = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
deferreds2,
consumeErrors=True,
)
defer.gatherResults(deferreds2, consumeErrors=True)
).addErrback(unwrapFirstError)
if include_none:
@ -124,7 +116,7 @@ class FederationBase(object):
def _check_sigs_and_hash(self, room_version, pdu):
return logcontext.make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0],
self._check_sigs_and_hashes(room_version, [pdu])[0]
)
def _check_sigs_and_hashes(self, room_version, pdus):
@ -159,11 +151,9 @@ class FederationBase(object):
# received event was probably a redacted copy (but we then use our
# *actual* redacted copy to be on the safe side.)
redacted_event = prune_event(pdu)
if (
set(redacted_event.keys()) == set(pdu.keys()) and
set(six.iterkeys(redacted_event.content))
== set(six.iterkeys(pdu.content))
):
if set(redacted_event.keys()) == set(pdu.keys()) and set(
six.iterkeys(redacted_event.content)
) == set(six.iterkeys(pdu.content)):
logger.info(
"Event %s seems to have been redacted; using our redacted "
"copy",
@ -172,14 +162,15 @@ class FederationBase(object):
else:
logger.warning(
"Event %s content has been tampered, redacting",
pdu.event_id, pdu.get_pdu_json(),
pdu.event_id,
)
return redacted_event
if self.spam_checker.check_event_for_spam(pdu):
logger.warn(
"Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
pdu.event_id,
pdu.get_pdu_json(),
)
return prune_event(pdu)
@ -190,23 +181,24 @@ class FederationBase(object):
with logcontext.PreserveLoggingContext(ctx):
logger.warn(
"Signature check failed for %s: %s",
pdu.event_id, failure.getErrorMessage(),
pdu.event_id,
failure.getErrorMessage(),
)
return failure
for deferred, pdu in zip(deferreds, pdus):
deferred.addCallbacks(
callback, errback,
callbackArgs=[pdu],
errbackArgs=[pdu],
callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
)
return deferreds
class PduToCheckSig(namedtuple("PduToCheckSig", [
"pdu", "redacted_pdu_json", "sender_domain", "deferreds",
])):
class PduToCheckSig(
namedtuple(
"PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
)
):
pass
@ -260,10 +252,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [
p for p in pdus_to_check
if not _is_invite_via_3pid(p.pdu)
]
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
more_deferreds = keyring.verify_json_objects_for_server(
[
@ -297,7 +286,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
# (ie, the room version uses old-style non-hash event IDs).
if v.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
p for p in pdus_to_check
p
for p in pdus_to_check
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
@ -315,10 +305,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
def event_err(e, pdu_to_check):
errmsg = (
"event id %s: unable to verify signature for event id domain: %s" % (
pdu_to_check.pdu.event_id,
e.getErrorMessage(),
)
"event id %s: unable to verify signature for event id domain: %s"
% (pdu_to_check.pdu.event_id, e.getErrorMessage())
)
# XX as above: not really sure if these are the right codes
raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
@ -368,21 +356,18 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
assert_params_in_dict(pdu_json, ('type', 'depth'))
assert_params_in_dict(pdu_json, ("type", "depth"))
depth = pdu_json['depth']
depth = pdu_json["depth"]
if not isinstance(depth, six.integer_types):
raise SynapseError(400, "Depth %r not an intger" % (depth, ),
Codes.BAD_JSON)
raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
if depth < 0:
raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
elif depth > MAX_DEPTH:
raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
event = event_type_from_format_version(event_format_version)(
pdu_json,
)
event = event_type_from_format_version(event_format_version)(pdu_json)
event.internal_metadata.outlier = outlier

View file

@ -57,6 +57,7 @@ class InvalidResponseError(RuntimeError):
"""Helper for _try_destination_list: indicates that the server returned a response
we couldn't parse
"""
pass
@ -65,9 +66,7 @@ class FederationClient(FederationBase):
super(FederationClient, self).__init__(hs)
self.pdu_destination_tried = {}
self._clock.looping_call(
self._clear_tried_cache, 60 * 1000,
)
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
@ -99,8 +98,14 @@ class FederationClient(FederationBase):
self.pdu_destination_tried[event_id] = destination_dict
@log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=False, ignore_backoff=False):
def make_query(
self,
destination,
query_type,
args,
retry_on_dns_fail=False,
ignore_backoff=False,
):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
@ -120,7 +125,10 @@ class FederationClient(FederationBase):
sent_queries_counter.labels(query_type).inc()
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
destination,
query_type,
args,
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)
@ -137,9 +145,7 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.labels("client_device_keys").inc()
return self.transport_layer.query_client_keys(
destination, content, timeout
)
return self.transport_layer.query_client_keys(destination, content, timeout)
@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
@ -147,9 +153,7 @@ class FederationClient(FederationBase):
server.
"""
sent_queries_counter.labels("user_devices").inc()
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
return self.transport_layer.query_user_devices(destination, user_id, timeout)
@log_function
def claim_client_keys(self, destination, content, timeout):
@ -164,9 +168,7 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
return self.transport_layer.claim_client_keys(
destination, content, timeout
)
return self.transport_layer.claim_client_keys(destination, content, timeout)
@defer.inlineCallbacks
@log_function
@ -191,7 +193,8 @@ class FederationClient(FederationBase):
return
transaction_data = yield self.transport_layer.backfill(
dest, room_id, extremities, limit)
dest, room_id, extremities, limit
)
logger.debug("backfill transaction_data=%s", repr(transaction_data))
@ -204,17 +207,19 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(room_version, pdus),
consumeErrors=True,
).addErrback(unwrapFirstError))
pdus[:] = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True
).addErrback(unwrapFirstError)
)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destinations, event_id, room_version, outlier=False,
timeout=None):
def get_pdu(
self, destinations, event_id, room_version, outlier=False, timeout=None
):
"""Requests the PDU with given origin and ID from the remote home
servers.
@ -255,7 +260,7 @@ class FederationClient(FederationBase):
try:
transaction_data = yield self.transport_layer.get_event(
destination, event_id, timeout=timeout,
destination, event_id, timeout=timeout
)
logger.debug(
@ -282,8 +287,7 @@ class FederationClient(FederationBase):
except SynapseError as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
"Failed to get PDU %s from %s because %s", event_id, destination, e
)
continue
except NotRetryingDestination as e:
@ -296,8 +300,7 @@ class FederationClient(FederationBase):
pdu_attempts[destination] = now
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
"Failed to get PDU %s from %s because %s", event_id, destination, e
)
continue
@ -326,7 +329,7 @@ class FederationClient(FederationBase):
# we have most of the state and auth_chain already.
# However, this may 404 if the other side has an old synapse.
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id,
destination, room_id, event_id=event_id
)
state_event_ids = result["pdu_ids"]
@ -340,12 +343,10 @@ class FederationClient(FederationBase):
logger.warning(
"Failed to fetch missing state/auth events for %s: %s",
room_id,
failed_to_fetch
failed_to_fetch,
)
event_map = {
ev.event_id: ev for ev in fetched_events
}
event_map = {ev.event_id: ev for ev in fetched_events}
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [
@ -362,15 +363,14 @@ class FederationClient(FederationBase):
raise e
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
destination, room_id, event_id=event_id
)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdus = [
event_from_pdu_json(p, format_ver, outlier=True)
for p in result["pdus"]
event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"]
]
auth_chain = [
@ -378,9 +378,9 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", [])
]
seen_events = yield self.store.get_events([
ev.event_id for ev in itertools.chain(pdus, auth_chain)
])
seen_events = yield self.store.get_events(
[ev.event_id for ev in itertools.chain(pdus, auth_chain)]
)
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination,
@ -442,7 +442,7 @@ class FederationClient(FederationBase):
batch_size = 20
missing_events = list(missing_events)
for i in range(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
batch = set(missing_events[i : i + batch_size])
deferreds = [
run_in_background(
@ -470,21 +470,17 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
destination, room_id, event_id,
)
res = yield self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [
event_from_pdu_json(p, format_ver, outlier=True)
for p in res["auth_chain"]
event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain,
outlier=True, room_version=room_version,
destination, auth_chain, outlier=True, room_version=room_version
)
signed_auth.sort(key=lambda e: e.depth)
@ -527,28 +523,26 @@ class FederationClient(FederationBase):
res = yield callback(destination)
defer.returnValue(res)
except InvalidResponseError as e:
logger.warn(
"Failed to %s via %s: %s",
description, destination, e,
)
logger.warn("Failed to %s via %s: %s", description, destination, e)
except HttpResponseException as e:
if not 500 <= e.code < 600:
raise e.to_synapse_error()
else:
logger.warn(
"Failed to %s via %s: %i %s",
description, destination, e.code, e.args[0],
description,
destination,
e.code,
e.args[0],
)
except Exception:
logger.warn(
"Failed to %s via %s",
description, destination, exc_info=1,
)
logger.warn("Failed to %s via %s", description, destination, exc_info=1)
raise RuntimeError("Failed to %s via any server" % (description, ))
raise RuntimeError("Failed to %s via any server" % (description,))
def make_membership_event(self, destinations, room_id, user_id, membership,
content, params):
def make_membership_event(
self, destinations, room_id, user_id, membership, content, params
):
"""
Creates an m.room.member event, with context, without participating in the room.
@ -584,14 +578,14 @@ class FederationClient(FederationBase):
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
"make_membership_event called with membership='%s', must be one of %s"
% (membership, ",".join(valid_memberships))
)
@defer.inlineCallbacks
def send_request(destination):
ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership, params,
destination, room_id, user_id, membership, params
)
# Note: If not supplied, the room version may be either v1 or v2,
@ -614,16 +608,17 @@ class FederationClient(FederationBase):
pdu_dict["prev_state"] = []
ev = builder.create_local_event_from_event_dict(
self._clock, self.hostname, self.signing_key,
format_version=event_format, event_dict=pdu_dict,
self._clock,
self.hostname,
self.signing_key,
format_version=event_format,
event_dict=pdu_dict,
)
defer.returnValue(
(destination, ev, event_format)
)
defer.returnValue((destination, ev, event_format))
return self._try_destination_list(
"make_" + membership, destinations, send_request,
"make_" + membership, destinations, send_request
)
def send_join(self, destinations, pdu, event_format_version):
@ -655,9 +650,7 @@ class FederationClient(FederationBase):
create_event = e
break
else:
raise InvalidResponseError(
"no %s in auth chain" % (EventTypes.Create,),
)
raise InvalidResponseError("no %s in auth chain" % (EventTypes.Create,))
# the room version should be sane.
room_version = create_event.content.get("room_version", "1")
@ -665,9 +658,8 @@ class FederationClient(FederationBase):
# This shouldn't be possible, because the remote server should have
# rejected the join attempt during make_join.
raise InvalidResponseError(
"room appears to have unsupported version %s" % (
room_version,
))
"room appears to have unsupported version %s" % (room_version,)
)
@defer.inlineCallbacks
def send_request(destination):
@ -691,10 +683,7 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
pdus = {
p.event_id: p
for p in itertools.chain(state, auth_chain)
}
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
room_version = None
for e in state:
@ -710,15 +699,13 @@ class FederationClient(FederationBase):
raise SynapseError(400, "No create event in state")
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, list(pdus.values()),
destination,
list(pdus.values()),
outlier=True,
room_version=room_version,
)
valid_pdus_map = {
p.event_id: p
for p in valid_pdus
}
valid_pdus_map = {p.event_id: p for p in valid_pdus}
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
@ -741,11 +728,14 @@ class FederationClient(FederationBase):
check_authchain_validity(signed_auth)
defer.returnValue({
"state": signed_state,
"auth_chain": signed_auth,
"origin": destination,
})
defer.returnValue(
{
"state": signed_state,
"auth_chain": signed_auth,
"origin": destination,
}
)
return self._try_destination_list("send_join", destinations, send_request)
@defer.inlineCallbacks
@ -854,6 +844,7 @@ class FederationClient(FederationBase):
Fails with a ``RuntimeError`` if no servers were reachable.
"""
@defer.inlineCallbacks
def send_request(destination):
time_now = self._clock.time_msec()
@ -869,14 +860,23 @@ class FederationClient(FederationBase):
return self._try_destination_list("send_leave", destinations, send_request)
def get_public_rooms(self, destination, limit=None, since_token=None,
search_filter=None, include_all_networks=False,
third_party_instance_id=None):
def get_public_rooms(
self,
destination,
limit=None,
since_token=None,
search_filter=None,
include_all_networks=False,
third_party_instance_id=None,
):
if destination == self.server_name:
return
return self.transport_layer.get_public_rooms(
destination, limit, since_token, search_filter,
destination,
limit,
since_token,
search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
@ -891,9 +891,7 @@ class FederationClient(FederationBase):
"""
time_now = self._clock.time_msec()
send_content = {
"auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
}
send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]}
code, content = yield self.transport_layer.send_query_auth(
destination=destination,
@ -905,13 +903,10 @@ class FederationClient(FederationBase):
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [
event_from_pdu_json(e, format_ver)
for e in content["auth_chain"]
]
auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True, room_version=room_version,
destination, auth_chain, outlier=True, room_version=room_version
)
signed_auth.sort(key=lambda e: e.depth)
@ -925,8 +920,16 @@ class FederationClient(FederationBase):
defer.returnValue(ret)
@defer.inlineCallbacks
def get_missing_events(self, destination, room_id, earliest_events_ids,
latest_events, limit, min_depth, timeout):
def get_missing_events(
self,
destination,
room_id,
earliest_events_ids,
latest_events,
limit,
min_depth,
timeout,
):
"""Tries to fetch events we are missing. This is called when we receive
an event without having received all of its ancestors.
@ -957,12 +960,11 @@ class FederationClient(FederationBase):
format_ver = room_version_to_event_format(room_version)
events = [
event_from_pdu_json(e, format_ver)
for e in content.get("events", [])
event_from_pdu_json(e, format_ver) for e in content.get("events", [])
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False, room_version=room_version,
destination, events, outlier=False, room_version=room_version
)
except HttpResponseException as e:
if not e.code == 400:
@ -982,17 +984,14 @@ class FederationClient(FederationBase):
try:
yield self.transport_layer.exchange_third_party_invite(
destination=destination,
room_id=room_id,
event_dict=event_dict,
destination=destination, room_id=room_id, event_dict=event_dict
)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_third_party_invite via %s: %s",
destination, str(e)
"Failed to send_third_party_invite via %s: %s", destination, str(e)
)
raise RuntimeError("Failed to send to any server.")

View file

@ -69,7 +69,6 @@ received_queries_counter = Counter(
class FederationServer(FederationBase):
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
@ -118,11 +117,13 @@ class FederationServer(FederationBase):
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
with (yield self._transaction_linearizer.queue(
(origin, transaction.transaction_id),
)):
with (
yield self._transaction_linearizer.queue(
(origin, transaction.transaction_id)
)
):
result = yield self._handle_incoming_transaction(
origin, transaction, request_time,
origin, transaction, request_time
)
defer.returnValue(result)
@ -144,7 +145,7 @@ class FederationServer(FederationBase):
if response:
logger.debug(
"[%s] We've already responded to this request",
transaction.transaction_id
transaction.transaction_id,
)
defer.returnValue(response)
return
@ -152,18 +153,15 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
# Reject if PDU count > 50 and EDU count > 100
if (len(transaction.pdus) > 50
or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
if len(transaction.pdus) > 50 or (
hasattr(transaction, "edus") and len(transaction.edus) > 100
):
logger.info(
"Transaction PDU or EDU count too large. Returning 400",
)
logger.info("Transaction PDU or EDU count too large. Returning 400")
response = {}
yield self.transaction_actions.set_response(
origin,
transaction,
400, response
origin, transaction, 400, response
)
defer.returnValue((400, response))
@ -230,9 +228,7 @@ class FederationServer(FederationBase):
try:
yield self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
logger.warn(
"Ignoring PDUs for room %s from banned server", room_id,
)
logger.warn("Ignoring PDUs for room %s from banned server", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
@ -242,9 +238,7 @@ class FederationServer(FederationBase):
event_id = pdu.event_id
with nested_logging_context(event_id):
try:
yield self._handle_received_pdu(
origin, pdu
)
yield self._handle_received_pdu(origin, pdu)
pdu_results[event_id] = {}
except FederationError as e:
logger.warn("Error handling PDU %s: %s", event_id, e)
@ -259,29 +253,18 @@ class FederationServer(FederationBase):
)
yield concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(),
TRANSACTION_CONCURRENCY_LIMIT,
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
yield self.received_edu(
origin,
edu.edu_type,
edu.content
)
yield self.received_edu(origin, edu.edu_type, edu.content)
response = {
"pdus": pdu_results,
}
response = {"pdus": pdu_results}
logger.debug("Returning: %s", str(response))
yield self.transaction_actions.set_response(
origin,
transaction,
200, response
)
yield self.transaction_actions.set_response(origin, transaction, 200, response)
defer.returnValue((200, response))
@defer.inlineCallbacks
@ -311,7 +294,8 @@ class FederationServer(FederationBase):
resp = yield self._state_resp_cache.wrap(
(room_id, event_id),
self._on_context_state_request_compute,
room_id, event_id,
room_id,
event_id,
)
defer.returnValue((200, resp))
@ -328,24 +312,17 @@ class FederationServer(FederationBase):
if not in_room:
raise AuthError(403, "Host not in room.")
state_ids = yield self.handler.get_state_ids_for_pdu(
room_id, event_id,
)
state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
defer.returnValue((200, {
"pdu_ids": state_ids,
"auth_chain_ids": auth_chain_ids,
}))
defer.returnValue(
(200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids})
)
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
pdus = yield self.handler.get_state_for_pdu(
room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
for event in auth_chain:
# We sign these again because there was a bug where we
@ -355,14 +332,16 @@ class FederationServer(FederationBase):
compute_event_signature(
event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0]
self.hs.config.signing_key[0],
)
)
defer.returnValue({
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
})
defer.returnValue(
{
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}
)
@defer.inlineCallbacks
@log_function
@ -370,9 +349,7 @@ class FederationServer(FederationBase):
pdu = yield self.handler.get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
(200, self._transaction_from_pdus([pdu]).get_dict())
)
defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict()))
else:
defer.returnValue((404, ""))
@ -394,10 +371,9 @@ class FederationServer(FederationBase):
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({
"event": pdu.get_pdu_json(time_now),
"room_version": room_version,
})
defer.returnValue(
{"event": pdu.get_pdu_json(time_now), "room_version": room_version}
)
@defer.inlineCallbacks
def on_invite_request(self, origin, content, room_version):
@ -431,12 +407,17 @@ class FederationServer(FederationBase):
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
}))
defer.returnValue(
(
200,
{
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
},
)
)
@defer.inlineCallbacks
def on_make_leave_request(self, origin, room_id, user_id):
@ -447,10 +428,9 @@ class FederationServer(FederationBase):
room_version = yield self.store.get_room_version(room_id)
time_now = self._clock.time_msec()
defer.returnValue({
"event": pdu.get_pdu_json(time_now),
"room_version": room_version,
})
defer.returnValue(
{"event": pdu.get_pdu_json(time_now), "room_version": room_version}
)
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content, room_id):
@ -475,9 +455,7 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
res = {
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
}
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
defer.returnValue((200, res))
@defer.inlineCallbacks
@ -508,12 +486,11 @@ class FederationServer(FederationBase):
format_ver = room_version_to_event_format(room_version)
auth_chain = [
event_from_pdu_json(e, format_ver)
for e in content["auth_chain"]
event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True, room_version=room_version,
origin, auth_chain, outlier=True, room_version=room_version
)
ret = yield self.handler.on_query_auth(
@ -527,17 +504,12 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
send_content = {
"auth_chain": [
e.get_pdu_json(time_now)
for e in ret["auth_chain"]
],
"auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
"rejects": ret.get("rejects", []),
"missing": ret.get("missing", []),
}
defer.returnValue(
(200, send_content)
)
defer.returnValue((200, send_content))
@log_function
def on_query_client_keys(self, origin, content):
@ -566,20 +538,23 @@ class FederationServer(FederationBase):
logger.info(
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in iteritems(json_result)
for device_id, device_keys in iteritems(user_keys)
for key_id, _ in iteritems(device_keys)
)),
",".join(
(
"%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in iteritems(json_result)
for device_id, device_keys in iteritems(user_keys)
for key_id, _ in iteritems(device_keys)
)
),
)
defer.returnValue({"one_time_keys": json_result})
@defer.inlineCallbacks
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit):
def on_get_missing_events(
self, origin, room_id, earliest_events, latest_events, limit
):
with (yield self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
@ -587,11 +562,13 @@ class FederationServer(FederationBase):
logger.info(
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d",
earliest_events, latest_events, limit,
earliest_events,
latest_events,
limit,
)
missing_events = yield self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit,
origin, room_id, earliest_events, latest_events, limit
)
if len(missing_events) < 5:
@ -603,9 +580,9 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
})
defer.returnValue(
{"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
)
@log_function
def on_openid_userinfo(self, token):
@ -666,22 +643,17 @@ class FederationServer(FederationBase):
# origin. See bug #1893. This is also true for some third party
# invites).
if not (
pdu.type == 'm.room.member' and
pdu.content and
pdu.content.get("membership", None) in (
Membership.JOIN, Membership.INVITE,
)
pdu.type == "m.room.member"
and pdu.content
and pdu.content.get("membership", None)
in (Membership.JOIN, Membership.INVITE)
):
logger.info(
"Discarding PDU %s from invalid origin %s",
pdu.event_id, origin
"Discarding PDU %s from invalid origin %s", pdu.event_id, origin
)
return
else:
logger.info(
"Accepting join PDU %s from %s",
pdu.event_id, origin
)
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
room_version = yield self.store.get_room_version(pdu.room_id)
@ -690,33 +662,19 @@ class FederationServer(FederationBase):
try:
pdu = yield self._check_sigs_and_hash(room_version, pdu)
except SynapseError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=pdu.event_id,
)
raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
yield self.handler.on_receive_pdu(
origin, pdu, sent_to_us_directly=True,
)
yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
@defer.inlineCallbacks
def exchange_third_party_invite(
self,
sender_user_id,
target_user_id,
room_id,
signed,
self, sender_user_id, target_user_id, room_id, signed
):
ret = yield self.handler.exchange_third_party_invite(
sender_user_id,
target_user_id,
room_id,
signed,
sender_user_id, target_user_id, room_id, signed
)
defer.returnValue(ret)
@ -771,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event):
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
if server_name[0] == '[':
if server_name[0] == "[":
return False
# check for ipv4 literals. We can just lift the routine from twisted.
@ -805,7 +763,9 @@ def server_matches_acl_event(server_name, acl_event):
def _acl_entry_matches(server_name, acl_entry):
if not isinstance(acl_entry, six.string_types):
logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
logger.warn(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
regex = glob_to_regex(acl_entry)
return regex.match(server_name)
@ -815,6 +775,7 @@ class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
"""
def __init__(self):
self.edu_handlers = {}
self.query_handlers = {}
@ -848,9 +809,7 @@ class FederationHandlerRegistry(object):
on and the result used as the response to the query request.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
raise KeyError("Already have a Query handler for %s" % (query_type,))
logger.info("Registering federation query handler for %r", query_type)
@ -905,14 +864,10 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
handler = self.edu_handlers.get(edu_type)
if handler:
return super(ReplicationFederationHandlerRegistry, self).on_edu(
edu_type, origin, content,
edu_type, origin, content
)
return self._send_edu(
edu_type=edu_type,
origin=origin,
content=content,
)
return self._send_edu(edu_type=edu_type, origin=origin, content=content)
def on_query(self, query_type, args):
"""Overrides FederationHandlerRegistry
@ -921,7 +876,4 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
if handler:
return handler(args)
return self._get_query_client(
query_type=query_type,
args=args,
)
return self._get_query_client(query_type=query_type, args=args)

View file

@ -46,12 +46,9 @@ class TransactionActions(object):
response code and response body.
"""
if not transaction.transaction_id:
raise RuntimeError("Cannot persist a transaction with no "
"transaction_id")
raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
return self.store.get_received_txn_response(
transaction.transaction_id, origin
)
return self.store.get_received_txn_response(transaction.transaction_id, origin)
@log_function
def set_response(self, origin, transaction, code, response):
@ -61,14 +58,10 @@ class TransactionActions(object):
Deferred
"""
if not transaction.transaction_id:
raise RuntimeError("Cannot persist a transaction with no "
"transaction_id")
raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
return self.store.set_received_txn_response(
transaction.transaction_id,
origin,
code,
response,
transaction.transaction_id, origin, code, response
)
@defer.inlineCallbacks

View file

@ -77,12 +77,22 @@ class FederationRemoteSendQueue(object):
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
def register(name, queue):
LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
"", [], lambda: len(queue))
LaterGauge(
"synapse_federation_send_queue_%s_size" % (queue_name,),
"",
[],
lambda: len(queue),
)
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
"edus", "device_messages", "pos_time", "presence_destinations",
"presence_map",
"presence_changed",
"keyed_edu",
"keyed_edu_changed",
"edus",
"device_messages",
"pos_time",
"presence_destinations",
]:
register(queue_name, getattr(self, queue_name))
@ -121,9 +131,7 @@ class FederationRemoteSendQueue(object):
del self.presence_changed[key]
user_ids = set(
user_id
for uids in self.presence_changed.values()
for user_id in uids
user_id for uids in self.presence_changed.values() for user_id in uids
)
keys = self.presence_destinations.keys()
@ -285,19 +293,21 @@ class FederationRemoteSendQueue(object):
]
for (key, user_id) in dest_user_ids:
rows.append((key, PresenceRow(
state=self.presence_map[user_id],
)))
rows.append((key, PresenceRow(state=self.presence_map[user_id])))
# Fetch presence to send to destinations
i = self.presence_destinations.bisect_right(from_token)
j = self.presence_destinations.bisect_right(to_token) + 1
for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
rows.append((pos, PresenceDestinationsRow(
state=self.presence_map[user_id],
destinations=list(dests),
)))
rows.append(
(
pos,
PresenceDestinationsRow(
state=self.presence_map[user_id], destinations=list(dests)
),
)
)
# Fetch changes keyed edus
i = self.keyed_edu_changed.bisect_right(from_token)
@ -308,10 +318,14 @@ class FederationRemoteSendQueue(object):
keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
for ((destination, edu_key), pos) in iteritems(keyed_edus):
rows.append((pos, KeyedEduRow(
key=edu_key,
edu=self.keyed_edu[(destination, edu_key)],
)))
rows.append(
(
pos,
KeyedEduRow(
key=edu_key, edu=self.keyed_edu[(destination, edu_key)]
),
)
)
# Fetch changed edus
i = self.edus.bisect_right(from_token)
@ -327,9 +341,7 @@ class FederationRemoteSendQueue(object):
device_messages = {v: k for k, v in self.device_messages.items()[i:j]}
for (destination, pos) in iteritems(device_messages):
rows.append((pos, DeviceRow(
destination=destination,
)))
rows.append((pos, DeviceRow(destination=destination)))
# Sort rows based on pos
rows.sort()
@ -377,16 +389,14 @@ class BaseFederationRow(object):
raise NotImplementedError()
class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
"state", # UserPresenceState
))):
class PresenceRow(
BaseFederationRow, namedtuple("PresenceRow", ("state",)) # UserPresenceState
):
TypeId = "p"
@staticmethod
def from_data(data):
return PresenceRow(
state=UserPresenceState.from_dict(data)
)
return PresenceRow(state=UserPresenceState.from_dict(data))
def to_data(self):
return self.state.as_dict()
@ -395,33 +405,35 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
buff.presence.append(self.state)
class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", (
"state", # UserPresenceState
"destinations", # list[str]
))):
class PresenceDestinationsRow(
BaseFederationRow,
namedtuple(
"PresenceDestinationsRow",
("state", "destinations"), # UserPresenceState # list[str]
),
):
TypeId = "pd"
@staticmethod
def from_data(data):
return PresenceDestinationsRow(
state=UserPresenceState.from_dict(data["state"]),
destinations=data["dests"],
state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
)
def to_data(self):
return {
"state": self.state.as_dict(),
"dests": self.destinations,
}
return {"state": self.state.as_dict(), "dests": self.destinations}
def add_to_buffer(self, buff):
buff.presence_destinations.append((self.state, self.destinations))
class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
"key", # tuple(str) - the edu key passed to send_edu
"edu", # Edu
))):
class KeyedEduRow(
BaseFederationRow,
namedtuple(
"KeyedEduRow",
("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu
),
):
"""Streams EDUs that have an associated key that is ued to clobber. For example,
typing EDUs clobber based on room_id.
"""
@ -430,28 +442,19 @@ class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
@staticmethod
def from_data(data):
return KeyedEduRow(
key=tuple(data["key"]),
edu=Edu(**data["edu"]),
)
return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
def to_data(self):
return {
"key": self.key,
"edu": self.edu.get_internal_dict(),
}
return {"key": self.key, "edu": self.edu.get_internal_dict()}
def add_to_buffer(self, buff):
buff.keyed_edus.setdefault(
self.edu.destination, {}
)[self.key] = self.edu
buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
class EduRow(BaseFederationRow, namedtuple("EduRow", (
"edu", # Edu
))):
class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
"""Streams EDUs that don't have keys. See KeyedEduRow
"""
TypeId = "e"
@staticmethod
@ -465,13 +468,12 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
"destination", # str
))):
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ("destination",))): # str
"""Streams the fact that either a) there is pending to device messages for
users on the remote, or b) a local users device has changed and needs to
be sent to the remote.
"""
TypeId = "d"
@staticmethod
@ -487,23 +489,20 @@ class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
TypeToRow = {
Row.TypeId: Row
for Row in (
PresenceRow,
PresenceDestinationsRow,
KeyedEduRow,
EduRow,
DeviceRow,
)
for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow)
}
ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
"presence", # list(UserPresenceState)
"presence_destinations", # list of tuples of UserPresenceState and destinations
"keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu]
"device_destinations", # set of destinations
))
ParsedFederationStreamData = namedtuple(
"ParsedFederationStreamData",
(
"presence", # list(UserPresenceState)
"presence_destinations", # list of tuples of UserPresenceState and destinations
"keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu]
"device_destinations", # set of destinations
),
)
def process_rows_for_federation(transaction_queue, rows):
@ -542,7 +541,7 @@ def process_rows_for_federation(transaction_queue, rows):
for state, destinations in buff.presence_destinations:
transaction_queue.send_presence_to_destinations(
states=[state], destinations=destinations,
states=[state], destinations=destinations
)
for destination, edu_map in iteritems(buff.keyed_edus):

View file

@ -44,8 +44,8 @@ sent_pdus_destination_dist_count = Counter(
)
sent_pdus_destination_dist_total = Counter(
"synapse_federation_client_sent_pdu_destinations:total", ""
"Total number of PDUs queued for sending across all destinations",
"synapse_federation_client_sent_pdu_destinations:total",
"" "Total number of PDUs queued for sending across all destinations",
)
@ -63,14 +63,15 @@ class FederationSender(object):
self._transaction_manager = TransactionManager(hs)
# map from destination to PerDestinationQueue
self._per_destination_queues = {} # type: dict[str, PerDestinationQueue]
self._per_destination_queues = {} # type: dict[str, PerDestinationQueue]
LaterGauge(
"synapse_federation_transaction_queue_pending_destinations",
"",
[],
lambda: sum(
1 for d in self._per_destination_queues.values()
1
for d in self._per_destination_queues.values()
if d.transmission_loop_running
),
)
@ -108,8 +109,9 @@ class FederationSender(object):
# awaiting a call to flush_read_receipts_for_room. The presence of an entry
# here for a given room means that we are rate-limiting RR flushes to that room,
# and that there is a pending call to _flush_rrs_for_room in the system.
self._queues_awaiting_rr_flush_by_room = {
} # type: dict[str, set[PerDestinationQueue]]
self._queues_awaiting_rr_flush_by_room = (
{}
) # type: dict[str, set[PerDestinationQueue]]
self._rr_txn_interval_per_room_ms = (
1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second
@ -141,8 +143,7 @@ class FederationSender(object):
# fire off a processing loop in the background
run_as_background_process(
"process_event_queue_for_federation",
self._process_event_queue_loop,
"process_event_queue_for_federation", self._process_event_queue_loop
)
@defer.inlineCallbacks
@ -152,7 +153,7 @@ class FederationSender(object):
while True:
last_token = yield self.store.get_federation_out_pos("events")
next_token, events = yield self.store.get_all_new_events_stream(
last_token, self._last_poked_id, limit=100,
last_token, self._last_poked_id, limit=100
)
logger.debug("Handling %s -> %s", last_token, next_token)
@ -168,6 +169,9 @@ class FederationSender(object):
if not is_mine and send_on_behalf_of is None:
return
if not event.internal_metadata.should_proactively_send():
return
try:
# Get the state from before the event.
# We need to make sure that this is the state from before
@ -176,7 +180,7 @@ class FederationSender(object):
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = yield self.state.get_current_hosts_in_room(
event.room_id, latest_event_ids=event.prev_event_ids(),
event.room_id, latest_event_ids=event.prev_event_ids()
)
except Exception:
logger.exception(
@ -206,37 +210,40 @@ class FederationSender(object):
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
logcontext.run_in_background(handle_room_events, evs)
for evs in itervalues(events_by_room)
],
consumeErrors=True
))
yield self.store.update_federation_out_pos(
"events", next_token
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
logcontext.run_in_background(handle_room_events, evs)
for evs in itervalues(events_by_room)
],
consumeErrors=True,
)
)
yield self.store.update_federation_out_pos("events", next_token)
if events:
now = self.clock.time_msec()
ts = yield self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_lag.labels(
"federation_sender").set(now - ts)
"federation_sender"
).set(now - ts)
synapse.metrics.event_processing_last_ts.labels(
"federation_sender").set(ts)
"federation_sender"
).set(ts)
events_processed_counter.inc(len(events))
event_processing_loop_room_count.labels(
"federation_sender"
).inc(len(events_by_room))
event_processing_loop_room_count.labels("federation_sender").inc(
len(events_by_room)
)
event_processing_loop_counter.labels("federation_sender").inc()
synapse.metrics.event_processing_positions.labels(
"federation_sender").set(next_token)
"federation_sender"
).set(next_token)
finally:
self._is_processing = False
@ -309,9 +316,7 @@ class FederationSender(object):
if not domains:
return
queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(
room_id
)
queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id)
# if there is no flush yet scheduled, we will send out these receipts with
# immediate flushes, and schedule the next flush for this room.
@ -374,10 +379,9 @@ class FederationSender(object):
# updates in quick succession are correctly handled.
# We only want to send presence for our own users, so lets always just
# filter here just in case.
self.pending_presence.update({
state.user_id: state for state in states
if self.is_mine_id(state.user_id)
})
self.pending_presence.update(
{state.user_id: state for state in states if self.is_mine_id(state.user_id)}
)
# We then handle the new pending presence in batches, first figuring
# out the destinations we need to send each state to and then poking it

View file

@ -189,11 +189,21 @@ class PerDestinationQueue(object):
pending_pdus = []
while True:
device_message_edus, device_stream_id, dev_list_id = (
# We have to keep 2 free slots for presence and rr_edus
yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2)
# We have to keep 2 free slots for presence and rr_edus
limit = MAX_EDUS_PER_TRANSACTION - 2
device_update_edus, dev_list_id = (
yield self._get_device_update_edus(limit)
)
limit -= len(device_update_edus)
to_device_edus, device_stream_id = (
yield self._get_to_device_message_edus(limit)
)
pending_edus = device_update_edus + to_device_edus
# BEGIN CRITICAL SECTION
#
# In order to avoid a race condition, we need to make sure that
@ -208,10 +218,6 @@ class PerDestinationQueue(object):
# We can only include at most 50 PDUs per transactions
pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
pending_edus = []
# We can only include at most 100 EDUs per transactions
# rr_edus and pending_presence take at most one slot each
pending_edus.extend(self._get_rr_edus(force_flush=False))
pending_presence = self._pending_presence
self._pending_presence = {}
@ -232,7 +238,6 @@ class PerDestinationQueue(object):
)
)
pending_edus.extend(device_message_edus)
pending_edus.extend(
self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
)
@ -272,10 +277,13 @@ class PerDestinationQueue(object):
sent_edus_by_type.labels(edu.edu_type).inc()
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
if device_message_edus:
if to_device_edus:
yield self._store.delete_device_msgs_for_remote(
self._destination, device_stream_id
)
# also mark the device updates as sent
if device_update_edus:
logger.info(
"Marking as sent %r %r", self._destination, dev_list_id
)
@ -347,12 +355,12 @@ class PerDestinationQueue(object):
return pending_edus
@defer.inlineCallbacks
def _get_new_device_messages(self, limit):
def _get_device_update_edus(self, limit):
last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination
now_stream_id, results = yield self._store.get_devices_by_remote(
self._destination, last_device_list, limit=limit,
self._destination, last_device_list, limit=limit
)
edus = [
Edu(
@ -366,15 +374,16 @@ class PerDestinationQueue(object):
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
defer.returnValue((edus, now_stream_id))
@defer.inlineCallbacks
def _get_to_device_message_edus(self, limit):
last_device_stream_id = self._last_device_stream_id
to_device_stream_id = self._store.get_to_device_stream_token()
contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
self._destination,
last_device_stream_id,
to_device_stream_id,
limit - len(edus),
self._destination, last_device_stream_id, to_device_stream_id, limit
)
edus.extend(
edus = [
Edu(
origin=self._server_name,
destination=self._destination,
@ -382,6 +391,6 @@ class PerDestinationQueue(object):
content=content,
)
for content in contents
)
]
defer.returnValue((edus, stream_id, now_stream_id))
defer.returnValue((edus, stream_id))

View file

@ -29,9 +29,10 @@ class TransactionManager(object):
shared between PerDestinationQueue objects
"""
def __init__(self, hs):
self._server_name = hs.hostname
self.clock = hs.get_clock() # nb must be called this for @measure_func
self.clock = hs.get_clock() # nb must be called this for @measure_func
self._store = hs.get_datastore()
self._transaction_actions = TransactionActions(self._store)
self._transport_layer = hs.get_federation_transport_client()
@ -55,9 +56,9 @@ class TransactionManager(object):
txn_id = str(self._next_txn_id)
logger.debug(
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d)",
destination, txn_id,
"TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
destination,
txn_id,
len(pdus),
len(edus),
)
@ -79,9 +80,9 @@ class TransactionManager(object):
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d)",
destination, txn_id,
"TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
destination,
txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
@ -112,20 +113,12 @@ class TransactionManager(object):
response = e.response
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
raise e
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
yield self._transaction_actions.delivered(
transaction, code, response
)
yield self._transaction_actions.delivered(transaction, code, response)
logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id)
@ -134,13 +127,18 @@ class TransactionManager(object):
if "error" in r:
logger.warn(
"TX [%s] {%s} Remote returned error for %s: %s",
destination, txn_id, e_id, r,
destination,
txn_id,
e_id,
r,
)
else:
for p in pdus:
logger.warn(
"TX [%s] {%s} Failed to send event %s",
destination, txn_id, p.event_id,
destination,
txn_id,
p.event_id,
)
success = False

View file

@ -48,12 +48,13 @@ class TransportLayerClient(object):
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state dest=%s, room=%s",
destination, room_id)
logger.debug("get_room_state dest=%s, room=%s", destination, room_id)
path = _create_v1_path("/state/%s", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
destination,
path=path,
args={"event_id": event_id},
try_trailing_slash_on_400=True,
)
@ -71,12 +72,13 @@ class TransportLayerClient(object):
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
path = _create_v1_path("/state_ids/%s", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
destination,
path=path,
args={"event_id": event_id},
try_trailing_slash_on_400=True,
)
@ -94,13 +96,11 @@ class TransportLayerClient(object):
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id)
logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
path = _create_v1_path("/event/%s", event_id)
return self.client.get_json(
destination, path=path, timeout=timeout,
try_trailing_slash_on_400=True,
destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
)
@log_function
@ -119,7 +119,10 @@ class TransportLayerClient(object):
"""
logger.debug(
"backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s",
destination, room_id, repr(event_tuples), str(limit)
destination,
room_id,
repr(event_tuples),
str(limit),
)
if not event_tuples:
@ -128,16 +131,10 @@ class TransportLayerClient(object):
path = _create_v1_path("/backfill/%s", room_id)
args = {
"v": event_tuples,
"limit": [str(limit)],
}
args = {"v": event_tuples, "limit": [str(limit)]}
return self.client.get_json(
destination,
path=path,
args=args,
try_trailing_slash_on_400=True,
destination, path=path, args=args, try_trailing_slash_on_400=True
)
@defer.inlineCallbacks
@ -163,7 +160,8 @@ class TransportLayerClient(object):
"""
logger.debug(
"send_data dest=%s, txid=%s",
transaction.destination, transaction.transaction_id
transaction.destination,
transaction.transaction_id,
)
if transaction.destination == self.server_name:
@ -189,8 +187,9 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
def make_query(
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
):
path = _create_v1_path("/query/%s", query_type)
content = yield self.client.get_json(
@ -235,8 +234,8 @@ class TransportLayerClient(object):
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
"make_membership_event called with membership='%s', must be one of %s"
% (membership, ",".join(valid_memberships))
)
path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
@ -268,9 +267,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
destination=destination, path=path, data=content
)
defer.returnValue(response)
@ -284,7 +281,6 @@ class TransportLayerClient(object):
destination=destination,
path=path,
data=content,
# we want to do our best to send this through. The problem is
# that if it fails, we won't retry it later, so if the remote
# server was just having a momentary blip, the room will be out of
@ -300,10 +296,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
destination=destination, path=path, data=content, ignore_backoff=True
)
defer.returnValue(response)
@ -314,26 +307,27 @@ class TransportLayerClient(object):
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
destination=destination, path=path, data=content, ignore_backoff=True
)
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def get_public_rooms(self, remote_server, limit, since_token,
search_filter=None, include_all_networks=False,
third_party_instance_id=None):
def get_public_rooms(
self,
remote_server,
limit,
since_token,
search_filter=None,
include_all_networks=False,
third_party_instance_id=None,
):
path = _create_v1_path("/publicRooms")
args = {
"include_all_networks": "true" if include_all_networks else "false",
}
args = {"include_all_networks": "true" if include_all_networks else "false"}
if third_party_instance_id:
args["third_party_instance_id"] = third_party_instance_id,
args["third_party_instance_id"] = (third_party_instance_id,)
if limit:
args["limit"] = [str(limit)]
if since_token:
@ -342,10 +336,7 @@ class TransportLayerClient(object):
# TODO(erikj): Actually send the search_filter across federation.
response = yield self.client.get_json(
destination=remote_server,
path=path,
args=args,
ignore_backoff=True,
destination=remote_server, path=path, args=args, ignore_backoff=True
)
defer.returnValue(response)
@ -353,12 +344,10 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
path = _create_v1_path("/exchange_third_party_invite/%s", room_id,)
path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=event_dict,
destination=destination, path=path, data=event_dict
)
defer.returnValue(response)
@ -368,10 +357,7 @@ class TransportLayerClient(object):
def get_event_auth(self, destination, room_id, event_id):
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json(
destination=destination,
path=path,
)
content = yield self.client.get_json(destination=destination, path=path)
defer.returnValue(content)
@ -381,9 +367,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
content = yield self.client.post_json(
destination=destination,
path=path,
data=content,
destination=destination, path=path, data=content
)
defer.returnValue(content)
@ -416,10 +400,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/user/keys/query")
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
timeout=timeout,
destination=destination, path=path, data=query_content, timeout=timeout
)
defer.returnValue(content)
@ -443,9 +424,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/user/devices/%s", user_id)
content = yield self.client.get_json(
destination=destination,
path=path,
timeout=timeout,
destination=destination, path=path, timeout=timeout
)
defer.returnValue(content)
@ -479,18 +458,23 @@ class TransportLayerClient(object):
path = _create_v1_path("/user/keys/claim")
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
timeout=timeout,
destination=destination, path=path, data=query_content, timeout=timeout
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth, timeout):
path = _create_v1_path("/get_missing_events/%s", room_id,)
def get_missing_events(
self,
destination,
room_id,
earliest_events,
latest_events,
limit,
min_depth,
timeout,
):
path = _create_v1_path("/get_missing_events/%s", room_id)
content = yield self.client.post_json(
destination=destination,
@ -510,7 +494,7 @@ class TransportLayerClient(object):
def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile
"""
path = _create_v1_path("/groups/%s/profile", group_id,)
path = _create_v1_path("/groups/%s/profile", group_id)
return self.client.get_json(
destination=destination,
@ -529,7 +513,7 @@ class TransportLayerClient(object):
requester_user_id (str)
content (dict): The new profile of the group
"""
path = _create_v1_path("/groups/%s/profile", group_id,)
path = _create_v1_path("/groups/%s/profile", group_id)
return self.client.post_json(
destination=destination,
@ -543,7 +527,7 @@ class TransportLayerClient(object):
def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary
"""
path = _create_v1_path("/groups/%s/summary", group_id,)
path = _create_v1_path("/groups/%s/summary", group_id)
return self.client.get_json(
destination=destination,
@ -556,7 +540,7 @@ class TransportLayerClient(object):
def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group
"""
path = _create_v1_path("/groups/%s/rooms", group_id,)
path = _create_v1_path("/groups/%s/rooms", group_id)
return self.client.get_json(
destination=destination,
@ -565,11 +549,12 @@ class TransportLayerClient(object):
ignore_backoff=True,
)
def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
content):
def add_room_to_group(
self, destination, group_id, requester_user_id, room_id, content
):
"""Add a room to a group
"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.post_json(
destination=destination,
@ -579,13 +564,13 @@ class TransportLayerClient(object):
ignore_backoff=True,
)
def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
config_key, content):
def update_room_in_group(
self, destination, group_id, requester_user_id, room_id, config_key, content
):
"""Update room in group
"""
path = _create_v1_path(
"/groups/%s/room/%s/config/%s",
group_id, room_id, config_key,
"/groups/%s/room/%s/config/%s", group_id, room_id, config_key
)
return self.client.post_json(
@ -599,7 +584,7 @@ class TransportLayerClient(object):
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group
"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.delete_json(
destination=destination,
@ -612,7 +597,7 @@ class TransportLayerClient(object):
def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group
"""
path = _create_v1_path("/groups/%s/users", group_id,)
path = _create_v1_path("/groups/%s/users", group_id)
return self.client.get_json(
destination=destination,
@ -625,7 +610,7 @@ class TransportLayerClient(object):
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group
"""
path = _create_v1_path("/groups/%s/invited_users", group_id,)
path = _create_v1_path("/groups/%s/invited_users", group_id)
return self.client.get_json(
destination=destination,
@ -638,16 +623,10 @@ class TransportLayerClient(object):
def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite
"""
path = _create_v1_path(
"/groups/%s/users/%s/accept_invite",
group_id, user_id,
)
path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
@ -657,14 +636,13 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
def invite_to_group(
self, destination, group_id, user_id, requester_user_id, content
):
"""Invite a user to a group
"""
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
@ -686,15 +664,13 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
def remove_user_from_group(self, destination, group_id, requester_user_id,
user_id, content):
def remove_user_from_group(
self, destination, group_id, requester_user_id, user_id, content
):
"""Remove a user fron a group
"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
@ -708,8 +684,9 @@ class TransportLayerClient(object):
)
@log_function
def remove_user_from_group_notification(self, destination, group_id, user_id,
content):
def remove_user_from_group_notification(
self, destination, group_id, user_id, content
):
"""Sent by group server to inform a user's server that they have been
kicked from the group.
"""
@ -717,10 +694,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
@ -732,24 +706,24 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
def update_group_summary_room(self, destination, group_id, user_id, room_id,
category_id, content):
def update_group_summary_room(
self, destination, group_id, user_id, room_id, category_id, content
):
"""Update a room entry in a group summary
"""
if category_id:
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id,
group_id,
category_id,
room_id,
)
else:
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
return self.client.post_json(
destination=destination,
@ -760,17 +734,20 @@ class TransportLayerClient(object):
)
@log_function
def delete_group_summary_room(self, destination, group_id, user_id, room_id,
category_id):
def delete_group_summary_room(
self, destination, group_id, user_id, room_id, category_id
):
"""Delete a room entry in a group summary
"""
if category_id:
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id,
group_id,
category_id,
room_id,
)
else:
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
return self.client.delete_json(
destination=destination,
@ -783,7 +760,7 @@ class TransportLayerClient(object):
def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group
"""
path = _create_v1_path("/groups/%s/categories", group_id,)
path = _create_v1_path("/groups/%s/categories", group_id)
return self.client.get_json(
destination=destination,
@ -796,7 +773,7 @@ class TransportLayerClient(object):
def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group
"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.get_json(
destination=destination,
@ -806,11 +783,12 @@ class TransportLayerClient(object):
)
@log_function
def update_group_category(self, destination, group_id, requester_user_id, category_id,
content):
def update_group_category(
self, destination, group_id, requester_user_id, category_id, content
):
"""Update a category in a group
"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.post_json(
destination=destination,
@ -821,11 +799,12 @@ class TransportLayerClient(object):
)
@log_function
def delete_group_category(self, destination, group_id, requester_user_id,
category_id):
def delete_group_category(
self, destination, group_id, requester_user_id, category_id
):
"""Delete a category in a group
"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.delete_json(
destination=destination,
@ -838,7 +817,7 @@ class TransportLayerClient(object):
def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group
"""
path = _create_v1_path("/groups/%s/roles", group_id,)
path = _create_v1_path("/groups/%s/roles", group_id)
return self.client.get_json(
destination=destination,
@ -851,7 +830,7 @@ class TransportLayerClient(object):
def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info
"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.get_json(
destination=destination,
@ -861,11 +840,12 @@ class TransportLayerClient(object):
)
@log_function
def update_group_role(self, destination, group_id, requester_user_id, role_id,
content):
def update_group_role(
self, destination, group_id, requester_user_id, role_id, content
):
"""Update a role in a group
"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.post_json(
destination=destination,
@ -879,7 +859,7 @@ class TransportLayerClient(object):
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group
"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.delete_json(
destination=destination,
@ -889,17 +869,17 @@ class TransportLayerClient(object):
)
@log_function
def update_group_summary_user(self, destination, group_id, requester_user_id,
user_id, role_id, content):
def update_group_summary_user(
self, destination, group_id, requester_user_id, user_id, role_id, content
):
"""Update a users entry in a group
"""
if role_id:
path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id,
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
)
else:
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -910,11 +890,10 @@ class TransportLayerClient(object):
)
@log_function
def set_group_join_policy(self, destination, group_id, requester_user_id,
content):
def set_group_join_policy(self, destination, group_id, requester_user_id, content):
"""Sets the join policy for a group
"""
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,)
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
return self.client.put_json(
destination=destination,
@ -925,17 +904,17 @@ class TransportLayerClient(object):
)
@log_function
def delete_group_summary_user(self, destination, group_id, requester_user_id,
user_id, role_id):
def delete_group_summary_user(
self, destination, group_id, requester_user_id, user_id, role_id
):
"""Delete a users entry in a group
"""
if role_id:
path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id,
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
)
else:
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
return self.client.delete_json(
destination=destination,
@ -953,10 +932,7 @@ class TransportLayerClient(object):
content = {"user_ids": user_ids}
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
destination=destination, path=path, data=content, ignore_backoff=True
)
@ -975,9 +951,8 @@ def _create_v1_path(path, *args):
Returns:
str
"""
return (
FEDERATION_V1_PREFIX
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
return FEDERATION_V1_PREFIX + path % tuple(
urllib.parse.quote(arg, "") for arg in args
)
@ -996,7 +971,6 @@ def _create_v2_path(path, *args):
Returns:
str
"""
return (
FEDERATION_V2_PREFIX
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
return FEDERATION_V2_PREFIX + path % tuple(
urllib.parse.quote(arg, "") for arg in args
)

View file

@ -66,8 +66,7 @@ class TransportLayerServer(JsonResource):
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
self.clock,
config=hs.config.rc_federation,
self.clock, config=hs.config.rc_federation
)
self.register_servlets()
@ -84,11 +83,13 @@ class TransportLayerServer(JsonResource):
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass
@ -105,8 +106,8 @@ class Authenticator(object):
def authenticate_request(self, request, content):
now = self._clock.time_msec()
json_request = {
"method": request.method.decode('ascii'),
"uri": request.uri.decode('ascii'),
"method": request.method.decode("ascii"),
"uri": request.uri.decode("ascii"),
"destination": self.server_name,
"signatures": {},
}
@ -120,7 +121,7 @@ class Authenticator(object):
if not auth_headers:
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
401, "Missing Authorization headers", Codes.UNAUTHORIZED
)
for auth in auth_headers:
@ -130,14 +131,14 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig
if (
self.federation_domain_whitelist is not None and
origin not in self.federation_domain_whitelist
self.federation_domain_whitelist is not None
and origin not in self.federation_domain_whitelist
):
raise FederationDeniedError(origin)
if not json_request["signatures"]:
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
401, "Missing Authorization headers", Codes.UNAUTHORIZED
)
yield self.keyring.verify_json_for_server(
@ -177,12 +178,12 @@ def _parse_auth_header(header_bytes):
AuthenticationError if the header could not be parsed
"""
try:
header_str = header_bytes.decode('utf-8')
header_str = header_bytes.decode("utf-8")
params = header_str.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith("\""):
if value.startswith('"'):
return value[1:-1]
else:
return value
@ -198,11 +199,11 @@ def _parse_auth_header(header_bytes):
except Exception as e:
logger.warn(
"Error parsing auth header '%s': %s",
header_bytes.decode('ascii', 'replace'),
header_bytes.decode("ascii", "replace"),
e,
)
raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED,
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
@ -242,6 +243,7 @@ class BaseFederationServlet(object):
Exception: other exceptions will be caught, logged, and a 500 will be
returned.
"""
REQUIRE_AUTH = True
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
@ -293,9 +295,7 @@ class BaseFederationServlet(object):
origin, content, request.args, *args, **kwargs
)
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)
response = yield func(origin, content, request.args, *args, **kwargs)
defer.returnValue(response)
@ -343,14 +343,12 @@ class FederationSendServlet(BaseFederationServlet):
try:
transaction_data = content
logger.debug(
"Decoded %s: %s",
transaction_id, str(transaction_data)
)
logger.debug("Decoded %s: %s", transaction_id, str(transaction_data))
logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d)",
transaction_id, origin,
transaction_id,
origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
)
@ -361,8 +359,7 @@ class FederationSendServlet(BaseFederationServlet):
# Add some extra data to the transaction dict that isn't included
# in the request body.
transaction_data.update(
transaction_id=transaction_id,
destination=self.server_name
transaction_id=transaction_id, destination=self.server_name
)
except Exception as e:
@ -372,7 +369,7 @@ class FederationSendServlet(BaseFederationServlet):
try:
code, response = yield self.handler.on_incoming_transaction(
origin, transaction_data,
origin, transaction_data
)
except Exception:
logger.exception("on_incoming_transaction failed")
@ -416,7 +413,7 @@ class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/?"
def on_GET(self, origin, content, query, context):
versions = [x.decode('ascii') for x in query[b"v"]]
versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None)
if not limit:
@ -432,7 +429,7 @@ class FederationQueryServlet(BaseFederationServlet):
def on_GET(self, origin, content, query, query_type):
return self.handler.on_query_request(
query_type,
{k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()}
{k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()},
)
@ -456,15 +453,14 @@ class FederationMakeJoinServlet(BaseFederationServlet):
Deferred[(int, object)|None]: either (response code, response object) to
return a JSON response, or None if the request has already been handled.
"""
versions = query.get(b'ver')
versions = query.get(b"ver")
if versions is not None:
supported_versions = [v.decode("utf-8") for v in versions]
else:
supported_versions = ["1"]
content = yield self.handler.on_make_join_request(
origin, context, user_id,
supported_versions=supported_versions,
origin, context, user_id, supported_versions=supported_versions
)
defer.returnValue((200, content))
@ -474,9 +470,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_leave_request(
origin, context, user_id,
)
content = yield self.handler.on_make_leave_request(origin, context, user_id)
defer.returnValue((200, content))
@ -517,7 +511,7 @@ class FederationV1InviteServlet(BaseFederationServlet):
# state resolution algorithm, and we don't use that for processing
# invites
content = yield self.handler.on_invite_request(
origin, content, room_version=RoomVersions.V1.identifier,
origin, content, room_version=RoomVersions.V1.identifier
)
# V1 federation API is defined to return a content of `[200, {...}]`
@ -545,7 +539,7 @@ class FederationV2InviteServlet(BaseFederationServlet):
event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
content = yield self.handler.on_invite_request(
origin, event, room_version=room_version,
origin, event, room_version=room_version
)
defer.returnValue((200, content))
@ -629,8 +623,10 @@ class On3pidBindServlet(BaseFederationServlet):
for invite in content["invites"]:
try:
if "signed" not in invite or "token" not in invite["signed"]:
message = ("Rejecting received notification of third-"
"party invite without signed: %s" % (invite,))
message = (
"Rejecting received notification of third-"
"party invite without signed: %s" % (invite,)
)
logger.info(message)
raise SynapseError(400, message)
yield self.handler.exchange_third_party_invite(
@ -671,18 +667,23 @@ class OpenIdUserInfo(BaseFederationServlet):
def on_GET(self, origin, content, query):
token = query.get(b"access_token", [None])[0]
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
}))
defer.returnValue(
(401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"})
)
return
user_id = yield self.handler.on_openid_userinfo(token.decode('ascii'))
user_id = yield self.handler.on_openid_userinfo(token.decode("ascii"))
if user_id is None:
defer.returnValue((401, {
"errcode": "M_UNKNOWN_TOKEN",
"error": "Access Token unknown or expired"
}))
defer.returnValue(
(
401,
{
"errcode": "M_UNKNOWN_TOKEN",
"error": "Access Token unknown or expired",
},
)
)
defer.returnValue((200, {"sub": user_id}))
@ -720,15 +721,15 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
def __init__(self, handler, authenticator, ratelimiter, server_name, deny_access):
def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
super(PublicRoomList, self).__init__(
handler, authenticator, ratelimiter, server_name,
handler, authenticator, ratelimiter, server_name
)
self.deny_access = deny_access
self.allow_access = allow_access
@defer.inlineCallbacks
def on_GET(self, origin, content, query):
if self.deny_access:
if not self.allow_access:
raise FederationDeniedError(origin)
limit = parse_integer_from_args(query, "limit", 0)
@ -748,9 +749,7 @@ class PublicRoomList(BaseFederationServlet):
network_tuple = ThirdPartyInstanceID(None, None)
data = yield self.handler.get_local_public_room_list(
limit, since_token,
network_tuple=network_tuple,
from_federation=True,
limit, since_token, network_tuple=network_tuple, from_federation=True
)
defer.returnValue((200, data))
@ -761,17 +760,18 @@ class FederationVersionServlet(BaseFederationServlet):
REQUIRE_AUTH = False
def on_GET(self, origin, content, query):
return defer.succeed((200, {
"server": {
"name": "Synapse",
"version": get_version_string(synapse)
},
}))
return defer.succeed(
(
200,
{"server": {"name": "Synapse", "version": get_version_string(synapse)}},
)
)
class FederationGroupsProfileServlet(BaseFederationServlet):
"""Get/set the basic profile of a group on behalf of a user
"""
PATH = "/groups/(?P<group_id>[^/]*)/profile"
@defer.inlineCallbacks
@ -780,9 +780,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.get_group_profile(
group_id, requester_user_id
)
new_content = yield self.handler.get_group_profile(group_id, requester_user_id)
defer.returnValue((200, new_content))
@ -808,9 +806,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.get_group_summary(
group_id, requester_user_id
)
new_content = yield self.handler.get_group_summary(group_id, requester_user_id)
defer.returnValue((200, new_content))
@ -818,6 +814,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
class FederationGroupsRoomsServlet(BaseFederationServlet):
"""Get the rooms in a group on behalf of a user
"""
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
@defer.inlineCallbacks
@ -826,9 +823,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.get_rooms_in_group(
group_id, requester_user_id
)
new_content = yield self.handler.get_rooms_in_group(group_id, requester_user_id)
defer.returnValue((200, new_content))
@ -836,6 +831,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
"""Add/remove room from group
"""
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@defer.inlineCallbacks
@ -857,7 +853,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.remove_room_from_group(
group_id, requester_user_id, room_id,
group_id, requester_user_id, room_id
)
defer.returnValue((200, new_content))
@ -866,6 +862,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
"""Update room config in group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)"
@ -878,7 +875,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin")
result = yield self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content,
group_id, requester_user_id, room_id, config_key, content
)
defer.returnValue((200, result))
@ -887,6 +884,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
class FederationGroupsUsersServlet(BaseFederationServlet):
"""Get the users in a group on behalf of a user
"""
PATH = "/groups/(?P<group_id>[^/]*)/users"
@defer.inlineCallbacks
@ -895,9 +893,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.get_users_in_group(
group_id, requester_user_id
)
new_content = yield self.handler.get_users_in_group(group_id, requester_user_id)
defer.returnValue((200, new_content))
@ -905,6 +901,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
"""Get the users that have been invited to a group
"""
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
@defer.inlineCallbacks
@ -923,6 +920,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
class FederationGroupsInviteServlet(BaseFederationServlet):
"""Ask a group server to invite someone to the group
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@defer.inlineCallbacks
@ -932,7 +930,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.invite_to_group(
group_id, user_id, requester_user_id, content,
group_id, user_id, requester_user_id, content
)
defer.returnValue((200, new_content))
@ -941,6 +939,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
"""Accept an invitation from the group server
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
@defer.inlineCallbacks
@ -948,9 +947,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
new_content = yield self.handler.accept_invite(
group_id, user_id, content,
)
new_content = yield self.handler.accept_invite(group_id, user_id, content)
defer.returnValue((200, new_content))
@ -958,6 +955,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
class FederationGroupsJoinServlet(BaseFederationServlet):
"""Attempt to join a group
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
@defer.inlineCallbacks
@ -965,9 +963,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
new_content = yield self.handler.join_group(
group_id, user_id, content,
)
new_content = yield self.handler.join_group(group_id, user_id, content)
defer.returnValue((200, new_content))
@ -975,6 +971,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
"""Leave or kick a user from the group
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@defer.inlineCallbacks
@ -984,7 +981,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.remove_user_from_group(
group_id, user_id, requester_user_id, content,
group_id, user_id, requester_user_id, content
)
defer.returnValue((200, new_content))
@ -993,6 +990,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
"""A group server has invited a local user
"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@defer.inlineCallbacks
@ -1000,9 +998,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
new_content = yield self.handler.on_invite(
group_id, user_id, content,
)
new_content = yield self.handler.on_invite(group_id, user_id, content)
defer.returnValue((200, new_content))
@ -1010,6 +1006,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
"""A group server has removed a local user
"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@defer.inlineCallbacks
@ -1018,7 +1015,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
raise SynapseError(403, "user_id doesn't match origin")
new_content = yield self.handler.user_removed_from_group(
group_id, user_id, content,
group_id, user_id, content
)
defer.returnValue((200, new_content))
@ -1027,6 +1024,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
"""A group or user's server renews their attestation
"""
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
@defer.inlineCallbacks
@ -1047,6 +1045,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
- /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?"
@ -1063,7 +1062,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.update_group_summary_room(
group_id, requester_user_id,
group_id,
requester_user_id,
room_id=room_id,
category_id=category_id,
content=content,
@ -1081,9 +1081,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.delete_group_summary_room(
group_id, requester_user_id,
room_id=room_id,
category_id=category_id,
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
defer.returnValue((200, resp))
@ -1092,9 +1090,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
class FederationGroupsCategoriesServlet(BaseFederationServlet):
"""Get all categories for a group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/categories/?"
)
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
@ -1102,9 +1099,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
resp = yield self.handler.get_group_categories(
group_id, requester_user_id,
)
resp = yield self.handler.get_group_categories(group_id, requester_user_id)
defer.returnValue((200, resp))
@ -1112,9 +1107,8 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
class FederationGroupsCategoryServlet(BaseFederationServlet):
"""Add/remove/get a category in a group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
)
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id, category_id):
@ -1138,7 +1132,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.upsert_group_category(
group_id, requester_user_id, category_id, content,
group_id, requester_user_id, category_id, content
)
defer.returnValue((200, resp))
@ -1153,7 +1147,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.delete_group_category(
group_id, requester_user_id, category_id,
group_id, requester_user_id, category_id
)
defer.returnValue((200, resp))
@ -1162,9 +1156,8 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
class FederationGroupsRolesServlet(BaseFederationServlet):
"""Get roles in a group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/roles/?"
)
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
@ -1172,9 +1165,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
resp = yield self.handler.get_group_roles(
group_id, requester_user_id,
)
resp = yield self.handler.get_group_roles(group_id, requester_user_id)
defer.returnValue((200, resp))
@ -1182,9 +1173,8 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
class FederationGroupsRoleServlet(BaseFederationServlet):
"""Add/remove/get a role in a group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
)
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id, role_id):
@ -1192,9 +1182,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
resp = yield self.handler.get_group_role(
group_id, requester_user_id, role_id
)
resp = yield self.handler.get_group_role(group_id, requester_user_id, role_id)
defer.returnValue((200, resp))
@ -1208,7 +1196,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.update_group_role(
group_id, requester_user_id, role_id, content,
group_id, requester_user_id, role_id, content
)
defer.returnValue((200, resp))
@ -1223,7 +1211,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.delete_group_role(
group_id, requester_user_id, role_id,
group_id, requester_user_id, role_id
)
defer.returnValue((200, resp))
@ -1236,6 +1224,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
- /groups/:group/summary/users/:user_id
- /groups/:group/summary/roles/:role/users/:user_id
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?"
@ -1252,7 +1241,8 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.update_group_summary_user(
group_id, requester_user_id,
group_id,
requester_user_id,
user_id=user_id,
role_id=role_id,
content=content,
@ -1270,9 +1260,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.delete_group_summary_user(
group_id, requester_user_id,
user_id=user_id,
role_id=role_id,
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
defer.returnValue((200, resp))
@ -1281,14 +1269,13 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
"""Get roles in a group
"""
PATH = (
"/get_groups_publicised"
)
PATH = "/get_groups_publicised"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
resp = yield self.handler.bulk_get_publicised_groups(
content["user_ids"], proxy=False,
content["user_ids"], proxy=False
)
defer.returnValue((200, resp))
@ -1297,6 +1284,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
"""Sets whether a group is joinable without an invite or knock
"""
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
@defer.inlineCallbacks
@ -1317,6 +1305,7 @@ class RoomComplexityServlet(BaseFederationServlet):
Indicates to other servers how complex (and therefore likely
resource-intensive) a public room this server knows about is.
"""
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
@ -1325,9 +1314,7 @@ class RoomComplexityServlet(BaseFederationServlet):
store = self.handler.hs.get_datastore()
is_public = yield store.is_room_world_readable_or_publicly_joinable(
room_id
)
is_public = yield store.is_room_world_readable_or_publicly_joinable(room_id)
if not is_public:
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
@ -1362,13 +1349,9 @@ FEDERATION_SERVLET_CLASSES = (
RoomComplexityServlet,
)
OPENID_SERVLET_CLASSES = (
OpenIdUserInfo,
)
OPENID_SERVLET_CLASSES = (OpenIdUserInfo,)
ROOM_LIST_CLASSES = (
PublicRoomList,
)
ROOM_LIST_CLASSES = (PublicRoomList,)
GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsProfileServlet,
@ -1399,9 +1382,7 @@ GROUP_LOCAL_SERVLET_CLASSES = (
)
GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet,
)
GROUP_ATTESTATION_SERVLET_CLASSES = (FederationGroupsRenewAttestaionServlet,)
DEFAULT_SERVLET_GROUPS = (
"federation",
@ -1455,7 +1436,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
deny_access=hs.config.restrict_public_rooms_to_local_users,
allow_access=hs.config.allow_public_rooms_over_federation,
).register(resource)
if "group_server" in servlet_groups:

View file

@ -32,21 +32,11 @@ class Edu(JsonEncodedObject):
internal ID or previous references graph.
"""
valid_keys = [
"origin",
"destination",
"edu_type",
"content",
]
valid_keys = ["origin", "destination", "edu_type", "content"]
required_keys = [
"edu_type",
]
required_keys = ["edu_type"]
internal_keys = [
"origin",
"destination",
]
internal_keys = ["origin", "destination"]
class Transaction(JsonEncodedObject):
@ -75,10 +65,7 @@ class Transaction(JsonEncodedObject):
"edus",
]
internal_keys = [
"transaction_id",
"destination",
]
internal_keys = ["transaction_id", "destination"]
required_keys = [
"transaction_id",
@ -98,9 +85,7 @@ class Transaction(JsonEncodedObject):
del kwargs["edus"]
super(Transaction, self).__init__(
transaction_id=transaction_id,
pdus=pdus,
**kwargs
transaction_id=transaction_id, pdus=pdus, **kwargs
)
@staticmethod
@ -109,13 +94,9 @@ class Transaction(JsonEncodedObject):
transaction_id and origin_server_ts keys.
"""
if "origin_server_ts" not in kwargs:
raise KeyError(
"Require 'origin_server_ts' to construct a Transaction"
)
raise KeyError("Require 'origin_server_ts' to construct a Transaction")
if "transaction_id" not in kwargs:
raise KeyError(
"Require 'transaction_id' to construct a Transaction"
)
raise KeyError("Require 'transaction_id' to construct a Transaction")
kwargs["pdus"] = [p.get_pdu_json() for p in pdus]

View file

@ -42,7 +42,7 @@ from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import RequestSendFailed, SynapseError
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util.logcontext import run_in_background
@ -65,6 +65,7 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
class GroupAttestationSigning(object):
"""Creates and verifies group attestations.
"""
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.clock = hs.get_clock()
@ -113,11 +114,15 @@ class GroupAttestationSigning(object):
validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
valid_until_ms = int(self.clock.time_msec() + validity_period)
return sign_json({
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": valid_until_ms,
}, self.server_name, self.signing_key)
return sign_json(
{
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": valid_until_ms,
},
self.server_name,
self.signing_key,
)
class GroupAttestionRenewer(object):
@ -132,9 +137,10 @@ class GroupAttestionRenewer(object):
self.is_mine_id = hs.is_mine_id
self.attestations = hs.get_groups_attestation_signing()
self._renew_attestations_loop = self.clock.looping_call(
self._start_renew_attestations, 30 * 60 * 1000,
)
if not hs.config.worker_app:
self._renew_attestations_loop = self.clock.looping_call(
self._start_renew_attestations, 30 * 60 * 1000
)
@defer.inlineCallbacks
def on_renew_attestation(self, group_id, user_id, content):
@ -146,9 +152,7 @@ class GroupAttestionRenewer(object):
raise SynapseError(400, "Neither user not group are on this server")
yield self.attestations.verify_attestation(
attestation,
user_id=user_id,
group_id=group_id,
attestation, user_id=user_id, group_id=group_id
)
yield self.store.update_remote_attestion(group_id, user_id, attestation)
@ -179,7 +183,8 @@ class GroupAttestionRenewer(object):
else:
logger.warn(
"Incorrectly trying to do attestations for user: %r in %r",
user_id, group_id,
user_id,
group_id,
)
yield self.store.remove_attestation_renewal(group_id, user_id)
return
@ -187,21 +192,20 @@ class GroupAttestionRenewer(object):
attestation = self.attestations.create_attestation(group_id, user_id)
yield self.transport_client.renew_group_attestation(
destination, group_id, user_id,
content={"attestation": attestation},
destination, group_id, user_id, content={"attestation": attestation}
)
yield self.store.update_attestation_renewal(
group_id, user_id, attestation
)
except RequestSendFailed as e:
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
"Failed to renew attestation of %r in %r: %s",
user_id, group_id, e,
"Failed to renew attestation of %r in %r: %s", user_id, group_id, e
)
except Exception:
logger.exception("Error renewing attestation of %r in %r",
user_id, group_id)
logger.exception(
"Error renewing attestation of %r in %r", user_id, group_id
)
for row in rows:
group_id = row["group_id"]

View file

@ -54,8 +54,9 @@ class GroupsServerHandler(object):
hs.get_groups_attestation_renewer()
@defer.inlineCallbacks
def check_group_is_ours(self, group_id, requester_user_id,
and_exists=False, and_is_admin=None):
def check_group_is_ours(
self, group_id, requester_user_id, and_exists=False, and_is_admin=None
):
"""Check that the group is ours, and optionally if it exists.
If group does exist then return group.
@ -73,7 +74,9 @@ class GroupsServerHandler(object):
if and_exists and not group:
raise SynapseError(404, "Unknown group")
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
is_user_in_group = yield self.store.is_user_in_group(
requester_user_id, group_id
)
if group and not is_user_in_group and not group["is_public"]:
raise SynapseError(404, "Unknown group")
@ -96,25 +99,27 @@ class GroupsServerHandler(object):
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
is_user_in_group = yield self.store.is_user_in_group(
requester_user_id, group_id
)
profile = yield self.get_group_profile(group_id, requester_user_id)
users, roles = yield self.store.get_users_for_summary_by_role(
group_id, include_private=is_user_in_group,
group_id, include_private=is_user_in_group
)
# TODO: Add profiles to users
rooms, categories = yield self.store.get_rooms_for_summary_by_category(
group_id, include_private=is_user_in_group,
group_id, include_private=is_user_in_group
)
for room_entry in rooms:
room_id = room_entry["room_id"]
joined_users = yield self.store.get_users_in_room(room_id)
entry = yield self.room_list_handler.generate_room_entry(
room_id, len(joined_users), with_alias=False, allow_private=True,
room_id, len(joined_users), with_alias=False, allow_private=True
)
entry = dict(entry) # so we don't change whats cached
entry.pop("room_id", None)
@ -134,7 +139,7 @@ class GroupsServerHandler(object):
entry["attestation"] = attestation
else:
entry["attestation"] = self.attestations.create_attestation(
group_id, user_id,
group_id, user_id
)
user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
@ -143,34 +148,34 @@ class GroupsServerHandler(object):
users.sort(key=lambda e: e.get("order", 0))
membership_info = yield self.store.get_users_membership_info_in_group(
group_id, requester_user_id,
group_id, requester_user_id
)
defer.returnValue({
"profile": profile,
"users_section": {
"users": users,
"roles": roles,
"total_user_count_estimate": 0, # TODO
},
"rooms_section": {
"rooms": rooms,
"categories": categories,
"total_room_count_estimate": 0, # TODO
},
"user": membership_info,
})
defer.returnValue(
{
"profile": profile,
"users_section": {
"users": users,
"roles": roles,
"total_user_count_estimate": 0, # TODO
},
"rooms_section": {
"rooms": rooms,
"categories": categories,
"total_room_count_estimate": 0, # TODO
},
"user": membership_info,
}
)
@defer.inlineCallbacks
def update_group_summary_room(self, group_id, requester_user_id,
room_id, category_id, content):
def update_group_summary_room(
self, group_id, requester_user_id, room_id, category_id, content
):
"""Add/update a room to the group summary
"""
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
RoomID.from_string(room_id) # Ensure valid room id
@ -190,21 +195,17 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
def delete_group_summary_room(self, group_id, requester_user_id,
room_id, category_id):
def delete_group_summary_room(
self, group_id, requester_user_id, room_id, category_id
):
"""Remove a room from the summary
"""
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_room_from_summary(
group_id=group_id,
room_id=room_id,
category_id=category_id,
group_id=group_id, room_id=room_id, category_id=category_id
)
defer.returnValue({})
@ -223,9 +224,7 @@ class GroupsServerHandler(object):
join_policy = _parse_join_policy_from_contents(content)
if join_policy is None:
raise SynapseError(
400, "No value specified for 'm.join_policy'"
)
raise SynapseError(400, "No value specified for 'm.join_policy'")
yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
@ -237,9 +236,7 @@ class GroupsServerHandler(object):
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
categories = yield self.store.get_group_categories(
group_id=group_id,
)
categories = yield self.store.get_group_categories(group_id=group_id)
defer.returnValue({"categories": categories})
@defer.inlineCallbacks
@ -249,8 +246,7 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = yield self.store.get_group_category(
group_id=group_id,
category_id=category_id,
group_id=group_id, category_id=category_id
)
defer.returnValue(res)
@ -260,10 +256,7 @@ class GroupsServerHandler(object):
"""Add/Update a group category
"""
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
is_public = _parse_visibility_from_contents(content)
@ -283,15 +276,11 @@ class GroupsServerHandler(object):
"""Delete a group category
"""
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_group_category(
group_id=group_id,
category_id=category_id,
group_id=group_id, category_id=category_id
)
defer.returnValue({})
@ -302,9 +291,7 @@ class GroupsServerHandler(object):
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
roles = yield self.store.get_group_roles(
group_id=group_id,
)
roles = yield self.store.get_group_roles(group_id=group_id)
defer.returnValue({"roles": roles})
@defer.inlineCallbacks
@ -313,10 +300,7 @@ class GroupsServerHandler(object):
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = yield self.store.get_group_role(
group_id=group_id,
role_id=role_id,
)
res = yield self.store.get_group_role(group_id=group_id, role_id=role_id)
defer.returnValue(res)
@defer.inlineCallbacks
@ -324,10 +308,7 @@ class GroupsServerHandler(object):
"""Add/update a role in a group
"""
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
is_public = _parse_visibility_from_contents(content)
@ -335,10 +316,7 @@ class GroupsServerHandler(object):
profile = content.get("profile")
yield self.store.upsert_group_role(
group_id=group_id,
role_id=role_id,
is_public=is_public,
profile=profile,
group_id=group_id, role_id=role_id, is_public=is_public, profile=profile
)
defer.returnValue({})
@ -348,26 +326,21 @@ class GroupsServerHandler(object):
"""Remove role from group
"""
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_group_role(
group_id=group_id,
role_id=role_id,
)
yield self.store.remove_group_role(group_id=group_id, role_id=role_id)
defer.returnValue({})
@defer.inlineCallbacks
def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
content):
def update_group_summary_user(
self, group_id, requester_user_id, user_id, role_id, content
):
"""Add/update a users entry in the group summary
"""
yield self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
order = content.get("order", None)
@ -389,13 +362,11 @@ class GroupsServerHandler(object):
"""Remove a user from the group summary
"""
yield self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_user_from_summary(
group_id=group_id,
user_id=user_id,
role_id=role_id,
group_id=group_id, user_id=user_id, role_id=role_id
)
defer.returnValue({})
@ -411,8 +382,11 @@ class GroupsServerHandler(object):
if group:
cols = [
"name", "short_description", "long_description",
"avatar_url", "is_public",
"name",
"short_description",
"long_description",
"avatar_url",
"is_public",
]
group_description = {key: group[key] for key in cols}
group_description["is_openly_joinable"] = group["join_policy"] == "open"
@ -426,12 +400,11 @@ class GroupsServerHandler(object):
"""Update the group profile
"""
yield self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
profile = {}
for keyname in ("name", "avatar_url", "short_description",
"long_description"):
for keyname in ("name", "avatar_url", "short_description", "long_description"):
if keyname in content:
value = content[keyname]
if not isinstance(value, string_types):
@ -449,10 +422,12 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
is_user_in_group = yield self.store.is_user_in_group(
requester_user_id, group_id
)
user_results = yield self.store.get_users_in_group(
group_id, include_private=is_user_in_group,
group_id, include_private=is_user_in_group
)
chunk = []
@ -470,24 +445,25 @@ class GroupsServerHandler(object):
entry["is_privileged"] = bool(is_privileged)
if not self.is_mine_id(g_user_id):
attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
attestation = yield self.store.get_remote_attestation(
group_id, g_user_id
)
if not attestation:
continue
entry["attestation"] = attestation
else:
entry["attestation"] = self.attestations.create_attestation(
group_id, g_user_id,
group_id, g_user_id
)
chunk.append(entry)
# TODO: If admin add lists of users whose attestations have timed out
defer.returnValue({
"chunk": chunk,
"total_user_count_estimate": len(user_results),
})
defer.returnValue(
{"chunk": chunk, "total_user_count_estimate": len(user_results)}
)
@defer.inlineCallbacks
def get_invited_users_in_group(self, group_id, requester_user_id):
@ -498,7 +474,9 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
is_user_in_group = yield self.store.is_user_in_group(
requester_user_id, group_id
)
if not is_user_in_group:
raise SynapseError(403, "User not in group")
@ -508,9 +486,7 @@ class GroupsServerHandler(object):
user_profiles = []
for user_id in invited_users:
user_profile = {
"user_id": user_id
}
user_profile = {"user_id": user_id}
try:
profile = yield self.profile_handler.get_profile_from_cache(user_id)
user_profile.update(profile)
@ -518,10 +494,9 @@ class GroupsServerHandler(object):
logger.warn("Error getting profile for %s: %s", user_id, e)
user_profiles.append(user_profile)
defer.returnValue({
"chunk": user_profiles,
"total_user_count_estimate": len(invited_users),
})
defer.returnValue(
{"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
)
@defer.inlineCallbacks
def get_rooms_in_group(self, group_id, requester_user_id):
@ -532,10 +507,12 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
is_user_in_group = yield self.store.is_user_in_group(
requester_user_id, group_id
)
room_results = yield self.store.get_rooms_in_group(
group_id, include_private=is_user_in_group,
group_id, include_private=is_user_in_group
)
chunk = []
@ -544,7 +521,7 @@ class GroupsServerHandler(object):
joined_users = yield self.store.get_users_in_room(room_id)
entry = yield self.room_list_handler.generate_room_entry(
room_id, len(joined_users), with_alias=False, allow_private=True,
room_id, len(joined_users), with_alias=False, allow_private=True
)
if not entry:
@ -556,10 +533,9 @@ class GroupsServerHandler(object):
chunk.sort(key=lambda e: -e["num_joined_members"])
defer.returnValue({
"chunk": chunk,
"total_room_count_estimate": len(room_results),
})
defer.returnValue(
{"chunk": chunk, "total_room_count_estimate": len(room_results)}
)
@defer.inlineCallbacks
def add_room_to_group(self, group_id, requester_user_id, room_id, content):
@ -578,8 +554,9 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
content):
def update_room_in_group(
self, group_id, requester_user_id, room_id, config_key, content
):
"""Update room in group
"""
RoomID.from_string(room_id) # Ensure valid room id
@ -592,8 +569,7 @@ class GroupsServerHandler(object):
is_public = _parse_visibility_dict(content)
yield self.store.update_room_in_group_visibility(
group_id, room_id,
is_public=is_public,
group_id, room_id, is_public=is_public
)
else:
raise SynapseError(400, "Uknown config option")
@ -625,10 +601,7 @@ class GroupsServerHandler(object):
# TODO: Check if user is already invited
content = {
"profile": {
"name": group["name"],
"avatar_url": group["avatar_url"],
},
"profile": {"name": group["name"], "avatar_url": group["avatar_url"]},
"inviter": requester_user_id,
}
@ -638,9 +611,7 @@ class GroupsServerHandler(object):
local_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content.update({
"attestation": local_attestation,
})
content.update({"attestation": local_attestation})
res = yield self.transport_client.invite_to_group_notification(
get_domain_from_id(user_id), group_id, user_id, content
@ -658,31 +629,24 @@ class GroupsServerHandler(object):
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
user_id=user_id,
group_id=group_id,
remote_attestation, user_id=user_id, group_id=group_id
)
else:
remote_attestation = None
yield self.store.add_user_to_group(
group_id, user_id,
group_id,
user_id,
is_admin=False,
is_public=False, # TODO
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
elif res["state"] == "invite":
yield self.store.add_group_invite(
group_id, user_id,
)
defer.returnValue({
"state": "invite"
})
yield self.store.add_group_invite(group_id, user_id)
defer.returnValue({"state": "invite"})
elif res["state"] == "reject":
defer.returnValue({
"state": "reject"
})
defer.returnValue({"state": "reject"})
else:
raise SynapseError(502, "Unknown state returned by HS")
@ -693,16 +657,12 @@ class GroupsServerHandler(object):
See accept_invite, join_group.
"""
if not self.hs.is_mine_id(user_id):
local_attestation = self.attestations.create_attestation(
group_id, user_id,
)
local_attestation = self.attestations.create_attestation(group_id, user_id)
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
user_id=user_id,
group_id=group_id,
remote_attestation, user_id=user_id, group_id=group_id
)
else:
local_attestation = None
@ -711,7 +671,8 @@ class GroupsServerHandler(object):
is_public = _parse_visibility_from_contents(content)
yield self.store.add_user_to_group(
group_id, user_id,
group_id,
user_id,
is_admin=False,
is_public=is_public,
local_attestation=local_attestation,
@ -731,17 +692,14 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_invited = yield self.store.is_user_invited_to_local_group(
group_id, requester_user_id,
group_id, requester_user_id
)
if not is_invited:
raise SynapseError(403, "User not invited to group")
local_attestation = yield self._add_user(group_id, requester_user_id, content)
defer.returnValue({
"state": "join",
"attestation": local_attestation,
})
defer.returnValue({"state": "join", "attestation": local_attestation})
@defer.inlineCallbacks
def join_group(self, group_id, requester_user_id, content):
@ -753,15 +711,12 @@ class GroupsServerHandler(object):
group_info = yield self.check_group_is_ours(
group_id, requester_user_id, and_exists=True
)
if group_info['join_policy'] != "open":
if group_info["join_policy"] != "open":
raise SynapseError(403, "Group is not publicly joinable")
local_attestation = yield self._add_user(group_id, requester_user_id, content)
defer.returnValue({
"state": "join",
"attestation": local_attestation,
})
defer.returnValue({"state": "join", "attestation": local_attestation})
@defer.inlineCallbacks
def knock(self, group_id, requester_user_id, content):
@ -800,9 +755,7 @@ class GroupsServerHandler(object):
is_kick = True
yield self.store.remove_user_from_group(
group_id, user_id,
)
yield self.store.remove_user_from_group(group_id, user_id)
if is_kick:
if self.hs.is_mine_id(user_id):
@ -830,19 +783,20 @@ class GroupsServerHandler(object):
if group:
raise SynapseError(400, "Group already exists")
is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
is_admin = yield self.auth.is_server_admin(
UserID.from_string(requester_user_id)
)
if not is_admin:
if not self.hs.config.enable_group_creation:
raise SynapseError(
403, "Only a server admin can create groups on this server",
403, "Only a server admin can create groups on this server"
)
localpart = group_id_obj.localpart
if not localpart.startswith(self.hs.config.group_creation_prefix):
raise SynapseError(
400,
"Can only create groups with prefix %r on this server" % (
self.hs.config.group_creation_prefix,
),
"Can only create groups with prefix %r on this server"
% (self.hs.config.group_creation_prefix,),
)
profile = content.get("profile", {})
@ -865,21 +819,19 @@ class GroupsServerHandler(object):
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
user_id=requester_user_id,
group_id=group_id,
remote_attestation, user_id=requester_user_id, group_id=group_id
)
local_attestation = self.attestations.create_attestation(
group_id,
requester_user_id,
group_id, requester_user_id
)
else:
local_attestation = None
remote_attestation = None
yield self.store.add_user_to_group(
group_id, requester_user_id,
group_id,
requester_user_id,
is_admin=True,
is_public=True, # TODO
local_attestation=local_attestation,
@ -893,9 +845,7 @@ class GroupsServerHandler(object):
avatar_url=user_profile.get("avatar_url"),
)
defer.returnValue({
"group_id": group_id,
})
defer.returnValue({"group_id": group_id})
@defer.inlineCallbacks
def delete_group(self, group_id, requester_user_id):
@ -911,29 +861,22 @@ class GroupsServerHandler(object):
Deferred
"""
yield self.check_group_is_ours(
group_id, requester_user_id,
and_exists=True,
)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
# Only server admins or group admins can delete groups.
is_admin = yield self.store.is_user_admin_in_group(
group_id, requester_user_id
)
is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id)
if not is_admin:
is_admin = yield self.auth.is_server_admin(
UserID.from_string(requester_user_id),
UserID.from_string(requester_user_id)
)
if not is_admin:
raise SynapseError(403, "User is not an admin")
# Before deleting the group lets kick everyone out of it
users = yield self.store.get_users_in_group(
group_id, include_private=True,
)
users = yield self.store.get_users_in_group(group_id, include_private=True)
@defer.inlineCallbacks
def _kick_user_from_group(user_id):
@ -989,9 +932,7 @@ def _parse_join_policy_dict(join_policy_dict):
return "invite"
if join_policy_type not in ("invite", "open"):
raise SynapseError(
400, "Synapse only supports 'invite'/'open' join rule"
)
raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule")
return join_policy_type
@ -1018,7 +959,5 @@ def _parse_visibility_dict(visibility):
return True
if vis_type not in ("public", "private"):
raise SynapseError(
400, "Synapse only supports 'public'/'private' visibility"
)
raise SynapseError(400, "Synapse only supports 'public'/'private' visibility")
return vis_type == "public"

View file

@ -94,14 +94,15 @@ class BaseHandler(object):
burst_count = self.hs.config.rc_message.burst_count
allowed, time_allowed = self.ratelimiter.can_do_action(
user_id, time_now,
user_id,
time_now,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
retry_after_ms=int(1000 * (time_allowed - time_now))
)
@defer.inlineCallbacks
@ -139,7 +140,7 @@ class BaseHandler(object):
if member_event.content["membership"] not in {
Membership.JOIN,
Membership.INVITE
Membership.INVITE,
}:
continue
@ -156,8 +157,7 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
requester = synapse.types.create_requester(
target_user, is_guest=True)
requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler()
yield handler.update_membership(
requester,

View file

@ -20,7 +20,7 @@ class AccountDataEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
def get_current_key(self, direction='f'):
def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()
@defer.inlineCallbacks
@ -34,29 +34,22 @@ class AccountDataEventSource(object):
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items():
results.append({
"type": "m.tag",
"content": {"tags": room_tags},
"room_id": room_id,
})
results.append(
{"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
)
account_data, room_account_data = (
yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
)
for account_data_type, content in account_data.items():
results.append({
"type": account_data_type,
"content": content,
})
results.append({"type": account_data_type, "content": content})
for room_id, account_data in room_account_data.items():
for account_data_type, content in account_data.items():
results.append({
"type": account_data_type,
"content": content,
"room_id": room_id,
})
results.append(
{"type": account_data_type, "content": content, "room_id": room_id}
)
defer.returnValue((results, current_stream_id))

View file

@ -49,12 +49,10 @@ class AccountValidityHandler(object):
app_name = self.hs.config.email_app_name
self._subject = self._account_validity.renew_email_subject % {
"app": app_name,
"app": app_name
}
self._from_string = self.hs.config.email_notif_from % {
"app": app_name,
}
self._from_string = self.hs.config.email_notif_from % {"app": app_name}
except Exception:
# If substitution failed, fall back to the bare strings.
self._subject = self._account_validity.renew_email_subject
@ -69,10 +67,7 @@ class AccountValidityHandler(object):
)
# Check the renewal emails to send and send them every 30min.
self.clock.looping_call(
self.send_renewal_emails,
30 * 60 * 1000,
)
self.clock.looping_call(self.send_renewal_emails, 30 * 60 * 1000)
@defer.inlineCallbacks
def send_renewal_emails(self):
@ -86,8 +81,7 @@ class AccountValidityHandler(object):
if expiring_users:
for user in expiring_users:
yield self._send_renewal_email(
user_id=user["user_id"],
expiration_ts=user["expiration_ts_ms"],
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)
@defer.inlineCallbacks
@ -110,6 +104,9 @@ class AccountValidityHandler(object):
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
# account manually.
# We don't need to do a specific check to make sure the account isn't
# deactivated, as a deactivated account isn't supposed to have any
# email address attached to it.
if not addresses:
return
@ -143,32 +140,33 @@ class AccountValidityHandler(object):
for address in addresses:
raw_to = email.utils.parseaddr(address)[1]
multipart_msg = MIMEMultipart('alternative')
multipart_msg['Subject'] = self._subject
multipart_msg['From'] = self._from_string
multipart_msg['To'] = address
multipart_msg['Date'] = email.utils.formatdate()
multipart_msg['Message-ID'] = email.utils.make_msgid()
multipart_msg = MIMEMultipart("alternative")
multipart_msg["Subject"] = self._subject
multipart_msg["From"] = self._from_string
multipart_msg["To"] = address
multipart_msg["Date"] = email.utils.formatdate()
multipart_msg["Message-ID"] = email.utils.make_msgid()
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)
logger.info("Sending renewal email to %s", address)
yield make_deferred_yieldable(self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
reactor=self.hs.get_reactor(),
port=self.hs.config.email_smtp_port,
requireAuthentication=self.hs.config.email_smtp_user is not None,
username=self.hs.config.email_smtp_user,
password=self.hs.config.email_smtp_pass,
requireTransportSecurity=self.hs.config.require_transport_security
))
yield make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from,
raw_to,
multipart_msg.as_string().encode("utf8"),
reactor=self.hs.get_reactor(),
port=self.hs.config.email_smtp_port,
requireAuthentication=self.hs.config.email_smtp_user is not None,
username=self.hs.config.email_smtp_user,
password=self.hs.config.email_smtp_pass,
requireTransportSecurity=self.hs.config.require_transport_security,
)
)
yield self.store.set_renewal_mail_status(
user_id=user_id,
email_sent=True,
)
yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
@defer.inlineCallbacks
def _get_email_addresses_for_user(self, user_id):
@ -245,9 +243,7 @@ class AccountValidityHandler(object):
expiration_ts = self.clock.time_msec() + self._account_validity.period
yield self.store.set_account_validity_for_user(
user_id=user_id,
expiration_ts=expiration_ts,
email_sent=email_sent,
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
defer.returnValue(expiration_ts)

View file

@ -15,14 +15,9 @@
import logging
import attr
from zope.interface import implementer
import twisted
import twisted.internet.error
from twisted.internet import defer
from twisted.python.filepath import FilePath
from twisted.python.url import URL
from twisted.web import server, static
from twisted.web.resource import Resource
@ -30,27 +25,6 @@ from synapse.app import check_bind_error
logger = logging.getLogger(__name__)
try:
from txacme.interfaces import ICertificateStore
@attr.s
@implementer(ICertificateStore)
class ErsatzStore(object):
"""
A store that only stores in memory.
"""
certs = attr.ib(default=attr.Factory(dict))
def store(self, server_name, pem_objects):
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
except ImportError:
# txacme is missing
pass
class AcmeHandler(object):
def __init__(self, hs):
@ -60,6 +34,7 @@ class AcmeHandler(object):
@defer.inlineCallbacks
def start_listening(self):
from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
# from eliot import add_destinations
@ -67,50 +42,27 @@ class AcmeHandler(object):
#
# add_destinations(TwistedDestination())
from txacme.challenges import HTTP01Responder
from txacme.service import AcmeIssuingService
from txacme.endpoint import load_or_create_client_key
from txacme.client import Client
from josepy.jwa import RS256
well_known = Resource()
self._store = ErsatzStore()
responder = HTTP01Responder()
self._issuer = AcmeIssuingService(
cert_store=self._store,
client_creator=(
lambda: Client.from_url(
reactor=self.reactor,
url=URL.from_text(self.hs.config.acme_url),
key=load_or_create_client_key(
FilePath(self.hs.config.config_dir_path)
),
alg=RS256,
)
),
clock=self.reactor,
responders=[responder],
self._issuer = acme_issuing_service.create_issuing_service(
self.reactor,
acme_url=self.hs.config.acme_url,
account_key_file=self.hs.config.acme_account_key_file,
well_known_resource=well_known,
)
well_known = Resource()
well_known.putChild(b'acme-challenge', responder.resource)
responder_resource = Resource()
responder_resource.putChild(b'.well-known', well_known)
responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
responder_resource.putChild(b".well-known", well_known)
responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain"))
srv = server.Site(responder_resource)
bind_addresses = self.hs.config.acme_bind_addresses
for host in bind_addresses:
logger.info(
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port,
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
)
try:
self.reactor.listenTCP(
self.hs.config.acme_port,
srv,
interface=host,
)
self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)
@ -132,7 +84,7 @@ class AcmeHandler(object):
logger.exception("Fail!")
raise
logger.warning("Reprovisioned %s, saving.", self._acme_domain)
cert_chain = self._store.certs[self._acme_domain]
cert_chain = self._issuer.cert_store.certs[self._acme_domain]
try:
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:

View file

@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility function to create an ACME issuing service.
This file contains the unconditional imports on the acme and cryptography bits that we
only need (and may only have available) if we are doing ACME, so is designed to be
imported conditionally.
"""
import logging
import attr
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from josepy import JWKRSA
from josepy.jwa import RS256
from txacme.challenges import HTTP01Responder
from txacme.client import Client
from txacme.interfaces import ICertificateStore
from txacme.service import AcmeIssuingService
from txacme.util import generate_private_key
from zope.interface import implementer
from twisted.internet import defer
from twisted.python.filepath import FilePath
from twisted.python.url import URL
logger = logging.getLogger(__name__)
def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
"""Create an ACME issuing service, and attach it to a web Resource
Args:
reactor: twisted reactor
acme_url (str): URL to use to request certificates
account_key_file (str): where to store the account key
well_known_resource (twisted.web.IResource): web resource for .well-known.
we will attach a child resource for "acme-challenge".
Returns:
AcmeIssuingService
"""
responder = HTTP01Responder()
well_known_resource.putChild(b"acme-challenge", responder.resource)
store = ErsatzStore()
return AcmeIssuingService(
cert_store=store,
client_creator=(
lambda: Client.from_url(
reactor=reactor,
url=URL.from_text(acme_url),
key=load_or_create_client_key(account_key_file),
alg=RS256,
)
),
clock=reactor,
responders=[responder],
)
@attr.s
@implementer(ICertificateStore)
class ErsatzStore(object):
"""
A store that only stores in memory.
"""
certs = attr.ib(default=attr.Factory(dict))
def store(self, server_name, pem_objects):
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
def load_or_create_client_key(key_file):
"""Load the ACME account key from a file, creating it if it does not exist.
Args:
key_file (str): name of the file to use as the account key
"""
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
# hardcode the 'client.key' filename
acme_key_file = FilePath(key_file)
if acme_key_file.exists():
logger.info("Loading ACME account key from '%s'", acme_key_file)
key = serialization.load_pem_private_key(
acme_key_file.getContent(), password=None, backend=default_backend()
)
else:
logger.info("Saving new ACME account key to '%s'", acme_key_file)
key = generate_private_key("rsa")
acme_key_file.setContent(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
)
return JWKRSA(key=key)

View file

@ -23,7 +23,6 @@ logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler):
def __init__(self, hs):
super(AdminHandler, self).__init__(hs)
@ -33,23 +32,17 @@ class AdminHandler(BaseHandler):
sessions = yield self.store.get_user_ip_and_agents(user)
for session in sessions:
connections.append({
"ip": session["ip"],
"last_seen": session["last_seen"],
"user_agent": session["user_agent"],
})
connections.append(
{
"ip": session["ip"],
"last_seen": session["last_seen"],
"user_agent": session["user_agent"],
}
)
ret = {
"user_id": user.to_string(),
"devices": {
"": {
"sessions": [
{
"connections": connections,
}
]
},
},
"devices": {"": {"sessions": [{"connections": connections}]}},
}
defer.returnValue(ret)

View file

@ -38,7 +38,6 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed
class ApplicationServicesHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
@ -101,9 +100,10 @@ class ApplicationServicesHandler(object):
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
def start_scheduler():
return self.scheduler.start().addErrback(
log_failure, "Application Services Failure",
log_failure, "Application Services Failure"
)
run_as_background_process("as_scheduler", start_scheduler)
@ -118,10 +118,15 @@ class ApplicationServicesHandler(object):
for event in events:
yield handle_event(event)
yield make_deferred_yieldable(defer.gatherResults([
run_in_background(handle_room_events, evs)
for evs in itervalues(events_by_room)
], consumeErrors=True))
yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(handle_room_events, evs)
for evs in itervalues(events_by_room)
],
consumeErrors=True,
)
)
yield self.store.set_appservice_last_pos(upper_bound)
@ -129,20 +134,23 @@ class ApplicationServicesHandler(object):
ts = yield self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_positions.labels(
"appservice_sender").set(upper_bound)
"appservice_sender"
).set(upper_bound)
events_processed_counter.inc(len(events))
event_processing_loop_room_count.labels(
"appservice_sender"
).inc(len(events_by_room))
event_processing_loop_room_count.labels("appservice_sender").inc(
len(events_by_room)
)
event_processing_loop_counter.labels("appservice_sender").inc()
synapse.metrics.event_processing_lag.labels(
"appservice_sender").set(now - ts)
"appservice_sender"
).set(now - ts)
synapse.metrics.event_processing_last_ts.labels(
"appservice_sender").set(ts)
"appservice_sender"
).set(ts)
finally:
self.is_processing = False
@ -155,13 +163,9 @@ class ApplicationServicesHandler(object):
Returns:
True if this user exists on at least one application service.
"""
user_query_services = yield self._get_services_for_user(
user_id=user_id
)
user_query_services = yield self._get_services_for_user(user_id=user_id)
for user_service in user_query_services:
is_known_user = yield self.appservice_api.query_user(
user_service, user_id
)
is_known_user = yield self.appservice_api.query_user(user_service, user_id)
if is_known_user:
defer.returnValue(True)
defer.returnValue(False)
@ -179,9 +183,7 @@ class ApplicationServicesHandler(object):
room_alias_str = room_alias.to_string()
services = self.store.get_app_services()
alias_query_services = [
s for s in services if (
s.is_interested_in_alias(room_alias_str)
)
s for s in services if (s.is_interested_in_alias(room_alias_str))
]
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
@ -189,22 +191,24 @@ class ApplicationServicesHandler(object):
)
if is_known_alias:
# the alias exists now so don't query more ASes.
result = yield self.store.get_association_from_room_alias(
room_alias
)
result = yield self.store.get_association_from_room_alias(room_alias)
defer.returnValue(result)
@defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield make_deferred_yieldable(defer.DeferredList([
run_in_background(
self.appservice_api.query_3pe,
service, kind, protocol, fields,
results = yield make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
self.appservice_api.query_3pe, service, kind, protocol, fields
)
for service in services
],
consumeErrors=True,
)
for service in services
], consumeErrors=True))
)
ret = []
for (success, result) in results:
@ -276,18 +280,12 @@ class ApplicationServicesHandler(object):
def _get_services_for_user(self, user_id):
services = self.store.get_app_services()
interested_list = [
s for s in services if (
s.is_interested_in_user(user_id)
)
]
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
return defer.succeed(interested_list)
def _get_services_for_3pn(self, protocol):
services = self.store.get_app_services()
interested_list = [
s for s in services if s.is_interested_in_protocol(protocol)
]
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
return defer.succeed(interested_list)
@defer.inlineCallbacks

View file

@ -134,13 +134,9 @@ class AuthHandler(BaseHandler):
"""
# build a list of supported flows
flows = [
[login_type] for login_type in self._supported_login_types
]
flows = [[login_type] for login_type in self._supported_login_types]
result, params, _ = yield self.check_auth(
flows, request_body, clientip,
)
result, params, _ = yield self.check_auth(flows, request_body, clientip)
# find the completed login type
for login_type in self._supported_login_types:
@ -151,9 +147,7 @@ class AuthHandler(BaseHandler):
break
else:
# this can't happen
raise Exception(
"check_auth returned True but no successful login type",
)
raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
if user_id != requester.user.to_string():
@ -215,11 +209,11 @@ class AuthHandler(BaseHandler):
authdict = None
sid = None
if clientdict and 'auth' in clientdict:
authdict = clientdict['auth']
del clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
if clientdict and "auth" in clientdict:
authdict = clientdict["auth"]
del clientdict["auth"]
if "session" in authdict:
sid = authdict["session"]
session = self._get_session_info(sid)
if len(clientdict) > 0:
@ -232,27 +226,27 @@ class AuthHandler(BaseHandler):
# on a home server.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
session['clientdict'] = clientdict
session["clientdict"] = clientdict
self._save_session(session)
elif 'clientdict' in session:
clientdict = session['clientdict']
elif "clientdict" in session:
clientdict = session["clientdict"]
if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session),
self._auth_dict_for_flows(flows, session)
)
if 'creds' not in session:
session['creds'] = {}
creds = session['creds']
if "creds" not in session:
session["creds"] = {}
creds = session["creds"]
# check auth type currently being presented
errordict = {}
if 'type' in authdict:
login_type = authdict['type']
if "type" in authdict:
login_type = authdict["type"]
try:
result = yield self._check_auth_dict(
authdict, clientip, password_servlet=password_servlet,
authdict, clientip, password_servlet=password_servlet
)
if result:
creds[login_type] = result
@ -281,16 +275,15 @@ class AuthHandler(BaseHandler):
# and is not sensitive).
logger.info(
"Auth completed with creds: %r. Client dict has keys: %r",
creds, list(clientdict)
creds,
list(clientdict),
)
defer.returnValue((creds, clientdict, session['id']))
defer.returnValue((creds, clientdict, session["id"]))
ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = list(creds)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(
ret,
)
raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
@ -300,15 +293,13 @@ class AuthHandler(BaseHandler):
"""
if stagetype not in self.checkers:
raise LoginError(400, "", Codes.MISSING_PARAM)
if 'session' not in authdict:
if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
sess = self._get_session_info(
authdict['session']
)
if 'creds' not in sess:
sess['creds'] = {}
creds = sess['creds']
sess = self._get_session_info(authdict["session"])
if "creds" not in sess:
sess["creds"] = {}
creds = sess["creds"]
result = yield self.checkers[stagetype](authdict, clientip)
if result:
@ -329,10 +320,10 @@ class AuthHandler(BaseHandler):
not send a session ID, returns None.
"""
sid = None
if clientdict and 'auth' in clientdict:
authdict = clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
if clientdict and "auth" in clientdict:
authdict = clientdict["auth"]
if "session" in authdict:
sid = authdict["session"]
return sid
def set_session_data(self, session_id, key, value):
@ -347,7 +338,7 @@ class AuthHandler(BaseHandler):
value (any): The data to store
"""
sess = self._get_session_info(session_id)
sess.setdefault('serverdict', {})[key] = value
sess.setdefault("serverdict", {})[key] = value
self._save_session(sess)
def get_session_data(self, session_id, key, default=None):
@ -360,7 +351,7 @@ class AuthHandler(BaseHandler):
default (any): Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks
def _check_auth_dict(self, authdict, clientip, password_servlet=False):
@ -378,15 +369,13 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
login_type = authdict['type']
login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
# XXX: Temporary workaround for having Synapse handle password resets
# See AuthHandler.check_auth for further details
res = yield checker(
authdict,
clientip=clientip,
password_servlet=password_servlet,
authdict, clientip=clientip, password_servlet=password_servlet
)
defer.returnValue(res)
@ -408,13 +397,11 @@ class AuthHandler(BaseHandler):
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
400, "Captcha response is required",
errcode=Codes.CAPTCHA_NEEDED
400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
"Submitting recaptcha response %s with remoteip %s",
user_response, clientip
"Submitting recaptcha response %s with remoteip %s", user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
@ -424,34 +411,34 @@ class AuthHandler(BaseHandler):
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
'secret': self.hs.config.recaptcha_private_key,
'response': user_response,
'remoteip': clientip,
}
"secret": self.hs.config.recaptcha_private_key,
"response": user_response,
"remoteip": clientip,
},
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json.loads(data)
if 'success' in resp_body:
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body['success'] else "Failed",
resp_body.get('hostname')
"Successful" if resp_body["success"] else "Failed",
resp_body.get("hostname"),
)
if resp_body['success']:
if resp_body["success"]:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
def _check_email_identity(self, authdict, **kwargs):
return self._check_threepid('email', authdict, **kwargs)
return self._check_threepid("email", authdict, **kwargs)
def _check_msisdn(self, authdict, **kwargs):
return self._check_threepid('msisdn', authdict)
return self._check_threepid("msisdn", authdict)
def _check_dummy_auth(self, authdict, **kwargs):
return defer.succeed(True)
@ -461,10 +448,10 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs):
if 'threepid_creds' not in authdict:
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict['threepid_creds']
threepid_creds = authdict["threepid_creds"]
identity_handler = self.hs.get_handlers().identity_handler
@ -482,31 +469,36 @@ class AuthHandler(BaseHandler):
validated=True,
)
threepid = {
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
} if row else None
threepid = (
{
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
}
if row
else None
)
if row:
# Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(400, "Password resets are not enabled on this homeserver")
raise SynapseError(
400, "Password resets are not enabled on this homeserver"
)
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
if threepid['medium'] != medium:
if threepid["medium"] != medium:
raise LoginError(
401,
"Expecting threepid of type '%s', got '%s'" % (
medium, threepid['medium'],
),
errcode=Codes.UNAUTHORIZED
"Expecting threepid of type '%s', got '%s'"
% (medium, threepid["medium"]),
errcode=Codes.UNAUTHORIZED,
)
threepid['threepid_creds'] = authdict['threepid_creds']
threepid["threepid_creds"] = authdict["threepid_creds"]
defer.returnValue(threepid)
@ -520,13 +512,14 @@ class AuthHandler(BaseHandler):
"version": self.hs.config.user_consent_version,
"en": {
"name": self.hs.config.user_consent_policy_name,
"url": "%s_matrix/consent?v=%s" % (
"url": "%s_matrix/consent?v=%s"
% (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
},
},
},
}
}
}
def _auth_dict_for_flows(self, flows, session):
@ -547,9 +540,9 @@ class AuthHandler(BaseHandler):
params[stage] = get_params[stage]()
return {
"session": session['id'],
"session": session["id"],
"flows": [{"stages": f} for f in public_flows],
"params": params
"params": params,
}
def _get_session_info(self, session_id):
@ -560,9 +553,7 @@ class AuthHandler(BaseHandler):
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {
"id": session_id,
}
self.sessions[session_id] = {"id": session_id}
return self.sessions[session_id]
@ -652,7 +643,8 @@ class AuthHandler(BaseHandler):
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
user_id,
user_infos.keys(),
)
defer.returnValue(result)
@ -690,12 +682,10 @@ class AuthHandler(BaseHandler):
user is too high too proceed.
"""
if username.startswith('@'):
if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(
username, self.hs.hostname
).to_string()
qualified_user_id = UserID(username, self.hs.hostname).to_string()
self.ratelimit_login_per_account(qualified_user_id)
@ -713,17 +703,15 @@ class AuthHandler(BaseHandler):
raise SynapseError(400, "Missing parameter: password")
for provider in self.password_providers:
if (hasattr(provider, "check_password")
and login_type == LoginType.PASSWORD):
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
is_valid = yield provider.check_password(
qualified_user_id, password,
)
is_valid = yield provider.check_password(qualified_user_id, password)
if is_valid:
defer.returnValue((qualified_user_id, None))
if (not hasattr(provider, "get_supported_login_types")
or not hasattr(provider, "check_auth")):
if not hasattr(provider, "get_supported_login_types") or not hasattr(
provider, "check_auth"
):
# this password provider doesn't understand custom login types
continue
@ -744,15 +732,12 @@ class AuthHandler(BaseHandler):
login_dict[f] = login_submission[f]
if missing_fields:
raise SynapseError(
400, "Missing parameters for login type %s: %s" % (
login_type,
missing_fields,
),
400,
"Missing parameters for login type %s: %s"
% (login_type, missing_fields),
)
result = yield provider.check_auth(
username, login_type, login_dict,
)
result = yield provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
@ -762,7 +747,7 @@ class AuthHandler(BaseHandler):
known_login_type = True
canonical_user_id = yield self._check_local_password(
qualified_user_id, password,
qualified_user_id, password
)
if canonical_user_id:
@ -773,7 +758,8 @@ class AuthHandler(BaseHandler):
# unknown username or invalid password.
self._failed_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), time_now_s=self._clock.time(),
qualified_user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
@ -781,10 +767,7 @@ class AuthHandler(BaseHandler):
# We raise a 403 here, but note that if we're doing user-interactive
# login, it turns all LoginErrors into a 401 anyway.
raise LoginError(
403, "Invalid password",
errcode=Codes.FORBIDDEN
)
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def check_password_provider_3pid(self, medium, address, password):
@ -810,9 +793,7 @@ class AuthHandler(BaseHandler):
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = yield provider.check_3pid_auth(
medium, address, password,
)
result = yield provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
@ -853,8 +834,7 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None):
access_token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token,
device_id)
yield self.store.add_access_token_to_user(user_id, access_token, device_id)
defer.returnValue(access_token)
@defer.inlineCallbacks
@ -896,12 +876,13 @@ class AuthHandler(BaseHandler):
# delete pushers associated with this access token
if user_info["token_id"] is not None:
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"], )
str(user_info["user"]), (user_info["token_id"],)
)
@defer.inlineCallbacks
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
device_id=None):
def delete_access_tokens_for_user(
self, user_id, except_token_id=None, device_id=None
):
"""Invalidate access tokens belonging to a user
Args:
@ -915,7 +896,7 @@ class AuthHandler(BaseHandler):
Deferred
"""
tokens_and_devices = yield self.store.user_delete_access_tokens(
user_id, except_token_id=except_token_id, device_id=device_id,
user_id, except_token_id=except_token_id, device_id=device_id
)
# see if any of our auth providers want to know about this
@ -923,14 +904,12 @@ class AuthHandler(BaseHandler):
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
yield provider.on_logged_out(
user_id=user_id,
device_id=device_id,
access_token=token,
user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
user_id, (token_id for _, token_id, _ in tokens_and_devices),
user_id, (token_id for _, token_id, _ in tokens_and_devices)
)
@defer.inlineCallbacks
@ -944,12 +923,11 @@ class AuthHandler(BaseHandler):
# of specific types of threepid (and fixes the fact that checking
# for the presence of an email address during password reset was
# case sensitive).
if medium == 'email':
if medium == "email":
address = address.lower()
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
@defer.inlineCallbacks
@ -973,22 +951,15 @@ class AuthHandler(BaseHandler):
"""
# 'Canonicalise' email addresses as per above
if medium == 'email':
if medium == "email":
address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler
result = yield identity_handler.try_unbind_threepid(
user_id,
{
'medium': medium,
'address': address,
'id_server': id_server,
},
user_id, {"medium": medium, "address": address, "id_server": id_server}
)
yield self.store.user_delete_threepid(
user_id, medium, address,
)
yield self.store.user_delete_threepid(user_id, medium, address)
defer.returnValue(result)
def _save_session(self, session):
@ -1006,14 +977,15 @@ class AuthHandler(BaseHandler):
Returns:
Deferred(unicode): Hashed password.
"""
def _do_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.hashpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
).decode('ascii')
).decode("ascii")
return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash)
@ -1027,18 +999,19 @@ class AuthHandler(BaseHandler):
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
"""
def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
stored_hash
pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
stored_hash,
)
if stored_hash:
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode('ascii')
stored_hash = stored_hash.encode("ascii")
return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else:
@ -1058,14 +1031,16 @@ class AuthHandler(BaseHandler):
for this user is too high too proceed.
"""
self._failed_attempts_ratelimiter.ratelimit(
user_id.lower(), time_now_s=self._clock.time(),
user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
self._account_ratelimiter.ratelimit(
user_id.lower(), time_now_s=self._clock.time(),
user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
update=True,
@ -1083,9 +1058,9 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
macaroon.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
macaroon.add_first_party_caveat(
"nonce = %s" % (stringutils.random_string_with_symbols(16),)
)
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
@ -1116,7 +1091,8 @@ class MacaroonGenerator(object):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017, 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -27,6 +28,7 @@ logger = logging.getLogger(__name__)
class DeactivateAccountHandler(BaseHandler):
"""Handler which deals with deactivating user accounts."""
def __init__(self, hs):
super(DeactivateAccountHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
@ -42,6 +44,8 @@ class DeactivateAccountHandler(BaseHandler):
# it left off (if it has work left to do).
hs.get_reactor().callWhenRunning(self._start_user_parting)
self._account_validity_enabled = hs.config.account_validity.enabled
@defer.inlineCallbacks
def deactivate_account(self, user_id, erase_data, id_server=None):
"""Deactivate a user's account
@ -75,9 +79,9 @@ class DeactivateAccountHandler(BaseHandler):
result = yield self._identity_handler.try_unbind_threepid(
user_id,
{
'medium': threepid['medium'],
'address': threepid['address'],
'id_server': id_server,
"medium": threepid["medium"],
"address": threepid["address"],
"id_server": id_server,
},
)
identity_server_supports_unbinding &= result
@ -86,7 +90,7 @@ class DeactivateAccountHandler(BaseHandler):
logger.exception("Failed to remove threepid from ID server")
raise SynapseError(400, "Failed to remove threepid from ID server")
yield self.store.user_delete_threepid(
user_id, threepid['medium'], threepid['address'],
user_id, threepid["medium"], threepid["address"]
)
# delete any devices belonging to the user, which will also
@ -114,6 +118,13 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running)
self._start_user_parting()
# Remove all information on the user from the account_validity table.
if self._account_validity_enabled:
yield self.store.delete_account_validity_for_user(user_id)
# Mark the user as deactivated.
yield self.store.set_user_deactivated_status(user_id, True)
defer.returnValue(identity_server_supports_unbinding)
def _start_user_parting(self):
@ -173,5 +184,6 @@ class DeactivateAccountHandler(BaseHandler):
except Exception:
logger.exception(
"Failed to part user %r from room %r: ignoring and continuing",
user_id, room_id,
user_id,
room_id,
)

View file

@ -58,9 +58,7 @@ class DeviceWorkerHandler(BaseHandler):
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
user_id, device_id=None
)
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)
devices = list(device_map.values())
for device in devices:
@ -85,9 +83,7 @@ class DeviceWorkerHandler(BaseHandler):
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
user_id, device_id,
)
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@ -114,13 +110,11 @@ class DeviceWorkerHandler(BaseHandler):
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_room_key,
user_id, from_token.room_key, now_room_key
)
rooms_changed.update(event.room_id for event in member_events)
stream_ordering = RoomStreamToken.parse_stream_token(
from_token.room_key
).stream
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
possibly_changed = set(changed)
possibly_left = set()
@ -206,10 +200,9 @@ class DeviceWorkerHandler(BaseHandler):
possibly_joined = []
possibly_left = []
defer.returnValue({
"changed": list(possibly_joined),
"left": list(possibly_left),
})
defer.returnValue(
{"changed": list(possibly_joined), "left": list(possibly_left)}
)
class DeviceHandler(DeviceWorkerHandler):
@ -223,17 +216,18 @@ class DeviceHandler(DeviceWorkerHandler):
federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update,
"m.device_list_update", self._edu_updater.incoming_device_list_update
)
federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
"user_devices", self.on_federation_query_user_devices
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
def check_device_registered(
self, user_id, device_id, initial_device_display_name=None
):
"""
If the given device has not been registered, register it with the
supplied display name.
@ -297,12 +291,10 @@ class DeviceHandler(DeviceWorkerHandler):
raise
yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
user_id, device_id=device_id
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
yield self.notify_device_update(user_id, [device_id])
@ -349,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler):
# considered as part of a critical path.
for device_id in device_ids:
yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
user_id, device_id=device_id
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
@ -372,9 +364,7 @@ class DeviceHandler(DeviceWorkerHandler):
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
user_id, device_id, new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, [device_id])
except errors.StoreError as e:
@ -404,29 +394,26 @@ class DeviceHandler(DeviceWorkerHandler):
for device_id in device_ids:
logger.debug(
"Notifying about update %r/%r, ID: %r", user_id, device_id,
position,
"Notifying about update %r/%r, ID: %r", user_id, device_id, position
)
room_ids = yield self.store.get_rooms_for_user(user_id)
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
if hosts:
logger.info("Sending device list update notif for %r to: %r", user_id, hosts)
logger.info(
"Sending device list update notif for %r to: %r", user_id, hosts
)
for host in hosts:
self.federation_sender.send_device_messages(host)
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
defer.returnValue({
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
})
defer.returnValue(
{"user_id": user_id, "stream_id": stream_id, "devices": devices}
)
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
@ -440,10 +427,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
class DeviceListEduUpdater(object):
@ -481,13 +465,15 @@ class DeviceListEduUpdater(object):
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
prev_ids = edu_content.pop("prev_id", [])
prev_ids = [str(p) for p in prev_ids] # They may come as ints
prev_ids = [str(p) for p in prev_ids] # They may come as ints
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning(
"Got device list update edu for %r/%r from %r",
user_id, device_id, origin,
user_id,
device_id,
origin,
)
return
@ -497,13 +483,12 @@ class DeviceListEduUpdater(object):
# probably won't get any further updates.
logger.warning(
"Got device list update edu for %r/%r, but don't share a room",
user_id, device_id,
user_id,
device_id,
)
return
logger.debug(
"Received device list update for %r/%r", user_id, device_id,
)
logger.debug("Received device list update for %r/%r", user_id, device_id)
self._pending_updates.setdefault(user_id, []).append(
(device_id, stream_id, prev_ids, edu_content)
@ -525,7 +510,10 @@ class DeviceListEduUpdater(object):
for device_id, stream_id, prev_ids, content in pending_updates:
logger.debug(
"Handling update %r/%r, ID: %r, prev: %r ",
user_id, device_id, stream_id, prev_ids,
user_id,
device_id,
stream_id,
prev_ids,
)
# Given a list of updates we check if we need to resync. This
@ -540,13 +528,13 @@ class DeviceListEduUpdater(object):
try:
result = yield self.federation.query_user_devices(origin, user_id)
except (
NotRetryingDestination, RequestSendFailed, HttpResponseException,
NotRetryingDestination,
RequestSendFailed,
HttpResponseException,
):
# TODO: Remember that we are now out of sync and try again
# later
logger.warn(
"Failed to handle device list update for %s", user_id,
)
logger.warn("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
@ -582,18 +570,21 @@ class DeviceListEduUpdater(object):
if len(devices) > 1000:
logger.warn(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id, len(devices)
user_id,
len(devices),
)
devices = []
for device in devices:
logger.debug(
"Handling resync update %r/%r, ID: %r",
user_id, device["device_id"], stream_id,
user_id,
device["device_id"],
stream_id,
)
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
user_id, devices, stream_id
)
device_ids = [device["device_id"] for device in devices]
yield self.device_handler.notify_device_update(user_id, device_ids)
@ -606,7 +597,7 @@ class DeviceListEduUpdater(object):
# change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
user_id, device_id, content, stream_id
)
yield self.device_handler.notify_device_update(
@ -624,14 +615,9 @@ class DeviceListEduUpdater(object):
"""
seen_updates = self._seen_updates.get(user_id, set())
extremity = yield self.store.get_device_list_last_stream_id_for_remote(
user_id
)
extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)
logger.debug(
"Current extremity for %r: %r",
user_id, extremity,
)
logger.debug("Current extremity for %r: %r", user_id, extremity)
stream_id_in_updates = set() # stream_ids in updates list
for _, stream_id, prev_ids, _ in updates:

View file

@ -25,7 +25,6 @@ logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
def __init__(self, hs):
"""
Args:
@ -47,15 +46,15 @@ class DeviceMessageHandler(object):
if origin != get_domain_from_id(sender_user_id):
logger.warn(
"Dropping device message from %r with spoofed sender %r",
origin, sender_user_id
origin,
sender_user_id,
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
messages_by_device = {

View file

@ -36,7 +36,6 @@ logger = logging.getLogger(__name__)
class DirectoryHandler(BaseHandler):
def __init__(self, hs):
super(DirectoryHandler, self).__init__(hs)
@ -77,15 +76,19 @@ class DirectoryHandler(BaseHandler):
raise SynapseError(400, "Failed to get server list")
yield self.store.create_room_alias_association(
room_alias,
room_id,
servers,
creator=creator,
room_alias, room_id, servers, creator=creator
)
@defer.inlineCallbacks
def create_association(self, requester, room_alias, room_id, servers=None,
send_event=True, check_membership=True):
def create_association(
self,
requester,
room_alias,
room_id,
servers=None,
send_event=True,
check_membership=True,
):
"""Attempt to create a new alias
Args:
@ -115,49 +118,40 @@ class DirectoryHandler(BaseHandler):
if service:
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400, "This application service has not reserved"
" this kind of alias.", errcode=Codes.EXCLUSIVE
400,
"This application service has not reserved" " this kind of alias.",
errcode=Codes.EXCLUSIVE,
)
else:
if self.require_membership and check_membership:
rooms_for_user = yield self.store.get_rooms_for_user(user_id)
if room_id not in rooms_for_user:
raise AuthError(
403,
"You must be in the room to create an alias for it",
403, "You must be in the room to create an alias for it"
)
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
raise AuthError(
403, "This user is not permitted to create this alias",
)
raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
user_id, room_id, room_alias.to_string(),
user_id, room_id, room_alias.to_string()
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
raise SynapseError(
403, "Not allowed to create alias",
)
raise SynapseError(403, "Not allowed to create alias")
can_create = yield self.can_modify_alias(
room_alias,
user_id=user_id
)
can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
400, "This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE
400,
"This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE,
)
yield self._create_association(room_alias, room_id, servers, creator=user_id)
if send_event:
yield self.send_room_alias_update_event(
requester,
room_id
)
yield self.send_room_alias_update_event(requester, room_id)
@defer.inlineCallbacks
def delete_association(self, requester, room_alias, send_event=True):
@ -194,34 +188,24 @@ class DirectoryHandler(BaseHandler):
raise
if not can_delete:
raise AuthError(
403, "You don't have permission to delete the alias.",
)
raise AuthError(403, "You don't have permission to delete the alias.")
can_delete = yield self.can_modify_alias(
room_alias,
user_id=user_id
)
can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
400, "This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE
400,
"This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE,
)
room_id = yield self._delete_association(room_alias)
try:
if send_event:
yield self.send_room_alias_update_event(
requester,
room_id
)
yield self.send_room_alias_update_event(requester, room_id)
yield self._update_canonical_alias(
requester,
requester.user.to_string(),
room_id,
room_alias,
requester, requester.user.to_string(), room_id, room_alias
)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@ -234,7 +218,7 @@ class DirectoryHandler(BaseHandler):
raise SynapseError(
400,
"This application service has not reserved this kind of alias",
errcode=Codes.EXCLUSIVE
errcode=Codes.EXCLUSIVE,
)
yield self._delete_association(room_alias)
@ -251,9 +235,7 @@ class DirectoryHandler(BaseHandler):
def get_association(self, room_alias):
room_id = None
if self.hs.is_mine(room_alias):
result = yield self.get_association_from_room_alias(
room_alias
)
result = yield self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
@ -263,9 +245,7 @@ class DirectoryHandler(BaseHandler):
result = yield self.federation.make_query(
destination=room_alias.domain,
query_type="directory",
args={
"room_alias": room_alias.to_string(),
},
args={"room_alias": room_alias.to_string()},
retry_on_dns_fail=False,
ignore_backoff=True,
)
@ -284,7 +264,7 @@ class DirectoryHandler(BaseHandler):
raise SynapseError(
404,
"Room alias %s not found" % (room_alias.to_string(),),
Codes.NOT_FOUND
Codes.NOT_FOUND,
)
users = yield self.state.get_current_users_in_room(room_id)
@ -293,41 +273,28 @@ class DirectoryHandler(BaseHandler):
# If this server is in the list of servers, return it first.
if self.server_name in servers:
servers = (
[self.server_name] +
[s for s in servers if s != self.server_name]
)
servers = [self.server_name] + [s for s in servers if s != self.server_name]
else:
servers = list(servers)
defer.returnValue({
"room_id": room_id,
"servers": servers,
})
defer.returnValue({"room_id": room_id, "servers": servers})
return
@defer.inlineCallbacks
def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias):
raise SynapseError(
400, "Room Alias is not hosted on this Home Server"
)
raise SynapseError(400, "Room Alias is not hosted on this Home Server")
result = yield self.get_association_from_room_alias(
room_alias
)
result = yield self.get_association_from_room_alias(room_alias)
if result is not None:
defer.returnValue({
"room_id": result.room_id,
"servers": result.servers,
})
defer.returnValue({"room_id": result.room_id, "servers": result.servers})
else:
raise SynapseError(
404,
"Room alias %r not found" % (room_alias.to_string(),),
Codes.NOT_FOUND
Codes.NOT_FOUND,
)
@defer.inlineCallbacks
@ -343,7 +310,7 @@ class DirectoryHandler(BaseHandler):
"sender": requester.user.to_string(),
"content": {"aliases": aliases},
},
ratelimit=False
ratelimit=False,
)
@defer.inlineCallbacks
@ -365,14 +332,12 @@ class DirectoryHandler(BaseHandler):
"sender": user_id,
"content": {},
},
ratelimit=False
ratelimit=False,
)
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
result = yield self.store.get_association_from_room_alias(
room_alias
)
result = yield self.store.get_association_from_room_alias(room_alias)
if not result:
# Query AS to see if it exists
as_handler = self.appservice_handler
@ -421,8 +386,7 @@ class DirectoryHandler(BaseHandler):
if not self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
403,
"This user is not permitted to publish rooms to the room list"
403, "This user is not permitted to publish rooms to the room list"
)
if requester.is_guest:
@ -434,8 +398,7 @@ class DirectoryHandler(BaseHandler):
if visibility == "public" and not self.enable_room_list_search:
# The room list has been disabled.
raise AuthError(
403,
"This user is not permitted to publish rooms to the room list"
403, "This user is not permitted to publish rooms to the room list"
)
room = yield self.store.get_room(room_id)
@ -452,20 +415,19 @@ class DirectoryHandler(BaseHandler):
room_aliases.append(canonical_alias)
if not self.config.is_publishing_room_allowed(
user_id, room_id, room_aliases,
user_id, room_id, room_aliases
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
raise SynapseError(
403, "Not allowed to publish room",
)
raise SynapseError(403, "Not allowed to publish room")
yield self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks
def edit_published_appservice_room_list(self, appservice_id, network_id,
room_id, visibility):
def edit_published_appservice_room_list(
self, appservice_id, network_id, room_id, visibility
):
"""Add or remove a room from the appservice/network specific public
room list.

View file

@ -99,9 +99,7 @@ class E2eKeysHandler(object):
query_list.append((user_id, None))
user_ids_not_in_cache, remote_results = (
yield self.store.get_user_devices_from_cache(
query_list
)
yield self.store.get_user_devices_from_cache(query_list)
)
for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
@ -126,9 +124,7 @@ class E2eKeysHandler(object):
destination_query = remote_queries_not_in_cache[destination]
try:
remote_result = yield self.federation.query_client_keys(
destination,
{"device_keys": destination_query},
timeout=timeout
destination, {"device_keys": destination_query}, timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
@ -138,14 +134,17 @@ class E2eKeysHandler(object):
except Exception as e:
failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([
run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
], consumeErrors=True))
yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
],
consumeErrors=True,
)
)
defer.returnValue({
"device_keys": results, "failures": failures,
})
defer.returnValue({"device_keys": results, "failures": failures})
@defer.inlineCallbacks
def query_local_devices(self, query):
@ -165,8 +164,7 @@ class E2eKeysHandler(object):
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
if not device_ids:
@ -231,9 +229,7 @@ class E2eKeysHandler(object):
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
destination, {"one_time_keys": device_keys}, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
@ -241,25 +237,29 @@ class E2eKeysHandler(object):
except Exception as e:
failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([
run_in_background(claim_client_keys, destination)
for destination in remote_queries
], consumeErrors=True))
yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(claim_client_keys, destination)
for destination in remote_queries
],
consumeErrors=True,
)
)
logger.info(
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in iteritems(json_result)
for device_id, device_keys in iteritems(user_keys)
for key_id, _ in iteritems(device_keys)
)),
",".join(
(
"%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in iteritems(json_result)
for device_id, device_keys in iteritems(user_keys)
for key_id, _ in iteritems(device_keys)
)
),
)
defer.returnValue({
"one_time_keys": json_result,
"failures": failures
})
defer.returnValue({"one_time_keys": json_result, "failures": failures})
@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
@ -270,11 +270,13 @@ class E2eKeysHandler(object):
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
device_id,
user_id,
time_now,
)
# TODO: Sign the JSON with the server key
changed = yield self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys,
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
@ -283,7 +285,7 @@ class E2eKeysHandler(object):
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
yield self._upload_one_time_keys_for_user(
user_id, device_id, time_now, one_time_keys,
user_id, device_id, time_now, one_time_keys
)
# the device should have been registered already, but it may have been
@ -298,20 +300,22 @@ class E2eKeysHandler(object):
defer.returnValue({"one_time_key_counts": result})
@defer.inlineCallbacks
def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
one_time_keys):
def _upload_one_time_keys_for_user(
self, user_id, device_id, time_now, one_time_keys
):
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(), device_id, user_id, time_now,
one_time_keys.keys(),
device_id,
user_id,
time_now,
)
# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, key_obj
))
key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys.
existing_key_map = yield self.store.get_e2e_one_time_keys(
@ -325,42 +329,35 @@ class E2eKeysHandler(object):
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
("One time key %s:%s already exists. "
"Old key: %s; new key: %r") %
(algorithm, key_id, ex_json, key)
(
"One time key %s:%s already exists. "
"Old key: %s; new key: %r"
)
% (algorithm, key_id, ex_json, key),
)
else:
new_keys.append((
algorithm, key_id, encode_canonical_json(key).decode('ascii')))
new_keys.append(
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
)
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
)
yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
def _exception_to_failure(e):
if isinstance(e, CodeMessageException):
return {
"status": e.code, "message": str(e),
}
return {"status": e.code, "message": str(e)}
if isinstance(e, NotRetryingDestination):
return {
"status": 503, "message": "Not ready for retry",
}
return {"status": 503, "message": "Not ready for retry"}
if isinstance(e, FederationDeniedError):
return {
"status": 403, "message": "Federation Denied",
}
return {"status": 403, "message": "Federation Denied"}
# include ConnectionRefused and other errors
#
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which json then fails to serialize.
return {
"status": 503, "message": str(e),
}
return {"status": 503, "message": str(e)}
def _one_time_keys_match(old_key_json, new_key):

View file

@ -152,14 +152,14 @@ class E2eRoomKeysHandler(object):
else:
raise
if version_info['version'] != version:
if version_info["version"] != version:
# Check that the version we're trying to upload actually exists
try:
version_info = yield self.store.get_e2e_room_keys_version_info(
user_id, version,
user_id, version
)
# if we get this far, the version must exist
raise RoomKeysVersionError(current_version=version_info['version'])
raise RoomKeysVersionError(current_version=version_info["version"])
except StoreError as e:
if e.code == 404:
raise NotFoundError("Version '%s' not found" % (version,))
@ -168,8 +168,8 @@ class E2eRoomKeysHandler(object):
# go through the room_keys.
# XXX: this should/could be done concurrently, given we're in a lock.
for room_id, room in iteritems(room_keys['rooms']):
for session_id, session in iteritems(room['sessions']):
for room_id, room in iteritems(room_keys["rooms"]):
for session_id, session in iteritems(room["sessions"]):
yield self._upload_room_key(
user_id, version, room_id, session_id, session
)
@ -223,14 +223,14 @@ class E2eRoomKeysHandler(object):
# spelt out with if/elifs rather than nested boolean expressions
# purely for legibility.
if room_key['is_verified'] and not current_room_key['is_verified']:
if room_key["is_verified"] and not current_room_key["is_verified"]:
return True
elif (
room_key['first_message_index'] <
current_room_key['first_message_index']
room_key["first_message_index"]
< current_room_key["first_message_index"]
):
return True
elif room_key['forwarded_count'] < current_room_key['forwarded_count']:
elif room_key["forwarded_count"] < current_room_key["forwarded_count"]:
return True
else:
return False
@ -328,16 +328,10 @@ class E2eRoomKeysHandler(object):
A deferred of an empty dict.
"""
if "version" not in version_info:
raise SynapseError(
400,
"Missing version in body",
Codes.MISSING_PARAM
)
raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM)
if version_info["version"] != version:
raise SynapseError(
400,
"Version in body does not match",
Codes.INVALID_PARAM
400, "Version in body does not match", Codes.INVALID_PARAM
)
with (yield self._upload_linearizer.queue(user_id)):
try:
@ -350,12 +344,10 @@ class E2eRoomKeysHandler(object):
else:
raise
if old_info["algorithm"] != version_info["algorithm"]:
raise SynapseError(
400,
"Algorithm does not match",
Codes.INVALID_PARAM
)
raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
yield self.store.update_e2e_room_keys_version(user_id, version, version_info)
yield self.store.update_e2e_room_keys_version(
user_id, version, version_info
)
defer.returnValue({})

View file

@ -31,7 +31,6 @@ logger = logging.getLogger(__name__)
class EventStreamHandler(BaseHandler):
def __init__(self, hs):
super(EventStreamHandler, self).__init__(hs)
@ -53,9 +52,17 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True,
only_keys=None, room_id=None, is_guest=False):
def get_stream(
self,
auth_user_id,
pagin_config,
timeout=0,
as_client_event=True,
affect_presence=True,
only_keys=None,
room_id=None,
is_guest=False,
):
"""Fetches the events stream for a given user.
If `only_keys` is not None, events from keys will be sent down.
@ -73,7 +80,7 @@ class EventStreamHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
context = yield presence_handler.user_syncing(
auth_user_id, affect_presence=affect_presence,
auth_user_id, affect_presence=affect_presence
)
with context:
if timeout:
@ -85,9 +92,12 @@ class EventStreamHandler(BaseHandler):
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout,
auth_user,
pagin_config,
timeout,
only_keys=only_keys,
is_guest=is_guest, explicit_room_id=room_id
is_guest=is_guest,
explicit_room_id=room_id,
)
# When the user joins a new room, or another user joins a currently
@ -102,17 +112,15 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = yield self.state.get_current_users_in_room(event.room_id)
states = yield presence_handler.get_states(
users,
as_event=True,
users = yield self.state.get_current_users_in_room(
event.room_id
)
states = yield presence_handler.get_states(users, as_event=True)
to_add.extend(states)
else:
ev = yield presence_handler.get_state(
UserID.from_string(event.state_key),
as_event=True,
UserID.from_string(event.state_key), as_event=True
)
to_add.append(ev)
@ -121,7 +129,9 @@ class EventStreamHandler(BaseHandler):
time_now = self.clock.time_msec()
chunks = yield self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event,
events,
time_now,
as_client_event=as_client_event,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_aggregations=False,
@ -137,7 +147,6 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler):
@defer.inlineCallbacks
def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event.
@ -164,16 +173,10 @@ class EventHandler(BaseHandler):
is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client(
self.store,
user.to_string(),
[event],
is_peeking=is_peeking
self.store, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:
raise AuthError(
403,
"You don't have permission to access that event."
)
raise AuthError(403, "You don't have permission to access that event.")
defer.returnValue(event)

File diff suppressed because it is too large Load diff

View file

@ -30,6 +30,7 @@ def _create_rerouter(func_name):
"""Returns a function that looks at the group id and calls the function
on federation or the local group server if the group is local
"""
def f(self, group_id, *args, **kwargs):
if self.is_mine_id(group_id):
return getattr(self.groups_server_handler, func_name)(
@ -49,9 +50,7 @@ def _create_rerouter(func_name):
def http_response_errback(failure):
failure.trap(HttpResponseException)
e = failure.value
if e.code == 403:
raise e.to_synapse_error()
return failure
raise e.to_synapse_error()
def request_failed_errback(failure):
failure.trap(RequestSendFailed)
@ -60,6 +59,7 @@ def _create_rerouter(func_name):
d.addErrback(http_response_errback)
d.addErrback(request_failed_errback)
return d
return f
@ -127,7 +127,7 @@ class GroupsLocalHandler(object):
)
else:
res = yield self.transport_client.get_group_summary(
get_domain_from_id(group_id), group_id, requester_user_id,
get_domain_from_id(group_id), group_id, requester_user_id
)
group_server_name = get_domain_from_id(group_id)
@ -184,7 +184,7 @@ class GroupsLocalHandler(object):
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
res = yield self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content,
get_domain_from_id(group_id), group_id, user_id, content
)
remote_attestation = res["attestation"]
@ -197,16 +197,15 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
group_id, user_id,
group_id,
user_id,
membership="join",
is_admin=True,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
defer.returnValue(res)
@ -223,7 +222,7 @@ class GroupsLocalHandler(object):
group_server_name = get_domain_from_id(group_id)
res = yield self.transport_client.get_users_in_group(
get_domain_from_id(group_id), group_id, requester_user_id,
get_domain_from_id(group_id), group_id, requester_user_id
)
chunk = res["chunk"]
@ -252,9 +251,7 @@ class GroupsLocalHandler(object):
"""Request to join a group
"""
if self.is_mine_id(group_id):
yield self.groups_server_handler.join_group(
group_id, user_id, content
)
yield self.groups_server_handler.join_group(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
@ -262,7 +259,7 @@ class GroupsLocalHandler(object):
content["attestation"] = local_attestation
res = yield self.transport_client.join_group(
get_domain_from_id(group_id), group_id, user_id, content,
get_domain_from_id(group_id), group_id, user_id, content
)
remote_attestation = res["attestation"]
@ -278,16 +275,15 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
group_id, user_id,
group_id,
user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
defer.returnValue({})
@ -296,9 +292,7 @@ class GroupsLocalHandler(object):
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
yield self.groups_server_handler.accept_invite(
group_id, user_id, content
)
yield self.groups_server_handler.accept_invite(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
@ -306,7 +300,7 @@ class GroupsLocalHandler(object):
content["attestation"] = local_attestation
res = yield self.transport_client.accept_group_invite(
get_domain_from_id(group_id), group_id, user_id, content,
get_domain_from_id(group_id), group_id, user_id, content
)
remote_attestation = res["attestation"]
@ -322,16 +316,15 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
group_id, user_id,
group_id,
user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
defer.returnValue({})
@ -339,17 +332,17 @@ class GroupsLocalHandler(object):
def invite(self, group_id, user_id, requester_user_id, config):
"""Invite a user to a group
"""
content = {
"requester_user_id": requester_user_id,
"config": config,
}
content = {"requester_user_id": requester_user_id, "config": config}
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.invite_to_group(
group_id, user_id, requester_user_id, content,
group_id, user_id, requester_user_id, content
)
else:
res = yield self.transport_client.invite_to_group(
get_domain_from_id(group_id), group_id, user_id, requester_user_id,
get_domain_from_id(group_id),
group_id,
user_id,
requester_user_id,
content,
)
@ -372,13 +365,12 @@ class GroupsLocalHandler(object):
local_profile["avatar_url"] = content["profile"]["avatar_url"]
token = yield self.store.register_user_group_membership(
group_id, user_id,
group_id,
user_id,
membership="invite",
content={"profile": local_profile, "inviter": content["inviter"]},
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
try:
user_profile = yield self.profile_handler.get_profile(user_id)
except Exception as e:
@ -393,25 +385,25 @@ class GroupsLocalHandler(object):
"""
if user_id == requester_user_id:
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="leave",
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
group_id, user_id, membership="leave"
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
# TODO: Should probably remember that we tried to leave so that we can
# retry if the group server is currently down.
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content,
group_id, user_id, requester_user_id, content
)
else:
content["requester_user_id"] = requester_user_id
res = yield self.transport_client.remove_user_from_group(
get_domain_from_id(group_id), group_id, requester_user_id,
user_id, content,
get_domain_from_id(group_id),
group_id,
requester_user_id,
user_id,
content,
)
defer.returnValue(res)
@ -422,12 +414,9 @@ class GroupsLocalHandler(object):
"""
# TODO: Check if user in group
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="leave",
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
group_id, user_id, membership="leave"
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
@defer.inlineCallbacks
def get_joined_groups(self, user_id):
@ -447,7 +436,7 @@ class GroupsLocalHandler(object):
defer.returnValue({"groups": result})
else:
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), [user_id],
get_domain_from_id(user_id), [user_id]
)
result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations
@ -462,9 +451,7 @@ class GroupsLocalHandler(object):
if self.hs.is_mine_id(user_id):
local_users.add(user_id)
else:
destinations.setdefault(
get_domain_from_id(user_id), set()
).add(user_id)
destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id)
if not proxy and destinations:
raise SynapseError(400, "Some user_ids are not local")
@ -474,16 +461,14 @@ class GroupsLocalHandler(object):
for destination, dest_user_ids in iteritems(destinations):
try:
r = yield self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids),
destination, list(dest_user_ids)
)
results.update(r["users"])
except Exception:
failed_results.extend(dest_user_ids)
for uid in local_users:
results[uid] = yield self.store.get_publicised_groups_for_user(
uid
)
results[uid] = yield self.store.get_publicised_groups_for_user(uid)
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)

View file

@ -36,7 +36,6 @@ logger = logging.getLogger(__name__)
class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
@ -64,40 +63,38 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
id_server = creds['idServer']
if "id_server" in creds:
id_server = creds["id_server"]
elif "idServer" in creds:
id_server = creds["idServer"]
else:
raise SynapseError(400, "No id_server in creds")
if 'client_secret' in creds:
client_secret = creds['client_secret']
elif 'clientSecret' in creds:
client_secret = creds['clientSecret']
if "client_secret" in creds:
client_secret = creds["client_secret"]
elif "clientSecret" in creds:
client_secret = creds["clientSecret"]
else:
raise SynapseError(400, "No client_secret in creds")
if not self._should_trust_id_server(id_server):
logger.warn(
'%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server
"%s is not a trusted ID server: rejecting 3pid " + "credentials",
id_server,
)
defer.returnValue(None)
try:
data = yield self.http_client.get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'client_secret': client_secret}
"https://%s%s"
% (id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid"),
{"sid": creds["sid"], "client_secret": client_secret},
)
except HttpResponseException as e:
logger.info("getValidated3pid failed with Matrix error: %r", e)
raise e.to_synapse_error()
if 'medium' in data:
if "medium" in data:
defer.returnValue(data)
defer.returnValue(None)
@ -106,30 +103,24 @@ class IdentityHandler(BaseHandler):
logger.debug("binding threepid %r to %s", creds, mxid)
data = None
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
id_server = creds['idServer']
if "id_server" in creds:
id_server = creds["id_server"]
elif "idServer" in creds:
id_server = creds["idServer"]
else:
raise SynapseError(400, "No id_server in creds")
if 'client_secret' in creds:
client_secret = creds['client_secret']
elif 'clientSecret' in creds:
client_secret = creds['clientSecret']
if "client_secret" in creds:
client_secret = creds["client_secret"]
elif "clientSecret" in creds:
client_secret = creds["clientSecret"]
else:
raise SynapseError(400, "No client_secret in creds")
try:
data = yield self.http_client.post_urlencoded_get_json(
"https://%s%s" % (
id_server, "/_matrix/identity/api/v1/3pid/bind"
),
{
'sid': creds['sid'],
'client_secret': client_secret,
'mxid': mxid,
}
"https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"),
{"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid},
)
logger.debug("bound threepid %r to %s", creds, mxid)
@ -165,9 +156,7 @@ class IdentityHandler(BaseHandler):
id_servers = [threepid["id_server"]]
else:
id_servers = yield self.store.get_id_servers_user_bound(
user_id=mxid,
medium=threepid["medium"],
address=threepid["address"],
user_id=mxid, medium=threepid["medium"], address=threepid["address"]
)
# We don't know where to unbind, so we don't have a choice but to return
@ -177,7 +166,7 @@ class IdentityHandler(BaseHandler):
changed = True
for id_server in id_servers:
changed &= yield self.try_unbind_threepid_with_id_server(
mxid, threepid, id_server,
mxid, threepid, id_server
)
defer.returnValue(changed)
@ -201,10 +190,7 @@ class IdentityHandler(BaseHandler):
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
content = {
"mxid": mxid,
"threepid": {
"medium": threepid["medium"],
"address": threepid["address"],
},
"threepid": {"medium": threepid["medium"], "address": threepid["address"]},
}
# we abuse the federation http client to sign the request, but we have to send it
@ -212,25 +198,19 @@ class IdentityHandler(BaseHandler):
# 'browser-like' HTTPS.
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
method='POST',
url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
method="POST",
url_bytes="/_matrix/identity/api/v1/3pid/unbind".encode("ascii"),
content=content,
destination_is=id_server,
)
headers = {
b"Authorization": auth_headers,
}
headers = {b"Authorization": auth_headers}
try:
yield self.http_client.post_json_get_json(
url,
content,
headers,
)
yield self.http_client.post_json_get_json(url, content, headers)
changed = True
except HttpResponseException as e:
changed = False
if e.code in (400, 404, 501,):
if e.code in (400, 404, 501):
# The remote server probably doesn't support unbinding (yet)
logger.warn("Received %d response while unbinding threepid", e.code)
else:
@ -248,35 +228,27 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def requestEmailToken(
self,
id_server,
email,
client_secret,
send_attempt,
next_link=None,
self, id_server, email, client_secret, send_attempt, next_link=None
):
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
Codes.SERVER_NOT_TRUSTED
400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
)
params = {
'email': email,
'client_secret': client_secret,
'send_attempt': send_attempt,
"email": email,
"client_secret": client_secret,
"send_attempt": send_attempt,
}
if next_link:
params.update({'next_link': next_link})
params.update({"next_link": next_link})
try:
data = yield self.http_client.post_json_get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/validate/email/requestToken"
),
params
"https://%s%s"
% (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"),
params,
)
defer.returnValue(data)
except HttpResponseException as e:
@ -285,30 +257,26 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def requestMsisdnToken(
self, id_server, country, phone_number,
client_secret, send_attempt, **kwargs
self, id_server, country, phone_number, client_secret, send_attempt, **kwargs
):
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
Codes.SERVER_NOT_TRUSTED
400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
)
params = {
'country': country,
'phone_number': phone_number,
'client_secret': client_secret,
'send_attempt': send_attempt,
"country": country,
"phone_number": phone_number,
"client_secret": client_secret,
"send_attempt": send_attempt,
}
params.update(kwargs)
try:
data = yield self.http_client.post_json_get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/validate/msisdn/requestToken"
),
params
"https://%s%s"
% (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"),
params,
)
defer.returnValue(data)
except HttpResponseException as e:

View file

@ -44,8 +44,13 @@ class InitialSyncHandler(BaseHandler):
self.snapshot_cache = SnapshotCache()
self._event_serializer = hs.get_event_client_serializer()
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
def snapshot_all_rooms(
self,
user_id=None,
pagin_config=None,
as_client_event=True,
include_archived=False,
):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
@ -77,13 +82,22 @@ class InitialSyncHandler(BaseHandler):
if result is not None:
return result
return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
))
return self.snapshot_cache.set(
now_ms,
key,
self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
),
)
@defer.inlineCallbacks
def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
def _snapshot_all_rooms(
self,
user_id=None,
pagin_config=None,
as_client_event=True,
include_archived=False,
):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
@ -128,8 +142,7 @@ class InitialSyncHandler(BaseHandler):
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
"public" if event.room_id in public_room_ids
else "private"
"public" if event.room_id in public_room_ids else "private"
),
}
@ -139,7 +152,7 @@ class InitialSyncHandler(BaseHandler):
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = yield self._event_serializer.serialize_event(
invite_event, time_now, as_client_event,
invite_event, time_now, as_client_event
)
rooms_ret.append(d)
@ -151,14 +164,12 @@ class InitialSyncHandler(BaseHandler):
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = run_in_background(
self.state_handler.get_current_state,
event.room_id,
self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background(
self.store.get_state_for_events,
[event.event_id],
self.store.get_state_for_events, [event.event_id]
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
@ -178,9 +189,7 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages
)
messages = yield filter_events_for_client(self.store, user_id, messages)
start_token = now_token.copy_and_replace("room_key", token)
end_token = now_token.copy_and_replace("room_key", room_end_token)
@ -189,8 +198,7 @@ class InitialSyncHandler(BaseHandler):
d["messages"] = {
"chunk": (
yield self._event_serializer.serialize_events(
messages, time_now=time_now,
as_client_event=as_client_event,
messages, time_now=time_now, as_client_event=as_client_event
)
),
"start": start_token.to_string(),
@ -200,23 +208,21 @@ class InitialSyncHandler(BaseHandler):
d["state"] = yield self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
as_client_event=as_client_event
as_client_event=as_client_event,
)
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data_events.append(
{"type": "m.tag", "content": {"tags": tags}}
)
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
account_data_events.append(
{"type": account_data_type, "content": content}
)
d["account_data"] = account_data_events
except Exception:
@ -226,10 +232,7 @@ class InitialSyncHandler(BaseHandler):
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
account_data_events.append({"type": account_data_type, "content": content})
now = self.clock.time_msec()
@ -274,7 +277,7 @@ class InitialSyncHandler(BaseHandler):
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id,
room_id, user_id
)
is_peeking = member_event_id is None
@ -290,28 +293,21 @@ class InitialSyncHandler(BaseHandler):
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
account_data_events.append({"type": account_data_type, "content": content})
result["account_data"] = account_data_events
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
[member_event_id],
)
def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
):
room_state = yield self.store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@ -319,14 +315,10 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event_id
)
stream_token = yield self.store.get_stream_token_for_event(member_event_id)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
room_id, limit=limit, end_token=stream_token
)
messages = yield filter_events_for_client(
@ -338,34 +330,39 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
defer.returnValue({
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": (yield self._event_serializer.serialize_events(
messages, time_now,
)),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": (yield self._event_serializer.serialize_events(
room_state.values(), time_now,
)),
"presence": [],
"receipts": [],
})
defer.returnValue(
{
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": (
yield self._event_serializer.serialize_events(
messages, time_now
)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": (
yield self._event_serializer.serialize_events(
room_state.values(), time_now
)
),
"presence": [],
"receipts": [],
}
)
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_peeking):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
def _room_initial_sync_joined(
self, user_id, room_id, pagin_config, membership, is_peeking
):
current_state = yield self.state.get_current_state(room_id=room_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = yield self._event_serializer.serialize_events(
current_state.values(), time_now,
current_state.values(), time_now
)
now_token = yield self.hs.get_event_sources().get_current_token()
@ -375,7 +372,8 @@ class InitialSyncHandler(BaseHandler):
limit = 10
room_members = [
m for m in current_state.values()
m
for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
@ -389,8 +387,7 @@ class InitialSyncHandler(BaseHandler):
defer.returnValue([])
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
[m.user_id for m in room_members], as_event=True
)
defer.returnValue(states)
@ -398,8 +395,7 @@ class InitialSyncHandler(BaseHandler):
@defer.inlineCallbacks
def get_receipts():
receipts = yield self.store.get_linearized_receipts_for_room(
room_id,
to_key=now_token.receipt_key,
room_id, to_key=now_token.receipt_key
)
if not receipts:
receipts = []
@ -415,14 +411,14 @@ class InitialSyncHandler(BaseHandler):
room_id,
limit=limit,
end_token=now_token.room_key,
)
),
],
consumeErrors=True,
).addErrback(unwrapFirstError),
).addErrback(unwrapFirstError)
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
self.store, user_id, messages, is_peeking=is_peeking
)
start_token = now_token.copy_and_replace("room_key", token)
@ -433,9 +429,9 @@ class InitialSyncHandler(BaseHandler):
ret = {
"room_id": room_id,
"messages": {
"chunk": (yield self._event_serializer.serialize_events(
messages, time_now,
)),
"chunk": (
yield self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
@ -464,8 +460,8 @@ class InitialSyncHandler(BaseHandler):
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility and
visibility.content["history_visibility"] == "world_readable"
visibility
and visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return

View file

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 - 2018 New Vector Ltd
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -33,9 +34,10 @@ from synapse.api.errors import (
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events.validator import EventValidator
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import run_in_background
@ -59,8 +61,9 @@ class MessageHandler(object):
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key="", is_guest=False):
def get_room_data(
self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False
):
""" Get data from a room.
Args:
@ -75,9 +78,7 @@ class MessageHandler(object):
)
if membership == Membership.JOIN:
data = yield self.state.get_current_state(
room_id, event_type, state_key
)
data = yield self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
@ -89,8 +90,12 @@ class MessageHandler(object):
@defer.inlineCallbacks
def get_state_events(
self, user_id, room_id, state_filter=StateFilter.all(),
at_token=None, is_guest=False,
self,
user_id,
room_id,
state_filter=StateFilter.all(),
at_token=None,
is_guest=False,
):
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
@ -122,50 +127,48 @@ class MessageHandler(object):
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
room_id, end_token=at_token.room_key, limit=1,
room_id, end_token=at_token.room_key, limit=1
)
if not last_events:
raise NotFoundError("Can't find event for token %s" % (at_token, ))
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
self.store, user_id, last_events,
self.store, user_id, last_events
)
event = last_events[0]
if visible_events:
room_state = yield self.store.get_state_for_events(
[event.event_id], state_filter=state_filter,
[event.event_id], state_filter=state_filter
)
room_state = room_state[event.event_id]
else:
raise AuthError(
403,
"User %s not allowed to view events in room %s at token %s" % (
user_id, room_id, at_token,
)
"User %s not allowed to view events in room %s at token %s"
% (user_id, room_id, at_token),
)
else:
membership, membership_event_id = (
yield self.auth.check_in_room_or_world_readable(
room_id, user_id,
)
yield self.auth.check_in_room_or_world_readable(room_id, user_id)
)
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
room_id, state_filter=state_filter,
room_id, state_filter=state_filter
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
[membership_event_id], state_filter=state_filter,
[membership_event_id], state_filter=state_filter
)
room_state = room_state[membership_event_id]
now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(
room_state.values(), now,
room_state.values(),
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_aggregations=False,
@ -209,13 +212,15 @@ class MessageHandler(object):
# Loop fell through, AS has no interested users in room
raise AuthError(403, "Appservice not in room")
defer.returnValue({
user_id: {
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
defer.returnValue(
{
user_id: {
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
for user_id, profile in iteritems(users_with_profile)
}
for user_id, profile in iteritems(users_with_profile)
})
)
class EventCreationHandler(object):
@ -248,6 +253,7 @@ class EventCreationHandler(object):
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = hs.get_third_party_event_rules()
self._block_events_without_consent_error = (
self.config.block_events_without_consent_error
@ -259,9 +265,28 @@ class EventCreationHandler(object):
if self._block_events_without_consent_error:
self._consent_uri_builder = ConsentURIBuilder(self.config)
if (
not self.config.worker_app
and self.config.cleanup_extremities_with_dummy_events
):
self.clock.looping_call(
lambda: run_as_background_process(
"send_dummy_events_to_fill_extremities",
self._send_dummy_events_to_fill_extremities,
),
5 * 60 * 1000,
)
@defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_events_and_hashes=None, require_consent=True):
def create_event(
self,
requester,
event_dict,
token_id=None,
txn_id=None,
prev_events_and_hashes=None,
require_consent=True,
):
"""
Given a dict from a client, create a new event.
@ -321,8 +346,7 @@ class EventCreationHandler(object):
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
"Failed to get profile information for %r: %s", target, e
)
is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
@ -358,16 +382,17 @@ class EventCreationHandler(object):
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event or prev_event.membership != Membership.JOIN:
logger.warning(
("Attempt to send `m.room.aliases` in room %s by user %s but"
" membership is %s"),
(
"Attempt to send `m.room.aliases` in room %s by user %s but"
" membership is %s"
),
event.room_id,
event.sender,
prev_event.membership if prev_event else None,
)
raise AuthError(
403,
"You must be in the room to create an alias for it",
403, "You must be in the room to create an alias for it"
)
self.validator.validate_new(event)
@ -434,8 +459,8 @@ class EventCreationHandler(object):
# exempt the system notices user
if (
self.config.server_notices_mxid is not None and
user_id == self.config.server_notices_mxid
self.config.server_notices_mxid is not None
and user_id == self.config.server_notices_mxid
):
return
@ -448,15 +473,10 @@ class EventCreationHandler(object):
return
consent_uri = self._consent_uri_builder.build_user_consent_uri(
requester.user.localpart,
)
msg = self._block_events_without_consent_error % {
'consent_uri': consent_uri,
}
raise ConsentNotGivenError(
msg=msg,
consent_uri=consent_uri,
requester.user.localpart
)
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
@ -471,8 +491,7 @@ class EventCreationHandler(object):
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
500, "Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender)
@ -484,15 +503,13 @@ class EventCreationHandler(object):
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
event.event_id, prev_state.event_id,
event.event_id,
prev_state.event_id,
)
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
requester=requester, event=event, context=context, ratelimit=ratelimit
)
@defer.inlineCallbacks
@ -518,11 +535,7 @@ class EventCreationHandler(object):
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
self, requester, event_dict, ratelimit=True, txn_id=None
):
"""
Creates an event, then sends it.
@ -537,32 +550,25 @@ class EventCreationHandler(object):
# taking longer.
with (yield self.limiter.queue(event_dict["room_id"])):
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, string_types):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
requester, event, context, ratelimit=ratelimit
)
defer.returnValue(event)
@measure_func("create_new_client_event")
@defer.inlineCallbacks
def create_new_client_event(self, builder, requester=None,
prev_events_and_hashes=None):
def create_new_client_event(
self, builder, requester=None, prev_events_and_hashes=None
):
"""Create a new event for a local client
Args:
@ -582,22 +588,21 @@ class EventCreationHandler(object):
"""
if prev_events_and_hashes is not None:
assert len(prev_events_and_hashes) <= 10, \
"Attempting to create an event with %i prev_events" % (
len(prev_events_and_hashes),
assert len(prev_events_and_hashes) <= 10, (
"Attempting to create an event with %i prev_events"
% (len(prev_events_and_hashes),)
)
else:
prev_events_and_hashes = \
yield self.store.get_prev_events_for_room(builder.room_id)
prev_events_and_hashes = yield self.store.get_prev_events_for_room(
builder.room_id
)
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in prev_events_and_hashes
]
event = yield builder.build(
prev_event_ids=[p for p, _ in prev_events],
)
event = yield builder.build(prev_event_ids=[p for p, _ in prev_events])
context = yield self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
@ -613,29 +618,19 @@ class EventCreationHandler(object):
aggregation_key = relation["key"]
already_exists = yield self.store.has_user_annotated_event(
relates_to, event.type, aggregation_key, event.sender,
relates_to, event.type, aggregation_key, event.sender
)
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")
logger.debug(
"Created event %s",
event.event_id,
)
logger.debug("Created event %s", event.event_id)
defer.returnValue(
(event, context,)
)
defer.returnValue((event, context))
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
def handle_new_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[],
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
@ -651,13 +646,22 @@ class EventCreationHandler(object):
extra_users (list(UserID)): Any extra users to notify about event
"""
if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
room_version = event.content.get(
"room_version", RoomVersions.V1.identifier
)
if event.is_state() and (event.type, event.state_key) == (
EventTypes.Create,
"",
):
room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else:
room_version = yield self.store.get_room_version(event.room_id)
event_allowed = yield self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as err:
@ -672,9 +676,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content)
raise
yield self.action_generator.handle_push_actions_for_event(
event, context
)
yield self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
@ -695,11 +697,7 @@ class EventCreationHandler(object):
return
yield self.persist_and_notify_client_event(
requester,
event,
context,
ratelimit=ratelimit,
extra_users=extra_users,
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
success = True
@ -708,18 +706,12 @@ class EventCreationHandler(object):
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
run_in_background(
self.store.remove_push_actions_from_staging,
event.event_id,
self.store.remove_push_actions_from_staging, event.event_id
)
@defer.inlineCallbacks
def persist_and_notify_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[],
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
@ -744,20 +736,16 @@ class EventCreationHandler(object):
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (
room_alias_str,
)
"Room alias %s does not point to the room" % (room_alias_str,),
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
def is_inviter_member_event(e):
return (
e.type == EventTypes.Member and
e.sender == event.sender
)
return e.type == EventTypes.Member and e.sender == event.sender
current_state_ids = yield context.get_current_state_ids(self.store)
@ -787,26 +775,21 @@ class EventCreationHandler(object):
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
invitee.domain, event
)
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
)
event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True,
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version = yield self.store.get_room_version(event.room_id)
if self.auth.check_redaction(room_version, event, auth_events=auth_events):
original_event = yield self.store.get_event(
@ -814,13 +797,10 @@ class EventCreationHandler(object):
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
allow_none=False
allow_none=False,
)
if event.user_id != original_event.user_id:
raise AuthError(
403,
"You don't have permission to redact events"
)
raise AuthError(403, "You don't have permission to redact events")
# We've already checked.
event.internal_metadata.recheck_redaction = False
@ -828,24 +808,18 @@ class EventCreationHandler(object):
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids(self.store)
if prev_state_ids:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
raise AuthError(403, "Changing the room create event is forbidden")
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
yield self.pusher_pool.on_new_notifications(
event_stream_id, max_stream_id,
)
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():
try:
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
event, event_stream_id, max_stream_id, extra_users=extra_users
)
except Exception:
logger.exception("Error notifying about new room event")
@ -864,3 +838,54 @@ class EventCreationHandler(object):
yield presence.bump_presence_active_time(user)
except Exception:
logger.exception("Error bumping presence active time")
@defer.inlineCallbacks
def _send_dummy_events_to_fill_extremities(self):
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
room_ids = yield self.store.get_rooms_with_many_extremities(
min_count=10, limit=5
)
for room_id in room_ids:
# For each room we need to find a joined member we can use to send
# the dummy event with.
prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id)
latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)
members = yield self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids
)
user_id = None
for member in members:
if self.hs.is_mine_id(member):
user_id = member
break
if not user_id:
# We don't have a joined user.
# TODO: We should do something here to stop the room from
# appearing next time.
continue
requester = create_requester(user_id)
event, context = yield self.create_event(
requester,
{
"type": "org.matrix.dummy_event",
"content": {},
"room_id": room_id,
"sender": user_id,
},
prev_events_and_hashes=prev_events_and_hashes,
)
event.internal_metadata.proactively_send = False
yield self.send_nonmember_event(requester, event, context, ratelimit=False)

View file

@ -55,9 +55,7 @@ class PurgeStatus(object):
self.status = PurgeStatus.STATUS_ACTIVE
def asdict(self):
return {
"status": PurgeStatus.STATUS_TEXT[self.status]
}
return {"status": PurgeStatus.STATUS_TEXT[self.status]}
class PaginationHandler(object):
@ -79,8 +77,7 @@ class PaginationHandler(object):
self._purges_by_id = {}
self._event_serializer = hs.get_event_client_serializer()
def start_purge_history(self, room_id, token,
delete_local_events=False):
def start_purge_history(self, room_id, token, delete_local_events=False):
"""Start off a history purge on a room.
Args:
@ -95,8 +92,7 @@ class PaginationHandler(object):
"""
if room_id in self._purges_in_progress_by_room:
raise SynapseError(
400,
"History purge already in progress for %s" % (room_id, ),
400, "History purge already in progress for %s" % (room_id,)
)
purge_id = random_string(16)
@ -107,14 +103,12 @@ class PaginationHandler(object):
self._purges_by_id[purge_id] = PurgeStatus()
run_in_background(
self._purge_history,
purge_id, room_id, token, delete_local_events,
self._purge_history, purge_id, room_id, token, delete_local_events
)
return purge_id
@defer.inlineCallbacks
def _purge_history(self, purge_id, room_id, token,
delete_local_events):
def _purge_history(self, purge_id, room_id, token, delete_local_events):
"""Carry out a history purge on a room.
Args:
@ -130,16 +124,13 @@ class PaginationHandler(object):
self._purges_in_progress_by_room.add(room_id)
try:
with (yield self.pagination_lock.write(room_id)):
yield self.store.purge_history(
room_id, token, delete_local_events,
)
yield self.store.purge_history(room_id, token, delete_local_events)
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
except Exception:
f = Failure()
logger.error(
"[purge] failed",
exc_info=(f.type, f.value, f.getTracebackObject()),
"[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
)
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
@ -148,6 +139,7 @@ class PaginationHandler(object):
# remove the purge from the list 24 hours after it completes
def clear_purge():
del self._purges_by_id[purge_id]
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id):
@ -162,8 +154,14 @@ class PaginationHandler(object):
return self._purges_by_id.get(purge_id)
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True, event_filter=None):
def get_messages(
self,
requester,
room_id=None,
pagin_config=None,
as_client_event=True,
event_filter=None,
):
"""Get messages in a room.
Args:
@ -182,9 +180,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token_for_room(
room_id=room_id
)
yield self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
@ -201,7 +197,7 @@ class PaginationHandler(object):
room_id, user_id
)
if source_config.direction == 'b':
if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
@ -235,27 +231,24 @@ class PaginationHandler(object):
event_filter=event_filter,
)
next_token = pagin_config.from_token.copy_and_replace(
"room_key", next_key
)
next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
if events:
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client(
self.store,
user_id,
events,
is_peeking=(member_event_id is None),
self.store, user_id, events, is_peeking=(member_event_id is None)
)
if not events:
defer.returnValue({
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
})
defer.returnValue(
{
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
}
)
state = None
if event_filter and event_filter.lazy_load_members() and len(events) > 0:
@ -263,12 +256,11 @@ class PaginationHandler(object):
# FIXME: we also care about invite targets etc.
state_filter = StateFilter.from_types(
(EventTypes.Member, event.sender)
for event in events
(EventTypes.Member, event.sender) for event in events
)
state_ids = yield self.store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter,
events[0].event_id, state_filter=state_filter
)
if state_ids:
@ -280,8 +272,7 @@ class PaginationHandler(object):
chunk = {
"chunk": (
yield self._event_serializer.serialize_events(
events, time_now,
as_client_event=as_client_event,
events, time_now, as_client_event=as_client_event
)
),
"start": pagin_config.from_token.to_string(),
@ -291,8 +282,7 @@ class PaginationHandler(object):
if state:
chunk["state"] = (
yield self._event_serializer.serialize_events(
state, time_now,
as_client_event=as_client_event,
state, time_now, as_client_event=as_client_event
)
)

Some files were not shown because too many files have changed in this diff Show more