mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-06-19 22:34:07 -04:00
Merge branch 'develop' into rav/saml2_client
This commit is contained in:
commit
a4daa899ec
478 changed files with 18927 additions and 11500 deletions
|
@ -17,14 +17,22 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
# Check that we're not running on an unsupported Python version.
|
||||
if sys.version_info < (3, 5):
|
||||
print("Synapse requires Python 3.5 or above.")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from twisted.internet import protocol
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.names.dns import DNSDatagramProtocol
|
||||
|
||||
protocol.Factory.noisy = False
|
||||
Factory.noisy = False
|
||||
DNSDatagramProtocol.noisy = False
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.0.0rc3"
|
||||
__version__ = "1.0.0"
|
||||
|
|
|
@ -57,18 +57,18 @@ def request_registration(
|
|||
|
||||
nonce = r.json()["nonce"]
|
||||
|
||||
mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1)
|
||||
mac = hmac.new(key=shared_secret.encode("utf8"), digestmod=hashlib.sha1)
|
||||
|
||||
mac.update(nonce.encode('utf8'))
|
||||
mac.update(nonce.encode("utf8"))
|
||||
mac.update(b"\x00")
|
||||
mac.update(user.encode('utf8'))
|
||||
mac.update(user.encode("utf8"))
|
||||
mac.update(b"\x00")
|
||||
mac.update(password.encode('utf8'))
|
||||
mac.update(password.encode("utf8"))
|
||||
mac.update(b"\x00")
|
||||
mac.update(b"admin" if admin else b"notadmin")
|
||||
if user_type:
|
||||
mac.update(b"\x00")
|
||||
mac.update(user_type.encode('utf8'))
|
||||
mac.update(user_type.encode("utf8"))
|
||||
|
||||
mac = mac.hexdigest()
|
||||
|
||||
|
@ -134,8 +134,9 @@ def register_new_user(user, password, server_location, shared_secret, admin, use
|
|||
else:
|
||||
admin = False
|
||||
|
||||
request_registration(user, password, server_location, shared_secret,
|
||||
bool(admin), user_type)
|
||||
request_registration(
|
||||
user, password, server_location, shared_secret, bool(admin), user_type
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -189,7 +190,7 @@ def main():
|
|||
group.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
type=argparse.FileType('r'),
|
||||
type=argparse.FileType("r"),
|
||||
help="Path to server config file. Used to read in shared secret.",
|
||||
)
|
||||
|
||||
|
@ -200,7 +201,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"server_url",
|
||||
default="https://localhost:8448",
|
||||
nargs='?',
|
||||
nargs="?",
|
||||
help="URL to use to talk to the home server. Defaults to "
|
||||
" 'https://localhost:8448'.",
|
||||
)
|
||||
|
@ -220,8 +221,9 @@ def main():
|
|||
if args.admin or args.no_admin:
|
||||
admin = args.admin
|
||||
|
||||
register_new_user(args.user, args.password, args.server_url, secret,
|
||||
admin, args.user_type)
|
||||
register_new_user(
|
||||
args.user, args.password, args.server_url, secret, admin, args.user_type
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -36,8 +36,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
AuthEventTypes = (
|
||||
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
|
||||
EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
|
||||
EventTypes.Create,
|
||||
EventTypes.Member,
|
||||
EventTypes.PowerLevels,
|
||||
EventTypes.JoinRules,
|
||||
EventTypes.RoomHistoryVisibility,
|
||||
EventTypes.ThirdPartyInvite,
|
||||
)
|
||||
|
||||
|
@ -54,6 +57,7 @@ class Auth(object):
|
|||
FIXME: This class contains a mix of functions for authenticating users
|
||||
of our client-server API and authenticating events added to room graphs.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -70,15 +74,12 @@ class Auth(object):
|
|||
def check_from_context(self, room_version, event, context, do_sig_check=True):
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
auth_events_ids = yield self.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True,
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = {
|
||||
(e.type, e.state_key): e for e in itervalues(auth_events)
|
||||
}
|
||||
auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
|
||||
self.check(
|
||||
room_version, event,
|
||||
auth_events=auth_events, do_sig_check=do_sig_check,
|
||||
room_version, event, auth_events=auth_events, do_sig_check=do_sig_check
|
||||
)
|
||||
|
||||
def check(self, room_version, event, auth_events, do_sig_check=True):
|
||||
|
@ -115,15 +116,10 @@ class Auth(object):
|
|||
the room.
|
||||
"""
|
||||
if current_state:
|
||||
member = current_state.get(
|
||||
(EventTypes.Member, user_id),
|
||||
None
|
||||
)
|
||||
member = current_state.get((EventTypes.Member, user_id), None)
|
||||
else:
|
||||
member = yield self.state.get_current_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Member,
|
||||
state_key=user_id
|
||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||
)
|
||||
|
||||
self._check_joined_room(member, user_id, room_id)
|
||||
|
@ -143,23 +139,17 @@ class Auth(object):
|
|||
the room. This will be the leave event if they have left the room.
|
||||
"""
|
||||
member = yield self.state.get_current_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Member,
|
||||
state_key=user_id
|
||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||
)
|
||||
membership = member.membership if member else None
|
||||
|
||||
if membership not in (Membership.JOIN, Membership.LEAVE):
|
||||
raise AuthError(403, "User %s not in room %s" % (
|
||||
user_id, room_id
|
||||
))
|
||||
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
||||
|
||||
if membership == Membership.LEAVE:
|
||||
forgot = yield self.store.did_forget(user_id, room_id)
|
||||
if forgot:
|
||||
raise AuthError(403, "User %s not in room %s" % (
|
||||
user_id, room_id
|
||||
))
|
||||
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
||||
|
||||
defer.returnValue(member)
|
||||
|
||||
|
@ -171,9 +161,9 @@ class Auth(object):
|
|||
|
||||
def _check_joined_room(self, member, user_id, room_id):
|
||||
if not member or member.membership != Membership.JOIN:
|
||||
raise AuthError(403, "User %s not in room %s (%s)" % (
|
||||
user_id, room_id, repr(member)
|
||||
))
|
||||
raise AuthError(
|
||||
403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
|
||||
)
|
||||
|
||||
def can_federate(self, event, auth_events):
|
||||
creation_event = auth_events.get((EventTypes.Create, ""))
|
||||
|
@ -185,11 +175,7 @@ class Auth(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_req(
|
||||
self,
|
||||
request,
|
||||
allow_guest=False,
|
||||
rights="access",
|
||||
allow_expired=False,
|
||||
self, request, allow_guest=False, rights="access", allow_expired=False
|
||||
):
|
||||
""" Get a registered user's ID.
|
||||
|
||||
|
@ -209,9 +195,8 @@ class Auth(object):
|
|||
try:
|
||||
ip_addr = self.hs.get_ip_from_request(request)
|
||||
user_agent = request.requestHeaders.getRawHeaders(
|
||||
b"User-Agent",
|
||||
default=[b""]
|
||||
)[0].decode('ascii', 'surrogateescape')
|
||||
b"User-Agent", default=[b""]
|
||||
)[0].decode("ascii", "surrogateescape")
|
||||
|
||||
access_token = self.get_access_token_from_request(
|
||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||
|
@ -243,11 +228,12 @@ class Auth(object):
|
|||
if self._account_validity.enabled and not allow_expired:
|
||||
user_id = user.to_string()
|
||||
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
|
||||
if expiration_ts is not None and self.clock.time_msec() >= expiration_ts:
|
||||
if (
|
||||
expiration_ts is not None
|
||||
and self.clock.time_msec() >= expiration_ts
|
||||
):
|
||||
raise AuthError(
|
||||
403,
|
||||
"User account has expired",
|
||||
errcode=Codes.EXPIRED_ACCOUNT,
|
||||
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
||||
)
|
||||
|
||||
# device_id may not be present if get_user_by_access_token has been
|
||||
|
@ -265,18 +251,23 @@ class Auth(object):
|
|||
|
||||
if is_guest and not allow_guest:
|
||||
raise AuthError(
|
||||
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
|
||||
403,
|
||||
"Guest access not allowed",
|
||||
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
||||
)
|
||||
|
||||
request.authenticated_entity = user.to_string()
|
||||
|
||||
defer.returnValue(synapse.types.create_requester(
|
||||
user, token_id, is_guest, device_id, app_service=app_service)
|
||||
defer.returnValue(
|
||||
synapse.types.create_requester(
|
||||
user, token_id, is_guest, device_id, app_service=app_service
|
||||
)
|
||||
)
|
||||
except KeyError:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||
errcode=Codes.MISSING_TOKEN
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"Missing access token.",
|
||||
errcode=Codes.MISSING_TOKEN,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -297,20 +288,14 @@ class Auth(object):
|
|||
if b"user_id" not in request.args:
|
||||
defer.returnValue((app_service.sender, app_service))
|
||||
|
||||
user_id = request.args[b"user_id"][0].decode('utf8')
|
||||
user_id = request.args[b"user_id"][0].decode("utf8")
|
||||
if app_service.sender == user_id:
|
||||
defer.returnValue((app_service.sender, app_service))
|
||||
|
||||
if not app_service.is_interested_in_user(user_id):
|
||||
raise AuthError(
|
||||
403,
|
||||
"Application service cannot masquerade as this user."
|
||||
)
|
||||
raise AuthError(403, "Application service cannot masquerade as this user.")
|
||||
if not (yield self.store.get_user_by_id(user_id)):
|
||||
raise AuthError(
|
||||
403,
|
||||
"Application service has not registered this user"
|
||||
)
|
||||
raise AuthError(403, "Application service has not registered this user")
|
||||
defer.returnValue((user_id, app_service))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -368,13 +353,13 @@ class Auth(object):
|
|||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"Unknown user_id %s" % user_id,
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
errcode=Codes.UNKNOWN_TOKEN,
|
||||
)
|
||||
if not stored_user["is_guest"]:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"Guest access token used for regular user",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
errcode=Codes.UNKNOWN_TOKEN,
|
||||
)
|
||||
ret = {
|
||||
"user": user,
|
||||
|
@ -402,8 +387,9 @@ class Auth(object):
|
|||
) as e:
|
||||
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"Invalid macaroon passed.",
|
||||
errcode=Codes.UNKNOWN_TOKEN,
|
||||
)
|
||||
|
||||
def _parse_and_validate_macaroon(self, token, rights="access"):
|
||||
|
@ -441,13 +427,13 @@ class Auth(object):
|
|||
guest = True
|
||||
|
||||
self.validate_macaroon(
|
||||
macaroon, rights, self.hs.config.expire_access_token,
|
||||
user_id=user_id,
|
||||
macaroon, rights, self.hs.config.expire_access_token, user_id=user_id
|
||||
)
|
||||
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"Invalid macaroon passed.",
|
||||
errcode=Codes.UNKNOWN_TOKEN,
|
||||
)
|
||||
|
||||
if not has_expiry and rights == "access":
|
||||
|
@ -472,10 +458,11 @@ class Auth(object):
|
|||
user_prefix = "user_id = "
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
return caveat.caveat_id[len(user_prefix):]
|
||||
return caveat.caveat_id[len(user_prefix) :]
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"No user caveat in macaroon",
|
||||
errcode=Codes.UNKNOWN_TOKEN,
|
||||
)
|
||||
|
||||
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
|
||||
|
@ -522,7 +509,7 @@ class Auth(object):
|
|||
prefix = "time < "
|
||||
if not caveat.startswith(prefix):
|
||||
return False
|
||||
expiry = int(caveat[len(prefix):])
|
||||
expiry = int(caveat[len(prefix) :])
|
||||
now = self.hs.get_clock().time_msec()
|
||||
return now < expiry
|
||||
|
||||
|
@ -554,14 +541,12 @@ class Auth(object):
|
|||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"Unrecognised access token.",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
errcode=Codes.UNKNOWN_TOKEN,
|
||||
)
|
||||
request.authenticated_entity = service.sender
|
||||
return defer.succeed(service)
|
||||
except KeyError:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
|
||||
)
|
||||
raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.")
|
||||
|
||||
def is_server_admin(self, user):
|
||||
""" Check if the given user is a local server admin.
|
||||
|
@ -581,19 +566,19 @@ class Auth(object):
|
|||
|
||||
auth_ids = []
|
||||
|
||||
key = (EventTypes.PowerLevels, "", )
|
||||
key = (EventTypes.PowerLevels, "")
|
||||
power_level_event_id = current_state_ids.get(key)
|
||||
|
||||
if power_level_event_id:
|
||||
auth_ids.append(power_level_event_id)
|
||||
|
||||
key = (EventTypes.JoinRules, "", )
|
||||
key = (EventTypes.JoinRules, "")
|
||||
join_rule_event_id = current_state_ids.get(key)
|
||||
|
||||
key = (EventTypes.Member, event.sender, )
|
||||
key = (EventTypes.Member, event.sender)
|
||||
member_event_id = current_state_ids.get(key)
|
||||
|
||||
key = (EventTypes.Create, "", )
|
||||
key = (EventTypes.Create, "")
|
||||
create_event_id = current_state_ids.get(key)
|
||||
if create_event_id:
|
||||
auth_ids.append(create_event_id)
|
||||
|
@ -619,7 +604,7 @@ class Auth(object):
|
|||
auth_ids.append(member_event_id)
|
||||
|
||||
if for_verification:
|
||||
key = (EventTypes.Member, event.state_key, )
|
||||
key = (EventTypes.Member, event.state_key)
|
||||
existing_event_id = current_state_ids.get(key)
|
||||
if existing_event_id:
|
||||
auth_ids.append(existing_event_id)
|
||||
|
@ -628,7 +613,7 @@ class Auth(object):
|
|||
if "third_party_invite" in event.content:
|
||||
key = (
|
||||
EventTypes.ThirdPartyInvite,
|
||||
event.content["third_party_invite"]["signed"]["token"]
|
||||
event.content["third_party_invite"]["signed"]["token"],
|
||||
)
|
||||
third_party_invite_id = current_state_ids.get(key)
|
||||
if third_party_invite_id:
|
||||
|
@ -684,7 +669,7 @@ class Auth(object):
|
|||
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
|
||||
|
||||
send_level = event_auth.get_send_level(
|
||||
EventTypes.Aliases, "", power_level_event,
|
||||
EventTypes.Aliases, "", power_level_event
|
||||
)
|
||||
user_level = event_auth.get_user_power_level(user_id, auth_events)
|
||||
|
||||
|
@ -692,7 +677,7 @@ class Auth(object):
|
|||
raise AuthError(
|
||||
403,
|
||||
"This server requires you to be a moderator in the room to"
|
||||
" edit its room list entry"
|
||||
" edit its room list entry",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -742,7 +727,7 @@ class Auth(object):
|
|||
)
|
||||
parts = auth_headers[0].split(b" ")
|
||||
if parts[0] == b"Bearer" and len(parts) == 2:
|
||||
return parts[1].decode('ascii')
|
||||
return parts[1].decode("ascii")
|
||||
else:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
|
@ -755,10 +740,10 @@ class Auth(object):
|
|||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
"Missing access token.",
|
||||
errcode=Codes.MISSING_TOKEN
|
||||
errcode=Codes.MISSING_TOKEN,
|
||||
)
|
||||
|
||||
return query_params[0].decode('ascii')
|
||||
return query_params[0].decode("ascii")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_in_room_or_world_readable(self, room_id, user_id):
|
||||
|
@ -785,8 +770,8 @@ class Auth(object):
|
|||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if (
|
||||
visibility and
|
||||
visibility.content["history_visibility"] == "world_readable"
|
||||
visibility
|
||||
and visibility.content["history_visibility"] == "world_readable"
|
||||
):
|
||||
defer.returnValue((Membership.JOIN, None))
|
||||
return
|
||||
|
@ -820,10 +805,11 @@ class Auth(object):
|
|||
|
||||
if self.hs.config.hs_disabled:
|
||||
raise ResourceLimitError(
|
||||
403, self.hs.config.hs_disabled_message,
|
||||
403,
|
||||
self.hs.config.hs_disabled_message,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
admin_contact=self.hs.config.admin_contact,
|
||||
limit_type=self.hs.config.hs_disabled_limit_type
|
||||
limit_type=self.hs.config.hs_disabled_limit_type,
|
||||
)
|
||||
if self.hs.config.limit_usage_by_mau is True:
|
||||
assert not (user_id and threepid)
|
||||
|
@ -848,8 +834,9 @@ class Auth(object):
|
|||
current_mau = yield self.store.get_monthly_active_count()
|
||||
if current_mau >= self.hs.config.max_mau_value:
|
||||
raise ResourceLimitError(
|
||||
403, "Monthly Active User Limit Exceeded",
|
||||
403,
|
||||
"Monthly Active User Limit Exceeded",
|
||||
admin_contact=self.hs.config.admin_contact,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
limit_type="monthly_active_user"
|
||||
limit_type="monthly_active_user",
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
"""Contains constants from the specification."""
|
||||
|
||||
# the "depth" field on events is limited to 2**63 - 1
|
||||
MAX_DEPTH = 2**63 - 1
|
||||
MAX_DEPTH = 2 ** 63 - 1
|
||||
|
||||
# the maximum length for a room alias is 255 characters
|
||||
MAX_ALIAS_LENGTH = 255
|
||||
|
@ -30,39 +30,41 @@ MAX_USERID_LENGTH = 255
|
|||
class Membership(object):
|
||||
|
||||
"""Represents the membership states of a user in a room."""
|
||||
INVITE = u"invite"
|
||||
JOIN = u"join"
|
||||
KNOCK = u"knock"
|
||||
LEAVE = u"leave"
|
||||
BAN = u"ban"
|
||||
|
||||
INVITE = "invite"
|
||||
JOIN = "join"
|
||||
KNOCK = "knock"
|
||||
LEAVE = "leave"
|
||||
BAN = "ban"
|
||||
LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
|
||||
|
||||
|
||||
class PresenceState(object):
|
||||
"""Represents the presence state of a user."""
|
||||
OFFLINE = u"offline"
|
||||
UNAVAILABLE = u"unavailable"
|
||||
ONLINE = u"online"
|
||||
|
||||
OFFLINE = "offline"
|
||||
UNAVAILABLE = "unavailable"
|
||||
ONLINE = "online"
|
||||
|
||||
|
||||
class JoinRules(object):
|
||||
PUBLIC = u"public"
|
||||
KNOCK = u"knock"
|
||||
INVITE = u"invite"
|
||||
PRIVATE = u"private"
|
||||
PUBLIC = "public"
|
||||
KNOCK = "knock"
|
||||
INVITE = "invite"
|
||||
PRIVATE = "private"
|
||||
|
||||
|
||||
class LoginType(object):
|
||||
PASSWORD = u"m.login.password"
|
||||
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||
MSISDN = u"m.login.msisdn"
|
||||
RECAPTCHA = u"m.login.recaptcha"
|
||||
TERMS = u"m.login.terms"
|
||||
DUMMY = u"m.login.dummy"
|
||||
PASSWORD = "m.login.password"
|
||||
EMAIL_IDENTITY = "m.login.email.identity"
|
||||
MSISDN = "m.login.msisdn"
|
||||
RECAPTCHA = "m.login.recaptcha"
|
||||
TERMS = "m.login.terms"
|
||||
DUMMY = "m.login.dummy"
|
||||
|
||||
# Only for C/S API v1
|
||||
APPLICATION_SERVICE = u"m.login.application_service"
|
||||
SHARED_SECRET = u"org.matrix.login.shared_secret"
|
||||
APPLICATION_SERVICE = "m.login.application_service"
|
||||
SHARED_SECRET = "org.matrix.login.shared_secret"
|
||||
|
||||
|
||||
class EventTypes(object):
|
||||
|
@ -118,6 +120,7 @@ class UserTypes(object):
|
|||
"""Allows for user type specific behaviour. With the benefit of hindsight
|
||||
'admin' and 'guest' users should also be UserTypes. Normal users are type None
|
||||
"""
|
||||
|
||||
SUPPORT = "support"
|
||||
ALL_USER_TYPES = (SUPPORT,)
|
||||
|
||||
|
@ -125,6 +128,7 @@ class UserTypes(object):
|
|||
class RelationTypes(object):
|
||||
"""The types of relations known to this server.
|
||||
"""
|
||||
|
||||
ANNOTATION = "m.annotation"
|
||||
REPLACE = "m.replace"
|
||||
REFERENCE = "m.reference"
|
||||
|
|
|
@ -70,6 +70,7 @@ class CodeMessageException(RuntimeError):
|
|||
code (int): HTTP error code
|
||||
msg (str): string describing the error
|
||||
"""
|
||||
|
||||
def __init__(self, code, msg):
|
||||
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
|
||||
self.code = code
|
||||
|
@ -83,6 +84,7 @@ class SynapseError(CodeMessageException):
|
|||
Attributes:
|
||||
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
|
||||
"""
|
||||
|
||||
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
|
||||
"""Constructs a synapse error.
|
||||
|
||||
|
@ -95,10 +97,7 @@ class SynapseError(CodeMessageException):
|
|||
self.errcode = errcode
|
||||
|
||||
def error_dict(self):
|
||||
return cs_error(
|
||||
self.msg,
|
||||
self.errcode,
|
||||
)
|
||||
return cs_error(self.msg, self.errcode)
|
||||
|
||||
|
||||
class ProxiedRequestError(SynapseError):
|
||||
|
@ -107,27 +106,23 @@ class ProxiedRequestError(SynapseError):
|
|||
Attributes:
|
||||
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
|
||||
"""
|
||||
|
||||
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
|
||||
super(ProxiedRequestError, self).__init__(
|
||||
code, msg, errcode
|
||||
)
|
||||
super(ProxiedRequestError, self).__init__(code, msg, errcode)
|
||||
if additional_fields is None:
|
||||
self._additional_fields = {}
|
||||
else:
|
||||
self._additional_fields = dict(additional_fields)
|
||||
|
||||
def error_dict(self):
|
||||
return cs_error(
|
||||
self.msg,
|
||||
self.errcode,
|
||||
**self._additional_fields
|
||||
)
|
||||
return cs_error(self.msg, self.errcode, **self._additional_fields)
|
||||
|
||||
|
||||
class ConsentNotGivenError(SynapseError):
|
||||
"""The error returned to the client when the user has not consented to the
|
||||
privacy policy.
|
||||
"""
|
||||
|
||||
def __init__(self, msg, consent_uri):
|
||||
"""Constructs a ConsentNotGivenError
|
||||
|
||||
|
@ -136,22 +131,17 @@ class ConsentNotGivenError(SynapseError):
|
|||
consent_url (str): The URL where the user can give their consent
|
||||
"""
|
||||
super(ConsentNotGivenError, self).__init__(
|
||||
code=http_client.FORBIDDEN,
|
||||
msg=msg,
|
||||
errcode=Codes.CONSENT_NOT_GIVEN
|
||||
code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN
|
||||
)
|
||||
self._consent_uri = consent_uri
|
||||
|
||||
def error_dict(self):
|
||||
return cs_error(
|
||||
self.msg,
|
||||
self.errcode,
|
||||
consent_uri=self._consent_uri
|
||||
)
|
||||
return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
|
||||
|
||||
|
||||
class RegistrationError(SynapseError):
|
||||
"""An error raised when a registration event fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -190,15 +180,17 @@ class InteractiveAuthIncompleteError(Exception):
|
|||
result (dict): the server response to the request, which should be
|
||||
passed back to the client
|
||||
"""
|
||||
|
||||
def __init__(self, result):
|
||||
super(InteractiveAuthIncompleteError, self).__init__(
|
||||
"Interactive auth not yet complete",
|
||||
"Interactive auth not yet complete"
|
||||
)
|
||||
self.result = result
|
||||
|
||||
|
||||
class UnrecognizedRequestError(SynapseError):
|
||||
"""An error indicating we don't understand the request you're trying to make"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "errcode" not in kwargs:
|
||||
kwargs["errcode"] = Codes.UNRECOGNIZED
|
||||
|
@ -207,21 +199,14 @@ class UnrecognizedRequestError(SynapseError):
|
|||
message = "Unrecognized request"
|
||||
else:
|
||||
message = args[0]
|
||||
super(UnrecognizedRequestError, self).__init__(
|
||||
400,
|
||||
message,
|
||||
**kwargs
|
||||
)
|
||||
super(UnrecognizedRequestError, self).__init__(400, message, **kwargs)
|
||||
|
||||
|
||||
class NotFoundError(SynapseError):
|
||||
"""An error indicating we can't find the thing you asked for"""
|
||||
|
||||
def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
|
||||
super(NotFoundError, self).__init__(
|
||||
404,
|
||||
msg,
|
||||
errcode=errcode
|
||||
)
|
||||
super(NotFoundError, self).__init__(404, msg, errcode=errcode)
|
||||
|
||||
|
||||
class AuthError(SynapseError):
|
||||
|
@ -238,8 +223,11 @@ class ResourceLimitError(SynapseError):
|
|||
Any error raised when there is a problem with resource usage.
|
||||
For instance, the monthly active user limit for the server has been exceeded
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, code, msg,
|
||||
self,
|
||||
code,
|
||||
msg,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
admin_contact=None,
|
||||
limit_type=None,
|
||||
|
@ -253,7 +241,7 @@ class ResourceLimitError(SynapseError):
|
|||
self.msg,
|
||||
self.errcode,
|
||||
admin_contact=self.admin_contact,
|
||||
limit_type=self.limit_type
|
||||
limit_type=self.limit_type,
|
||||
)
|
||||
|
||||
|
||||
|
@ -268,6 +256,7 @@ class EventSizeError(SynapseError):
|
|||
|
||||
class EventStreamError(SynapseError):
|
||||
"""An error raised when there a problem with the event stream."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "errcode" not in kwargs:
|
||||
kwargs["errcode"] = Codes.BAD_PAGINATION
|
||||
|
@ -276,47 +265,53 @@ class EventStreamError(SynapseError):
|
|||
|
||||
class LoginError(SynapseError):
|
||||
"""An error raised when there was a problem logging in."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StoreError(SynapseError):
|
||||
"""An error raised when there was a problem storing some data."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidCaptchaError(SynapseError):
|
||||
def __init__(self, code=400, msg="Invalid captcha.", error_url=None,
|
||||
errcode=Codes.CAPTCHA_INVALID):
|
||||
def __init__(
|
||||
self,
|
||||
code=400,
|
||||
msg="Invalid captcha.",
|
||||
error_url=None,
|
||||
errcode=Codes.CAPTCHA_INVALID,
|
||||
):
|
||||
super(InvalidCaptchaError, self).__init__(code, msg, errcode)
|
||||
self.error_url = error_url
|
||||
|
||||
def error_dict(self):
|
||||
return cs_error(
|
||||
self.msg,
|
||||
self.errcode,
|
||||
error_url=self.error_url,
|
||||
)
|
||||
return cs_error(self.msg, self.errcode, error_url=self.error_url)
|
||||
|
||||
|
||||
class LimitExceededError(SynapseError):
|
||||
"""A client has sent too many requests and is being throttled.
|
||||
"""
|
||||
def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None,
|
||||
errcode=Codes.LIMIT_EXCEEDED):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code=429,
|
||||
msg="Too Many Requests",
|
||||
retry_after_ms=None,
|
||||
errcode=Codes.LIMIT_EXCEEDED,
|
||||
):
|
||||
super(LimitExceededError, self).__init__(code, msg, errcode)
|
||||
self.retry_after_ms = retry_after_ms
|
||||
|
||||
def error_dict(self):
|
||||
return cs_error(
|
||||
self.msg,
|
||||
self.errcode,
|
||||
retry_after_ms=self.retry_after_ms,
|
||||
)
|
||||
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
|
||||
|
||||
|
||||
class RoomKeysVersionError(SynapseError):
|
||||
"""A client has tried to upload to a non-current version of the room_keys store
|
||||
"""
|
||||
|
||||
def __init__(self, current_version):
|
||||
"""
|
||||
Args:
|
||||
|
@ -331,6 +326,7 @@ class RoomKeysVersionError(SynapseError):
|
|||
class UnsupportedRoomVersionError(SynapseError):
|
||||
"""The client's request to create a room used a room version that the server does
|
||||
not support."""
|
||||
|
||||
def __init__(self):
|
||||
super(UnsupportedRoomVersionError, self).__init__(
|
||||
code=400,
|
||||
|
@ -354,22 +350,19 @@ class IncompatibleRoomVersionError(SynapseError):
|
|||
Unlike UnsupportedRoomVersionError, it is specific to the case of the make_join
|
||||
failing.
|
||||
"""
|
||||
|
||||
def __init__(self, room_version):
|
||||
super(IncompatibleRoomVersionError, self).__init__(
|
||||
code=400,
|
||||
msg="Your homeserver does not support the features required to "
|
||||
"join this room",
|
||||
"join this room",
|
||||
errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
|
||||
)
|
||||
|
||||
self._room_version = room_version
|
||||
|
||||
def error_dict(self):
|
||||
return cs_error(
|
||||
self.msg,
|
||||
self.errcode,
|
||||
room_version=self._room_version,
|
||||
)
|
||||
return cs_error(self.msg, self.errcode, room_version=self._room_version)
|
||||
|
||||
|
||||
class RequestSendFailed(RuntimeError):
|
||||
|
@ -380,11 +373,11 @@ class RequestSendFailed(RuntimeError):
|
|||
networking (e.g. DNS failures, connection timeouts etc), versus unexpected
|
||||
errors (like programming errors).
|
||||
"""
|
||||
|
||||
def __init__(self, inner_exception, can_retry):
|
||||
super(RequestSendFailed, self).__init__(
|
||||
"Failed to send request: %s: %s" % (
|
||||
type(inner_exception).__name__, inner_exception,
|
||||
)
|
||||
"Failed to send request: %s: %s"
|
||||
% (type(inner_exception).__name__, inner_exception)
|
||||
)
|
||||
self.inner_exception = inner_exception
|
||||
self.can_retry = can_retry
|
||||
|
@ -428,7 +421,7 @@ class FederationError(RuntimeError):
|
|||
self.affected = affected
|
||||
self.source = source
|
||||
|
||||
msg = "%s %s: %s" % (level, code, reason,)
|
||||
msg = "%s %s: %s" % (level, code, reason)
|
||||
super(FederationError, self).__init__(msg)
|
||||
|
||||
def get_dict(self):
|
||||
|
@ -448,6 +441,7 @@ class HttpResponseException(CodeMessageException):
|
|||
Attributes:
|
||||
response (bytes): body of response
|
||||
"""
|
||||
|
||||
def __init__(self, code, msg, response):
|
||||
"""
|
||||
|
||||
|
@ -486,7 +480,7 @@ class HttpResponseException(CodeMessageException):
|
|||
if not isinstance(j, dict):
|
||||
j = {}
|
||||
|
||||
errcode = j.pop('errcode', Codes.UNKNOWN)
|
||||
errmsg = j.pop('error', self.msg)
|
||||
errcode = j.pop("errcode", Codes.UNKNOWN)
|
||||
errmsg = j.pop("error", self.msg)
|
||||
|
||||
return ProxiedRequestError(self.code, errmsg, errcode, j)
|
||||
|
|
|
@ -28,117 +28,55 @@ FILTER_SCHEMA = {
|
|||
"additionalProperties": False,
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "number"
|
||||
},
|
||||
"senders": {
|
||||
"$ref": "#/definitions/user_id_array"
|
||||
},
|
||||
"not_senders": {
|
||||
"$ref": "#/definitions/user_id_array"
|
||||
},
|
||||
"limit": {"type": "number"},
|
||||
"senders": {"$ref": "#/definitions/user_id_array"},
|
||||
"not_senders": {"$ref": "#/definitions/user_id_array"},
|
||||
# TODO: We don't limit event type values but we probably should...
|
||||
# check types are valid event types
|
||||
"types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"not_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
"types": {"type": "array", "items": {"type": "string"}},
|
||||
"not_types": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
|
||||
ROOM_FILTER_SCHEMA = {
|
||||
"additionalProperties": False,
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"not_rooms": {
|
||||
"$ref": "#/definitions/room_id_array"
|
||||
},
|
||||
"rooms": {
|
||||
"$ref": "#/definitions/room_id_array"
|
||||
},
|
||||
"ephemeral": {
|
||||
"$ref": "#/definitions/room_event_filter"
|
||||
},
|
||||
"include_leave": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"state": {
|
||||
"$ref": "#/definitions/room_event_filter"
|
||||
},
|
||||
"timeline": {
|
||||
"$ref": "#/definitions/room_event_filter"
|
||||
},
|
||||
"account_data": {
|
||||
"$ref": "#/definitions/room_event_filter"
|
||||
},
|
||||
}
|
||||
"not_rooms": {"$ref": "#/definitions/room_id_array"},
|
||||
"rooms": {"$ref": "#/definitions/room_id_array"},
|
||||
"ephemeral": {"$ref": "#/definitions/room_event_filter"},
|
||||
"include_leave": {"type": "boolean"},
|
||||
"state": {"$ref": "#/definitions/room_event_filter"},
|
||||
"timeline": {"$ref": "#/definitions/room_event_filter"},
|
||||
"account_data": {"$ref": "#/definitions/room_event_filter"},
|
||||
},
|
||||
}
|
||||
|
||||
ROOM_EVENT_FILTER_SCHEMA = {
|
||||
"additionalProperties": False,
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "number"
|
||||
},
|
||||
"senders": {
|
||||
"$ref": "#/definitions/user_id_array"
|
||||
},
|
||||
"not_senders": {
|
||||
"$ref": "#/definitions/user_id_array"
|
||||
},
|
||||
"types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"not_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"rooms": {
|
||||
"$ref": "#/definitions/room_id_array"
|
||||
},
|
||||
"not_rooms": {
|
||||
"$ref": "#/definitions/room_id_array"
|
||||
},
|
||||
"contains_url": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"lazy_load_members": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"include_redundant_members": {
|
||||
"type": "boolean"
|
||||
},
|
||||
}
|
||||
"limit": {"type": "number"},
|
||||
"senders": {"$ref": "#/definitions/user_id_array"},
|
||||
"not_senders": {"$ref": "#/definitions/user_id_array"},
|
||||
"types": {"type": "array", "items": {"type": "string"}},
|
||||
"not_types": {"type": "array", "items": {"type": "string"}},
|
||||
"rooms": {"$ref": "#/definitions/room_id_array"},
|
||||
"not_rooms": {"$ref": "#/definitions/room_id_array"},
|
||||
"contains_url": {"type": "boolean"},
|
||||
"lazy_load_members": {"type": "boolean"},
|
||||
"include_redundant_members": {"type": "boolean"},
|
||||
},
|
||||
}
|
||||
|
||||
USER_ID_ARRAY_SCHEMA = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"format": "matrix_user_id"
|
||||
}
|
||||
"items": {"type": "string", "format": "matrix_user_id"},
|
||||
}
|
||||
|
||||
ROOM_ID_ARRAY_SCHEMA = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"format": "matrix_room_id"
|
||||
}
|
||||
"items": {"type": "string", "format": "matrix_room_id"},
|
||||
}
|
||||
|
||||
USER_FILTER_SCHEMA = {
|
||||
|
@ -150,22 +88,13 @@ USER_FILTER_SCHEMA = {
|
|||
"user_id_array": USER_ID_ARRAY_SCHEMA,
|
||||
"filter": FILTER_SCHEMA,
|
||||
"room_filter": ROOM_FILTER_SCHEMA,
|
||||
"room_event_filter": ROOM_EVENT_FILTER_SCHEMA
|
||||
"room_event_filter": ROOM_EVENT_FILTER_SCHEMA,
|
||||
},
|
||||
"properties": {
|
||||
"presence": {
|
||||
"$ref": "#/definitions/filter"
|
||||
},
|
||||
"account_data": {
|
||||
"$ref": "#/definitions/filter"
|
||||
},
|
||||
"room": {
|
||||
"$ref": "#/definitions/room_filter"
|
||||
},
|
||||
"event_format": {
|
||||
"type": "string",
|
||||
"enum": ["client", "federation"]
|
||||
},
|
||||
"presence": {"$ref": "#/definitions/filter"},
|
||||
"account_data": {"$ref": "#/definitions/filter"},
|
||||
"room": {"$ref": "#/definitions/room_filter"},
|
||||
"event_format": {"type": "string", "enum": ["client", "federation"]},
|
||||
"event_fields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
@ -177,26 +106,25 @@ USER_FILTER_SCHEMA = {
|
|||
#
|
||||
# Note that because this is a regular expression, we have to escape
|
||||
# each backslash in the pattern.
|
||||
"pattern": r"^((?!\\\\).)*$"
|
||||
}
|
||||
}
|
||||
"pattern": r"^((?!\\\\).)*$",
|
||||
},
|
||||
},
|
||||
},
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
|
||||
@FormatChecker.cls_checks('matrix_room_id')
|
||||
@FormatChecker.cls_checks("matrix_room_id")
|
||||
def matrix_room_id_validator(room_id_str):
|
||||
return RoomID.from_string(room_id_str)
|
||||
|
||||
|
||||
@FormatChecker.cls_checks('matrix_user_id')
|
||||
@FormatChecker.cls_checks("matrix_user_id")
|
||||
def matrix_user_id_validator(user_id_str):
|
||||
return UserID.from_string(user_id_str)
|
||||
|
||||
|
||||
class Filtering(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(Filtering, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -228,8 +156,9 @@ class Filtering(object):
|
|||
# individual top-level key e.g. public_user_data. Filters are made of
|
||||
# many definitions.
|
||||
try:
|
||||
jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
|
||||
format_checker=FormatChecker())
|
||||
jsonschema.validate(
|
||||
user_filter_json, USER_FILTER_SCHEMA, format_checker=FormatChecker()
|
||||
)
|
||||
except jsonschema.ValidationError as e:
|
||||
raise SynapseError(400, str(e))
|
||||
|
||||
|
@ -240,10 +169,9 @@ class FilterCollection(object):
|
|||
|
||||
room_filter_json = self._filter_json.get("room", {})
|
||||
|
||||
self._room_filter = Filter({
|
||||
k: v for k, v in room_filter_json.items()
|
||||
if k in ("rooms", "not_rooms")
|
||||
})
|
||||
self._room_filter = Filter(
|
||||
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
|
||||
)
|
||||
|
||||
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
|
||||
self._room_state_filter = Filter(room_filter_json.get("state", {}))
|
||||
|
@ -252,9 +180,7 @@ class FilterCollection(object):
|
|||
self._presence_filter = Filter(filter_json.get("presence", {}))
|
||||
self._account_data = Filter(filter_json.get("account_data", {}))
|
||||
|
||||
self.include_leave = filter_json.get("room", {}).get(
|
||||
"include_leave", False
|
||||
)
|
||||
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
|
||||
self.event_fields = filter_json.get("event_fields", [])
|
||||
self.event_format = filter_json.get("event_format", "client")
|
||||
|
||||
|
@ -299,22 +225,22 @@ class FilterCollection(object):
|
|||
|
||||
def blocks_all_presence(self):
|
||||
return (
|
||||
self._presence_filter.filters_all_types() or
|
||||
self._presence_filter.filters_all_senders()
|
||||
self._presence_filter.filters_all_types()
|
||||
or self._presence_filter.filters_all_senders()
|
||||
)
|
||||
|
||||
def blocks_all_room_ephemeral(self):
|
||||
return (
|
||||
self._room_ephemeral_filter.filters_all_types() or
|
||||
self._room_ephemeral_filter.filters_all_senders() or
|
||||
self._room_ephemeral_filter.filters_all_rooms()
|
||||
self._room_ephemeral_filter.filters_all_types()
|
||||
or self._room_ephemeral_filter.filters_all_senders()
|
||||
or self._room_ephemeral_filter.filters_all_rooms()
|
||||
)
|
||||
|
||||
def blocks_all_room_timeline(self):
|
||||
return (
|
||||
self._room_timeline_filter.filters_all_types() or
|
||||
self._room_timeline_filter.filters_all_senders() or
|
||||
self._room_timeline_filter.filters_all_rooms()
|
||||
self._room_timeline_filter.filters_all_types()
|
||||
or self._room_timeline_filter.filters_all_senders()
|
||||
or self._room_timeline_filter.filters_all_rooms()
|
||||
)
|
||||
|
||||
|
||||
|
@ -375,12 +301,7 @@ class Filter(object):
|
|||
# check if there is a string url field in the content for filtering purposes
|
||||
contains_url = isinstance(content.get("url"), text_type)
|
||||
|
||||
return self.check_fields(
|
||||
room_id,
|
||||
sender,
|
||||
ev_type,
|
||||
contains_url,
|
||||
)
|
||||
return self.check_fields(room_id, sender, ev_type, contains_url)
|
||||
|
||||
def check_fields(self, room_id, sender, event_type, contains_url):
|
||||
"""Checks whether the filter matches the given event fields.
|
||||
|
@ -391,7 +312,7 @@ class Filter(object):
|
|||
literal_keys = {
|
||||
"rooms": lambda v: room_id == v,
|
||||
"senders": lambda v: sender == v,
|
||||
"types": lambda v: _matches_wildcard(event_type, v)
|
||||
"types": lambda v: _matches_wildcard(event_type, v),
|
||||
}
|
||||
|
||||
for name, match_func in literal_keys.items():
|
||||
|
|
|
@ -44,29 +44,25 @@ class Ratelimiter(object):
|
|||
"""
|
||||
self.prune_message_counts(time_now_s)
|
||||
message_count, time_start, _ignored = self.message_counts.get(
|
||||
key, (0., time_now_s, None),
|
||||
key, (0.0, time_now_s, None)
|
||||
)
|
||||
time_delta = time_now_s - time_start
|
||||
sent_count = message_count - time_delta * rate_hz
|
||||
if sent_count < 0:
|
||||
allowed = True
|
||||
time_start = time_now_s
|
||||
message_count = 1.
|
||||
elif sent_count > burst_count - 1.:
|
||||
message_count = 1.0
|
||||
elif sent_count > burst_count - 1.0:
|
||||
allowed = False
|
||||
else:
|
||||
allowed = True
|
||||
message_count += 1
|
||||
|
||||
if update:
|
||||
self.message_counts[key] = (
|
||||
message_count, time_start, rate_hz
|
||||
)
|
||||
self.message_counts[key] = (message_count, time_start, rate_hz)
|
||||
|
||||
if rate_hz > 0:
|
||||
time_allowed = (
|
||||
time_start + (message_count - burst_count + 1) / rate_hz
|
||||
)
|
||||
time_allowed = time_start + (message_count - burst_count + 1) / rate_hz
|
||||
if time_allowed < time_now_s:
|
||||
time_allowed = time_now_s
|
||||
else:
|
||||
|
@ -76,9 +72,7 @@ class Ratelimiter(object):
|
|||
|
||||
def prune_message_counts(self, time_now_s):
|
||||
for key in list(self.message_counts.keys()):
|
||||
message_count, time_start, rate_hz = (
|
||||
self.message_counts[key]
|
||||
)
|
||||
message_count, time_start, rate_hz = self.message_counts[key]
|
||||
time_delta = time_now_s - time_start
|
||||
if message_count - time_delta * rate_hz > 0:
|
||||
break
|
||||
|
@ -92,5 +86,5 @@ class Ratelimiter(object):
|
|||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now_s)),
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now_s))
|
||||
)
|
||||
|
|
|
@ -19,9 +19,10 @@ class EventFormatVersions(object):
|
|||
"""This is an internal enum for tracking the version of the event format,
|
||||
independently from the room version.
|
||||
"""
|
||||
V1 = 1 # $id:server event id format
|
||||
V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
|
||||
V3 = 3 # MSC1884-style $hash format: introduced for room v4
|
||||
|
||||
V1 = 1 # $id:server event id format
|
||||
V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
|
||||
V3 = 3 # MSC1884-style $hash format: introduced for room v4
|
||||
|
||||
|
||||
KNOWN_EVENT_FORMAT_VERSIONS = {
|
||||
|
@ -33,8 +34,9 @@ KNOWN_EVENT_FORMAT_VERSIONS = {
|
|||
|
||||
class StateResolutionVersions(object):
|
||||
"""Enum to identify the state resolution algorithms"""
|
||||
V1 = 1 # room v1 state res
|
||||
V2 = 2 # MSC1442 state res: room v2 and later
|
||||
|
||||
V1 = 1 # room v1 state res
|
||||
V2 = 2 # MSC1442 state res: room v2 and later
|
||||
|
||||
|
||||
class RoomDisposition(object):
|
||||
|
@ -46,10 +48,10 @@ class RoomDisposition(object):
|
|||
class RoomVersion(object):
|
||||
"""An object which describes the unique attributes of a room version."""
|
||||
|
||||
identifier = attr.ib() # str; the identifier for this version
|
||||
disposition = attr.ib() # str; one of the RoomDispositions
|
||||
event_format = attr.ib() # int; one of the EventFormatVersions
|
||||
state_res = attr.ib() # int; one of the StateResolutionVersions
|
||||
identifier = attr.ib() # str; the identifier for this version
|
||||
disposition = attr.ib() # str; one of the RoomDispositions
|
||||
event_format = attr.ib() # int; one of the EventFormatVersions
|
||||
state_res = attr.ib() # int; one of the StateResolutionVersions
|
||||
enforce_key_validity = attr.ib() # bool
|
||||
|
||||
|
||||
|
@ -92,11 +94,12 @@ class RoomVersions(object):
|
|||
|
||||
|
||||
KNOWN_ROOM_VERSIONS = {
|
||||
v.identifier: v for v in (
|
||||
v.identifier: v
|
||||
for v in (
|
||||
RoomVersions.V1,
|
||||
RoomVersions.V2,
|
||||
RoomVersions.V3,
|
||||
RoomVersions.V4,
|
||||
RoomVersions.V5,
|
||||
)
|
||||
} # type: dict[str, RoomVersion]
|
||||
} # type: dict[str, RoomVersion]
|
||||
|
|
|
@ -42,13 +42,9 @@ class ConsentURIBuilder(object):
|
|||
hs_config (synapse.config.homeserver.HomeServerConfig):
|
||||
"""
|
||||
if hs_config.form_secret is None:
|
||||
raise ConfigError(
|
||||
"form_secret not set in config",
|
||||
)
|
||||
raise ConfigError("form_secret not set in config")
|
||||
if hs_config.public_baseurl is None:
|
||||
raise ConfigError(
|
||||
"public_baseurl not set in config",
|
||||
)
|
||||
raise ConfigError("public_baseurl not set in config")
|
||||
|
||||
self._hmac_secret = hs_config.form_secret.encode("utf-8")
|
||||
self._public_baseurl = hs_config.public_baseurl
|
||||
|
@ -64,15 +60,10 @@ class ConsentURIBuilder(object):
|
|||
(str) the URI where the user can do consent
|
||||
"""
|
||||
mac = hmac.new(
|
||||
key=self._hmac_secret,
|
||||
msg=user_id.encode('ascii'),
|
||||
digestmod=sha256,
|
||||
key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256
|
||||
).hexdigest()
|
||||
consent_uri = "%s_matrix/consent?%s" % (
|
||||
self._public_baseurl,
|
||||
urlencode({
|
||||
"u": user_id,
|
||||
"h": mac
|
||||
}),
|
||||
urlencode({"u": user_id, "h": mac}),
|
||||
)
|
||||
return consent_uri
|
||||
|
|
|
@ -43,7 +43,7 @@ def check_bind_error(e, address, bind_addresses):
|
|||
address (str): Address on which binding was attempted.
|
||||
bind_addresses (list): Addresses on which the service listens.
|
||||
"""
|
||||
if address == '0.0.0.0' and '::' in bind_addresses:
|
||||
logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
|
||||
if address == "0.0.0.0" and "::" in bind_addresses:
|
||||
logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]")
|
||||
else:
|
||||
raise e
|
||||
|
|
|
@ -19,7 +19,6 @@ import signal
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
import psutil
|
||||
from daemonize import Daemonize
|
||||
|
||||
from twisted.internet import defer, error, reactor
|
||||
|
@ -68,21 +67,13 @@ def start_worker_reactor(appname, config):
|
|||
gc_thresholds=config.gc_thresholds,
|
||||
pid_file=config.worker_pid_file,
|
||||
daemonize=config.worker_daemonize,
|
||||
cpu_affinity=config.worker_cpu_affinity,
|
||||
print_pidfile=config.print_pidfile,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
||||
def start_reactor(
|
||||
appname,
|
||||
soft_file_limit,
|
||||
gc_thresholds,
|
||||
pid_file,
|
||||
daemonize,
|
||||
cpu_affinity,
|
||||
print_pidfile,
|
||||
logger,
|
||||
appname, soft_file_limit, gc_thresholds, pid_file, daemonize, print_pidfile, logger
|
||||
):
|
||||
""" Run the reactor in the main process
|
||||
|
||||
|
@ -95,7 +86,6 @@ def start_reactor(
|
|||
gc_thresholds:
|
||||
pid_file (str): name of pid file to write to if daemonize is True
|
||||
daemonize (bool): true to run the reactor in a background process
|
||||
cpu_affinity (int|None): cpu affinity mask
|
||||
print_pidfile (bool): whether to print the pid file, if daemonize is True
|
||||
logger (logging.Logger): logger instance to pass to Daemonize
|
||||
"""
|
||||
|
@ -109,20 +99,6 @@ def start_reactor(
|
|||
# between the sentinel and `run` logcontexts.
|
||||
with PreserveLoggingContext():
|
||||
logger.info("Running")
|
||||
if cpu_affinity is not None:
|
||||
# Turn the bitmask into bits, reverse it so we go from 0 up
|
||||
mask_to_bits = bin(cpu_affinity)[2:][::-1]
|
||||
|
||||
cpus = []
|
||||
cpu_num = 0
|
||||
|
||||
for i in mask_to_bits:
|
||||
if i == "1":
|
||||
cpus.append(cpu_num)
|
||||
cpu_num += 1
|
||||
|
||||
p = psutil.Process()
|
||||
p.cpu_affinity(cpus)
|
||||
|
||||
change_resource_limit(soft_file_limit)
|
||||
if gc_thresholds:
|
||||
|
@ -149,10 +125,10 @@ def start_reactor(
|
|||
def quit_with_error(error_string):
|
||||
message_lines = error_string.split("\n")
|
||||
line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
|
||||
sys.stderr.write("*" * line_length + '\n')
|
||||
sys.stderr.write("*" * line_length + "\n")
|
||||
for line in message_lines:
|
||||
sys.stderr.write(" %s\n" % (line.rstrip(),))
|
||||
sys.stderr.write("*" * line_length + '\n')
|
||||
sys.stderr.write("*" * line_length + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
@ -178,14 +154,7 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
|
|||
r = []
|
||||
for address in bind_addresses:
|
||||
try:
|
||||
r.append(
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
factory,
|
||||
backlog,
|
||||
address
|
||||
)
|
||||
)
|
||||
r.append(reactor.listenTCP(port, factory, backlog, address))
|
||||
except error.CannotListenError as e:
|
||||
check_bind_error(e, address, bind_addresses)
|
||||
|
||||
|
@ -205,13 +174,7 @@ def listen_ssl(
|
|||
for address in bind_addresses:
|
||||
try:
|
||||
r.append(
|
||||
reactor.listenSSL(
|
||||
port,
|
||||
factory,
|
||||
context_factory,
|
||||
backlog,
|
||||
address
|
||||
)
|
||||
reactor.listenSSL(port, factory, context_factory, backlog, address)
|
||||
)
|
||||
except error.CannotListenError as e:
|
||||
check_bind_error(e, address, bind_addresses)
|
||||
|
@ -243,15 +206,13 @@ def refresh_certificate(hs):
|
|||
if isinstance(i.factory, TLSMemoryBIOFactory):
|
||||
addr = i.getHost()
|
||||
logger.info(
|
||||
"Replacing TLS context factory on [%s]:%i", addr.host, addr.port,
|
||||
"Replacing TLS context factory on [%s]:%i", addr.host, addr.port
|
||||
)
|
||||
# We want to replace TLS factories with a new one, with the new
|
||||
# TLS configuration. We do this by reaching in and pulling out
|
||||
# the wrappedFactory, and then re-wrapping it.
|
||||
i.factory = TLSMemoryBIOFactory(
|
||||
hs.tls_server_context_factory,
|
||||
False,
|
||||
i.factory.wrappedFactory
|
||||
hs.tls_server_context_factory, False, i.factory.wrappedFactory
|
||||
)
|
||||
logger.info("Context factories updated.")
|
||||
|
||||
|
@ -267,6 +228,7 @@ def start(hs, listeners=None):
|
|||
try:
|
||||
# Set up the SIGHUP machinery.
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
|
||||
def handle_sighup(*args, **kwargs):
|
||||
for i in _sighup_callbacks:
|
||||
i(hs)
|
||||
|
@ -302,10 +264,8 @@ def setup_sentry(hs):
|
|||
return
|
||||
|
||||
import sentry_sdk
|
||||
sentry_sdk.init(
|
||||
dsn=hs.config.sentry_dsn,
|
||||
release=get_version_string(synapse),
|
||||
)
|
||||
|
||||
sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse))
|
||||
|
||||
# We set some default tags that give some context to this instance
|
||||
with sentry_sdk.configure_scope() as scope:
|
||||
|
@ -326,7 +286,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
|
|||
many DNS queries at once
|
||||
"""
|
||||
new_resolver = _LimitedHostnameResolver(
|
||||
reactor.nameResolver, max_dns_requests_in_flight,
|
||||
reactor.nameResolver, max_dns_requests_in_flight
|
||||
)
|
||||
|
||||
reactor.installNameResolver(new_resolver)
|
||||
|
@ -339,11 +299,17 @@ class _LimitedHostnameResolver(object):
|
|||
def __init__(self, resolver, max_dns_requests_in_flight):
|
||||
self._resolver = resolver
|
||||
self._limiter = Linearizer(
|
||||
name="dns_client_limiter", max_count=max_dns_requests_in_flight,
|
||||
name="dns_client_limiter", max_count=max_dns_requests_in_flight
|
||||
)
|
||||
|
||||
def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
|
||||
addressTypes=None, transportSemantics='TCP'):
|
||||
def resolveHostName(
|
||||
self,
|
||||
resolutionReceiver,
|
||||
hostName,
|
||||
portNumber=0,
|
||||
addressTypes=None,
|
||||
transportSemantics="TCP",
|
||||
):
|
||||
# We need this function to return `resolutionReceiver` so we do all the
|
||||
# actual logic involving deferreds in a separate function.
|
||||
|
||||
|
@ -363,8 +329,14 @@ class _LimitedHostnameResolver(object):
|
|||
return resolutionReceiver
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _resolve(self, resolutionReceiver, hostName, portNumber=0,
|
||||
addressTypes=None, transportSemantics='TCP'):
|
||||
def _resolve(
|
||||
self,
|
||||
resolutionReceiver,
|
||||
hostName,
|
||||
portNumber=0,
|
||||
addressTypes=None,
|
||||
transportSemantics="TCP",
|
||||
):
|
||||
|
||||
with (yield self._limiter.queue(())):
|
||||
# resolveHostName doesn't return a Deferred, so we need to hook into
|
||||
|
@ -374,8 +346,7 @@ class _LimitedHostnameResolver(object):
|
|||
receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred)
|
||||
|
||||
self._resolver.resolveHostName(
|
||||
receiver, hostName, portNumber,
|
||||
addressTypes, transportSemantics,
|
||||
receiver, hostName, portNumber, addressTypes, transportSemantics
|
||||
)
|
||||
|
||||
yield deferred
|
||||
|
|
|
@ -44,7 +44,9 @@ logger = logging.getLogger("synapse.app.appservice")
|
|||
|
||||
|
||||
class AppserviceSlaveStore(
|
||||
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
|
||||
DirectoryStore,
|
||||
SlavedEventStore,
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
):
|
||||
pass
|
||||
|
@ -74,7 +76,7 @@ class AppserviceServer(HomeServer):
|
|||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Synapse appservice now listening on port %d", port)
|
||||
|
@ -88,18 +90,19 @@ class AppserviceServer(HomeServer):
|
|||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -132,9 +135,7 @@ class ASReplicationHandler(ReplicationClientHandler):
|
|||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse appservice", config_options
|
||||
)
|
||||
config = HomeServerConfig.load_config("Synapse appservice", config_options)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + str(e) + "\n")
|
||||
sys.exit(1)
|
||||
|
@ -173,6 +174,6 @@ def start(config_options):
|
|||
_base.start_worker_reactor("synapse-appservice", config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
|
|
|
@ -37,6 +37,7 @@ from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
|
|||
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
|
||||
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||
from synapse.replication.slave.storage.profile import SlavedProfileStore
|
||||
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||
|
@ -52,6 +53,7 @@ from synapse.rest.client.v1.room import (
|
|||
PublicRoomListRestServlet,
|
||||
RoomEventContextServlet,
|
||||
RoomMemberListRestServlet,
|
||||
RoomMessageListRestServlet,
|
||||
RoomStateRestServlet,
|
||||
)
|
||||
from synapse.rest.client.v1.voip import VoipRestServlet
|
||||
|
@ -74,6 +76,7 @@ class ClientReaderSlavedStore(
|
|||
SlavedDeviceStore,
|
||||
SlavedReceiptsStore,
|
||||
SlavedPushRuleStore,
|
||||
SlavedGroupServerStore,
|
||||
SlavedAccountDataStore,
|
||||
SlavedEventStore,
|
||||
SlavedKeyStore,
|
||||
|
@ -109,6 +112,7 @@ class ClientReaderServer(HomeServer):
|
|||
JoinedRoomMemberListRestServlet(self).register(resource)
|
||||
RoomStateRestServlet(self).register(resource)
|
||||
RoomEventContextServlet(self).register(resource)
|
||||
RoomMessageListRestServlet(self).register(resource)
|
||||
RegisterRestServlet(self).register(resource)
|
||||
LoginRestServlet(self).register(resource)
|
||||
ThreepidRestServlet(self).register(resource)
|
||||
|
@ -118,9 +122,7 @@ class ClientReaderServer(HomeServer):
|
|||
PushRuleRestServlet(self).register(resource)
|
||||
VersionsRestServlet().register(resource)
|
||||
|
||||
resources.update({
|
||||
"/_matrix/client": resource,
|
||||
})
|
||||
resources.update({"/_matrix/client": resource})
|
||||
|
||||
root_resource = create_resource_tree(resources, NoResource())
|
||||
|
||||
|
@ -133,7 +135,7 @@ class ClientReaderServer(HomeServer):
|
|||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Synapse client reader now listening on port %d", port)
|
||||
|
@ -147,18 +149,19 @@ class ClientReaderServer(HomeServer):
|
|||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -170,9 +173,7 @@ class ClientReaderServer(HomeServer):
|
|||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse client reader", config_options
|
||||
)
|
||||
config = HomeServerConfig.load_config("Synapse client reader", config_options)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + str(e) + "\n")
|
||||
sys.exit(1)
|
||||
|
@ -199,6 +200,6 @@ def start(config_options):
|
|||
_base.start_worker_reactor("synapse-client-reader", config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
|
|
|
@ -109,12 +109,14 @@ class EventCreatorServer(HomeServer):
|
|||
ProfileAvatarURLRestServlet(self).register(resource)
|
||||
ProfileDisplaynameRestServlet(self).register(resource)
|
||||
ProfileRestServlet(self).register(resource)
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
})
|
||||
resources.update(
|
||||
{
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
}
|
||||
)
|
||||
|
||||
root_resource = create_resource_tree(resources, NoResource())
|
||||
|
||||
|
@ -127,7 +129,7 @@ class EventCreatorServer(HomeServer):
|
|||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Synapse event creator now listening on port %d", port)
|
||||
|
@ -141,18 +143,19 @@ class EventCreatorServer(HomeServer):
|
|||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -164,9 +167,7 @@ class EventCreatorServer(HomeServer):
|
|||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse event creator", config_options
|
||||
)
|
||||
config = HomeServerConfig.load_config("Synapse event creator", config_options)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + str(e) + "\n")
|
||||
sys.exit(1)
|
||||
|
@ -198,6 +199,6 @@ def start(config_options):
|
|||
_base.start_worker_reactor("synapse-event-creator", config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
|
|
|
@ -86,19 +86,18 @@ class FederationReaderServer(HomeServer):
|
|||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
|
||||
elif name == "federation":
|
||||
resources.update({
|
||||
FEDERATION_PREFIX: TransportLayerServer(self),
|
||||
})
|
||||
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
|
||||
if name == "openid" and "federation" not in res["names"]:
|
||||
# Only load the openid resource separately if federation resource
|
||||
# is not specified since federation resource includes openid
|
||||
# resource.
|
||||
resources.update({
|
||||
FEDERATION_PREFIX: TransportLayerServer(
|
||||
self,
|
||||
servlet_groups=["openid"],
|
||||
),
|
||||
})
|
||||
resources.update(
|
||||
{
|
||||
FEDERATION_PREFIX: TransportLayerServer(
|
||||
self, servlet_groups=["openid"]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
if name in ["keys", "federation"]:
|
||||
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
|
||||
|
@ -115,7 +114,7 @@ class FederationReaderServer(HomeServer):
|
|||
root_resource,
|
||||
self.version_string,
|
||||
),
|
||||
reactor=self.get_reactor()
|
||||
reactor=self.get_reactor(),
|
||||
)
|
||||
|
||||
logger.info("Synapse federation reader now listening on port %d", port)
|
||||
|
@ -129,18 +128,19 @@ class FederationReaderServer(HomeServer):
|
|||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -181,6 +181,6 @@ def start(config_options):
|
|||
_base.start_worker_reactor("synapse-federation-reader", config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
|
|
|
@ -52,8 +52,13 @@ logger = logging.getLogger("synapse.app.federation_sender")
|
|||
|
||||
|
||||
class FederationSenderSlaveStore(
|
||||
SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore,
|
||||
SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
|
||||
SlavedDeviceInboxStore,
|
||||
SlavedTransactionStore,
|
||||
SlavedReceiptsStore,
|
||||
SlavedEventStore,
|
||||
SlavedRegistrationStore,
|
||||
SlavedDeviceStore,
|
||||
SlavedPresenceStore,
|
||||
):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
|
||||
|
@ -65,10 +70,7 @@ class FederationSenderSlaveStore(
|
|||
self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
|
||||
|
||||
def _get_federation_out_pos(self, db_conn):
|
||||
sql = (
|
||||
"SELECT stream_id FROM federation_stream_position"
|
||||
" WHERE type = ?"
|
||||
)
|
||||
sql = "SELECT stream_id FROM federation_stream_position" " WHERE type = ?"
|
||||
sql = self.database_engine.convert_param_style(sql)
|
||||
|
||||
txn = db_conn.cursor()
|
||||
|
@ -103,7 +105,7 @@ class FederationSenderServer(HomeServer):
|
|||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Synapse federation_sender now listening on port %d", port)
|
||||
|
@ -117,18 +119,19 @@ class FederationSenderServer(HomeServer):
|
|||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -151,7 +154,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
|
|||
self.send_handler.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
|
||||
args = super(
|
||||
FederationSenderReplicationHandler, self
|
||||
).get_streams_to_replicate()
|
||||
args.update(self.send_handler.stream_positions())
|
||||
return args
|
||||
|
||||
|
@ -203,6 +208,7 @@ class FederationSenderHandler(object):
|
|||
"""Processes the replication stream and forwards the appropriate entries
|
||||
to the federation sender.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, replication_client):
|
||||
self.store = hs.get_datastore()
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
|
@ -241,7 +247,7 @@ class FederationSenderHandler(object):
|
|||
# ... and when new receipts happen
|
||||
elif stream_name == ReceiptsStream.NAME:
|
||||
run_as_background_process(
|
||||
"process_receipts_for_federation", self._on_new_receipts, rows,
|
||||
"process_receipts_for_federation", self._on_new_receipts, rows
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -278,12 +284,14 @@ class FederationSenderHandler(object):
|
|||
|
||||
# We ACK this token over replication so that the master can drop
|
||||
# its in memory queues
|
||||
self.replication_client.send_federation_ack(self.federation_position)
|
||||
self.replication_client.send_federation_ack(
|
||||
self.federation_position
|
||||
)
|
||||
self._last_ack = self.federation_position
|
||||
except Exception:
|
||||
logger.exception("Error updating federation stream position")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
|
|
|
@ -62,14 +62,11 @@ class PresenceStatusStubServlet(RestServlet):
|
|||
# Pass through the auth headers, if any, in case the access token
|
||||
# is there.
|
||||
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
|
||||
headers = {
|
||||
"Authorization": auth_headers,
|
||||
}
|
||||
headers = {"Authorization": auth_headers}
|
||||
|
||||
try:
|
||||
result = yield self.http_client.get_json(
|
||||
self.main_uri + request.uri.decode('ascii'),
|
||||
headers=headers,
|
||||
self.main_uri + request.uri.decode("ascii"), headers=headers
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
raise e.to_synapse_error()
|
||||
|
@ -105,18 +102,19 @@ class KeyUploadServlet(RestServlet):
|
|||
if device_id is not None:
|
||||
# passing the device_id here is deprecated; however, we allow it
|
||||
# for now for compatibility with older clients.
|
||||
if (requester.device_id is not None and
|
||||
device_id != requester.device_id):
|
||||
logger.warning("Client uploading keys for a different device "
|
||||
"(logged in as %s, uploading for %s)",
|
||||
requester.device_id, device_id)
|
||||
if requester.device_id is not None and device_id != requester.device_id:
|
||||
logger.warning(
|
||||
"Client uploading keys for a different device "
|
||||
"(logged in as %s, uploading for %s)",
|
||||
requester.device_id,
|
||||
device_id,
|
||||
)
|
||||
else:
|
||||
device_id = requester.device_id
|
||||
|
||||
if device_id is None:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"To upload keys, you must pass device_id when authenticating"
|
||||
400, "To upload keys, you must pass device_id when authenticating"
|
||||
)
|
||||
|
||||
if body:
|
||||
|
@ -124,13 +122,9 @@ class KeyUploadServlet(RestServlet):
|
|||
# Pass through the auth headers, if any, in case the access token
|
||||
# is there.
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
|
||||
headers = {
|
||||
"Authorization": auth_headers,
|
||||
}
|
||||
headers = {"Authorization": auth_headers}
|
||||
result = yield self.http_client.post_json_get_json(
|
||||
self.main_uri + request.uri.decode('ascii'),
|
||||
body,
|
||||
headers=headers,
|
||||
self.main_uri + request.uri.decode("ascii"), body, headers=headers
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -171,12 +165,14 @@ class FrontendProxyServer(HomeServer):
|
|||
if not self.config.use_presence:
|
||||
PresenceStatusStubServlet(self).register(resource)
|
||||
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
})
|
||||
resources.update(
|
||||
{
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
}
|
||||
)
|
||||
|
||||
root_resource = create_resource_tree(resources, NoResource())
|
||||
|
||||
|
@ -190,7 +186,7 @@ class FrontendProxyServer(HomeServer):
|
|||
root_resource,
|
||||
self.version_string,
|
||||
),
|
||||
reactor=self.get_reactor()
|
||||
reactor=self.get_reactor(),
|
||||
)
|
||||
|
||||
logger.info("Synapse client reader now listening on port %d", port)
|
||||
|
@ -204,18 +200,19 @@ class FrontendProxyServer(HomeServer):
|
|||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -227,9 +224,7 @@ class FrontendProxyServer(HomeServer):
|
|||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse frontend proxy", config_options
|
||||
)
|
||||
config = HomeServerConfig.load_config("Synapse frontend proxy", config_options)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + str(e) + "\n")
|
||||
sys.exit(1)
|
||||
|
@ -258,6 +253,6 @@ def start(config_options):
|
|||
_base.start_worker_reactor("synapse-frontend-proxy", config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
|
|
|
@ -101,13 +101,12 @@ class SynapseHomeServer(HomeServer):
|
|||
# Skip loading openid resource if federation is defined
|
||||
# since federation resource will include openid
|
||||
continue
|
||||
resources.update(self._configure_named_resource(
|
||||
name, res.get("compress", False),
|
||||
))
|
||||
resources.update(
|
||||
self._configure_named_resource(name, res.get("compress", False))
|
||||
)
|
||||
|
||||
additional_resources = listener_config.get("additional_resources", {})
|
||||
logger.debug("Configuring additional resources: %r",
|
||||
additional_resources)
|
||||
logger.debug("Configuring additional resources: %r", additional_resources)
|
||||
module_api = ModuleApi(self, self.get_auth_handler())
|
||||
for path, resmodule in additional_resources.items():
|
||||
handler_cls, config = load_module(resmodule)
|
||||
|
@ -174,60 +173,67 @@ class SynapseHomeServer(HomeServer):
|
|||
if compress:
|
||||
client_resource = gz_wrap(client_resource)
|
||||
|
||||
resources.update({
|
||||
"/_matrix/client/api/v1": client_resource,
|
||||
"/_synapse/password_reset": client_resource,
|
||||
"/_matrix/client/r0": client_resource,
|
||||
"/_matrix/client/unstable": client_resource,
|
||||
"/_matrix/client/v2_alpha": client_resource,
|
||||
"/_matrix/client/versions": client_resource,
|
||||
"/.well-known/matrix/client": WellKnownResource(self),
|
||||
"/_synapse/admin": AdminRestResource(self),
|
||||
})
|
||||
resources.update(
|
||||
{
|
||||
"/_matrix/client/api/v1": client_resource,
|
||||
"/_matrix/client/r0": client_resource,
|
||||
"/_matrix/client/unstable": client_resource,
|
||||
"/_matrix/client/v2_alpha": client_resource,
|
||||
"/_matrix/client/versions": client_resource,
|
||||
"/.well-known/matrix/client": WellKnownResource(self),
|
||||
"/_synapse/admin": AdminRestResource(self),
|
||||
}
|
||||
)
|
||||
|
||||
if self.get_config().saml2_enabled:
|
||||
from synapse.rest.saml2 import SAML2Resource
|
||||
|
||||
resources["/_matrix/saml2"] = SAML2Resource(self)
|
||||
|
||||
if name == "consent":
|
||||
from synapse.rest.consent.consent_resource import ConsentResource
|
||||
|
||||
consent_resource = ConsentResource(self)
|
||||
if compress:
|
||||
consent_resource = gz_wrap(consent_resource)
|
||||
resources.update({
|
||||
"/_matrix/consent": consent_resource,
|
||||
})
|
||||
resources.update({"/_matrix/consent": consent_resource})
|
||||
|
||||
if name == "federation":
|
||||
resources.update({
|
||||
FEDERATION_PREFIX: TransportLayerServer(self),
|
||||
})
|
||||
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
|
||||
|
||||
if name == "openid":
|
||||
resources.update({
|
||||
FEDERATION_PREFIX: TransportLayerServer(self, servlet_groups=["openid"]),
|
||||
})
|
||||
resources.update(
|
||||
{
|
||||
FEDERATION_PREFIX: TransportLayerServer(
|
||||
self, servlet_groups=["openid"]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
if name in ["static", "client"]:
|
||||
resources.update({
|
||||
STATIC_PREFIX: File(
|
||||
os.path.join(os.path.dirname(synapse.__file__), "static")
|
||||
),
|
||||
})
|
||||
resources.update(
|
||||
{
|
||||
STATIC_PREFIX: File(
|
||||
os.path.join(os.path.dirname(synapse.__file__), "static")
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
if name in ["media", "federation", "client"]:
|
||||
if self.get_config().enable_media_repo:
|
||||
media_repo = self.get_media_repository_resource()
|
||||
resources.update({
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
resources.update(
|
||||
{
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
}
|
||||
)
|
||||
elif name == "media":
|
||||
raise ConfigError(
|
||||
"'media' resource conflicts with enable_media_repo=False",
|
||||
"'media' resource conflicts with enable_media_repo=False"
|
||||
)
|
||||
|
||||
if name in ["keys", "federation"]:
|
||||
|
@ -258,18 +264,14 @@ class SynapseHomeServer(HomeServer):
|
|||
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listening_services.extend(
|
||||
self._listener_http(config, listener)
|
||||
)
|
||||
self._listening_services.extend(self._listener_http(config, listener))
|
||||
elif listener["type"] == "manhole":
|
||||
listen_tcp(
|
||||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "replication":
|
||||
services = listen_tcp(
|
||||
|
@ -278,16 +280,17 @@ class SynapseHomeServer(HomeServer):
|
|||
ReplicationStreamProtocolFactory(self),
|
||||
)
|
||||
for s in services:
|
||||
reactor.addSystemEventTrigger(
|
||||
"before", "shutdown", s.stopListening,
|
||||
)
|
||||
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -313,7 +316,7 @@ current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
|
|||
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
|
||||
registered_reserved_users_mau_gauge = Gauge(
|
||||
"synapse_admin_mau:registered_reserved_users",
|
||||
"Registered users with reserved threepids"
|
||||
"Registered users with reserved threepids",
|
||||
)
|
||||
|
||||
|
||||
|
@ -328,8 +331,7 @@ def setup(config_options):
|
|||
"""
|
||||
try:
|
||||
config = HomeServerConfig.load_or_generate_config(
|
||||
"Synapse Homeserver",
|
||||
config_options,
|
||||
"Synapse Homeserver", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + str(e) + "\n")
|
||||
|
@ -340,10 +342,7 @@ def setup(config_options):
|
|||
# generating config files and shouldn't try to continue.
|
||||
sys.exit(0)
|
||||
|
||||
synapse.config.logger.setup_logging(
|
||||
config,
|
||||
use_worker_options=False
|
||||
)
|
||||
synapse.config.logger.setup_logging(config, use_worker_options=False)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
|
@ -358,7 +357,7 @@ def setup(config_options):
|
|||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
logger.info("Preparing database: %s...", config.database_config['name'])
|
||||
logger.info("Preparing database: %s...", config.database_config["name"])
|
||||
|
||||
try:
|
||||
with hs.get_db_conn(run_new_connection=False) as db_conn:
|
||||
|
@ -376,7 +375,7 @@ def setup(config_options):
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("Database prepared in %s.", config.database_config['name'])
|
||||
logger.info("Database prepared in %s.", config.database_config["name"])
|
||||
|
||||
hs.setup()
|
||||
hs.setup_master()
|
||||
|
@ -392,9 +391,7 @@ def setup(config_options):
|
|||
acme = hs.get_acme_handler()
|
||||
|
||||
# Check how long the certificate is active for.
|
||||
cert_days_remaining = hs.config.is_disk_cert_valid(
|
||||
allow_self_signed=False
|
||||
)
|
||||
cert_days_remaining = hs.config.is_disk_cert_valid(allow_self_signed=False)
|
||||
|
||||
# We want to reprovision if cert_days_remaining is None (meaning no
|
||||
# certificate exists), or the days remaining number it returns
|
||||
|
@ -402,8 +399,8 @@ def setup(config_options):
|
|||
provision = False
|
||||
|
||||
if (
|
||||
cert_days_remaining is None or
|
||||
cert_days_remaining < hs.config.acme_reprovision_threshold
|
||||
cert_days_remaining is None
|
||||
or cert_days_remaining < hs.config.acme_reprovision_threshold
|
||||
):
|
||||
provision = True
|
||||
|
||||
|
@ -434,10 +431,7 @@ def setup(config_options):
|
|||
yield do_acme()
|
||||
|
||||
# Check if it needs to be reprovisioned every day.
|
||||
hs.get_clock().looping_call(
|
||||
reprovision_acme,
|
||||
24 * 60 * 60 * 1000
|
||||
)
|
||||
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
|
||||
|
||||
_base.start(hs, config.listeners)
|
||||
|
||||
|
@ -464,6 +458,7 @@ class SynapseService(service.Service):
|
|||
A twisted Service class that will start synapse. Used to run synapse
|
||||
via twistd and a .tac.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
|
@ -480,6 +475,7 @@ class SynapseService(service.Service):
|
|||
def run(hs):
|
||||
PROFILE_SYNAPSE = False
|
||||
if PROFILE_SYNAPSE:
|
||||
|
||||
def profile(func):
|
||||
from cProfile import Profile
|
||||
from threading import current_thread
|
||||
|
@ -490,13 +486,14 @@ def run(hs):
|
|||
func(*args, **kargs)
|
||||
profile.disable()
|
||||
ident = current_thread().ident
|
||||
profile.dump_stats("/tmp/%s.%s.%i.pstat" % (
|
||||
hs.hostname, func.__name__, ident
|
||||
))
|
||||
profile.dump_stats(
|
||||
"/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident)
|
||||
)
|
||||
|
||||
return profiled
|
||||
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
|
||||
ThreadPool._worker = profile(ThreadPool._worker)
|
||||
reactor.run = profile(reactor.run)
|
||||
|
||||
|
@ -541,7 +538,10 @@ def run(hs):
|
|||
stats["total_room_count"] = room_count
|
||||
|
||||
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
|
||||
stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
|
||||
stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
|
||||
stats[
|
||||
"daily_active_rooms"
|
||||
] = yield hs.get_datastore().count_daily_active_rooms()
|
||||
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
|
||||
|
||||
r30_results = yield hs.get_datastore().count_r30_users()
|
||||
|
@ -565,8 +565,7 @@ def run(hs):
|
|||
logger.info("Reporting stats to matrix.org: %s" % (stats,))
|
||||
try:
|
||||
yield hs.get_simple_http_client().put_json(
|
||||
"https://matrix.org/report-usage-stats/push",
|
||||
stats
|
||||
"https://matrix.org/report-usage-stats/push", stats
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn("Error reporting stats: %s", e)
|
||||
|
@ -581,14 +580,11 @@ def run(hs):
|
|||
logger.info("report_stats can use psutil")
|
||||
stats_process.append(process)
|
||||
except (AttributeError):
|
||||
logger.warning(
|
||||
"Unable to read memory/cpu stats. Disabling reporting."
|
||||
)
|
||||
logger.warning("Unable to read memory/cpu stats. Disabling reporting.")
|
||||
|
||||
def generate_user_daily_visit_stats():
|
||||
return run_as_background_process(
|
||||
"generate_user_daily_visits",
|
||||
hs.get_datastore().generate_user_daily_visits,
|
||||
"generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits
|
||||
)
|
||||
|
||||
# Rather than update on per session basis, batch up the requests.
|
||||
|
@ -599,9 +595,9 @@ def run(hs):
|
|||
# monthly active user limiting functionality
|
||||
def reap_monthly_active_users():
|
||||
return run_as_background_process(
|
||||
"reap_monthly_active_users",
|
||||
hs.get_datastore().reap_monthly_active_users,
|
||||
"reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users
|
||||
)
|
||||
|
||||
clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
|
||||
reap_monthly_active_users()
|
||||
|
||||
|
@ -619,8 +615,7 @@ def run(hs):
|
|||
|
||||
def start_generate_monthly_active_users():
|
||||
return run_as_background_process(
|
||||
"generate_monthly_active_users",
|
||||
generate_monthly_active_users,
|
||||
"generate_monthly_active_users", generate_monthly_active_users
|
||||
)
|
||||
|
||||
start_generate_monthly_active_users()
|
||||
|
@ -646,7 +641,6 @@ def run(hs):
|
|||
gc_thresholds=hs.config.gc_thresholds,
|
||||
pid_file=hs.config.pid_file,
|
||||
daemonize=hs.config.daemonize,
|
||||
cpu_affinity=hs.config.cpu_affinity,
|
||||
print_pidfile=hs.config.print_pidfile,
|
||||
logger=logger,
|
||||
)
|
||||
|
@ -660,5 +654,5 @@ def main():
|
|||
run(hs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -72,13 +72,15 @@ class MediaRepositoryServer(HomeServer):
|
|||
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
|
||||
elif name == "media":
|
||||
media_repo = self.get_media_repository_resource()
|
||||
resources.update({
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
resources.update(
|
||||
{
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
root_resource = create_resource_tree(resources, NoResource())
|
||||
|
||||
|
@ -91,7 +93,7 @@ class MediaRepositoryServer(HomeServer):
|
|||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Synapse media repository now listening on port %d", port)
|
||||
|
@ -105,18 +107,19 @@ class MediaRepositoryServer(HomeServer):
|
|||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -164,6 +167,6 @@ def start(config_options):
|
|||
_base.start_worker_reactor("synapse-media-repository", config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
|
|
|
@ -46,36 +46,27 @@ logger = logging.getLogger("synapse.app.pusher")
|
|||
|
||||
|
||||
class PusherSlaveStore(
|
||||
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
|
||||
SlavedAccountDataStore
|
||||
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedAccountDataStore
|
||||
):
|
||||
update_pusher_last_stream_ordering_and_success = (
|
||||
__func__(DataStore.update_pusher_last_stream_ordering_and_success)
|
||||
update_pusher_last_stream_ordering_and_success = __func__(
|
||||
DataStore.update_pusher_last_stream_ordering_and_success
|
||||
)
|
||||
|
||||
update_pusher_failing_since = (
|
||||
__func__(DataStore.update_pusher_failing_since)
|
||||
update_pusher_failing_since = __func__(DataStore.update_pusher_failing_since)
|
||||
|
||||
update_pusher_last_stream_ordering = __func__(
|
||||
DataStore.update_pusher_last_stream_ordering
|
||||
)
|
||||
|
||||
update_pusher_last_stream_ordering = (
|
||||
__func__(DataStore.update_pusher_last_stream_ordering)
|
||||
get_throttle_params_by_room = __func__(DataStore.get_throttle_params_by_room)
|
||||
|
||||
set_throttle_params = __func__(DataStore.set_throttle_params)
|
||||
|
||||
get_time_of_last_push_action_before = __func__(
|
||||
DataStore.get_time_of_last_push_action_before
|
||||
)
|
||||
|
||||
get_throttle_params_by_room = (
|
||||
__func__(DataStore.get_throttle_params_by_room)
|
||||
)
|
||||
|
||||
set_throttle_params = (
|
||||
__func__(DataStore.set_throttle_params)
|
||||
)
|
||||
|
||||
get_time_of_last_push_action_before = (
|
||||
__func__(DataStore.get_time_of_last_push_action_before)
|
||||
)
|
||||
|
||||
get_profile_displayname = (
|
||||
__func__(DataStore.get_profile_displayname)
|
||||
)
|
||||
get_profile_displayname = __func__(DataStore.get_profile_displayname)
|
||||
|
||||
|
||||
class PusherServer(HomeServer):
|
||||
|
@ -105,7 +96,7 @@ class PusherServer(HomeServer):
|
|||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Synapse pusher now listening on port %d", port)
|
||||
|
@ -119,18 +110,19 @@ class PusherServer(HomeServer):
|
|||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -161,9 +153,7 @@ class PusherReplicationHandler(ReplicationClientHandler):
|
|||
else:
|
||||
yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
|
||||
elif stream_name == "events":
|
||||
yield self.pusher_pool.on_new_notifications(
|
||||
token, token,
|
||||
)
|
||||
yield self.pusher_pool.on_new_notifications(token, token)
|
||||
elif stream_name == "receipts":
|
||||
yield self.pusher_pool.on_new_receipts(
|
||||
token, token, set(row.room_id for row in rows)
|
||||
|
@ -188,9 +178,7 @@ class PusherReplicationHandler(ReplicationClientHandler):
|
|||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse pusher", config_options
|
||||
)
|
||||
config = HomeServerConfig.load_config("Synapse pusher", config_options)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + str(e) + "\n")
|
||||
sys.exit(1)
|
||||
|
@ -234,6 +222,6 @@ def start(config_options):
|
|||
_base.start_worker_reactor("synapse-pusher", config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
ps = start(sys.argv[1:])
|
||||
|
|
|
@ -98,10 +98,7 @@ class SynchrotronPresence(object):
|
|||
self.notifier = hs.get_notifier()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
self.user_to_current_state = {
|
||||
state.user_id: state
|
||||
for state in active_presence
|
||||
}
|
||||
self.user_to_current_state = {state.user_id: state for state in active_presence}
|
||||
|
||||
# user_id -> last_sync_ms. Lists the users that have stopped syncing
|
||||
# but we haven't notified the master of that yet
|
||||
|
@ -196,17 +193,26 @@ class SynchrotronPresence(object):
|
|||
room_ids_to_states, users_to_states = parties
|
||||
|
||||
self.notifier.on_new_event(
|
||||
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
|
||||
users=users_to_states.keys()
|
||||
"presence_key",
|
||||
stream_id,
|
||||
rooms=room_ids_to_states.keys(),
|
||||
users=users_to_states.keys(),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def process_replication_rows(self, token, rows):
|
||||
states = [UserPresenceState(
|
||||
row.user_id, row.state, row.last_active_ts,
|
||||
row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg,
|
||||
row.currently_active
|
||||
) for row in rows]
|
||||
states = [
|
||||
UserPresenceState(
|
||||
row.user_id,
|
||||
row.state,
|
||||
row.last_active_ts,
|
||||
row.last_federation_update_ts,
|
||||
row.last_user_sync_ts,
|
||||
row.status_msg,
|
||||
row.currently_active,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
for state in states:
|
||||
self.user_to_current_state[state.user_id] = state
|
||||
|
@ -217,7 +223,8 @@ class SynchrotronPresence(object):
|
|||
def get_currently_syncing_users(self):
|
||||
if self.hs.config.use_presence:
|
||||
return [
|
||||
user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
|
||||
user_id
|
||||
for user_id, count in iteritems(self.user_to_num_current_syncs)
|
||||
if count > 0
|
||||
]
|
||||
else:
|
||||
|
@ -281,12 +288,14 @@ class SynchrotronServer(HomeServer):
|
|||
events.register_servlets(self, resource)
|
||||
InitialSyncRestServlet(self).register(resource)
|
||||
RoomInitialSyncRestServlet(self).register(resource)
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
})
|
||||
resources.update(
|
||||
{
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
}
|
||||
)
|
||||
|
||||
root_resource = create_resource_tree(resources, NoResource())
|
||||
|
||||
|
@ -299,7 +308,7 @@ class SynchrotronServer(HomeServer):
|
|||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Synapse synchrotron now listening on port %d", port)
|
||||
|
@ -313,18 +322,19 @@ class SynchrotronServer(HomeServer):
|
|||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -382,40 +392,36 @@ class SyncReplicationHandler(ReplicationClientHandler):
|
|||
)
|
||||
elif stream_name == "push_rules":
|
||||
self.notifier.on_new_event(
|
||||
"push_rules_key", token, users=[row.user_id for row in rows],
|
||||
"push_rules_key", token, users=[row.user_id for row in rows]
|
||||
)
|
||||
elif stream_name in ("account_data", "tag_account_data",):
|
||||
elif stream_name in ("account_data", "tag_account_data"):
|
||||
self.notifier.on_new_event(
|
||||
"account_data_key", token, users=[row.user_id for row in rows],
|
||||
"account_data_key", token, users=[row.user_id for row in rows]
|
||||
)
|
||||
elif stream_name == "receipts":
|
||||
self.notifier.on_new_event(
|
||||
"receipt_key", token, rooms=[row.room_id for row in rows],
|
||||
"receipt_key", token, rooms=[row.room_id for row in rows]
|
||||
)
|
||||
elif stream_name == "typing":
|
||||
self.typing_handler.process_replication_rows(token, rows)
|
||||
self.notifier.on_new_event(
|
||||
"typing_key", token, rooms=[row.room_id for row in rows],
|
||||
"typing_key", token, rooms=[row.room_id for row in rows]
|
||||
)
|
||||
elif stream_name == "to_device":
|
||||
entities = [row.entity for row in rows if row.entity.startswith("@")]
|
||||
if entities:
|
||||
self.notifier.on_new_event(
|
||||
"to_device_key", token, users=entities,
|
||||
)
|
||||
self.notifier.on_new_event("to_device_key", token, users=entities)
|
||||
elif stream_name == "device_lists":
|
||||
all_room_ids = set()
|
||||
for row in rows:
|
||||
room_ids = yield self.store.get_rooms_for_user(row.user_id)
|
||||
all_room_ids.update(room_ids)
|
||||
self.notifier.on_new_event(
|
||||
"device_list_key", token, rooms=all_room_ids,
|
||||
)
|
||||
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
|
||||
elif stream_name == "presence":
|
||||
yield self.presence_handler.process_replication_rows(token, rows)
|
||||
elif stream_name == "receipts":
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[row.user_id for row in rows],
|
||||
"groups_key", token, users=[row.user_id for row in rows]
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error processing replication")
|
||||
|
@ -423,9 +429,7 @@ class SyncReplicationHandler(ReplicationClientHandler):
|
|||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse synchrotron", config_options
|
||||
)
|
||||
config = HomeServerConfig.load_config("Synapse synchrotron", config_options)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + str(e) + "\n")
|
||||
sys.exit(1)
|
||||
|
@ -453,6 +457,6 @@ def start(config_options):
|
|||
_base.start_worker_reactor("synapse-synchrotron", config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
|
|
|
@ -66,14 +66,16 @@ class UserDirectorySlaveStore(
|
|||
|
||||
events_max = self._stream_id_gen.get_current_token()
|
||||
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
|
||||
db_conn, "current_state_delta_stream",
|
||||
db_conn,
|
||||
"current_state_delta_stream",
|
||||
entity_column="room_id",
|
||||
stream_column="stream_id",
|
||||
max_value=events_max, # As we share the stream id with events token
|
||||
limit=1000,
|
||||
)
|
||||
self._curr_state_delta_stream_cache = StreamChangeCache(
|
||||
"_curr_state_delta_stream_cache", min_curr_state_delta_id,
|
||||
"_curr_state_delta_stream_cache",
|
||||
min_curr_state_delta_id,
|
||||
prefilled_cache=curr_state_delta_prefill,
|
||||
)
|
||||
|
||||
|
@ -110,12 +112,14 @@ class UserDirectoryServer(HomeServer):
|
|||
elif name == "client":
|
||||
resource = JsonResource(self, canonical_json=False)
|
||||
user_directory.register_servlets(self, resource)
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
})
|
||||
resources.update(
|
||||
{
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
}
|
||||
)
|
||||
|
||||
root_resource = create_resource_tree(resources, NoResource())
|
||||
|
||||
|
@ -128,7 +132,7 @@ class UserDirectoryServer(HomeServer):
|
|||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Synapse user_dir now listening on port %d", port)
|
||||
|
@ -142,18 +146,19 @@ class UserDirectoryServer(HomeServer):
|
|||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
username="matrix", password="rabbithole", globals={"hs": self}
|
||||
),
|
||||
)
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warn(("Metrics listener configured, but "
|
||||
"enable_metrics is not True!"))
|
||||
logger.warn(
|
||||
(
|
||||
"Metrics listener configured, but "
|
||||
"enable_metrics is not True!"
|
||||
)
|
||||
)
|
||||
else:
|
||||
_base.listen_metrics(listener["bind_addresses"],
|
||||
listener["port"])
|
||||
_base.listen_metrics(listener["bind_addresses"], listener["port"])
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
|
@ -186,9 +191,7 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
|
|||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse user directory", config_options
|
||||
)
|
||||
config = HomeServerConfig.load_config("Synapse user directory", config_options)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + str(e) + "\n")
|
||||
sys.exit(1)
|
||||
|
@ -227,6 +230,6 @@ def start(config_options):
|
|||
_base.start_worker_reactor("synapse-user-dir", config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
|
|
|
@ -48,9 +48,7 @@ class AppServiceTransaction(object):
|
|||
A Deferred which resolves to True if the transaction was sent.
|
||||
"""
|
||||
return as_api.push_bulk(
|
||||
service=self.service,
|
||||
events=self.events,
|
||||
txn_id=self.id
|
||||
service=self.service, events=self.events, txn_id=self.id
|
||||
)
|
||||
|
||||
def complete(self, store):
|
||||
|
@ -64,10 +62,7 @@ class AppServiceTransaction(object):
|
|||
Returns:
|
||||
A Deferred which resolves to True if the transaction was completed.
|
||||
"""
|
||||
return store.complete_appservice_txn(
|
||||
service=self.service,
|
||||
txn_id=self.id
|
||||
)
|
||||
return store.complete_appservice_txn(service=self.service, txn_id=self.id)
|
||||
|
||||
|
||||
class ApplicationService(object):
|
||||
|
@ -76,6 +71,7 @@ class ApplicationService(object):
|
|||
|
||||
Provides methods to check if this service is "interested" in events.
|
||||
"""
|
||||
|
||||
NS_USERS = "users"
|
||||
NS_ALIASES = "aliases"
|
||||
NS_ROOMS = "rooms"
|
||||
|
@ -84,9 +80,19 @@ class ApplicationService(object):
|
|||
# values.
|
||||
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
|
||||
|
||||
def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
|
||||
sender=None, id=None, protocols=None, rate_limited=True,
|
||||
ip_range_whitelist=None):
|
||||
def __init__(
|
||||
self,
|
||||
token,
|
||||
hostname,
|
||||
url=None,
|
||||
namespaces=None,
|
||||
hs_token=None,
|
||||
sender=None,
|
||||
id=None,
|
||||
protocols=None,
|
||||
rate_limited=True,
|
||||
ip_range_whitelist=None,
|
||||
):
|
||||
self.token = token
|
||||
self.url = url
|
||||
self.hs_token = hs_token
|
||||
|
@ -128,9 +134,7 @@ class ApplicationService(object):
|
|||
if not isinstance(regex_obj, dict):
|
||||
raise ValueError("Expected dict regex for ns '%s'" % ns)
|
||||
if not isinstance(regex_obj.get("exclusive"), bool):
|
||||
raise ValueError(
|
||||
"Expected bool for 'exclusive' in ns '%s'" % ns
|
||||
)
|
||||
raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
|
||||
group_id = regex_obj.get("group_id")
|
||||
if group_id:
|
||||
if not isinstance(group_id, str):
|
||||
|
@ -153,9 +157,7 @@ class ApplicationService(object):
|
|||
if isinstance(regex, string_types):
|
||||
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected string for 'regex' in ns '%s'" % ns
|
||||
)
|
||||
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
|
||||
return namespaces
|
||||
|
||||
def _matches_regex(self, test_string, namespace_key):
|
||||
|
@ -178,8 +180,9 @@ class ApplicationService(object):
|
|||
if self.is_interested_in_user(event.sender):
|
||||
defer.returnValue(True)
|
||||
# also check m.room.member state key
|
||||
if (event.type == EventTypes.Member and
|
||||
self.is_interested_in_user(event.state_key)):
|
||||
if event.type == EventTypes.Member and self.is_interested_in_user(
|
||||
event.state_key
|
||||
):
|
||||
defer.returnValue(True)
|
||||
|
||||
if not store:
|
||||
|
|
|
@ -32,19 +32,17 @@ logger = logging.getLogger(__name__)
|
|||
sent_transactions_counter = Counter(
|
||||
"synapse_appservice_api_sent_transactions",
|
||||
"Number of /transactions/ requests sent",
|
||||
["service"]
|
||||
["service"],
|
||||
)
|
||||
|
||||
failed_transactions_counter = Counter(
|
||||
"synapse_appservice_api_failed_transactions",
|
||||
"Number of /transactions/ requests that failed to send",
|
||||
["service"]
|
||||
["service"],
|
||||
)
|
||||
|
||||
sent_events_counter = Counter(
|
||||
"synapse_appservice_api_sent_events",
|
||||
"Number of events sent to the AS",
|
||||
["service"]
|
||||
"synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
|
||||
)
|
||||
|
||||
HOUR_IN_MS = 60 * 60 * 1000
|
||||
|
@ -92,8 +90,9 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
super(ApplicationServiceApi, self).__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta",
|
||||
timeout_ms=HOUR_IN_MS)
|
||||
self.protocol_meta_cache = ResponseCache(
|
||||
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_user(self, service, user_id):
|
||||
|
@ -102,9 +101,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
|
||||
response = None
|
||||
try:
|
||||
response = yield self.get_json(uri, {
|
||||
"access_token": service.hs_token
|
||||
})
|
||||
response = yield self.get_json(uri, {"access_token": service.hs_token})
|
||||
if response is not None: # just an empty json object
|
||||
defer.returnValue(True)
|
||||
except CodeMessageException as e:
|
||||
|
@ -123,9 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
|
||||
response = None
|
||||
try:
|
||||
response = yield self.get_json(uri, {
|
||||
"access_token": service.hs_token
|
||||
})
|
||||
response = yield self.get_json(uri, {"access_token": service.hs_token})
|
||||
if response is not None: # just an empty json object
|
||||
defer.returnValue(True)
|
||||
except CodeMessageException as e:
|
||||
|
@ -144,9 +139,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||
required_field = "alias"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
||||
)
|
||||
raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind)
|
||||
if service.url is None:
|
||||
defer.returnValue([])
|
||||
|
||||
|
@ -154,14 +147,13 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
service.url,
|
||||
APP_SERVICE_PREFIX,
|
||||
kind,
|
||||
urllib.parse.quote(protocol)
|
||||
urllib.parse.quote(protocol),
|
||||
)
|
||||
try:
|
||||
response = yield self.get_json(uri, fields)
|
||||
if not isinstance(response, list):
|
||||
logger.warning(
|
||||
"query_3pe to %s returned an invalid response %r",
|
||||
uri, response
|
||||
"query_3pe to %s returned an invalid response %r", uri, response
|
||||
)
|
||||
defer.returnValue([])
|
||||
|
||||
|
@ -171,8 +163,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
ret.append(r)
|
||||
else:
|
||||
logger.warning(
|
||||
"query_3pe to %s returned an invalid result %r",
|
||||
uri, r
|
||||
"query_3pe to %s returned an invalid result %r", uri, r
|
||||
)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
@ -189,27 +180,27 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
uri = "%s%s/thirdparty/protocol/%s" % (
|
||||
service.url,
|
||||
APP_SERVICE_PREFIX,
|
||||
urllib.parse.quote(protocol)
|
||||
urllib.parse.quote(protocol),
|
||||
)
|
||||
try:
|
||||
info = yield self.get_json(uri, {})
|
||||
|
||||
if not _is_valid_3pe_metadata(info):
|
||||
logger.warning("query_3pe_protocol to %s did not return a"
|
||||
" valid result", uri)
|
||||
logger.warning(
|
||||
"query_3pe_protocol to %s did not return a" " valid result", uri
|
||||
)
|
||||
defer.returnValue(None)
|
||||
|
||||
for instance in info.get("instances", []):
|
||||
network_id = instance.get("network_id", None)
|
||||
if network_id is not None:
|
||||
instance["instance_id"] = ThirdPartyInstanceID(
|
||||
service.id, network_id,
|
||||
service.id, network_id
|
||||
).to_string()
|
||||
|
||||
defer.returnValue(info)
|
||||
except Exception as ex:
|
||||
logger.warning("query_3pe_protocol to %s threw exception %s",
|
||||
uri, ex)
|
||||
logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex)
|
||||
defer.returnValue(None)
|
||||
|
||||
key = (service.id, protocol)
|
||||
|
@ -223,22 +214,19 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
events = self._serialize(events)
|
||||
|
||||
if txn_id is None:
|
||||
logger.warning("push_bulk: Missing txn ID sending events to %s",
|
||||
service.url)
|
||||
logger.warning(
|
||||
"push_bulk: Missing txn ID sending events to %s", service.url
|
||||
)
|
||||
txn_id = str(0)
|
||||
txn_id = str(txn_id)
|
||||
|
||||
uri = service.url + ("/transactions/%s" %
|
||||
urllib.parse.quote(txn_id))
|
||||
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
|
||||
try:
|
||||
yield self.put_json(
|
||||
uri=uri,
|
||||
json_body={
|
||||
"events": events
|
||||
},
|
||||
args={
|
||||
"access_token": service.hs_token
|
||||
})
|
||||
json_body={"events": events},
|
||||
args={"access_token": service.hs_token},
|
||||
)
|
||||
sent_transactions_counter.labels(service.id).inc()
|
||||
sent_events_counter.labels(service.id).inc(len(events))
|
||||
defer.returnValue(True)
|
||||
|
@ -252,6 +240,4 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
|
||||
def _serialize(self, events):
|
||||
time_now = self.clock.time_msec()
|
||||
return [
|
||||
serialize_event(e, time_now, as_client_event=True) for e in events
|
||||
]
|
||||
return [serialize_event(e, time_now, as_client_event=True) for e in events]
|
||||
|
|
|
@ -112,15 +112,14 @@ class _ServiceQueuer(object):
|
|||
return
|
||||
|
||||
run_as_background_process(
|
||||
"as-sender-%s" % (service.id, ),
|
||||
self._send_request, service,
|
||||
"as-sender-%s" % (service.id,), self._send_request, service
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_request(self, service):
|
||||
# sanity-check: we shouldn't get here if this service already has a sender
|
||||
# running.
|
||||
assert(service.id not in self.requests_in_flight)
|
||||
assert service.id not in self.requests_in_flight
|
||||
|
||||
self.requests_in_flight.add(service.id)
|
||||
try:
|
||||
|
@ -137,7 +136,6 @@ class _ServiceQueuer(object):
|
|||
|
||||
|
||||
class _TransactionController(object):
|
||||
|
||||
def __init__(self, clock, store, as_api, recoverer_fn):
|
||||
self.clock = clock
|
||||
self.store = store
|
||||
|
@ -149,10 +147,7 @@ class _TransactionController(object):
|
|||
@defer.inlineCallbacks
|
||||
def send(self, service, events):
|
||||
try:
|
||||
txn = yield self.store.create_appservice_txn(
|
||||
service=service,
|
||||
events=events
|
||||
)
|
||||
txn = yield self.store.create_appservice_txn(service=service, events=events)
|
||||
service_is_up = yield self._is_service_up(service)
|
||||
if service_is_up:
|
||||
sent = yield txn.send(self.as_api)
|
||||
|
@ -167,12 +162,12 @@ class _TransactionController(object):
|
|||
@defer.inlineCallbacks
|
||||
def on_recovered(self, recoverer):
|
||||
self.recoverers.remove(recoverer)
|
||||
logger.info("Successfully recovered application service AS ID %s",
|
||||
recoverer.service.id)
|
||||
logger.info(
|
||||
"Successfully recovered application service AS ID %s", recoverer.service.id
|
||||
)
|
||||
logger.info("Remaining active recoverers: %s", len(self.recoverers))
|
||||
yield self.store.set_appservice_state(
|
||||
recoverer.service,
|
||||
ApplicationServiceState.UP
|
||||
recoverer.service, ApplicationServiceState.UP
|
||||
)
|
||||
|
||||
def add_recoverers(self, recoverers):
|
||||
|
@ -184,13 +179,10 @@ class _TransactionController(object):
|
|||
@defer.inlineCallbacks
|
||||
def _start_recoverer(self, service):
|
||||
try:
|
||||
yield self.store.set_appservice_state(
|
||||
service,
|
||||
ApplicationServiceState.DOWN
|
||||
)
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
logger.info(
|
||||
"Application service falling behind. Starting recoverer. AS ID %s",
|
||||
service.id
|
||||
service.id,
|
||||
)
|
||||
recoverer = self.recoverer_fn(service, self.on_recovered)
|
||||
self.add_recoverers([recoverer])
|
||||
|
@ -205,19 +197,16 @@ class _TransactionController(object):
|
|||
|
||||
|
||||
class _Recoverer(object):
|
||||
|
||||
@staticmethod
|
||||
@defer.inlineCallbacks
|
||||
def start(clock, store, as_api, callback):
|
||||
services = yield store.get_appservices_by_state(
|
||||
ApplicationServiceState.DOWN
|
||||
)
|
||||
recoverers = [
|
||||
_Recoverer(clock, store, as_api, s, callback) for s in services
|
||||
]
|
||||
services = yield store.get_appservices_by_state(ApplicationServiceState.DOWN)
|
||||
recoverers = [_Recoverer(clock, store, as_api, s, callback) for s in services]
|
||||
for r in recoverers:
|
||||
logger.info("Starting recoverer for AS ID %s which was marked as "
|
||||
"DOWN", r.service.id)
|
||||
logger.info(
|
||||
"Starting recoverer for AS ID %s which was marked as " "DOWN",
|
||||
r.service.id,
|
||||
)
|
||||
r.recover()
|
||||
defer.returnValue(recoverers)
|
||||
|
||||
|
@ -232,9 +221,9 @@ class _Recoverer(object):
|
|||
def recover(self):
|
||||
def _retry():
|
||||
run_as_background_process(
|
||||
"as-recoverer-%s" % (self.service.id,),
|
||||
self.retry,
|
||||
"as-recoverer-%s" % (self.service.id,), self.retry
|
||||
)
|
||||
|
||||
self.clock.call_later((2 ** self.backoff_counter), _retry)
|
||||
|
||||
def _backoff(self):
|
||||
|
@ -248,8 +237,9 @@ class _Recoverer(object):
|
|||
try:
|
||||
txn = yield self.store.get_oldest_unsent_txn(self.service)
|
||||
if txn:
|
||||
logger.info("Retrying transaction %s for AS ID %s",
|
||||
txn.id, txn.service.id)
|
||||
logger.info(
|
||||
"Retrying transaction %s for AS ID %s", txn.id, txn.service.id
|
||||
)
|
||||
sent = yield txn.send(self.as_api)
|
||||
if sent:
|
||||
yield txn.complete(self.store)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2017-2018 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -134,11 +136,6 @@ class Config(object):
|
|||
with open(file_path) as file_stream:
|
||||
return file_stream.read()
|
||||
|
||||
@staticmethod
|
||||
def read_config_file(file_path):
|
||||
with open(file_path) as file_stream:
|
||||
return yaml.safe_load(file_stream)
|
||||
|
||||
def invoke_all(self, name, *args, **kargs):
|
||||
results = []
|
||||
for cls in type(self).mro():
|
||||
|
@ -153,12 +150,12 @@ class Config(object):
|
|||
server_name,
|
||||
generate_secrets=False,
|
||||
report_stats=None,
|
||||
open_private_ports=False,
|
||||
):
|
||||
"""Build a default configuration file
|
||||
|
||||
This is used both when the user explicitly asks us to generate a config file
|
||||
(eg with --generate_config), and before loading the config at runtime (to give
|
||||
a base which the config files override)
|
||||
This is used when the user explicitly asks us to generate a config file
|
||||
(eg with --generate_config).
|
||||
|
||||
Args:
|
||||
config_dir_path (str): The path where the config files are kept. Used to
|
||||
|
@ -177,25 +174,33 @@ class Config(object):
|
|||
report_stats (bool|None): Initial setting for the report_stats setting.
|
||||
If None, report_stats will be left unset.
|
||||
|
||||
open_private_ports (bool): True to leave private ports (such as the non-TLS
|
||||
HTTP listener) open to the internet.
|
||||
|
||||
Returns:
|
||||
str: the yaml config file
|
||||
"""
|
||||
default_config = "\n\n".join(
|
||||
return "\n\n".join(
|
||||
dedent(conf)
|
||||
for conf in self.invoke_all(
|
||||
"default_config",
|
||||
"generate_config_section",
|
||||
config_dir_path=config_dir_path,
|
||||
data_dir_path=data_dir_path,
|
||||
server_name=server_name,
|
||||
generate_secrets=generate_secrets,
|
||||
report_stats=report_stats,
|
||||
open_private_ports=open_private_ports,
|
||||
)
|
||||
)
|
||||
|
||||
return default_config
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, description, argv):
|
||||
"""Parse the commandline and config files
|
||||
|
||||
Doesn't support config-file-generation: used by the worker apps.
|
||||
|
||||
Returns: Config object.
|
||||
"""
|
||||
config_parser = argparse.ArgumentParser(description=description)
|
||||
config_parser.add_argument(
|
||||
"-c",
|
||||
|
@ -210,7 +215,7 @@ class Config(object):
|
|||
"--keys-directory",
|
||||
metavar="DIRECTORY",
|
||||
help="Where files such as certs and signing keys are stored when"
|
||||
" their location is given explicitly in the config."
|
||||
" their location is not given explicitly in the config."
|
||||
" Defaults to the directory containing the last config file",
|
||||
)
|
||||
|
||||
|
@ -222,8 +227,19 @@ class Config(object):
|
|||
|
||||
config_files = find_config_files(search_paths=config_args.config_path)
|
||||
|
||||
obj.read_config_files(
|
||||
config_files, keys_directory=config_args.keys_directory, generate_keys=False
|
||||
if not config_files:
|
||||
config_parser.error("Must supply a config file.")
|
||||
|
||||
if config_args.keys_directory:
|
||||
config_dir_path = config_args.keys_directory
|
||||
else:
|
||||
config_dir_path = os.path.dirname(config_files[-1])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
data_dir_path = os.getcwd()
|
||||
|
||||
config_dict = read_config_files(config_files)
|
||||
obj.parse_config_dict(
|
||||
config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
|
||||
)
|
||||
|
||||
obj.invoke_all("read_arguments", config_args)
|
||||
|
@ -232,6 +248,12 @@ class Config(object):
|
|||
|
||||
@classmethod
|
||||
def load_or_generate_config(cls, description, argv):
|
||||
"""Parse the commandline and config files
|
||||
|
||||
Supports generation of config files, so is used for the main homeserver app.
|
||||
|
||||
Returns: Config object, or None if --generate-config or --generate-keys was set
|
||||
"""
|
||||
config_parser = argparse.ArgumentParser(add_help=False)
|
||||
config_parser.add_argument(
|
||||
"-c",
|
||||
|
@ -241,37 +263,74 @@ class Config(object):
|
|||
help="Specify config file. Can be given multiple times and"
|
||||
" may specify directories containing *.yaml files.",
|
||||
)
|
||||
config_parser.add_argument(
|
||||
|
||||
generate_group = config_parser.add_argument_group("Config generation")
|
||||
generate_group.add_argument(
|
||||
"--generate-config",
|
||||
action="store_true",
|
||||
help="Generate a config file for the server name",
|
||||
help="Generate a config file, then exit.",
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"--report-stats",
|
||||
action="store",
|
||||
help="Whether the generated config reports anonymized usage statistics",
|
||||
choices=["yes", "no"],
|
||||
)
|
||||
config_parser.add_argument(
|
||||
generate_group.add_argument(
|
||||
"--generate-missing-configs",
|
||||
"--generate-keys",
|
||||
action="store_true",
|
||||
help="Generate any missing key files then exit",
|
||||
help="Generate any missing additional config files, then exit.",
|
||||
)
|
||||
config_parser.add_argument(
|
||||
generate_group.add_argument(
|
||||
"-H", "--server-name", help="The server name to generate a config file for."
|
||||
)
|
||||
generate_group.add_argument(
|
||||
"--report-stats",
|
||||
action="store",
|
||||
help="Whether the generated config reports anonymized usage statistics.",
|
||||
choices=["yes", "no"],
|
||||
)
|
||||
generate_group.add_argument(
|
||||
"--config-directory",
|
||||
"--keys-directory",
|
||||
metavar="DIRECTORY",
|
||||
help="Used with 'generate-*' options to specify where files such as"
|
||||
" signing keys should be stored, unless explicitly"
|
||||
" specified in the config.",
|
||||
help=(
|
||||
"Specify where additional config files such as signing keys and log"
|
||||
" config should be stored. Defaults to the same directory as the last"
|
||||
" config file."
|
||||
),
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"-H", "--server-name", help="The server name to generate a config file for"
|
||||
generate_group.add_argument(
|
||||
"--data-directory",
|
||||
metavar="DIRECTORY",
|
||||
help=(
|
||||
"Specify where data such as the media store and database file should be"
|
||||
" stored. Defaults to the current working directory."
|
||||
),
|
||||
)
|
||||
generate_group.add_argument(
|
||||
"--open-private-ports",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Leave private ports (such as the non-TLS HTTP listener) open to the"
|
||||
" internet. Do not use this unless you know what you are doing."
|
||||
),
|
||||
)
|
||||
|
||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||
|
||||
config_files = find_config_files(search_paths=config_args.config_path)
|
||||
|
||||
generate_keys = config_args.generate_keys
|
||||
if not config_files:
|
||||
config_parser.error(
|
||||
"Must supply a config file.\nA config file can be automatically"
|
||||
' generated using "--generate-config -H SERVER_NAME'
|
||||
' -c CONFIG-FILE"'
|
||||
)
|
||||
|
||||
if config_args.config_directory:
|
||||
config_dir_path = config_args.config_directory
|
||||
else:
|
||||
config_dir_path = os.path.dirname(config_files[-1])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
data_dir_path = os.getcwd()
|
||||
|
||||
generate_missing_configs = config_args.generate_missing_configs
|
||||
|
||||
obj = cls()
|
||||
|
||||
|
@ -281,19 +340,16 @@ class Config(object):
|
|||
"Please specify either --report-stats=yes or --report-stats=no\n\n"
|
||||
+ MISSING_REPORT_STATS_SPIEL
|
||||
)
|
||||
if not config_files:
|
||||
config_parser.error(
|
||||
"Must supply a config file.\nA config file can be automatically"
|
||||
" generated using \"--generate-config -H SERVER_NAME"
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
(config_path,) = config_files
|
||||
if not cls.path_exists(config_path):
|
||||
if config_args.keys_directory:
|
||||
config_dir_path = config_args.keys_directory
|
||||
print("Generating config file %s" % (config_path,))
|
||||
|
||||
if config_args.data_directory:
|
||||
data_dir_path = config_args.data_directory
|
||||
else:
|
||||
config_dir_path = os.path.dirname(config_path)
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
data_dir_path = os.getcwd()
|
||||
data_dir_path = os.path.abspath(data_dir_path)
|
||||
|
||||
server_name = config_args.server_name
|
||||
if not server_name:
|
||||
|
@ -304,22 +360,21 @@ class Config(object):
|
|||
|
||||
config_str = obj.generate_config(
|
||||
config_dir_path=config_dir_path,
|
||||
data_dir_path=os.getcwd(),
|
||||
data_dir_path=data_dir_path,
|
||||
server_name=server_name,
|
||||
report_stats=(config_args.report_stats == "yes"),
|
||||
generate_secrets=True,
|
||||
open_private_ports=config_args.open_private_ports,
|
||||
)
|
||||
|
||||
if not cls.path_exists(config_dir_path):
|
||||
os.makedirs(config_dir_path)
|
||||
with open(config_path, "w") as config_file:
|
||||
config_file.write(
|
||||
"# vim:ft=yaml\n\n"
|
||||
)
|
||||
config_file.write("# vim:ft=yaml\n\n")
|
||||
config_file.write(config_str)
|
||||
|
||||
config = yaml.safe_load(config_str)
|
||||
obj.invoke_all("generate_files", config)
|
||||
config_dict = yaml.safe_load(config_str)
|
||||
obj.generate_missing_files(config_dict, config_dir_path)
|
||||
|
||||
print(
|
||||
(
|
||||
|
@ -333,12 +388,12 @@ class Config(object):
|
|||
else:
|
||||
print(
|
||||
(
|
||||
"Config file %r already exists. Generating any missing key"
|
||||
"Config file %r already exists. Generating any missing config"
|
||||
" files."
|
||||
)
|
||||
% (config_path,)
|
||||
)
|
||||
generate_keys = True
|
||||
generate_missing_configs = True
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[config_parser],
|
||||
|
@ -349,66 +404,63 @@ class Config(object):
|
|||
obj.invoke_all("add_arguments", parser)
|
||||
args = parser.parse_args(remaining_args)
|
||||
|
||||
if not config_files:
|
||||
config_parser.error(
|
||||
"Must supply a config file.\nA config file can be automatically"
|
||||
" generated using \"--generate-config -H SERVER_NAME"
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
obj.read_config_files(
|
||||
config_files,
|
||||
keys_directory=config_args.keys_directory,
|
||||
generate_keys=generate_keys,
|
||||
)
|
||||
|
||||
if generate_keys:
|
||||
config_dict = read_config_files(config_files)
|
||||
if generate_missing_configs:
|
||||
obj.generate_missing_files(config_dict, config_dir_path)
|
||||
return None
|
||||
|
||||
obj.parse_config_dict(
|
||||
config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
|
||||
)
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
||||
return obj
|
||||
|
||||
def read_config_files(self, config_files, keys_directory=None, generate_keys=False):
|
||||
if not keys_directory:
|
||||
keys_directory = os.path.dirname(config_files[-1])
|
||||
def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
|
||||
"""Read the information from the config dict into this Config object.
|
||||
|
||||
self.config_dir_path = os.path.abspath(keys_directory)
|
||||
Args:
|
||||
config_dict (dict): Configuration data, as read from the yaml
|
||||
|
||||
specified_config = {}
|
||||
for config_file in config_files:
|
||||
yaml_config = self.read_config_file(config_file)
|
||||
specified_config.update(yaml_config)
|
||||
config_dir_path (str): The path where the config files are kept. Used to
|
||||
create filenames for things like the log config and the signing key.
|
||||
|
||||
if "server_name" not in specified_config:
|
||||
raise ConfigError(MISSING_SERVER_NAME)
|
||||
|
||||
server_name = specified_config["server_name"]
|
||||
config_string = self.generate_config(
|
||||
config_dir_path=self.config_dir_path,
|
||||
data_dir_path=os.getcwd(),
|
||||
server_name=server_name,
|
||||
generate_secrets=False,
|
||||
data_dir_path (str): The path where the data files are kept. Used to create
|
||||
filenames for things like the database and media store.
|
||||
"""
|
||||
self.invoke_all(
|
||||
"read_config",
|
||||
config_dict,
|
||||
config_dir_path=config_dir_path,
|
||||
data_dir_path=data_dir_path,
|
||||
)
|
||||
config = yaml.safe_load(config_string)
|
||||
config.pop("log_config")
|
||||
config.update(specified_config)
|
||||
|
||||
if "report_stats" not in config:
|
||||
raise ConfigError(
|
||||
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
|
||||
+ "\n"
|
||||
+ MISSING_REPORT_STATS_SPIEL
|
||||
)
|
||||
def generate_missing_files(self, config_dict, config_dir_path):
|
||||
self.invoke_all("generate_files", config_dict, config_dir_path)
|
||||
|
||||
if generate_keys:
|
||||
self.invoke_all("generate_files", config)
|
||||
return
|
||||
|
||||
self.parse_config_dict(config)
|
||||
def read_config_files(config_files):
|
||||
"""Read the config files into a dict
|
||||
|
||||
def parse_config_dict(self, config_dict):
|
||||
self.invoke_all("read_config", config_dict)
|
||||
Args:
|
||||
config_files (iterable[str]): A list of the config files to read
|
||||
|
||||
Returns: dict
|
||||
"""
|
||||
specified_config = {}
|
||||
for config_file in config_files:
|
||||
with open(config_file) as file_stream:
|
||||
yaml_config = yaml.safe_load(file_stream)
|
||||
specified_config.update(yaml_config)
|
||||
|
||||
if "server_name" not in specified_config:
|
||||
raise ConfigError(MISSING_SERVER_NAME)
|
||||
|
||||
if "report_stats" not in specified_config:
|
||||
raise ConfigError(
|
||||
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_SPIEL
|
||||
)
|
||||
return specified_config
|
||||
|
||||
|
||||
def find_config_files(search_paths):
|
||||
|
|
|
@ -18,17 +18,19 @@ from ._base import Config
|
|||
|
||||
|
||||
class ApiConfig(Config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.room_invite_state_types = config.get(
|
||||
"room_invite_state_types",
|
||||
[
|
||||
EventTypes.JoinRules,
|
||||
EventTypes.CanonicalAlias,
|
||||
EventTypes.RoomAvatar,
|
||||
EventTypes.RoomEncryption,
|
||||
EventTypes.Name,
|
||||
],
|
||||
)
|
||||
|
||||
def read_config(self, config):
|
||||
self.room_invite_state_types = config.get("room_invite_state_types", [
|
||||
EventTypes.JoinRules,
|
||||
EventTypes.CanonicalAlias,
|
||||
EventTypes.RoomAvatar,
|
||||
EventTypes.RoomEncryption,
|
||||
EventTypes.Name,
|
||||
])
|
||||
|
||||
def default_config(cls, **kwargs):
|
||||
def generate_config_section(cls, **kwargs):
|
||||
return """\
|
||||
## API Configuration ##
|
||||
|
||||
|
@ -40,4 +42,6 @@ class ApiConfig(Config):
|
|||
# - "{RoomAvatar}"
|
||||
# - "{RoomEncryption}"
|
||||
# - "{Name}"
|
||||
""".format(**vars(EventTypes))
|
||||
""".format(
|
||||
**vars(EventTypes)
|
||||
)
|
||||
|
|
|
@ -29,13 +29,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AppServiceConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.app_service_config_files = config.get("app_service_config_files", [])
|
||||
self.notify_appservices = config.get("notify_appservices", True)
|
||||
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
|
||||
|
||||
def default_config(cls, **kwargs):
|
||||
def generate_config_section(cls, **kwargs):
|
||||
return """\
|
||||
# A list of application service config files to use
|
||||
#
|
||||
|
@ -53,9 +52,7 @@ class AppServiceConfig(Config):
|
|||
def load_appservices(hostname, config_files):
|
||||
"""Returns a list of Application Services from the config files."""
|
||||
if not isinstance(config_files, list):
|
||||
logger.warning(
|
||||
"Expected %s to be a list of AS config files.", config_files
|
||||
)
|
||||
logger.warning("Expected %s to be a list of AS config files.", config_files)
|
||||
return []
|
||||
|
||||
# Dicts of value -> filename
|
||||
|
@ -66,22 +63,20 @@ def load_appservices(hostname, config_files):
|
|||
|
||||
for config_file in config_files:
|
||||
try:
|
||||
with open(config_file, 'r') as f:
|
||||
appservice = _load_appservice(
|
||||
hostname, yaml.safe_load(f), config_file
|
||||
)
|
||||
with open(config_file, "r") as f:
|
||||
appservice = _load_appservice(hostname, yaml.safe_load(f), config_file)
|
||||
if appservice.id in seen_ids:
|
||||
raise ConfigError(
|
||||
"Cannot reuse ID across application services: "
|
||||
"%s (files: %s, %s)" % (
|
||||
appservice.id, config_file, seen_ids[appservice.id],
|
||||
)
|
||||
"%s (files: %s, %s)"
|
||||
% (appservice.id, config_file, seen_ids[appservice.id])
|
||||
)
|
||||
seen_ids[appservice.id] = config_file
|
||||
if appservice.token in seen_as_tokens:
|
||||
raise ConfigError(
|
||||
"Cannot reuse as_token across application services: "
|
||||
"%s (files: %s, %s)" % (
|
||||
"%s (files: %s, %s)"
|
||||
% (
|
||||
appservice.token,
|
||||
config_file,
|
||||
seen_as_tokens[appservice.token],
|
||||
|
@ -98,28 +93,26 @@ def load_appservices(hostname, config_files):
|
|||
|
||||
|
||||
def _load_appservice(hostname, as_info, config_filename):
|
||||
required_string_fields = [
|
||||
"id", "as_token", "hs_token", "sender_localpart"
|
||||
]
|
||||
required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"]
|
||||
for field in required_string_fields:
|
||||
if not isinstance(as_info.get(field), string_types):
|
||||
raise KeyError("Required string field: '%s' (%s)" % (
|
||||
field, config_filename,
|
||||
))
|
||||
raise KeyError(
|
||||
"Required string field: '%s' (%s)" % (field, config_filename)
|
||||
)
|
||||
|
||||
# 'url' must either be a string or explicitly null, not missing
|
||||
# to avoid accidentally turning off push for ASes.
|
||||
if (not isinstance(as_info.get("url"), string_types) and
|
||||
as_info.get("url", "") is not None):
|
||||
if (
|
||||
not isinstance(as_info.get("url"), string_types)
|
||||
and as_info.get("url", "") is not None
|
||||
):
|
||||
raise KeyError(
|
||||
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
|
||||
)
|
||||
|
||||
localpart = as_info["sender_localpart"]
|
||||
if urlparse.quote(localpart) != localpart:
|
||||
raise ValueError(
|
||||
"sender_localpart needs characters which are not URL encoded."
|
||||
)
|
||||
raise ValueError("sender_localpart needs characters which are not URL encoded.")
|
||||
user = UserID(localpart, hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
|
@ -138,13 +131,12 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||
for regex_obj in as_info["namespaces"][ns]:
|
||||
if not isinstance(regex_obj, dict):
|
||||
raise ValueError(
|
||||
"Expected namespace entry in %s to be an object,"
|
||||
" but got %s", ns, regex_obj
|
||||
"Expected namespace entry in %s to be an object," " but got %s",
|
||||
ns,
|
||||
regex_obj,
|
||||
)
|
||||
if not isinstance(regex_obj.get("regex"), string_types):
|
||||
raise ValueError(
|
||||
"Missing/bad type 'regex' key in %s", regex_obj
|
||||
)
|
||||
raise ValueError("Missing/bad type 'regex' key in %s", regex_obj)
|
||||
if not isinstance(regex_obj.get("exclusive"), bool):
|
||||
raise ValueError(
|
||||
"Missing/bad type 'exclusive' key in %s", regex_obj
|
||||
|
@ -167,10 +159,8 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||
)
|
||||
|
||||
ip_range_whitelist = None
|
||||
if as_info.get('ip_range_whitelist'):
|
||||
ip_range_whitelist = IPSet(
|
||||
as_info.get('ip_range_whitelist')
|
||||
)
|
||||
if as_info.get("ip_range_whitelist"):
|
||||
ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist"))
|
||||
|
||||
return ApplicationService(
|
||||
token=as_info["as_token"],
|
||||
|
|
|
@ -16,8 +16,7 @@ from ._base import Config
|
|||
|
||||
|
||||
class CaptchaConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.recaptcha_private_key = config.get("recaptcha_private_key")
|
||||
self.recaptcha_public_key = config.get("recaptcha_public_key")
|
||||
self.enable_registration_captcha = config.get(
|
||||
|
@ -29,7 +28,7 @@ class CaptchaConfig(Config):
|
|||
"https://www.recaptcha.net/recaptcha/api/siteverify",
|
||||
)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
## Captcha ##
|
||||
# See docs/CAPTCHA_SETUP for full details of configuring this.
|
||||
|
|
|
@ -22,7 +22,7 @@ class CasConfig(Config):
|
|||
cas_server_url: URL of CAS server
|
||||
"""
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
cas_config = config.get("cas_config", None)
|
||||
if cas_config:
|
||||
self.cas_enabled = cas_config.get("enabled", True)
|
||||
|
@ -35,7 +35,7 @@ class CasConfig(Config):
|
|||
self.cas_service_url = None
|
||||
self.cas_required_attributes = {}
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Enable CAS for registration and login.
|
||||
#
|
||||
|
|
|
@ -84,35 +84,32 @@ class ConsentConfig(Config):
|
|||
self.user_consent_at_registration = False
|
||||
self.user_consent_policy_name = "Privacy Policy"
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
consent_config = config.get("user_consent")
|
||||
if consent_config is None:
|
||||
return
|
||||
self.user_consent_version = str(consent_config["version"])
|
||||
self.user_consent_template_dir = self.abspath(
|
||||
consent_config["template_dir"]
|
||||
)
|
||||
self.user_consent_template_dir = self.abspath(consent_config["template_dir"])
|
||||
if not path.isdir(self.user_consent_template_dir):
|
||||
raise ConfigError(
|
||||
"Could not find template directory '%s'" % (
|
||||
self.user_consent_template_dir,
|
||||
),
|
||||
"Could not find template directory '%s'"
|
||||
% (self.user_consent_template_dir,)
|
||||
)
|
||||
self.user_consent_server_notice_content = consent_config.get(
|
||||
"server_notice_content",
|
||||
"server_notice_content"
|
||||
)
|
||||
self.block_events_without_consent_error = consent_config.get(
|
||||
"block_events_error",
|
||||
"block_events_error"
|
||||
)
|
||||
self.user_consent_server_notice_to_guests = bool(
|
||||
consent_config.get("send_server_notice_to_guests", False)
|
||||
)
|
||||
self.user_consent_at_registration = bool(
|
||||
consent_config.get("require_at_registration", False)
|
||||
)
|
||||
self.user_consent_server_notice_to_guests = bool(consent_config.get(
|
||||
"send_server_notice_to_guests", False,
|
||||
))
|
||||
self.user_consent_at_registration = bool(consent_config.get(
|
||||
"require_at_registration", False,
|
||||
))
|
||||
self.user_consent_policy_name = consent_config.get(
|
||||
"policy_name", "Privacy Policy",
|
||||
"policy_name", "Privacy Policy"
|
||||
)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs):
|
||||
return DEFAULT_CONFIG
|
||||
|
|
|
@ -18,37 +18,30 @@ from ._base import Config
|
|||
|
||||
|
||||
class DatabaseConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
self.event_cache_size = self.parse_size(
|
||||
config.get("event_cache_size", "10K")
|
||||
)
|
||||
def read_config(self, config, **kwargs):
|
||||
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
|
||||
|
||||
self.database_config = config.get("database")
|
||||
|
||||
if self.database_config is None:
|
||||
self.database_config = {
|
||||
"name": "sqlite3",
|
||||
"args": {},
|
||||
}
|
||||
self.database_config = {"name": "sqlite3", "args": {}}
|
||||
|
||||
name = self.database_config.get("name", None)
|
||||
if name == "psycopg2":
|
||||
pass
|
||||
elif name == "sqlite3":
|
||||
self.database_config.setdefault("args", {}).update({
|
||||
"cp_min": 1,
|
||||
"cp_max": 1,
|
||||
"check_same_thread": False,
|
||||
})
|
||||
self.database_config.setdefault("args", {}).update(
|
||||
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unsupported database type '%s'" % (name,))
|
||||
|
||||
self.set_databasepath(config.get("database_path"))
|
||||
|
||||
def default_config(self, data_dir_path, **kwargs):
|
||||
def generate_config_section(self, data_dir_path, **kwargs):
|
||||
database_path = os.path.join(data_dir_path, "homeserver.db")
|
||||
return """\
|
||||
return (
|
||||
"""\
|
||||
## Database ##
|
||||
|
||||
database:
|
||||
|
@ -62,7 +55,9 @@ class DatabaseConfig(Config):
|
|||
# Number of events to cache in memory.
|
||||
#
|
||||
#event_cache_size: 10K
|
||||
""" % locals()
|
||||
"""
|
||||
% locals()
|
||||
)
|
||||
|
||||
def read_arguments(self, args):
|
||||
self.set_databasepath(args.database_path)
|
||||
|
@ -77,6 +72,8 @@ class DatabaseConfig(Config):
|
|||
def add_arguments(self, parser):
|
||||
db_group = parser.add_argument_group("database")
|
||||
db_group.add_argument(
|
||||
"-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
|
||||
help="The path to a sqlite database to use."
|
||||
"-d",
|
||||
"--database-path",
|
||||
metavar="SQLITE_DATABASE_PATH",
|
||||
help="The path to a sqlite database to use.",
|
||||
)
|
||||
|
|
|
@ -19,18 +19,15 @@ from __future__ import print_function
|
|||
|
||||
# This file can't be called email.py because if it is, we cannot:
|
||||
import email.utils
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmailConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
# TODO: We should separate better the email configuration from the notification
|
||||
# and account validity config.
|
||||
|
||||
|
@ -59,7 +56,7 @@ class EmailConfig(Config):
|
|||
if self.email_notif_from is not None:
|
||||
# make sure it's valid
|
||||
parsed = email.utils.parseaddr(self.email_notif_from)
|
||||
if parsed[1] == '':
|
||||
if parsed[1] == "":
|
||||
raise RuntimeError("Invalid notif_from address")
|
||||
|
||||
template_dir = email_config.get("template_dir")
|
||||
|
@ -68,27 +65,27 @@ class EmailConfig(Config):
|
|||
# (Note that loading as package_resources with jinja.PackageLoader doesn't
|
||||
# work for the same reason.)
|
||||
if not template_dir:
|
||||
template_dir = pkg_resources.resource_filename(
|
||||
'synapse', 'res/templates'
|
||||
)
|
||||
template_dir = pkg_resources.resource_filename("synapse", "res/templates")
|
||||
|
||||
self.email_template_dir = os.path.abspath(template_dir)
|
||||
|
||||
self.email_enable_notifs = email_config.get("enable_notifs", False)
|
||||
account_validity_renewal_enabled = config.get(
|
||||
"account_validity", {},
|
||||
).get("renew_at")
|
||||
account_validity_renewal_enabled = config.get("account_validity", {}).get(
|
||||
"renew_at"
|
||||
)
|
||||
|
||||
email_trust_identity_server_for_password_resets = email_config.get(
|
||||
"trust_identity_server_for_password_resets", False,
|
||||
"trust_identity_server_for_password_resets", False
|
||||
)
|
||||
self.email_password_reset_behaviour = (
|
||||
"remote" if email_trust_identity_server_for_password_resets else "local"
|
||||
)
|
||||
self.password_resets_were_disabled_due_to_email_config = False
|
||||
if self.email_password_reset_behaviour == "local" and email_config == {}:
|
||||
logger.warn(
|
||||
"User password resets have been disabled due to lack of email config"
|
||||
)
|
||||
# We cannot warn the user this has happened here
|
||||
# Instead do so when a user attempts to reset their password
|
||||
self.password_resets_were_disabled_due_to_email_config = True
|
||||
|
||||
self.email_password_reset_behaviour = "off"
|
||||
|
||||
# Get lifetime of a validation token in milliseconds
|
||||
|
@ -104,62 +101,59 @@ class EmailConfig(Config):
|
|||
# make sure we can import the required deps
|
||||
import jinja2
|
||||
import bleach
|
||||
|
||||
# prevent unused warnings
|
||||
jinja2
|
||||
bleach
|
||||
|
||||
if self.email_password_reset_behaviour == "local":
|
||||
required = [
|
||||
"smtp_host",
|
||||
"smtp_port",
|
||||
"notif_from",
|
||||
]
|
||||
required = ["smtp_host", "smtp_port", "notif_from"]
|
||||
|
||||
missing = []
|
||||
for k in required:
|
||||
if k not in email_config:
|
||||
missing.append(k)
|
||||
|
||||
if (len(missing) > 0):
|
||||
if len(missing) > 0:
|
||||
raise RuntimeError(
|
||||
"email.password_reset_behaviour is set to 'local' "
|
||||
"but required keys are missing: %s" %
|
||||
(", ".join(["email." + k for k in missing]),)
|
||||
"but required keys are missing: %s"
|
||||
% (", ".join(["email." + k for k in missing]),)
|
||||
)
|
||||
|
||||
# Templates for password reset emails
|
||||
self.email_password_reset_template_html = email_config.get(
|
||||
"password_reset_template_html", "password_reset.html",
|
||||
"password_reset_template_html", "password_reset.html"
|
||||
)
|
||||
self.email_password_reset_template_text = email_config.get(
|
||||
"password_reset_template_text", "password_reset.txt",
|
||||
"password_reset_template_text", "password_reset.txt"
|
||||
)
|
||||
self.email_password_reset_failure_template = email_config.get(
|
||||
"password_reset_failure_template", "password_reset_failure.html",
|
||||
"password_reset_failure_template", "password_reset_failure.html"
|
||||
)
|
||||
# This template does not support any replaceable variables, so we will
|
||||
# read it from the disk once during setup
|
||||
email_password_reset_success_template = email_config.get(
|
||||
"password_reset_success_template", "password_reset_success.html",
|
||||
"password_reset_success_template", "password_reset_success.html"
|
||||
)
|
||||
|
||||
# Check templates exist
|
||||
for f in [self.email_password_reset_template_html,
|
||||
self.email_password_reset_template_text,
|
||||
self.email_password_reset_failure_template,
|
||||
email_password_reset_success_template]:
|
||||
for f in [
|
||||
self.email_password_reset_template_html,
|
||||
self.email_password_reset_template_text,
|
||||
self.email_password_reset_failure_template,
|
||||
email_password_reset_success_template,
|
||||
]:
|
||||
p = os.path.join(self.email_template_dir, f)
|
||||
if not os.path.isfile(p):
|
||||
raise ConfigError("Unable to find template file %s" % (p, ))
|
||||
raise ConfigError("Unable to find template file %s" % (p,))
|
||||
|
||||
# Retrieve content of web templates
|
||||
filepath = os.path.join(
|
||||
self.email_template_dir,
|
||||
email_password_reset_success_template,
|
||||
self.email_template_dir, email_password_reset_success_template
|
||||
)
|
||||
self.email_password_reset_success_html_content = self.read_file(
|
||||
filepath,
|
||||
"email.password_reset_template_success_html",
|
||||
filepath, "email.password_reset_template_success_html"
|
||||
)
|
||||
|
||||
if config.get("public_baseurl") is None:
|
||||
|
@ -183,10 +177,10 @@ class EmailConfig(Config):
|
|||
if k not in email_config:
|
||||
missing.append(k)
|
||||
|
||||
if (len(missing) > 0):
|
||||
if len(missing) > 0:
|
||||
raise RuntimeError(
|
||||
"email.enable_notifs is True but required keys are missing: %s" %
|
||||
(", ".join(["email." + k for k in missing]),)
|
||||
"email.enable_notifs is True but required keys are missing: %s"
|
||||
% (", ".join(["email." + k for k in missing]),)
|
||||
)
|
||||
|
||||
if config.get("public_baseurl") is None:
|
||||
|
@ -200,29 +194,27 @@ class EmailConfig(Config):
|
|||
for f in self.email_notif_template_text, self.email_notif_template_html:
|
||||
p = os.path.join(self.email_template_dir, f)
|
||||
if not os.path.isfile(p):
|
||||
raise ConfigError("Unable to find email template file %s" % (p, ))
|
||||
raise ConfigError("Unable to find email template file %s" % (p,))
|
||||
|
||||
self.email_notif_for_new_users = email_config.get(
|
||||
"notif_for_new_users", True
|
||||
)
|
||||
self.email_riot_base_url = email_config.get(
|
||||
"riot_base_url", None
|
||||
)
|
||||
self.email_riot_base_url = email_config.get("riot_base_url", None)
|
||||
|
||||
if account_validity_renewal_enabled:
|
||||
self.email_expiry_template_html = email_config.get(
|
||||
"expiry_template_html", "notice_expiry.html",
|
||||
"expiry_template_html", "notice_expiry.html"
|
||||
)
|
||||
self.email_expiry_template_text = email_config.get(
|
||||
"expiry_template_text", "notice_expiry.txt",
|
||||
"expiry_template_text", "notice_expiry.txt"
|
||||
)
|
||||
|
||||
for f in self.email_expiry_template_text, self.email_expiry_template_html:
|
||||
p = os.path.join(self.email_template_dir, f)
|
||||
if not os.path.isfile(p):
|
||||
raise ConfigError("Unable to find email template file %s" % (p, ))
|
||||
raise ConfigError("Unable to find email template file %s" % (p,))
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Enable sending emails for password resets, notification events or
|
||||
# account expiry notices
|
||||
|
|
|
@ -17,11 +17,11 @@ from ._base import Config
|
|||
|
||||
|
||||
class GroupsConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.enable_group_creation = config.get("enable_group_creation", False)
|
||||
self.group_creation_prefix = config.get("group_creation_prefix", "")
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
# Uncomment to allow non-server-admin users to create groups on this server
|
||||
#
|
||||
|
|
|
@ -38,6 +38,7 @@ from .server import ServerConfig
|
|||
from .server_notices_config import ServerNoticesConfig
|
||||
from .spam_checker import SpamCheckerConfig
|
||||
from .stats import StatsConfig
|
||||
from .third_party_event_rules import ThirdPartyRulesConfig
|
||||
from .tls import TlsConfig
|
||||
from .user_directory import UserDirectoryConfig
|
||||
from .voip import VoipConfig
|
||||
|
@ -73,5 +74,6 @@ class HomeServerConfig(
|
|||
StatsConfig,
|
||||
ServerNoticesConfig,
|
||||
RoomDirectoryConfig,
|
||||
ThirdPartyRulesConfig,
|
||||
):
|
||||
pass
|
||||
|
|
|
@ -15,17 +15,15 @@
|
|||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
MISSING_JWT = (
|
||||
"""Missing jwt library. This is required for jwt login.
|
||||
MISSING_JWT = """Missing jwt library. This is required for jwt login.
|
||||
|
||||
Install by running:
|
||||
pip install pyjwt
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class JWTConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
jwt_config = config.get("jwt_config", None)
|
||||
if jwt_config:
|
||||
self.jwt_enabled = jwt_config.get("enabled", False)
|
||||
|
@ -34,6 +32,7 @@ class JWTConfig(Config):
|
|||
|
||||
try:
|
||||
import jwt
|
||||
|
||||
jwt # To stop unused lint.
|
||||
except ImportError:
|
||||
raise ConfigError(MISSING_JWT)
|
||||
|
@ -42,7 +41,7 @@ class JWTConfig(Config):
|
|||
self.jwt_secret = None
|
||||
self.jwt_algorithm = None
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
# The JWT needs to contain a globally unique "sub" (subject) claim.
|
||||
#
|
||||
|
|
|
@ -65,13 +65,18 @@ class TrustedKeyServer(object):
|
|||
|
||||
|
||||
class KeyConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, config_dir_path, **kwargs):
|
||||
# the signing key can be specified inline or in a separate file
|
||||
if "signing_key" in config:
|
||||
self.signing_key = read_signing_keys([config["signing_key"]])
|
||||
else:
|
||||
self.signing_key_path = config["signing_key_path"]
|
||||
self.signing_key = self.read_signing_key(self.signing_key_path)
|
||||
signing_key_path = config.get("signing_key_path")
|
||||
if signing_key_path is None:
|
||||
signing_key_path = os.path.join(
|
||||
config_dir_path, config["server_name"] + ".signing.key"
|
||||
)
|
||||
|
||||
self.signing_key = self.read_signing_key(signing_key_path)
|
||||
|
||||
self.old_signing_keys = self.read_old_signing_keys(
|
||||
config.get("old_signing_keys", {})
|
||||
|
@ -117,7 +122,7 @@ class KeyConfig(Config):
|
|||
# falsification of values
|
||||
self.form_secret = config.get("form_secret", None)
|
||||
|
||||
def default_config(
|
||||
def generate_config_section(
|
||||
self, config_dir_path, server_name, generate_secrets=False, **kwargs
|
||||
):
|
||||
base_key_name = os.path.join(config_dir_path, server_name)
|
||||
|
@ -237,10 +242,18 @@ class KeyConfig(Config):
|
|||
)
|
||||
return keys
|
||||
|
||||
def generate_files(self, config):
|
||||
signing_key_path = config["signing_key_path"]
|
||||
def generate_files(self, config, config_dir_path):
|
||||
if "signing_key" in config:
|
||||
return
|
||||
|
||||
signing_key_path = config.get("signing_key_path")
|
||||
if signing_key_path is None:
|
||||
signing_key_path = os.path.join(
|
||||
config_dir_path, config["server_name"] + ".signing.key"
|
||||
)
|
||||
|
||||
if not self.path_exists(signing_key_path):
|
||||
print("Generating signing key file %s" % (signing_key_path,))
|
||||
with open(signing_key_path, "w") as signing_key_file:
|
||||
key_id = "a_" + random_string(4)
|
||||
write_signing_keys(signing_key_file, (generate_signing_key(key_id),))
|
||||
|
@ -348,9 +361,8 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
|
|||
|
||||
result.verify_keys[key_id] = verify_key
|
||||
|
||||
if (
|
||||
not federation_verify_certificates and
|
||||
not server.get("accept_keys_insecurely")
|
||||
if not federation_verify_certificates and not server.get(
|
||||
"accept_keys_insecurely"
|
||||
):
|
||||
_assert_keyserver_has_verify_keys(result)
|
||||
|
||||
|
|
|
@ -29,7 +29,8 @@ from synapse.util.versionstring import get_version_string
|
|||
|
||||
from ._base import Config
|
||||
|
||||
DEFAULT_LOG_CONFIG = Template("""
|
||||
DEFAULT_LOG_CONFIG = Template(
|
||||
"""
|
||||
version: 1
|
||||
|
||||
formatters:
|
||||
|
@ -68,26 +69,29 @@ loggers:
|
|||
root:
|
||||
level: INFO
|
||||
handlers: [file, console]
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class LoggingConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.verbosity = config.get("verbose", 0)
|
||||
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
|
||||
self.log_config = self.abspath(config.get("log_config"))
|
||||
self.log_file = self.abspath(config.get("log_file"))
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
log_config = os.path.join(config_dir_path, server_name + ".log.config")
|
||||
return """\
|
||||
return (
|
||||
"""\
|
||||
## Logging ##
|
||||
|
||||
# A yaml python logging config file
|
||||
#
|
||||
log_config: "%(log_config)s"
|
||||
""" % locals()
|
||||
"""
|
||||
% locals()
|
||||
)
|
||||
|
||||
def read_arguments(self, args):
|
||||
if args.verbose is not None:
|
||||
|
@ -102,32 +106,43 @@ class LoggingConfig(Config):
|
|||
def add_arguments(cls, parser):
|
||||
logging_group = parser.add_argument_group("logging")
|
||||
logging_group.add_argument(
|
||||
'-v', '--verbose', dest="verbose", action='count',
|
||||
"-v",
|
||||
"--verbose",
|
||||
dest="verbose",
|
||||
action="count",
|
||||
help="The verbosity level. Specify multiple times to increase "
|
||||
"verbosity. (Ignored if --log-config is specified.)"
|
||||
"verbosity. (Ignored if --log-config is specified.)",
|
||||
)
|
||||
logging_group.add_argument(
|
||||
'-f', '--log-file', dest="log_file",
|
||||
help="File to log to. (Ignored if --log-config is specified.)"
|
||||
"-f",
|
||||
"--log-file",
|
||||
dest="log_file",
|
||||
help="File to log to. (Ignored if --log-config is specified.)",
|
||||
)
|
||||
logging_group.add_argument(
|
||||
'--log-config', dest="log_config", default=None,
|
||||
help="Python logging config file"
|
||||
"--log-config",
|
||||
dest="log_config",
|
||||
default=None,
|
||||
help="Python logging config file",
|
||||
)
|
||||
logging_group.add_argument(
|
||||
'-n', '--no-redirect-stdio',
|
||||
action='store_true', default=None,
|
||||
help="Do not redirect stdout/stderr to the log"
|
||||
"-n",
|
||||
"--no-redirect-stdio",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="Do not redirect stdout/stderr to the log",
|
||||
)
|
||||
|
||||
def generate_files(self, config):
|
||||
def generate_files(self, config, config_dir_path):
|
||||
log_config = config.get("log_config")
|
||||
if log_config and not os.path.exists(log_config):
|
||||
log_file = self.abspath("homeserver.log")
|
||||
print(
|
||||
"Generating log config file %s which will log to %s"
|
||||
% (log_config, log_file)
|
||||
)
|
||||
with open(log_config, "w") as log_config_file:
|
||||
log_config_file.write(
|
||||
DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
|
||||
)
|
||||
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
|
||||
|
||||
|
||||
def setup_logging(config, use_worker_options=False):
|
||||
|
@ -143,10 +158,8 @@ def setup_logging(config, use_worker_options=False):
|
|||
register_sighup (func | None): Function to call to register a
|
||||
sighup handler.
|
||||
"""
|
||||
log_config = (config.worker_log_config if use_worker_options
|
||||
else config.log_config)
|
||||
log_file = (config.worker_log_file if use_worker_options
|
||||
else config.log_file)
|
||||
log_config = config.worker_log_config if use_worker_options else config.log_config
|
||||
log_file = config.worker_log_file if use_worker_options else config.log_file
|
||||
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||
|
@ -164,23 +177,23 @@ def setup_logging(config, use_worker_options=False):
|
|||
if config.verbosity > 1:
|
||||
level_for_storage = logging.DEBUG
|
||||
|
||||
logger = logging.getLogger('')
|
||||
logger = logging.getLogger("")
|
||||
logger.setLevel(level)
|
||||
|
||||
logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
|
||||
logging.getLogger("synapse.storage.SQL").setLevel(level_for_storage)
|
||||
|
||||
formatter = logging.Formatter(log_format)
|
||||
if log_file:
|
||||
# TODO: Customisable file size / backup count
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
|
||||
encoding='utf8'
|
||||
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, encoding="utf8"
|
||||
)
|
||||
|
||||
def sighup(signum, stack):
|
||||
logger.info("Closing log file due to SIGHUP")
|
||||
handler.doRollover()
|
||||
logger.info("Opened new log file due to SIGHUP")
|
||||
|
||||
else:
|
||||
handler = logging.StreamHandler()
|
||||
|
||||
|
@ -193,8 +206,9 @@ def setup_logging(config, use_worker_options=False):
|
|||
|
||||
logger.addHandler(handler)
|
||||
else:
|
||||
|
||||
def load_log_config():
|
||||
with open(log_config, 'r') as f:
|
||||
with open(log_config, "r") as f:
|
||||
logging.config.dictConfig(yaml.safe_load(f))
|
||||
|
||||
def sighup(*args):
|
||||
|
@ -209,10 +223,7 @@ def setup_logging(config, use_worker_options=False):
|
|||
# make sure that the first thing we log is a thing we can grep backwards
|
||||
# for
|
||||
logging.warn("***** STARTING SERVER *****")
|
||||
logging.warn(
|
||||
"Server %s version %s",
|
||||
sys.argv[0], get_version_string(synapse),
|
||||
)
|
||||
logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
|
||||
logging.info("Server hostname: %s", config.server_name)
|
||||
|
||||
# It's critical to point twisted's internal logging somewhere, otherwise it
|
||||
|
@ -242,8 +253,7 @@ def setup_logging(config, use_worker_options=False):
|
|||
return observer(event)
|
||||
|
||||
globalLogBeginner.beginLoggingTo(
|
||||
[_log],
|
||||
redirectStandardIO=not config.no_redirect_stdio,
|
||||
[_log], redirectStandardIO=not config.no_redirect_stdio
|
||||
)
|
||||
if not config.no_redirect_stdio:
|
||||
print("Redirected stdout/stderr to logs")
|
||||
|
|
|
@ -15,15 +15,13 @@
|
|||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
MISSING_SENTRY = (
|
||||
"""Missing sentry-sdk library. This is required to enable sentry
|
||||
MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentry
|
||||
integration.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class MetricsConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.enable_metrics = config.get("enable_metrics", False)
|
||||
self.report_stats = config.get("report_stats", None)
|
||||
self.metrics_port = config.get("metrics_port")
|
||||
|
@ -39,10 +37,10 @@ class MetricsConfig(Config):
|
|||
self.sentry_dsn = config["sentry"].get("dsn")
|
||||
if not self.sentry_dsn:
|
||||
raise ConfigError(
|
||||
"sentry.dsn field is required when sentry integration is enabled",
|
||||
"sentry.dsn field is required when sentry integration is enabled"
|
||||
)
|
||||
|
||||
def default_config(self, report_stats=None, **kwargs):
|
||||
def generate_config_section(self, report_stats=None, **kwargs):
|
||||
res = """\
|
||||
## Metrics ###
|
||||
|
||||
|
@ -66,6 +64,6 @@ class MetricsConfig(Config):
|
|||
if report_stats is None:
|
||||
res += "# report_stats: true|false\n"
|
||||
else:
|
||||
res += "report_stats: %s\n" % ('true' if report_stats else 'false')
|
||||
res += "report_stats: %s\n" % ("true" if report_stats else "false")
|
||||
|
||||
return res
|
||||
|
|
|
@ -20,7 +20,7 @@ class PasswordConfig(Config):
|
|||
"""Password login configuration
|
||||
"""
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
password_config = config.get("password_config", {})
|
||||
if password_config is None:
|
||||
password_config = {}
|
||||
|
@ -28,7 +28,7 @@ class PasswordConfig(Config):
|
|||
self.password_enabled = password_config.get("enabled", True)
|
||||
self.password_pepper = password_config.get("pepper", "")
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """\
|
||||
password_config:
|
||||
# Uncomment to disable password login
|
||||
|
|
|
@ -17,11 +17,11 @@ from synapse.util.module_loader import load_module
|
|||
|
||||
from ._base import Config
|
||||
|
||||
LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
|
||||
LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
|
||||
|
||||
|
||||
class PasswordAuthProviderConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.password_providers = []
|
||||
providers = []
|
||||
|
||||
|
@ -29,28 +29,24 @@ class PasswordAuthProviderConfig(Config):
|
|||
# param.
|
||||
ldap_config = config.get("ldap_config", {})
|
||||
if ldap_config.get("enabled", False):
|
||||
providers.append({
|
||||
'module': LDAP_PROVIDER,
|
||||
'config': ldap_config,
|
||||
})
|
||||
providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
|
||||
|
||||
providers.extend(config.get("password_providers", []))
|
||||
for provider in providers:
|
||||
mod_name = provider['module']
|
||||
mod_name = provider["module"]
|
||||
|
||||
# This is for backwards compat when the ldap auth provider resided
|
||||
# in this package.
|
||||
if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
||||
mod_name = LDAP_PROVIDER
|
||||
|
||||
(provider_class, provider_config) = load_module({
|
||||
"module": mod_name,
|
||||
"config": provider['config'],
|
||||
})
|
||||
(provider_class, provider_config) = load_module(
|
||||
{"module": mod_name, "config": provider["config"]}
|
||||
)
|
||||
|
||||
self.password_providers.append((provider_class, provider_config))
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
#password_providers:
|
||||
# - module: "ldap_auth_provider.LdapAuthProvider"
|
||||
|
|
|
@ -18,7 +18,7 @@ from ._base import Config
|
|||
|
||||
|
||||
class PushConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
push_config = config.get("push", {})
|
||||
self.push_include_content = push_config.get("include_content", True)
|
||||
|
||||
|
@ -42,7 +42,7 @@ class PushConfig(Config):
|
|||
)
|
||||
self.push_include_content = not redact_content
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Clients requesting push notifications can either have the body of
|
||||
# the message sent in the notification poke along with other details
|
||||
|
|
|
@ -36,7 +36,7 @@ class FederationRateLimitConfig(object):
|
|||
|
||||
|
||||
class RatelimitConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
|
||||
# Load the new-style messages config if it exists. Otherwise fall back
|
||||
# to the old method.
|
||||
|
@ -80,7 +80,7 @@ class RatelimitConfig(Config):
|
|||
"federation_rr_transactions_per_room_per_second", 50
|
||||
)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
## Ratelimiting ##
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from synapse.util.stringutils import random_string_with_symbols
|
|||
class AccountValidityConfig(Config):
|
||||
def __init__(self, config, synapse_config):
|
||||
self.enabled = config.get("enabled", False)
|
||||
self.renew_by_email_enabled = ("renew_at" in config)
|
||||
self.renew_by_email_enabled = "renew_at" in config
|
||||
|
||||
if self.enabled:
|
||||
if "period" in config:
|
||||
|
@ -39,15 +39,14 @@ class AccountValidityConfig(Config):
|
|||
else:
|
||||
self.renew_email_subject = "Renew your %(app)s account"
|
||||
|
||||
self.startup_job_max_delta = self.period * 10. / 100.
|
||||
self.startup_job_max_delta = self.period * 10.0 / 100.0
|
||||
|
||||
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
|
||||
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
|
||||
|
||||
|
||||
class RegistrationConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.enable_registration = bool(
|
||||
strtobool(str(config.get("enable_registration", False)))
|
||||
)
|
||||
|
@ -57,7 +56,7 @@ class RegistrationConfig(Config):
|
|||
)
|
||||
|
||||
self.account_validity = AccountValidityConfig(
|
||||
config.get("account_validity", {}), config,
|
||||
config.get("account_validity", {}), config
|
||||
)
|
||||
|
||||
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||
|
@ -67,35 +66,37 @@ class RegistrationConfig(Config):
|
|||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
self.trusted_third_party_id_servers = config.get(
|
||||
"trusted_third_party_id_servers",
|
||||
["matrix.org", "vector.im"],
|
||||
"trusted_third_party_id_servers", ["matrix.org", "vector.im"]
|
||||
)
|
||||
self.default_identity_server = config.get("default_identity_server")
|
||||
self.allow_guest_access = config.get("allow_guest_access", False)
|
||||
|
||||
self.invite_3pid_guest = (
|
||||
self.allow_guest_access and config.get("invite_3pid_guest", False)
|
||||
self.invite_3pid_guest = self.allow_guest_access and config.get(
|
||||
"invite_3pid_guest", False
|
||||
)
|
||||
|
||||
self.auto_join_rooms = config.get("auto_join_rooms", [])
|
||||
for room_alias in self.auto_join_rooms:
|
||||
if not RoomAlias.is_valid(room_alias):
|
||||
raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
|
||||
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
|
||||
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
|
||||
|
||||
self.disable_msisdn_registration = (
|
||||
config.get("disable_msisdn_registration", False)
|
||||
self.disable_msisdn_registration = config.get(
|
||||
"disable_msisdn_registration", False
|
||||
)
|
||||
|
||||
def default_config(self, generate_secrets=False, **kwargs):
|
||||
def generate_config_section(self, generate_secrets=False, **kwargs):
|
||||
if generate_secrets:
|
||||
registration_shared_secret = 'registration_shared_secret: "%s"' % (
|
||||
random_string_with_symbols(50),
|
||||
)
|
||||
else:
|
||||
registration_shared_secret = '# registration_shared_secret: <PRIVATE STRING>'
|
||||
registration_shared_secret = (
|
||||
"# registration_shared_secret: <PRIVATE STRING>"
|
||||
)
|
||||
|
||||
return """\
|
||||
return (
|
||||
"""\
|
||||
## Registration ##
|
||||
#
|
||||
# Registration can be rate-limited using the parameters in the "Ratelimiting"
|
||||
|
@ -217,17 +218,19 @@ class RegistrationConfig(Config):
|
|||
# users cannot be auto-joined since they do not exist.
|
||||
#
|
||||
#autocreate_auto_join_rooms: true
|
||||
""" % locals()
|
||||
"""
|
||||
% locals()
|
||||
)
|
||||
|
||||
def add_arguments(self, parser):
|
||||
reg_group = parser.add_argument_group("registration")
|
||||
reg_group.add_argument(
|
||||
"--enable-registration", action="store_true", default=None,
|
||||
help="Enable registration for new users."
|
||||
"--enable-registration",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="Enable registration for new users.",
|
||||
)
|
||||
|
||||
def read_arguments(self, args):
|
||||
if args.enable_registration is not None:
|
||||
self.enable_registration = bool(
|
||||
strtobool(str(args.enable_registration))
|
||||
)
|
||||
self.enable_registration = bool(strtobool(str(args.enable_registration)))
|
||||
|
|
|
@ -20,27 +20,11 @@ from synapse.util.module_loader import load_module
|
|||
from ._base import Config, ConfigError
|
||||
|
||||
DEFAULT_THUMBNAIL_SIZES = [
|
||||
{
|
||||
"width": 32,
|
||||
"height": 32,
|
||||
"method": "crop",
|
||||
}, {
|
||||
"width": 96,
|
||||
"height": 96,
|
||||
"method": "crop",
|
||||
}, {
|
||||
"width": 320,
|
||||
"height": 240,
|
||||
"method": "scale",
|
||||
}, {
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
"method": "scale",
|
||||
}, {
|
||||
"width": 800,
|
||||
"height": 600,
|
||||
"method": "scale"
|
||||
},
|
||||
{"width": 32, "height": 32, "method": "crop"},
|
||||
{"width": 96, "height": 96, "method": "crop"},
|
||||
{"width": 320, "height": 240, "method": "scale"},
|
||||
{"width": 640, "height": 480, "method": "scale"},
|
||||
{"width": 800, "height": 600, "method": "scale"},
|
||||
]
|
||||
|
||||
THUMBNAIL_SIZE_YAML = """\
|
||||
|
@ -49,19 +33,15 @@ THUMBNAIL_SIZE_YAML = """\
|
|||
# method: %(method)s
|
||||
"""
|
||||
|
||||
MISSING_NETADDR = (
|
||||
"Missing netaddr library. This is required for URL preview API."
|
||||
)
|
||||
MISSING_NETADDR = "Missing netaddr library. This is required for URL preview API."
|
||||
|
||||
MISSING_LXML = (
|
||||
"""Missing lxml library. This is required for URL preview API.
|
||||
MISSING_LXML = """Missing lxml library. This is required for URL preview API.
|
||||
|
||||
Install by running:
|
||||
pip install lxml
|
||||
|
||||
Requires libxslt1-dev system package.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
ThumbnailRequirement = namedtuple(
|
||||
|
@ -69,7 +49,8 @@ ThumbnailRequirement = namedtuple(
|
|||
)
|
||||
|
||||
MediaStorageProviderConfig = namedtuple(
|
||||
"MediaStorageProviderConfig", (
|
||||
"MediaStorageProviderConfig",
|
||||
(
|
||||
"store_local", # Whether to store newly uploaded local files
|
||||
"store_remote", # Whether to store newly downloaded remote files
|
||||
"store_synchronous", # Whether to wait for successful storage for local uploads
|
||||
|
@ -100,18 +81,19 @@ def parse_thumbnail_requirements(thumbnail_sizes):
|
|||
requirements.setdefault("image/gif", []).append(png_thumbnail)
|
||||
requirements.setdefault("image/png", []).append(png_thumbnail)
|
||||
return {
|
||||
media_type: tuple(thumbnails)
|
||||
for media_type, thumbnails in requirements.items()
|
||||
media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items()
|
||||
}
|
||||
|
||||
|
||||
class ContentRepositoryConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M"))
|
||||
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
|
||||
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
|
||||
|
||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||
self.media_store_path = self.ensure_directory(
|
||||
config.get("media_store_path", "media_store")
|
||||
)
|
||||
|
||||
backup_media_store_path = config.get("backup_media_store_path")
|
||||
|
||||
|
@ -127,15 +109,15 @@ class ContentRepositoryConfig(Config):
|
|||
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
|
||||
)
|
||||
|
||||
storage_providers = [{
|
||||
"module": "file_system",
|
||||
"store_local": True,
|
||||
"store_synchronous": synchronous_backup_media_store,
|
||||
"store_remote": True,
|
||||
"config": {
|
||||
"directory": backup_media_store_path,
|
||||
storage_providers = [
|
||||
{
|
||||
"module": "file_system",
|
||||
"store_local": True,
|
||||
"store_synchronous": synchronous_backup_media_store,
|
||||
"store_remote": True,
|
||||
"config": {"directory": backup_media_store_path},
|
||||
}
|
||||
}]
|
||||
]
|
||||
|
||||
# This is a list of config that can be used to create the storage
|
||||
# providers. The entries are tuples of (Class, class_config,
|
||||
|
@ -165,18 +147,19 @@ class ContentRepositoryConfig(Config):
|
|||
)
|
||||
|
||||
self.media_storage_providers.append(
|
||||
(provider_class, parsed_config, wrapper_config,)
|
||||
(provider_class, parsed_config, wrapper_config)
|
||||
)
|
||||
|
||||
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
||||
self.uploads_path = self.ensure_directory(config.get("uploads_path", "uploads"))
|
||||
self.dynamic_thumbnails = config.get("dynamic_thumbnails", False)
|
||||
self.thumbnail_requirements = parse_thumbnail_requirements(
|
||||
config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES),
|
||||
config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES)
|
||||
)
|
||||
self.url_preview_enabled = config.get("url_preview_enabled", False)
|
||||
if self.url_preview_enabled:
|
||||
try:
|
||||
import lxml
|
||||
|
||||
lxml # To stop unused lint.
|
||||
except ImportError:
|
||||
raise ConfigError(MISSING_LXML)
|
||||
|
@ -199,17 +182,15 @@ class ContentRepositoryConfig(Config):
|
|||
|
||||
# we always blacklist '0.0.0.0' and '::', which are supposed to be
|
||||
# unroutable addresses.
|
||||
self.url_preview_ip_range_blacklist.update(['0.0.0.0', '::'])
|
||||
self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"])
|
||||
|
||||
self.url_preview_ip_range_whitelist = IPSet(
|
||||
config.get("url_preview_ip_range_whitelist", ())
|
||||
)
|
||||
|
||||
self.url_preview_url_blacklist = config.get(
|
||||
"url_preview_url_blacklist", ()
|
||||
)
|
||||
self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ())
|
||||
|
||||
def default_config(self, data_dir_path, **kwargs):
|
||||
def generate_config_section(self, data_dir_path, **kwargs):
|
||||
media_store = os.path.join(data_dir_path, "media_store")
|
||||
uploads_path = os.path.join(data_dir_path, "uploads")
|
||||
|
||||
|
@ -219,7 +200,8 @@ class ContentRepositoryConfig(Config):
|
|||
# strip final NL
|
||||
formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1]
|
||||
|
||||
return r"""
|
||||
return (
|
||||
r"""
|
||||
# Directory where uploaded images and attachments are stored.
|
||||
#
|
||||
media_store_path: "%(media_store)s"
|
||||
|
@ -342,4 +324,6 @@ class ContentRepositoryConfig(Config):
|
|||
# The largest allowed URL preview spidering size in bytes
|
||||
#
|
||||
#max_spider_size: 10M
|
||||
""" % locals()
|
||||
"""
|
||||
% locals()
|
||||
)
|
||||
|
|
|
@ -19,10 +19,8 @@ from ._base import Config, ConfigError
|
|||
|
||||
|
||||
class RoomDirectoryConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.enable_room_list_search = config.get(
|
||||
"enable_room_list_search", True,
|
||||
)
|
||||
def read_config(self, config, **kwargs):
|
||||
self.enable_room_list_search = config.get("enable_room_list_search", True)
|
||||
|
||||
alias_creation_rules = config.get("alias_creation_rules")
|
||||
|
||||
|
@ -33,11 +31,7 @@ class RoomDirectoryConfig(Config):
|
|||
]
|
||||
else:
|
||||
self._alias_creation_rules = [
|
||||
_RoomDirectoryRule(
|
||||
"alias_creation_rules", {
|
||||
"action": "allow",
|
||||
}
|
||||
)
|
||||
_RoomDirectoryRule("alias_creation_rules", {"action": "allow"})
|
||||
]
|
||||
|
||||
room_list_publication_rules = config.get("room_list_publication_rules")
|
||||
|
@ -49,14 +43,10 @@ class RoomDirectoryConfig(Config):
|
|||
]
|
||||
else:
|
||||
self._room_list_publication_rules = [
|
||||
_RoomDirectoryRule(
|
||||
"room_list_publication_rules", {
|
||||
"action": "allow",
|
||||
}
|
||||
)
|
||||
_RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
|
||||
]
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Uncomment to disable searching the public room list. When disabled
|
||||
# blocks searching local and remote room lists for local and remote
|
||||
|
@ -178,8 +168,7 @@ class _RoomDirectoryRule(object):
|
|||
self.action = action
|
||||
else:
|
||||
raise ConfigError(
|
||||
"%s rules can only have action of 'allow'"
|
||||
" or 'deny'" % (option_name,)
|
||||
"%s rules can only have action of 'allow'" " or 'deny'" % (option_name,)
|
||||
)
|
||||
|
||||
self._alias_matches_all = alias == "*"
|
||||
|
|
|
@ -18,7 +18,7 @@ from ._base import Config, ConfigError
|
|||
|
||||
|
||||
class SAML2Config(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.saml2_enabled = False
|
||||
|
||||
saml2_config = config.get("saml2_config")
|
||||
|
@ -34,6 +34,7 @@ class SAML2Config(Config):
|
|||
self.saml2_enabled = True
|
||||
|
||||
import saml2.config
|
||||
|
||||
self.saml2_sp_config = saml2.config.SPConfig()
|
||||
self.saml2_sp_config.load(self._default_saml_config_dict())
|
||||
self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
|
||||
|
@ -47,29 +48,26 @@ class SAML2Config(Config):
|
|||
|
||||
public_baseurl = self.public_baseurl
|
||||
if public_baseurl is None:
|
||||
raise ConfigError(
|
||||
"saml2_config requires a public_baseurl to be set"
|
||||
)
|
||||
raise ConfigError("saml2_config requires a public_baseurl to be set")
|
||||
|
||||
metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
|
||||
response_url = public_baseurl + "_matrix/saml2/authn_response"
|
||||
return {
|
||||
"entityid": metadata_url,
|
||||
|
||||
"service": {
|
||||
"sp": {
|
||||
"endpoints": {
|
||||
"assertion_consumer_service": [
|
||||
(response_url, saml2.BINDING_HTTP_POST),
|
||||
],
|
||||
(response_url, saml2.BINDING_HTTP_POST)
|
||||
]
|
||||
},
|
||||
"required_attributes": ["uid"],
|
||||
"optional_attributes": ["mail", "surname", "givenname"],
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """\
|
||||
# Enable SAML2 for registration and login. Uses pysaml2.
|
||||
#
|
||||
|
@ -112,4 +110,6 @@ class SAML2Config(Config):
|
|||
# # separate pysaml2 configuration file:
|
||||
# #
|
||||
# config_path: "%(config_dir_path)s/sp_conf.py"
|
||||
""" % {"config_dir_path": config_dir_path}
|
||||
""" % {
|
||||
"config_dir_path": config_dir_path
|
||||
}
|
||||
|
|
|
@ -34,14 +34,13 @@ logger = logging.Logger(__name__)
|
|||
#
|
||||
# We later check for errors when binding to 0.0.0.0 and ignore them if :: is also in
|
||||
# in the list.
|
||||
DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0']
|
||||
DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
|
||||
|
||||
DEFAULT_ROOM_VERSION = "4"
|
||||
|
||||
|
||||
class ServerConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.server_name = config["server_name"]
|
||||
self.server_context = config.get("server_context", None)
|
||||
|
||||
|
@ -58,7 +57,6 @@ class ServerConfig(Config):
|
|||
self.user_agent_suffix = config.get("user_agent_suffix")
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
||||
self.public_baseurl = config.get("public_baseurl")
|
||||
self.cpu_affinity = config.get("cpu_affinity")
|
||||
|
||||
# Whether to send federation traffic out in this process. This only
|
||||
# applies to some federation traffic, and so shouldn't be used to
|
||||
|
@ -81,27 +79,45 @@ class ServerConfig(Config):
|
|||
# Whether to require authentication to retrieve profile data (avatars,
|
||||
# display names) of other users through the client API.
|
||||
self.require_auth_for_profile_requests = config.get(
|
||||
"require_auth_for_profile_requests", False,
|
||||
"require_auth_for_profile_requests", False
|
||||
)
|
||||
|
||||
# If set to 'True', requires authentication to access the server's
|
||||
# public rooms directory through the client API, and forbids any other
|
||||
# homeserver to fetch it via federation.
|
||||
self.restrict_public_rooms_to_local_users = config.get(
|
||||
"restrict_public_rooms_to_local_users", False,
|
||||
)
|
||||
if "restrict_public_rooms_to_local_users" in config and (
|
||||
"allow_public_rooms_without_auth" in config
|
||||
or "allow_public_rooms_over_federation" in config
|
||||
):
|
||||
raise ConfigError(
|
||||
"Can't use 'restrict_public_rooms_to_local_users' if"
|
||||
" 'allow_public_rooms_without_auth' and/or"
|
||||
" 'allow_public_rooms_over_federation' is set."
|
||||
)
|
||||
|
||||
default_room_version = config.get(
|
||||
"default_room_version", DEFAULT_ROOM_VERSION,
|
||||
)
|
||||
# Check if the legacy "restrict_public_rooms_to_local_users" flag is set. This
|
||||
# flag is now obsolete but we need to check it for backward-compatibility.
|
||||
if config.get("restrict_public_rooms_to_local_users", False):
|
||||
self.allow_public_rooms_without_auth = False
|
||||
self.allow_public_rooms_over_federation = False
|
||||
else:
|
||||
# If set to 'False', requires authentication to access the server's public
|
||||
# rooms directory through the client API. Defaults to 'True'.
|
||||
self.allow_public_rooms_without_auth = config.get(
|
||||
"allow_public_rooms_without_auth", True
|
||||
)
|
||||
# If set to 'False', forbids any other homeserver to fetch the server's public
|
||||
# rooms directory via federation. Defaults to 'True'.
|
||||
self.allow_public_rooms_over_federation = config.get(
|
||||
"allow_public_rooms_over_federation", True
|
||||
)
|
||||
|
||||
default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION)
|
||||
|
||||
# Ensure room version is a str
|
||||
default_room_version = str(default_room_version)
|
||||
|
||||
if default_room_version not in KNOWN_ROOM_VERSIONS:
|
||||
raise ConfigError(
|
||||
"Unknown default_room_version: %s, known room versions: %s" %
|
||||
(default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
|
||||
"Unknown default_room_version: %s, known room versions: %s"
|
||||
% (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
|
||||
)
|
||||
|
||||
# Get the actual room version object rather than just the identifier
|
||||
|
@ -116,31 +132,25 @@ class ServerConfig(Config):
|
|||
|
||||
# Whether we should block invites sent to users on this server
|
||||
# (other than those sent by local server admins)
|
||||
self.block_non_admin_invites = config.get(
|
||||
"block_non_admin_invites", False,
|
||||
)
|
||||
self.block_non_admin_invites = config.get("block_non_admin_invites", False)
|
||||
|
||||
# Whether to enable experimental MSC1849 (aka relations) support
|
||||
self.experimental_msc1849_support_enabled = config.get(
|
||||
"experimental_msc1849_support_enabled", False,
|
||||
"experimental_msc1849_support_enabled", False
|
||||
)
|
||||
|
||||
# Options to control access by tracking MAU
|
||||
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
|
||||
self.max_mau_value = 0
|
||||
if self.limit_usage_by_mau:
|
||||
self.max_mau_value = config.get(
|
||||
"max_mau_value", 0,
|
||||
)
|
||||
self.max_mau_value = config.get("max_mau_value", 0)
|
||||
self.mau_stats_only = config.get("mau_stats_only", False)
|
||||
|
||||
self.mau_limits_reserved_threepids = config.get(
|
||||
"mau_limit_reserved_threepids", []
|
||||
)
|
||||
|
||||
self.mau_trial_days = config.get(
|
||||
"mau_trial_days", 0,
|
||||
)
|
||||
self.mau_trial_days = config.get("mau_trial_days", 0)
|
||||
|
||||
# Options to disable HS
|
||||
self.hs_disabled = config.get("hs_disabled", False)
|
||||
|
@ -153,9 +163,7 @@ class ServerConfig(Config):
|
|||
|
||||
# FIXME: federation_domain_whitelist needs sytests
|
||||
self.federation_domain_whitelist = None
|
||||
federation_domain_whitelist = config.get(
|
||||
"federation_domain_whitelist", None,
|
||||
)
|
||||
federation_domain_whitelist = config.get("federation_domain_whitelist", None)
|
||||
|
||||
if federation_domain_whitelist is not None:
|
||||
# turn the whitelist into a hash for speed of lookup
|
||||
|
@ -165,7 +173,7 @@ class ServerConfig(Config):
|
|||
self.federation_domain_whitelist[domain] = True
|
||||
|
||||
self.federation_ip_range_blacklist = config.get(
|
||||
"federation_ip_range_blacklist", [],
|
||||
"federation_ip_range_blacklist", []
|
||||
)
|
||||
|
||||
# Attempt to create an IPSet from the given ranges
|
||||
|
@ -178,13 +186,12 @@ class ServerConfig(Config):
|
|||
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Invalid range(s) provided in "
|
||||
"federation_ip_range_blacklist: %s" % e
|
||||
"Invalid range(s) provided in " "federation_ip_range_blacklist: %s" % e
|
||||
)
|
||||
|
||||
if self.public_baseurl is not None:
|
||||
if self.public_baseurl[-1] != '/':
|
||||
self.public_baseurl += '/'
|
||||
if self.public_baseurl[-1] != "/":
|
||||
self.public_baseurl += "/"
|
||||
self.start_pushers = config.get("start_pushers", True)
|
||||
|
||||
# (undocumented) option for torturing the worker-mode replication a bit,
|
||||
|
@ -195,7 +202,7 @@ class ServerConfig(Config):
|
|||
# Whether to require a user to be in the room to add an alias to it.
|
||||
# Defaults to True.
|
||||
self.require_membership_for_aliases = config.get(
|
||||
"require_membership_for_aliases", True,
|
||||
"require_membership_for_aliases", True
|
||||
)
|
||||
|
||||
# Whether to allow per-room membership profiles through the send of membership
|
||||
|
@ -227,9 +234,9 @@ class ServerConfig(Config):
|
|||
|
||||
# if we still have an empty list of addresses, use the default list
|
||||
if not bind_addresses:
|
||||
if listener['type'] == 'metrics':
|
||||
if listener["type"] == "metrics":
|
||||
# the metrics listener doesn't support IPv6
|
||||
bind_addresses.append('0.0.0.0')
|
||||
bind_addresses.append("0.0.0.0")
|
||||
else:
|
||||
bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
|
||||
|
||||
|
@ -249,78 +256,80 @@ class ServerConfig(Config):
|
|||
bind_host = config.get("bind_host", "")
|
||||
gzip_responses = config.get("gzip_responses", True)
|
||||
|
||||
self.listeners.append({
|
||||
"port": bind_port,
|
||||
"bind_addresses": [bind_host],
|
||||
"tls": True,
|
||||
"type": "http",
|
||||
"resources": [
|
||||
{
|
||||
"names": ["client"],
|
||||
"compress": gzip_responses,
|
||||
},
|
||||
{
|
||||
"names": ["federation"],
|
||||
"compress": False,
|
||||
}
|
||||
]
|
||||
})
|
||||
self.listeners.append(
|
||||
{
|
||||
"port": bind_port,
|
||||
"bind_addresses": [bind_host],
|
||||
"tls": True,
|
||||
"type": "http",
|
||||
"resources": [
|
||||
{"names": ["client"], "compress": gzip_responses},
|
||||
{"names": ["federation"], "compress": False},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
unsecure_port = config.get("unsecure_port", bind_port - 400)
|
||||
if unsecure_port:
|
||||
self.listeners.append({
|
||||
"port": unsecure_port,
|
||||
"bind_addresses": [bind_host],
|
||||
"tls": False,
|
||||
"type": "http",
|
||||
"resources": [
|
||||
{
|
||||
"names": ["client"],
|
||||
"compress": gzip_responses,
|
||||
},
|
||||
{
|
||||
"names": ["federation"],
|
||||
"compress": False,
|
||||
}
|
||||
]
|
||||
})
|
||||
self.listeners.append(
|
||||
{
|
||||
"port": unsecure_port,
|
||||
"bind_addresses": [bind_host],
|
||||
"tls": False,
|
||||
"type": "http",
|
||||
"resources": [
|
||||
{"names": ["client"], "compress": gzip_responses},
|
||||
{"names": ["federation"], "compress": False},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
manhole = config.get("manhole")
|
||||
if manhole:
|
||||
self.listeners.append({
|
||||
"port": manhole,
|
||||
"bind_addresses": ["127.0.0.1"],
|
||||
"type": "manhole",
|
||||
"tls": False,
|
||||
})
|
||||
self.listeners.append(
|
||||
{
|
||||
"port": manhole,
|
||||
"bind_addresses": ["127.0.0.1"],
|
||||
"type": "manhole",
|
||||
"tls": False,
|
||||
}
|
||||
)
|
||||
|
||||
metrics_port = config.get("metrics_port")
|
||||
if metrics_port:
|
||||
logger.warn(
|
||||
("The metrics_port configuration option is deprecated in Synapse 0.31 "
|
||||
"in favour of a listener. Please see "
|
||||
"http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
|
||||
" on how to configure the new listener."))
|
||||
(
|
||||
"The metrics_port configuration option is deprecated in Synapse 0.31 "
|
||||
"in favour of a listener. Please see "
|
||||
"http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
|
||||
" on how to configure the new listener."
|
||||
)
|
||||
)
|
||||
|
||||
self.listeners.append({
|
||||
"port": metrics_port,
|
||||
"bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
|
||||
"tls": False,
|
||||
"type": "http",
|
||||
"resources": [
|
||||
{
|
||||
"names": ["metrics"],
|
||||
"compress": False,
|
||||
},
|
||||
]
|
||||
})
|
||||
self.listeners.append(
|
||||
{
|
||||
"port": metrics_port,
|
||||
"bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
|
||||
"tls": False,
|
||||
"type": "http",
|
||||
"resources": [{"names": ["metrics"], "compress": False}],
|
||||
}
|
||||
)
|
||||
|
||||
_check_resource_config(self.listeners)
|
||||
|
||||
# An experimental option to try and periodically clean up extremities
|
||||
# by sending dummy events.
|
||||
self.cleanup_extremities_with_dummy_events = config.get(
|
||||
"cleanup_extremities_with_dummy_events", False
|
||||
)
|
||||
|
||||
def has_tls_listener(self):
|
||||
return any(l["tls"] for l in self.listeners)
|
||||
|
||||
def default_config(self, server_name, data_dir_path, **kwargs):
|
||||
def generate_config_section(
|
||||
self, server_name, data_dir_path, open_private_ports, **kwargs
|
||||
):
|
||||
_, bind_port = parse_and_validate_server_name(server_name)
|
||||
if bind_port is not None:
|
||||
unsecure_port = bind_port - 400
|
||||
|
@ -333,7 +342,15 @@ class ServerConfig(Config):
|
|||
# Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
|
||||
# default config string
|
||||
default_room_version = DEFAULT_ROOM_VERSION
|
||||
return """\
|
||||
|
||||
unsecure_http_binding = "port: %i\n tls: false" % (unsecure_port,)
|
||||
if not open_private_ports:
|
||||
unsecure_http_binding += (
|
||||
"\n bind_addresses: ['::1', '127.0.0.1']"
|
||||
)
|
||||
|
||||
return (
|
||||
"""\
|
||||
## Server ##
|
||||
|
||||
# The domain name of the server, with optional explicit port.
|
||||
|
@ -347,29 +364,6 @@ class ServerConfig(Config):
|
|||
#
|
||||
pid_file: %(pid_file)s
|
||||
|
||||
# CPU affinity mask. Setting this restricts the CPUs on which the
|
||||
# process will be scheduled. It is represented as a bitmask, with the
|
||||
# lowest order bit corresponding to the first logical CPU and the
|
||||
# highest order bit corresponding to the last logical CPU. Not all CPUs
|
||||
# may exist on a given system but a mask may specify more CPUs than are
|
||||
# present.
|
||||
#
|
||||
# For example:
|
||||
# 0x00000001 is processor #0,
|
||||
# 0x00000003 is processors #0 and #1,
|
||||
# 0xFFFFFFFF is all processors (#0 through #31).
|
||||
#
|
||||
# Pinning a Python process to a single CPU is desirable, because Python
|
||||
# is inherently single-threaded due to the GIL, and can suffer a
|
||||
# 30-40%% slowdown due to cache blow-out and thread context switching
|
||||
# if the scheduler happens to schedule the underlying threads across
|
||||
# different cores. See
|
||||
# https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
|
||||
#
|
||||
# This setting requires the affinity package to be installed!
|
||||
#
|
||||
#cpu_affinity: 0xFFFFFFFF
|
||||
|
||||
# The path to the web client which will be served at /_matrix/client/
|
||||
# if 'webclient' is configured under the 'listeners' configuration.
|
||||
#
|
||||
|
@ -401,11 +395,15 @@ class ServerConfig(Config):
|
|||
#
|
||||
#require_auth_for_profile_requests: true
|
||||
|
||||
# If set to 'true', requires authentication to access the server's
|
||||
# public rooms directory through the client API, and forbids any other
|
||||
# homeserver to fetch it via federation. Defaults to 'false'.
|
||||
# If set to 'false', requires authentication to access the server's public rooms
|
||||
# directory through the client API. Defaults to 'true'.
|
||||
#
|
||||
#restrict_public_rooms_to_local_users: true
|
||||
#allow_public_rooms_without_auth: false
|
||||
|
||||
# If set to 'false', forbids any other homeserver to fetch the server's public
|
||||
# rooms directory via federation. Defaults to 'true'.
|
||||
#
|
||||
#allow_public_rooms_over_federation: false
|
||||
|
||||
# The default room version for newly created rooms.
|
||||
#
|
||||
|
@ -546,9 +544,7 @@ class ServerConfig(Config):
|
|||
# If you plan to use a reverse proxy, please see
|
||||
# https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst.
|
||||
#
|
||||
- port: %(unsecure_port)s
|
||||
tls: false
|
||||
bind_addresses: ['::1', '127.0.0.1']
|
||||
- %(unsecure_http_binding)s
|
||||
type: http
|
||||
x_forwarded: true
|
||||
|
||||
|
@ -556,7 +552,7 @@ class ServerConfig(Config):
|
|||
- names: [client, federation]
|
||||
compress: false
|
||||
|
||||
# example additonal_resources:
|
||||
# example additional_resources:
|
||||
#
|
||||
#additional_resources:
|
||||
# "/_matrix/my/custom/endpoint":
|
||||
|
@ -631,7 +627,9 @@ class ServerConfig(Config):
|
|||
# Defaults to 'true'.
|
||||
#
|
||||
#allow_per_room_profiles: false
|
||||
""" % locals()
|
||||
"""
|
||||
% locals()
|
||||
)
|
||||
|
||||
def read_arguments(self, args):
|
||||
if args.manhole is not None:
|
||||
|
@ -643,17 +641,26 @@ class ServerConfig(Config):
|
|||
|
||||
def add_arguments(self, parser):
|
||||
server_group = parser.add_argument_group("server")
|
||||
server_group.add_argument("-D", "--daemonize", action='store_true',
|
||||
default=None,
|
||||
help="Daemonize the home server")
|
||||
server_group.add_argument("--print-pidfile", action='store_true',
|
||||
default=None,
|
||||
help="Print the path to the pidfile just"
|
||||
" before daemonizing")
|
||||
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
||||
type=int,
|
||||
help="Turn on the twisted telnet manhole"
|
||||
" service on the given port.")
|
||||
server_group.add_argument(
|
||||
"-D",
|
||||
"--daemonize",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="Daemonize the home server",
|
||||
)
|
||||
server_group.add_argument(
|
||||
"--print-pidfile",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="Print the path to the pidfile just" " before daemonizing",
|
||||
)
|
||||
server_group.add_argument(
|
||||
"--manhole",
|
||||
metavar="PORT",
|
||||
dest="manhole",
|
||||
type=int,
|
||||
help="Turn on the twisted telnet manhole" " service on the given port.",
|
||||
)
|
||||
|
||||
|
||||
def is_threepid_reserved(reserved_threepids, threepid):
|
||||
|
@ -667,7 +674,7 @@ def is_threepid_reserved(reserved_threepids, threepid):
|
|||
"""
|
||||
|
||||
for tp in reserved_threepids:
|
||||
if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']):
|
||||
if threepid["medium"] == tp["medium"] and threepid["address"] == tp["address"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -680,9 +687,7 @@ def read_gc_thresholds(thresholds):
|
|||
return None
|
||||
try:
|
||||
assert len(thresholds) == 3
|
||||
return (
|
||||
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
|
||||
)
|
||||
return (int(thresholds[0]), int(thresholds[1]), int(thresholds[2]))
|
||||
except Exception:
|
||||
raise ConfigError(
|
||||
"Value of `gc_threshold` must be a list of three integers if set"
|
||||
|
@ -700,22 +705,22 @@ def _warn_if_webclient_configured(listeners):
|
|||
for listener in listeners:
|
||||
for res in listener.get("resources", []):
|
||||
for name in res.get("names", []):
|
||||
if name == 'webclient':
|
||||
if name == "webclient":
|
||||
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
|
||||
return
|
||||
|
||||
|
||||
KNOWN_RESOURCES = (
|
||||
'client',
|
||||
'consent',
|
||||
'federation',
|
||||
'keys',
|
||||
'media',
|
||||
'metrics',
|
||||
'openid',
|
||||
'replication',
|
||||
'static',
|
||||
'webclient',
|
||||
"client",
|
||||
"consent",
|
||||
"federation",
|
||||
"keys",
|
||||
"media",
|
||||
"metrics",
|
||||
"openid",
|
||||
"replication",
|
||||
"static",
|
||||
"webclient",
|
||||
)
|
||||
|
||||
|
||||
|
@ -729,11 +734,9 @@ def _check_resource_config(listeners):
|
|||
|
||||
for resource in resource_names:
|
||||
if resource not in KNOWN_RESOURCES:
|
||||
raise ConfigError(
|
||||
"Unknown listener resource '%s'" % (resource, )
|
||||
)
|
||||
raise ConfigError("Unknown listener resource '%s'" % (resource,))
|
||||
if resource == "consent":
|
||||
try:
|
||||
check_requirements('resources.consent')
|
||||
check_requirements("resources.consent")
|
||||
except DependencyException as e:
|
||||
raise ConfigError(e.message)
|
||||
|
|
|
@ -58,6 +58,7 @@ class ServerNoticesConfig(Config):
|
|||
The name to use for the server notices room.
|
||||
None if server notices are not enabled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ServerNoticesConfig, self).__init__()
|
||||
self.server_notices_mxid = None
|
||||
|
@ -65,23 +66,17 @@ class ServerNoticesConfig(Config):
|
|||
self.server_notices_mxid_avatar_url = None
|
||||
self.server_notices_room_name = None
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
c = config.get("server_notices")
|
||||
if c is None:
|
||||
return
|
||||
|
||||
mxid_localpart = c['system_mxid_localpart']
|
||||
self.server_notices_mxid = UserID(
|
||||
mxid_localpart, self.server_name,
|
||||
).to_string()
|
||||
self.server_notices_mxid_display_name = c.get(
|
||||
'system_mxid_display_name', None,
|
||||
)
|
||||
self.server_notices_mxid_avatar_url = c.get(
|
||||
'system_mxid_avatar_url', None,
|
||||
)
|
||||
mxid_localpart = c["system_mxid_localpart"]
|
||||
self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string()
|
||||
self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None)
|
||||
self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None)
|
||||
# todo: i18n
|
||||
self.server_notices_room_name = c.get('room_name', "Server Notices")
|
||||
self.server_notices_room_name = c.get("room_name", "Server Notices")
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs):
|
||||
return DEFAULT_CONFIG
|
||||
|
|
|
@ -19,14 +19,14 @@ from ._base import Config
|
|||
|
||||
|
||||
class SpamCheckerConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.spam_checker = None
|
||||
|
||||
provider = config.get("spam_checker", None)
|
||||
if provider is not None:
|
||||
self.spam_checker = load_module(provider)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
#spam_checker:
|
||||
# module: "my_custom_project.SuperSpamChecker"
|
||||
|
|
|
@ -25,7 +25,7 @@ class StatsConfig(Config):
|
|||
Configuration for the behaviour of synapse's stats engine
|
||||
"""
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.stats_enabled = True
|
||||
self.stats_bucket_size = 86400
|
||||
self.stats_retention = sys.maxsize
|
||||
|
@ -42,7 +42,7 @@ class StatsConfig(Config):
|
|||
/ 1000
|
||||
)
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Local statistics collection. Used in populating the room directory.
|
||||
#
|
||||
|
|
42
synapse/config/third_party_event_rules.py
Normal file
42
synapse/config/third_party_event_rules.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class ThirdPartyRulesConfig(Config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.third_party_event_rules = None
|
||||
|
||||
provider = config.get("third_party_event_rules", None)
|
||||
if provider is not None:
|
||||
self.third_party_event_rules = load_module(provider)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
# Server admins can define a Python module that implements extra rules for
|
||||
# allowing or denying incoming events. In order to work, this module needs to
|
||||
# override the methods defined in synapse/events/third_party_rules.py.
|
||||
#
|
||||
# This feature is designed to be used in closed federations only, where each
|
||||
# participating server enforces the same rules.
|
||||
#
|
||||
#third_party_event_rules:
|
||||
# module: "my_custom_project.SuperRulesSet"
|
||||
# config:
|
||||
# example_option: 'things'
|
||||
"""
|
|
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TlsConfig(Config):
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, config_dir_path, **kwargs):
|
||||
|
||||
acme_config = config.get("acme", None)
|
||||
if acme_config is None:
|
||||
|
@ -42,14 +42,18 @@ class TlsConfig(Config):
|
|||
self.acme_enabled = acme_config.get("enabled", False)
|
||||
|
||||
# hyperlink complains on py2 if this is not a Unicode
|
||||
self.acme_url = six.text_type(acme_config.get(
|
||||
"url", u"https://acme-v01.api.letsencrypt.org/directory"
|
||||
))
|
||||
self.acme_url = six.text_type(
|
||||
acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory")
|
||||
)
|
||||
self.acme_port = acme_config.get("port", 80)
|
||||
self.acme_bind_addresses = acme_config.get("bind_addresses", ['::', '0.0.0.0'])
|
||||
self.acme_bind_addresses = acme_config.get("bind_addresses", ["::", "0.0.0.0"])
|
||||
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
|
||||
self.acme_domain = acme_config.get("domain", config.get("server_name"))
|
||||
|
||||
self.acme_account_key_file = self.abspath(
|
||||
acme_config.get("account_key_file", config_dir_path + "/client.key")
|
||||
)
|
||||
|
||||
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
|
||||
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
|
||||
|
||||
|
@ -74,12 +78,12 @@ class TlsConfig(Config):
|
|||
|
||||
# Whether to verify certificates on outbound federation traffic
|
||||
self.federation_verify_certificates = config.get(
|
||||
"federation_verify_certificates", True,
|
||||
"federation_verify_certificates", True
|
||||
)
|
||||
|
||||
# Whitelist of domains to not verify certificates for
|
||||
fed_whitelist_entries = config.get(
|
||||
"federation_certificate_verification_whitelist", [],
|
||||
"federation_certificate_verification_whitelist", []
|
||||
)
|
||||
|
||||
# Support globs (*) in whitelist values
|
||||
|
@ -90,9 +94,7 @@ class TlsConfig(Config):
|
|||
self.federation_certificate_verification_whitelist.append(entry_regex)
|
||||
|
||||
# List of custom certificate authorities for federation traffic validation
|
||||
custom_ca_list = config.get(
|
||||
"federation_custom_ca_list", None,
|
||||
)
|
||||
custom_ca_list = config.get("federation_custom_ca_list", None)
|
||||
|
||||
# Read in and parse custom CA certificates
|
||||
self.federation_ca_trust_root = None
|
||||
|
@ -101,8 +103,10 @@ class TlsConfig(Config):
|
|||
# A trustroot cannot be generated without any CA certificates.
|
||||
# Raise an error if this option has been specified without any
|
||||
# corresponding certificates.
|
||||
raise ConfigError("federation_custom_ca_list specified without "
|
||||
"any certificate files")
|
||||
raise ConfigError(
|
||||
"federation_custom_ca_list specified without "
|
||||
"any certificate files"
|
||||
)
|
||||
|
||||
certs = []
|
||||
for ca_file in custom_ca_list:
|
||||
|
@ -114,8 +118,9 @@ class TlsConfig(Config):
|
|||
cert_base = Certificate.loadPEM(content)
|
||||
certs.append(cert_base)
|
||||
except Exception as e:
|
||||
raise ConfigError("Error parsing custom CA certificate file %s: %s"
|
||||
% (ca_file, e))
|
||||
raise ConfigError(
|
||||
"Error parsing custom CA certificate file %s: %s" % (ca_file, e)
|
||||
)
|
||||
|
||||
self.federation_ca_trust_root = trustRootFromCertificates(certs)
|
||||
|
||||
|
@ -146,17 +151,21 @@ class TlsConfig(Config):
|
|||
return None
|
||||
|
||||
try:
|
||||
with open(self.tls_certificate_file, 'rb') as f:
|
||||
with open(self.tls_certificate_file, "rb") as f:
|
||||
cert_pem = f.read()
|
||||
except Exception as e:
|
||||
raise ConfigError("Failed to read existing certificate file %s: %s"
|
||||
% (self.tls_certificate_file, e))
|
||||
raise ConfigError(
|
||||
"Failed to read existing certificate file %s: %s"
|
||||
% (self.tls_certificate_file, e)
|
||||
)
|
||||
|
||||
try:
|
||||
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
|
||||
except Exception as e:
|
||||
raise ConfigError("Failed to parse existing certificate file %s: %s"
|
||||
% (self.tls_certificate_file, e))
|
||||
raise ConfigError(
|
||||
"Failed to parse existing certificate file %s: %s"
|
||||
% (self.tls_certificate_file, e)
|
||||
)
|
||||
|
||||
if not allow_self_signed:
|
||||
if tls_certificate.get_subject() == tls_certificate.get_issuer():
|
||||
|
@ -166,7 +175,7 @@ class TlsConfig(Config):
|
|||
|
||||
# YYYYMMDDhhmmssZ -- in UTC
|
||||
expires_on = datetime.strptime(
|
||||
tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
|
||||
tls_certificate.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ"
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
days_remaining = (expires_on - now).days
|
||||
|
@ -191,7 +200,8 @@ class TlsConfig(Config):
|
|||
except Exception as e:
|
||||
logger.info(
|
||||
"Unable to read TLS certificate (%s). Ignoring as no "
|
||||
"tls listeners enabled.", e,
|
||||
"tls listeners enabled.",
|
||||
e,
|
||||
)
|
||||
|
||||
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
||||
|
@ -205,18 +215,21 @@ class TlsConfig(Config):
|
|||
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
||||
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
|
||||
if sha256_fingerprint not in sha256_fingerprints:
|
||||
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
|
||||
self.tls_fingerprints.append({"sha256": sha256_fingerprint})
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(
|
||||
self, config_dir_path, server_name, data_dir_path, **kwargs
|
||||
):
|
||||
base_key_name = os.path.join(config_dir_path, server_name)
|
||||
|
||||
tls_certificate_path = base_key_name + ".tls.crt"
|
||||
tls_private_key_path = base_key_name + ".tls.key"
|
||||
default_acme_account_file = os.path.join(data_dir_path, "acme_account.key")
|
||||
|
||||
# this is to avoid the max line length. Sorrynotsorry
|
||||
proxypassline = (
|
||||
'ProxyPass /.well-known/acme-challenge '
|
||||
'http://localhost:8009/.well-known/acme-challenge'
|
||||
"ProxyPass /.well-known/acme-challenge "
|
||||
"http://localhost:8009/.well-known/acme-challenge"
|
||||
)
|
||||
|
||||
return (
|
||||
|
@ -337,6 +350,13 @@ class TlsConfig(Config):
|
|||
#
|
||||
#domain: matrix.example.com
|
||||
|
||||
# file to use for the account key. This will be generated if it doesn't
|
||||
# exist.
|
||||
#
|
||||
# If unspecified, we will use CONFDIR/client.key.
|
||||
#
|
||||
account_key_file: %(default_acme_account_file)s
|
||||
|
||||
# List of allowed TLS fingerprints for this server to publish along
|
||||
# with the signing keys for this server. Other matrix servers that
|
||||
# make HTTPS requests to this server will check that the TLS
|
||||
|
|
|
@ -21,19 +21,19 @@ class UserDirectoryConfig(Config):
|
|||
Configuration for the behaviour of the /user_directory API
|
||||
"""
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.user_directory_search_enabled = True
|
||||
self.user_directory_search_all_users = False
|
||||
user_directory_config = config.get("user_directory", None)
|
||||
if user_directory_config:
|
||||
self.user_directory_search_enabled = (
|
||||
user_directory_config.get("enabled", True)
|
||||
self.user_directory_search_enabled = user_directory_config.get(
|
||||
"enabled", True
|
||||
)
|
||||
self.user_directory_search_all_users = (
|
||||
user_directory_config.get("search_all_users", False)
|
||||
self.user_directory_search_all_users = user_directory_config.get(
|
||||
"search_all_users", False
|
||||
)
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# User Directory configuration
|
||||
#
|
||||
|
|
|
@ -16,18 +16,17 @@ from ._base import Config
|
|||
|
||||
|
||||
class VoipConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.turn_uris = config.get("turn_uris", [])
|
||||
self.turn_shared_secret = config.get("turn_shared_secret")
|
||||
self.turn_username = config.get("turn_username")
|
||||
self.turn_password = config.get("turn_password")
|
||||
self.turn_user_lifetime = self.parse_duration(
|
||||
config.get("turn_user_lifetime", "1h"),
|
||||
config.get("turn_user_lifetime", "1h")
|
||||
)
|
||||
self.turn_allow_guests = config.get("turn_allow_guests", True)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
## TURN ##
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ class WorkerConfig(Config):
|
|||
They have their own pid_file and listener configuration. They use the
|
||||
replication_url to talk to the main synapse process."""
|
||||
|
||||
def read_config(self, config):
|
||||
def read_config(self, config, **kwargs):
|
||||
self.worker_app = config.get("worker_app")
|
||||
|
||||
# Canonicalise worker_app so that master always has None
|
||||
|
@ -46,18 +46,19 @@ class WorkerConfig(Config):
|
|||
self.worker_name = config.get("worker_name", self.worker_app)
|
||||
|
||||
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
|
||||
self.worker_cpu_affinity = config.get("worker_cpu_affinity")
|
||||
|
||||
# This option is really only here to support `--manhole` command line
|
||||
# argument.
|
||||
manhole = config.get("worker_manhole")
|
||||
if manhole:
|
||||
self.worker_listeners.append({
|
||||
"port": manhole,
|
||||
"bind_addresses": ["127.0.0.1"],
|
||||
"type": "manhole",
|
||||
"tls": False,
|
||||
})
|
||||
self.worker_listeners.append(
|
||||
{
|
||||
"port": manhole,
|
||||
"bind_addresses": ["127.0.0.1"],
|
||||
"type": "manhole",
|
||||
"tls": False,
|
||||
}
|
||||
)
|
||||
|
||||
if self.worker_listeners:
|
||||
for listener in self.worker_listeners:
|
||||
|
@ -67,7 +68,7 @@ class WorkerConfig(Config):
|
|||
if bind_address:
|
||||
bind_addresses.append(bind_address)
|
||||
elif not bind_addresses:
|
||||
bind_addresses.append('')
|
||||
bind_addresses.append("")
|
||||
|
||||
def read_arguments(self, args):
|
||||
# We support a bunch of command line arguments that override options in
|
||||
|
|
|
@ -46,9 +46,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
|||
if name not in hashes:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Algorithm %s not in hashes %s" % (
|
||||
name, list(hashes),
|
||||
),
|
||||
"Algorithm %s not in hashes %s" % (name, list(hashes)),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
message_hash_base64 = hashes[name]
|
||||
|
@ -56,9 +54,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
|||
message_hash_bytes = decode_base64(message_hash_base64)
|
||||
except Exception:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Invalid base64: %s" % (message_hash_base64,),
|
||||
Codes.UNAUTHORIZED,
|
||||
400, "Invalid base64: %s" % (message_hash_base64,), Codes.UNAUTHORIZED
|
||||
)
|
||||
return message_hash_bytes == expected_hash
|
||||
|
||||
|
@ -135,8 +131,9 @@ def compute_event_signature(event_dict, signature_name, signing_key):
|
|||
return redact_json["signatures"]
|
||||
|
||||
|
||||
def add_hashes_and_signatures(event_dict, signature_name, signing_key,
|
||||
hash_algorithm=hashlib.sha256):
|
||||
def add_hashes_and_signatures(
|
||||
event_dict, signature_name, signing_key, hash_algorithm=hashlib.sha256
|
||||
):
|
||||
"""Add content hash and sign the event
|
||||
|
||||
Args:
|
||||
|
@ -153,7 +150,5 @@ def add_hashes_and_signatures(event_dict, signature_name, signing_key,
|
|||
event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
|
||||
|
||||
event_dict["signatures"] = compute_event_signature(
|
||||
event_dict,
|
||||
signature_name=signature_name,
|
||||
signing_key=signing_key,
|
||||
event_dict, signature_name=signature_name, signing_key=signing_key
|
||||
)
|
||||
|
|
|
@ -505,7 +505,7 @@ class BaseV2KeyFetcher(object):
|
|||
Returns:
|
||||
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
|
||||
"""
|
||||
ts_valid_until_ms = response_json[u"valid_until_ts"]
|
||||
ts_valid_until_ms = response_json["valid_until_ts"]
|
||||
|
||||
# start by extracting the keys from the response, since they may be required
|
||||
# to validate the signature on the response.
|
||||
|
@ -614,10 +614,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
|
||||
results = yield logcontext.make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(get_key, server)
|
||||
for server in self.key_servers
|
||||
],
|
||||
[run_in_background(get_key, server) for server in self.key_servers],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
@ -630,9 +627,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
defer.returnValue(union_of_keys)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key_v2_indirect(
|
||||
self, keys_to_fetch, key_server
|
||||
):
|
||||
def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch (dict[str, dict[str, int]]):
|
||||
|
@ -661,9 +656,9 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
destination=perspective_name,
|
||||
path="/_matrix/key/v2/query",
|
||||
data={
|
||||
u"server_keys": {
|
||||
"server_keys": {
|
||||
server_name: {
|
||||
key_id: {u"minimum_valid_until_ts": min_valid_ts}
|
||||
key_id: {"minimum_valid_until_ts": min_valid_ts}
|
||||
for key_id, min_valid_ts in server_keys.items()
|
||||
}
|
||||
for server_name, server_keys in keys_to_fetch.items()
|
||||
|
@ -690,10 +685,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
)
|
||||
|
||||
try:
|
||||
self._validate_perspectives_response(
|
||||
key_server,
|
||||
response,
|
||||
)
|
||||
self._validate_perspectives_response(key_server, response)
|
||||
|
||||
processed_response = yield self.process_v2_response(
|
||||
perspective_name, response, time_added_ms=time_now_ms
|
||||
|
@ -720,9 +712,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
|
||||
defer.returnValue(keys)
|
||||
|
||||
def _validate_perspectives_response(
|
||||
self, key_server, response,
|
||||
):
|
||||
def _validate_perspectives_response(self, key_server, response):
|
||||
"""Optionally check the signature on the result of a /key/query request
|
||||
|
||||
Args:
|
||||
|
@ -739,13 +729,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
return
|
||||
|
||||
if (
|
||||
u"signatures" not in response
|
||||
or perspective_name not in response[u"signatures"]
|
||||
"signatures" not in response
|
||||
or perspective_name not in response["signatures"]
|
||||
):
|
||||
raise KeyLookupError("Response not signed by the notary server")
|
||||
|
||||
verified = False
|
||||
for key_id in response[u"signatures"][perspective_name]:
|
||||
for key_id in response["signatures"][perspective_name]:
|
||||
if key_id in perspective_keys:
|
||||
verify_signed_json(response, perspective_name, perspective_keys[key_id])
|
||||
verified = True
|
||||
|
@ -754,7 +744,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
raise KeyLookupError(
|
||||
"Response not signed with a known key: signed with: %r, known keys: %r"
|
||||
% (
|
||||
list(response[u"signatures"][perspective_name].keys()),
|
||||
list(response["signatures"][perspective_name].keys()),
|
||||
list(perspective_keys.keys()),
|
||||
)
|
||||
)
|
||||
|
@ -826,7 +816,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||
path="/_matrix/key/v2/server/"
|
||||
+ urllib.parse.quote(requested_key_id),
|
||||
ignore_backoff=True,
|
||||
|
||||
# we only give the remote server 10s to respond. It should be an
|
||||
# easy request to handle, so if it doesn't reply within 10s, it's
|
||||
# probably not going to.
|
||||
|
|
|
@ -85,17 +85,14 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
|
|||
room_id_domain = get_domain_from_id(event.room_id)
|
||||
if room_id_domain != sender_domain:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Creation event's room_id domain does not match sender's"
|
||||
403, "Creation event's room_id domain does not match sender's"
|
||||
)
|
||||
|
||||
room_version = event.content.get("room_version", "1")
|
||||
if room_version not in KNOWN_ROOM_VERSIONS:
|
||||
raise AuthError(
|
||||
403,
|
||||
"room appears to have unsupported version %s" % (
|
||||
room_version,
|
||||
))
|
||||
403, "room appears to have unsupported version %s" % (room_version,)
|
||||
)
|
||||
# FIXME
|
||||
logger.debug("Allowing! %s", event)
|
||||
return
|
||||
|
@ -103,46 +100,30 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
|
|||
creation_event = auth_events.get((EventTypes.Create, ""), None)
|
||||
|
||||
if not creation_event:
|
||||
raise AuthError(
|
||||
403,
|
||||
"No create event in auth events",
|
||||
)
|
||||
raise AuthError(403, "No create event in auth events")
|
||||
|
||||
creating_domain = get_domain_from_id(event.room_id)
|
||||
originating_domain = get_domain_from_id(event.sender)
|
||||
if creating_domain != originating_domain:
|
||||
if not _can_federate(event, auth_events):
|
||||
raise AuthError(
|
||||
403,
|
||||
"This room has been marked as unfederatable."
|
||||
)
|
||||
raise AuthError(403, "This room has been marked as unfederatable.")
|
||||
|
||||
# FIXME: Temp hack
|
||||
if event.type == EventTypes.Aliases:
|
||||
if not event.is_state():
|
||||
raise AuthError(
|
||||
403,
|
||||
"Alias event must be a state event",
|
||||
)
|
||||
raise AuthError(403, "Alias event must be a state event")
|
||||
if not event.state_key:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Alias event must have non-empty state_key"
|
||||
)
|
||||
raise AuthError(403, "Alias event must have non-empty state_key")
|
||||
sender_domain = get_domain_from_id(event.sender)
|
||||
if event.state_key != sender_domain:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Alias event's state_key does not match sender's domain"
|
||||
403, "Alias event's state_key does not match sender's domain"
|
||||
)
|
||||
logger.debug("Allowing! %s", event)
|
||||
return
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"Auth events: %s",
|
||||
[a.event_id for a in auth_events.values()]
|
||||
)
|
||||
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
_is_membership_change_allowed(event, auth_events)
|
||||
|
@ -159,9 +140,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
|
|||
invite_level = _get_named_level(auth_events, "invite", 0)
|
||||
|
||||
if user_level < invite_level:
|
||||
raise AuthError(
|
||||
403, "You don't have permission to invite users",
|
||||
)
|
||||
raise AuthError(403, "You don't have permission to invite users")
|
||||
else:
|
||||
logger.debug("Allowing! %s", event)
|
||||
return
|
||||
|
@ -207,7 +186,7 @@ def _is_membership_change_allowed(event, auth_events):
|
|||
# Check if this is the room creator joining:
|
||||
if len(event.prev_event_ids()) == 1 and Membership.JOIN == membership:
|
||||
# Get room creation event:
|
||||
key = (EventTypes.Create, "", )
|
||||
key = (EventTypes.Create, "")
|
||||
create = auth_events.get(key)
|
||||
if create and event.prev_event_ids()[0] == create.event_id:
|
||||
if create.content["creator"] == event.state_key:
|
||||
|
@ -219,38 +198,31 @@ def _is_membership_change_allowed(event, auth_events):
|
|||
target_domain = get_domain_from_id(target_user_id)
|
||||
if creating_domain != target_domain:
|
||||
if not _can_federate(event, auth_events):
|
||||
raise AuthError(
|
||||
403,
|
||||
"This room has been marked as unfederatable."
|
||||
)
|
||||
raise AuthError(403, "This room has been marked as unfederatable.")
|
||||
|
||||
# get info about the caller
|
||||
key = (EventTypes.Member, event.user_id, )
|
||||
key = (EventTypes.Member, event.user_id)
|
||||
caller = auth_events.get(key)
|
||||
|
||||
caller_in_room = caller and caller.membership == Membership.JOIN
|
||||
caller_invited = caller and caller.membership == Membership.INVITE
|
||||
|
||||
# get info about the target
|
||||
key = (EventTypes.Member, target_user_id, )
|
||||
key = (EventTypes.Member, target_user_id)
|
||||
target = auth_events.get(key)
|
||||
|
||||
target_in_room = target and target.membership == Membership.JOIN
|
||||
target_banned = target and target.membership == Membership.BAN
|
||||
|
||||
key = (EventTypes.JoinRules, "", )
|
||||
key = (EventTypes.JoinRules, "")
|
||||
join_rule_event = auth_events.get(key)
|
||||
if join_rule_event:
|
||||
join_rule = join_rule_event.content.get(
|
||||
"join_rule", JoinRules.INVITE
|
||||
)
|
||||
join_rule = join_rule_event.content.get("join_rule", JoinRules.INVITE)
|
||||
else:
|
||||
join_rule = JoinRules.INVITE
|
||||
|
||||
user_level = get_user_power_level(event.user_id, auth_events)
|
||||
target_level = get_user_power_level(
|
||||
target_user_id, auth_events
|
||||
)
|
||||
target_level = get_user_power_level(target_user_id, auth_events)
|
||||
|
||||
# FIXME (erikj): What should we do here as the default?
|
||||
ban_level = _get_named_level(auth_events, "ban", 50)
|
||||
|
@ -266,29 +238,26 @@ def _is_membership_change_allowed(event, auth_events):
|
|||
"join_rule": join_rule,
|
||||
"target_user_id": target_user_id,
|
||||
"event.user_id": event.user_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if Membership.INVITE == membership and "third_party_invite" in event.content:
|
||||
if not _verify_third_party_invite(event, auth_events):
|
||||
raise AuthError(403, "You are not invited to this room.")
|
||||
if target_banned:
|
||||
raise AuthError(
|
||||
403, "%s is banned from the room" % (target_user_id,)
|
||||
)
|
||||
raise AuthError(403, "%s is banned from the room" % (target_user_id,))
|
||||
return
|
||||
|
||||
if Membership.JOIN != membership:
|
||||
if (caller_invited
|
||||
and Membership.LEAVE == membership
|
||||
and target_user_id == event.user_id):
|
||||
if (
|
||||
caller_invited
|
||||
and Membership.LEAVE == membership
|
||||
and target_user_id == event.user_id
|
||||
):
|
||||
return
|
||||
|
||||
if not caller_in_room: # caller isn't joined
|
||||
raise AuthError(
|
||||
403,
|
||||
"%s not in room %s." % (event.user_id, event.room_id,)
|
||||
)
|
||||
raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id))
|
||||
|
||||
if Membership.INVITE == membership:
|
||||
# TODO (erikj): We should probably handle this more intelligently
|
||||
|
@ -296,19 +265,14 @@ def _is_membership_change_allowed(event, auth_events):
|
|||
|
||||
# Invites are valid iff caller is in the room and target isn't.
|
||||
if target_banned:
|
||||
raise AuthError(
|
||||
403, "%s is banned from the room" % (target_user_id,)
|
||||
)
|
||||
raise AuthError(403, "%s is banned from the room" % (target_user_id,))
|
||||
elif target_in_room: # the target is already in the room.
|
||||
raise AuthError(403, "%s is already in the room." %
|
||||
target_user_id)
|
||||
raise AuthError(403, "%s is already in the room." % target_user_id)
|
||||
else:
|
||||
invite_level = _get_named_level(auth_events, "invite", 0)
|
||||
|
||||
if user_level < invite_level:
|
||||
raise AuthError(
|
||||
403, "You don't have permission to invite users",
|
||||
)
|
||||
raise AuthError(403, "You don't have permission to invite users")
|
||||
elif Membership.JOIN == membership:
|
||||
# Joins are valid iff caller == target and they were:
|
||||
# invited: They are accepting the invitation
|
||||
|
@ -329,16 +293,12 @@ def _is_membership_change_allowed(event, auth_events):
|
|||
elif Membership.LEAVE == membership:
|
||||
# TODO (erikj): Implement kicks.
|
||||
if target_banned and user_level < ban_level:
|
||||
raise AuthError(
|
||||
403, "You cannot unban user %s." % (target_user_id,)
|
||||
)
|
||||
raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
|
||||
elif target_user_id != event.user_id:
|
||||
kick_level = _get_named_level(auth_events, "kick", 50)
|
||||
|
||||
if user_level < kick_level or user_level <= target_level:
|
||||
raise AuthError(
|
||||
403, "You cannot kick user %s." % target_user_id
|
||||
)
|
||||
raise AuthError(403, "You cannot kick user %s." % target_user_id)
|
||||
elif Membership.BAN == membership:
|
||||
if user_level < ban_level or user_level <= target_level:
|
||||
raise AuthError(403, "You don't have permission to ban")
|
||||
|
@ -347,21 +307,17 @@ def _is_membership_change_allowed(event, auth_events):
|
|||
|
||||
|
||||
def _check_event_sender_in_room(event, auth_events):
|
||||
key = (EventTypes.Member, event.user_id, )
|
||||
key = (EventTypes.Member, event.user_id)
|
||||
member_event = auth_events.get(key)
|
||||
|
||||
return _check_joined_room(
|
||||
member_event,
|
||||
event.user_id,
|
||||
event.room_id
|
||||
)
|
||||
return _check_joined_room(member_event, event.user_id, event.room_id)
|
||||
|
||||
|
||||
def _check_joined_room(member, user_id, room_id):
|
||||
if not member or member.membership != Membership.JOIN:
|
||||
raise AuthError(403, "User %s not in room %s (%s)" % (
|
||||
user_id, room_id, repr(member)
|
||||
))
|
||||
raise AuthError(
|
||||
403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
|
||||
)
|
||||
|
||||
|
||||
def get_send_level(etype, state_key, power_levels_event):
|
||||
|
@ -402,26 +358,21 @@ def get_send_level(etype, state_key, power_levels_event):
|
|||
def _can_send_event(event, auth_events):
|
||||
power_levels_event = _get_power_level_event(auth_events)
|
||||
|
||||
send_level = get_send_level(
|
||||
event.type, event.get("state_key"), power_levels_event,
|
||||
)
|
||||
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
|
||||
user_level = get_user_power_level(event.user_id, auth_events)
|
||||
|
||||
if user_level < send_level:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to post that to the room. " +
|
||||
"user_level (%d) < send_level (%d)" % (user_level, send_level)
|
||||
"You don't have permission to post that to the room. "
|
||||
+ "user_level (%d) < send_level (%d)" % (user_level, send_level),
|
||||
)
|
||||
|
||||
# Check state_key
|
||||
if hasattr(event, "state_key"):
|
||||
if event.state_key.startswith("@"):
|
||||
if event.state_key != event.user_id:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You are not allowed to set others state"
|
||||
)
|
||||
raise AuthError(403, "You are not allowed to set others state")
|
||||
|
||||
return True
|
||||
|
||||
|
@ -459,10 +410,7 @@ def check_redaction(room_version, event, auth_events):
|
|||
event.internal_metadata.recheck_redaction = True
|
||||
return True
|
||||
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to redact events"
|
||||
)
|
||||
raise AuthError(403, "You don't have permission to redact events")
|
||||
|
||||
|
||||
def _check_power_levels(event, auth_events):
|
||||
|
@ -479,7 +427,7 @@ def _check_power_levels(event, auth_events):
|
|||
except Exception:
|
||||
raise SynapseError(400, "Not a valid power level: %s" % (v,))
|
||||
|
||||
key = (event.type, event.state_key, )
|
||||
key = (event.type, event.state_key)
|
||||
current_state = auth_events.get(key)
|
||||
|
||||
if not current_state:
|
||||
|
@ -500,16 +448,12 @@ def _check_power_levels(event, auth_events):
|
|||
|
||||
old_list = current_state.content.get("users", {})
|
||||
for user in set(list(old_list) + list(user_list)):
|
||||
levels_to_check.append(
|
||||
(user, "users")
|
||||
)
|
||||
levels_to_check.append((user, "users"))
|
||||
|
||||
old_list = current_state.content.get("events", {})
|
||||
new_list = event.content.get("events", {})
|
||||
for ev_id in set(list(old_list) + list(new_list)):
|
||||
levels_to_check.append(
|
||||
(ev_id, "events")
|
||||
)
|
||||
levels_to_check.append((ev_id, "events"))
|
||||
|
||||
old_state = current_state.content
|
||||
new_state = event.content
|
||||
|
@ -540,7 +484,7 @@ def _check_power_levels(event, auth_events):
|
|||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to remove ops level equal "
|
||||
"to your own"
|
||||
"to your own",
|
||||
)
|
||||
|
||||
# Check if the old and new levels are greater than the user level
|
||||
|
@ -550,8 +494,7 @@ def _check_power_levels(event, auth_events):
|
|||
if old_level_too_big or new_level_too_big:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to add ops level greater "
|
||||
"than your own"
|
||||
"You don't have permission to add ops level greater " "than your own",
|
||||
)
|
||||
|
||||
|
||||
|
@ -587,10 +530,9 @@ def get_user_power_level(user_id, auth_events):
|
|||
|
||||
# some things which call this don't pass the create event: hack around
|
||||
# that.
|
||||
key = (EventTypes.Create, "", )
|
||||
key = (EventTypes.Create, "")
|
||||
create_event = auth_events.get(key)
|
||||
if (create_event is not None and
|
||||
create_event.content["creator"] == user_id):
|
||||
if create_event is not None and create_event.content["creator"] == user_id:
|
||||
return 100
|
||||
else:
|
||||
return 0
|
||||
|
@ -636,9 +578,7 @@ def _verify_third_party_invite(event, auth_events):
|
|||
|
||||
token = signed["token"]
|
||||
|
||||
invite_event = auth_events.get(
|
||||
(EventTypes.ThirdPartyInvite, token,)
|
||||
)
|
||||
invite_event = auth_events.get((EventTypes.ThirdPartyInvite, token))
|
||||
if not invite_event:
|
||||
return False
|
||||
|
||||
|
@ -661,8 +601,7 @@ def _verify_third_party_invite(event, auth_events):
|
|||
if not key_name.startswith("ed25519:"):
|
||||
continue
|
||||
verify_key = decode_verify_key_bytes(
|
||||
key_name,
|
||||
decode_base64(public_key)
|
||||
key_name, decode_base64(public_key)
|
||||
)
|
||||
verify_signed_json(signed, server, verify_key)
|
||||
|
||||
|
@ -671,7 +610,7 @@ def _verify_third_party_invite(event, auth_events):
|
|||
# The caller is responsible for checking that the signing
|
||||
# server has not revoked that public key.
|
||||
return True
|
||||
except (KeyError, SignatureVerifyException,):
|
||||
except (KeyError, SignatureVerifyException):
|
||||
continue
|
||||
return False
|
||||
|
||||
|
@ -679,9 +618,7 @@ def _verify_third_party_invite(event, auth_events):
|
|||
def get_public_keys(invite_event):
|
||||
public_keys = []
|
||||
if "public_key" in invite_event.content:
|
||||
o = {
|
||||
"public_key": invite_event.content["public_key"],
|
||||
}
|
||||
o = {"public_key": invite_event.content["public_key"]}
|
||||
if "key_validity_url" in invite_event.content:
|
||||
o["key_validity_url"] = invite_event.content["key_validity_url"]
|
||||
public_keys.append(o)
|
||||
|
@ -702,22 +639,22 @@ def auth_types_for_event(event):
|
|||
|
||||
auth_types = []
|
||||
|
||||
auth_types.append((EventTypes.PowerLevels, "", ))
|
||||
auth_types.append((EventTypes.Member, event.sender, ))
|
||||
auth_types.append((EventTypes.Create, "", ))
|
||||
auth_types.append((EventTypes.PowerLevels, ""))
|
||||
auth_types.append((EventTypes.Member, event.sender))
|
||||
auth_types.append((EventTypes.Create, ""))
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
membership = event.content["membership"]
|
||||
if membership in [Membership.JOIN, Membership.INVITE]:
|
||||
auth_types.append((EventTypes.JoinRules, "", ))
|
||||
auth_types.append((EventTypes.JoinRules, ""))
|
||||
|
||||
auth_types.append((EventTypes.Member, event.state_key, ))
|
||||
auth_types.append((EventTypes.Member, event.state_key))
|
||||
|
||||
if membership == Membership.INVITE:
|
||||
if "third_party_invite" in event.content:
|
||||
key = (
|
||||
EventTypes.ThirdPartyInvite,
|
||||
event.content["third_party_invite"]["signed"]["token"]
|
||||
event.content["third_party_invite"]["signed"]["token"],
|
||||
)
|
||||
auth_types.append(key)
|
||||
|
||||
|
|
|
@ -92,6 +92,18 @@ class _EventInternalMetadata(object):
|
|||
"""
|
||||
return getattr(self, "soft_failed", False)
|
||||
|
||||
def should_proactively_send(self):
|
||||
"""Whether the event, if ours, should be sent to other clients and
|
||||
servers.
|
||||
|
||||
This is used for sending dummy events internally. Servers and clients
|
||||
can still explicitly fetch the event.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
return getattr(self, "proactively_send", True)
|
||||
|
||||
|
||||
def _event_dict_property(key):
|
||||
# We want to be able to use hasattr with the event dict properties.
|
||||
|
@ -115,25 +127,25 @@ def _event_dict_property(key):
|
|||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
return property(
|
||||
getter,
|
||||
setter,
|
||||
delete,
|
||||
)
|
||||
return property(getter, setter, delete)
|
||||
|
||||
|
||||
class EventBase(object):
|
||||
def __init__(self, event_dict, signatures={}, unsigned={},
|
||||
internal_metadata_dict={}, rejected_reason=None):
|
||||
def __init__(
|
||||
self,
|
||||
event_dict,
|
||||
signatures={},
|
||||
unsigned={},
|
||||
internal_metadata_dict={},
|
||||
rejected_reason=None,
|
||||
):
|
||||
self.signatures = signatures
|
||||
self.unsigned = unsigned
|
||||
self.rejected_reason = rejected_reason
|
||||
|
||||
self._event_dict = event_dict
|
||||
|
||||
self.internal_metadata = _EventInternalMetadata(
|
||||
internal_metadata_dict
|
||||
)
|
||||
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
|
||||
|
||||
auth_events = _event_dict_property("auth_events")
|
||||
depth = _event_dict_property("depth")
|
||||
|
@ -156,10 +168,7 @@ class EventBase(object):
|
|||
|
||||
def get_dict(self):
|
||||
d = dict(self._event_dict)
|
||||
d.update({
|
||||
"signatures": self.signatures,
|
||||
"unsigned": dict(self.unsigned),
|
||||
})
|
||||
d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
|
||||
|
||||
return d
|
||||
|
||||
|
@ -346,6 +355,7 @@ class FrozenEventV2(EventBase):
|
|||
|
||||
class FrozenEventV3(FrozenEventV2):
|
||||
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
|
||||
|
||||
format_version = EventFormatVersions.V3 # All events of this type are V3
|
||||
|
||||
@property
|
||||
|
@ -402,6 +412,4 @@ def event_type_from_format_version(format_version):
|
|||
elif format_version == EventFormatVersions.V3:
|
||||
return FrozenEventV3
|
||||
else:
|
||||
raise Exception(
|
||||
"No event format %r" % (format_version,)
|
||||
)
|
||||
raise Exception("No event format %r" % (format_version,))
|
||||
|
|
|
@ -78,7 +78,9 @@ class EventBuilder(object):
|
|||
_redacts = attr.ib(default=None)
|
||||
_origin_server_ts = attr.ib(default=None)
|
||||
|
||||
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
|
||||
internal_metadata = attr.ib(
|
||||
default=attr.Factory(lambda: _EventInternalMetadata({}))
|
||||
)
|
||||
|
||||
@property
|
||||
def state_key(self):
|
||||
|
@ -102,11 +104,9 @@ class EventBuilder(object):
|
|||
"""
|
||||
|
||||
state_ids = yield self._state.get_current_state_ids(
|
||||
self.room_id, prev_event_ids,
|
||||
)
|
||||
auth_ids = yield self._auth.compute_auth_events(
|
||||
self, state_ids,
|
||||
self.room_id, prev_event_ids
|
||||
)
|
||||
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
|
||||
|
||||
if self.format_version == EventFormatVersions.V1:
|
||||
auth_events = yield self._store.add_event_hashes(auth_ids)
|
||||
|
@ -115,9 +115,7 @@ class EventBuilder(object):
|
|||
auth_events = auth_ids
|
||||
prev_events = prev_event_ids
|
||||
|
||||
old_depth = yield self._store.get_max_depth_of(
|
||||
prev_event_ids,
|
||||
)
|
||||
old_depth = yield self._store.get_max_depth_of(prev_event_ids)
|
||||
depth = old_depth + 1
|
||||
|
||||
# we cap depth of generated events, to ensure that they are not
|
||||
|
@ -217,9 +215,14 @@ class EventBuilderFactory(object):
|
|||
)
|
||||
|
||||
|
||||
def create_local_event_from_event_dict(clock, hostname, signing_key,
|
||||
format_version, event_dict,
|
||||
internal_metadata_dict=None):
|
||||
def create_local_event_from_event_dict(
|
||||
clock,
|
||||
hostname,
|
||||
signing_key,
|
||||
format_version,
|
||||
event_dict,
|
||||
internal_metadata_dict=None,
|
||||
):
|
||||
"""Takes a fully formed event dict, ensuring that fields like `origin`
|
||||
and `origin_server_ts` have correct values for a locally produced event,
|
||||
then signs and hashes it.
|
||||
|
@ -237,9 +240,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
|
|||
"""
|
||||
|
||||
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
|
||||
raise Exception(
|
||||
"No event format defined for version %r" % (format_version,)
|
||||
)
|
||||
raise Exception("No event format defined for version %r" % (format_version,))
|
||||
|
||||
if internal_metadata_dict is None:
|
||||
internal_metadata_dict = {}
|
||||
|
@ -258,13 +259,9 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
|
|||
|
||||
event_dict.setdefault("signatures", {})
|
||||
|
||||
add_hashes_and_signatures(
|
||||
event_dict,
|
||||
hostname,
|
||||
signing_key,
|
||||
)
|
||||
add_hashes_and_signatures(event_dict, hostname, signing_key)
|
||||
return event_type_from_format_version(format_version)(
|
||||
event_dict, internal_metadata_dict=internal_metadata_dict,
|
||||
event_dict, internal_metadata_dict=internal_metadata_dict
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -88,8 +88,9 @@ class EventContext(object):
|
|||
self.app_service = None
|
||||
|
||||
@staticmethod
|
||||
def with_state(state_group, current_state_ids, prev_state_ids,
|
||||
prev_group=None, delta_ids=None):
|
||||
def with_state(
|
||||
state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None
|
||||
):
|
||||
context = EventContext()
|
||||
|
||||
# The current state including the current event
|
||||
|
@ -132,17 +133,19 @@ class EventContext(object):
|
|||
else:
|
||||
prev_state_id = None
|
||||
|
||||
defer.returnValue({
|
||||
"prev_state_id": prev_state_id,
|
||||
"event_type": event.type,
|
||||
"event_state_key": event.state_key if event.is_state() else None,
|
||||
"state_group": self.state_group,
|
||||
"rejected": self.rejected,
|
||||
"prev_group": self.prev_group,
|
||||
"delta_ids": _encode_state_dict(self.delta_ids),
|
||||
"prev_state_events": self.prev_state_events,
|
||||
"app_service_id": self.app_service.id if self.app_service else None
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"prev_state_id": prev_state_id,
|
||||
"event_type": event.type,
|
||||
"event_state_key": event.state_key if event.is_state() else None,
|
||||
"state_group": self.state_group,
|
||||
"rejected": self.rejected,
|
||||
"prev_group": self.prev_group,
|
||||
"delta_ids": _encode_state_dict(self.delta_ids),
|
||||
"prev_state_events": self.prev_state_events,
|
||||
"app_service_id": self.app_service.id if self.app_service else None,
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(store, input):
|
||||
|
@ -194,7 +197,7 @@ class EventContext(object):
|
|||
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(
|
||||
self._fill_out_state, store,
|
||||
self._fill_out_state, store
|
||||
)
|
||||
|
||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
@ -214,7 +217,7 @@ class EventContext(object):
|
|||
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(
|
||||
self._fill_out_state, store,
|
||||
self._fill_out_state, store
|
||||
)
|
||||
|
||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
@ -240,9 +243,7 @@ class EventContext(object):
|
|||
if self.state_group is None:
|
||||
return
|
||||
|
||||
self._current_state_ids = yield store.get_state_ids_for_group(
|
||||
self.state_group,
|
||||
)
|
||||
self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
|
||||
if self._prev_state_id and self._event_state_key is not None:
|
||||
self._prev_state_ids = dict(self._current_state_ids)
|
||||
|
||||
|
@ -252,8 +253,9 @@ class EventContext(object):
|
|||
self._prev_state_ids = self._current_state_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_state(self, state_group, prev_state_ids, current_state_ids,
|
||||
prev_group, delta_ids):
|
||||
def update_state(
|
||||
self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
|
||||
):
|
||||
"""Replace the state in the context
|
||||
"""
|
||||
|
||||
|
@ -279,10 +281,7 @@ def _encode_state_dict(state_dict):
|
|||
if state_dict is None:
|
||||
return None
|
||||
|
||||
return [
|
||||
(etype, state_key, v)
|
||||
for (etype, state_key), v in iteritems(state_dict)
|
||||
]
|
||||
return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)]
|
||||
|
||||
|
||||
def _decode_state_dict(input):
|
||||
|
@ -291,4 +290,4 @@ def _decode_state_dict(input):
|
|||
if input is None:
|
||||
return None
|
||||
|
||||
return frozendict({(etype, state_key,): v for etype, state_key, v in input})
|
||||
return frozendict({(etype, state_key): v for etype, state_key, v in input})
|
||||
|
|
|
@ -60,7 +60,9 @@ class SpamChecker(object):
|
|||
if self.spam_checker is None:
|
||||
return True
|
||||
|
||||
return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
|
||||
return self.spam_checker.user_may_invite(
|
||||
inviter_userid, invitee_userid, room_id
|
||||
)
|
||||
|
||||
def user_may_create_room(self, userid):
|
||||
"""Checks if a given user may create a room
|
||||
|
|
113
synapse/events/third_party_rules.py
Normal file
113
synapse/events/third_party_rules.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
class ThirdPartyEventRules(object):
|
||||
"""Allows server admins to provide a Python module implementing an extra
|
||||
set of rules to apply when processing events.
|
||||
|
||||
This is designed to help admins of closed federations with enforcing custom
|
||||
behaviours.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.third_party_rules = None
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
module = None
|
||||
config = None
|
||||
if hs.config.third_party_event_rules:
|
||||
module, config = hs.config.third_party_event_rules
|
||||
|
||||
if module is not None:
|
||||
self.third_party_rules = module(
|
||||
config=config, http_client=hs.get_simple_http_client()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_event_allowed(self, event, context):
|
||||
"""Check if a provided event should be allowed in the given context.
|
||||
|
||||
Args:
|
||||
event (synapse.events.EventBase): The event to be checked.
|
||||
context (synapse.events.snapshot.EventContext): The context of the event.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[bool]: True if the event should be allowed, False if not.
|
||||
"""
|
||||
if self.third_party_rules is None:
|
||||
defer.returnValue(True)
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
|
||||
# Retrieve the state events from the database.
|
||||
state_events = {}
|
||||
for key, event_id in prev_state_ids.items():
|
||||
state_events[key] = yield self.store.get_event(event_id, allow_none=True)
|
||||
|
||||
ret = yield self.third_party_rules.check_event_allowed(event, state_events)
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_create_room(self, requester, config, is_requester_admin):
|
||||
"""Intercept requests to create room to allow, deny or update the
|
||||
request config.
|
||||
|
||||
Args:
|
||||
requester (Requester)
|
||||
config (dict): The creation config from the client.
|
||||
is_requester_admin (bool): If the requester is an admin
|
||||
|
||||
Returns:
|
||||
defer.Deferred
|
||||
"""
|
||||
|
||||
if self.third_party_rules is None:
|
||||
return
|
||||
|
||||
yield self.third_party_rules.on_create_room(
|
||||
requester, config, is_requester_admin
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_threepid_can_be_invited(self, medium, address, room_id):
|
||||
"""Check if a provided 3PID can be invited in the given room.
|
||||
|
||||
Args:
|
||||
medium (str): The 3PID's medium.
|
||||
address (str): The 3PID's address.
|
||||
room_id (str): The room we want to invite the threepid to.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[bool], True if the 3PID can be invited, False if not.
|
||||
"""
|
||||
|
||||
if self.third_party_rules is None:
|
||||
defer.returnValue(True)
|
||||
|
||||
state_ids = yield self.store.get_filtered_current_state_ids(room_id)
|
||||
room_state_events = yield self.store.get_events(state_ids.values())
|
||||
|
||||
state_events = {}
|
||||
for key, event_id in state_ids.items():
|
||||
state_events[key] = room_state_events[event_id]
|
||||
|
||||
ret = yield self.third_party_rules.check_threepid_can_be_invited(
|
||||
medium, address, state_events
|
||||
)
|
||||
defer.returnValue(ret)
|
|
@ -31,7 +31,7 @@ from . import EventBase
|
|||
# by a match for 'stuff'.
|
||||
# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
|
||||
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
|
||||
SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
|
||||
SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
|
||||
|
||||
|
||||
def prune_event(event):
|
||||
|
@ -51,6 +51,7 @@ def prune_event(event):
|
|||
pruned_event_dict = prune_event_dict(event.get_dict())
|
||||
|
||||
from . import event_type_from_format_version
|
||||
|
||||
return event_type_from_format_version(event.format_version)(
|
||||
pruned_event_dict, event.internal_metadata.get_dict()
|
||||
)
|
||||
|
@ -116,11 +117,7 @@ def prune_event_dict(event_dict):
|
|||
elif event_type == EventTypes.RoomHistoryVisibility:
|
||||
add_fields("history_visibility")
|
||||
|
||||
allowed_fields = {
|
||||
k: v
|
||||
for k, v in event_dict.items()
|
||||
if k in allowed_keys
|
||||
}
|
||||
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
|
||||
|
||||
allowed_fields["content"] = new_content
|
||||
|
||||
|
@ -205,7 +202,7 @@ def only_fields(dictionary, fields):
|
|||
# for each element of the output array of arrays:
|
||||
# remove escaping so we can use the right key names.
|
||||
split_fields[:] = [
|
||||
[f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
|
||||
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
|
||||
]
|
||||
|
||||
output = {}
|
||||
|
@ -226,7 +223,10 @@ def format_event_for_client_v1(d):
|
|||
d["user_id"] = sender
|
||||
|
||||
copy_keys = (
|
||||
"age", "redacted_because", "replaces_state", "prev_content",
|
||||
"age",
|
||||
"redacted_because",
|
||||
"replaces_state",
|
||||
"prev_content",
|
||||
"invite_room_state",
|
||||
)
|
||||
for key in copy_keys:
|
||||
|
@ -238,8 +238,13 @@ def format_event_for_client_v1(d):
|
|||
|
||||
def format_event_for_client_v2(d):
|
||||
drop_keys = (
|
||||
"auth_events", "prev_events", "hashes", "signatures", "depth",
|
||||
"origin", "prev_state",
|
||||
"auth_events",
|
||||
"prev_events",
|
||||
"hashes",
|
||||
"signatures",
|
||||
"depth",
|
||||
"origin",
|
||||
"prev_state",
|
||||
)
|
||||
for key in drop_keys:
|
||||
d.pop(key, None)
|
||||
|
@ -252,9 +257,15 @@ def format_event_for_client_v2_without_room_id(d):
|
|||
return d
|
||||
|
||||
|
||||
def serialize_event(e, time_now_ms, as_client_event=True,
|
||||
event_format=format_event_for_client_v1,
|
||||
token_id=None, only_event_fields=None, is_invite=False):
|
||||
def serialize_event(
|
||||
e,
|
||||
time_now_ms,
|
||||
as_client_event=True,
|
||||
event_format=format_event_for_client_v1,
|
||||
token_id=None,
|
||||
only_event_fields=None,
|
||||
is_invite=False,
|
||||
):
|
||||
"""Serialize event for clients
|
||||
|
||||
Args:
|
||||
|
@ -288,8 +299,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
|
|||
|
||||
if "redacted_because" in e.unsigned:
|
||||
d["unsigned"]["redacted_because"] = serialize_event(
|
||||
e.unsigned["redacted_because"], time_now_ms,
|
||||
event_format=event_format
|
||||
e.unsigned["redacted_because"], time_now_ms, event_format=event_format
|
||||
)
|
||||
|
||||
if token_id is not None:
|
||||
|
@ -308,8 +318,9 @@ def serialize_event(e, time_now_ms, as_client_event=True,
|
|||
d = event_format(d)
|
||||
|
||||
if only_event_fields:
|
||||
if (not isinstance(only_event_fields, list) or
|
||||
not all(isinstance(f, string_types) for f in only_event_fields)):
|
||||
if not isinstance(only_event_fields, list) or not all(
|
||||
isinstance(f, string_types) for f in only_event_fields
|
||||
):
|
||||
raise TypeError("only_event_fields must be a list of strings")
|
||||
d = only_fields(d, only_event_fields)
|
||||
|
||||
|
@ -352,11 +363,9 @@ class EventClientSerializer(object):
|
|||
# If MSC1849 is enabled then we need to look if thre are any relations
|
||||
# we need to bundle in with the event
|
||||
if self.experimental_msc1849_support_enabled and bundle_aggregations:
|
||||
annotations = yield self.store.get_aggregation_groups_for_event(
|
||||
event_id,
|
||||
)
|
||||
annotations = yield self.store.get_aggregation_groups_for_event(event_id)
|
||||
references = yield self.store.get_relations_for_event(
|
||||
event_id, RelationTypes.REFERENCE, direction="f",
|
||||
event_id, RelationTypes.REFERENCE, direction="f"
|
||||
)
|
||||
|
||||
if annotations.chunk:
|
||||
|
@ -383,9 +392,7 @@ class EventClientSerializer(object):
|
|||
serialized_event["content"].pop("m.relates_to", None)
|
||||
|
||||
r = serialized_event["unsigned"].setdefault("m.relations", {})
|
||||
r[RelationTypes.REPLACE] = {
|
||||
"event_id": edit.event_id,
|
||||
}
|
||||
r[RelationTypes.REPLACE] = {"event_id": edit.event_id}
|
||||
|
||||
defer.returnValue(serialized_event)
|
||||
|
||||
|
@ -401,6 +408,5 @@ class EventClientSerializer(object):
|
|||
Deferred[list[dict]]: The list of serialized events
|
||||
"""
|
||||
return yieldable_gather_results(
|
||||
self.serialize_event, events,
|
||||
time_now=time_now, **kwargs
|
||||
self.serialize_event, events, time_now=time_now, **kwargs
|
||||
)
|
||||
|
|
|
@ -48,9 +48,7 @@ class EventValidator(object):
|
|||
raise SynapseError(400, "Event does not have key %s" % (k,))
|
||||
|
||||
# Check that the following keys have string values
|
||||
event_strings = [
|
||||
"origin",
|
||||
]
|
||||
event_strings = ["origin"]
|
||||
|
||||
for s in event_strings:
|
||||
if not isinstance(getattr(event, s), string_types):
|
||||
|
@ -62,8 +60,10 @@ class EventValidator(object):
|
|||
if len(alias) > MAX_ALIAS_LENGTH:
|
||||
raise SynapseError(
|
||||
400,
|
||||
("Can't create aliases longer than"
|
||||
" %d characters" % (MAX_ALIAS_LENGTH,)),
|
||||
(
|
||||
"Can't create aliases longer than"
|
||||
" %d characters" % (MAX_ALIAS_LENGTH,)
|
||||
),
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
|
@ -76,11 +76,7 @@ class EventValidator(object):
|
|||
event (EventBuilder|FrozenEvent)
|
||||
"""
|
||||
|
||||
strings = [
|
||||
"room_id",
|
||||
"sender",
|
||||
"type",
|
||||
]
|
||||
strings = ["room_id", "sender", "type"]
|
||||
|
||||
if hasattr(event, "state_key"):
|
||||
strings.append("state_key")
|
||||
|
@ -93,10 +89,7 @@ class EventValidator(object):
|
|||
UserID.from_string(event.sender)
|
||||
|
||||
if event.type == EventTypes.Message:
|
||||
strings = [
|
||||
"body",
|
||||
"msgtype",
|
||||
]
|
||||
strings = ["body", "msgtype"]
|
||||
|
||||
self._ensure_strings(event.content, strings)
|
||||
|
||||
|
|
|
@ -44,8 +44,9 @@ class FederationBase(object):
|
|||
self._clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
|
||||
outlier=False, include_none=False):
|
||||
def _check_sigs_and_hash_and_fetch(
|
||||
self, origin, pdus, room_version, outlier=False, include_none=False
|
||||
):
|
||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||
one. If a PDU fails its signature check then we check if we have it in
|
||||
the database and if not then request if from the originating server of
|
||||
|
@ -79,9 +80,7 @@ class FederationBase(object):
|
|||
if not res:
|
||||
# Check local db.
|
||||
res = yield self.store.get_event(
|
||||
pdu.event_id,
|
||||
allow_rejected=True,
|
||||
allow_none=True,
|
||||
pdu.event_id, allow_rejected=True, allow_none=True
|
||||
)
|
||||
|
||||
if not res and pdu.origin != origin:
|
||||
|
@ -98,23 +97,16 @@ class FederationBase(object):
|
|||
|
||||
if not res:
|
||||
logger.warn(
|
||||
"Failed to find copy of %s with valid signature",
|
||||
pdu.event_id,
|
||||
"Failed to find copy of %s with valid signature", pdu.event_id
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
handle = logcontext.preserve_fn(handle_check_result)
|
||||
deferreds2 = [
|
||||
handle(pdu, deferred)
|
||||
for pdu, deferred in zip(pdus, deferreds)
|
||||
]
|
||||
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
|
||||
|
||||
valid_pdus = yield logcontext.make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
deferreds2,
|
||||
consumeErrors=True,
|
||||
)
|
||||
defer.gatherResults(deferreds2, consumeErrors=True)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
if include_none:
|
||||
|
@ -124,7 +116,7 @@ class FederationBase(object):
|
|||
|
||||
def _check_sigs_and_hash(self, room_version, pdu):
|
||||
return logcontext.make_deferred_yieldable(
|
||||
self._check_sigs_and_hashes(room_version, [pdu])[0],
|
||||
self._check_sigs_and_hashes(room_version, [pdu])[0]
|
||||
)
|
||||
|
||||
def _check_sigs_and_hashes(self, room_version, pdus):
|
||||
|
@ -159,11 +151,9 @@ class FederationBase(object):
|
|||
# received event was probably a redacted copy (but we then use our
|
||||
# *actual* redacted copy to be on the safe side.)
|
||||
redacted_event = prune_event(pdu)
|
||||
if (
|
||||
set(redacted_event.keys()) == set(pdu.keys()) and
|
||||
set(six.iterkeys(redacted_event.content))
|
||||
== set(six.iterkeys(pdu.content))
|
||||
):
|
||||
if set(redacted_event.keys()) == set(pdu.keys()) and set(
|
||||
six.iterkeys(redacted_event.content)
|
||||
) == set(six.iterkeys(pdu.content)):
|
||||
logger.info(
|
||||
"Event %s seems to have been redacted; using our redacted "
|
||||
"copy",
|
||||
|
@ -172,14 +162,15 @@ class FederationBase(object):
|
|||
else:
|
||||
logger.warning(
|
||||
"Event %s content has been tampered, redacting",
|
||||
pdu.event_id, pdu.get_pdu_json(),
|
||||
pdu.event_id,
|
||||
)
|
||||
return redacted_event
|
||||
|
||||
if self.spam_checker.check_event_for_spam(pdu):
|
||||
logger.warn(
|
||||
"Event contains spam, redacting %s: %s",
|
||||
pdu.event_id, pdu.get_pdu_json()
|
||||
pdu.event_id,
|
||||
pdu.get_pdu_json(),
|
||||
)
|
||||
return prune_event(pdu)
|
||||
|
||||
|
@ -190,23 +181,24 @@ class FederationBase(object):
|
|||
with logcontext.PreserveLoggingContext(ctx):
|
||||
logger.warn(
|
||||
"Signature check failed for %s: %s",
|
||||
pdu.event_id, failure.getErrorMessage(),
|
||||
pdu.event_id,
|
||||
failure.getErrorMessage(),
|
||||
)
|
||||
return failure
|
||||
|
||||
for deferred, pdu in zip(deferreds, pdus):
|
||||
deferred.addCallbacks(
|
||||
callback, errback,
|
||||
callbackArgs=[pdu],
|
||||
errbackArgs=[pdu],
|
||||
callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
|
||||
)
|
||||
|
||||
return deferreds
|
||||
|
||||
|
||||
class PduToCheckSig(namedtuple("PduToCheckSig", [
|
||||
"pdu", "redacted_pdu_json", "sender_domain", "deferreds",
|
||||
])):
|
||||
class PduToCheckSig(
|
||||
namedtuple(
|
||||
"PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
|
||||
)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -260,10 +252,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
|
|||
|
||||
# First we check that the sender event is signed by the sender's domain
|
||||
# (except if its a 3pid invite, in which case it may be sent by any server)
|
||||
pdus_to_check_sender = [
|
||||
p for p in pdus_to_check
|
||||
if not _is_invite_via_3pid(p.pdu)
|
||||
]
|
||||
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
|
||||
|
||||
more_deferreds = keyring.verify_json_objects_for_server(
|
||||
[
|
||||
|
@ -297,7 +286,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
|
|||
# (ie, the room version uses old-style non-hash event IDs).
|
||||
if v.event_format == EventFormatVersions.V1:
|
||||
pdus_to_check_event_id = [
|
||||
p for p in pdus_to_check
|
||||
p
|
||||
for p in pdus_to_check
|
||||
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
|
||||
]
|
||||
|
||||
|
@ -315,10 +305,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
|
|||
|
||||
def event_err(e, pdu_to_check):
|
||||
errmsg = (
|
||||
"event id %s: unable to verify signature for event id domain: %s" % (
|
||||
pdu_to_check.pdu.event_id,
|
||||
e.getErrorMessage(),
|
||||
)
|
||||
"event id %s: unable to verify signature for event id domain: %s"
|
||||
% (pdu_to_check.pdu.event_id, e.getErrorMessage())
|
||||
)
|
||||
# XX as above: not really sure if these are the right codes
|
||||
raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
|
||||
|
@ -368,21 +356,18 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
|
|||
"""
|
||||
# we could probably enforce a bunch of other fields here (room_id, sender,
|
||||
# origin, etc etc)
|
||||
assert_params_in_dict(pdu_json, ('type', 'depth'))
|
||||
assert_params_in_dict(pdu_json, ("type", "depth"))
|
||||
|
||||
depth = pdu_json['depth']
|
||||
depth = pdu_json["depth"]
|
||||
if not isinstance(depth, six.integer_types):
|
||||
raise SynapseError(400, "Depth %r not an intger" % (depth, ),
|
||||
Codes.BAD_JSON)
|
||||
raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
|
||||
|
||||
if depth < 0:
|
||||
raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
|
||||
elif depth > MAX_DEPTH:
|
||||
raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
|
||||
|
||||
event = event_type_from_format_version(event_format_version)(
|
||||
pdu_json,
|
||||
)
|
||||
event = event_type_from_format_version(event_format_version)(pdu_json)
|
||||
|
||||
event.internal_metadata.outlier = outlier
|
||||
|
||||
|
|
|
@ -57,6 +57,7 @@ class InvalidResponseError(RuntimeError):
|
|||
"""Helper for _try_destination_list: indicates that the server returned a response
|
||||
we couldn't parse
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -65,9 +66,7 @@ class FederationClient(FederationBase):
|
|||
super(FederationClient, self).__init__(hs)
|
||||
|
||||
self.pdu_destination_tried = {}
|
||||
self._clock.looping_call(
|
||||
self._clear_tried_cache, 60 * 1000,
|
||||
)
|
||||
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
|
||||
self.state = hs.get_state_handler()
|
||||
self.transport_layer = hs.get_federation_transport_client()
|
||||
|
||||
|
@ -99,8 +98,14 @@ class FederationClient(FederationBase):
|
|||
self.pdu_destination_tried[event_id] = destination_dict
|
||||
|
||||
@log_function
|
||||
def make_query(self, destination, query_type, args,
|
||||
retry_on_dns_fail=False, ignore_backoff=False):
|
||||
def make_query(
|
||||
self,
|
||||
destination,
|
||||
query_type,
|
||||
args,
|
||||
retry_on_dns_fail=False,
|
||||
ignore_backoff=False,
|
||||
):
|
||||
"""Sends a federation Query to a remote homeserver of the given type
|
||||
and arguments.
|
||||
|
||||
|
@ -120,7 +125,10 @@ class FederationClient(FederationBase):
|
|||
sent_queries_counter.labels(query_type).inc()
|
||||
|
||||
return self.transport_layer.make_query(
|
||||
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
|
||||
destination,
|
||||
query_type,
|
||||
args,
|
||||
retry_on_dns_fail=retry_on_dns_fail,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
|
@ -137,9 +145,7 @@ class FederationClient(FederationBase):
|
|||
response
|
||||
"""
|
||||
sent_queries_counter.labels("client_device_keys").inc()
|
||||
return self.transport_layer.query_client_keys(
|
||||
destination, content, timeout
|
||||
)
|
||||
return self.transport_layer.query_client_keys(destination, content, timeout)
|
||||
|
||||
@log_function
|
||||
def query_user_devices(self, destination, user_id, timeout=30000):
|
||||
|
@ -147,9 +153,7 @@ class FederationClient(FederationBase):
|
|||
server.
|
||||
"""
|
||||
sent_queries_counter.labels("user_devices").inc()
|
||||
return self.transport_layer.query_user_devices(
|
||||
destination, user_id, timeout
|
||||
)
|
||||
return self.transport_layer.query_user_devices(destination, user_id, timeout)
|
||||
|
||||
@log_function
|
||||
def claim_client_keys(self, destination, content, timeout):
|
||||
|
@ -164,9 +168,7 @@ class FederationClient(FederationBase):
|
|||
response
|
||||
"""
|
||||
sent_queries_counter.labels("client_one_time_keys").inc()
|
||||
return self.transport_layer.claim_client_keys(
|
||||
destination, content, timeout
|
||||
)
|
||||
return self.transport_layer.claim_client_keys(destination, content, timeout)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -191,7 +193,8 @@ class FederationClient(FederationBase):
|
|||
return
|
||||
|
||||
transaction_data = yield self.transport_layer.backfill(
|
||||
dest, room_id, extremities, limit)
|
||||
dest, room_id, extremities, limit
|
||||
)
|
||||
|
||||
logger.debug("backfill transaction_data=%s", repr(transaction_data))
|
||||
|
||||
|
@ -204,17 +207,19 @@ class FederationClient(FederationBase):
|
|||
]
|
||||
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
self._check_sigs_and_hashes(room_version, pdus),
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError))
|
||||
pdus[:] = yield logcontext.make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
defer.returnValue(pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_pdu(self, destinations, event_id, room_version, outlier=False,
|
||||
timeout=None):
|
||||
def get_pdu(
|
||||
self, destinations, event_id, room_version, outlier=False, timeout=None
|
||||
):
|
||||
"""Requests the PDU with given origin and ID from the remote home
|
||||
servers.
|
||||
|
||||
|
@ -255,7 +260,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
try:
|
||||
transaction_data = yield self.transport_layer.get_event(
|
||||
destination, event_id, timeout=timeout,
|
||||
destination, event_id, timeout=timeout
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
|
@ -282,8 +287,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
except SynapseError as e:
|
||||
logger.info(
|
||||
"Failed to get PDU %s from %s because %s",
|
||||
event_id, destination, e,
|
||||
"Failed to get PDU %s from %s because %s", event_id, destination, e
|
||||
)
|
||||
continue
|
||||
except NotRetryingDestination as e:
|
||||
|
@ -296,8 +300,7 @@ class FederationClient(FederationBase):
|
|||
pdu_attempts[destination] = now
|
||||
|
||||
logger.info(
|
||||
"Failed to get PDU %s from %s because %s",
|
||||
event_id, destination, e,
|
||||
"Failed to get PDU %s from %s because %s", event_id, destination, e
|
||||
)
|
||||
continue
|
||||
|
||||
|
@ -326,7 +329,7 @@ class FederationClient(FederationBase):
|
|||
# we have most of the state and auth_chain already.
|
||||
# However, this may 404 if the other side has an old synapse.
|
||||
result = yield self.transport_layer.get_room_state_ids(
|
||||
destination, room_id, event_id=event_id,
|
||||
destination, room_id, event_id=event_id
|
||||
)
|
||||
|
||||
state_event_ids = result["pdu_ids"]
|
||||
|
@ -340,12 +343,10 @@ class FederationClient(FederationBase):
|
|||
logger.warning(
|
||||
"Failed to fetch missing state/auth events for %s: %s",
|
||||
room_id,
|
||||
failed_to_fetch
|
||||
failed_to_fetch,
|
||||
)
|
||||
|
||||
event_map = {
|
||||
ev.event_id: ev for ev in fetched_events
|
||||
}
|
||||
event_map = {ev.event_id: ev for ev in fetched_events}
|
||||
|
||||
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
|
||||
auth_chain = [
|
||||
|
@ -362,15 +363,14 @@ class FederationClient(FederationBase):
|
|||
raise e
|
||||
|
||||
result = yield self.transport_layer.get_room_state(
|
||||
destination, room_id, event_id=event_id,
|
||||
destination, room_id, event_id=event_id
|
||||
)
|
||||
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
format_ver = room_version_to_event_format(room_version)
|
||||
|
||||
pdus = [
|
||||
event_from_pdu_json(p, format_ver, outlier=True)
|
||||
for p in result["pdus"]
|
||||
event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"]
|
||||
]
|
||||
|
||||
auth_chain = [
|
||||
|
@ -378,9 +378,9 @@ class FederationClient(FederationBase):
|
|||
for p in result.get("auth_chain", [])
|
||||
]
|
||||
|
||||
seen_events = yield self.store.get_events([
|
||||
ev.event_id for ev in itertools.chain(pdus, auth_chain)
|
||||
])
|
||||
seen_events = yield self.store.get_events(
|
||||
[ev.event_id for ev in itertools.chain(pdus, auth_chain)]
|
||||
)
|
||||
|
||||
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination,
|
||||
|
@ -442,7 +442,7 @@ class FederationClient(FederationBase):
|
|||
batch_size = 20
|
||||
missing_events = list(missing_events)
|
||||
for i in range(0, len(missing_events), batch_size):
|
||||
batch = set(missing_events[i:i + batch_size])
|
||||
batch = set(missing_events[i : i + batch_size])
|
||||
|
||||
deferreds = [
|
||||
run_in_background(
|
||||
|
@ -470,21 +470,17 @@ class FederationClient(FederationBase):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_event_auth(self, destination, room_id, event_id):
|
||||
res = yield self.transport_layer.get_event_auth(
|
||||
destination, room_id, event_id,
|
||||
)
|
||||
res = yield self.transport_layer.get_event_auth(destination, room_id, event_id)
|
||||
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
format_ver = room_version_to_event_format(room_version)
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(p, format_ver, outlier=True)
|
||||
for p in res["auth_chain"]
|
||||
event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"]
|
||||
]
|
||||
|
||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, auth_chain,
|
||||
outlier=True, room_version=room_version,
|
||||
destination, auth_chain, outlier=True, room_version=room_version
|
||||
)
|
||||
|
||||
signed_auth.sort(key=lambda e: e.depth)
|
||||
|
@ -527,28 +523,26 @@ class FederationClient(FederationBase):
|
|||
res = yield callback(destination)
|
||||
defer.returnValue(res)
|
||||
except InvalidResponseError as e:
|
||||
logger.warn(
|
||||
"Failed to %s via %s: %s",
|
||||
description, destination, e,
|
||||
)
|
||||
logger.warn("Failed to %s via %s: %s", description, destination, e)
|
||||
except HttpResponseException as e:
|
||||
if not 500 <= e.code < 600:
|
||||
raise e.to_synapse_error()
|
||||
else:
|
||||
logger.warn(
|
||||
"Failed to %s via %s: %i %s",
|
||||
description, destination, e.code, e.args[0],
|
||||
description,
|
||||
destination,
|
||||
e.code,
|
||||
e.args[0],
|
||||
)
|
||||
except Exception:
|
||||
logger.warn(
|
||||
"Failed to %s via %s",
|
||||
description, destination, exc_info=1,
|
||||
)
|
||||
logger.warn("Failed to %s via %s", description, destination, exc_info=1)
|
||||
|
||||
raise RuntimeError("Failed to %s via any server" % (description, ))
|
||||
raise RuntimeError("Failed to %s via any server" % (description,))
|
||||
|
||||
def make_membership_event(self, destinations, room_id, user_id, membership,
|
||||
content, params):
|
||||
def make_membership_event(
|
||||
self, destinations, room_id, user_id, membership, content, params
|
||||
):
|
||||
"""
|
||||
Creates an m.room.member event, with context, without participating in the room.
|
||||
|
||||
|
@ -584,14 +578,14 @@ class FederationClient(FederationBase):
|
|||
valid_memberships = {Membership.JOIN, Membership.LEAVE}
|
||||
if membership not in valid_memberships:
|
||||
raise RuntimeError(
|
||||
"make_membership_event called with membership='%s', must be one of %s" %
|
||||
(membership, ",".join(valid_memberships))
|
||||
"make_membership_event called with membership='%s', must be one of %s"
|
||||
% (membership, ",".join(valid_memberships))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_request(destination):
|
||||
ret = yield self.transport_layer.make_membership_event(
|
||||
destination, room_id, user_id, membership, params,
|
||||
destination, room_id, user_id, membership, params
|
||||
)
|
||||
|
||||
# Note: If not supplied, the room version may be either v1 or v2,
|
||||
|
@ -614,16 +608,17 @@ class FederationClient(FederationBase):
|
|||
pdu_dict["prev_state"] = []
|
||||
|
||||
ev = builder.create_local_event_from_event_dict(
|
||||
self._clock, self.hostname, self.signing_key,
|
||||
format_version=event_format, event_dict=pdu_dict,
|
||||
self._clock,
|
||||
self.hostname,
|
||||
self.signing_key,
|
||||
format_version=event_format,
|
||||
event_dict=pdu_dict,
|
||||
)
|
||||
|
||||
defer.returnValue(
|
||||
(destination, ev, event_format)
|
||||
)
|
||||
defer.returnValue((destination, ev, event_format))
|
||||
|
||||
return self._try_destination_list(
|
||||
"make_" + membership, destinations, send_request,
|
||||
"make_" + membership, destinations, send_request
|
||||
)
|
||||
|
||||
def send_join(self, destinations, pdu, event_format_version):
|
||||
|
@ -655,9 +650,7 @@ class FederationClient(FederationBase):
|
|||
create_event = e
|
||||
break
|
||||
else:
|
||||
raise InvalidResponseError(
|
||||
"no %s in auth chain" % (EventTypes.Create,),
|
||||
)
|
||||
raise InvalidResponseError("no %s in auth chain" % (EventTypes.Create,))
|
||||
|
||||
# the room version should be sane.
|
||||
room_version = create_event.content.get("room_version", "1")
|
||||
|
@ -665,9 +658,8 @@ class FederationClient(FederationBase):
|
|||
# This shouldn't be possible, because the remote server should have
|
||||
# rejected the join attempt during make_join.
|
||||
raise InvalidResponseError(
|
||||
"room appears to have unsupported version %s" % (
|
||||
room_version,
|
||||
))
|
||||
"room appears to have unsupported version %s" % (room_version,)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_request(destination):
|
||||
|
@ -691,10 +683,7 @@ class FederationClient(FederationBase):
|
|||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
|
||||
pdus = {
|
||||
p.event_id: p
|
||||
for p in itertools.chain(state, auth_chain)
|
||||
}
|
||||
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
|
||||
|
||||
room_version = None
|
||||
for e in state:
|
||||
|
@ -710,15 +699,13 @@ class FederationClient(FederationBase):
|
|||
raise SynapseError(400, "No create event in state")
|
||||
|
||||
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, list(pdus.values()),
|
||||
destination,
|
||||
list(pdus.values()),
|
||||
outlier=True,
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
valid_pdus_map = {
|
||||
p.event_id: p
|
||||
for p in valid_pdus
|
||||
}
|
||||
valid_pdus_map = {p.event_id: p for p in valid_pdus}
|
||||
|
||||
# NB: We *need* to copy to ensure that we don't have multiple
|
||||
# references being passed on, as that causes... issues.
|
||||
|
@ -741,11 +728,14 @@ class FederationClient(FederationBase):
|
|||
|
||||
check_authchain_validity(signed_auth)
|
||||
|
||||
defer.returnValue({
|
||||
"state": signed_state,
|
||||
"auth_chain": signed_auth,
|
||||
"origin": destination,
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"state": signed_state,
|
||||
"auth_chain": signed_auth,
|
||||
"origin": destination,
|
||||
}
|
||||
)
|
||||
|
||||
return self._try_destination_list("send_join", destinations, send_request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -854,6 +844,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
Fails with a ``RuntimeError`` if no servers were reachable.
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_request(destination):
|
||||
time_now = self._clock.time_msec()
|
||||
|
@ -869,14 +860,23 @@ class FederationClient(FederationBase):
|
|||
|
||||
return self._try_destination_list("send_leave", destinations, send_request)
|
||||
|
||||
def get_public_rooms(self, destination, limit=None, since_token=None,
|
||||
search_filter=None, include_all_networks=False,
|
||||
third_party_instance_id=None):
|
||||
def get_public_rooms(
|
||||
self,
|
||||
destination,
|
||||
limit=None,
|
||||
since_token=None,
|
||||
search_filter=None,
|
||||
include_all_networks=False,
|
||||
third_party_instance_id=None,
|
||||
):
|
||||
if destination == self.server_name:
|
||||
return
|
||||
|
||||
return self.transport_layer.get_public_rooms(
|
||||
destination, limit, since_token, search_filter,
|
||||
destination,
|
||||
limit,
|
||||
since_token,
|
||||
search_filter,
|
||||
include_all_networks=include_all_networks,
|
||||
third_party_instance_id=third_party_instance_id,
|
||||
)
|
||||
|
@ -891,9 +891,7 @@ class FederationClient(FederationBase):
|
|||
"""
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
send_content = {
|
||||
"auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
|
||||
}
|
||||
send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]}
|
||||
|
||||
code, content = yield self.transport_layer.send_query_auth(
|
||||
destination=destination,
|
||||
|
@ -905,13 +903,10 @@ class FederationClient(FederationBase):
|
|||
room_version = yield self.store.get_room_version(room_id)
|
||||
format_ver = room_version_to_event_format(room_version)
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(e, format_ver)
|
||||
for e in content["auth_chain"]
|
||||
]
|
||||
auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]]
|
||||
|
||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, auth_chain, outlier=True, room_version=room_version,
|
||||
destination, auth_chain, outlier=True, room_version=room_version
|
||||
)
|
||||
|
||||
signed_auth.sort(key=lambda e: e.depth)
|
||||
|
@ -925,8 +920,16 @@ class FederationClient(FederationBase):
|
|||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_missing_events(self, destination, room_id, earliest_events_ids,
|
||||
latest_events, limit, min_depth, timeout):
|
||||
def get_missing_events(
|
||||
self,
|
||||
destination,
|
||||
room_id,
|
||||
earliest_events_ids,
|
||||
latest_events,
|
||||
limit,
|
||||
min_depth,
|
||||
timeout,
|
||||
):
|
||||
"""Tries to fetch events we are missing. This is called when we receive
|
||||
an event without having received all of its ancestors.
|
||||
|
||||
|
@ -957,12 +960,11 @@ class FederationClient(FederationBase):
|
|||
format_ver = room_version_to_event_format(room_version)
|
||||
|
||||
events = [
|
||||
event_from_pdu_json(e, format_ver)
|
||||
for e in content.get("events", [])
|
||||
event_from_pdu_json(e, format_ver) for e in content.get("events", [])
|
||||
]
|
||||
|
||||
signed_events = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, events, outlier=False, room_version=room_version,
|
||||
destination, events, outlier=False, room_version=room_version
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
if not e.code == 400:
|
||||
|
@ -982,17 +984,14 @@ class FederationClient(FederationBase):
|
|||
|
||||
try:
|
||||
yield self.transport_layer.exchange_third_party_invite(
|
||||
destination=destination,
|
||||
room_id=room_id,
|
||||
event_dict=event_dict,
|
||||
destination=destination, room_id=room_id, event_dict=event_dict
|
||||
)
|
||||
defer.returnValue(None)
|
||||
except CodeMessageException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to send_third_party_invite via %s: %s",
|
||||
destination, str(e)
|
||||
"Failed to send_third_party_invite via %s: %s", destination, str(e)
|
||||
)
|
||||
|
||||
raise RuntimeError("Failed to send to any server.")
|
||||
|
|
|
@ -69,7 +69,6 @@ received_queries_counter = Counter(
|
|||
|
||||
|
||||
class FederationServer(FederationBase):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(FederationServer, self).__init__(hs)
|
||||
|
||||
|
@ -118,11 +117,13 @@ class FederationServer(FederationBase):
|
|||
|
||||
# use a linearizer to ensure that we don't process the same transaction
|
||||
# multiple times in parallel.
|
||||
with (yield self._transaction_linearizer.queue(
|
||||
(origin, transaction.transaction_id),
|
||||
)):
|
||||
with (
|
||||
yield self._transaction_linearizer.queue(
|
||||
(origin, transaction.transaction_id)
|
||||
)
|
||||
):
|
||||
result = yield self._handle_incoming_transaction(
|
||||
origin, transaction, request_time,
|
||||
origin, transaction, request_time
|
||||
)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
@ -144,7 +145,7 @@ class FederationServer(FederationBase):
|
|||
if response:
|
||||
logger.debug(
|
||||
"[%s] We've already responded to this request",
|
||||
transaction.transaction_id
|
||||
transaction.transaction_id,
|
||||
)
|
||||
defer.returnValue(response)
|
||||
return
|
||||
|
@ -152,18 +153,15 @@ class FederationServer(FederationBase):
|
|||
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
||||
|
||||
# Reject if PDU count > 50 and EDU count > 100
|
||||
if (len(transaction.pdus) > 50
|
||||
or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
|
||||
if len(transaction.pdus) > 50 or (
|
||||
hasattr(transaction, "edus") and len(transaction.edus) > 100
|
||||
):
|
||||
|
||||
logger.info(
|
||||
"Transaction PDU or EDU count too large. Returning 400",
|
||||
)
|
||||
logger.info("Transaction PDU or EDU count too large. Returning 400")
|
||||
|
||||
response = {}
|
||||
yield self.transaction_actions.set_response(
|
||||
origin,
|
||||
transaction,
|
||||
400, response
|
||||
origin, transaction, 400, response
|
||||
)
|
||||
defer.returnValue((400, response))
|
||||
|
||||
|
@ -230,9 +228,7 @@ class FederationServer(FederationBase):
|
|||
try:
|
||||
yield self.check_server_matches_acl(origin_host, room_id)
|
||||
except AuthError as e:
|
||||
logger.warn(
|
||||
"Ignoring PDUs for room %s from banned server", room_id,
|
||||
)
|
||||
logger.warn("Ignoring PDUs for room %s from banned server", room_id)
|
||||
for pdu in pdus_by_room[room_id]:
|
||||
event_id = pdu.event_id
|
||||
pdu_results[event_id] = e.error_dict()
|
||||
|
@ -242,9 +238,7 @@ class FederationServer(FederationBase):
|
|||
event_id = pdu.event_id
|
||||
with nested_logging_context(event_id):
|
||||
try:
|
||||
yield self._handle_received_pdu(
|
||||
origin, pdu
|
||||
)
|
||||
yield self._handle_received_pdu(origin, pdu)
|
||||
pdu_results[event_id] = {}
|
||||
except FederationError as e:
|
||||
logger.warn("Error handling PDU %s: %s", event_id, e)
|
||||
|
@ -259,29 +253,18 @@ class FederationServer(FederationBase):
|
|||
)
|
||||
|
||||
yield concurrently_execute(
|
||||
process_pdus_for_room, pdus_by_room.keys(),
|
||||
TRANSACTION_CONCURRENCY_LIMIT,
|
||||
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
|
||||
)
|
||||
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in (Edu(**x) for x in transaction.edus):
|
||||
yield self.received_edu(
|
||||
origin,
|
||||
edu.edu_type,
|
||||
edu.content
|
||||
)
|
||||
yield self.received_edu(origin, edu.edu_type, edu.content)
|
||||
|
||||
response = {
|
||||
"pdus": pdu_results,
|
||||
}
|
||||
response = {"pdus": pdu_results}
|
||||
|
||||
logger.debug("Returning: %s", str(response))
|
||||
|
||||
yield self.transaction_actions.set_response(
|
||||
origin,
|
||||
transaction,
|
||||
200, response
|
||||
)
|
||||
yield self.transaction_actions.set_response(origin, transaction, 200, response)
|
||||
defer.returnValue((200, response))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -311,7 +294,8 @@ class FederationServer(FederationBase):
|
|||
resp = yield self._state_resp_cache.wrap(
|
||||
(room_id, event_id),
|
||||
self._on_context_state_request_compute,
|
||||
room_id, event_id,
|
||||
room_id,
|
||||
event_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -328,24 +312,17 @@ class FederationServer(FederationBase):
|
|||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
state_ids = yield self.handler.get_state_ids_for_pdu(
|
||||
room_id, event_id,
|
||||
)
|
||||
state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
|
||||
auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"pdu_ids": state_ids,
|
||||
"auth_chain_ids": auth_chain_ids,
|
||||
}))
|
||||
defer.returnValue(
|
||||
(200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids})
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_context_state_request_compute(self, room_id, event_id):
|
||||
pdus = yield self.handler.get_state_for_pdu(
|
||||
room_id, event_id,
|
||||
)
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
|
||||
auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
|
||||
|
||||
for event in auth_chain:
|
||||
# We sign these again because there was a bug where we
|
||||
|
@ -355,14 +332,16 @@ class FederationServer(FederationBase):
|
|||
compute_event_signature(
|
||||
event.get_pdu_json(),
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
self.hs.config.signing_key[0],
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -370,9 +349,7 @@ class FederationServer(FederationBase):
|
|||
pdu = yield self.handler.get_persisted_pdu(origin, event_id)
|
||||
|
||||
if pdu:
|
||||
defer.returnValue(
|
||||
(200, self._transaction_from_pdus([pdu]).get_dict())
|
||||
)
|
||||
defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict()))
|
||||
else:
|
||||
defer.returnValue((404, ""))
|
||||
|
||||
|
@ -394,10 +371,9 @@ class FederationServer(FederationBase):
|
|||
|
||||
pdu = yield self.handler.on_make_join_request(room_id, user_id)
|
||||
time_now = self._clock.time_msec()
|
||||
defer.returnValue({
|
||||
"event": pdu.get_pdu_json(time_now),
|
||||
"room_version": room_version,
|
||||
})
|
||||
defer.returnValue(
|
||||
{"event": pdu.get_pdu_json(time_now), "room_version": room_version}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_invite_request(self, origin, content, room_version):
|
||||
|
@ -431,12 +407,17 @@ class FederationServer(FederationBase):
|
|||
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
|
||||
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
|
||||
time_now = self._clock.time_msec()
|
||||
defer.returnValue((200, {
|
||||
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
|
||||
"auth_chain": [
|
||||
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
|
||||
],
|
||||
}))
|
||||
defer.returnValue(
|
||||
(
|
||||
200,
|
||||
{
|
||||
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
|
||||
"auth_chain": [
|
||||
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
|
||||
],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_make_leave_request(self, origin, room_id, user_id):
|
||||
|
@ -447,10 +428,9 @@ class FederationServer(FederationBase):
|
|||
room_version = yield self.store.get_room_version(room_id)
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
defer.returnValue({
|
||||
"event": pdu.get_pdu_json(time_now),
|
||||
"room_version": room_version,
|
||||
})
|
||||
defer.returnValue(
|
||||
{"event": pdu.get_pdu_json(time_now), "room_version": room_version}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_send_leave_request(self, origin, content, room_id):
|
||||
|
@ -475,9 +455,7 @@ class FederationServer(FederationBase):
|
|||
|
||||
time_now = self._clock.time_msec()
|
||||
auth_pdus = yield self.handler.on_event_auth(event_id)
|
||||
res = {
|
||||
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
||||
}
|
||||
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
|
||||
defer.returnValue((200, res))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -508,12 +486,11 @@ class FederationServer(FederationBase):
|
|||
format_ver = room_version_to_event_format(room_version)
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(e, format_ver)
|
||||
for e in content["auth_chain"]
|
||||
event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
|
||||
]
|
||||
|
||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||
origin, auth_chain, outlier=True, room_version=room_version,
|
||||
origin, auth_chain, outlier=True, room_version=room_version
|
||||
)
|
||||
|
||||
ret = yield self.handler.on_query_auth(
|
||||
|
@ -527,17 +504,12 @@ class FederationServer(FederationBase):
|
|||
|
||||
time_now = self._clock.time_msec()
|
||||
send_content = {
|
||||
"auth_chain": [
|
||||
e.get_pdu_json(time_now)
|
||||
for e in ret["auth_chain"]
|
||||
],
|
||||
"auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
|
||||
"rejects": ret.get("rejects", []),
|
||||
"missing": ret.get("missing", []),
|
||||
}
|
||||
|
||||
defer.returnValue(
|
||||
(200, send_content)
|
||||
)
|
||||
defer.returnValue((200, send_content))
|
||||
|
||||
@log_function
|
||||
def on_query_client_keys(self, origin, content):
|
||||
|
@ -566,20 +538,23 @@ class FederationServer(FederationBase):
|
|||
|
||||
logger.info(
|
||||
"Claimed one-time-keys: %s",
|
||||
",".join((
|
||||
"%s for %s:%s" % (key_id, user_id, device_id)
|
||||
for user_id, user_keys in iteritems(json_result)
|
||||
for device_id, device_keys in iteritems(user_keys)
|
||||
for key_id, _ in iteritems(device_keys)
|
||||
)),
|
||||
",".join(
|
||||
(
|
||||
"%s for %s:%s" % (key_id, user_id, device_id)
|
||||
for user_id, user_keys in iteritems(json_result)
|
||||
for device_id, device_keys in iteritems(user_keys)
|
||||
for key_id, _ in iteritems(device_keys)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
defer.returnValue({"one_time_keys": json_result})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_get_missing_events(self, origin, room_id, earliest_events,
|
||||
latest_events, limit):
|
||||
def on_get_missing_events(
|
||||
self, origin, room_id, earliest_events, latest_events, limit
|
||||
):
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
yield self.check_server_matches_acl(origin_host, room_id)
|
||||
|
@ -587,11 +562,13 @@ class FederationServer(FederationBase):
|
|||
logger.info(
|
||||
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
|
||||
" limit: %d",
|
||||
earliest_events, latest_events, limit,
|
||||
earliest_events,
|
||||
latest_events,
|
||||
limit,
|
||||
)
|
||||
|
||||
missing_events = yield self.handler.on_get_missing_events(
|
||||
origin, room_id, earliest_events, latest_events, limit,
|
||||
origin, room_id, earliest_events, latest_events, limit
|
||||
)
|
||||
|
||||
if len(missing_events) < 5:
|
||||
|
@ -603,9 +580,9 @@ class FederationServer(FederationBase):
|
|||
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
defer.returnValue({
|
||||
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
|
||||
})
|
||||
defer.returnValue(
|
||||
{"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
|
||||
)
|
||||
|
||||
@log_function
|
||||
def on_openid_userinfo(self, token):
|
||||
|
@ -666,22 +643,17 @@ class FederationServer(FederationBase):
|
|||
# origin. See bug #1893. This is also true for some third party
|
||||
# invites).
|
||||
if not (
|
||||
pdu.type == 'm.room.member' and
|
||||
pdu.content and
|
||||
pdu.content.get("membership", None) in (
|
||||
Membership.JOIN, Membership.INVITE,
|
||||
)
|
||||
pdu.type == "m.room.member"
|
||||
and pdu.content
|
||||
and pdu.content.get("membership", None)
|
||||
in (Membership.JOIN, Membership.INVITE)
|
||||
):
|
||||
logger.info(
|
||||
"Discarding PDU %s from invalid origin %s",
|
||||
pdu.event_id, origin
|
||||
"Discarding PDU %s from invalid origin %s", pdu.event_id, origin
|
||||
)
|
||||
return
|
||||
else:
|
||||
logger.info(
|
||||
"Accepting join PDU %s from %s",
|
||||
pdu.event_id, origin
|
||||
)
|
||||
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
|
||||
|
||||
# We've already checked that we know the room version by this point
|
||||
room_version = yield self.store.get_room_version(pdu.room_id)
|
||||
|
@ -690,33 +662,19 @@ class FederationServer(FederationBase):
|
|||
try:
|
||||
pdu = yield self._check_sigs_and_hash(room_version, pdu)
|
||||
except SynapseError as e:
|
||||
raise FederationError(
|
||||
"ERROR",
|
||||
e.code,
|
||||
e.msg,
|
||||
affected=pdu.event_id,
|
||||
)
|
||||
raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
|
||||
|
||||
yield self.handler.on_receive_pdu(
|
||||
origin, pdu, sent_to_us_directly=True,
|
||||
)
|
||||
yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
|
||||
|
||||
def __str__(self):
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def exchange_third_party_invite(
|
||||
self,
|
||||
sender_user_id,
|
||||
target_user_id,
|
||||
room_id,
|
||||
signed,
|
||||
self, sender_user_id, target_user_id, room_id, signed
|
||||
):
|
||||
ret = yield self.handler.exchange_third_party_invite(
|
||||
sender_user_id,
|
||||
target_user_id,
|
||||
room_id,
|
||||
signed,
|
||||
sender_user_id, target_user_id, room_id, signed
|
||||
)
|
||||
defer.returnValue(ret)
|
||||
|
||||
|
@ -771,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event):
|
|||
allow_ip_literals = True
|
||||
if not allow_ip_literals:
|
||||
# check for ipv6 literals. These start with '['.
|
||||
if server_name[0] == '[':
|
||||
if server_name[0] == "[":
|
||||
return False
|
||||
|
||||
# check for ipv4 literals. We can just lift the routine from twisted.
|
||||
|
@ -805,7 +763,9 @@ def server_matches_acl_event(server_name, acl_event):
|
|||
|
||||
def _acl_entry_matches(server_name, acl_entry):
|
||||
if not isinstance(acl_entry, six.string_types):
|
||||
logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
|
||||
logger.warn(
|
||||
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
|
||||
)
|
||||
return False
|
||||
regex = glob_to_regex(acl_entry)
|
||||
return regex.match(server_name)
|
||||
|
@ -815,6 +775,7 @@ class FederationHandlerRegistry(object):
|
|||
"""Allows classes to register themselves as handlers for a given EDU or
|
||||
query type for incoming federation traffic.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.edu_handlers = {}
|
||||
self.query_handlers = {}
|
||||
|
@ -848,9 +809,7 @@ class FederationHandlerRegistry(object):
|
|||
on and the result used as the response to the query request.
|
||||
"""
|
||||
if query_type in self.query_handlers:
|
||||
raise KeyError(
|
||||
"Already have a Query handler for %s" % (query_type,)
|
||||
)
|
||||
raise KeyError("Already have a Query handler for %s" % (query_type,))
|
||||
|
||||
logger.info("Registering federation query handler for %r", query_type)
|
||||
|
||||
|
@ -905,14 +864,10 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
|
|||
handler = self.edu_handlers.get(edu_type)
|
||||
if handler:
|
||||
return super(ReplicationFederationHandlerRegistry, self).on_edu(
|
||||
edu_type, origin, content,
|
||||
edu_type, origin, content
|
||||
)
|
||||
|
||||
return self._send_edu(
|
||||
edu_type=edu_type,
|
||||
origin=origin,
|
||||
content=content,
|
||||
)
|
||||
return self._send_edu(edu_type=edu_type, origin=origin, content=content)
|
||||
|
||||
def on_query(self, query_type, args):
|
||||
"""Overrides FederationHandlerRegistry
|
||||
|
@ -921,7 +876,4 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
|
|||
if handler:
|
||||
return handler(args)
|
||||
|
||||
return self._get_query_client(
|
||||
query_type=query_type,
|
||||
args=args,
|
||||
)
|
||||
return self._get_query_client(query_type=query_type, args=args)
|
||||
|
|
|
@ -46,12 +46,9 @@ class TransactionActions(object):
|
|||
response code and response body.
|
||||
"""
|
||||
if not transaction.transaction_id:
|
||||
raise RuntimeError("Cannot persist a transaction with no "
|
||||
"transaction_id")
|
||||
raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
|
||||
|
||||
return self.store.get_received_txn_response(
|
||||
transaction.transaction_id, origin
|
||||
)
|
||||
return self.store.get_received_txn_response(transaction.transaction_id, origin)
|
||||
|
||||
@log_function
|
||||
def set_response(self, origin, transaction, code, response):
|
||||
|
@ -61,14 +58,10 @@ class TransactionActions(object):
|
|||
Deferred
|
||||
"""
|
||||
if not transaction.transaction_id:
|
||||
raise RuntimeError("Cannot persist a transaction with no "
|
||||
"transaction_id")
|
||||
raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
|
||||
|
||||
return self.store.set_received_txn_response(
|
||||
transaction.transaction_id,
|
||||
origin,
|
||||
code,
|
||||
response,
|
||||
transaction.transaction_id, origin, code, response
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -77,12 +77,22 @@ class FederationRemoteSendQueue(object):
|
|||
# lambda binds to the queue rather than to the name of the queue which
|
||||
# changes. ARGH.
|
||||
def register(name, queue):
|
||||
LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
|
||||
"", [], lambda: len(queue))
|
||||
LaterGauge(
|
||||
"synapse_federation_send_queue_%s_size" % (queue_name,),
|
||||
"",
|
||||
[],
|
||||
lambda: len(queue),
|
||||
)
|
||||
|
||||
for queue_name in [
|
||||
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
|
||||
"edus", "device_messages", "pos_time", "presence_destinations",
|
||||
"presence_map",
|
||||
"presence_changed",
|
||||
"keyed_edu",
|
||||
"keyed_edu_changed",
|
||||
"edus",
|
||||
"device_messages",
|
||||
"pos_time",
|
||||
"presence_destinations",
|
||||
]:
|
||||
register(queue_name, getattr(self, queue_name))
|
||||
|
||||
|
@ -121,9 +131,7 @@ class FederationRemoteSendQueue(object):
|
|||
del self.presence_changed[key]
|
||||
|
||||
user_ids = set(
|
||||
user_id
|
||||
for uids in self.presence_changed.values()
|
||||
for user_id in uids
|
||||
user_id for uids in self.presence_changed.values() for user_id in uids
|
||||
)
|
||||
|
||||
keys = self.presence_destinations.keys()
|
||||
|
@ -285,19 +293,21 @@ class FederationRemoteSendQueue(object):
|
|||
]
|
||||
|
||||
for (key, user_id) in dest_user_ids:
|
||||
rows.append((key, PresenceRow(
|
||||
state=self.presence_map[user_id],
|
||||
)))
|
||||
rows.append((key, PresenceRow(state=self.presence_map[user_id])))
|
||||
|
||||
# Fetch presence to send to destinations
|
||||
i = self.presence_destinations.bisect_right(from_token)
|
||||
j = self.presence_destinations.bisect_right(to_token) + 1
|
||||
|
||||
for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
|
||||
rows.append((pos, PresenceDestinationsRow(
|
||||
state=self.presence_map[user_id],
|
||||
destinations=list(dests),
|
||||
)))
|
||||
rows.append(
|
||||
(
|
||||
pos,
|
||||
PresenceDestinationsRow(
|
||||
state=self.presence_map[user_id], destinations=list(dests)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Fetch changes keyed edus
|
||||
i = self.keyed_edu_changed.bisect_right(from_token)
|
||||
|
@ -308,10 +318,14 @@ class FederationRemoteSendQueue(object):
|
|||
keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
|
||||
|
||||
for ((destination, edu_key), pos) in iteritems(keyed_edus):
|
||||
rows.append((pos, KeyedEduRow(
|
||||
key=edu_key,
|
||||
edu=self.keyed_edu[(destination, edu_key)],
|
||||
)))
|
||||
rows.append(
|
||||
(
|
||||
pos,
|
||||
KeyedEduRow(
|
||||
key=edu_key, edu=self.keyed_edu[(destination, edu_key)]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Fetch changed edus
|
||||
i = self.edus.bisect_right(from_token)
|
||||
|
@ -327,9 +341,7 @@ class FederationRemoteSendQueue(object):
|
|||
device_messages = {v: k for k, v in self.device_messages.items()[i:j]}
|
||||
|
||||
for (destination, pos) in iteritems(device_messages):
|
||||
rows.append((pos, DeviceRow(
|
||||
destination=destination,
|
||||
)))
|
||||
rows.append((pos, DeviceRow(destination=destination)))
|
||||
|
||||
# Sort rows based on pos
|
||||
rows.sort()
|
||||
|
@ -377,16 +389,14 @@ class BaseFederationRow(object):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
|
||||
"state", # UserPresenceState
|
||||
))):
|
||||
class PresenceRow(
|
||||
BaseFederationRow, namedtuple("PresenceRow", ("state",)) # UserPresenceState
|
||||
):
|
||||
TypeId = "p"
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
return PresenceRow(
|
||||
state=UserPresenceState.from_dict(data)
|
||||
)
|
||||
return PresenceRow(state=UserPresenceState.from_dict(data))
|
||||
|
||||
def to_data(self):
|
||||
return self.state.as_dict()
|
||||
|
@ -395,33 +405,35 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
|
|||
buff.presence.append(self.state)
|
||||
|
||||
|
||||
class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", (
|
||||
"state", # UserPresenceState
|
||||
"destinations", # list[str]
|
||||
))):
|
||||
class PresenceDestinationsRow(
|
||||
BaseFederationRow,
|
||||
namedtuple(
|
||||
"PresenceDestinationsRow",
|
||||
("state", "destinations"), # UserPresenceState # list[str]
|
||||
),
|
||||
):
|
||||
TypeId = "pd"
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
return PresenceDestinationsRow(
|
||||
state=UserPresenceState.from_dict(data["state"]),
|
||||
destinations=data["dests"],
|
||||
state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
|
||||
)
|
||||
|
||||
def to_data(self):
|
||||
return {
|
||||
"state": self.state.as_dict(),
|
||||
"dests": self.destinations,
|
||||
}
|
||||
return {"state": self.state.as_dict(), "dests": self.destinations}
|
||||
|
||||
def add_to_buffer(self, buff):
|
||||
buff.presence_destinations.append((self.state, self.destinations))
|
||||
|
||||
|
||||
class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
|
||||
"key", # tuple(str) - the edu key passed to send_edu
|
||||
"edu", # Edu
|
||||
))):
|
||||
class KeyedEduRow(
|
||||
BaseFederationRow,
|
||||
namedtuple(
|
||||
"KeyedEduRow",
|
||||
("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu
|
||||
),
|
||||
):
|
||||
"""Streams EDUs that have an associated key that is ued to clobber. For example,
|
||||
typing EDUs clobber based on room_id.
|
||||
"""
|
||||
|
@ -430,28 +442,19 @@ class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
|
|||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
return KeyedEduRow(
|
||||
key=tuple(data["key"]),
|
||||
edu=Edu(**data["edu"]),
|
||||
)
|
||||
return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
|
||||
|
||||
def to_data(self):
|
||||
return {
|
||||
"key": self.key,
|
||||
"edu": self.edu.get_internal_dict(),
|
||||
}
|
||||
return {"key": self.key, "edu": self.edu.get_internal_dict()}
|
||||
|
||||
def add_to_buffer(self, buff):
|
||||
buff.keyed_edus.setdefault(
|
||||
self.edu.destination, {}
|
||||
)[self.key] = self.edu
|
||||
buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
|
||||
|
||||
|
||||
class EduRow(BaseFederationRow, namedtuple("EduRow", (
|
||||
"edu", # Edu
|
||||
))):
|
||||
class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
|
||||
"""Streams EDUs that don't have keys. See KeyedEduRow
|
||||
"""
|
||||
|
||||
TypeId = "e"
|
||||
|
||||
@staticmethod
|
||||
|
@ -465,13 +468,12 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
|
|||
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
||||
|
||||
|
||||
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
|
||||
"destination", # str
|
||||
))):
|
||||
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ("destination",))): # str
|
||||
"""Streams the fact that either a) there is pending to device messages for
|
||||
users on the remote, or b) a local users device has changed and needs to
|
||||
be sent to the remote.
|
||||
"""
|
||||
|
||||
TypeId = "d"
|
||||
|
||||
@staticmethod
|
||||
|
@ -487,23 +489,20 @@ class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
|
|||
|
||||
TypeToRow = {
|
||||
Row.TypeId: Row
|
||||
for Row in (
|
||||
PresenceRow,
|
||||
PresenceDestinationsRow,
|
||||
KeyedEduRow,
|
||||
EduRow,
|
||||
DeviceRow,
|
||||
)
|
||||
for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow)
|
||||
}
|
||||
|
||||
|
||||
ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
|
||||
"presence", # list(UserPresenceState)
|
||||
"presence_destinations", # list of tuples of UserPresenceState and destinations
|
||||
"keyed_edus", # dict of destination -> { key -> Edu }
|
||||
"edus", # dict of destination -> [Edu]
|
||||
"device_destinations", # set of destinations
|
||||
))
|
||||
ParsedFederationStreamData = namedtuple(
|
||||
"ParsedFederationStreamData",
|
||||
(
|
||||
"presence", # list(UserPresenceState)
|
||||
"presence_destinations", # list of tuples of UserPresenceState and destinations
|
||||
"keyed_edus", # dict of destination -> { key -> Edu }
|
||||
"edus", # dict of destination -> [Edu]
|
||||
"device_destinations", # set of destinations
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def process_rows_for_federation(transaction_queue, rows):
|
||||
|
@ -542,7 +541,7 @@ def process_rows_for_federation(transaction_queue, rows):
|
|||
|
||||
for state, destinations in buff.presence_destinations:
|
||||
transaction_queue.send_presence_to_destinations(
|
||||
states=[state], destinations=destinations,
|
||||
states=[state], destinations=destinations
|
||||
)
|
||||
|
||||
for destination, edu_map in iteritems(buff.keyed_edus):
|
||||
|
|
|
@ -44,8 +44,8 @@ sent_pdus_destination_dist_count = Counter(
|
|||
)
|
||||
|
||||
sent_pdus_destination_dist_total = Counter(
|
||||
"synapse_federation_client_sent_pdu_destinations:total", ""
|
||||
"Total number of PDUs queued for sending across all destinations",
|
||||
"synapse_federation_client_sent_pdu_destinations:total",
|
||||
"" "Total number of PDUs queued for sending across all destinations",
|
||||
)
|
||||
|
||||
|
||||
|
@ -63,14 +63,15 @@ class FederationSender(object):
|
|||
self._transaction_manager = TransactionManager(hs)
|
||||
|
||||
# map from destination to PerDestinationQueue
|
||||
self._per_destination_queues = {} # type: dict[str, PerDestinationQueue]
|
||||
self._per_destination_queues = {} # type: dict[str, PerDestinationQueue]
|
||||
|
||||
LaterGauge(
|
||||
"synapse_federation_transaction_queue_pending_destinations",
|
||||
"",
|
||||
[],
|
||||
lambda: sum(
|
||||
1 for d in self._per_destination_queues.values()
|
||||
1
|
||||
for d in self._per_destination_queues.values()
|
||||
if d.transmission_loop_running
|
||||
),
|
||||
)
|
||||
|
@ -108,8 +109,9 @@ class FederationSender(object):
|
|||
# awaiting a call to flush_read_receipts_for_room. The presence of an entry
|
||||
# here for a given room means that we are rate-limiting RR flushes to that room,
|
||||
# and that there is a pending call to _flush_rrs_for_room in the system.
|
||||
self._queues_awaiting_rr_flush_by_room = {
|
||||
} # type: dict[str, set[PerDestinationQueue]]
|
||||
self._queues_awaiting_rr_flush_by_room = (
|
||||
{}
|
||||
) # type: dict[str, set[PerDestinationQueue]]
|
||||
|
||||
self._rr_txn_interval_per_room_ms = (
|
||||
1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second
|
||||
|
@ -141,8 +143,7 @@ class FederationSender(object):
|
|||
|
||||
# fire off a processing loop in the background
|
||||
run_as_background_process(
|
||||
"process_event_queue_for_federation",
|
||||
self._process_event_queue_loop,
|
||||
"process_event_queue_for_federation", self._process_event_queue_loop
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -152,7 +153,7 @@ class FederationSender(object):
|
|||
while True:
|
||||
last_token = yield self.store.get_federation_out_pos("events")
|
||||
next_token, events = yield self.store.get_all_new_events_stream(
|
||||
last_token, self._last_poked_id, limit=100,
|
||||
last_token, self._last_poked_id, limit=100
|
||||
)
|
||||
|
||||
logger.debug("Handling %s -> %s", last_token, next_token)
|
||||
|
@ -168,6 +169,9 @@ class FederationSender(object):
|
|||
if not is_mine and send_on_behalf_of is None:
|
||||
return
|
||||
|
||||
if not event.internal_metadata.should_proactively_send():
|
||||
return
|
||||
|
||||
try:
|
||||
# Get the state from before the event.
|
||||
# We need to make sure that this is the state from before
|
||||
|
@ -176,7 +180,7 @@ class FederationSender(object):
|
|||
# banned then it won't receive the event because it won't
|
||||
# be in the room after the ban.
|
||||
destinations = yield self.state.get_current_hosts_in_room(
|
||||
event.room_id, latest_event_ids=event.prev_event_ids(),
|
||||
event.room_id, latest_event_ids=event.prev_event_ids()
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
|
@ -206,37 +210,40 @@ class FederationSender(object):
|
|||
for event in events:
|
||||
events_by_room.setdefault(event.room_id, []).append(event)
|
||||
|
||||
yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
logcontext.run_in_background(handle_room_events, evs)
|
||||
for evs in itervalues(events_by_room)
|
||||
],
|
||||
consumeErrors=True
|
||||
))
|
||||
|
||||
yield self.store.update_federation_out_pos(
|
||||
"events", next_token
|
||||
yield logcontext.make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
logcontext.run_in_background(handle_room_events, evs)
|
||||
for evs in itervalues(events_by_room)
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
)
|
||||
|
||||
yield self.store.update_federation_out_pos("events", next_token)
|
||||
|
||||
if events:
|
||||
now = self.clock.time_msec()
|
||||
ts = yield self.store.get_received_ts(events[-1].event_id)
|
||||
|
||||
synapse.metrics.event_processing_lag.labels(
|
||||
"federation_sender").set(now - ts)
|
||||
"federation_sender"
|
||||
).set(now - ts)
|
||||
synapse.metrics.event_processing_last_ts.labels(
|
||||
"federation_sender").set(ts)
|
||||
"federation_sender"
|
||||
).set(ts)
|
||||
|
||||
events_processed_counter.inc(len(events))
|
||||
|
||||
event_processing_loop_room_count.labels(
|
||||
"federation_sender"
|
||||
).inc(len(events_by_room))
|
||||
event_processing_loop_room_count.labels("federation_sender").inc(
|
||||
len(events_by_room)
|
||||
)
|
||||
|
||||
event_processing_loop_counter.labels("federation_sender").inc()
|
||||
|
||||
synapse.metrics.event_processing_positions.labels(
|
||||
"federation_sender").set(next_token)
|
||||
"federation_sender"
|
||||
).set(next_token)
|
||||
|
||||
finally:
|
||||
self._is_processing = False
|
||||
|
@ -309,9 +316,7 @@ class FederationSender(object):
|
|||
if not domains:
|
||||
return
|
||||
|
||||
queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(
|
||||
room_id
|
||||
)
|
||||
queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id)
|
||||
|
||||
# if there is no flush yet scheduled, we will send out these receipts with
|
||||
# immediate flushes, and schedule the next flush for this room.
|
||||
|
@ -374,10 +379,9 @@ class FederationSender(object):
|
|||
# updates in quick succession are correctly handled.
|
||||
# We only want to send presence for our own users, so lets always just
|
||||
# filter here just in case.
|
||||
self.pending_presence.update({
|
||||
state.user_id: state for state in states
|
||||
if self.is_mine_id(state.user_id)
|
||||
})
|
||||
self.pending_presence.update(
|
||||
{state.user_id: state for state in states if self.is_mine_id(state.user_id)}
|
||||
)
|
||||
|
||||
# We then handle the new pending presence in batches, first figuring
|
||||
# out the destinations we need to send each state to and then poking it
|
||||
|
|
|
@ -189,11 +189,21 @@ class PerDestinationQueue(object):
|
|||
|
||||
pending_pdus = []
|
||||
while True:
|
||||
device_message_edus, device_stream_id, dev_list_id = (
|
||||
# We have to keep 2 free slots for presence and rr_edus
|
||||
yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2)
|
||||
# We have to keep 2 free slots for presence and rr_edus
|
||||
limit = MAX_EDUS_PER_TRANSACTION - 2
|
||||
|
||||
device_update_edus, dev_list_id = (
|
||||
yield self._get_device_update_edus(limit)
|
||||
)
|
||||
|
||||
limit -= len(device_update_edus)
|
||||
|
||||
to_device_edus, device_stream_id = (
|
||||
yield self._get_to_device_message_edus(limit)
|
||||
)
|
||||
|
||||
pending_edus = device_update_edus + to_device_edus
|
||||
|
||||
# BEGIN CRITICAL SECTION
|
||||
#
|
||||
# In order to avoid a race condition, we need to make sure that
|
||||
|
@ -208,10 +218,6 @@ class PerDestinationQueue(object):
|
|||
# We can only include at most 50 PDUs per transactions
|
||||
pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
|
||||
|
||||
pending_edus = []
|
||||
|
||||
# We can only include at most 100 EDUs per transactions
|
||||
# rr_edus and pending_presence take at most one slot each
|
||||
pending_edus.extend(self._get_rr_edus(force_flush=False))
|
||||
pending_presence = self._pending_presence
|
||||
self._pending_presence = {}
|
||||
|
@ -232,7 +238,6 @@ class PerDestinationQueue(object):
|
|||
)
|
||||
)
|
||||
|
||||
pending_edus.extend(device_message_edus)
|
||||
pending_edus.extend(
|
||||
self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
|
||||
)
|
||||
|
@ -272,10 +277,13 @@ class PerDestinationQueue(object):
|
|||
sent_edus_by_type.labels(edu.edu_type).inc()
|
||||
# Remove the acknowledged device messages from the database
|
||||
# Only bother if we actually sent some device messages
|
||||
if device_message_edus:
|
||||
if to_device_edus:
|
||||
yield self._store.delete_device_msgs_for_remote(
|
||||
self._destination, device_stream_id
|
||||
)
|
||||
|
||||
# also mark the device updates as sent
|
||||
if device_update_edus:
|
||||
logger.info(
|
||||
"Marking as sent %r %r", self._destination, dev_list_id
|
||||
)
|
||||
|
@ -347,12 +355,12 @@ class PerDestinationQueue(object):
|
|||
return pending_edus
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_new_device_messages(self, limit):
|
||||
def _get_device_update_edus(self, limit):
|
||||
last_device_list = self._last_device_list_stream_id
|
||||
|
||||
# Retrieve list of new device updates to send to the destination
|
||||
now_stream_id, results = yield self._store.get_devices_by_remote(
|
||||
self._destination, last_device_list, limit=limit,
|
||||
self._destination, last_device_list, limit=limit
|
||||
)
|
||||
edus = [
|
||||
Edu(
|
||||
|
@ -366,15 +374,16 @@ class PerDestinationQueue(object):
|
|||
|
||||
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
|
||||
|
||||
defer.returnValue((edus, now_stream_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_to_device_message_edus(self, limit):
|
||||
last_device_stream_id = self._last_device_stream_id
|
||||
to_device_stream_id = self._store.get_to_device_stream_token()
|
||||
contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
|
||||
self._destination,
|
||||
last_device_stream_id,
|
||||
to_device_stream_id,
|
||||
limit - len(edus),
|
||||
self._destination, last_device_stream_id, to_device_stream_id, limit
|
||||
)
|
||||
edus.extend(
|
||||
edus = [
|
||||
Edu(
|
||||
origin=self._server_name,
|
||||
destination=self._destination,
|
||||
|
@ -382,6 +391,6 @@ class PerDestinationQueue(object):
|
|||
content=content,
|
||||
)
|
||||
for content in contents
|
||||
)
|
||||
]
|
||||
|
||||
defer.returnValue((edus, stream_id, now_stream_id))
|
||||
defer.returnValue((edus, stream_id))
|
||||
|
|
|
@ -29,9 +29,10 @@ class TransactionManager(object):
|
|||
|
||||
shared between PerDestinationQueue objects
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self._server_name = hs.hostname
|
||||
self.clock = hs.get_clock() # nb must be called this for @measure_func
|
||||
self.clock = hs.get_clock() # nb must be called this for @measure_func
|
||||
self._store = hs.get_datastore()
|
||||
self._transaction_actions = TransactionActions(self._store)
|
||||
self._transport_layer = hs.get_federation_transport_client()
|
||||
|
@ -55,9 +56,9 @@ class TransactionManager(object):
|
|||
txn_id = str(self._next_txn_id)
|
||||
|
||||
logger.debug(
|
||||
"TX [%s] {%s} Attempting new transaction"
|
||||
" (pdus: %d, edus: %d)",
|
||||
destination, txn_id,
|
||||
"TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
|
||||
destination,
|
||||
txn_id,
|
||||
len(pdus),
|
||||
len(edus),
|
||||
)
|
||||
|
@ -79,9 +80,9 @@ class TransactionManager(object):
|
|||
|
||||
logger.debug("TX [%s] Persisted transaction", destination)
|
||||
logger.info(
|
||||
"TX [%s] {%s} Sending transaction [%s],"
|
||||
" (PDUs: %d, EDUs: %d)",
|
||||
destination, txn_id,
|
||||
"TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
|
||||
destination,
|
||||
txn_id,
|
||||
transaction.transaction_id,
|
||||
len(pdus),
|
||||
len(edus),
|
||||
|
@ -112,20 +113,12 @@ class TransactionManager(object):
|
|||
response = e.response
|
||||
|
||||
if e.code in (401, 404, 429) or 500 <= e.code:
|
||||
logger.info(
|
||||
"TX [%s] {%s} got %d response",
|
||||
destination, txn_id, code
|
||||
)
|
||||
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
|
||||
raise e
|
||||
|
||||
logger.info(
|
||||
"TX [%s] {%s} got %d response",
|
||||
destination, txn_id, code
|
||||
)
|
||||
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
|
||||
|
||||
yield self._transaction_actions.delivered(
|
||||
transaction, code, response
|
||||
)
|
||||
yield self._transaction_actions.delivered(transaction, code, response)
|
||||
|
||||
logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id)
|
||||
|
||||
|
@ -134,13 +127,18 @@ class TransactionManager(object):
|
|||
if "error" in r:
|
||||
logger.warn(
|
||||
"TX [%s] {%s} Remote returned error for %s: %s",
|
||||
destination, txn_id, e_id, r,
|
||||
destination,
|
||||
txn_id,
|
||||
e_id,
|
||||
r,
|
||||
)
|
||||
else:
|
||||
for p in pdus:
|
||||
logger.warn(
|
||||
"TX [%s] {%s} Failed to send event %s",
|
||||
destination, txn_id, p.event_id,
|
||||
destination,
|
||||
txn_id,
|
||||
p.event_id,
|
||||
)
|
||||
success = False
|
||||
|
||||
|
|
|
@ -48,12 +48,13 @@ class TransportLayerClient(object):
|
|||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_room_state dest=%s, room=%s",
|
||||
destination, room_id)
|
||||
logger.debug("get_room_state dest=%s, room=%s", destination, room_id)
|
||||
|
||||
path = _create_v1_path("/state/%s", room_id)
|
||||
return self.client.get_json(
|
||||
destination, path=path, args={"event_id": event_id},
|
||||
destination,
|
||||
path=path,
|
||||
args={"event_id": event_id},
|
||||
try_trailing_slash_on_400=True,
|
||||
)
|
||||
|
||||
|
@ -71,12 +72,13 @@ class TransportLayerClient(object):
|
|||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_room_state_ids dest=%s, room=%s",
|
||||
destination, room_id)
|
||||
logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
|
||||
|
||||
path = _create_v1_path("/state_ids/%s", room_id)
|
||||
return self.client.get_json(
|
||||
destination, path=path, args={"event_id": event_id},
|
||||
destination,
|
||||
path=path,
|
||||
args={"event_id": event_id},
|
||||
try_trailing_slash_on_400=True,
|
||||
)
|
||||
|
||||
|
@ -94,13 +96,11 @@ class TransportLayerClient(object):
|
|||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_pdu dest=%s, event_id=%s",
|
||||
destination, event_id)
|
||||
logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
|
||||
|
||||
path = _create_v1_path("/event/%s", event_id)
|
||||
return self.client.get_json(
|
||||
destination, path=path, timeout=timeout,
|
||||
try_trailing_slash_on_400=True,
|
||||
destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
|
||||
)
|
||||
|
||||
@log_function
|
||||
|
@ -119,7 +119,10 @@ class TransportLayerClient(object):
|
|||
"""
|
||||
logger.debug(
|
||||
"backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s",
|
||||
destination, room_id, repr(event_tuples), str(limit)
|
||||
destination,
|
||||
room_id,
|
||||
repr(event_tuples),
|
||||
str(limit),
|
||||
)
|
||||
|
||||
if not event_tuples:
|
||||
|
@ -128,16 +131,10 @@ class TransportLayerClient(object):
|
|||
|
||||
path = _create_v1_path("/backfill/%s", room_id)
|
||||
|
||||
args = {
|
||||
"v": event_tuples,
|
||||
"limit": [str(limit)],
|
||||
}
|
||||
args = {"v": event_tuples, "limit": [str(limit)]}
|
||||
|
||||
return self.client.get_json(
|
||||
destination,
|
||||
path=path,
|
||||
args=args,
|
||||
try_trailing_slash_on_400=True,
|
||||
destination, path=path, args=args, try_trailing_slash_on_400=True
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -163,7 +160,8 @@ class TransportLayerClient(object):
|
|||
"""
|
||||
logger.debug(
|
||||
"send_data dest=%s, txid=%s",
|
||||
transaction.destination, transaction.transaction_id
|
||||
transaction.destination,
|
||||
transaction.transaction_id,
|
||||
)
|
||||
|
||||
if transaction.destination == self.server_name:
|
||||
|
@ -189,8 +187,9 @@ class TransportLayerClient(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def make_query(self, destination, query_type, args, retry_on_dns_fail,
|
||||
ignore_backoff=False):
|
||||
def make_query(
|
||||
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
|
||||
):
|
||||
path = _create_v1_path("/query/%s", query_type)
|
||||
|
||||
content = yield self.client.get_json(
|
||||
|
@ -235,8 +234,8 @@ class TransportLayerClient(object):
|
|||
valid_memberships = {Membership.JOIN, Membership.LEAVE}
|
||||
if membership not in valid_memberships:
|
||||
raise RuntimeError(
|
||||
"make_membership_event called with membership='%s', must be one of %s" %
|
||||
(membership, ",".join(valid_memberships))
|
||||
"make_membership_event called with membership='%s', must be one of %s"
|
||||
% (membership, ",".join(valid_memberships))
|
||||
)
|
||||
path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
|
||||
|
||||
|
@ -268,9 +267,7 @@ class TransportLayerClient(object):
|
|||
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
destination=destination, path=path, data=content
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
@ -284,7 +281,6 @@ class TransportLayerClient(object):
|
|||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
|
||||
# we want to do our best to send this through. The problem is
|
||||
# that if it fails, we won't retry it later, so if the remote
|
||||
# server was just having a momentary blip, the room will be out of
|
||||
|
@ -300,10 +296,7 @@ class TransportLayerClient(object):
|
|||
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
@ -314,26 +307,27 @@ class TransportLayerClient(object):
|
|||
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_public_rooms(self, remote_server, limit, since_token,
|
||||
search_filter=None, include_all_networks=False,
|
||||
third_party_instance_id=None):
|
||||
def get_public_rooms(
|
||||
self,
|
||||
remote_server,
|
||||
limit,
|
||||
since_token,
|
||||
search_filter=None,
|
||||
include_all_networks=False,
|
||||
third_party_instance_id=None,
|
||||
):
|
||||
path = _create_v1_path("/publicRooms")
|
||||
|
||||
args = {
|
||||
"include_all_networks": "true" if include_all_networks else "false",
|
||||
}
|
||||
args = {"include_all_networks": "true" if include_all_networks else "false"}
|
||||
if third_party_instance_id:
|
||||
args["third_party_instance_id"] = third_party_instance_id,
|
||||
args["third_party_instance_id"] = (third_party_instance_id,)
|
||||
if limit:
|
||||
args["limit"] = [str(limit)]
|
||||
if since_token:
|
||||
|
@ -342,10 +336,7 @@ class TransportLayerClient(object):
|
|||
# TODO(erikj): Actually send the search_filter across federation.
|
||||
|
||||
response = yield self.client.get_json(
|
||||
destination=remote_server,
|
||||
path=path,
|
||||
args=args,
|
||||
ignore_backoff=True,
|
||||
destination=remote_server, path=path, args=args, ignore_backoff=True
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
@ -353,12 +344,10 @@ class TransportLayerClient(object):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def exchange_third_party_invite(self, destination, room_id, event_dict):
|
||||
path = _create_v1_path("/exchange_third_party_invite/%s", room_id,)
|
||||
path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=event_dict,
|
||||
destination=destination, path=path, data=event_dict
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
@ -368,10 +357,7 @@ class TransportLayerClient(object):
|
|||
def get_event_auth(self, destination, room_id, event_id):
|
||||
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
|
||||
|
||||
content = yield self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
)
|
||||
content = yield self.client.get_json(destination=destination, path=path)
|
||||
|
||||
defer.returnValue(content)
|
||||
|
||||
|
@ -381,9 +367,7 @@ class TransportLayerClient(object):
|
|||
path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
|
||||
|
||||
content = yield self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
destination=destination, path=path, data=content
|
||||
)
|
||||
|
||||
defer.returnValue(content)
|
||||
|
@ -416,10 +400,7 @@ class TransportLayerClient(object):
|
|||
path = _create_v1_path("/user/keys/query")
|
||||
|
||||
content = yield self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=query_content,
|
||||
timeout=timeout,
|
||||
destination=destination, path=path, data=query_content, timeout=timeout
|
||||
)
|
||||
defer.returnValue(content)
|
||||
|
||||
|
@ -443,9 +424,7 @@ class TransportLayerClient(object):
|
|||
path = _create_v1_path("/user/devices/%s", user_id)
|
||||
|
||||
content = yield self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
timeout=timeout,
|
||||
destination=destination, path=path, timeout=timeout
|
||||
)
|
||||
defer.returnValue(content)
|
||||
|
||||
|
@ -479,18 +458,23 @@ class TransportLayerClient(object):
|
|||
path = _create_v1_path("/user/keys/claim")
|
||||
|
||||
content = yield self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=query_content,
|
||||
timeout=timeout,
|
||||
destination=destination, path=path, data=query_content, timeout=timeout
|
||||
)
|
||||
defer.returnValue(content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_missing_events(self, destination, room_id, earliest_events,
|
||||
latest_events, limit, min_depth, timeout):
|
||||
path = _create_v1_path("/get_missing_events/%s", room_id,)
|
||||
def get_missing_events(
|
||||
self,
|
||||
destination,
|
||||
room_id,
|
||||
earliest_events,
|
||||
latest_events,
|
||||
limit,
|
||||
min_depth,
|
||||
timeout,
|
||||
):
|
||||
path = _create_v1_path("/get_missing_events/%s", room_id)
|
||||
|
||||
content = yield self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -510,7 +494,7 @@ class TransportLayerClient(object):
|
|||
def get_group_profile(self, destination, group_id, requester_user_id):
|
||||
"""Get a group profile
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/profile", group_id,)
|
||||
path = _create_v1_path("/groups/%s/profile", group_id)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -529,7 +513,7 @@ class TransportLayerClient(object):
|
|||
requester_user_id (str)
|
||||
content (dict): The new profile of the group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/profile", group_id,)
|
||||
path = _create_v1_path("/groups/%s/profile", group_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -543,7 +527,7 @@ class TransportLayerClient(object):
|
|||
def get_group_summary(self, destination, group_id, requester_user_id):
|
||||
"""Get a group summary
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/summary", group_id,)
|
||||
path = _create_v1_path("/groups/%s/summary", group_id)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -556,7 +540,7 @@ class TransportLayerClient(object):
|
|||
def get_rooms_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get all rooms in a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/rooms", group_id,)
|
||||
path = _create_v1_path("/groups/%s/rooms", group_id)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -565,11 +549,12 @@ class TransportLayerClient(object):
|
|||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
|
||||
content):
|
||||
def add_room_to_group(
|
||||
self, destination, group_id, requester_user_id, room_id, content
|
||||
):
|
||||
"""Add a room to a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
|
||||
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -579,13 +564,13 @@ class TransportLayerClient(object):
|
|||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
|
||||
config_key, content):
|
||||
def update_room_in_group(
|
||||
self, destination, group_id, requester_user_id, room_id, config_key, content
|
||||
):
|
||||
"""Update room in group
|
||||
"""
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/room/%s/config/%s",
|
||||
group_id, room_id, config_key,
|
||||
"/groups/%s/room/%s/config/%s", group_id, room_id, config_key
|
||||
)
|
||||
|
||||
return self.client.post_json(
|
||||
|
@ -599,7 +584,7 @@ class TransportLayerClient(object):
|
|||
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
|
||||
"""Remove a room from a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
|
||||
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
|
@ -612,7 +597,7 @@ class TransportLayerClient(object):
|
|||
def get_users_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get users in a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/users", group_id,)
|
||||
path = _create_v1_path("/groups/%s/users", group_id)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -625,7 +610,7 @@ class TransportLayerClient(object):
|
|||
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get users that have been invited to a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/invited_users", group_id,)
|
||||
path = _create_v1_path("/groups/%s/invited_users", group_id)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -638,16 +623,10 @@ class TransportLayerClient(object):
|
|||
def accept_group_invite(self, destination, group_id, user_id, content):
|
||||
"""Accept a group invite
|
||||
"""
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/users/%s/accept_invite",
|
||||
group_id, user_id,
|
||||
)
|
||||
path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
)
|
||||
|
||||
@log_function
|
||||
|
@ -657,14 +636,13 @@ class TransportLayerClient(object):
|
|||
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
)
|
||||
|
||||
@log_function
|
||||
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
|
||||
def invite_to_group(
|
||||
self, destination, group_id, user_id, requester_user_id, content
|
||||
):
|
||||
"""Invite a user to a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
|
||||
|
@ -686,15 +664,13 @@ class TransportLayerClient(object):
|
|||
path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
)
|
||||
|
||||
@log_function
|
||||
def remove_user_from_group(self, destination, group_id, requester_user_id,
|
||||
user_id, content):
|
||||
def remove_user_from_group(
|
||||
self, destination, group_id, requester_user_id, user_id, content
|
||||
):
|
||||
"""Remove a user fron a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
|
||||
|
@ -708,8 +684,9 @@ class TransportLayerClient(object):
|
|||
)
|
||||
|
||||
@log_function
|
||||
def remove_user_from_group_notification(self, destination, group_id, user_id,
|
||||
content):
|
||||
def remove_user_from_group_notification(
|
||||
self, destination, group_id, user_id, content
|
||||
):
|
||||
"""Sent by group server to inform a user's server that they have been
|
||||
kicked from the group.
|
||||
"""
|
||||
|
@ -717,10 +694,7 @@ class TransportLayerClient(object):
|
|||
path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
)
|
||||
|
||||
@log_function
|
||||
|
@ -732,24 +706,24 @@ class TransportLayerClient(object):
|
|||
path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
)
|
||||
|
||||
@log_function
|
||||
def update_group_summary_room(self, destination, group_id, user_id, room_id,
|
||||
category_id, content):
|
||||
def update_group_summary_room(
|
||||
self, destination, group_id, user_id, room_id, category_id, content
|
||||
):
|
||||
"""Update a room entry in a group summary
|
||||
"""
|
||||
if category_id:
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/summary/categories/%s/rooms/%s",
|
||||
group_id, category_id, room_id,
|
||||
group_id,
|
||||
category_id,
|
||||
room_id,
|
||||
)
|
||||
else:
|
||||
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
|
||||
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -760,17 +734,20 @@ class TransportLayerClient(object):
|
|||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_summary_room(self, destination, group_id, user_id, room_id,
|
||||
category_id):
|
||||
def delete_group_summary_room(
|
||||
self, destination, group_id, user_id, room_id, category_id
|
||||
):
|
||||
"""Delete a room entry in a group summary
|
||||
"""
|
||||
if category_id:
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/summary/categories/%s/rooms/%s",
|
||||
group_id, category_id, room_id,
|
||||
group_id,
|
||||
category_id,
|
||||
room_id,
|
||||
)
|
||||
else:
|
||||
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
|
||||
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
|
@ -783,7 +760,7 @@ class TransportLayerClient(object):
|
|||
def get_group_categories(self, destination, group_id, requester_user_id):
|
||||
"""Get all categories in a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/categories", group_id,)
|
||||
path = _create_v1_path("/groups/%s/categories", group_id)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -796,7 +773,7 @@ class TransportLayerClient(object):
|
|||
def get_group_category(self, destination, group_id, requester_user_id, category_id):
|
||||
"""Get category info in a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
|
||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -806,11 +783,12 @@ class TransportLayerClient(object):
|
|||
)
|
||||
|
||||
@log_function
|
||||
def update_group_category(self, destination, group_id, requester_user_id, category_id,
|
||||
content):
|
||||
def update_group_category(
|
||||
self, destination, group_id, requester_user_id, category_id, content
|
||||
):
|
||||
"""Update a category in a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
|
||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -821,11 +799,12 @@ class TransportLayerClient(object):
|
|||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_category(self, destination, group_id, requester_user_id,
|
||||
category_id):
|
||||
def delete_group_category(
|
||||
self, destination, group_id, requester_user_id, category_id
|
||||
):
|
||||
"""Delete a category in a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
|
||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
|
@ -838,7 +817,7 @@ class TransportLayerClient(object):
|
|||
def get_group_roles(self, destination, group_id, requester_user_id):
|
||||
"""Get all roles in a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/roles", group_id,)
|
||||
path = _create_v1_path("/groups/%s/roles", group_id)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -851,7 +830,7 @@ class TransportLayerClient(object):
|
|||
def get_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||
"""Get a roles info
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
|
||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -861,11 +840,12 @@ class TransportLayerClient(object):
|
|||
)
|
||||
|
||||
@log_function
|
||||
def update_group_role(self, destination, group_id, requester_user_id, role_id,
|
||||
content):
|
||||
def update_group_role(
|
||||
self, destination, group_id, requester_user_id, role_id, content
|
||||
):
|
||||
"""Update a role in a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
|
||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -879,7 +859,7 @@ class TransportLayerClient(object):
|
|||
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||
"""Delete a role in a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
|
||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
|
@ -889,17 +869,17 @@ class TransportLayerClient(object):
|
|||
)
|
||||
|
||||
@log_function
|
||||
def update_group_summary_user(self, destination, group_id, requester_user_id,
|
||||
user_id, role_id, content):
|
||||
def update_group_summary_user(
|
||||
self, destination, group_id, requester_user_id, user_id, role_id, content
|
||||
):
|
||||
"""Update a users entry in a group
|
||||
"""
|
||||
if role_id:
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/summary/roles/%s/users/%s",
|
||||
group_id, role_id, user_id,
|
||||
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
|
||||
)
|
||||
else:
|
||||
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
|
||||
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -910,11 +890,10 @@ class TransportLayerClient(object):
|
|||
)
|
||||
|
||||
@log_function
|
||||
def set_group_join_policy(self, destination, group_id, requester_user_id,
|
||||
content):
|
||||
def set_group_join_policy(self, destination, group_id, requester_user_id, content):
|
||||
"""Sets the join policy for a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,)
|
||||
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
|
||||
|
||||
return self.client.put_json(
|
||||
destination=destination,
|
||||
|
@ -925,17 +904,17 @@ class TransportLayerClient(object):
|
|||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_summary_user(self, destination, group_id, requester_user_id,
|
||||
user_id, role_id):
|
||||
def delete_group_summary_user(
|
||||
self, destination, group_id, requester_user_id, user_id, role_id
|
||||
):
|
||||
"""Delete a users entry in a group
|
||||
"""
|
||||
if role_id:
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/summary/roles/%s/users/%s",
|
||||
group_id, role_id, user_id,
|
||||
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
|
||||
)
|
||||
else:
|
||||
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
|
||||
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
|
@ -953,10 +932,7 @@ class TransportLayerClient(object):
|
|||
content = {"user_ids": user_ids}
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
)
|
||||
|
||||
|
||||
|
@ -975,9 +951,8 @@ def _create_v1_path(path, *args):
|
|||
Returns:
|
||||
str
|
||||
"""
|
||||
return (
|
||||
FEDERATION_V1_PREFIX
|
||||
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
|
||||
return FEDERATION_V1_PREFIX + path % tuple(
|
||||
urllib.parse.quote(arg, "") for arg in args
|
||||
)
|
||||
|
||||
|
||||
|
@ -996,7 +971,6 @@ def _create_v2_path(path, *args):
|
|||
Returns:
|
||||
str
|
||||
"""
|
||||
return (
|
||||
FEDERATION_V2_PREFIX
|
||||
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
|
||||
return FEDERATION_V2_PREFIX + path % tuple(
|
||||
urllib.parse.quote(arg, "") for arg in args
|
||||
)
|
||||
|
|
|
@ -66,8 +66,7 @@ class TransportLayerServer(JsonResource):
|
|||
|
||||
self.authenticator = Authenticator(hs)
|
||||
self.ratelimiter = FederationRateLimiter(
|
||||
self.clock,
|
||||
config=hs.config.rc_federation,
|
||||
self.clock, config=hs.config.rc_federation
|
||||
)
|
||||
|
||||
self.register_servlets()
|
||||
|
@ -84,11 +83,13 @@ class TransportLayerServer(JsonResource):
|
|||
|
||||
class AuthenticationError(SynapseError):
|
||||
"""There was a problem authenticating the request"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NoAuthenticationError(AuthenticationError):
|
||||
"""The request had no authentication information"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -105,8 +106,8 @@ class Authenticator(object):
|
|||
def authenticate_request(self, request, content):
|
||||
now = self._clock.time_msec()
|
||||
json_request = {
|
||||
"method": request.method.decode('ascii'),
|
||||
"uri": request.uri.decode('ascii'),
|
||||
"method": request.method.decode("ascii"),
|
||||
"uri": request.uri.decode("ascii"),
|
||||
"destination": self.server_name,
|
||||
"signatures": {},
|
||||
}
|
||||
|
@ -120,7 +121,7 @@ class Authenticator(object):
|
|||
|
||||
if not auth_headers:
|
||||
raise NoAuthenticationError(
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
for auth in auth_headers:
|
||||
|
@ -130,14 +131,14 @@ class Authenticator(object):
|
|||
json_request["signatures"].setdefault(origin, {})[key] = sig
|
||||
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
origin not in self.federation_domain_whitelist
|
||||
self.federation_domain_whitelist is not None
|
||||
and origin not in self.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(origin)
|
||||
|
||||
if not json_request["signatures"]:
|
||||
raise NoAuthenticationError(
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
yield self.keyring.verify_json_for_server(
|
||||
|
@ -177,12 +178,12 @@ def _parse_auth_header(header_bytes):
|
|||
AuthenticationError if the header could not be parsed
|
||||
"""
|
||||
try:
|
||||
header_str = header_bytes.decode('utf-8')
|
||||
header_str = header_bytes.decode("utf-8")
|
||||
params = header_str.split(" ")[1].split(",")
|
||||
param_dict = dict(kv.split("=") for kv in params)
|
||||
|
||||
def strip_quotes(value):
|
||||
if value.startswith("\""):
|
||||
if value.startswith('"'):
|
||||
return value[1:-1]
|
||||
else:
|
||||
return value
|
||||
|
@ -198,11 +199,11 @@ def _parse_auth_header(header_bytes):
|
|||
except Exception as e:
|
||||
logger.warn(
|
||||
"Error parsing auth header '%s': %s",
|
||||
header_bytes.decode('ascii', 'replace'),
|
||||
header_bytes.decode("ascii", "replace"),
|
||||
e,
|
||||
)
|
||||
raise AuthenticationError(
|
||||
400, "Malformed Authorization header", Codes.UNAUTHORIZED,
|
||||
400, "Malformed Authorization header", Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
|
||||
|
@ -242,6 +243,7 @@ class BaseFederationServlet(object):
|
|||
Exception: other exceptions will be caught, logged, and a 500 will be
|
||||
returned.
|
||||
"""
|
||||
|
||||
REQUIRE_AUTH = True
|
||||
|
||||
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
|
||||
|
@ -293,9 +295,7 @@ class BaseFederationServlet(object):
|
|||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
response = yield func(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
response = yield func(origin, content, request.args, *args, **kwargs)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
|
@ -343,14 +343,12 @@ class FederationSendServlet(BaseFederationServlet):
|
|||
try:
|
||||
transaction_data = content
|
||||
|
||||
logger.debug(
|
||||
"Decoded %s: %s",
|
||||
transaction_id, str(transaction_data)
|
||||
)
|
||||
logger.debug("Decoded %s: %s", transaction_id, str(transaction_data))
|
||||
|
||||
logger.info(
|
||||
"Received txn %s from %s. (PDUs: %d, EDUs: %d)",
|
||||
transaction_id, origin,
|
||||
transaction_id,
|
||||
origin,
|
||||
len(transaction_data.get("pdus", [])),
|
||||
len(transaction_data.get("edus", [])),
|
||||
)
|
||||
|
@ -361,8 +359,7 @@ class FederationSendServlet(BaseFederationServlet):
|
|||
# Add some extra data to the transaction dict that isn't included
|
||||
# in the request body.
|
||||
transaction_data.update(
|
||||
transaction_id=transaction_id,
|
||||
destination=self.server_name
|
||||
transaction_id=transaction_id, destination=self.server_name
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -372,7 +369,7 @@ class FederationSendServlet(BaseFederationServlet):
|
|||
|
||||
try:
|
||||
code, response = yield self.handler.on_incoming_transaction(
|
||||
origin, transaction_data,
|
||||
origin, transaction_data
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("on_incoming_transaction failed")
|
||||
|
@ -416,7 +413,7 @@ class FederationBackfillServlet(BaseFederationServlet):
|
|||
PATH = "/backfill/(?P<context>[^/]*)/?"
|
||||
|
||||
def on_GET(self, origin, content, query, context):
|
||||
versions = [x.decode('ascii') for x in query[b"v"]]
|
||||
versions = [x.decode("ascii") for x in query[b"v"]]
|
||||
limit = parse_integer_from_args(query, "limit", None)
|
||||
|
||||
if not limit:
|
||||
|
@ -432,7 +429,7 @@ class FederationQueryServlet(BaseFederationServlet):
|
|||
def on_GET(self, origin, content, query, query_type):
|
||||
return self.handler.on_query_request(
|
||||
query_type,
|
||||
{k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()}
|
||||
{k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()},
|
||||
)
|
||||
|
||||
|
||||
|
@ -456,15 +453,14 @@ class FederationMakeJoinServlet(BaseFederationServlet):
|
|||
Deferred[(int, object)|None]: either (response code, response object) to
|
||||
return a JSON response, or None if the request has already been handled.
|
||||
"""
|
||||
versions = query.get(b'ver')
|
||||
versions = query.get(b"ver")
|
||||
if versions is not None:
|
||||
supported_versions = [v.decode("utf-8") for v in versions]
|
||||
else:
|
||||
supported_versions = ["1"]
|
||||
|
||||
content = yield self.handler.on_make_join_request(
|
||||
origin, context, user_id,
|
||||
supported_versions=supported_versions,
|
||||
origin, context, user_id, supported_versions=supported_versions
|
||||
)
|
||||
defer.returnValue((200, content))
|
||||
|
||||
|
@ -474,9 +470,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, context, user_id):
|
||||
content = yield self.handler.on_make_leave_request(
|
||||
origin, context, user_id,
|
||||
)
|
||||
content = yield self.handler.on_make_leave_request(origin, context, user_id)
|
||||
defer.returnValue((200, content))
|
||||
|
||||
|
||||
|
@ -517,7 +511,7 @@ class FederationV1InviteServlet(BaseFederationServlet):
|
|||
# state resolution algorithm, and we don't use that for processing
|
||||
# invites
|
||||
content = yield self.handler.on_invite_request(
|
||||
origin, content, room_version=RoomVersions.V1.identifier,
|
||||
origin, content, room_version=RoomVersions.V1.identifier
|
||||
)
|
||||
|
||||
# V1 federation API is defined to return a content of `[200, {...}]`
|
||||
|
@ -545,7 +539,7 @@ class FederationV2InviteServlet(BaseFederationServlet):
|
|||
event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
|
||||
|
||||
content = yield self.handler.on_invite_request(
|
||||
origin, event, room_version=room_version,
|
||||
origin, event, room_version=room_version
|
||||
)
|
||||
defer.returnValue((200, content))
|
||||
|
||||
|
@ -629,8 +623,10 @@ class On3pidBindServlet(BaseFederationServlet):
|
|||
for invite in content["invites"]:
|
||||
try:
|
||||
if "signed" not in invite or "token" not in invite["signed"]:
|
||||
message = ("Rejecting received notification of third-"
|
||||
"party invite without signed: %s" % (invite,))
|
||||
message = (
|
||||
"Rejecting received notification of third-"
|
||||
"party invite without signed: %s" % (invite,)
|
||||
)
|
||||
logger.info(message)
|
||||
raise SynapseError(400, message)
|
||||
yield self.handler.exchange_third_party_invite(
|
||||
|
@ -671,18 +667,23 @@ class OpenIdUserInfo(BaseFederationServlet):
|
|||
def on_GET(self, origin, content, query):
|
||||
token = query.get(b"access_token", [None])[0]
|
||||
if token is None:
|
||||
defer.returnValue((401, {
|
||||
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
|
||||
}))
|
||||
defer.returnValue(
|
||||
(401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"})
|
||||
)
|
||||
return
|
||||
|
||||
user_id = yield self.handler.on_openid_userinfo(token.decode('ascii'))
|
||||
user_id = yield self.handler.on_openid_userinfo(token.decode("ascii"))
|
||||
|
||||
if user_id is None:
|
||||
defer.returnValue((401, {
|
||||
"errcode": "M_UNKNOWN_TOKEN",
|
||||
"error": "Access Token unknown or expired"
|
||||
}))
|
||||
defer.returnValue(
|
||||
(
|
||||
401,
|
||||
{
|
||||
"errcode": "M_UNKNOWN_TOKEN",
|
||||
"error": "Access Token unknown or expired",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue((200, {"sub": user_id}))
|
||||
|
||||
|
@ -720,15 +721,15 @@ class PublicRoomList(BaseFederationServlet):
|
|||
|
||||
PATH = "/publicRooms"
|
||||
|
||||
def __init__(self, handler, authenticator, ratelimiter, server_name, deny_access):
|
||||
def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
|
||||
super(PublicRoomList, self).__init__(
|
||||
handler, authenticator, ratelimiter, server_name,
|
||||
handler, authenticator, ratelimiter, server_name
|
||||
)
|
||||
self.deny_access = deny_access
|
||||
self.allow_access = allow_access
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query):
|
||||
if self.deny_access:
|
||||
if not self.allow_access:
|
||||
raise FederationDeniedError(origin)
|
||||
|
||||
limit = parse_integer_from_args(query, "limit", 0)
|
||||
|
@ -748,9 +749,7 @@ class PublicRoomList(BaseFederationServlet):
|
|||
network_tuple = ThirdPartyInstanceID(None, None)
|
||||
|
||||
data = yield self.handler.get_local_public_room_list(
|
||||
limit, since_token,
|
||||
network_tuple=network_tuple,
|
||||
from_federation=True,
|
||||
limit, since_token, network_tuple=network_tuple, from_federation=True
|
||||
)
|
||||
defer.returnValue((200, data))
|
||||
|
||||
|
@ -761,17 +760,18 @@ class FederationVersionServlet(BaseFederationServlet):
|
|||
REQUIRE_AUTH = False
|
||||
|
||||
def on_GET(self, origin, content, query):
|
||||
return defer.succeed((200, {
|
||||
"server": {
|
||||
"name": "Synapse",
|
||||
"version": get_version_string(synapse)
|
||||
},
|
||||
}))
|
||||
return defer.succeed(
|
||||
(
|
||||
200,
|
||||
{"server": {"name": "Synapse", "version": get_version_string(synapse)}},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class FederationGroupsProfileServlet(BaseFederationServlet):
|
||||
"""Get/set the basic profile of a group on behalf of a user
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/profile"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -780,9 +780,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet):
|
|||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_group_profile(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
new_content = yield self.handler.get_group_profile(group_id, requester_user_id)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
@ -808,9 +806,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
|
|||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_group_summary(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
new_content = yield self.handler.get_group_summary(group_id, requester_user_id)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
@ -818,6 +814,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
|
|||
class FederationGroupsRoomsServlet(BaseFederationServlet):
|
||||
"""Get the rooms in a group on behalf of a user
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -826,9 +823,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
|
|||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_rooms_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
new_content = yield self.handler.get_rooms_in_group(group_id, requester_user_id)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
@ -836,6 +831,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
|
|||
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
||||
"""Add/remove room from group
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -857,7 +853,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
|||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.remove_room_from_group(
|
||||
group_id, requester_user_id, room_id,
|
||||
group_id, requester_user_id, room_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
@ -866,6 +862,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
|||
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
|
||||
"""Update room config in group
|
||||
"""
|
||||
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
|
||||
"/config/(?P<config_key>[^/]*)"
|
||||
|
@ -878,7 +875,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
|
|||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
result = yield self.groups_handler.update_room_in_group(
|
||||
group_id, requester_user_id, room_id, config_key, content,
|
||||
group_id, requester_user_id, room_id, config_key, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -887,6 +884,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
|
|||
class FederationGroupsUsersServlet(BaseFederationServlet):
|
||||
"""Get the users in a group on behalf of a user
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -895,9 +893,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
|
|||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_users_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
new_content = yield self.handler.get_users_in_group(group_id, requester_user_id)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
@ -905,6 +901,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
|
|||
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
|
||||
"""Get the users that have been invited to a group
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -923,6 +920,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
|
|||
class FederationGroupsInviteServlet(BaseFederationServlet):
|
||||
"""Ask a group server to invite someone to the group
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -932,7 +930,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
|
|||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.invite_to_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
group_id, user_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
@ -941,6 +939,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
|
|||
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
|
||||
"""Accept an invitation from the group server
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -948,9 +947,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
|
|||
if get_domain_from_id(user_id) != origin:
|
||||
raise SynapseError(403, "user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.accept_invite(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
new_content = yield self.handler.accept_invite(group_id, user_id, content)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
@ -958,6 +955,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
|
|||
class FederationGroupsJoinServlet(BaseFederationServlet):
|
||||
"""Attempt to join a group
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -965,9 +963,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
|
|||
if get_domain_from_id(user_id) != origin:
|
||||
raise SynapseError(403, "user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.join_group(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
new_content = yield self.handler.join_group(group_id, user_id, content)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
@ -975,6 +971,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
|
|||
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
|
||||
"""Leave or kick a user from the group
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -984,7 +981,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
|
|||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
group_id, user_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
@ -993,6 +990,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
|
|||
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
|
||||
"""A group server has invited a local user
|
||||
"""
|
||||
|
||||
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -1000,9 +998,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
|
|||
if get_domain_from_id(group_id) != origin:
|
||||
raise SynapseError(403, "group_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.on_invite(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
new_content = yield self.handler.on_invite(group_id, user_id, content)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
@ -1010,6 +1006,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
|
|||
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
|
||||
"""A group server has removed a local user
|
||||
"""
|
||||
|
||||
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -1018,7 +1015,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
|
|||
raise SynapseError(403, "user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.user_removed_from_group(
|
||||
group_id, user_id, content,
|
||||
group_id, user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
@ -1027,6 +1024,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
|
|||
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
|
||||
"""A group or user's server renews their attestation
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -1047,6 +1045,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
|
|||
- /groups/:group/summary/rooms/:room_id
|
||||
- /groups/:group/summary/categories/:category/rooms/:room_id
|
||||
"""
|
||||
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/categories/(?P<category_id>[^/]+))?"
|
||||
|
@ -1063,7 +1062,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
|
|||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.update_group_summary_room(
|
||||
group_id, requester_user_id,
|
||||
group_id,
|
||||
requester_user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
|
@ -1081,9 +1081,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
|
|||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_summary_room(
|
||||
group_id, requester_user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
group_id, requester_user_id, room_id=room_id, category_id=category_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -1092,9 +1090,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
|
|||
class FederationGroupsCategoriesServlet(BaseFederationServlet):
|
||||
"""Get all categories for a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/categories/?"
|
||||
)
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
|
@ -1102,9 +1099,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
|
|||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_categories(
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
resp = yield self.handler.get_group_categories(group_id, requester_user_id)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
@ -1112,9 +1107,8 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
|
|||
class FederationGroupsCategoryServlet(BaseFederationServlet):
|
||||
"""Add/remove/get a category in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
|
||||
)
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id, category_id):
|
||||
|
@ -1138,7 +1132,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
|
|||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.upsert_group_category(
|
||||
group_id, requester_user_id, category_id, content,
|
||||
group_id, requester_user_id, category_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -1153,7 +1147,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
|
|||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_category(
|
||||
group_id, requester_user_id, category_id,
|
||||
group_id, requester_user_id, category_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -1162,9 +1156,8 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
|
|||
class FederationGroupsRolesServlet(BaseFederationServlet):
|
||||
"""Get roles in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/roles/?"
|
||||
)
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
|
@ -1172,9 +1165,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
|
|||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_roles(
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
resp = yield self.handler.get_group_roles(group_id, requester_user_id)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
@ -1182,9 +1173,8 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
|
|||
class FederationGroupsRoleServlet(BaseFederationServlet):
|
||||
"""Add/remove/get a role in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
|
||||
)
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id, role_id):
|
||||
|
@ -1192,9 +1182,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
|
|||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_role(
|
||||
group_id, requester_user_id, role_id
|
||||
)
|
||||
resp = yield self.handler.get_group_role(group_id, requester_user_id, role_id)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
@ -1208,7 +1196,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
|
|||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.update_group_role(
|
||||
group_id, requester_user_id, role_id, content,
|
||||
group_id, requester_user_id, role_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -1223,7 +1211,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
|
|||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_role(
|
||||
group_id, requester_user_id, role_id,
|
||||
group_id, requester_user_id, role_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -1236,6 +1224,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
|
|||
- /groups/:group/summary/users/:user_id
|
||||
- /groups/:group/summary/roles/:role/users/:user_id
|
||||
"""
|
||||
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/roles/(?P<role_id>[^/]+))?"
|
||||
|
@ -1252,7 +1241,8 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
|
|||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.update_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
group_id,
|
||||
requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
content=content,
|
||||
|
@ -1270,9 +1260,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
|
|||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
group_id, requester_user_id, user_id=user_id, role_id=role_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -1281,14 +1269,13 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
|
|||
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
|
||||
"""Get roles in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/get_groups_publicised"
|
||||
)
|
||||
|
||||
PATH = "/get_groups_publicised"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query):
|
||||
resp = yield self.handler.bulk_get_publicised_groups(
|
||||
content["user_ids"], proxy=False,
|
||||
content["user_ids"], proxy=False
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -1297,6 +1284,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
|
|||
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
|
||||
"""Sets whether a group is joinable without an invite or knock
|
||||
"""
|
||||
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -1317,6 +1305,7 @@ class RoomComplexityServlet(BaseFederationServlet):
|
|||
Indicates to other servers how complex (and therefore likely
|
||||
resource-intensive) a public room this server knows about is.
|
||||
"""
|
||||
|
||||
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
|
||||
PREFIX = FEDERATION_UNSTABLE_PREFIX
|
||||
|
||||
|
@ -1325,9 +1314,7 @@ class RoomComplexityServlet(BaseFederationServlet):
|
|||
|
||||
store = self.handler.hs.get_datastore()
|
||||
|
||||
is_public = yield store.is_room_world_readable_or_publicly_joinable(
|
||||
room_id
|
||||
)
|
||||
is_public = yield store.is_room_world_readable_or_publicly_joinable(room_id)
|
||||
|
||||
if not is_public:
|
||||
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
|
||||
|
@ -1362,13 +1349,9 @@ FEDERATION_SERVLET_CLASSES = (
|
|||
RoomComplexityServlet,
|
||||
)
|
||||
|
||||
OPENID_SERVLET_CLASSES = (
|
||||
OpenIdUserInfo,
|
||||
)
|
||||
OPENID_SERVLET_CLASSES = (OpenIdUserInfo,)
|
||||
|
||||
ROOM_LIST_CLASSES = (
|
||||
PublicRoomList,
|
||||
)
|
||||
ROOM_LIST_CLASSES = (PublicRoomList,)
|
||||
|
||||
GROUP_SERVER_SERVLET_CLASSES = (
|
||||
FederationGroupsProfileServlet,
|
||||
|
@ -1399,9 +1382,7 @@ GROUP_LOCAL_SERVLET_CLASSES = (
|
|||
)
|
||||
|
||||
|
||||
GROUP_ATTESTATION_SERVLET_CLASSES = (
|
||||
FederationGroupsRenewAttestaionServlet,
|
||||
)
|
||||
GROUP_ATTESTATION_SERVLET_CLASSES = (FederationGroupsRenewAttestaionServlet,)
|
||||
|
||||
DEFAULT_SERVLET_GROUPS = (
|
||||
"federation",
|
||||
|
@ -1455,7 +1436,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
|
|||
authenticator=authenticator,
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
deny_access=hs.config.restrict_public_rooms_to_local_users,
|
||||
allow_access=hs.config.allow_public_rooms_over_federation,
|
||||
).register(resource)
|
||||
|
||||
if "group_server" in servlet_groups:
|
||||
|
|
|
@ -32,21 +32,11 @@ class Edu(JsonEncodedObject):
|
|||
internal ID or previous references graph.
|
||||
"""
|
||||
|
||||
valid_keys = [
|
||||
"origin",
|
||||
"destination",
|
||||
"edu_type",
|
||||
"content",
|
||||
]
|
||||
valid_keys = ["origin", "destination", "edu_type", "content"]
|
||||
|
||||
required_keys = [
|
||||
"edu_type",
|
||||
]
|
||||
required_keys = ["edu_type"]
|
||||
|
||||
internal_keys = [
|
||||
"origin",
|
||||
"destination",
|
||||
]
|
||||
internal_keys = ["origin", "destination"]
|
||||
|
||||
|
||||
class Transaction(JsonEncodedObject):
|
||||
|
@ -75,10 +65,7 @@ class Transaction(JsonEncodedObject):
|
|||
"edus",
|
||||
]
|
||||
|
||||
internal_keys = [
|
||||
"transaction_id",
|
||||
"destination",
|
||||
]
|
||||
internal_keys = ["transaction_id", "destination"]
|
||||
|
||||
required_keys = [
|
||||
"transaction_id",
|
||||
|
@ -98,9 +85,7 @@ class Transaction(JsonEncodedObject):
|
|||
del kwargs["edus"]
|
||||
|
||||
super(Transaction, self).__init__(
|
||||
transaction_id=transaction_id,
|
||||
pdus=pdus,
|
||||
**kwargs
|
||||
transaction_id=transaction_id, pdus=pdus, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -109,13 +94,9 @@ class Transaction(JsonEncodedObject):
|
|||
transaction_id and origin_server_ts keys.
|
||||
"""
|
||||
if "origin_server_ts" not in kwargs:
|
||||
raise KeyError(
|
||||
"Require 'origin_server_ts' to construct a Transaction"
|
||||
)
|
||||
raise KeyError("Require 'origin_server_ts' to construct a Transaction")
|
||||
if "transaction_id" not in kwargs:
|
||||
raise KeyError(
|
||||
"Require 'transaction_id' to construct a Transaction"
|
||||
)
|
||||
raise KeyError("Require 'transaction_id' to construct a Transaction")
|
||||
|
||||
kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ from signedjson.sign import sign_json
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import RequestSendFailed, SynapseError
|
||||
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.logcontext import run_in_background
|
||||
|
@ -65,6 +65,7 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
|
|||
class GroupAttestationSigning(object):
|
||||
"""Creates and verifies group attestations.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.keyring = hs.get_keyring()
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -113,11 +114,15 @@ class GroupAttestationSigning(object):
|
|||
validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
|
||||
valid_until_ms = int(self.clock.time_msec() + validity_period)
|
||||
|
||||
return sign_json({
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"valid_until_ms": valid_until_ms,
|
||||
}, self.server_name, self.signing_key)
|
||||
return sign_json(
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"valid_until_ms": valid_until_ms,
|
||||
},
|
||||
self.server_name,
|
||||
self.signing_key,
|
||||
)
|
||||
|
||||
|
||||
class GroupAttestionRenewer(object):
|
||||
|
@ -132,9 +137,10 @@ class GroupAttestionRenewer(object):
|
|||
self.is_mine_id = hs.is_mine_id
|
||||
self.attestations = hs.get_groups_attestation_signing()
|
||||
|
||||
self._renew_attestations_loop = self.clock.looping_call(
|
||||
self._start_renew_attestations, 30 * 60 * 1000,
|
||||
)
|
||||
if not hs.config.worker_app:
|
||||
self._renew_attestations_loop = self.clock.looping_call(
|
||||
self._start_renew_attestations, 30 * 60 * 1000
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_renew_attestation(self, group_id, user_id, content):
|
||||
|
@ -146,9 +152,7 @@ class GroupAttestionRenewer(object):
|
|||
raise SynapseError(400, "Neither user not group are on this server")
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
attestation, user_id=user_id, group_id=group_id
|
||||
)
|
||||
|
||||
yield self.store.update_remote_attestion(group_id, user_id, attestation)
|
||||
|
@ -179,7 +183,8 @@ class GroupAttestionRenewer(object):
|
|||
else:
|
||||
logger.warn(
|
||||
"Incorrectly trying to do attestations for user: %r in %r",
|
||||
user_id, group_id,
|
||||
user_id,
|
||||
group_id,
|
||||
)
|
||||
yield self.store.remove_attestation_renewal(group_id, user_id)
|
||||
return
|
||||
|
@ -187,21 +192,20 @@ class GroupAttestionRenewer(object):
|
|||
attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
|
||||
yield self.transport_client.renew_group_attestation(
|
||||
destination, group_id, user_id,
|
||||
content={"attestation": attestation},
|
||||
destination, group_id, user_id, content={"attestation": attestation}
|
||||
)
|
||||
|
||||
yield self.store.update_attestation_renewal(
|
||||
group_id, user_id, attestation
|
||||
)
|
||||
except RequestSendFailed as e:
|
||||
except (RequestSendFailed, HttpResponseException) as e:
|
||||
logger.warning(
|
||||
"Failed to renew attestation of %r in %r: %s",
|
||||
user_id, group_id, e,
|
||||
"Failed to renew attestation of %r in %r: %s", user_id, group_id, e
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error renewing attestation of %r in %r",
|
||||
user_id, group_id)
|
||||
logger.exception(
|
||||
"Error renewing attestation of %r in %r", user_id, group_id
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
group_id = row["group_id"]
|
||||
|
|
|
@ -54,8 +54,9 @@ class GroupsServerHandler(object):
|
|||
hs.get_groups_attestation_renewer()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_group_is_ours(self, group_id, requester_user_id,
|
||||
and_exists=False, and_is_admin=None):
|
||||
def check_group_is_ours(
|
||||
self, group_id, requester_user_id, and_exists=False, and_is_admin=None
|
||||
):
|
||||
"""Check that the group is ours, and optionally if it exists.
|
||||
|
||||
If group does exist then return group.
|
||||
|
@ -73,7 +74,9 @@ class GroupsServerHandler(object):
|
|||
if and_exists and not group:
|
||||
raise SynapseError(404, "Unknown group")
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
is_user_in_group = yield self.store.is_user_in_group(
|
||||
requester_user_id, group_id
|
||||
)
|
||||
if group and not is_user_in_group and not group["is_public"]:
|
||||
raise SynapseError(404, "Unknown group")
|
||||
|
||||
|
@ -96,25 +99,27 @@ class GroupsServerHandler(object):
|
|||
"""
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
is_user_in_group = yield self.store.is_user_in_group(
|
||||
requester_user_id, group_id
|
||||
)
|
||||
|
||||
profile = yield self.get_group_profile(group_id, requester_user_id)
|
||||
|
||||
users, roles = yield self.store.get_users_for_summary_by_role(
|
||||
group_id, include_private=is_user_in_group,
|
||||
group_id, include_private=is_user_in_group
|
||||
)
|
||||
|
||||
# TODO: Add profiles to users
|
||||
|
||||
rooms, categories = yield self.store.get_rooms_for_summary_by_category(
|
||||
group_id, include_private=is_user_in_group,
|
||||
group_id, include_private=is_user_in_group
|
||||
)
|
||||
|
||||
for room_entry in rooms:
|
||||
room_id = room_entry["room_id"]
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
entry = yield self.room_list_handler.generate_room_entry(
|
||||
room_id, len(joined_users), with_alias=False, allow_private=True,
|
||||
room_id, len(joined_users), with_alias=False, allow_private=True
|
||||
)
|
||||
entry = dict(entry) # so we don't change whats cached
|
||||
entry.pop("room_id", None)
|
||||
|
@ -134,7 +139,7 @@ class GroupsServerHandler(object):
|
|||
entry["attestation"] = attestation
|
||||
else:
|
||||
entry["attestation"] = self.attestations.create_attestation(
|
||||
group_id, user_id,
|
||||
group_id, user_id
|
||||
)
|
||||
|
||||
user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
|
||||
|
@ -143,34 +148,34 @@ class GroupsServerHandler(object):
|
|||
users.sort(key=lambda e: e.get("order", 0))
|
||||
|
||||
membership_info = yield self.store.get_users_membership_info_in_group(
|
||||
group_id, requester_user_id,
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"profile": profile,
|
||||
"users_section": {
|
||||
"users": users,
|
||||
"roles": roles,
|
||||
"total_user_count_estimate": 0, # TODO
|
||||
},
|
||||
"rooms_section": {
|
||||
"rooms": rooms,
|
||||
"categories": categories,
|
||||
"total_room_count_estimate": 0, # TODO
|
||||
},
|
||||
"user": membership_info,
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"profile": profile,
|
||||
"users_section": {
|
||||
"users": users,
|
||||
"roles": roles,
|
||||
"total_user_count_estimate": 0, # TODO
|
||||
},
|
||||
"rooms_section": {
|
||||
"rooms": rooms,
|
||||
"categories": categories,
|
||||
"total_room_count_estimate": 0, # TODO
|
||||
},
|
||||
"user": membership_info,
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_summary_room(self, group_id, requester_user_id,
|
||||
room_id, category_id, content):
|
||||
def update_group_summary_room(
|
||||
self, group_id, requester_user_id, room_id, category_id, content
|
||||
):
|
||||
"""Add/update a room to the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
RoomID.from_string(room_id) # Ensure valid room id
|
||||
|
@ -190,21 +195,17 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_summary_room(self, group_id, requester_user_id,
|
||||
room_id, category_id):
|
||||
def delete_group_summary_room(
|
||||
self, group_id, requester_user_id, room_id, category_id
|
||||
):
|
||||
"""Remove a room from the summary
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
yield self.store.remove_room_from_summary(
|
||||
group_id=group_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
group_id=group_id, room_id=room_id, category_id=category_id
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
@ -223,9 +224,7 @@ class GroupsServerHandler(object):
|
|||
|
||||
join_policy = _parse_join_policy_from_contents(content)
|
||||
if join_policy is None:
|
||||
raise SynapseError(
|
||||
400, "No value specified for 'm.join_policy'"
|
||||
)
|
||||
raise SynapseError(400, "No value specified for 'm.join_policy'")
|
||||
|
||||
yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
|
||||
|
||||
|
@ -237,9 +236,7 @@ class GroupsServerHandler(object):
|
|||
"""
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
categories = yield self.store.get_group_categories(
|
||||
group_id=group_id,
|
||||
)
|
||||
categories = yield self.store.get_group_categories(group_id=group_id)
|
||||
defer.returnValue({"categories": categories})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -249,8 +246,7 @@ class GroupsServerHandler(object):
|
|||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
res = yield self.store.get_group_category(
|
||||
group_id=group_id,
|
||||
category_id=category_id,
|
||||
group_id=group_id, category_id=category_id
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
@ -260,10 +256,7 @@ class GroupsServerHandler(object):
|
|||
"""Add/Update a group category
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
@ -283,15 +276,11 @@ class GroupsServerHandler(object):
|
|||
"""Delete a group category
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
yield self.store.remove_group_category(
|
||||
group_id=group_id,
|
||||
category_id=category_id,
|
||||
group_id=group_id, category_id=category_id
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
@ -302,9 +291,7 @@ class GroupsServerHandler(object):
|
|||
"""
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
roles = yield self.store.get_group_roles(
|
||||
group_id=group_id,
|
||||
)
|
||||
roles = yield self.store.get_group_roles(group_id=group_id)
|
||||
defer.returnValue({"roles": roles})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -313,10 +300,7 @@ class GroupsServerHandler(object):
|
|||
"""
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
res = yield self.store.get_group_role(
|
||||
group_id=group_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
res = yield self.store.get_group_role(group_id=group_id, role_id=role_id)
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -324,10 +308,7 @@ class GroupsServerHandler(object):
|
|||
"""Add/update a role in a group
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
@ -335,10 +316,7 @@ class GroupsServerHandler(object):
|
|||
profile = content.get("profile")
|
||||
|
||||
yield self.store.upsert_group_role(
|
||||
group_id=group_id,
|
||||
role_id=role_id,
|
||||
is_public=is_public,
|
||||
profile=profile,
|
||||
group_id=group_id, role_id=role_id, is_public=is_public, profile=profile
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
@ -348,26 +326,21 @@ class GroupsServerHandler(object):
|
|||
"""Remove role from group
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
and_exists=True,
|
||||
and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
yield self.store.remove_group_role(
|
||||
group_id=group_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
yield self.store.remove_group_role(group_id=group_id, role_id=role_id)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
|
||||
content):
|
||||
def update_group_summary_user(
|
||||
self, group_id, requester_user_id, user_id, role_id, content
|
||||
):
|
||||
"""Add/update a users entry in the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
order = content.get("order", None)
|
||||
|
@ -389,13 +362,11 @@ class GroupsServerHandler(object):
|
|||
"""Remove a user from the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
yield self.store.remove_user_from_summary(
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
group_id=group_id, user_id=user_id, role_id=role_id
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
@ -411,8 +382,11 @@ class GroupsServerHandler(object):
|
|||
|
||||
if group:
|
||||
cols = [
|
||||
"name", "short_description", "long_description",
|
||||
"avatar_url", "is_public",
|
||||
"name",
|
||||
"short_description",
|
||||
"long_description",
|
||||
"avatar_url",
|
||||
"is_public",
|
||||
]
|
||||
group_description = {key: group[key] for key in cols}
|
||||
group_description["is_openly_joinable"] = group["join_policy"] == "open"
|
||||
|
@ -426,12 +400,11 @@ class GroupsServerHandler(object):
|
|||
"""Update the group profile
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
profile = {}
|
||||
for keyname in ("name", "avatar_url", "short_description",
|
||||
"long_description"):
|
||||
for keyname in ("name", "avatar_url", "short_description", "long_description"):
|
||||
if keyname in content:
|
||||
value = content[keyname]
|
||||
if not isinstance(value, string_types):
|
||||
|
@ -449,10 +422,12 @@ class GroupsServerHandler(object):
|
|||
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
is_user_in_group = yield self.store.is_user_in_group(
|
||||
requester_user_id, group_id
|
||||
)
|
||||
|
||||
user_results = yield self.store.get_users_in_group(
|
||||
group_id, include_private=is_user_in_group,
|
||||
group_id, include_private=is_user_in_group
|
||||
)
|
||||
|
||||
chunk = []
|
||||
|
@ -470,24 +445,25 @@ class GroupsServerHandler(object):
|
|||
entry["is_privileged"] = bool(is_privileged)
|
||||
|
||||
if not self.is_mine_id(g_user_id):
|
||||
attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
|
||||
attestation = yield self.store.get_remote_attestation(
|
||||
group_id, g_user_id
|
||||
)
|
||||
if not attestation:
|
||||
continue
|
||||
|
||||
entry["attestation"] = attestation
|
||||
else:
|
||||
entry["attestation"] = self.attestations.create_attestation(
|
||||
group_id, g_user_id,
|
||||
group_id, g_user_id
|
||||
)
|
||||
|
||||
chunk.append(entry)
|
||||
|
||||
# TODO: If admin add lists of users whose attestations have timed out
|
||||
|
||||
defer.returnValue({
|
||||
"chunk": chunk,
|
||||
"total_user_count_estimate": len(user_results),
|
||||
})
|
||||
defer.returnValue(
|
||||
{"chunk": chunk, "total_user_count_estimate": len(user_results)}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_invited_users_in_group(self, group_id, requester_user_id):
|
||||
|
@ -498,7 +474,9 @@ class GroupsServerHandler(object):
|
|||
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
is_user_in_group = yield self.store.is_user_in_group(
|
||||
requester_user_id, group_id
|
||||
)
|
||||
|
||||
if not is_user_in_group:
|
||||
raise SynapseError(403, "User not in group")
|
||||
|
@ -508,9 +486,7 @@ class GroupsServerHandler(object):
|
|||
user_profiles = []
|
||||
|
||||
for user_id in invited_users:
|
||||
user_profile = {
|
||||
"user_id": user_id
|
||||
}
|
||||
user_profile = {"user_id": user_id}
|
||||
try:
|
||||
profile = yield self.profile_handler.get_profile_from_cache(user_id)
|
||||
user_profile.update(profile)
|
||||
|
@ -518,10 +494,9 @@ class GroupsServerHandler(object):
|
|||
logger.warn("Error getting profile for %s: %s", user_id, e)
|
||||
user_profiles.append(user_profile)
|
||||
|
||||
defer.returnValue({
|
||||
"chunk": user_profiles,
|
||||
"total_user_count_estimate": len(invited_users),
|
||||
})
|
||||
defer.returnValue(
|
||||
{"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_in_group(self, group_id, requester_user_id):
|
||||
|
@ -532,10 +507,12 @@ class GroupsServerHandler(object):
|
|||
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
is_user_in_group = yield self.store.is_user_in_group(
|
||||
requester_user_id, group_id
|
||||
)
|
||||
|
||||
room_results = yield self.store.get_rooms_in_group(
|
||||
group_id, include_private=is_user_in_group,
|
||||
group_id, include_private=is_user_in_group
|
||||
)
|
||||
|
||||
chunk = []
|
||||
|
@ -544,7 +521,7 @@ class GroupsServerHandler(object):
|
|||
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
entry = yield self.room_list_handler.generate_room_entry(
|
||||
room_id, len(joined_users), with_alias=False, allow_private=True,
|
||||
room_id, len(joined_users), with_alias=False, allow_private=True
|
||||
)
|
||||
|
||||
if not entry:
|
||||
|
@ -556,10 +533,9 @@ class GroupsServerHandler(object):
|
|||
|
||||
chunk.sort(key=lambda e: -e["num_joined_members"])
|
||||
|
||||
defer.returnValue({
|
||||
"chunk": chunk,
|
||||
"total_room_count_estimate": len(room_results),
|
||||
})
|
||||
defer.returnValue(
|
||||
{"chunk": chunk, "total_room_count_estimate": len(room_results)}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_room_to_group(self, group_id, requester_user_id, room_id, content):
|
||||
|
@ -578,8 +554,9 @@ class GroupsServerHandler(object):
|
|||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
|
||||
content):
|
||||
def update_room_in_group(
|
||||
self, group_id, requester_user_id, room_id, config_key, content
|
||||
):
|
||||
"""Update room in group
|
||||
"""
|
||||
RoomID.from_string(room_id) # Ensure valid room id
|
||||
|
@ -592,8 +569,7 @@ class GroupsServerHandler(object):
|
|||
is_public = _parse_visibility_dict(content)
|
||||
|
||||
yield self.store.update_room_in_group_visibility(
|
||||
group_id, room_id,
|
||||
is_public=is_public,
|
||||
group_id, room_id, is_public=is_public
|
||||
)
|
||||
else:
|
||||
raise SynapseError(400, "Uknown config option")
|
||||
|
@ -625,10 +601,7 @@ class GroupsServerHandler(object):
|
|||
# TODO: Check if user is already invited
|
||||
|
||||
content = {
|
||||
"profile": {
|
||||
"name": group["name"],
|
||||
"avatar_url": group["avatar_url"],
|
||||
},
|
||||
"profile": {"name": group["name"], "avatar_url": group["avatar_url"]},
|
||||
"inviter": requester_user_id,
|
||||
}
|
||||
|
||||
|
@ -638,9 +611,7 @@ class GroupsServerHandler(object):
|
|||
local_attestation = None
|
||||
else:
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
content.update({
|
||||
"attestation": local_attestation,
|
||||
})
|
||||
content.update({"attestation": local_attestation})
|
||||
|
||||
res = yield self.transport_client.invite_to_group_notification(
|
||||
get_domain_from_id(user_id), group_id, user_id, content
|
||||
|
@ -658,31 +629,24 @@ class GroupsServerHandler(object):
|
|||
remote_attestation = res["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
remote_attestation, user_id=user_id, group_id=group_id
|
||||
)
|
||||
else:
|
||||
remote_attestation = None
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, user_id,
|
||||
group_id,
|
||||
user_id,
|
||||
is_admin=False,
|
||||
is_public=False, # TODO
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
)
|
||||
elif res["state"] == "invite":
|
||||
yield self.store.add_group_invite(
|
||||
group_id, user_id,
|
||||
)
|
||||
defer.returnValue({
|
||||
"state": "invite"
|
||||
})
|
||||
yield self.store.add_group_invite(group_id, user_id)
|
||||
defer.returnValue({"state": "invite"})
|
||||
elif res["state"] == "reject":
|
||||
defer.returnValue({
|
||||
"state": "reject"
|
||||
})
|
||||
defer.returnValue({"state": "reject"})
|
||||
else:
|
||||
raise SynapseError(502, "Unknown state returned by HS")
|
||||
|
||||
|
@ -693,16 +657,12 @@ class GroupsServerHandler(object):
|
|||
See accept_invite, join_group.
|
||||
"""
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
local_attestation = self.attestations.create_attestation(
|
||||
group_id, user_id,
|
||||
)
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
|
||||
remote_attestation = content["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
remote_attestation, user_id=user_id, group_id=group_id
|
||||
)
|
||||
else:
|
||||
local_attestation = None
|
||||
|
@ -711,7 +671,8 @@ class GroupsServerHandler(object):
|
|||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, user_id,
|
||||
group_id,
|
||||
user_id,
|
||||
is_admin=False,
|
||||
is_public=is_public,
|
||||
local_attestation=local_attestation,
|
||||
|
@ -731,17 +692,14 @@ class GroupsServerHandler(object):
|
|||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_invited = yield self.store.is_user_invited_to_local_group(
|
||||
group_id, requester_user_id,
|
||||
group_id, requester_user_id
|
||||
)
|
||||
if not is_invited:
|
||||
raise SynapseError(403, "User not invited to group")
|
||||
|
||||
local_attestation = yield self._add_user(group_id, requester_user_id, content)
|
||||
|
||||
defer.returnValue({
|
||||
"state": "join",
|
||||
"attestation": local_attestation,
|
||||
})
|
||||
defer.returnValue({"state": "join", "attestation": local_attestation})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def join_group(self, group_id, requester_user_id, content):
|
||||
|
@ -753,15 +711,12 @@ class GroupsServerHandler(object):
|
|||
group_info = yield self.check_group_is_ours(
|
||||
group_id, requester_user_id, and_exists=True
|
||||
)
|
||||
if group_info['join_policy'] != "open":
|
||||
if group_info["join_policy"] != "open":
|
||||
raise SynapseError(403, "Group is not publicly joinable")
|
||||
|
||||
local_attestation = yield self._add_user(group_id, requester_user_id, content)
|
||||
|
||||
defer.returnValue({
|
||||
"state": "join",
|
||||
"attestation": local_attestation,
|
||||
})
|
||||
defer.returnValue({"state": "join", "attestation": local_attestation})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def knock(self, group_id, requester_user_id, content):
|
||||
|
@ -800,9 +755,7 @@ class GroupsServerHandler(object):
|
|||
|
||||
is_kick = True
|
||||
|
||||
yield self.store.remove_user_from_group(
|
||||
group_id, user_id,
|
||||
)
|
||||
yield self.store.remove_user_from_group(group_id, user_id)
|
||||
|
||||
if is_kick:
|
||||
if self.hs.is_mine_id(user_id):
|
||||
|
@ -830,19 +783,20 @@ class GroupsServerHandler(object):
|
|||
if group:
|
||||
raise SynapseError(400, "Group already exists")
|
||||
|
||||
is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
|
||||
is_admin = yield self.auth.is_server_admin(
|
||||
UserID.from_string(requester_user_id)
|
||||
)
|
||||
if not is_admin:
|
||||
if not self.hs.config.enable_group_creation:
|
||||
raise SynapseError(
|
||||
403, "Only a server admin can create groups on this server",
|
||||
403, "Only a server admin can create groups on this server"
|
||||
)
|
||||
localpart = group_id_obj.localpart
|
||||
if not localpart.startswith(self.hs.config.group_creation_prefix):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Can only create groups with prefix %r on this server" % (
|
||||
self.hs.config.group_creation_prefix,
|
||||
),
|
||||
"Can only create groups with prefix %r on this server"
|
||||
% (self.hs.config.group_creation_prefix,),
|
||||
)
|
||||
|
||||
profile = content.get("profile", {})
|
||||
|
@ -865,21 +819,19 @@ class GroupsServerHandler(object):
|
|||
remote_attestation = content["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=requester_user_id,
|
||||
group_id=group_id,
|
||||
remote_attestation, user_id=requester_user_id, group_id=group_id
|
||||
)
|
||||
|
||||
local_attestation = self.attestations.create_attestation(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
group_id, requester_user_id
|
||||
)
|
||||
else:
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, requester_user_id,
|
||||
group_id,
|
||||
requester_user_id,
|
||||
is_admin=True,
|
||||
is_public=True, # TODO
|
||||
local_attestation=local_attestation,
|
||||
|
@ -893,9 +845,7 @@ class GroupsServerHandler(object):
|
|||
avatar_url=user_profile.get("avatar_url"),
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"group_id": group_id,
|
||||
})
|
||||
defer.returnValue({"group_id": group_id})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group(self, group_id, requester_user_id):
|
||||
|
@ -911,29 +861,22 @@ class GroupsServerHandler(object):
|
|||
Deferred
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(
|
||||
group_id, requester_user_id,
|
||||
and_exists=True,
|
||||
)
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
# Only server admins or group admins can delete groups.
|
||||
|
||||
is_admin = yield self.store.is_user_admin_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id)
|
||||
|
||||
if not is_admin:
|
||||
is_admin = yield self.auth.is_server_admin(
|
||||
UserID.from_string(requester_user_id),
|
||||
UserID.from_string(requester_user_id)
|
||||
)
|
||||
|
||||
if not is_admin:
|
||||
raise SynapseError(403, "User is not an admin")
|
||||
|
||||
# Before deleting the group lets kick everyone out of it
|
||||
users = yield self.store.get_users_in_group(
|
||||
group_id, include_private=True,
|
||||
)
|
||||
users = yield self.store.get_users_in_group(group_id, include_private=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _kick_user_from_group(user_id):
|
||||
|
@ -989,9 +932,7 @@ def _parse_join_policy_dict(join_policy_dict):
|
|||
return "invite"
|
||||
|
||||
if join_policy_type not in ("invite", "open"):
|
||||
raise SynapseError(
|
||||
400, "Synapse only supports 'invite'/'open' join rule"
|
||||
)
|
||||
raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule")
|
||||
return join_policy_type
|
||||
|
||||
|
||||
|
@ -1018,7 +959,5 @@ def _parse_visibility_dict(visibility):
|
|||
return True
|
||||
|
||||
if vis_type not in ("public", "private"):
|
||||
raise SynapseError(
|
||||
400, "Synapse only supports 'public'/'private' visibility"
|
||||
)
|
||||
raise SynapseError(400, "Synapse only supports 'public'/'private' visibility")
|
||||
return vis_type == "public"
|
||||
|
|
|
@ -94,14 +94,15 @@ class BaseHandler(object):
|
|||
burst_count = self.hs.config.rc_message.burst_count
|
||||
|
||||
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||
user_id, time_now,
|
||||
user_id,
|
||||
time_now,
|
||||
rate_hz=messages_per_second,
|
||||
burst_count=burst_count,
|
||||
update=update,
|
||||
)
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -139,7 +140,7 @@ class BaseHandler(object):
|
|||
|
||||
if member_event.content["membership"] not in {
|
||||
Membership.JOIN,
|
||||
Membership.INVITE
|
||||
Membership.INVITE,
|
||||
}:
|
||||
continue
|
||||
|
||||
|
@ -156,8 +157,7 @@ class BaseHandler(object):
|
|||
# and having homeservers have their own users leave keeps more
|
||||
# of that decision-making and control local to the guest-having
|
||||
# homeserver.
|
||||
requester = synapse.types.create_requester(
|
||||
target_user, is_guest=True)
|
||||
requester = synapse.types.create_requester(target_user, is_guest=True)
|
||||
handler = self.hs.get_room_member_handler()
|
||||
yield handler.update_membership(
|
||||
requester,
|
||||
|
|
|
@ -20,7 +20,7 @@ class AccountDataEventSource(object):
|
|||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
def get_current_key(self, direction='f'):
|
||||
def get_current_key(self, direction="f"):
|
||||
return self.store.get_max_account_data_stream_id()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -34,29 +34,22 @@ class AccountDataEventSource(object):
|
|||
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
|
||||
|
||||
for room_id, room_tags in tags.items():
|
||||
results.append({
|
||||
"type": "m.tag",
|
||||
"content": {"tags": room_tags},
|
||||
"room_id": room_id,
|
||||
})
|
||||
results.append(
|
||||
{"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
|
||||
)
|
||||
|
||||
account_data, room_account_data = (
|
||||
yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
||||
)
|
||||
|
||||
for account_data_type, content in account_data.items():
|
||||
results.append({
|
||||
"type": account_data_type,
|
||||
"content": content,
|
||||
})
|
||||
results.append({"type": account_data_type, "content": content})
|
||||
|
||||
for room_id, account_data in room_account_data.items():
|
||||
for account_data_type, content in account_data.items():
|
||||
results.append({
|
||||
"type": account_data_type,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
})
|
||||
results.append(
|
||||
{"type": account_data_type, "content": content, "room_id": room_id}
|
||||
)
|
||||
|
||||
defer.returnValue((results, current_stream_id))
|
||||
|
||||
|
|
|
@ -49,12 +49,10 @@ class AccountValidityHandler(object):
|
|||
app_name = self.hs.config.email_app_name
|
||||
|
||||
self._subject = self._account_validity.renew_email_subject % {
|
||||
"app": app_name,
|
||||
"app": app_name
|
||||
}
|
||||
|
||||
self._from_string = self.hs.config.email_notif_from % {
|
||||
"app": app_name,
|
||||
}
|
||||
self._from_string = self.hs.config.email_notif_from % {"app": app_name}
|
||||
except Exception:
|
||||
# If substitution failed, fall back to the bare strings.
|
||||
self._subject = self._account_validity.renew_email_subject
|
||||
|
@ -69,10 +67,7 @@ class AccountValidityHandler(object):
|
|||
)
|
||||
|
||||
# Check the renewal emails to send and send them every 30min.
|
||||
self.clock.looping_call(
|
||||
self.send_renewal_emails,
|
||||
30 * 60 * 1000,
|
||||
)
|
||||
self.clock.looping_call(self.send_renewal_emails, 30 * 60 * 1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_renewal_emails(self):
|
||||
|
@ -86,8 +81,7 @@ class AccountValidityHandler(object):
|
|||
if expiring_users:
|
||||
for user in expiring_users:
|
||||
yield self._send_renewal_email(
|
||||
user_id=user["user_id"],
|
||||
expiration_ts=user["expiration_ts_ms"],
|
||||
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -110,6 +104,9 @@ class AccountValidityHandler(object):
|
|||
# Stop right here if the user doesn't have at least one email address.
|
||||
# In this case, they will have to ask their server admin to renew their
|
||||
# account manually.
|
||||
# We don't need to do a specific check to make sure the account isn't
|
||||
# deactivated, as a deactivated account isn't supposed to have any
|
||||
# email address attached to it.
|
||||
if not addresses:
|
||||
return
|
||||
|
||||
|
@ -143,32 +140,33 @@ class AccountValidityHandler(object):
|
|||
for address in addresses:
|
||||
raw_to = email.utils.parseaddr(address)[1]
|
||||
|
||||
multipart_msg = MIMEMultipart('alternative')
|
||||
multipart_msg['Subject'] = self._subject
|
||||
multipart_msg['From'] = self._from_string
|
||||
multipart_msg['To'] = address
|
||||
multipart_msg['Date'] = email.utils.formatdate()
|
||||
multipart_msg['Message-ID'] = email.utils.make_msgid()
|
||||
multipart_msg = MIMEMultipart("alternative")
|
||||
multipart_msg["Subject"] = self._subject
|
||||
multipart_msg["From"] = self._from_string
|
||||
multipart_msg["To"] = address
|
||||
multipart_msg["Date"] = email.utils.formatdate()
|
||||
multipart_msg["Message-ID"] = email.utils.make_msgid()
|
||||
multipart_msg.attach(text_part)
|
||||
multipart_msg.attach(html_part)
|
||||
|
||||
logger.info("Sending renewal email to %s", address)
|
||||
|
||||
yield make_deferred_yieldable(self.sendmail(
|
||||
self.hs.config.email_smtp_host,
|
||||
self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
|
||||
reactor=self.hs.get_reactor(),
|
||||
port=self.hs.config.email_smtp_port,
|
||||
requireAuthentication=self.hs.config.email_smtp_user is not None,
|
||||
username=self.hs.config.email_smtp_user,
|
||||
password=self.hs.config.email_smtp_pass,
|
||||
requireTransportSecurity=self.hs.config.require_transport_security
|
||||
))
|
||||
yield make_deferred_yieldable(
|
||||
self.sendmail(
|
||||
self.hs.config.email_smtp_host,
|
||||
self._raw_from,
|
||||
raw_to,
|
||||
multipart_msg.as_string().encode("utf8"),
|
||||
reactor=self.hs.get_reactor(),
|
||||
port=self.hs.config.email_smtp_port,
|
||||
requireAuthentication=self.hs.config.email_smtp_user is not None,
|
||||
username=self.hs.config.email_smtp_user,
|
||||
password=self.hs.config.email_smtp_pass,
|
||||
requireTransportSecurity=self.hs.config.require_transport_security,
|
||||
)
|
||||
)
|
||||
|
||||
yield self.store.set_renewal_mail_status(
|
||||
user_id=user_id,
|
||||
email_sent=True,
|
||||
)
|
||||
yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_email_addresses_for_user(self, user_id):
|
||||
|
@ -245,9 +243,7 @@ class AccountValidityHandler(object):
|
|||
expiration_ts = self.clock.time_msec() + self._account_validity.period
|
||||
|
||||
yield self.store.set_account_validity_for_user(
|
||||
user_id=user_id,
|
||||
expiration_ts=expiration_ts,
|
||||
email_sent=email_sent,
|
||||
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
|
||||
)
|
||||
|
||||
defer.returnValue(expiration_ts)
|
||||
|
|
|
@ -15,14 +15,9 @@
|
|||
|
||||
import logging
|
||||
|
||||
import attr
|
||||
from zope.interface import implementer
|
||||
|
||||
import twisted
|
||||
import twisted.internet.error
|
||||
from twisted.internet import defer
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.url import URL
|
||||
from twisted.web import server, static
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
|
@ -30,27 +25,6 @@ from synapse.app import check_bind_error
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from txacme.interfaces import ICertificateStore
|
||||
|
||||
@attr.s
|
||||
@implementer(ICertificateStore)
|
||||
class ErsatzStore(object):
|
||||
"""
|
||||
A store that only stores in memory.
|
||||
"""
|
||||
|
||||
certs = attr.ib(default=attr.Factory(dict))
|
||||
|
||||
def store(self, server_name, pem_objects):
|
||||
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
except ImportError:
|
||||
# txacme is missing
|
||||
pass
|
||||
|
||||
|
||||
class AcmeHandler(object):
|
||||
def __init__(self, hs):
|
||||
|
@ -60,6 +34,7 @@ class AcmeHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def start_listening(self):
|
||||
from synapse.handlers import acme_issuing_service
|
||||
|
||||
# Configure logging for txacme, if you need to debug
|
||||
# from eliot import add_destinations
|
||||
|
@ -67,50 +42,27 @@ class AcmeHandler(object):
|
|||
#
|
||||
# add_destinations(TwistedDestination())
|
||||
|
||||
from txacme.challenges import HTTP01Responder
|
||||
from txacme.service import AcmeIssuingService
|
||||
from txacme.endpoint import load_or_create_client_key
|
||||
from txacme.client import Client
|
||||
from josepy.jwa import RS256
|
||||
well_known = Resource()
|
||||
|
||||
self._store = ErsatzStore()
|
||||
responder = HTTP01Responder()
|
||||
|
||||
self._issuer = AcmeIssuingService(
|
||||
cert_store=self._store,
|
||||
client_creator=(
|
||||
lambda: Client.from_url(
|
||||
reactor=self.reactor,
|
||||
url=URL.from_text(self.hs.config.acme_url),
|
||||
key=load_or_create_client_key(
|
||||
FilePath(self.hs.config.config_dir_path)
|
||||
),
|
||||
alg=RS256,
|
||||
)
|
||||
),
|
||||
clock=self.reactor,
|
||||
responders=[responder],
|
||||
self._issuer = acme_issuing_service.create_issuing_service(
|
||||
self.reactor,
|
||||
acme_url=self.hs.config.acme_url,
|
||||
account_key_file=self.hs.config.acme_account_key_file,
|
||||
well_known_resource=well_known,
|
||||
)
|
||||
|
||||
well_known = Resource()
|
||||
well_known.putChild(b'acme-challenge', responder.resource)
|
||||
responder_resource = Resource()
|
||||
responder_resource.putChild(b'.well-known', well_known)
|
||||
responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
|
||||
|
||||
responder_resource.putChild(b".well-known", well_known)
|
||||
responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain"))
|
||||
srv = server.Site(responder_resource)
|
||||
|
||||
bind_addresses = self.hs.config.acme_bind_addresses
|
||||
for host in bind_addresses:
|
||||
logger.info(
|
||||
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port,
|
||||
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
|
||||
)
|
||||
try:
|
||||
self.reactor.listenTCP(
|
||||
self.hs.config.acme_port,
|
||||
srv,
|
||||
interface=host,
|
||||
)
|
||||
self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
|
||||
except twisted.internet.error.CannotListenError as e:
|
||||
check_bind_error(e, host, bind_addresses)
|
||||
|
||||
|
@ -132,7 +84,7 @@ class AcmeHandler(object):
|
|||
logger.exception("Fail!")
|
||||
raise
|
||||
logger.warning("Reprovisioned %s, saving.", self._acme_domain)
|
||||
cert_chain = self._store.certs[self._acme_domain]
|
||||
cert_chain = self._issuer.cert_store.certs[self._acme_domain]
|
||||
|
||||
try:
|
||||
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
|
||||
|
|
117
synapse/handlers/acme_issuing_service.py
Normal file
117
synapse/handlers/acme_issuing_service.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utility function to create an ACME issuing service.
|
||||
|
||||
This file contains the unconditional imports on the acme and cryptography bits that we
|
||||
only need (and may only have available) if we are doing ACME, so is designed to be
|
||||
imported conditionally.
|
||||
"""
|
||||
import logging
|
||||
|
||||
import attr
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from josepy import JWKRSA
|
||||
from josepy.jwa import RS256
|
||||
from txacme.challenges import HTTP01Responder
|
||||
from txacme.client import Client
|
||||
from txacme.interfaces import ICertificateStore
|
||||
from txacme.service import AcmeIssuingService
|
||||
from txacme.util import generate_private_key
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.url import URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
|
||||
"""Create an ACME issuing service, and attach it to a web Resource
|
||||
|
||||
Args:
|
||||
reactor: twisted reactor
|
||||
acme_url (str): URL to use to request certificates
|
||||
account_key_file (str): where to store the account key
|
||||
well_known_resource (twisted.web.IResource): web resource for .well-known.
|
||||
we will attach a child resource for "acme-challenge".
|
||||
|
||||
Returns:
|
||||
AcmeIssuingService
|
||||
"""
|
||||
responder = HTTP01Responder()
|
||||
|
||||
well_known_resource.putChild(b"acme-challenge", responder.resource)
|
||||
|
||||
store = ErsatzStore()
|
||||
|
||||
return AcmeIssuingService(
|
||||
cert_store=store,
|
||||
client_creator=(
|
||||
lambda: Client.from_url(
|
||||
reactor=reactor,
|
||||
url=URL.from_text(acme_url),
|
||||
key=load_or_create_client_key(account_key_file),
|
||||
alg=RS256,
|
||||
)
|
||||
),
|
||||
clock=reactor,
|
||||
responders=[responder],
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
@implementer(ICertificateStore)
|
||||
class ErsatzStore(object):
|
||||
"""
|
||||
A store that only stores in memory.
|
||||
"""
|
||||
|
||||
certs = attr.ib(default=attr.Factory(dict))
|
||||
|
||||
def store(self, server_name, pem_objects):
|
||||
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
def load_or_create_client_key(key_file):
|
||||
"""Load the ACME account key from a file, creating it if it does not exist.
|
||||
|
||||
Args:
|
||||
key_file (str): name of the file to use as the account key
|
||||
"""
|
||||
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
|
||||
# hardcode the 'client.key' filename
|
||||
acme_key_file = FilePath(key_file)
|
||||
if acme_key_file.exists():
|
||||
logger.info("Loading ACME account key from '%s'", acme_key_file)
|
||||
key = serialization.load_pem_private_key(
|
||||
acme_key_file.getContent(), password=None, backend=default_backend()
|
||||
)
|
||||
else:
|
||||
logger.info("Saving new ACME account key to '%s'", acme_key_file)
|
||||
key = generate_private_key("rsa")
|
||||
acme_key_file.setContent(
|
||||
key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
return JWKRSA(key=key)
|
|
@ -23,7 +23,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AdminHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(AdminHandler, self).__init__(hs)
|
||||
|
||||
|
@ -33,23 +32,17 @@ class AdminHandler(BaseHandler):
|
|||
|
||||
sessions = yield self.store.get_user_ip_and_agents(user)
|
||||
for session in sessions:
|
||||
connections.append({
|
||||
"ip": session["ip"],
|
||||
"last_seen": session["last_seen"],
|
||||
"user_agent": session["user_agent"],
|
||||
})
|
||||
connections.append(
|
||||
{
|
||||
"ip": session["ip"],
|
||||
"last_seen": session["last_seen"],
|
||||
"user_agent": session["user_agent"],
|
||||
}
|
||||
)
|
||||
|
||||
ret = {
|
||||
"user_id": user.to_string(),
|
||||
"devices": {
|
||||
"": {
|
||||
"sessions": [
|
||||
{
|
||||
"connections": connections,
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
"devices": {"": {"sessions": [{"connections": connections}]}},
|
||||
}
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
|
|
@ -38,7 +38,6 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed
|
|||
|
||||
|
||||
class ApplicationServicesHandler(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
@ -101,9 +100,10 @@ class ApplicationServicesHandler(object):
|
|||
yield self._check_user_exists(event.state_key)
|
||||
|
||||
if not self.started_scheduler:
|
||||
|
||||
def start_scheduler():
|
||||
return self.scheduler.start().addErrback(
|
||||
log_failure, "Application Services Failure",
|
||||
log_failure, "Application Services Failure"
|
||||
)
|
||||
|
||||
run_as_background_process("as_scheduler", start_scheduler)
|
||||
|
@ -118,10 +118,15 @@ class ApplicationServicesHandler(object):
|
|||
for event in events:
|
||||
yield handle_event(event)
|
||||
|
||||
yield make_deferred_yieldable(defer.gatherResults([
|
||||
run_in_background(handle_room_events, evs)
|
||||
for evs in itervalues(events_by_room)
|
||||
], consumeErrors=True))
|
||||
yield make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(handle_room_events, evs)
|
||||
for evs in itervalues(events_by_room)
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
)
|
||||
|
||||
yield self.store.set_appservice_last_pos(upper_bound)
|
||||
|
||||
|
@ -129,20 +134,23 @@ class ApplicationServicesHandler(object):
|
|||
ts = yield self.store.get_received_ts(events[-1].event_id)
|
||||
|
||||
synapse.metrics.event_processing_positions.labels(
|
||||
"appservice_sender").set(upper_bound)
|
||||
"appservice_sender"
|
||||
).set(upper_bound)
|
||||
|
||||
events_processed_counter.inc(len(events))
|
||||
|
||||
event_processing_loop_room_count.labels(
|
||||
"appservice_sender"
|
||||
).inc(len(events_by_room))
|
||||
event_processing_loop_room_count.labels("appservice_sender").inc(
|
||||
len(events_by_room)
|
||||
)
|
||||
|
||||
event_processing_loop_counter.labels("appservice_sender").inc()
|
||||
|
||||
synapse.metrics.event_processing_lag.labels(
|
||||
"appservice_sender").set(now - ts)
|
||||
"appservice_sender"
|
||||
).set(now - ts)
|
||||
synapse.metrics.event_processing_last_ts.labels(
|
||||
"appservice_sender").set(ts)
|
||||
"appservice_sender"
|
||||
).set(ts)
|
||||
finally:
|
||||
self.is_processing = False
|
||||
|
||||
|
@ -155,13 +163,9 @@ class ApplicationServicesHandler(object):
|
|||
Returns:
|
||||
True if this user exists on at least one application service.
|
||||
"""
|
||||
user_query_services = yield self._get_services_for_user(
|
||||
user_id=user_id
|
||||
)
|
||||
user_query_services = yield self._get_services_for_user(user_id=user_id)
|
||||
for user_service in user_query_services:
|
||||
is_known_user = yield self.appservice_api.query_user(
|
||||
user_service, user_id
|
||||
)
|
||||
is_known_user = yield self.appservice_api.query_user(user_service, user_id)
|
||||
if is_known_user:
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(False)
|
||||
|
@ -179,9 +183,7 @@ class ApplicationServicesHandler(object):
|
|||
room_alias_str = room_alias.to_string()
|
||||
services = self.store.get_app_services()
|
||||
alias_query_services = [
|
||||
s for s in services if (
|
||||
s.is_interested_in_alias(room_alias_str)
|
||||
)
|
||||
s for s in services if (s.is_interested_in_alias(room_alias_str))
|
||||
]
|
||||
for alias_service in alias_query_services:
|
||||
is_known_alias = yield self.appservice_api.query_alias(
|
||||
|
@ -189,22 +191,24 @@ class ApplicationServicesHandler(object):
|
|||
)
|
||||
if is_known_alias:
|
||||
# the alias exists now so don't query more ASes.
|
||||
result = yield self.store.get_association_from_room_alias(
|
||||
room_alias
|
||||
)
|
||||
result = yield self.store.get_association_from_room_alias(room_alias)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_3pe(self, kind, protocol, fields):
|
||||
services = yield self._get_services_for_3pn(protocol)
|
||||
|
||||
results = yield make_deferred_yieldable(defer.DeferredList([
|
||||
run_in_background(
|
||||
self.appservice_api.query_3pe,
|
||||
service, kind, protocol, fields,
|
||||
results = yield make_deferred_yieldable(
|
||||
defer.DeferredList(
|
||||
[
|
||||
run_in_background(
|
||||
self.appservice_api.query_3pe, service, kind, protocol, fields
|
||||
)
|
||||
for service in services
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
for service in services
|
||||
], consumeErrors=True))
|
||||
)
|
||||
|
||||
ret = []
|
||||
for (success, result) in results:
|
||||
|
@ -276,18 +280,12 @@ class ApplicationServicesHandler(object):
|
|||
|
||||
def _get_services_for_user(self, user_id):
|
||||
services = self.store.get_app_services()
|
||||
interested_list = [
|
||||
s for s in services if (
|
||||
s.is_interested_in_user(user_id)
|
||||
)
|
||||
]
|
||||
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
|
||||
return defer.succeed(interested_list)
|
||||
|
||||
def _get_services_for_3pn(self, protocol):
|
||||
services = self.store.get_app_services()
|
||||
interested_list = [
|
||||
s for s in services if s.is_interested_in_protocol(protocol)
|
||||
]
|
||||
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
|
||||
return defer.succeed(interested_list)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -134,13 +134,9 @@ class AuthHandler(BaseHandler):
|
|||
"""
|
||||
|
||||
# build a list of supported flows
|
||||
flows = [
|
||||
[login_type] for login_type in self._supported_login_types
|
||||
]
|
||||
flows = [[login_type] for login_type in self._supported_login_types]
|
||||
|
||||
result, params, _ = yield self.check_auth(
|
||||
flows, request_body, clientip,
|
||||
)
|
||||
result, params, _ = yield self.check_auth(flows, request_body, clientip)
|
||||
|
||||
# find the completed login type
|
||||
for login_type in self._supported_login_types:
|
||||
|
@ -151,9 +147,7 @@ class AuthHandler(BaseHandler):
|
|||
break
|
||||
else:
|
||||
# this can't happen
|
||||
raise Exception(
|
||||
"check_auth returned True but no successful login type",
|
||||
)
|
||||
raise Exception("check_auth returned True but no successful login type")
|
||||
|
||||
# check that the UI auth matched the access token
|
||||
if user_id != requester.user.to_string():
|
||||
|
@ -215,11 +209,11 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
authdict = None
|
||||
sid = None
|
||||
if clientdict and 'auth' in clientdict:
|
||||
authdict = clientdict['auth']
|
||||
del clientdict['auth']
|
||||
if 'session' in authdict:
|
||||
sid = authdict['session']
|
||||
if clientdict and "auth" in clientdict:
|
||||
authdict = clientdict["auth"]
|
||||
del clientdict["auth"]
|
||||
if "session" in authdict:
|
||||
sid = authdict["session"]
|
||||
session = self._get_session_info(sid)
|
||||
|
||||
if len(clientdict) > 0:
|
||||
|
@ -232,27 +226,27 @@ class AuthHandler(BaseHandler):
|
|||
# on a home server.
|
||||
# Revisit: Assumimg the REST APIs do sensible validation, the data
|
||||
# isn't arbintrary.
|
||||
session['clientdict'] = clientdict
|
||||
session["clientdict"] = clientdict
|
||||
self._save_session(session)
|
||||
elif 'clientdict' in session:
|
||||
clientdict = session['clientdict']
|
||||
elif "clientdict" in session:
|
||||
clientdict = session["clientdict"]
|
||||
|
||||
if not authdict:
|
||||
raise InteractiveAuthIncompleteError(
|
||||
self._auth_dict_for_flows(flows, session),
|
||||
self._auth_dict_for_flows(flows, session)
|
||||
)
|
||||
|
||||
if 'creds' not in session:
|
||||
session['creds'] = {}
|
||||
creds = session['creds']
|
||||
if "creds" not in session:
|
||||
session["creds"] = {}
|
||||
creds = session["creds"]
|
||||
|
||||
# check auth type currently being presented
|
||||
errordict = {}
|
||||
if 'type' in authdict:
|
||||
login_type = authdict['type']
|
||||
if "type" in authdict:
|
||||
login_type = authdict["type"]
|
||||
try:
|
||||
result = yield self._check_auth_dict(
|
||||
authdict, clientip, password_servlet=password_servlet,
|
||||
authdict, clientip, password_servlet=password_servlet
|
||||
)
|
||||
if result:
|
||||
creds[login_type] = result
|
||||
|
@ -281,16 +275,15 @@ class AuthHandler(BaseHandler):
|
|||
# and is not sensitive).
|
||||
logger.info(
|
||||
"Auth completed with creds: %r. Client dict has keys: %r",
|
||||
creds, list(clientdict)
|
||||
creds,
|
||||
list(clientdict),
|
||||
)
|
||||
defer.returnValue((creds, clientdict, session['id']))
|
||||
defer.returnValue((creds, clientdict, session["id"]))
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, session)
|
||||
ret['completed'] = list(creds)
|
||||
ret["completed"] = list(creds)
|
||||
ret.update(errordict)
|
||||
raise InteractiveAuthIncompleteError(
|
||||
ret,
|
||||
)
|
||||
raise InteractiveAuthIncompleteError(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_oob_auth(self, stagetype, authdict, clientip):
|
||||
|
@ -300,15 +293,13 @@ class AuthHandler(BaseHandler):
|
|||
"""
|
||||
if stagetype not in self.checkers:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
if 'session' not in authdict:
|
||||
if "session" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
sess = self._get_session_info(
|
||||
authdict['session']
|
||||
)
|
||||
if 'creds' not in sess:
|
||||
sess['creds'] = {}
|
||||
creds = sess['creds']
|
||||
sess = self._get_session_info(authdict["session"])
|
||||
if "creds" not in sess:
|
||||
sess["creds"] = {}
|
||||
creds = sess["creds"]
|
||||
|
||||
result = yield self.checkers[stagetype](authdict, clientip)
|
||||
if result:
|
||||
|
@ -329,10 +320,10 @@ class AuthHandler(BaseHandler):
|
|||
not send a session ID, returns None.
|
||||
"""
|
||||
sid = None
|
||||
if clientdict and 'auth' in clientdict:
|
||||
authdict = clientdict['auth']
|
||||
if 'session' in authdict:
|
||||
sid = authdict['session']
|
||||
if clientdict and "auth" in clientdict:
|
||||
authdict = clientdict["auth"]
|
||||
if "session" in authdict:
|
||||
sid = authdict["session"]
|
||||
return sid
|
||||
|
||||
def set_session_data(self, session_id, key, value):
|
||||
|
@ -347,7 +338,7 @@ class AuthHandler(BaseHandler):
|
|||
value (any): The data to store
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
sess.setdefault('serverdict', {})[key] = value
|
||||
sess.setdefault("serverdict", {})[key] = value
|
||||
self._save_session(sess)
|
||||
|
||||
def get_session_data(self, session_id, key, default=None):
|
||||
|
@ -360,7 +351,7 @@ class AuthHandler(BaseHandler):
|
|||
default (any): Value to return if the key has not been set
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
return sess.setdefault('serverdict', {}).get(key, default)
|
||||
return sess.setdefault("serverdict", {}).get(key, default)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_auth_dict(self, authdict, clientip, password_servlet=False):
|
||||
|
@ -378,15 +369,13 @@ class AuthHandler(BaseHandler):
|
|||
SynapseError if there was a problem with the request
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
login_type = authdict['type']
|
||||
login_type = authdict["type"]
|
||||
checker = self.checkers.get(login_type)
|
||||
if checker is not None:
|
||||
# XXX: Temporary workaround for having Synapse handle password resets
|
||||
# See AuthHandler.check_auth for further details
|
||||
res = yield checker(
|
||||
authdict,
|
||||
clientip=clientip,
|
||||
password_servlet=password_servlet,
|
||||
authdict, clientip=clientip, password_servlet=password_servlet
|
||||
)
|
||||
defer.returnValue(res)
|
||||
|
||||
|
@ -408,13 +397,11 @@ class AuthHandler(BaseHandler):
|
|||
# Client tried to provide captcha but didn't give the parameter:
|
||||
# bad request.
|
||||
raise LoginError(
|
||||
400, "Captcha response is required",
|
||||
errcode=Codes.CAPTCHA_NEEDED
|
||||
400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Submitting recaptcha response %s with remoteip %s",
|
||||
user_response, clientip
|
||||
"Submitting recaptcha response %s with remoteip %s", user_response, clientip
|
||||
)
|
||||
|
||||
# TODO: get this from the homeserver rather than creating a new one for
|
||||
|
@ -424,34 +411,34 @@ class AuthHandler(BaseHandler):
|
|||
resp_body = yield client.post_urlencoded_get_json(
|
||||
self.hs.config.recaptcha_siteverify_api,
|
||||
args={
|
||||
'secret': self.hs.config.recaptcha_private_key,
|
||||
'response': user_response,
|
||||
'remoteip': clientip,
|
||||
}
|
||||
"secret": self.hs.config.recaptcha_private_key,
|
||||
"response": user_response,
|
||||
"remoteip": clientip,
|
||||
},
|
||||
)
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted is silly
|
||||
data = pde.response
|
||||
resp_body = json.loads(data)
|
||||
|
||||
if 'success' in resp_body:
|
||||
if "success" in resp_body:
|
||||
# Note that we do NOT check the hostname here: we explicitly
|
||||
# intend the CAPTCHA to be presented by whatever client the
|
||||
# user is using, we just care that they have completed a CAPTCHA.
|
||||
logger.info(
|
||||
"%s reCAPTCHA from hostname %s",
|
||||
"Successful" if resp_body['success'] else "Failed",
|
||||
resp_body.get('hostname')
|
||||
"Successful" if resp_body["success"] else "Failed",
|
||||
resp_body.get("hostname"),
|
||||
)
|
||||
if resp_body['success']:
|
||||
if resp_body["success"]:
|
||||
defer.returnValue(True)
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
def _check_email_identity(self, authdict, **kwargs):
|
||||
return self._check_threepid('email', authdict, **kwargs)
|
||||
return self._check_threepid("email", authdict, **kwargs)
|
||||
|
||||
def _check_msisdn(self, authdict, **kwargs):
|
||||
return self._check_threepid('msisdn', authdict)
|
||||
return self._check_threepid("msisdn", authdict)
|
||||
|
||||
def _check_dummy_auth(self, authdict, **kwargs):
|
||||
return defer.succeed(True)
|
||||
|
@ -461,10 +448,10 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs):
|
||||
if 'threepid_creds' not in authdict:
|
||||
if "threepid_creds" not in authdict:
|
||||
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
|
||||
|
||||
threepid_creds = authdict['threepid_creds']
|
||||
threepid_creds = authdict["threepid_creds"]
|
||||
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
|
||||
|
@ -482,31 +469,36 @@ class AuthHandler(BaseHandler):
|
|||
validated=True,
|
||||
)
|
||||
|
||||
threepid = {
|
||||
"medium": row["medium"],
|
||||
"address": row["address"],
|
||||
"validated_at": row["validated_at"],
|
||||
} if row else None
|
||||
threepid = (
|
||||
{
|
||||
"medium": row["medium"],
|
||||
"address": row["address"],
|
||||
"validated_at": row["validated_at"],
|
||||
}
|
||||
if row
|
||||
else None
|
||||
)
|
||||
|
||||
if row:
|
||||
# Valid threepid returned, delete from the db
|
||||
yield self.store.delete_threepid_session(threepid_creds["sid"])
|
||||
else:
|
||||
raise SynapseError(400, "Password resets are not enabled on this homeserver")
|
||||
raise SynapseError(
|
||||
400, "Password resets are not enabled on this homeserver"
|
||||
)
|
||||
|
||||
if not threepid:
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
if threepid['medium'] != medium:
|
||||
if threepid["medium"] != medium:
|
||||
raise LoginError(
|
||||
401,
|
||||
"Expecting threepid of type '%s', got '%s'" % (
|
||||
medium, threepid['medium'],
|
||||
),
|
||||
errcode=Codes.UNAUTHORIZED
|
||||
"Expecting threepid of type '%s', got '%s'"
|
||||
% (medium, threepid["medium"]),
|
||||
errcode=Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
threepid['threepid_creds'] = authdict['threepid_creds']
|
||||
threepid["threepid_creds"] = authdict["threepid_creds"]
|
||||
|
||||
defer.returnValue(threepid)
|
||||
|
||||
|
@ -520,13 +512,14 @@ class AuthHandler(BaseHandler):
|
|||
"version": self.hs.config.user_consent_version,
|
||||
"en": {
|
||||
"name": self.hs.config.user_consent_policy_name,
|
||||
"url": "%s_matrix/consent?v=%s" % (
|
||||
"url": "%s_matrix/consent?v=%s"
|
||||
% (
|
||||
self.hs.config.public_baseurl,
|
||||
self.hs.config.user_consent_version,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _auth_dict_for_flows(self, flows, session):
|
||||
|
@ -547,9 +540,9 @@ class AuthHandler(BaseHandler):
|
|||
params[stage] = get_params[stage]()
|
||||
|
||||
return {
|
||||
"session": session['id'],
|
||||
"session": session["id"],
|
||||
"flows": [{"stages": f} for f in public_flows],
|
||||
"params": params
|
||||
"params": params,
|
||||
}
|
||||
|
||||
def _get_session_info(self, session_id):
|
||||
|
@ -560,9 +553,7 @@ class AuthHandler(BaseHandler):
|
|||
# create a new session
|
||||
while session_id is None or session_id in self.sessions:
|
||||
session_id = stringutils.random_string(24)
|
||||
self.sessions[session_id] = {
|
||||
"id": session_id,
|
||||
}
|
||||
self.sessions[session_id] = {"id": session_id}
|
||||
|
||||
return self.sessions[session_id]
|
||||
|
||||
|
@ -652,7 +643,8 @@ class AuthHandler(BaseHandler):
|
|||
logger.warn(
|
||||
"Attempted to login as %s but it matches more than one user "
|
||||
"inexactly: %r",
|
||||
user_id, user_infos.keys()
|
||||
user_id,
|
||||
user_infos.keys(),
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
|
@ -690,12 +682,10 @@ class AuthHandler(BaseHandler):
|
|||
user is too high too proceed.
|
||||
"""
|
||||
|
||||
if username.startswith('@'):
|
||||
if username.startswith("@"):
|
||||
qualified_user_id = username
|
||||
else:
|
||||
qualified_user_id = UserID(
|
||||
username, self.hs.hostname
|
||||
).to_string()
|
||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||
|
||||
self.ratelimit_login_per_account(qualified_user_id)
|
||||
|
||||
|
@ -713,17 +703,15 @@ class AuthHandler(BaseHandler):
|
|||
raise SynapseError(400, "Missing parameter: password")
|
||||
|
||||
for provider in self.password_providers:
|
||||
if (hasattr(provider, "check_password")
|
||||
and login_type == LoginType.PASSWORD):
|
||||
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
|
||||
known_login_type = True
|
||||
is_valid = yield provider.check_password(
|
||||
qualified_user_id, password,
|
||||
)
|
||||
is_valid = yield provider.check_password(qualified_user_id, password)
|
||||
if is_valid:
|
||||
defer.returnValue((qualified_user_id, None))
|
||||
|
||||
if (not hasattr(provider, "get_supported_login_types")
|
||||
or not hasattr(provider, "check_auth")):
|
||||
if not hasattr(provider, "get_supported_login_types") or not hasattr(
|
||||
provider, "check_auth"
|
||||
):
|
||||
# this password provider doesn't understand custom login types
|
||||
continue
|
||||
|
||||
|
@ -744,15 +732,12 @@ class AuthHandler(BaseHandler):
|
|||
login_dict[f] = login_submission[f]
|
||||
if missing_fields:
|
||||
raise SynapseError(
|
||||
400, "Missing parameters for login type %s: %s" % (
|
||||
login_type,
|
||||
missing_fields,
|
||||
),
|
||||
400,
|
||||
"Missing parameters for login type %s: %s"
|
||||
% (login_type, missing_fields),
|
||||
)
|
||||
|
||||
result = yield provider.check_auth(
|
||||
username, login_type, login_dict,
|
||||
)
|
||||
result = yield provider.check_auth(username, login_type, login_dict)
|
||||
if result:
|
||||
if isinstance(result, str):
|
||||
result = (result, None)
|
||||
|
@ -762,7 +747,7 @@ class AuthHandler(BaseHandler):
|
|||
known_login_type = True
|
||||
|
||||
canonical_user_id = yield self._check_local_password(
|
||||
qualified_user_id, password,
|
||||
qualified_user_id, password
|
||||
)
|
||||
|
||||
if canonical_user_id:
|
||||
|
@ -773,7 +758,8 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
# unknown username or invalid password.
|
||||
self._failed_attempts_ratelimiter.ratelimit(
|
||||
qualified_user_id.lower(), time_now_s=self._clock.time(),
|
||||
qualified_user_id.lower(),
|
||||
time_now_s=self._clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
update=True,
|
||||
|
@ -781,10 +767,7 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
# We raise a 403 here, but note that if we're doing user-interactive
|
||||
# login, it turns all LoginErrors into a 401 anyway.
|
||||
raise LoginError(
|
||||
403, "Invalid password",
|
||||
errcode=Codes.FORBIDDEN
|
||||
)
|
||||
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_password_provider_3pid(self, medium, address, password):
|
||||
|
@ -810,9 +793,7 @@ class AuthHandler(BaseHandler):
|
|||
# success, to a str (which is the user_id) or a tuple of
|
||||
# (user_id, callback_func), where callback_func should be run
|
||||
# after we've finished everything else
|
||||
result = yield provider.check_3pid_auth(
|
||||
medium, address, password,
|
||||
)
|
||||
result = yield provider.check_3pid_auth(medium, address, password)
|
||||
if result:
|
||||
# Check if the return value is a str or a tuple
|
||||
if isinstance(result, str):
|
||||
|
@ -853,8 +834,7 @@ class AuthHandler(BaseHandler):
|
|||
@defer.inlineCallbacks
|
||||
def issue_access_token(self, user_id, device_id=None):
|
||||
access_token = self.macaroon_gen.generate_access_token(user_id)
|
||||
yield self.store.add_access_token_to_user(user_id, access_token,
|
||||
device_id)
|
||||
yield self.store.add_access_token_to_user(user_id, access_token, device_id)
|
||||
defer.returnValue(access_token)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -896,12 +876,13 @@ class AuthHandler(BaseHandler):
|
|||
# delete pushers associated with this access token
|
||||
if user_info["token_id"] is not None:
|
||||
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||
str(user_info["user"]), (user_info["token_id"], )
|
||||
str(user_info["user"]), (user_info["token_id"],)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
|
||||
device_id=None):
|
||||
def delete_access_tokens_for_user(
|
||||
self, user_id, except_token_id=None, device_id=None
|
||||
):
|
||||
"""Invalidate access tokens belonging to a user
|
||||
|
||||
Args:
|
||||
|
@ -915,7 +896,7 @@ class AuthHandler(BaseHandler):
|
|||
Deferred
|
||||
"""
|
||||
tokens_and_devices = yield self.store.user_delete_access_tokens(
|
||||
user_id, except_token_id=except_token_id, device_id=device_id,
|
||||
user_id, except_token_id=except_token_id, device_id=device_id
|
||||
)
|
||||
|
||||
# see if any of our auth providers want to know about this
|
||||
|
@ -923,14 +904,12 @@ class AuthHandler(BaseHandler):
|
|||
if hasattr(provider, "on_logged_out"):
|
||||
for token, token_id, device_id in tokens_and_devices:
|
||||
yield provider.on_logged_out(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
access_token=token,
|
||||
user_id=user_id, device_id=device_id, access_token=token
|
||||
)
|
||||
|
||||
# delete pushers associated with the access tokens
|
||||
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||
user_id, (token_id for _, token_id, _ in tokens_and_devices),
|
||||
user_id, (token_id for _, token_id, _ in tokens_and_devices)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -944,12 +923,11 @@ class AuthHandler(BaseHandler):
|
|||
# of specific types of threepid (and fixes the fact that checking
|
||||
# for the presence of an email address during password reset was
|
||||
# case sensitive).
|
||||
if medium == 'email':
|
||||
if medium == "email":
|
||||
address = address.lower()
|
||||
|
||||
yield self.store.user_add_threepid(
|
||||
user_id, medium, address, validated_at,
|
||||
self.hs.get_clock().time_msec()
|
||||
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -973,22 +951,15 @@ class AuthHandler(BaseHandler):
|
|||
"""
|
||||
|
||||
# 'Canonicalise' email addresses as per above
|
||||
if medium == 'email':
|
||||
if medium == "email":
|
||||
address = address.lower()
|
||||
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
result = yield identity_handler.try_unbind_threepid(
|
||||
user_id,
|
||||
{
|
||||
'medium': medium,
|
||||
'address': address,
|
||||
'id_server': id_server,
|
||||
},
|
||||
user_id, {"medium": medium, "address": address, "id_server": id_server}
|
||||
)
|
||||
|
||||
yield self.store.user_delete_threepid(
|
||||
user_id, medium, address,
|
||||
)
|
||||
yield self.store.user_delete_threepid(user_id, medium, address)
|
||||
defer.returnValue(result)
|
||||
|
||||
def _save_session(self, session):
|
||||
|
@ -1006,14 +977,15 @@ class AuthHandler(BaseHandler):
|
|||
Returns:
|
||||
Deferred(unicode): Hashed password.
|
||||
"""
|
||||
|
||||
def _do_hash():
|
||||
# Normalise the Unicode in the password
|
||||
pw = unicodedata.normalize("NFKC", password)
|
||||
|
||||
return bcrypt.hashpw(
|
||||
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
|
||||
pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
|
||||
bcrypt.gensalt(self.bcrypt_rounds),
|
||||
).decode('ascii')
|
||||
).decode("ascii")
|
||||
|
||||
return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash)
|
||||
|
||||
|
@ -1027,18 +999,19 @@ class AuthHandler(BaseHandler):
|
|||
Returns:
|
||||
Deferred(bool): Whether self.hash(password) == stored_hash.
|
||||
"""
|
||||
|
||||
def _do_validate_hash():
|
||||
# Normalise the Unicode in the password
|
||||
pw = unicodedata.normalize("NFKC", password)
|
||||
|
||||
return bcrypt.checkpw(
|
||||
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
|
||||
stored_hash
|
||||
pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
|
||||
stored_hash,
|
||||
)
|
||||
|
||||
if stored_hash:
|
||||
if not isinstance(stored_hash, bytes):
|
||||
stored_hash = stored_hash.encode('ascii')
|
||||
stored_hash = stored_hash.encode("ascii")
|
||||
|
||||
return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
|
||||
else:
|
||||
|
@ -1058,14 +1031,16 @@ class AuthHandler(BaseHandler):
|
|||
for this user is too high too proceed.
|
||||
"""
|
||||
self._failed_attempts_ratelimiter.ratelimit(
|
||||
user_id.lower(), time_now_s=self._clock.time(),
|
||||
user_id.lower(),
|
||||
time_now_s=self._clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
update=False,
|
||||
)
|
||||
|
||||
self._account_ratelimiter.ratelimit(
|
||||
user_id.lower(), time_now_s=self._clock.time(),
|
||||
user_id.lower(),
|
||||
time_now_s=self._clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_account.per_second,
|
||||
burst_count=self.hs.config.rc_login_account.burst_count,
|
||||
update=True,
|
||||
|
@ -1083,9 +1058,9 @@ class MacaroonGenerator(object):
|
|||
macaroon.add_first_party_caveat("type = access")
|
||||
# Include a nonce, to make sure that each login gets a different
|
||||
# access token.
|
||||
macaroon.add_first_party_caveat("nonce = %s" % (
|
||||
stringutils.random_string_with_symbols(16),
|
||||
))
|
||||
macaroon.add_first_party_caveat(
|
||||
"nonce = %s" % (stringutils.random_string_with_symbols(16),)
|
||||
)
|
||||
for caveat in extra_caveats:
|
||||
macaroon.add_first_party_caveat(caveat)
|
||||
return macaroon.serialize()
|
||||
|
@ -1116,7 +1091,8 @@ class MacaroonGenerator(object):
|
|||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
key=self.hs.config.macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
return macaroon
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017, 2018 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +28,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class DeactivateAccountHandler(BaseHandler):
|
||||
"""Handler which deals with deactivating user accounts."""
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DeactivateAccountHandler, self).__init__(hs)
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
@ -42,6 +44,8 @@ class DeactivateAccountHandler(BaseHandler):
|
|||
# it left off (if it has work left to do).
|
||||
hs.get_reactor().callWhenRunning(self._start_user_parting)
|
||||
|
||||
self._account_validity_enabled = hs.config.account_validity.enabled
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deactivate_account(self, user_id, erase_data, id_server=None):
|
||||
"""Deactivate a user's account
|
||||
|
@ -75,9 +79,9 @@ class DeactivateAccountHandler(BaseHandler):
|
|||
result = yield self._identity_handler.try_unbind_threepid(
|
||||
user_id,
|
||||
{
|
||||
'medium': threepid['medium'],
|
||||
'address': threepid['address'],
|
||||
'id_server': id_server,
|
||||
"medium": threepid["medium"],
|
||||
"address": threepid["address"],
|
||||
"id_server": id_server,
|
||||
},
|
||||
)
|
||||
identity_server_supports_unbinding &= result
|
||||
|
@ -86,7 +90,7 @@ class DeactivateAccountHandler(BaseHandler):
|
|||
logger.exception("Failed to remove threepid from ID server")
|
||||
raise SynapseError(400, "Failed to remove threepid from ID server")
|
||||
yield self.store.user_delete_threepid(
|
||||
user_id, threepid['medium'], threepid['address'],
|
||||
user_id, threepid["medium"], threepid["address"]
|
||||
)
|
||||
|
||||
# delete any devices belonging to the user, which will also
|
||||
|
@ -114,6 +118,13 @@ class DeactivateAccountHandler(BaseHandler):
|
|||
# parts users from rooms (if it isn't already running)
|
||||
self._start_user_parting()
|
||||
|
||||
# Remove all information on the user from the account_validity table.
|
||||
if self._account_validity_enabled:
|
||||
yield self.store.delete_account_validity_for_user(user_id)
|
||||
|
||||
# Mark the user as deactivated.
|
||||
yield self.store.set_user_deactivated_status(user_id, True)
|
||||
|
||||
defer.returnValue(identity_server_supports_unbinding)
|
||||
|
||||
def _start_user_parting(self):
|
||||
|
@ -173,5 +184,6 @@ class DeactivateAccountHandler(BaseHandler):
|
|||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to part user %r from room %r: ignoring and continuing",
|
||||
user_id, room_id,
|
||||
user_id,
|
||||
room_id,
|
||||
)
|
||||
|
|
|
@ -58,9 +58,7 @@ class DeviceWorkerHandler(BaseHandler):
|
|||
|
||||
device_map = yield self.store.get_devices_by_user(user_id)
|
||||
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
user_id, device_id=None
|
||||
)
|
||||
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)
|
||||
|
||||
devices = list(device_map.values())
|
||||
for device in devices:
|
||||
|
@ -85,9 +83,7 @@ class DeviceWorkerHandler(BaseHandler):
|
|||
device = yield self.store.get_device(user_id, device_id)
|
||||
except errors.StoreError:
|
||||
raise errors.NotFoundError
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
user_id, device_id,
|
||||
)
|
||||
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
|
||||
_update_device_from_client_ips(device, ips)
|
||||
defer.returnValue(device)
|
||||
|
||||
|
@ -114,13 +110,11 @@ class DeviceWorkerHandler(BaseHandler):
|
|||
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
|
||||
|
||||
member_events = yield self.store.get_membership_changes_for_user(
|
||||
user_id, from_token.room_key, now_room_key,
|
||||
user_id, from_token.room_key, now_room_key
|
||||
)
|
||||
rooms_changed.update(event.room_id for event in member_events)
|
||||
|
||||
stream_ordering = RoomStreamToken.parse_stream_token(
|
||||
from_token.room_key
|
||||
).stream
|
||||
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
|
||||
|
||||
possibly_changed = set(changed)
|
||||
possibly_left = set()
|
||||
|
@ -206,10 +200,9 @@ class DeviceWorkerHandler(BaseHandler):
|
|||
possibly_joined = []
|
||||
possibly_left = []
|
||||
|
||||
defer.returnValue({
|
||||
"changed": list(possibly_joined),
|
||||
"left": list(possibly_left),
|
||||
})
|
||||
defer.returnValue(
|
||||
{"changed": list(possibly_joined), "left": list(possibly_left)}
|
||||
)
|
||||
|
||||
|
||||
class DeviceHandler(DeviceWorkerHandler):
|
||||
|
@ -223,17 +216,18 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
federation_registry = hs.get_federation_registry()
|
||||
|
||||
federation_registry.register_edu_handler(
|
||||
"m.device_list_update", self._edu_updater.incoming_device_list_update,
|
||||
"m.device_list_update", self._edu_updater.incoming_device_list_update
|
||||
)
|
||||
federation_registry.register_query_handler(
|
||||
"user_devices", self.on_federation_query_user_devices,
|
||||
"user_devices", self.on_federation_query_user_devices
|
||||
)
|
||||
|
||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_device_registered(self, user_id, device_id,
|
||||
initial_device_display_name=None):
|
||||
def check_device_registered(
|
||||
self, user_id, device_id, initial_device_display_name=None
|
||||
):
|
||||
"""
|
||||
If the given device has not been registered, register it with the
|
||||
supplied display name.
|
||||
|
@ -297,12 +291,10 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
raise
|
||||
|
||||
yield self._auth_handler.delete_access_tokens_for_user(
|
||||
user_id, device_id=device_id,
|
||||
user_id, device_id=device_id
|
||||
)
|
||||
|
||||
yield self.store.delete_e2e_keys_by_device(
|
||||
user_id=user_id, device_id=device_id
|
||||
)
|
||||
yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
|
||||
|
||||
yield self.notify_device_update(user_id, [device_id])
|
||||
|
||||
|
@ -349,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
# considered as part of a critical path.
|
||||
for device_id in device_ids:
|
||||
yield self._auth_handler.delete_access_tokens_for_user(
|
||||
user_id, device_id=device_id,
|
||||
user_id, device_id=device_id
|
||||
)
|
||||
yield self.store.delete_e2e_keys_by_device(
|
||||
user_id=user_id, device_id=device_id
|
||||
|
@ -372,9 +364,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
try:
|
||||
yield self.store.update_device(
|
||||
user_id,
|
||||
device_id,
|
||||
new_display_name=content.get("display_name")
|
||||
user_id, device_id, new_display_name=content.get("display_name")
|
||||
)
|
||||
yield self.notify_device_update(user_id, [device_id])
|
||||
except errors.StoreError as e:
|
||||
|
@ -404,29 +394,26 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
for device_id in device_ids:
|
||||
logger.debug(
|
||||
"Notifying about update %r/%r, ID: %r", user_id, device_id,
|
||||
position,
|
||||
"Notifying about update %r/%r, ID: %r", user_id, device_id, position
|
||||
)
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
|
||||
yield self.notifier.on_new_event(
|
||||
"device_list_key", position, rooms=room_ids,
|
||||
)
|
||||
yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
|
||||
|
||||
if hosts:
|
||||
logger.info("Sending device list update notif for %r to: %r", user_id, hosts)
|
||||
logger.info(
|
||||
"Sending device list update notif for %r to: %r", user_id, hosts
|
||||
)
|
||||
for host in hosts:
|
||||
self.federation_sender.send_device_messages(host)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_federation_query_user_devices(self, user_id):
|
||||
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"stream_id": stream_id,
|
||||
"devices": devices,
|
||||
})
|
||||
defer.returnValue(
|
||||
{"user_id": user_id, "stream_id": stream_id, "devices": devices}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_left_room(self, user, room_id):
|
||||
|
@ -440,10 +427,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
def _update_device_from_client_ips(device, client_ips):
|
||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||
device.update({
|
||||
"last_seen_ts": ip.get("last_seen"),
|
||||
"last_seen_ip": ip.get("ip"),
|
||||
})
|
||||
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
|
||||
|
||||
|
||||
class DeviceListEduUpdater(object):
|
||||
|
@ -481,13 +465,15 @@ class DeviceListEduUpdater(object):
|
|||
device_id = edu_content.pop("device_id")
|
||||
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
|
||||
prev_ids = edu_content.pop("prev_id", [])
|
||||
prev_ids = [str(p) for p in prev_ids] # They may come as ints
|
||||
prev_ids = [str(p) for p in prev_ids] # They may come as ints
|
||||
|
||||
if get_domain_from_id(user_id) != origin:
|
||||
# TODO: Raise?
|
||||
logger.warning(
|
||||
"Got device list update edu for %r/%r from %r",
|
||||
user_id, device_id, origin,
|
||||
user_id,
|
||||
device_id,
|
||||
origin,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -497,13 +483,12 @@ class DeviceListEduUpdater(object):
|
|||
# probably won't get any further updates.
|
||||
logger.warning(
|
||||
"Got device list update edu for %r/%r, but don't share a room",
|
||||
user_id, device_id,
|
||||
user_id,
|
||||
device_id,
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"Received device list update for %r/%r", user_id, device_id,
|
||||
)
|
||||
logger.debug("Received device list update for %r/%r", user_id, device_id)
|
||||
|
||||
self._pending_updates.setdefault(user_id, []).append(
|
||||
(device_id, stream_id, prev_ids, edu_content)
|
||||
|
@ -525,7 +510,10 @@ class DeviceListEduUpdater(object):
|
|||
for device_id, stream_id, prev_ids, content in pending_updates:
|
||||
logger.debug(
|
||||
"Handling update %r/%r, ID: %r, prev: %r ",
|
||||
user_id, device_id, stream_id, prev_ids,
|
||||
user_id,
|
||||
device_id,
|
||||
stream_id,
|
||||
prev_ids,
|
||||
)
|
||||
|
||||
# Given a list of updates we check if we need to resync. This
|
||||
|
@ -540,13 +528,13 @@ class DeviceListEduUpdater(object):
|
|||
try:
|
||||
result = yield self.federation.query_user_devices(origin, user_id)
|
||||
except (
|
||||
NotRetryingDestination, RequestSendFailed, HttpResponseException,
|
||||
NotRetryingDestination,
|
||||
RequestSendFailed,
|
||||
HttpResponseException,
|
||||
):
|
||||
# TODO: Remember that we are now out of sync and try again
|
||||
# later
|
||||
logger.warn(
|
||||
"Failed to handle device list update for %s", user_id,
|
||||
)
|
||||
logger.warn("Failed to handle device list update for %s", user_id)
|
||||
# We abort on exceptions rather than accepting the update
|
||||
# as otherwise synapse will 'forget' that its device list
|
||||
# is out of date. If we bail then we will retry the resync
|
||||
|
@ -582,18 +570,21 @@ class DeviceListEduUpdater(object):
|
|||
if len(devices) > 1000:
|
||||
logger.warn(
|
||||
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
|
||||
user_id, len(devices)
|
||||
user_id,
|
||||
len(devices),
|
||||
)
|
||||
devices = []
|
||||
|
||||
for device in devices:
|
||||
logger.debug(
|
||||
"Handling resync update %r/%r, ID: %r",
|
||||
user_id, device["device_id"], stream_id,
|
||||
user_id,
|
||||
device["device_id"],
|
||||
stream_id,
|
||||
)
|
||||
|
||||
yield self.store.update_remote_device_list_cache(
|
||||
user_id, devices, stream_id,
|
||||
user_id, devices, stream_id
|
||||
)
|
||||
device_ids = [device["device_id"] for device in devices]
|
||||
yield self.device_handler.notify_device_update(user_id, device_ids)
|
||||
|
@ -606,7 +597,7 @@ class DeviceListEduUpdater(object):
|
|||
# change (because of the single prev_id matching the current cache)
|
||||
for device_id, stream_id, prev_ids, content in pending_updates:
|
||||
yield self.store.update_remote_device_list_cache_entry(
|
||||
user_id, device_id, content, stream_id,
|
||||
user_id, device_id, content, stream_id
|
||||
)
|
||||
|
||||
yield self.device_handler.notify_device_update(
|
||||
|
@ -624,14 +615,9 @@ class DeviceListEduUpdater(object):
|
|||
"""
|
||||
seen_updates = self._seen_updates.get(user_id, set())
|
||||
|
||||
extremity = yield self.store.get_device_list_last_stream_id_for_remote(
|
||||
user_id
|
||||
)
|
||||
extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)
|
||||
|
||||
logger.debug(
|
||||
"Current extremity for %r: %r",
|
||||
user_id, extremity,
|
||||
)
|
||||
logger.debug("Current extremity for %r: %r", user_id, extremity)
|
||||
|
||||
stream_id_in_updates = set() # stream_ids in updates list
|
||||
for _, stream_id, prev_ids, _ in updates:
|
||||
|
|
|
@ -25,7 +25,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DeviceMessageHandler(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
|
@ -47,15 +46,15 @@ class DeviceMessageHandler(object):
|
|||
if origin != get_domain_from_id(sender_user_id):
|
||||
logger.warn(
|
||||
"Dropping device message from %r with spoofed sender %r",
|
||||
origin, sender_user_id
|
||||
origin,
|
||||
sender_user_id,
|
||||
)
|
||||
message_type = content["type"]
|
||||
message_id = content["message_id"]
|
||||
for user_id, by_device in content["messages"].items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if not self.is_mine(UserID.from_string(user_id)):
|
||||
logger.warning("Request for keys for non-local user %s",
|
||||
user_id)
|
||||
logger.warning("Request for keys for non-local user %s", user_id)
|
||||
raise SynapseError(400, "Not a user here")
|
||||
|
||||
messages_by_device = {
|
||||
|
|
|
@ -36,7 +36,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DirectoryHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DirectoryHandler, self).__init__(hs)
|
||||
|
||||
|
@ -77,15 +76,19 @@ class DirectoryHandler(BaseHandler):
|
|||
raise SynapseError(400, "Failed to get server list")
|
||||
|
||||
yield self.store.create_room_alias_association(
|
||||
room_alias,
|
||||
room_id,
|
||||
servers,
|
||||
creator=creator,
|
||||
room_alias, room_id, servers, creator=creator
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_association(self, requester, room_alias, room_id, servers=None,
|
||||
send_event=True, check_membership=True):
|
||||
def create_association(
|
||||
self,
|
||||
requester,
|
||||
room_alias,
|
||||
room_id,
|
||||
servers=None,
|
||||
send_event=True,
|
||||
check_membership=True,
|
||||
):
|
||||
"""Attempt to create a new alias
|
||||
|
||||
Args:
|
||||
|
@ -115,49 +118,40 @@ class DirectoryHandler(BaseHandler):
|
|||
if service:
|
||||
if not service.is_interested_in_alias(room_alias.to_string()):
|
||||
raise SynapseError(
|
||||
400, "This application service has not reserved"
|
||||
" this kind of alias.", errcode=Codes.EXCLUSIVE
|
||||
400,
|
||||
"This application service has not reserved" " this kind of alias.",
|
||||
errcode=Codes.EXCLUSIVE,
|
||||
)
|
||||
else:
|
||||
if self.require_membership and check_membership:
|
||||
rooms_for_user = yield self.store.get_rooms_for_user(user_id)
|
||||
if room_id not in rooms_for_user:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You must be in the room to create an alias for it",
|
||||
403, "You must be in the room to create an alias for it"
|
||||
)
|
||||
|
||||
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
|
||||
raise AuthError(
|
||||
403, "This user is not permitted to create this alias",
|
||||
)
|
||||
raise AuthError(403, "This user is not permitted to create this alias")
|
||||
|
||||
if not self.config.is_alias_creation_allowed(
|
||||
user_id, room_id, room_alias.to_string(),
|
||||
user_id, room_id, room_alias.to_string()
|
||||
):
|
||||
# Lets just return a generic message, as there may be all sorts of
|
||||
# reasons why we said no. TODO: Allow configurable error messages
|
||||
# per alias creation rule?
|
||||
raise SynapseError(
|
||||
403, "Not allowed to create alias",
|
||||
)
|
||||
raise SynapseError(403, "Not allowed to create alias")
|
||||
|
||||
can_create = yield self.can_modify_alias(
|
||||
room_alias,
|
||||
user_id=user_id
|
||||
)
|
||||
can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
|
||||
if not can_create:
|
||||
raise AuthError(
|
||||
400, "This alias is reserved by an application service.",
|
||||
errcode=Codes.EXCLUSIVE
|
||||
400,
|
||||
"This alias is reserved by an application service.",
|
||||
errcode=Codes.EXCLUSIVE,
|
||||
)
|
||||
|
||||
yield self._create_association(room_alias, room_id, servers, creator=user_id)
|
||||
if send_event:
|
||||
yield self.send_room_alias_update_event(
|
||||
requester,
|
||||
room_id
|
||||
)
|
||||
yield self.send_room_alias_update_event(requester, room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_association(self, requester, room_alias, send_event=True):
|
||||
|
@ -194,34 +188,24 @@ class DirectoryHandler(BaseHandler):
|
|||
raise
|
||||
|
||||
if not can_delete:
|
||||
raise AuthError(
|
||||
403, "You don't have permission to delete the alias.",
|
||||
)
|
||||
raise AuthError(403, "You don't have permission to delete the alias.")
|
||||
|
||||
can_delete = yield self.can_modify_alias(
|
||||
room_alias,
|
||||
user_id=user_id
|
||||
)
|
||||
can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
|
||||
if not can_delete:
|
||||
raise SynapseError(
|
||||
400, "This alias is reserved by an application service.",
|
||||
errcode=Codes.EXCLUSIVE
|
||||
400,
|
||||
"This alias is reserved by an application service.",
|
||||
errcode=Codes.EXCLUSIVE,
|
||||
)
|
||||
|
||||
room_id = yield self._delete_association(room_alias)
|
||||
|
||||
try:
|
||||
if send_event:
|
||||
yield self.send_room_alias_update_event(
|
||||
requester,
|
||||
room_id
|
||||
)
|
||||
yield self.send_room_alias_update_event(requester, room_id)
|
||||
|
||||
yield self._update_canonical_alias(
|
||||
requester,
|
||||
requester.user.to_string(),
|
||||
room_id,
|
||||
room_alias,
|
||||
requester, requester.user.to_string(), room_id, room_alias
|
||||
)
|
||||
except AuthError as e:
|
||||
logger.info("Failed to update alias events: %s", e)
|
||||
|
@ -234,7 +218,7 @@ class DirectoryHandler(BaseHandler):
|
|||
raise SynapseError(
|
||||
400,
|
||||
"This application service has not reserved this kind of alias",
|
||||
errcode=Codes.EXCLUSIVE
|
||||
errcode=Codes.EXCLUSIVE,
|
||||
)
|
||||
yield self._delete_association(room_alias)
|
||||
|
||||
|
@ -251,9 +235,7 @@ class DirectoryHandler(BaseHandler):
|
|||
def get_association(self, room_alias):
|
||||
room_id = None
|
||||
if self.hs.is_mine(room_alias):
|
||||
result = yield self.get_association_from_room_alias(
|
||||
room_alias
|
||||
)
|
||||
result = yield self.get_association_from_room_alias(room_alias)
|
||||
|
||||
if result:
|
||||
room_id = result.room_id
|
||||
|
@ -263,9 +245,7 @@ class DirectoryHandler(BaseHandler):
|
|||
result = yield self.federation.make_query(
|
||||
destination=room_alias.domain,
|
||||
query_type="directory",
|
||||
args={
|
||||
"room_alias": room_alias.to_string(),
|
||||
},
|
||||
args={"room_alias": room_alias.to_string()},
|
||||
retry_on_dns_fail=False,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
@ -284,7 +264,7 @@ class DirectoryHandler(BaseHandler):
|
|||
raise SynapseError(
|
||||
404,
|
||||
"Room alias %s not found" % (room_alias.to_string(),),
|
||||
Codes.NOT_FOUND
|
||||
Codes.NOT_FOUND,
|
||||
)
|
||||
|
||||
users = yield self.state.get_current_users_in_room(room_id)
|
||||
|
@ -293,41 +273,28 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
# If this server is in the list of servers, return it first.
|
||||
if self.server_name in servers:
|
||||
servers = (
|
||||
[self.server_name] +
|
||||
[s for s in servers if s != self.server_name]
|
||||
)
|
||||
servers = [self.server_name] + [s for s in servers if s != self.server_name]
|
||||
else:
|
||||
servers = list(servers)
|
||||
|
||||
defer.returnValue({
|
||||
"room_id": room_id,
|
||||
"servers": servers,
|
||||
})
|
||||
defer.returnValue({"room_id": room_id, "servers": servers})
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_directory_query(self, args):
|
||||
room_alias = RoomAlias.from_string(args["room_alias"])
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(
|
||||
400, "Room Alias is not hosted on this Home Server"
|
||||
)
|
||||
raise SynapseError(400, "Room Alias is not hosted on this Home Server")
|
||||
|
||||
result = yield self.get_association_from_room_alias(
|
||||
room_alias
|
||||
)
|
||||
result = yield self.get_association_from_room_alias(room_alias)
|
||||
|
||||
if result is not None:
|
||||
defer.returnValue({
|
||||
"room_id": result.room_id,
|
||||
"servers": result.servers,
|
||||
})
|
||||
defer.returnValue({"room_id": result.room_id, "servers": result.servers})
|
||||
else:
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Room alias %r not found" % (room_alias.to_string(),),
|
||||
Codes.NOT_FOUND
|
||||
Codes.NOT_FOUND,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -343,7 +310,7 @@ class DirectoryHandler(BaseHandler):
|
|||
"sender": requester.user.to_string(),
|
||||
"content": {"aliases": aliases},
|
||||
},
|
||||
ratelimit=False
|
||||
ratelimit=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -365,14 +332,12 @@ class DirectoryHandler(BaseHandler):
|
|||
"sender": user_id,
|
||||
"content": {},
|
||||
},
|
||||
ratelimit=False
|
||||
ratelimit=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_association_from_room_alias(self, room_alias):
|
||||
result = yield self.store.get_association_from_room_alias(
|
||||
room_alias
|
||||
)
|
||||
result = yield self.store.get_association_from_room_alias(room_alias)
|
||||
if not result:
|
||||
# Query AS to see if it exists
|
||||
as_handler = self.appservice_handler
|
||||
|
@ -421,8 +386,7 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
if not self.spam_checker.user_may_publish_room(user_id, room_id):
|
||||
raise AuthError(
|
||||
403,
|
||||
"This user is not permitted to publish rooms to the room list"
|
||||
403, "This user is not permitted to publish rooms to the room list"
|
||||
)
|
||||
|
||||
if requester.is_guest:
|
||||
|
@ -434,8 +398,7 @@ class DirectoryHandler(BaseHandler):
|
|||
if visibility == "public" and not self.enable_room_list_search:
|
||||
# The room list has been disabled.
|
||||
raise AuthError(
|
||||
403,
|
||||
"This user is not permitted to publish rooms to the room list"
|
||||
403, "This user is not permitted to publish rooms to the room list"
|
||||
)
|
||||
|
||||
room = yield self.store.get_room(room_id)
|
||||
|
@ -452,20 +415,19 @@ class DirectoryHandler(BaseHandler):
|
|||
room_aliases.append(canonical_alias)
|
||||
|
||||
if not self.config.is_publishing_room_allowed(
|
||||
user_id, room_id, room_aliases,
|
||||
user_id, room_id, room_aliases
|
||||
):
|
||||
# Lets just return a generic message, as there may be all sorts of
|
||||
# reasons why we said no. TODO: Allow configurable error messages
|
||||
# per alias creation rule?
|
||||
raise SynapseError(
|
||||
403, "Not allowed to publish room",
|
||||
)
|
||||
raise SynapseError(403, "Not allowed to publish room")
|
||||
|
||||
yield self.store.set_room_is_public(room_id, making_public)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def edit_published_appservice_room_list(self, appservice_id, network_id,
|
||||
room_id, visibility):
|
||||
def edit_published_appservice_room_list(
|
||||
self, appservice_id, network_id, room_id, visibility
|
||||
):
|
||||
"""Add or remove a room from the appservice/network specific public
|
||||
room list.
|
||||
|
||||
|
|
|
@ -99,9 +99,7 @@ class E2eKeysHandler(object):
|
|||
query_list.append((user_id, None))
|
||||
|
||||
user_ids_not_in_cache, remote_results = (
|
||||
yield self.store.get_user_devices_from_cache(
|
||||
query_list
|
||||
)
|
||||
yield self.store.get_user_devices_from_cache(query_list)
|
||||
)
|
||||
for user_id, devices in iteritems(remote_results):
|
||||
user_devices = results.setdefault(user_id, {})
|
||||
|
@ -126,9 +124,7 @@ class E2eKeysHandler(object):
|
|||
destination_query = remote_queries_not_in_cache[destination]
|
||||
try:
|
||||
remote_result = yield self.federation.query_client_keys(
|
||||
destination,
|
||||
{"device_keys": destination_query},
|
||||
timeout=timeout
|
||||
destination, {"device_keys": destination_query}, timeout=timeout
|
||||
)
|
||||
|
||||
for user_id, keys in remote_result["device_keys"].items():
|
||||
|
@ -138,14 +134,17 @@ class E2eKeysHandler(object):
|
|||
except Exception as e:
|
||||
failures[destination] = _exception_to_failure(e)
|
||||
|
||||
yield make_deferred_yieldable(defer.gatherResults([
|
||||
run_in_background(do_remote_query, destination)
|
||||
for destination in remote_queries_not_in_cache
|
||||
], consumeErrors=True))
|
||||
yield make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(do_remote_query, destination)
|
||||
for destination in remote_queries_not_in_cache
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"device_keys": results, "failures": failures,
|
||||
})
|
||||
defer.returnValue({"device_keys": results, "failures": failures})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_local_devices(self, query):
|
||||
|
@ -165,8 +164,7 @@ class E2eKeysHandler(object):
|
|||
for user_id, device_ids in query.items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if not self.is_mine(UserID.from_string(user_id)):
|
||||
logger.warning("Request for keys for non-local user %s",
|
||||
user_id)
|
||||
logger.warning("Request for keys for non-local user %s", user_id)
|
||||
raise SynapseError(400, "Not a user here")
|
||||
|
||||
if not device_ids:
|
||||
|
@ -231,9 +229,7 @@ class E2eKeysHandler(object):
|
|||
device_keys = remote_queries[destination]
|
||||
try:
|
||||
remote_result = yield self.federation.claim_client_keys(
|
||||
destination,
|
||||
{"one_time_keys": device_keys},
|
||||
timeout=timeout
|
||||
destination, {"one_time_keys": device_keys}, timeout=timeout
|
||||
)
|
||||
for user_id, keys in remote_result["one_time_keys"].items():
|
||||
if user_id in device_keys:
|
||||
|
@ -241,25 +237,29 @@ class E2eKeysHandler(object):
|
|||
except Exception as e:
|
||||
failures[destination] = _exception_to_failure(e)
|
||||
|
||||
yield make_deferred_yieldable(defer.gatherResults([
|
||||
run_in_background(claim_client_keys, destination)
|
||||
for destination in remote_queries
|
||||
], consumeErrors=True))
|
||||
yield make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(claim_client_keys, destination)
|
||||
for destination in remote_queries
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Claimed one-time-keys: %s",
|
||||
",".join((
|
||||
"%s for %s:%s" % (key_id, user_id, device_id)
|
||||
for user_id, user_keys in iteritems(json_result)
|
||||
for device_id, device_keys in iteritems(user_keys)
|
||||
for key_id, _ in iteritems(device_keys)
|
||||
)),
|
||||
",".join(
|
||||
(
|
||||
"%s for %s:%s" % (key_id, user_id, device_id)
|
||||
for user_id, user_keys in iteritems(json_result)
|
||||
for device_id, device_keys in iteritems(user_keys)
|
||||
for key_id, _ in iteritems(device_keys)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"one_time_keys": json_result,
|
||||
"failures": failures
|
||||
})
|
||||
defer.returnValue({"one_time_keys": json_result, "failures": failures})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def upload_keys_for_user(self, user_id, device_id, keys):
|
||||
|
@ -270,11 +270,13 @@ class E2eKeysHandler(object):
|
|||
if device_keys:
|
||||
logger.info(
|
||||
"Updating device_keys for device %r for user %s at %d",
|
||||
device_id, user_id, time_now
|
||||
device_id,
|
||||
user_id,
|
||||
time_now,
|
||||
)
|
||||
# TODO: Sign the JSON with the server key
|
||||
changed = yield self.store.set_e2e_device_keys(
|
||||
user_id, device_id, time_now, device_keys,
|
||||
user_id, device_id, time_now, device_keys
|
||||
)
|
||||
if changed:
|
||||
# Only notify about device updates *if* the keys actually changed
|
||||
|
@ -283,7 +285,7 @@ class E2eKeysHandler(object):
|
|||
one_time_keys = keys.get("one_time_keys", None)
|
||||
if one_time_keys:
|
||||
yield self._upload_one_time_keys_for_user(
|
||||
user_id, device_id, time_now, one_time_keys,
|
||||
user_id, device_id, time_now, one_time_keys
|
||||
)
|
||||
|
||||
# the device should have been registered already, but it may have been
|
||||
|
@ -298,20 +300,22 @@ class E2eKeysHandler(object):
|
|||
defer.returnValue({"one_time_key_counts": result})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
|
||||
one_time_keys):
|
||||
def _upload_one_time_keys_for_user(
|
||||
self, user_id, device_id, time_now, one_time_keys
|
||||
):
|
||||
logger.info(
|
||||
"Adding one_time_keys %r for device %r for user %r at %d",
|
||||
one_time_keys.keys(), device_id, user_id, time_now,
|
||||
one_time_keys.keys(),
|
||||
device_id,
|
||||
user_id,
|
||||
time_now,
|
||||
)
|
||||
|
||||
# make a list of (alg, id, key) tuples
|
||||
key_list = []
|
||||
for key_id, key_obj in one_time_keys.items():
|
||||
algorithm, key_id = key_id.split(":")
|
||||
key_list.append((
|
||||
algorithm, key_id, key_obj
|
||||
))
|
||||
key_list.append((algorithm, key_id, key_obj))
|
||||
|
||||
# First we check if we have already persisted any of the keys.
|
||||
existing_key_map = yield self.store.get_e2e_one_time_keys(
|
||||
|
@ -325,42 +329,35 @@ class E2eKeysHandler(object):
|
|||
if not _one_time_keys_match(ex_json, key):
|
||||
raise SynapseError(
|
||||
400,
|
||||
("One time key %s:%s already exists. "
|
||||
"Old key: %s; new key: %r") %
|
||||
(algorithm, key_id, ex_json, key)
|
||||
(
|
||||
"One time key %s:%s already exists. "
|
||||
"Old key: %s; new key: %r"
|
||||
)
|
||||
% (algorithm, key_id, ex_json, key),
|
||||
)
|
||||
else:
|
||||
new_keys.append((
|
||||
algorithm, key_id, encode_canonical_json(key).decode('ascii')))
|
||||
new_keys.append(
|
||||
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
|
||||
)
|
||||
|
||||
yield self.store.add_e2e_one_time_keys(
|
||||
user_id, device_id, time_now, new_keys
|
||||
)
|
||||
yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
|
||||
|
||||
|
||||
def _exception_to_failure(e):
|
||||
if isinstance(e, CodeMessageException):
|
||||
return {
|
||||
"status": e.code, "message": str(e),
|
||||
}
|
||||
return {"status": e.code, "message": str(e)}
|
||||
|
||||
if isinstance(e, NotRetryingDestination):
|
||||
return {
|
||||
"status": 503, "message": "Not ready for retry",
|
||||
}
|
||||
return {"status": 503, "message": "Not ready for retry"}
|
||||
|
||||
if isinstance(e, FederationDeniedError):
|
||||
return {
|
||||
"status": 403, "message": "Federation Denied",
|
||||
}
|
||||
return {"status": 403, "message": "Federation Denied"}
|
||||
|
||||
# include ConnectionRefused and other errors
|
||||
#
|
||||
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
|
||||
# give a string for e.message, which json then fails to serialize.
|
||||
return {
|
||||
"status": 503, "message": str(e),
|
||||
}
|
||||
return {"status": 503, "message": str(e)}
|
||||
|
||||
|
||||
def _one_time_keys_match(old_key_json, new_key):
|
||||
|
|
|
@ -152,14 +152,14 @@ class E2eRoomKeysHandler(object):
|
|||
else:
|
||||
raise
|
||||
|
||||
if version_info['version'] != version:
|
||||
if version_info["version"] != version:
|
||||
# Check that the version we're trying to upload actually exists
|
||||
try:
|
||||
version_info = yield self.store.get_e2e_room_keys_version_info(
|
||||
user_id, version,
|
||||
user_id, version
|
||||
)
|
||||
# if we get this far, the version must exist
|
||||
raise RoomKeysVersionError(current_version=version_info['version'])
|
||||
raise RoomKeysVersionError(current_version=version_info["version"])
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise NotFoundError("Version '%s' not found" % (version,))
|
||||
|
@ -168,8 +168,8 @@ class E2eRoomKeysHandler(object):
|
|||
|
||||
# go through the room_keys.
|
||||
# XXX: this should/could be done concurrently, given we're in a lock.
|
||||
for room_id, room in iteritems(room_keys['rooms']):
|
||||
for session_id, session in iteritems(room['sessions']):
|
||||
for room_id, room in iteritems(room_keys["rooms"]):
|
||||
for session_id, session in iteritems(room["sessions"]):
|
||||
yield self._upload_room_key(
|
||||
user_id, version, room_id, session_id, session
|
||||
)
|
||||
|
@ -223,14 +223,14 @@ class E2eRoomKeysHandler(object):
|
|||
# spelt out with if/elifs rather than nested boolean expressions
|
||||
# purely for legibility.
|
||||
|
||||
if room_key['is_verified'] and not current_room_key['is_verified']:
|
||||
if room_key["is_verified"] and not current_room_key["is_verified"]:
|
||||
return True
|
||||
elif (
|
||||
room_key['first_message_index'] <
|
||||
current_room_key['first_message_index']
|
||||
room_key["first_message_index"]
|
||||
< current_room_key["first_message_index"]
|
||||
):
|
||||
return True
|
||||
elif room_key['forwarded_count'] < current_room_key['forwarded_count']:
|
||||
elif room_key["forwarded_count"] < current_room_key["forwarded_count"]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
@ -328,16 +328,10 @@ class E2eRoomKeysHandler(object):
|
|||
A deferred of an empty dict.
|
||||
"""
|
||||
if "version" not in version_info:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing version in body",
|
||||
Codes.MISSING_PARAM
|
||||
)
|
||||
raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM)
|
||||
if version_info["version"] != version:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Version in body does not match",
|
||||
Codes.INVALID_PARAM
|
||||
400, "Version in body does not match", Codes.INVALID_PARAM
|
||||
)
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
try:
|
||||
|
@ -350,12 +344,10 @@ class E2eRoomKeysHandler(object):
|
|||
else:
|
||||
raise
|
||||
if old_info["algorithm"] != version_info["algorithm"]:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Algorithm does not match",
|
||||
Codes.INVALID_PARAM
|
||||
)
|
||||
raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
|
||||
|
||||
yield self.store.update_e2e_room_keys_version(user_id, version, version_info)
|
||||
yield self.store.update_e2e_room_keys_version(
|
||||
user_id, version, version_info
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
|
|
@ -31,7 +31,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class EventStreamHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(EventStreamHandler, self).__init__(hs)
|
||||
|
||||
|
@ -53,9 +52,17 @@ class EventStreamHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_stream(self, auth_user_id, pagin_config, timeout=0,
|
||||
as_client_event=True, affect_presence=True,
|
||||
only_keys=None, room_id=None, is_guest=False):
|
||||
def get_stream(
|
||||
self,
|
||||
auth_user_id,
|
||||
pagin_config,
|
||||
timeout=0,
|
||||
as_client_event=True,
|
||||
affect_presence=True,
|
||||
only_keys=None,
|
||||
room_id=None,
|
||||
is_guest=False,
|
||||
):
|
||||
"""Fetches the events stream for a given user.
|
||||
|
||||
If `only_keys` is not None, events from keys will be sent down.
|
||||
|
@ -73,7 +80,7 @@ class EventStreamHandler(BaseHandler):
|
|||
presence_handler = self.hs.get_presence_handler()
|
||||
|
||||
context = yield presence_handler.user_syncing(
|
||||
auth_user_id, affect_presence=affect_presence,
|
||||
auth_user_id, affect_presence=affect_presence
|
||||
)
|
||||
with context:
|
||||
if timeout:
|
||||
|
@ -85,9 +92,12 @@ class EventStreamHandler(BaseHandler):
|
|||
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
|
||||
|
||||
events, tokens = yield self.notifier.get_events_for(
|
||||
auth_user, pagin_config, timeout,
|
||||
auth_user,
|
||||
pagin_config,
|
||||
timeout,
|
||||
only_keys=only_keys,
|
||||
is_guest=is_guest, explicit_room_id=room_id
|
||||
is_guest=is_guest,
|
||||
explicit_room_id=room_id,
|
||||
)
|
||||
|
||||
# When the user joins a new room, or another user joins a currently
|
||||
|
@ -102,17 +112,15 @@ class EventStreamHandler(BaseHandler):
|
|||
# Send down presence.
|
||||
if event.state_key == auth_user_id:
|
||||
# Send down presence for everyone in the room.
|
||||
users = yield self.state.get_current_users_in_room(event.room_id)
|
||||
states = yield presence_handler.get_states(
|
||||
users,
|
||||
as_event=True,
|
||||
users = yield self.state.get_current_users_in_room(
|
||||
event.room_id
|
||||
)
|
||||
states = yield presence_handler.get_states(users, as_event=True)
|
||||
to_add.extend(states)
|
||||
else:
|
||||
|
||||
ev = yield presence_handler.get_state(
|
||||
UserID.from_string(event.state_key),
|
||||
as_event=True,
|
||||
UserID.from_string(event.state_key), as_event=True
|
||||
)
|
||||
to_add.append(ev)
|
||||
|
||||
|
@ -121,7 +129,9 @@ class EventStreamHandler(BaseHandler):
|
|||
time_now = self.clock.time_msec()
|
||||
|
||||
chunks = yield self._event_serializer.serialize_events(
|
||||
events, time_now, as_client_event=as_client_event,
|
||||
events,
|
||||
time_now,
|
||||
as_client_event=as_client_event,
|
||||
# We don't bundle "live" events, as otherwise clients
|
||||
# will end up double counting annotations.
|
||||
bundle_aggregations=False,
|
||||
|
@ -137,7 +147,6 @@ class EventStreamHandler(BaseHandler):
|
|||
|
||||
|
||||
class EventHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, user, room_id, event_id):
|
||||
"""Retrieve a single specified event.
|
||||
|
@ -164,16 +173,10 @@ class EventHandler(BaseHandler):
|
|||
is_peeking = user.to_string() not in users
|
||||
|
||||
filtered = yield filter_events_for_client(
|
||||
self.store,
|
||||
user.to_string(),
|
||||
[event],
|
||||
is_peeking=is_peeking
|
||||
self.store, user.to_string(), [event], is_peeking=is_peeking
|
||||
)
|
||||
|
||||
if not filtered:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to access that event."
|
||||
)
|
||||
raise AuthError(403, "You don't have permission to access that event.")
|
||||
|
||||
defer.returnValue(event)
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -30,6 +30,7 @@ def _create_rerouter(func_name):
|
|||
"""Returns a function that looks at the group id and calls the function
|
||||
on federation or the local group server if the group is local
|
||||
"""
|
||||
|
||||
def f(self, group_id, *args, **kwargs):
|
||||
if self.is_mine_id(group_id):
|
||||
return getattr(self.groups_server_handler, func_name)(
|
||||
|
@ -49,9 +50,7 @@ def _create_rerouter(func_name):
|
|||
def http_response_errback(failure):
|
||||
failure.trap(HttpResponseException)
|
||||
e = failure.value
|
||||
if e.code == 403:
|
||||
raise e.to_synapse_error()
|
||||
return failure
|
||||
raise e.to_synapse_error()
|
||||
|
||||
def request_failed_errback(failure):
|
||||
failure.trap(RequestSendFailed)
|
||||
|
@ -60,6 +59,7 @@ def _create_rerouter(func_name):
|
|||
d.addErrback(http_response_errback)
|
||||
d.addErrback(request_failed_errback)
|
||||
return d
|
||||
|
||||
return f
|
||||
|
||||
|
||||
|
@ -127,7 +127,7 @@ class GroupsLocalHandler(object):
|
|||
)
|
||||
else:
|
||||
res = yield self.transport_client.get_group_summary(
|
||||
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||
get_domain_from_id(group_id), group_id, requester_user_id
|
||||
)
|
||||
|
||||
group_server_name = get_domain_from_id(group_id)
|
||||
|
@ -184,7 +184,7 @@ class GroupsLocalHandler(object):
|
|||
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
|
||||
|
||||
res = yield self.transport_client.create_group(
|
||||
get_domain_from_id(group_id), group_id, user_id, content,
|
||||
get_domain_from_id(group_id), group_id, user_id, content
|
||||
)
|
||||
|
||||
remote_attestation = res["attestation"]
|
||||
|
@ -197,16 +197,15 @@ class GroupsLocalHandler(object):
|
|||
|
||||
is_publicised = content.get("publicise", False)
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
group_id,
|
||||
user_id,
|
||||
membership="join",
|
||||
is_admin=True,
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
is_publicised=is_publicised,
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
self.notifier.on_new_event("groups_key", token, users=[user_id])
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
|
@ -223,7 +222,7 @@ class GroupsLocalHandler(object):
|
|||
group_server_name = get_domain_from_id(group_id)
|
||||
|
||||
res = yield self.transport_client.get_users_in_group(
|
||||
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||
get_domain_from_id(group_id), group_id, requester_user_id
|
||||
)
|
||||
|
||||
chunk = res["chunk"]
|
||||
|
@ -252,9 +251,7 @@ class GroupsLocalHandler(object):
|
|||
"""Request to join a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
yield self.groups_server_handler.join_group(
|
||||
group_id, user_id, content
|
||||
)
|
||||
yield self.groups_server_handler.join_group(group_id, user_id, content)
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
else:
|
||||
|
@ -262,7 +259,7 @@ class GroupsLocalHandler(object):
|
|||
content["attestation"] = local_attestation
|
||||
|
||||
res = yield self.transport_client.join_group(
|
||||
get_domain_from_id(group_id), group_id, user_id, content,
|
||||
get_domain_from_id(group_id), group_id, user_id, content
|
||||
)
|
||||
|
||||
remote_attestation = res["attestation"]
|
||||
|
@ -278,16 +275,15 @@ class GroupsLocalHandler(object):
|
|||
is_publicised = content.get("publicise", False)
|
||||
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
group_id,
|
||||
user_id,
|
||||
membership="join",
|
||||
is_admin=False,
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
is_publicised=is_publicised,
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
self.notifier.on_new_event("groups_key", token, users=[user_id])
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
|
@ -296,9 +292,7 @@ class GroupsLocalHandler(object):
|
|||
"""Accept an invite to a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
yield self.groups_server_handler.accept_invite(
|
||||
group_id, user_id, content
|
||||
)
|
||||
yield self.groups_server_handler.accept_invite(group_id, user_id, content)
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
else:
|
||||
|
@ -306,7 +300,7 @@ class GroupsLocalHandler(object):
|
|||
content["attestation"] = local_attestation
|
||||
|
||||
res = yield self.transport_client.accept_group_invite(
|
||||
get_domain_from_id(group_id), group_id, user_id, content,
|
||||
get_domain_from_id(group_id), group_id, user_id, content
|
||||
)
|
||||
|
||||
remote_attestation = res["attestation"]
|
||||
|
@ -322,16 +316,15 @@ class GroupsLocalHandler(object):
|
|||
is_publicised = content.get("publicise", False)
|
||||
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
group_id,
|
||||
user_id,
|
||||
membership="join",
|
||||
is_admin=False,
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
is_publicised=is_publicised,
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
self.notifier.on_new_event("groups_key", token, users=[user_id])
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
|
@ -339,17 +332,17 @@ class GroupsLocalHandler(object):
|
|||
def invite(self, group_id, user_id, requester_user_id, config):
|
||||
"""Invite a user to a group
|
||||
"""
|
||||
content = {
|
||||
"requester_user_id": requester_user_id,
|
||||
"config": config,
|
||||
}
|
||||
content = {"requester_user_id": requester_user_id, "config": config}
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.invite_to_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
group_id, user_id, requester_user_id, content
|
||||
)
|
||||
else:
|
||||
res = yield self.transport_client.invite_to_group(
|
||||
get_domain_from_id(group_id), group_id, user_id, requester_user_id,
|
||||
get_domain_from_id(group_id),
|
||||
group_id,
|
||||
user_id,
|
||||
requester_user_id,
|
||||
content,
|
||||
)
|
||||
|
||||
|
@ -372,13 +365,12 @@ class GroupsLocalHandler(object):
|
|||
local_profile["avatar_url"] = content["profile"]["avatar_url"]
|
||||
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
group_id,
|
||||
user_id,
|
||||
membership="invite",
|
||||
content={"profile": local_profile, "inviter": content["inviter"]},
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
self.notifier.on_new_event("groups_key", token, users=[user_id])
|
||||
try:
|
||||
user_profile = yield self.profile_handler.get_profile(user_id)
|
||||
except Exception as e:
|
||||
|
@ -393,25 +385,25 @@ class GroupsLocalHandler(object):
|
|||
"""
|
||||
if user_id == requester_user_id:
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="leave",
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
group_id, user_id, membership="leave"
|
||||
)
|
||||
self.notifier.on_new_event("groups_key", token, users=[user_id])
|
||||
|
||||
# TODO: Should probably remember that we tried to leave so that we can
|
||||
# retry if the group server is currently down.
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
group_id, user_id, requester_user_id, content
|
||||
)
|
||||
else:
|
||||
content["requester_user_id"] = requester_user_id
|
||||
res = yield self.transport_client.remove_user_from_group(
|
||||
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||
user_id, content,
|
||||
get_domain_from_id(group_id),
|
||||
group_id,
|
||||
requester_user_id,
|
||||
user_id,
|
||||
content,
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
@ -422,12 +414,9 @@ class GroupsLocalHandler(object):
|
|||
"""
|
||||
# TODO: Check if user in group
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="leave",
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
group_id, user_id, membership="leave"
|
||||
)
|
||||
self.notifier.on_new_event("groups_key", token, users=[user_id])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_groups(self, user_id):
|
||||
|
@ -447,7 +436,7 @@ class GroupsLocalHandler(object):
|
|||
defer.returnValue({"groups": result})
|
||||
else:
|
||||
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
|
||||
get_domain_from_id(user_id), [user_id],
|
||||
get_domain_from_id(user_id), [user_id]
|
||||
)
|
||||
result = bulk_result.get("users", {}).get(user_id)
|
||||
# TODO: Verify attestations
|
||||
|
@ -462,9 +451,7 @@ class GroupsLocalHandler(object):
|
|||
if self.hs.is_mine_id(user_id):
|
||||
local_users.add(user_id)
|
||||
else:
|
||||
destinations.setdefault(
|
||||
get_domain_from_id(user_id), set()
|
||||
).add(user_id)
|
||||
destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id)
|
||||
|
||||
if not proxy and destinations:
|
||||
raise SynapseError(400, "Some user_ids are not local")
|
||||
|
@ -474,16 +461,14 @@ class GroupsLocalHandler(object):
|
|||
for destination, dest_user_ids in iteritems(destinations):
|
||||
try:
|
||||
r = yield self.transport_client.bulk_get_publicised_groups(
|
||||
destination, list(dest_user_ids),
|
||||
destination, list(dest_user_ids)
|
||||
)
|
||||
results.update(r["users"])
|
||||
except Exception:
|
||||
failed_results.extend(dest_user_ids)
|
||||
|
||||
for uid in local_users:
|
||||
results[uid] = yield self.store.get_publicised_groups_for_user(
|
||||
uid
|
||||
)
|
||||
results[uid] = yield self.store.get_publicised_groups_for_user(uid)
|
||||
|
||||
# Check AS associated groups for this user - this depends on the
|
||||
# RegExps in the AS registration file (under `users`)
|
||||
|
|
|
@ -36,7 +36,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class IdentityHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(IdentityHandler, self).__init__(hs)
|
||||
|
||||
|
@ -64,40 +63,38 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def threepid_from_creds(self, creds):
|
||||
if 'id_server' in creds:
|
||||
id_server = creds['id_server']
|
||||
elif 'idServer' in creds:
|
||||
id_server = creds['idServer']
|
||||
if "id_server" in creds:
|
||||
id_server = creds["id_server"]
|
||||
elif "idServer" in creds:
|
||||
id_server = creds["idServer"]
|
||||
else:
|
||||
raise SynapseError(400, "No id_server in creds")
|
||||
|
||||
if 'client_secret' in creds:
|
||||
client_secret = creds['client_secret']
|
||||
elif 'clientSecret' in creds:
|
||||
client_secret = creds['clientSecret']
|
||||
if "client_secret" in creds:
|
||||
client_secret = creds["client_secret"]
|
||||
elif "clientSecret" in creds:
|
||||
client_secret = creds["clientSecret"]
|
||||
else:
|
||||
raise SynapseError(400, "No client_secret in creds")
|
||||
|
||||
if not self._should_trust_id_server(id_server):
|
||||
logger.warn(
|
||||
'%s is not a trusted ID server: rejecting 3pid ' +
|
||||
'credentials', id_server
|
||||
"%s is not a trusted ID server: rejecting 3pid " + "credentials",
|
||||
id_server,
|
||||
)
|
||||
defer.returnValue(None)
|
||||
|
||||
try:
|
||||
data = yield self.http_client.get_json(
|
||||
"https://%s%s" % (
|
||||
id_server,
|
||||
"/_matrix/identity/api/v1/3pid/getValidated3pid"
|
||||
),
|
||||
{'sid': creds['sid'], 'client_secret': client_secret}
|
||||
"https://%s%s"
|
||||
% (id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid"),
|
||||
{"sid": creds["sid"], "client_secret": client_secret},
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
logger.info("getValidated3pid failed with Matrix error: %r", e)
|
||||
raise e.to_synapse_error()
|
||||
|
||||
if 'medium' in data:
|
||||
if "medium" in data:
|
||||
defer.returnValue(data)
|
||||
defer.returnValue(None)
|
||||
|
||||
|
@ -106,30 +103,24 @@ class IdentityHandler(BaseHandler):
|
|||
logger.debug("binding threepid %r to %s", creds, mxid)
|
||||
data = None
|
||||
|
||||
if 'id_server' in creds:
|
||||
id_server = creds['id_server']
|
||||
elif 'idServer' in creds:
|
||||
id_server = creds['idServer']
|
||||
if "id_server" in creds:
|
||||
id_server = creds["id_server"]
|
||||
elif "idServer" in creds:
|
||||
id_server = creds["idServer"]
|
||||
else:
|
||||
raise SynapseError(400, "No id_server in creds")
|
||||
|
||||
if 'client_secret' in creds:
|
||||
client_secret = creds['client_secret']
|
||||
elif 'clientSecret' in creds:
|
||||
client_secret = creds['clientSecret']
|
||||
if "client_secret" in creds:
|
||||
client_secret = creds["client_secret"]
|
||||
elif "clientSecret" in creds:
|
||||
client_secret = creds["clientSecret"]
|
||||
else:
|
||||
raise SynapseError(400, "No client_secret in creds")
|
||||
|
||||
try:
|
||||
data = yield self.http_client.post_urlencoded_get_json(
|
||||
"https://%s%s" % (
|
||||
id_server, "/_matrix/identity/api/v1/3pid/bind"
|
||||
),
|
||||
{
|
||||
'sid': creds['sid'],
|
||||
'client_secret': client_secret,
|
||||
'mxid': mxid,
|
||||
}
|
||||
"https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"),
|
||||
{"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid},
|
||||
)
|
||||
logger.debug("bound threepid %r to %s", creds, mxid)
|
||||
|
||||
|
@ -165,9 +156,7 @@ class IdentityHandler(BaseHandler):
|
|||
id_servers = [threepid["id_server"]]
|
||||
else:
|
||||
id_servers = yield self.store.get_id_servers_user_bound(
|
||||
user_id=mxid,
|
||||
medium=threepid["medium"],
|
||||
address=threepid["address"],
|
||||
user_id=mxid, medium=threepid["medium"], address=threepid["address"]
|
||||
)
|
||||
|
||||
# We don't know where to unbind, so we don't have a choice but to return
|
||||
|
@ -177,7 +166,7 @@ class IdentityHandler(BaseHandler):
|
|||
changed = True
|
||||
for id_server in id_servers:
|
||||
changed &= yield self.try_unbind_threepid_with_id_server(
|
||||
mxid, threepid, id_server,
|
||||
mxid, threepid, id_server
|
||||
)
|
||||
|
||||
defer.returnValue(changed)
|
||||
|
@ -201,10 +190,7 @@ class IdentityHandler(BaseHandler):
|
|||
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
|
||||
content = {
|
||||
"mxid": mxid,
|
||||
"threepid": {
|
||||
"medium": threepid["medium"],
|
||||
"address": threepid["address"],
|
||||
},
|
||||
"threepid": {"medium": threepid["medium"], "address": threepid["address"]},
|
||||
}
|
||||
|
||||
# we abuse the federation http client to sign the request, but we have to send it
|
||||
|
@ -212,25 +198,19 @@ class IdentityHandler(BaseHandler):
|
|||
# 'browser-like' HTTPS.
|
||||
auth_headers = self.federation_http_client.build_auth_headers(
|
||||
destination=None,
|
||||
method='POST',
|
||||
url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
|
||||
method="POST",
|
||||
url_bytes="/_matrix/identity/api/v1/3pid/unbind".encode("ascii"),
|
||||
content=content,
|
||||
destination_is=id_server,
|
||||
)
|
||||
headers = {
|
||||
b"Authorization": auth_headers,
|
||||
}
|
||||
headers = {b"Authorization": auth_headers}
|
||||
|
||||
try:
|
||||
yield self.http_client.post_json_get_json(
|
||||
url,
|
||||
content,
|
||||
headers,
|
||||
)
|
||||
yield self.http_client.post_json_get_json(url, content, headers)
|
||||
changed = True
|
||||
except HttpResponseException as e:
|
||||
changed = False
|
||||
if e.code in (400, 404, 501,):
|
||||
if e.code in (400, 404, 501):
|
||||
# The remote server probably doesn't support unbinding (yet)
|
||||
logger.warn("Received %d response while unbinding threepid", e.code)
|
||||
else:
|
||||
|
@ -248,35 +228,27 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def requestEmailToken(
|
||||
self,
|
||||
id_server,
|
||||
email,
|
||||
client_secret,
|
||||
send_attempt,
|
||||
next_link=None,
|
||||
self, id_server, email, client_secret, send_attempt, next_link=None
|
||||
):
|
||||
if not self._should_trust_id_server(id_server):
|
||||
raise SynapseError(
|
||||
400, "Untrusted ID server '%s'" % id_server,
|
||||
Codes.SERVER_NOT_TRUSTED
|
||||
400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
|
||||
)
|
||||
|
||||
params = {
|
||||
'email': email,
|
||||
'client_secret': client_secret,
|
||||
'send_attempt': send_attempt,
|
||||
"email": email,
|
||||
"client_secret": client_secret,
|
||||
"send_attempt": send_attempt,
|
||||
}
|
||||
|
||||
if next_link:
|
||||
params.update({'next_link': next_link})
|
||||
params.update({"next_link": next_link})
|
||||
|
||||
try:
|
||||
data = yield self.http_client.post_json_get_json(
|
||||
"https://%s%s" % (
|
||||
id_server,
|
||||
"/_matrix/identity/api/v1/validate/email/requestToken"
|
||||
),
|
||||
params
|
||||
"https://%s%s"
|
||||
% (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"),
|
||||
params,
|
||||
)
|
||||
defer.returnValue(data)
|
||||
except HttpResponseException as e:
|
||||
|
@ -285,30 +257,26 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def requestMsisdnToken(
|
||||
self, id_server, country, phone_number,
|
||||
client_secret, send_attempt, **kwargs
|
||||
self, id_server, country, phone_number, client_secret, send_attempt, **kwargs
|
||||
):
|
||||
if not self._should_trust_id_server(id_server):
|
||||
raise SynapseError(
|
||||
400, "Untrusted ID server '%s'" % id_server,
|
||||
Codes.SERVER_NOT_TRUSTED
|
||||
400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
|
||||
)
|
||||
|
||||
params = {
|
||||
'country': country,
|
||||
'phone_number': phone_number,
|
||||
'client_secret': client_secret,
|
||||
'send_attempt': send_attempt,
|
||||
"country": country,
|
||||
"phone_number": phone_number,
|
||||
"client_secret": client_secret,
|
||||
"send_attempt": send_attempt,
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
try:
|
||||
data = yield self.http_client.post_json_get_json(
|
||||
"https://%s%s" % (
|
||||
id_server,
|
||||
"/_matrix/identity/api/v1/validate/msisdn/requestToken"
|
||||
),
|
||||
params
|
||||
"https://%s%s"
|
||||
% (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"),
|
||||
params,
|
||||
)
|
||||
defer.returnValue(data)
|
||||
except HttpResponseException as e:
|
||||
|
|
|
@ -44,8 +44,13 @@ class InitialSyncHandler(BaseHandler):
|
|||
self.snapshot_cache = SnapshotCache()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
|
||||
as_client_event=True, include_archived=False):
|
||||
def snapshot_all_rooms(
|
||||
self,
|
||||
user_id=None,
|
||||
pagin_config=None,
|
||||
as_client_event=True,
|
||||
include_archived=False,
|
||||
):
|
||||
"""Retrieve a snapshot of all rooms the user is invited or has joined.
|
||||
|
||||
This snapshot may include messages for all rooms where the user is
|
||||
|
@ -77,13 +82,22 @@ class InitialSyncHandler(BaseHandler):
|
|||
if result is not None:
|
||||
return result
|
||||
|
||||
return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
|
||||
user_id, pagin_config, as_client_event, include_archived
|
||||
))
|
||||
return self.snapshot_cache.set(
|
||||
now_ms,
|
||||
key,
|
||||
self._snapshot_all_rooms(
|
||||
user_id, pagin_config, as_client_event, include_archived
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
|
||||
as_client_event=True, include_archived=False):
|
||||
def _snapshot_all_rooms(
|
||||
self,
|
||||
user_id=None,
|
||||
pagin_config=None,
|
||||
as_client_event=True,
|
||||
include_archived=False,
|
||||
):
|
||||
|
||||
memberships = [Membership.INVITE, Membership.JOIN]
|
||||
if include_archived:
|
||||
|
@ -128,8 +142,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
"visibility": (
|
||||
"public" if event.room_id in public_room_ids
|
||||
else "private"
|
||||
"public" if event.room_id in public_room_ids else "private"
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -139,7 +152,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
invite_event = yield self.store.get_event(event.event_id)
|
||||
d["invite"] = yield self._event_serializer.serialize_event(
|
||||
invite_event, time_now, as_client_event,
|
||||
invite_event, time_now, as_client_event
|
||||
)
|
||||
|
||||
rooms_ret.append(d)
|
||||
|
@ -151,14 +164,12 @@ class InitialSyncHandler(BaseHandler):
|
|||
if event.membership == Membership.JOIN:
|
||||
room_end_token = now_token.room_key
|
||||
deferred_room_state = run_in_background(
|
||||
self.state_handler.get_current_state,
|
||||
event.room_id,
|
||||
self.state_handler.get_current_state, event.room_id
|
||||
)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
room_end_token = "s%d" % (event.stream_ordering,)
|
||||
deferred_room_state = run_in_background(
|
||||
self.store.get_state_for_events,
|
||||
[event.event_id],
|
||||
self.store.get_state_for_events, [event.event_id]
|
||||
)
|
||||
deferred_room_state.addCallback(
|
||||
lambda states: states[event.event_id]
|
||||
|
@ -178,9 +189,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
messages = yield filter_events_for_client(
|
||||
self.store, user_id, messages
|
||||
)
|
||||
messages = yield filter_events_for_client(self.store, user_id, messages)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token)
|
||||
end_token = now_token.copy_and_replace("room_key", room_end_token)
|
||||
|
@ -189,8 +198,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
d["messages"] = {
|
||||
"chunk": (
|
||||
yield self._event_serializer.serialize_events(
|
||||
messages, time_now=time_now,
|
||||
as_client_event=as_client_event,
|
||||
messages, time_now=time_now, as_client_event=as_client_event
|
||||
)
|
||||
),
|
||||
"start": start_token.to_string(),
|
||||
|
@ -200,23 +208,21 @@ class InitialSyncHandler(BaseHandler):
|
|||
d["state"] = yield self._event_serializer.serialize_events(
|
||||
current_state.values(),
|
||||
time_now=time_now,
|
||||
as_client_event=as_client_event
|
||||
as_client_event=as_client_event,
|
||||
)
|
||||
|
||||
account_data_events = []
|
||||
tags = tags_by_room.get(event.room_id)
|
||||
if tags:
|
||||
account_data_events.append({
|
||||
"type": "m.tag",
|
||||
"content": {"tags": tags},
|
||||
})
|
||||
account_data_events.append(
|
||||
{"type": "m.tag", "content": {"tags": tags}}
|
||||
)
|
||||
|
||||
account_data = account_data_by_room.get(event.room_id, {})
|
||||
for account_data_type, content in account_data.items():
|
||||
account_data_events.append({
|
||||
"type": account_data_type,
|
||||
"content": content,
|
||||
})
|
||||
account_data_events.append(
|
||||
{"type": account_data_type, "content": content}
|
||||
)
|
||||
|
||||
d["account_data"] = account_data_events
|
||||
except Exception:
|
||||
|
@ -226,10 +232,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
account_data_events = []
|
||||
for account_data_type, content in account_data.items():
|
||||
account_data_events.append({
|
||||
"type": account_data_type,
|
||||
"content": content,
|
||||
})
|
||||
account_data_events.append({"type": account_data_type, "content": content})
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
|
@ -274,7 +277,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
user_id = requester.user.to_string()
|
||||
|
||||
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id,
|
||||
room_id, user_id
|
||||
)
|
||||
is_peeking = member_event_id is None
|
||||
|
||||
|
@ -290,28 +293,21 @@ class InitialSyncHandler(BaseHandler):
|
|||
account_data_events = []
|
||||
tags = yield self.store.get_tags_for_room(user_id, room_id)
|
||||
if tags:
|
||||
account_data_events.append({
|
||||
"type": "m.tag",
|
||||
"content": {"tags": tags},
|
||||
})
|
||||
account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
|
||||
|
||||
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
|
||||
for account_data_type, content in account_data.items():
|
||||
account_data_events.append({
|
||||
"type": account_data_type,
|
||||
"content": content,
|
||||
})
|
||||
account_data_events.append({"type": account_data_type, "content": content})
|
||||
|
||||
result["account_data"] = account_data_events
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
|
||||
membership, member_event_id, is_peeking):
|
||||
room_state = yield self.store.get_state_for_events(
|
||||
[member_event_id],
|
||||
)
|
||||
def _room_initial_sync_parted(
|
||||
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
|
||||
):
|
||||
room_state = yield self.store.get_state_for_events([member_event_id])
|
||||
|
||||
room_state = room_state[member_event_id]
|
||||
|
||||
|
@ -319,14 +315,10 @@ class InitialSyncHandler(BaseHandler):
|
|||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
stream_token = yield self.store.get_stream_token_for_event(
|
||||
member_event_id
|
||||
)
|
||||
stream_token = yield self.store.get_stream_token_for_event(member_event_id)
|
||||
|
||||
messages, token = yield self.store.get_recent_events_for_room(
|
||||
room_id,
|
||||
limit=limit,
|
||||
end_token=stream_token
|
||||
room_id, limit=limit, end_token=stream_token
|
||||
)
|
||||
|
||||
messages = yield filter_events_for_client(
|
||||
|
@ -338,34 +330,39 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
defer.returnValue({
|
||||
"membership": membership,
|
||||
"room_id": room_id,
|
||||
"messages": {
|
||||
"chunk": (yield self._event_serializer.serialize_events(
|
||||
messages, time_now,
|
||||
)),
|
||||
"start": start_token.to_string(),
|
||||
"end": end_token.to_string(),
|
||||
},
|
||||
"state": (yield self._event_serializer.serialize_events(
|
||||
room_state.values(), time_now,
|
||||
)),
|
||||
"presence": [],
|
||||
"receipts": [],
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"membership": membership,
|
||||
"room_id": room_id,
|
||||
"messages": {
|
||||
"chunk": (
|
||||
yield self._event_serializer.serialize_events(
|
||||
messages, time_now
|
||||
)
|
||||
),
|
||||
"start": start_token.to_string(),
|
||||
"end": end_token.to_string(),
|
||||
},
|
||||
"state": (
|
||||
yield self._event_serializer.serialize_events(
|
||||
room_state.values(), time_now
|
||||
)
|
||||
),
|
||||
"presence": [],
|
||||
"receipts": [],
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
|
||||
membership, is_peeking):
|
||||
current_state = yield self.state.get_current_state(
|
||||
room_id=room_id,
|
||||
)
|
||||
def _room_initial_sync_joined(
|
||||
self, user_id, room_id, pagin_config, membership, is_peeking
|
||||
):
|
||||
current_state = yield self.state.get_current_state(room_id=room_id)
|
||||
|
||||
# TODO: These concurrently
|
||||
time_now = self.clock.time_msec()
|
||||
state = yield self._event_serializer.serialize_events(
|
||||
current_state.values(), time_now,
|
||||
current_state.values(), time_now
|
||||
)
|
||||
|
||||
now_token = yield self.hs.get_event_sources().get_current_token()
|
||||
|
@ -375,7 +372,8 @@ class InitialSyncHandler(BaseHandler):
|
|||
limit = 10
|
||||
|
||||
room_members = [
|
||||
m for m in current_state.values()
|
||||
m
|
||||
for m in current_state.values()
|
||||
if m.type == EventTypes.Member
|
||||
and m.content["membership"] == Membership.JOIN
|
||||
]
|
||||
|
@ -389,8 +387,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
defer.returnValue([])
|
||||
|
||||
states = yield presence_handler.get_states(
|
||||
[m.user_id for m in room_members],
|
||||
as_event=True,
|
||||
[m.user_id for m in room_members], as_event=True
|
||||
)
|
||||
|
||||
defer.returnValue(states)
|
||||
|
@ -398,8 +395,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
@defer.inlineCallbacks
|
||||
def get_receipts():
|
||||
receipts = yield self.store.get_linearized_receipts_for_room(
|
||||
room_id,
|
||||
to_key=now_token.receipt_key,
|
||||
room_id, to_key=now_token.receipt_key
|
||||
)
|
||||
if not receipts:
|
||||
receipts = []
|
||||
|
@ -415,14 +411,14 @@ class InitialSyncHandler(BaseHandler):
|
|||
room_id,
|
||||
limit=limit,
|
||||
end_token=now_token.room_key,
|
||||
)
|
||||
),
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError),
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
messages = yield filter_events_for_client(
|
||||
self.store, user_id, messages, is_peeking=is_peeking,
|
||||
self.store, user_id, messages, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token)
|
||||
|
@ -433,9 +429,9 @@ class InitialSyncHandler(BaseHandler):
|
|||
ret = {
|
||||
"room_id": room_id,
|
||||
"messages": {
|
||||
"chunk": (yield self._event_serializer.serialize_events(
|
||||
messages, time_now,
|
||||
)),
|
||||
"chunk": (
|
||||
yield self._event_serializer.serialize_events(messages, time_now)
|
||||
),
|
||||
"start": start_token.to_string(),
|
||||
"end": end_token.to_string(),
|
||||
},
|
||||
|
@ -464,8 +460,8 @@ class InitialSyncHandler(BaseHandler):
|
|||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if (
|
||||
visibility and
|
||||
visibility.content["history_visibility"] == "world_readable"
|
||||
visibility
|
||||
and visibility.content["history_visibility"] == "world_readable"
|
||||
):
|
||||
defer.returnValue((Membership.JOIN, None))
|
||||
return
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2017 - 2018 New Vector Ltd
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2017-2018 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -33,9 +34,10 @@ from synapse.api.errors import (
|
|||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.api.urls import ConsentURIBuilder
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import RoomAlias, UserID
|
||||
from synapse.types import RoomAlias, UserID, create_requester
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util.logcontext import run_in_background
|
||||
|
@ -59,8 +61,9 @@ class MessageHandler(object):
|
|||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_data(self, user_id=None, room_id=None,
|
||||
event_type=None, state_key="", is_guest=False):
|
||||
def get_room_data(
|
||||
self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False
|
||||
):
|
||||
""" Get data from a room.
|
||||
|
||||
Args:
|
||||
|
@ -75,9 +78,7 @@ class MessageHandler(object):
|
|||
)
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
data = yield self.state.get_current_state(
|
||||
room_id, event_type, state_key
|
||||
)
|
||||
data = yield self.state.get_current_state(room_id, event_type, state_key)
|
||||
elif membership == Membership.LEAVE:
|
||||
key = (event_type, state_key)
|
||||
room_state = yield self.store.get_state_for_events(
|
||||
|
@ -89,8 +90,12 @@ class MessageHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_events(
|
||||
self, user_id, room_id, state_filter=StateFilter.all(),
|
||||
at_token=None, is_guest=False,
|
||||
self,
|
||||
user_id,
|
||||
room_id,
|
||||
state_filter=StateFilter.all(),
|
||||
at_token=None,
|
||||
is_guest=False,
|
||||
):
|
||||
"""Retrieve all state events for a given room. If the user is
|
||||
joined to the room then return the current state. If the user has
|
||||
|
@ -122,50 +127,48 @@ class MessageHandler(object):
|
|||
# does not reliably give you the state at the given stream position.
|
||||
# (https://github.com/matrix-org/synapse/issues/3305)
|
||||
last_events, _ = yield self.store.get_recent_events_for_room(
|
||||
room_id, end_token=at_token.room_key, limit=1,
|
||||
room_id, end_token=at_token.room_key, limit=1
|
||||
)
|
||||
|
||||
if not last_events:
|
||||
raise NotFoundError("Can't find event for token %s" % (at_token, ))
|
||||
raise NotFoundError("Can't find event for token %s" % (at_token,))
|
||||
|
||||
visible_events = yield filter_events_for_client(
|
||||
self.store, user_id, last_events,
|
||||
self.store, user_id, last_events
|
||||
)
|
||||
|
||||
event = last_events[0]
|
||||
if visible_events:
|
||||
room_state = yield self.store.get_state_for_events(
|
||||
[event.event_id], state_filter=state_filter,
|
||||
[event.event_id], state_filter=state_filter
|
||||
)
|
||||
room_state = room_state[event.event_id]
|
||||
else:
|
||||
raise AuthError(
|
||||
403,
|
||||
"User %s not allowed to view events in room %s at token %s" % (
|
||||
user_id, room_id, at_token,
|
||||
)
|
||||
"User %s not allowed to view events in room %s at token %s"
|
||||
% (user_id, room_id, at_token),
|
||||
)
|
||||
else:
|
||||
membership, membership_event_id = (
|
||||
yield self.auth.check_in_room_or_world_readable(
|
||||
room_id, user_id,
|
||||
)
|
||||
yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
)
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
state_ids = yield self.store.get_filtered_current_state_ids(
|
||||
room_id, state_filter=state_filter,
|
||||
room_id, state_filter=state_filter
|
||||
)
|
||||
room_state = yield self.store.get_events(state_ids.values())
|
||||
elif membership == Membership.LEAVE:
|
||||
room_state = yield self.store.get_state_for_events(
|
||||
[membership_event_id], state_filter=state_filter,
|
||||
[membership_event_id], state_filter=state_filter
|
||||
)
|
||||
room_state = room_state[membership_event_id]
|
||||
|
||||
now = self.clock.time_msec()
|
||||
events = yield self._event_serializer.serialize_events(
|
||||
room_state.values(), now,
|
||||
room_state.values(),
|
||||
now,
|
||||
# We don't bother bundling aggregations in when asked for state
|
||||
# events, as clients won't use them.
|
||||
bundle_aggregations=False,
|
||||
|
@ -209,13 +212,15 @@ class MessageHandler(object):
|
|||
# Loop fell through, AS has no interested users in room
|
||||
raise AuthError(403, "Appservice not in room")
|
||||
|
||||
defer.returnValue({
|
||||
user_id: {
|
||||
"avatar_url": profile.avatar_url,
|
||||
"display_name": profile.display_name,
|
||||
defer.returnValue(
|
||||
{
|
||||
user_id: {
|
||||
"avatar_url": profile.avatar_url,
|
||||
"display_name": profile.display_name,
|
||||
}
|
||||
for user_id, profile in iteritems(users_with_profile)
|
||||
}
|
||||
for user_id, profile in iteritems(users_with_profile)
|
||||
})
|
||||
)
|
||||
|
||||
|
||||
class EventCreationHandler(object):
|
||||
|
@ -248,6 +253,7 @@ class EventCreationHandler(object):
|
|||
self.action_generator = hs.get_action_generator()
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
self.third_party_event_rules = hs.get_third_party_event_rules()
|
||||
|
||||
self._block_events_without_consent_error = (
|
||||
self.config.block_events_without_consent_error
|
||||
|
@ -259,9 +265,28 @@ class EventCreationHandler(object):
|
|||
if self._block_events_without_consent_error:
|
||||
self._consent_uri_builder = ConsentURIBuilder(self.config)
|
||||
|
||||
if (
|
||||
not self.config.worker_app
|
||||
and self.config.cleanup_extremities_with_dummy_events
|
||||
):
|
||||
self.clock.looping_call(
|
||||
lambda: run_as_background_process(
|
||||
"send_dummy_events_to_fill_extremities",
|
||||
self._send_dummy_events_to_fill_extremities,
|
||||
),
|
||||
5 * 60 * 1000,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
|
||||
prev_events_and_hashes=None, require_consent=True):
|
||||
def create_event(
|
||||
self,
|
||||
requester,
|
||||
event_dict,
|
||||
token_id=None,
|
||||
txn_id=None,
|
||||
prev_events_and_hashes=None,
|
||||
require_consent=True,
|
||||
):
|
||||
"""
|
||||
Given a dict from a client, create a new event.
|
||||
|
||||
|
@ -321,8 +346,7 @@ class EventCreationHandler(object):
|
|||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to get profile information for %r: %s",
|
||||
target, e
|
||||
"Failed to get profile information for %r: %s", target, e
|
||||
)
|
||||
|
||||
is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
|
||||
|
@ -358,16 +382,17 @@ class EventCreationHandler(object):
|
|||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||
if not prev_event or prev_event.membership != Membership.JOIN:
|
||||
logger.warning(
|
||||
("Attempt to send `m.room.aliases` in room %s by user %s but"
|
||||
" membership is %s"),
|
||||
(
|
||||
"Attempt to send `m.room.aliases` in room %s by user %s but"
|
||||
" membership is %s"
|
||||
),
|
||||
event.room_id,
|
||||
event.sender,
|
||||
prev_event.membership if prev_event else None,
|
||||
)
|
||||
|
||||
raise AuthError(
|
||||
403,
|
||||
"You must be in the room to create an alias for it",
|
||||
403, "You must be in the room to create an alias for it"
|
||||
)
|
||||
|
||||
self.validator.validate_new(event)
|
||||
|
@ -434,8 +459,8 @@ class EventCreationHandler(object):
|
|||
|
||||
# exempt the system notices user
|
||||
if (
|
||||
self.config.server_notices_mxid is not None and
|
||||
user_id == self.config.server_notices_mxid
|
||||
self.config.server_notices_mxid is not None
|
||||
and user_id == self.config.server_notices_mxid
|
||||
):
|
||||
return
|
||||
|
||||
|
@ -448,15 +473,10 @@ class EventCreationHandler(object):
|
|||
return
|
||||
|
||||
consent_uri = self._consent_uri_builder.build_user_consent_uri(
|
||||
requester.user.localpart,
|
||||
)
|
||||
msg = self._block_events_without_consent_error % {
|
||||
'consent_uri': consent_uri,
|
||||
}
|
||||
raise ConsentNotGivenError(
|
||||
msg=msg,
|
||||
consent_uri=consent_uri,
|
||||
requester.user.localpart
|
||||
)
|
||||
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
|
||||
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_nonmember_event(self, requester, event, context, ratelimit=True):
|
||||
|
@ -471,8 +491,7 @@ class EventCreationHandler(object):
|
|||
"""
|
||||
if event.type == EventTypes.Member:
|
||||
raise SynapseError(
|
||||
500,
|
||||
"Tried to send member event through non-member codepath"
|
||||
500, "Tried to send member event through non-member codepath"
|
||||
)
|
||||
|
||||
user = UserID.from_string(event.sender)
|
||||
|
@ -484,15 +503,13 @@ class EventCreationHandler(object):
|
|||
if prev_state is not None:
|
||||
logger.info(
|
||||
"Not bothering to persist state event %s duplicated by %s",
|
||||
event.event_id, prev_state.event_id,
|
||||
event.event_id,
|
||||
prev_state.event_id,
|
||||
)
|
||||
defer.returnValue(prev_state)
|
||||
|
||||
yield self.handle_new_client_event(
|
||||
requester=requester,
|
||||
event=event,
|
||||
context=context,
|
||||
ratelimit=ratelimit,
|
||||
requester=requester, event=event, context=context, ratelimit=ratelimit
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -518,11 +535,7 @@ class EventCreationHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_nonmember_event(
|
||||
self,
|
||||
requester,
|
||||
event_dict,
|
||||
ratelimit=True,
|
||||
txn_id=None
|
||||
self, requester, event_dict, ratelimit=True, txn_id=None
|
||||
):
|
||||
"""
|
||||
Creates an event, then sends it.
|
||||
|
@ -537,32 +550,25 @@ class EventCreationHandler(object):
|
|||
# taking longer.
|
||||
with (yield self.limiter.queue(event_dict["room_id"])):
|
||||
event, context = yield self.create_event(
|
||||
requester,
|
||||
event_dict,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id
|
||||
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
|
||||
)
|
||||
|
||||
spam_error = self.spam_checker.check_event_for_spam(event)
|
||||
if spam_error:
|
||||
if not isinstance(spam_error, string_types):
|
||||
spam_error = "Spam is not permitted here"
|
||||
raise SynapseError(
|
||||
403, spam_error, Codes.FORBIDDEN
|
||||
)
|
||||
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
|
||||
|
||||
yield self.send_nonmember_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=ratelimit,
|
||||
requester, event, context, ratelimit=ratelimit
|
||||
)
|
||||
defer.returnValue(event)
|
||||
|
||||
@measure_func("create_new_client_event")
|
||||
@defer.inlineCallbacks
|
||||
def create_new_client_event(self, builder, requester=None,
|
||||
prev_events_and_hashes=None):
|
||||
def create_new_client_event(
|
||||
self, builder, requester=None, prev_events_and_hashes=None
|
||||
):
|
||||
"""Create a new event for a local client
|
||||
|
||||
Args:
|
||||
|
@ -582,22 +588,21 @@ class EventCreationHandler(object):
|
|||
"""
|
||||
|
||||
if prev_events_and_hashes is not None:
|
||||
assert len(prev_events_and_hashes) <= 10, \
|
||||
"Attempting to create an event with %i prev_events" % (
|
||||
len(prev_events_and_hashes),
|
||||
assert len(prev_events_and_hashes) <= 10, (
|
||||
"Attempting to create an event with %i prev_events"
|
||||
% (len(prev_events_and_hashes),)
|
||||
)
|
||||
else:
|
||||
prev_events_and_hashes = \
|
||||
yield self.store.get_prev_events_for_room(builder.room_id)
|
||||
prev_events_and_hashes = yield self.store.get_prev_events_for_room(
|
||||
builder.room_id
|
||||
)
|
||||
|
||||
prev_events = [
|
||||
(event_id, prev_hashes)
|
||||
for event_id, prev_hashes, _ in prev_events_and_hashes
|
||||
]
|
||||
|
||||
event = yield builder.build(
|
||||
prev_event_ids=[p for p, _ in prev_events],
|
||||
)
|
||||
event = yield builder.build(prev_event_ids=[p for p, _ in prev_events])
|
||||
context = yield self.state.compute_event_context(event)
|
||||
if requester:
|
||||
context.app_service = requester.app_service
|
||||
|
@ -613,29 +618,19 @@ class EventCreationHandler(object):
|
|||
aggregation_key = relation["key"]
|
||||
|
||||
already_exists = yield self.store.has_user_annotated_event(
|
||||
relates_to, event.type, aggregation_key, event.sender,
|
||||
relates_to, event.type, aggregation_key, event.sender
|
||||
)
|
||||
if already_exists:
|
||||
raise SynapseError(400, "Can't send same reaction twice")
|
||||
|
||||
logger.debug(
|
||||
"Created event %s",
|
||||
event.event_id,
|
||||
)
|
||||
logger.debug("Created event %s", event.event_id)
|
||||
|
||||
defer.returnValue(
|
||||
(event, context,)
|
||||
)
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@measure_func("handle_new_client_event")
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_client_event(
|
||||
self,
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=True,
|
||||
extra_users=[],
|
||||
self, requester, event, context, ratelimit=True, extra_users=[]
|
||||
):
|
||||
"""Processes a new event. This includes checking auth, persisting it,
|
||||
notifying users, sending to remote servers, etc.
|
||||
|
@ -651,13 +646,22 @@ class EventCreationHandler(object):
|
|||
extra_users (list(UserID)): Any extra users to notify about event
|
||||
"""
|
||||
|
||||
if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
|
||||
room_version = event.content.get(
|
||||
"room_version", RoomVersions.V1.identifier
|
||||
)
|
||||
if event.is_state() and (event.type, event.state_key) == (
|
||||
EventTypes.Create,
|
||||
"",
|
||||
):
|
||||
room_version = event.content.get("room_version", RoomVersions.V1.identifier)
|
||||
else:
|
||||
room_version = yield self.store.get_room_version(event.room_id)
|
||||
|
||||
event_allowed = yield self.third_party_event_rules.check_event_allowed(
|
||||
event, context
|
||||
)
|
||||
if not event_allowed:
|
||||
raise SynapseError(
|
||||
403, "This event is not allowed in this context", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
try:
|
||||
yield self.auth.check_from_context(room_version, event, context)
|
||||
except AuthError as err:
|
||||
|
@ -672,9 +676,7 @@ class EventCreationHandler(object):
|
|||
logger.exception("Failed to encode content: %r", event.content)
|
||||
raise
|
||||
|
||||
yield self.action_generator.handle_push_actions_for_event(
|
||||
event, context
|
||||
)
|
||||
yield self.action_generator.handle_push_actions_for_event(event, context)
|
||||
|
||||
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
|
||||
# hack around with a try/finally instead.
|
||||
|
@ -695,11 +697,7 @@ class EventCreationHandler(object):
|
|||
return
|
||||
|
||||
yield self.persist_and_notify_client_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=ratelimit,
|
||||
extra_users=extra_users,
|
||||
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
|
||||
)
|
||||
|
||||
success = True
|
||||
|
@ -708,18 +706,12 @@ class EventCreationHandler(object):
|
|||
# Ensure that we actually remove the entries in the push actions
|
||||
# staging area, if we calculated them.
|
||||
run_in_background(
|
||||
self.store.remove_push_actions_from_staging,
|
||||
event.event_id,
|
||||
self.store.remove_push_actions_from_staging, event.event_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def persist_and_notify_client_event(
|
||||
self,
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=True,
|
||||
extra_users=[],
|
||||
self, requester, event, context, ratelimit=True, extra_users=[]
|
||||
):
|
||||
"""Called when we have fully built the event, have already
|
||||
calculated the push actions for the event, and checked auth.
|
||||
|
@ -744,20 +736,16 @@ class EventCreationHandler(object):
|
|||
if mapping["room_id"] != event.room_id:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Room alias %s does not point to the room" % (
|
||||
room_alias_str,
|
||||
)
|
||||
"Room alias %s does not point to the room" % (room_alias_str,),
|
||||
)
|
||||
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.content["membership"] == Membership.INVITE:
|
||||
|
||||
def is_inviter_member_event(e):
|
||||
return (
|
||||
e.type == EventTypes.Member and
|
||||
e.sender == event.sender
|
||||
)
|
||||
return e.type == EventTypes.Member and e.sender == event.sender
|
||||
|
||||
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||
|
||||
|
@ -787,26 +775,21 @@ class EventCreationHandler(object):
|
|||
# to get them to sign the event.
|
||||
|
||||
returned_invite = yield federation_handler.send_invite(
|
||||
invitee.domain,
|
||||
event,
|
||||
invitee.domain, event
|
||||
)
|
||||
|
||||
event.unsigned.pop("room_state", None)
|
||||
|
||||
# TODO: Make sure the signatures actually are correct.
|
||||
event.signatures.update(
|
||||
returned_invite.signatures
|
||||
)
|
||||
event.signatures.update(returned_invite.signatures)
|
||||
|
||||
if event.type == EventTypes.Redaction:
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
auth_events_ids = yield self.auth.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True,
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events.values()
|
||||
}
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||
room_version = yield self.store.get_room_version(event.room_id)
|
||||
if self.auth.check_redaction(room_version, event, auth_events=auth_events):
|
||||
original_event = yield self.store.get_event(
|
||||
|
@ -814,13 +797,10 @@ class EventCreationHandler(object):
|
|||
check_redacted=False,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
allow_none=False
|
||||
allow_none=False,
|
||||
)
|
||||
if event.user_id != original_event.user_id:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to redact events"
|
||||
)
|
||||
raise AuthError(403, "You don't have permission to redact events")
|
||||
|
||||
# We've already checked.
|
||||
event.internal_metadata.recheck_redaction = False
|
||||
|
@ -828,24 +808,18 @@ class EventCreationHandler(object):
|
|||
if event.type == EventTypes.Create:
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
if prev_state_ids:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Changing the room create event is forbidden",
|
||||
)
|
||||
raise AuthError(403, "Changing the room create event is forbidden")
|
||||
|
||||
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
||||
event, context=context
|
||||
)
|
||||
|
||||
yield self.pusher_pool.on_new_notifications(
|
||||
event_stream_id, max_stream_id,
|
||||
)
|
||||
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
|
||||
|
||||
def _notify():
|
||||
try:
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=extra_users
|
||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error notifying about new room event")
|
||||
|
@ -864,3 +838,54 @@ class EventCreationHandler(object):
|
|||
yield presence.bump_presence_active_time(user)
|
||||
except Exception:
|
||||
logger.exception("Error bumping presence active time")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_dummy_events_to_fill_extremities(self):
|
||||
"""Background task to send dummy events into rooms that have a large
|
||||
number of extremities
|
||||
"""
|
||||
|
||||
room_ids = yield self.store.get_rooms_with_many_extremities(
|
||||
min_count=10, limit=5
|
||||
)
|
||||
|
||||
for room_id in room_ids:
|
||||
# For each room we need to find a joined member we can use to send
|
||||
# the dummy event with.
|
||||
|
||||
prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id)
|
||||
|
||||
latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)
|
||||
|
||||
members = yield self.state.get_current_users_in_room(
|
||||
room_id, latest_event_ids=latest_event_ids
|
||||
)
|
||||
|
||||
user_id = None
|
||||
for member in members:
|
||||
if self.hs.is_mine_id(member):
|
||||
user_id = member
|
||||
break
|
||||
|
||||
if not user_id:
|
||||
# We don't have a joined user.
|
||||
# TODO: We should do something here to stop the room from
|
||||
# appearing next time.
|
||||
continue
|
||||
|
||||
requester = create_requester(user_id)
|
||||
|
||||
event, context = yield self.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": "org.matrix.dummy_event",
|
||||
"content": {},
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
},
|
||||
prev_events_and_hashes=prev_events_and_hashes,
|
||||
)
|
||||
|
||||
event.internal_metadata.proactively_send = False
|
||||
|
||||
yield self.send_nonmember_event(requester, event, context, ratelimit=False)
|
||||
|
|
|
@ -55,9 +55,7 @@ class PurgeStatus(object):
|
|||
self.status = PurgeStatus.STATUS_ACTIVE
|
||||
|
||||
def asdict(self):
|
||||
return {
|
||||
"status": PurgeStatus.STATUS_TEXT[self.status]
|
||||
}
|
||||
return {"status": PurgeStatus.STATUS_TEXT[self.status]}
|
||||
|
||||
|
||||
class PaginationHandler(object):
|
||||
|
@ -79,8 +77,7 @@ class PaginationHandler(object):
|
|||
self._purges_by_id = {}
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
def start_purge_history(self, room_id, token,
|
||||
delete_local_events=False):
|
||||
def start_purge_history(self, room_id, token, delete_local_events=False):
|
||||
"""Start off a history purge on a room.
|
||||
|
||||
Args:
|
||||
|
@ -95,8 +92,7 @@ class PaginationHandler(object):
|
|||
"""
|
||||
if room_id in self._purges_in_progress_by_room:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"History purge already in progress for %s" % (room_id, ),
|
||||
400, "History purge already in progress for %s" % (room_id,)
|
||||
)
|
||||
|
||||
purge_id = random_string(16)
|
||||
|
@ -107,14 +103,12 @@ class PaginationHandler(object):
|
|||
|
||||
self._purges_by_id[purge_id] = PurgeStatus()
|
||||
run_in_background(
|
||||
self._purge_history,
|
||||
purge_id, room_id, token, delete_local_events,
|
||||
self._purge_history, purge_id, room_id, token, delete_local_events
|
||||
)
|
||||
return purge_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _purge_history(self, purge_id, room_id, token,
|
||||
delete_local_events):
|
||||
def _purge_history(self, purge_id, room_id, token, delete_local_events):
|
||||
"""Carry out a history purge on a room.
|
||||
|
||||
Args:
|
||||
|
@ -130,16 +124,13 @@ class PaginationHandler(object):
|
|||
self._purges_in_progress_by_room.add(room_id)
|
||||
try:
|
||||
with (yield self.pagination_lock.write(room_id)):
|
||||
yield self.store.purge_history(
|
||||
room_id, token, delete_local_events,
|
||||
)
|
||||
yield self.store.purge_history(room_id, token, delete_local_events)
|
||||
logger.info("[purge] complete")
|
||||
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
|
||||
except Exception:
|
||||
f = Failure()
|
||||
logger.error(
|
||||
"[purge] failed",
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
"[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
|
||||
)
|
||||
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
|
||||
finally:
|
||||
|
@ -148,6 +139,7 @@ class PaginationHandler(object):
|
|||
# remove the purge from the list 24 hours after it completes
|
||||
def clear_purge():
|
||||
del self._purges_by_id[purge_id]
|
||||
|
||||
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
|
||||
|
||||
def get_purge_status(self, purge_id):
|
||||
|
@ -162,8 +154,14 @@ class PaginationHandler(object):
|
|||
return self._purges_by_id.get(purge_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, requester, room_id=None, pagin_config=None,
|
||||
as_client_event=True, event_filter=None):
|
||||
def get_messages(
|
||||
self,
|
||||
requester,
|
||||
room_id=None,
|
||||
pagin_config=None,
|
||||
as_client_event=True,
|
||||
event_filter=None,
|
||||
):
|
||||
"""Get messages in a room.
|
||||
|
||||
Args:
|
||||
|
@ -182,9 +180,7 @@ class PaginationHandler(object):
|
|||
room_token = pagin_config.from_token.room_key
|
||||
else:
|
||||
pagin_config.from_token = (
|
||||
yield self.hs.get_event_sources().get_current_token_for_room(
|
||||
room_id=room_id
|
||||
)
|
||||
yield self.hs.get_event_sources().get_current_token_for_pagination()
|
||||
)
|
||||
room_token = pagin_config.from_token.room_key
|
||||
|
||||
|
@ -201,7 +197,7 @@ class PaginationHandler(object):
|
|||
room_id, user_id
|
||||
)
|
||||
|
||||
if source_config.direction == 'b':
|
||||
if source_config.direction == "b":
|
||||
# if we're going backwards, we might need to backfill. This
|
||||
# requires that we have a topo token.
|
||||
if room_token.topological:
|
||||
|
@ -235,27 +231,24 @@ class PaginationHandler(object):
|
|||
event_filter=event_filter,
|
||||
)
|
||||
|
||||
next_token = pagin_config.from_token.copy_and_replace(
|
||||
"room_key", next_key
|
||||
)
|
||||
next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
|
||||
|
||||
if events:
|
||||
if event_filter:
|
||||
events = event_filter.filter(events)
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
self.store,
|
||||
user_id,
|
||||
events,
|
||||
is_peeking=(member_event_id is None),
|
||||
self.store, user_id, events, is_peeking=(member_event_id is None)
|
||||
)
|
||||
|
||||
if not events:
|
||||
defer.returnValue({
|
||||
"chunk": [],
|
||||
"start": pagin_config.from_token.to_string(),
|
||||
"end": next_token.to_string(),
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"chunk": [],
|
||||
"start": pagin_config.from_token.to_string(),
|
||||
"end": next_token.to_string(),
|
||||
}
|
||||
)
|
||||
|
||||
state = None
|
||||
if event_filter and event_filter.lazy_load_members() and len(events) > 0:
|
||||
|
@ -263,12 +256,11 @@ class PaginationHandler(object):
|
|||
|
||||
# FIXME: we also care about invite targets etc.
|
||||
state_filter = StateFilter.from_types(
|
||||
(EventTypes.Member, event.sender)
|
||||
for event in events
|
||||
(EventTypes.Member, event.sender) for event in events
|
||||
)
|
||||
|
||||
state_ids = yield self.store.get_state_ids_for_event(
|
||||
events[0].event_id, state_filter=state_filter,
|
||||
events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
|
||||
if state_ids:
|
||||
|
@ -280,8 +272,7 @@ class PaginationHandler(object):
|
|||
chunk = {
|
||||
"chunk": (
|
||||
yield self._event_serializer.serialize_events(
|
||||
events, time_now,
|
||||
as_client_event=as_client_event,
|
||||
events, time_now, as_client_event=as_client_event
|
||||
)
|
||||
),
|
||||
"start": pagin_config.from_token.to_string(),
|
||||
|
@ -291,8 +282,7 @@ class PaginationHandler(object):
|
|||
if state:
|
||||
chunk["state"] = (
|
||||
yield self._event_serializer.serialize_events(
|
||||
state, time_now,
|
||||
as_client_event=as_client_event,
|
||||
state, time_now, as_client_event=as_client_event
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue