mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 12:36:02 -04:00
Merge branch 'develop' into rav/saml2_client
This commit is contained in:
commit
a4daa899ec
478 changed files with 18927 additions and 11500 deletions
|
@ -56,8 +56,9 @@ class ClientDirectoryServer(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
if "room_id" not in content:
|
||||
raise SynapseError(400, 'Missing params: ["room_id"]',
|
||||
errcode=Codes.BAD_JSON)
|
||||
raise SynapseError(
|
||||
400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON
|
||||
)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
logger.debug("Got room name: %s", room_alias.to_string())
|
||||
|
@ -89,13 +90,11 @@ class ClientDirectoryServer(RestServlet):
|
|||
try:
|
||||
service = yield self.auth.get_appservice_by_req(request)
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
yield dir_handler.delete_appservice_association(
|
||||
service, room_alias
|
||||
)
|
||||
yield dir_handler.delete_appservice_association(service, room_alias)
|
||||
logger.info(
|
||||
"Application service at %s deleted alias %s",
|
||||
service.url,
|
||||
room_alias.to_string()
|
||||
room_alias.to_string(),
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
except AuthError:
|
||||
|
@ -107,14 +106,10 @@ class ClientDirectoryServer(RestServlet):
|
|||
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
|
||||
yield dir_handler.delete_association(
|
||||
requester, room_alias
|
||||
)
|
||||
yield dir_handler.delete_association(requester, room_alias)
|
||||
|
||||
logger.info(
|
||||
"User %s deleted alias %s",
|
||||
user.to_string(),
|
||||
room_alias.to_string()
|
||||
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -135,9 +130,9 @@ class ClientDirectoryListServer(RestServlet):
|
|||
if room is None:
|
||||
raise NotFoundError("Unknown room")
|
||||
|
||||
defer.returnValue((200, {
|
||||
"visibility": "public" if room["is_public"] else "private"
|
||||
}))
|
||||
defer.returnValue(
|
||||
(200, {"visibility": "public" if room["is_public"] else "private"})
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id):
|
||||
|
@ -147,7 +142,7 @@ class ClientDirectoryListServer(RestServlet):
|
|||
visibility = content.get("visibility", "public")
|
||||
|
||||
yield self.handlers.directory_handler.edit_published_room_list(
|
||||
requester, room_id, visibility,
|
||||
requester, room_id, visibility
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -157,7 +152,7 @@ class ClientDirectoryListServer(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
yield self.handlers.directory_handler.edit_published_room_list(
|
||||
requester, room_id, "private",
|
||||
requester, room_id, "private"
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -191,7 +186,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
|
|||
)
|
||||
|
||||
yield self.handlers.directory_handler.edit_published_appservice_room_list(
|
||||
requester.app_service.id, network_id, room_id, visibility,
|
||||
requester.app_service.id, network_id, room_id, visibility
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
|
|
@ -38,17 +38,14 @@ class EventStreamRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=True,
|
||||
)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
is_guest = requester.is_guest
|
||||
room_id = None
|
||||
if is_guest:
|
||||
if b"room_id" not in request.args:
|
||||
raise SynapseError(400, "Guest users must specify room_id param")
|
||||
if b"room_id" in request.args:
|
||||
room_id = request.args[b"room_id"][0].decode('ascii')
|
||||
room_id = request.args[b"room_id"][0].decode("ascii")
|
||||
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
|
|
|
@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission):
|
|||
to a typed object.
|
||||
"""
|
||||
if "user" in submission:
|
||||
submission["identifier"] = {
|
||||
"type": "m.id.user",
|
||||
"user": submission["user"],
|
||||
}
|
||||
submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
|
||||
del submission["user"]
|
||||
|
||||
if "medium" in submission and "address" in submission:
|
||||
|
@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier):
|
|||
|
||||
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
|
||||
|
||||
return {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": "msisdn",
|
||||
"address": msisdn,
|
||||
}
|
||||
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
|
||||
|
||||
|
||||
class LoginRestServlet(RestServlet):
|
||||
|
@ -124,9 +117,9 @@ class LoginRestServlet(RestServlet):
|
|||
# login flow types returned.
|
||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||
|
||||
flows.extend((
|
||||
{"type": t} for t in self.auth_handler.get_supported_login_types()
|
||||
))
|
||||
flows.extend(
|
||||
({"type": t} for t in self.auth_handler.get_supported_login_types())
|
||||
)
|
||||
|
||||
return (200, {"flows": flows})
|
||||
|
||||
|
@ -136,7 +129,8 @@ class LoginRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
self._address_ratelimiter.ratelimit(
|
||||
request.getClientIP(), time_now_s=self.hs.clock.time(),
|
||||
request.getClientIP(),
|
||||
time_now_s=self.hs.clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_address.per_second,
|
||||
burst_count=self.hs.config.rc_login_address.burst_count,
|
||||
update=True,
|
||||
|
@ -144,8 +138,9 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
login_submission = parse_json_object_from_request(request)
|
||||
try:
|
||||
if self.jwt_enabled and (login_submission["type"] ==
|
||||
LoginRestServlet.JWT_TYPE):
|
||||
if self.jwt_enabled and (
|
||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
):
|
||||
result = yield self.do_jwt_login(login_submission)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
result = yield self.do_token_login(login_submission)
|
||||
|
@ -174,10 +169,10 @@ class LoginRestServlet(RestServlet):
|
|||
# field)
|
||||
logger.info(
|
||||
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
|
||||
login_submission.get('identifier'),
|
||||
login_submission.get('medium'),
|
||||
login_submission.get('address'),
|
||||
login_submission.get('user'),
|
||||
login_submission.get("identifier"),
|
||||
login_submission.get("medium"),
|
||||
login_submission.get("address"),
|
||||
login_submission.get("user"),
|
||||
)
|
||||
login_submission_legacy_convert(login_submission)
|
||||
|
||||
|
@ -194,13 +189,13 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
# convert threepid identifiers to user IDs
|
||||
if identifier["type"] == "m.id.thirdparty":
|
||||
address = identifier.get('address')
|
||||
medium = identifier.get('medium')
|
||||
address = identifier.get("address")
|
||||
medium = identifier.get("medium")
|
||||
|
||||
if medium is None or address is None:
|
||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||
|
||||
if medium == 'email':
|
||||
if medium == "email":
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
|
@ -209,34 +204,28 @@ class LoginRestServlet(RestServlet):
|
|||
# Check for login providers that support 3pid login types
|
||||
canonical_user_id, callback_3pid = (
|
||||
yield self.auth_handler.check_password_provider_3pid(
|
||||
medium,
|
||||
address,
|
||||
login_submission["password"],
|
||||
medium, address, login_submission["password"]
|
||||
)
|
||||
)
|
||||
if canonical_user_id:
|
||||
# Authentication through password provider and 3pid succeeded
|
||||
result = yield self._register_device_with_callback(
|
||||
canonical_user_id, login_submission, callback_3pid,
|
||||
canonical_user_id, login_submission, callback_3pid
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
# No password providers were able to handle this 3pid
|
||||
# Check local store
|
||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
medium, address,
|
||||
medium, address
|
||||
)
|
||||
if not user_id:
|
||||
logger.warn(
|
||||
"unknown 3pid identifier medium %s, address %r",
|
||||
medium, address,
|
||||
"unknown 3pid identifier medium %s, address %r", medium, address
|
||||
)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
identifier = {
|
||||
"type": "m.id.user",
|
||||
"user": user_id,
|
||||
}
|
||||
identifier = {"type": "m.id.user", "user": user_id}
|
||||
|
||||
# by this point, the identifier should be an m.id.user: if it's anything
|
||||
# else, we haven't understood it.
|
||||
|
@ -246,22 +235,16 @@ class LoginRestServlet(RestServlet):
|
|||
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||
|
||||
canonical_user_id, callback = yield self.auth_handler.validate_login(
|
||||
identifier["user"],
|
||||
login_submission,
|
||||
identifier["user"], login_submission
|
||||
)
|
||||
|
||||
result = yield self._register_device_with_callback(
|
||||
canonical_user_id, login_submission, callback,
|
||||
canonical_user_id, login_submission, callback
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _register_device_with_callback(
|
||||
self,
|
||||
user_id,
|
||||
login_submission,
|
||||
callback=None,
|
||||
):
|
||||
def _register_device_with_callback(self, user_id, login_submission, callback=None):
|
||||
""" Registers a device with a given user_id. Optionally run a callback
|
||||
function after registration has completed.
|
||||
|
||||
|
@ -277,7 +260,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name,
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -294,7 +277,7 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def do_token_login(self, login_submission):
|
||||
token = login_submission['token']
|
||||
token = login_submission["token"]
|
||||
auth_handler = self.auth_handler
|
||||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
|
@ -303,7 +286,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name,
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -320,15 +303,16 @@ class LoginRestServlet(RestServlet):
|
|||
token = login_submission.get("token", None)
|
||||
if token is None:
|
||||
raise LoginError(
|
||||
401, "Token field for JWT is missing",
|
||||
errcode=Codes.UNAUTHORIZED
|
||||
401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
import jwt
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
|
||||
payload = jwt.decode(
|
||||
token, self.jwt_secret, algorithms=[self.jwt_algorithm]
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
|
||||
except InvalidTokenError:
|
||||
|
@ -346,7 +330,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
registered_user_id, device_id, initial_display_name,
|
||||
registered_user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -362,7 +346,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
registered_user_id, device_id, initial_display_name,
|
||||
registered_user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -376,6 +360,7 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
class BaseSsoRedirectServlet(RestServlet):
|
||||
"""Common base class for /login/sso/redirect impls"""
|
||||
|
||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||
|
||||
def on_GET(self, request):
|
||||
|
@ -401,21 +386,20 @@ class BaseSsoRedirectServlet(RestServlet):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CasRedirectServlet(RestServlet):
|
||||
class CasRedirectServlet(BaseSsoRedirectServlet):
|
||||
def __init__(self, hs):
|
||||
super(CasRedirectServlet, self).__init__()
|
||||
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
|
||||
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
|
||||
self.cas_server_url = hs.config.cas_server_url.encode("ascii")
|
||||
self.cas_service_url = hs.config.cas_service_url.encode("ascii")
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
client_redirect_url_param = urllib.parse.urlencode({
|
||||
b"redirectUrl": client_redirect_url
|
||||
}).encode('ascii')
|
||||
hs_redirect_url = (self.cas_service_url +
|
||||
b"/_matrix/client/r0/login/cas/ticket")
|
||||
service_param = urllib.parse.urlencode({
|
||||
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
|
||||
}).encode('ascii')
|
||||
client_redirect_url_param = urllib.parse.urlencode(
|
||||
{b"redirectUrl": client_redirect_url}
|
||||
).encode("ascii")
|
||||
hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
|
||||
service_param = urllib.parse.urlencode(
|
||||
{b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
|
||||
).encode("ascii")
|
||||
return b"%s/login?%s" % (self.cas_server_url, service_param)
|
||||
|
||||
|
||||
|
@ -436,7 +420,7 @@ class CasTicketServlet(RestServlet):
|
|||
uri = self.cas_server_url + "/proxyValidate"
|
||||
args = {
|
||||
"ticket": parse_string(request, "ticket", required=True),
|
||||
"service": self.cas_service_url
|
||||
"service": self.cas_service_url,
|
||||
}
|
||||
try:
|
||||
body = yield self._http_client.get_raw(uri, args)
|
||||
|
@ -463,7 +447,7 @@ class CasTicketServlet(RestServlet):
|
|||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
return self._sso_auth_handler.on_successful_auth(
|
||||
user, request, client_redirect_url,
|
||||
user, request, client_redirect_url
|
||||
)
|
||||
|
||||
def parse_cas_response(self, cas_response_body):
|
||||
|
@ -473,7 +457,7 @@ class CasTicketServlet(RestServlet):
|
|||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise Exception("root of CAS response is not serviceResponse")
|
||||
success = (root[0].tag.endswith("authenticationSuccess"))
|
||||
success = root[0].tag.endswith("authenticationSuccess")
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
|
@ -491,11 +475,11 @@ class CasTicketServlet(RestServlet):
|
|||
raise Exception("CAS response does not contain user")
|
||||
except Exception:
|
||||
logger.error("Error parsing CAS response", exc_info=1)
|
||||
raise LoginError(401, "Invalid CAS response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not success:
|
||||
raise LoginError(401, "Unsuccessful CAS response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
raise LoginError(
|
||||
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
return user, attributes
|
||||
|
||||
|
||||
|
@ -507,11 +491,11 @@ class SAMLRedirectServlet(BaseSsoRedirectServlet):
|
|||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
reqid, info = self._saml_client.prepare_for_authenticate(
|
||||
relay_state=client_redirect_url,
|
||||
relay_state=client_redirect_url
|
||||
)
|
||||
|
||||
for key, value in info['headers']:
|
||||
if key == 'Location':
|
||||
for key, value in info["headers"]:
|
||||
if key == "Location":
|
||||
return value
|
||||
|
||||
# this shouldn't happen!
|
||||
|
@ -526,6 +510,7 @@ class SSOAuthHandler(object):
|
|||
Args:
|
||||
hs (synapse.server.HomeServer)
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self._hostname = hs.hostname
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
@ -534,8 +519,7 @@ class SSOAuthHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_successful_auth(
|
||||
self, username, request, client_redirect_url,
|
||||
user_display_name=None,
|
||||
self, username, request, client_redirect_url, user_display_name=None
|
||||
):
|
||||
"""Called once the user has successfully authenticated with the SSO.
|
||||
|
||||
|
|
|
@ -46,7 +46,8 @@ class LogoutRestServlet(RestServlet):
|
|||
yield self._auth_handler.delete_access_token(access_token)
|
||||
else:
|
||||
yield self._device_handler.delete_device(
|
||||
requester.user.to_string(), requester.device_id)
|
||||
requester.user.to_string(), requester.device_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class PresenceStatusRestServlet(RestServlet):
|
|||
|
||||
if requester.user != user:
|
||||
allowed = yield self.presence_handler.is_visible(
|
||||
observed_user=user, observer_user=requester.user,
|
||||
observed_user=user, observer_user=requester.user
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
|
|
|
@ -63,8 +63,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
|||
except Exception:
|
||||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, requester, new_name, is_admin)
|
||||
yield self.profile_handler.set_displayname(user, requester, new_name, is_admin)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -113,8 +112,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
|||
except Exception:
|
||||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.profile_handler.set_avatar_url(
|
||||
user, requester, new_name, is_admin)
|
||||
yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
|
|
@ -21,7 +21,11 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
UnrecognizedRequestError,
|
||||
)
|
||||
from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_json_value_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.push.baserules import BASE_RULE_IDS
|
||||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
|
||||
|
@ -32,7 +36,8 @@ from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundExc
|
|||
class PushRuleRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
|
||||
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
|
||||
"Unrecognised request: You probably wanted a trailing slash")
|
||||
"Unrecognised request: You probably wanted a trailing slash"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PushRuleRestServlet, self).__init__()
|
||||
|
@ -54,27 +59,25 @@ class PushRuleRestServlet(RestServlet):
|
|||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
|
||||
if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
|
||||
raise SynapseError(400, "rule_id may not contain slashes")
|
||||
|
||||
content = parse_json_value_from_request(request)
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
if 'attr' in spec:
|
||||
if "attr" in spec:
|
||||
yield self.set_rule_attr(user_id, spec, content)
|
||||
self.notify_user(user_id)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
if spec['rule_id'].startswith('.'):
|
||||
if spec["rule_id"].startswith("."):
|
||||
# Rule ids starting with '.' are reserved for server default rules.
|
||||
raise SynapseError(400, "cannot add new rule_ids that start with '.'")
|
||||
|
||||
try:
|
||||
(conditions, actions) = _rule_tuple_from_request_object(
|
||||
spec['template'],
|
||||
spec['rule_id'],
|
||||
content,
|
||||
spec["template"], spec["rule_id"], content
|
||||
)
|
||||
except InvalidRuleException as e:
|
||||
raise SynapseError(400, str(e))
|
||||
|
@ -95,7 +98,7 @@ class PushRuleRestServlet(RestServlet):
|
|||
conditions=conditions,
|
||||
actions=actions,
|
||||
before=before,
|
||||
after=after
|
||||
after=after,
|
||||
)
|
||||
self.notify_user(user_id)
|
||||
except InconsistentRuleException as e:
|
||||
|
@ -118,9 +121,7 @@ class PushRuleRestServlet(RestServlet):
|
|||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
|
||||
try:
|
||||
yield self.store.delete_push_rule(
|
||||
user_id, namespaced_rule_id
|
||||
)
|
||||
yield self.store.delete_push_rule(user_id, namespaced_rule_id)
|
||||
self.notify_user(user_id)
|
||||
defer.returnValue((200, {}))
|
||||
except StoreError as e:
|
||||
|
@ -149,10 +150,10 @@ class PushRuleRestServlet(RestServlet):
|
|||
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
|
||||
)
|
||||
|
||||
if path[0] == '':
|
||||
if path[0] == "":
|
||||
defer.returnValue((200, rules))
|
||||
elif path[0] == 'global':
|
||||
result = _filter_ruleset_with_path(rules['global'], path[1:])
|
||||
elif path[0] == "global":
|
||||
result = _filter_ruleset_with_path(rules["global"], path[1:])
|
||||
defer.returnValue((200, result))
|
||||
else:
|
||||
raise UnrecognizedRequestError()
|
||||
|
@ -162,12 +163,10 @@ class PushRuleRestServlet(RestServlet):
|
|||
|
||||
def notify_user(self, user_id):
|
||||
stream_id, _ = self.store.get_push_rules_stream_token()
|
||||
self.notifier.on_new_event(
|
||||
"push_rules_key", stream_id, users=[user_id]
|
||||
)
|
||||
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
|
||||
|
||||
def set_rule_attr(self, user_id, spec, val):
|
||||
if spec['attr'] == 'enabled':
|
||||
if spec["attr"] == "enabled":
|
||||
if isinstance(val, dict) and "enabled" in val:
|
||||
val = val["enabled"]
|
||||
if not isinstance(val, bool):
|
||||
|
@ -176,14 +175,12 @@ class PushRuleRestServlet(RestServlet):
|
|||
# bools directly, so let's not break them.
|
||||
raise SynapseError(400, "Value for 'enabled' must be boolean")
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
return self.store.set_push_rule_enabled(
|
||||
user_id, namespaced_rule_id, val
|
||||
)
|
||||
elif spec['attr'] == 'actions':
|
||||
actions = val.get('actions')
|
||||
return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
|
||||
elif spec["attr"] == "actions":
|
||||
actions = val.get("actions")
|
||||
_check_actions(actions)
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
rule_id = spec['rule_id']
|
||||
rule_id = spec["rule_id"]
|
||||
is_default_rule = rule_id.startswith(".")
|
||||
if is_default_rule:
|
||||
if namespaced_rule_id not in BASE_RULE_IDS:
|
||||
|
@ -210,12 +207,12 @@ def _rule_spec_from_path(path):
|
|||
"""
|
||||
if len(path) < 2:
|
||||
raise UnrecognizedRequestError()
|
||||
if path[0] != 'pushrules':
|
||||
if path[0] != "pushrules":
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
scope = path[1]
|
||||
path = path[2:]
|
||||
if scope != 'global':
|
||||
if scope != "global":
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
if len(path) == 0:
|
||||
|
@ -229,56 +226,40 @@ def _rule_spec_from_path(path):
|
|||
|
||||
rule_id = path[0]
|
||||
|
||||
spec = {
|
||||
'scope': scope,
|
||||
'template': template,
|
||||
'rule_id': rule_id
|
||||
}
|
||||
spec = {"scope": scope, "template": template, "rule_id": rule_id}
|
||||
|
||||
path = path[1:]
|
||||
|
||||
if len(path) > 0 and len(path[0]) > 0:
|
||||
spec['attr'] = path[0]
|
||||
spec["attr"] = path[0]
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
|
||||
if rule_template in ['override', 'underride']:
|
||||
if 'conditions' not in req_obj:
|
||||
if rule_template in ["override", "underride"]:
|
||||
if "conditions" not in req_obj:
|
||||
raise InvalidRuleException("Missing 'conditions'")
|
||||
conditions = req_obj['conditions']
|
||||
conditions = req_obj["conditions"]
|
||||
for c in conditions:
|
||||
if 'kind' not in c:
|
||||
if "kind" not in c:
|
||||
raise InvalidRuleException("Condition without 'kind'")
|
||||
elif rule_template == 'room':
|
||||
conditions = [{
|
||||
'kind': 'event_match',
|
||||
'key': 'room_id',
|
||||
'pattern': rule_id
|
||||
}]
|
||||
elif rule_template == 'sender':
|
||||
conditions = [{
|
||||
'kind': 'event_match',
|
||||
'key': 'user_id',
|
||||
'pattern': rule_id
|
||||
}]
|
||||
elif rule_template == 'content':
|
||||
if 'pattern' not in req_obj:
|
||||
elif rule_template == "room":
|
||||
conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}]
|
||||
elif rule_template == "sender":
|
||||
conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}]
|
||||
elif rule_template == "content":
|
||||
if "pattern" not in req_obj:
|
||||
raise InvalidRuleException("Content rule missing 'pattern'")
|
||||
pat = req_obj['pattern']
|
||||
pat = req_obj["pattern"]
|
||||
|
||||
conditions = [{
|
||||
'kind': 'event_match',
|
||||
'key': 'content.body',
|
||||
'pattern': pat
|
||||
}]
|
||||
conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}]
|
||||
else:
|
||||
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
|
||||
|
||||
if 'actions' not in req_obj:
|
||||
if "actions" not in req_obj:
|
||||
raise InvalidRuleException("No actions found")
|
||||
actions = req_obj['actions']
|
||||
actions = req_obj["actions"]
|
||||
|
||||
_check_actions(actions)
|
||||
|
||||
|
@ -290,9 +271,9 @@ def _check_actions(actions):
|
|||
raise InvalidRuleException("No actions found")
|
||||
|
||||
for a in actions:
|
||||
if a in ['notify', 'dont_notify', 'coalesce']:
|
||||
if a in ["notify", "dont_notify", "coalesce"]:
|
||||
pass
|
||||
elif isinstance(a, dict) and 'set_tweak' in a:
|
||||
elif isinstance(a, dict) and "set_tweak" in a:
|
||||
pass
|
||||
else:
|
||||
raise InvalidRuleException("Unrecognised action")
|
||||
|
@ -304,7 +285,7 @@ def _filter_ruleset_with_path(ruleset, path):
|
|||
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
|
||||
)
|
||||
|
||||
if path[0] == '':
|
||||
if path[0] == "":
|
||||
return ruleset
|
||||
template_kind = path[0]
|
||||
if template_kind not in ruleset:
|
||||
|
@ -314,13 +295,13 @@ def _filter_ruleset_with_path(ruleset, path):
|
|||
raise UnrecognizedRequestError(
|
||||
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
|
||||
)
|
||||
if path[0] == '':
|
||||
if path[0] == "":
|
||||
return ruleset[template_kind]
|
||||
rule_id = path[0]
|
||||
|
||||
the_rule = None
|
||||
for r in ruleset[template_kind]:
|
||||
if r['rule_id'] == rule_id:
|
||||
if r["rule_id"] == rule_id:
|
||||
the_rule = r
|
||||
if the_rule is None:
|
||||
raise NotFoundError
|
||||
|
@ -339,19 +320,19 @@ def _filter_ruleset_with_path(ruleset, path):
|
|||
|
||||
|
||||
def _priority_class_from_spec(spec):
|
||||
if spec['template'] not in PRIORITY_CLASS_MAP.keys():
|
||||
raise InvalidRuleException("Unknown template: %s" % (spec['template']))
|
||||
pc = PRIORITY_CLASS_MAP[spec['template']]
|
||||
if spec["template"] not in PRIORITY_CLASS_MAP.keys():
|
||||
raise InvalidRuleException("Unknown template: %s" % (spec["template"]))
|
||||
pc = PRIORITY_CLASS_MAP[spec["template"]]
|
||||
|
||||
return pc
|
||||
|
||||
|
||||
def _namespaced_rule_id_from_spec(spec):
|
||||
return _namespaced_rule_id(spec, spec['rule_id'])
|
||||
return _namespaced_rule_id(spec, spec["rule_id"])
|
||||
|
||||
|
||||
def _namespaced_rule_id(spec, rule_id):
|
||||
return "global/%s/%s" % (spec['template'], rule_id)
|
||||
return "global/%s/%s" % (spec["template"], rule_id)
|
||||
|
||||
|
||||
class InvalidRuleException(Exception):
|
||||
|
|
|
@ -44,9 +44,7 @@ class PushersRestServlet(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = requester.user
|
||||
|
||||
pushers = yield self.hs.get_datastore().get_pushers_by_user_id(
|
||||
user.to_string()
|
||||
)
|
||||
pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
|
||||
|
||||
allowed_keys = [
|
||||
"app_display_name",
|
||||
|
@ -87,50 +85,61 @@ class PushersSetRestServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
if ('pushkey' in content and 'app_id' in content
|
||||
and 'kind' in content and
|
||||
content['kind'] is None):
|
||||
if (
|
||||
"pushkey" in content
|
||||
and "app_id" in content
|
||||
and "kind" in content
|
||||
and content["kind"] is None
|
||||
):
|
||||
yield self.pusher_pool.remove_pusher(
|
||||
content['app_id'], content['pushkey'], user_id=user.to_string()
|
||||
content["app_id"], content["pushkey"], user_id=user.to_string()
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
assert_params_in_dict(
|
||||
content,
|
||||
['kind', 'app_id', 'app_display_name',
|
||||
'device_display_name', 'pushkey', 'lang', 'data']
|
||||
[
|
||||
"kind",
|
||||
"app_id",
|
||||
"app_display_name",
|
||||
"device_display_name",
|
||||
"pushkey",
|
||||
"lang",
|
||||
"data",
|
||||
],
|
||||
)
|
||||
|
||||
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
|
||||
logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"])
|
||||
logger.debug("Got pushers request with body: %r", content)
|
||||
|
||||
append = False
|
||||
if 'append' in content:
|
||||
append = content['append']
|
||||
if "append" in content:
|
||||
append = content["append"]
|
||||
|
||||
if not append:
|
||||
yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
|
||||
app_id=content['app_id'],
|
||||
pushkey=content['pushkey'],
|
||||
not_user_id=user.to_string()
|
||||
app_id=content["app_id"],
|
||||
pushkey=content["pushkey"],
|
||||
not_user_id=user.to_string(),
|
||||
)
|
||||
|
||||
try:
|
||||
yield self.pusher_pool.add_pusher(
|
||||
user_id=user.to_string(),
|
||||
access_token=requester.access_token_id,
|
||||
kind=content['kind'],
|
||||
app_id=content['app_id'],
|
||||
app_display_name=content['app_display_name'],
|
||||
device_display_name=content['device_display_name'],
|
||||
pushkey=content['pushkey'],
|
||||
lang=content['lang'],
|
||||
data=content['data'],
|
||||
profile_tag=content.get('profile_tag', ""),
|
||||
kind=content["kind"],
|
||||
app_id=content["app_id"],
|
||||
app_display_name=content["app_display_name"],
|
||||
device_display_name=content["device_display_name"],
|
||||
pushkey=content["pushkey"],
|
||||
lang=content["lang"],
|
||||
data=content["data"],
|
||||
profile_tag=content.get("profile_tag", ""),
|
||||
)
|
||||
except PusherConfigException as pce:
|
||||
raise SynapseError(400, "Config Error: " + str(pce),
|
||||
errcode=Codes.MISSING_PARAM)
|
||||
raise SynapseError(
|
||||
400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
|
@ -144,6 +153,7 @@ class PushersRemoveRestServlet(RestServlet):
|
|||
"""
|
||||
To allow pusher to be delete by clicking a link (ie. GET request)
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/pushers/remove$", v1=True)
|
||||
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
|
||||
|
||||
|
@ -164,9 +174,7 @@ class PushersRemoveRestServlet(RestServlet):
|
|||
|
||||
try:
|
||||
yield self.pusher_pool.remove_pusher(
|
||||
app_id=app_id,
|
||||
pushkey=pushkey,
|
||||
user_id=user.to_string(),
|
||||
app_id=app_id, pushkey=pushkey, user_id=user.to_string()
|
||||
)
|
||||
except StoreError as se:
|
||||
if se.code != 404:
|
||||
|
@ -177,9 +185,9 @@ class PushersRemoveRestServlet(RestServlet):
|
|||
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (
|
||||
len(PushersRemoveRestServlet.SUCCESS_HTML),
|
||||
))
|
||||
request.setHeader(
|
||||
b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),)
|
||||
)
|
||||
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
|
||||
finish_request(request)
|
||||
defer.returnValue(None)
|
||||
|
|
|
@ -61,18 +61,16 @@ class RoomCreateRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/createRoom"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
|
||||
http_server.register_paths("OPTIONS",
|
||||
client_patterns("/rooms(?:/.*)?$", v1=True),
|
||||
self.on_OPTIONS)
|
||||
http_server.register_paths(
|
||||
"OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS
|
||||
)
|
||||
# define CORS for /createRoom[/txnid]
|
||||
http_server.register_paths("OPTIONS",
|
||||
client_patterns("/createRoom(?:/.*)?$", v1=True),
|
||||
self.on_OPTIONS)
|
||||
http_server.register_paths(
|
||||
"OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS
|
||||
)
|
||||
|
||||
def on_PUT(self, request, txn_id):
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request
|
||||
)
|
||||
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -107,21 +105,23 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
|
||||
|
||||
# /room/$roomid/state/$eventtype/$statekey
|
||||
state_key = ("/rooms/(?P<room_id>[^/]*)/state/"
|
||||
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
|
||||
state_key = (
|
||||
"/rooms/(?P<room_id>[^/]*)/state/"
|
||||
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$"
|
||||
)
|
||||
|
||||
http_server.register_paths("GET",
|
||||
client_patterns(state_key, v1=True),
|
||||
self.on_GET)
|
||||
http_server.register_paths("PUT",
|
||||
client_patterns(state_key, v1=True),
|
||||
self.on_PUT)
|
||||
http_server.register_paths("GET",
|
||||
client_patterns(no_state_key, v1=True),
|
||||
self.on_GET_no_state_key)
|
||||
http_server.register_paths("PUT",
|
||||
client_patterns(no_state_key, v1=True),
|
||||
self.on_PUT_no_state_key)
|
||||
http_server.register_paths(
|
||||
"GET", client_patterns(state_key, v1=True), self.on_GET
|
||||
)
|
||||
http_server.register_paths(
|
||||
"PUT", client_patterns(state_key, v1=True), self.on_PUT
|
||||
)
|
||||
http_server.register_paths(
|
||||
"GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key
|
||||
)
|
||||
http_server.register_paths(
|
||||
"PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key
|
||||
)
|
||||
|
||||
def on_GET_no_state_key(self, request, room_id, event_type):
|
||||
return self.on_GET(request, room_id, event_type, "")
|
||||
|
@ -132,8 +132,9 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_type, state_key):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
format = parse_string(request, "format", default="content",
|
||||
allowed_values=["content", "event"])
|
||||
format = parse_string(
|
||||
request, "format", default="content", allowed_values=["content", "event"]
|
||||
)
|
||||
|
||||
msg_handler = self.message_handler
|
||||
data = yield msg_handler.get_room_data(
|
||||
|
@ -145,9 +146,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
)
|
||||
|
||||
if not data:
|
||||
raise SynapseError(
|
||||
404, "Event not found.", errcode=Codes.NOT_FOUND
|
||||
)
|
||||
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
|
||||
|
||||
if format == "event":
|
||||
event = format_event_for_client_v2(data.get_dict())
|
||||
|
@ -182,9 +181,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
)
|
||||
else:
|
||||
event = yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
event_dict,
|
||||
txn_id=txn_id,
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
|
||||
ret = {}
|
||||
|
@ -195,7 +192,6 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
|
||||
# TODO: Needs unit testing for generic events + feedback
|
||||
class RoomSendEventRestServlet(TransactionRestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomSendEventRestServlet, self).__init__(hs)
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
@ -203,7 +199,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
|
||||
def register(self, http_server):
|
||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
|
||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server, with_get=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -218,13 +214,11 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
"sender": requester.user.to_string(),
|
||||
}
|
||||
|
||||
if b'ts' in request.args and requester.app_service:
|
||||
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
|
||||
if b"ts" in request.args and requester.app_service:
|
||||
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
|
||||
|
||||
event = yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
event_dict,
|
||||
txn_id=txn_id,
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {"event_id": event.event_id}))
|
||||
|
@ -247,15 +241,12 @@ class JoinRoomAliasServlet(TransactionRestServlet):
|
|||
|
||||
def register(self, http_server):
|
||||
# /join/$room_identifier[/$txn_id]
|
||||
PATTERNS = ("/join/(?P<room_identifier>[^/]*)")
|
||||
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_identifier, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=True,
|
||||
)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
try:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
@ -268,7 +259,7 @@ class JoinRoomAliasServlet(TransactionRestServlet):
|
|||
room_id = room_identifier
|
||||
try:
|
||||
remote_room_hosts = [
|
||||
x.decode('ascii') for x in request.args[b"server_name"]
|
||||
x.decode("ascii") for x in request.args[b"server_name"]
|
||||
]
|
||||
except Exception:
|
||||
remote_room_hosts = None
|
||||
|
@ -278,9 +269,9 @@ class JoinRoomAliasServlet(TransactionRestServlet):
|
|||
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
|
||||
room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(400, "%s was not legal room ID or room alias" % (
|
||||
room_identifier,
|
||||
))
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
)
|
||||
|
||||
yield self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
|
@ -320,7 +311,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||
# Option to allow servers to require auth when accessing
|
||||
# /publicRooms via CS API. This is especially helpful in private
|
||||
# federations.
|
||||
if self.hs.config.restrict_public_rooms_to_local_users:
|
||||
if not self.hs.config.allow_public_rooms_without_auth:
|
||||
raise
|
||||
|
||||
# We allow people to not be authed if they're just looking at our
|
||||
|
@ -339,14 +330,11 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||
handler = self.hs.get_room_list_handler()
|
||||
if server:
|
||||
data = yield handler.get_remote_public_room_list(
|
||||
server,
|
||||
limit=limit,
|
||||
since_token=since_token,
|
||||
server, limit=limit, since_token=since_token
|
||||
)
|
||||
else:
|
||||
data = yield handler.get_local_public_room_list(
|
||||
limit=limit,
|
||||
since_token=since_token,
|
||||
limit=limit, since_token=since_token
|
||||
)
|
||||
|
||||
defer.returnValue((200, data))
|
||||
|
@ -439,16 +427,13 @@ class RoomMemberListRestServlet(RestServlet):
|
|||
chunk = []
|
||||
|
||||
for event in events:
|
||||
if (
|
||||
(membership and event['content'].get("membership") != membership) or
|
||||
(not_membership and event['content'].get("membership") == not_membership)
|
||||
if (membership and event["content"].get("membership") != membership) or (
|
||||
not_membership and event["content"].get("membership") == not_membership
|
||||
):
|
||||
continue
|
||||
chunk.append(event)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"chunk": chunk
|
||||
}))
|
||||
defer.returnValue((200, {"chunk": chunk}))
|
||||
|
||||
|
||||
# deprecated in favour of /members?membership=join?
|
||||
|
@ -466,12 +451,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
users_with_profile = yield self.message_handler.get_joined_members(
|
||||
requester, room_id,
|
||||
requester, room_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"joined": users_with_profile,
|
||||
}))
|
||||
defer.returnValue((200, {"joined": users_with_profile}))
|
||||
|
||||
|
||||
# TODO: Needs better unit testing
|
||||
|
@ -486,9 +469,7 @@ class RoomMessageListRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(
|
||||
request, default_limit=10,
|
||||
)
|
||||
pagination_config = PaginationConfig.from_request(request, default_limit=10)
|
||||
as_client_event = b"raw" not in request.args
|
||||
filter_bytes = parse_string(request, b"filter", encoding=None)
|
||||
if filter_bytes:
|
||||
|
@ -544,9 +525,7 @@ class RoomInitialSyncRestServlet(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
content = yield self.initial_sync_handler.room_initial_sync(
|
||||
room_id=room_id,
|
||||
requester=requester,
|
||||
pagin_config=pagination_config,
|
||||
room_id=room_id, requester=requester, pagin_config=pagination_config
|
||||
)
|
||||
defer.returnValue((200, content))
|
||||
|
||||
|
@ -603,30 +582,24 @@ class RoomEventContextServlet(RestServlet):
|
|||
event_filter = None
|
||||
|
||||
results = yield self.room_context_handler.get_event_context(
|
||||
requester.user,
|
||||
room_id,
|
||||
event_id,
|
||||
limit,
|
||||
event_filter,
|
||||
requester.user, room_id, event_id, limit, event_filter
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise SynapseError(
|
||||
404, "Event not found.", errcode=Codes.NOT_FOUND
|
||||
)
|
||||
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
results["events_before"] = yield self._event_serializer.serialize_events(
|
||||
results["events_before"], time_now,
|
||||
results["events_before"], time_now
|
||||
)
|
||||
results["event"] = yield self._event_serializer.serialize_event(
|
||||
results["event"], time_now,
|
||||
results["event"], time_now
|
||||
)
|
||||
results["events_after"] = yield self._event_serializer.serialize_events(
|
||||
results["events_after"], time_now,
|
||||
results["events_after"], time_now
|
||||
)
|
||||
results["state"] = yield self._event_serializer.serialize_events(
|
||||
results["state"], time_now,
|
||||
results["state"], time_now
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
@ -639,20 +612,14 @@ class RoomForgetRestServlet(TransactionRestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
|
||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=False,
|
||||
)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||
|
||||
yield self.room_member_handler.forget(
|
||||
user=requester.user,
|
||||
room_id=room_id,
|
||||
)
|
||||
yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -664,7 +631,6 @@ class RoomForgetRestServlet(TransactionRestServlet):
|
|||
|
||||
# TODO: Needs unit testing
|
||||
class RoomMembershipRestServlet(TransactionRestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomMembershipRestServlet, self).__init__(hs)
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
|
@ -672,20 +638,19 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
|
||||
def register(self, http_server):
|
||||
# /rooms/$roomid/[invite|join|leave]
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
|
||||
"(?P<membership_action>join|invite|leave|ban|unban|kick)")
|
||||
PATTERNS = (
|
||||
"/rooms/(?P<room_id>[^/]*)/"
|
||||
"(?P<membership_action>join|invite|leave|ban|unban|kick)"
|
||||
)
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, membership_action, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=True,
|
||||
)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
if requester.is_guest and membership_action not in {
|
||||
Membership.JOIN,
|
||||
Membership.LEAVE
|
||||
Membership.LEAVE,
|
||||
}:
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
|
@ -704,7 +669,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
content["address"],
|
||||
content["id_server"],
|
||||
requester,
|
||||
txn_id
|
||||
txn_id,
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
return
|
||||
|
@ -715,8 +680,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
target = UserID.from_string(content["user_id"])
|
||||
|
||||
event_content = None
|
||||
if 'reason' in content and membership_action in ['kick', 'ban']:
|
||||
event_content = {'reason': content['reason']}
|
||||
if "reason" in content and membership_action in ["kick", "ban"]:
|
||||
event_content = {"reason": content["reason"]}
|
||||
|
||||
yield self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
|
@ -755,7 +720,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
|
||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -817,9 +782,7 @@ class RoomTypingRestServlet(RestServlet):
|
|||
)
|
||||
else:
|
||||
yield self.typing_handler.stopped_typing(
|
||||
target_user=target_user,
|
||||
auth_user=requester.user,
|
||||
room_id=room_id,
|
||||
target_user=target_user, auth_user=requester.user, room_id=room_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -841,9 +804,7 @@ class SearchRestServlet(RestServlet):
|
|||
|
||||
batch = parse_string(request, "next_batch")
|
||||
results = yield self.handlers.search_handler.search(
|
||||
requester.user,
|
||||
content,
|
||||
batch,
|
||||
requester.user, content, batch
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
@ -879,20 +840,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
|
|||
with_get: True to also register respective GET paths for the PUTs.
|
||||
"""
|
||||
http_server.register_paths(
|
||||
"POST",
|
||||
client_patterns(regex_string + "$", v1=True),
|
||||
servlet.on_POST
|
||||
"POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST
|
||||
)
|
||||
http_server.register_paths(
|
||||
"PUT",
|
||||
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
|
||||
servlet.on_PUT
|
||||
servlet.on_PUT,
|
||||
)
|
||||
if with_get:
|
||||
http_server.register_paths(
|
||||
"GET",
|
||||
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
|
||||
servlet.on_GET
|
||||
servlet.on_GET,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -34,8 +34,7 @@ class VoipRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
self.hs.config.turn_allow_guests
|
||||
request, self.hs.config.turn_allow_guests
|
||||
)
|
||||
|
||||
turnUris = self.hs.config.turn_uris
|
||||
|
@ -49,9 +48,7 @@ class VoipRestServlet(RestServlet):
|
|||
username = "%d:%s" % (expiry, requester.user.to_string())
|
||||
|
||||
mac = hmac.new(
|
||||
turnSecret.encode(),
|
||||
msg=username.encode(),
|
||||
digestmod=hashlib.sha1
|
||||
turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1
|
||||
)
|
||||
# We need to use standard padded base64 encoding here
|
||||
# encode_base64 because we need to add the standard padding to get the
|
||||
|
@ -65,12 +62,17 @@ class VoipRestServlet(RestServlet):
|
|||
else:
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
defer.returnValue((200, {
|
||||
'username': username,
|
||||
'password': password,
|
||||
'ttl': userLifetime / 1000,
|
||||
'uris': turnUris,
|
||||
}))
|
||||
defer.returnValue(
|
||||
(
|
||||
200,
|
||||
{
|
||||
"username": username,
|
||||
"password": password,
|
||||
"ttl": userLifetime / 1000,
|
||||
"uris": turnUris,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue