mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Merge branch 'develop' into dinsic
This commit is contained in:
commit
75b25b3f1f
@ -46,6 +46,7 @@ class Codes(object):
|
|||||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||||
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
||||||
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
||||||
|
THREEPID_DENIED = "M_THREEPID_DENIED"
|
||||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||||
|
|
||||||
|
@ -31,6 +31,8 @@ class RegistrationConfig(Config):
|
|||||||
strtobool(str(config["disable_registration"]))
|
strtobool(str(config["disable_registration"]))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||||
|
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||||
|
|
||||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||||
@ -52,6 +54,23 @@ class RegistrationConfig(Config):
|
|||||||
# Enable registration for new users.
|
# Enable registration for new users.
|
||||||
enable_registration: False
|
enable_registration: False
|
||||||
|
|
||||||
|
# The user must provide all of the below types of 3PID when registering.
|
||||||
|
#
|
||||||
|
# registrations_require_3pid:
|
||||||
|
# - email
|
||||||
|
# - msisdn
|
||||||
|
|
||||||
|
# Mandate that users are only allowed to associate certain formats of
|
||||||
|
# 3PIDs with accounts on this server.
|
||||||
|
#
|
||||||
|
# allowed_local_3pids:
|
||||||
|
# - medium: email
|
||||||
|
# pattern: ".*@matrix\\.org"
|
||||||
|
# - medium: email
|
||||||
|
# pattern: ".*@vector\\.im"
|
||||||
|
# - medium: msisdn
|
||||||
|
# pattern: "\\+44"
|
||||||
|
|
||||||
# If set, allows registration by anyone who also has the shared
|
# If set, allows registration by anyone who also has the shared
|
||||||
# secret, even if registration is otherwise disabled.
|
# secret, even if registration is otherwise disabled.
|
||||||
registration_shared_secret: "%(registration_shared_secret)s"
|
registration_shared_secret: "%(registration_shared_secret)s"
|
||||||
|
@ -16,6 +16,8 @@
|
|||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from synapse.util.module_loader import load_module
|
||||||
|
|
||||||
|
|
||||||
MISSING_NETADDR = (
|
MISSING_NETADDR = (
|
||||||
"Missing netaddr library. This is required for URL preview API."
|
"Missing netaddr library. This is required for URL preview API."
|
||||||
@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
|
|||||||
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
|
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MediaStorageProviderConfig = namedtuple(
|
||||||
|
"MediaStorageProviderConfig", (
|
||||||
|
"store_local", # Whether to store newly uploaded local files
|
||||||
|
"store_remote", # Whether to store newly downloaded remote files
|
||||||
|
"store_synchronous", # Whether to wait for successful storage for local uploads
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_thumbnail_requirements(thumbnail_sizes):
|
def parse_thumbnail_requirements(thumbnail_sizes):
|
||||||
""" Takes a list of dictionaries with "width", "height", and "method" keys
|
""" Takes a list of dictionaries with "width", "height", and "method" keys
|
||||||
@ -73,14 +83,59 @@ class ContentRepositoryConfig(Config):
|
|||||||
|
|
||||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||||
|
|
||||||
self.backup_media_store_path = config.get("backup_media_store_path")
|
backup_media_store_path = config.get("backup_media_store_path")
|
||||||
if self.backup_media_store_path:
|
|
||||||
self.backup_media_store_path = self.ensure_directory(
|
synchronous_backup_media_store = config.get(
|
||||||
self.backup_media_store_path
|
"synchronous_backup_media_store", False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.synchronous_backup_media_store = config.get(
|
storage_providers = config.get("media_storage_providers", [])
|
||||||
"synchronous_backup_media_store", False
|
|
||||||
|
if backup_media_store_path:
|
||||||
|
if storage_providers:
|
||||||
|
raise ConfigError(
|
||||||
|
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_providers = [{
|
||||||
|
"module": "file_system",
|
||||||
|
"store_local": True,
|
||||||
|
"store_synchronous": synchronous_backup_media_store,
|
||||||
|
"store_remote": True,
|
||||||
|
"config": {
|
||||||
|
"directory": backup_media_store_path,
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
# This is a list of config that can be used to create the storage
|
||||||
|
# providers. The entries are tuples of (Class, class_config,
|
||||||
|
# MediaStorageProviderConfig), where Class is the class of the provider,
|
||||||
|
# the class_config the config to pass to it, and
|
||||||
|
# MediaStorageProviderConfig are options for StorageProviderWrapper.
|
||||||
|
#
|
||||||
|
# We don't create the storage providers here as not all workers need
|
||||||
|
# them to be started.
|
||||||
|
self.media_storage_providers = []
|
||||||
|
|
||||||
|
for provider_config in storage_providers:
|
||||||
|
# We special case the module "file_system" so as not to need to
|
||||||
|
# expose FileStorageProviderBackend
|
||||||
|
if provider_config["module"] == "file_system":
|
||||||
|
provider_config["module"] = (
|
||||||
|
"synapse.rest.media.v1.storage_provider"
|
||||||
|
".FileStorageProviderBackend"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_class, parsed_config = load_module(provider_config)
|
||||||
|
|
||||||
|
wrapper_config = MediaStorageProviderConfig(
|
||||||
|
provider_config.get("store_local", False),
|
||||||
|
provider_config.get("store_remote", False),
|
||||||
|
provider_config.get("store_synchronous", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.media_storage_providers.append(
|
||||||
|
(provider_class, parsed_config, wrapper_config,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
||||||
@ -127,13 +182,19 @@ class ContentRepositoryConfig(Config):
|
|||||||
# Directory where uploaded images and attachments are stored.
|
# Directory where uploaded images and attachments are stored.
|
||||||
media_store_path: "%(media_store)s"
|
media_store_path: "%(media_store)s"
|
||||||
|
|
||||||
# A secondary directory where uploaded images and attachments are
|
# Media storage providers allow media to be stored in different
|
||||||
# stored as a backup.
|
# locations.
|
||||||
# backup_media_store_path: "%(media_store)s"
|
# media_storage_providers:
|
||||||
|
# - module: file_system
|
||||||
# Whether to wait for successful write to backup media store before
|
# # Whether to write new local files.
|
||||||
# returning successfully.
|
# store_local: false
|
||||||
# synchronous_backup_media_store: false
|
# # Whether to write new remote media
|
||||||
|
# store_remote: false
|
||||||
|
# # Whether to block upload requests waiting for write to this
|
||||||
|
# # provider to complete
|
||||||
|
# store_synchronous: false
|
||||||
|
# config:
|
||||||
|
# directory: /mnt/some/other/directory
|
||||||
|
|
||||||
# Directory where in-progress uploads are stored.
|
# Directory where in-progress uploads are stored.
|
||||||
uploads_path: "%(uploads_path)s"
|
uploads_path: "%(uploads_path)s"
|
||||||
|
@ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient
|
|||||||
from synapse import types
|
from synapse import types
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -293,7 +294,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
for c in threepidCreds:
|
for c in threepidCreds:
|
||||||
logger.info("validating theeepidcred sid %s on id server %s",
|
logger.info("validating threepidcred sid %s on id server %s",
|
||||||
c['sid'], c['idServer'])
|
c['sid'], c['idServer'])
|
||||||
try:
|
try:
|
||||||
identity_handler = self.hs.get_handlers().identity_handler
|
identity_handler = self.hs.get_handlers().identity_handler
|
||||||
@ -307,6 +308,11 @@ class RegistrationHandler(BaseHandler):
|
|||||||
logger.info("got threepid with medium '%s' and address '%s'",
|
logger.info("got threepid with medium '%s' and address '%s'",
|
||||||
threepid['medium'], threepid['address'])
|
threepid['medium'], threepid['address'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
|
||||||
|
raise RegistrationError(
|
||||||
|
403, "Third party identifier is not allowed"
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def bind_emails(self, user_id, threepidCreds):
|
def bind_emails(self, user_id, threepidCreds):
|
||||||
"""Links emails with a user ID and informs an identity server.
|
"""Links emails with a user ID and informs an identity server.
|
||||||
|
@ -15,6 +15,9 @@
|
|||||||
|
|
||||||
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def flatten(items):
|
def flatten(items):
|
||||||
@ -153,7 +156,11 @@ class CallbackMetric(BaseMetric):
|
|||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
|
try:
|
||||||
value = self.callback()
|
value = self.callback()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to render %s", self.name)
|
||||||
|
return ["# FAILED to render " + self.name]
|
||||||
|
|
||||||
if self.is_scalar():
|
if self.is_scalar():
|
||||||
return list(self._render_for_labels([], value))
|
return list(self._render_for_labels([], value))
|
||||||
|
@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
|
|
||||||
|
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||||
|
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||||
|
|
||||||
|
flows = []
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
return (
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
200,
|
if not require_msisdn:
|
||||||
{"flows": [
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.RECAPTCHA,
|
"type": LoginType.RECAPTCHA,
|
||||||
"stages": [
|
"stages": [
|
||||||
@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||||||
LoginType.PASSWORD
|
LoginType.PASSWORD
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
])
|
||||||
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
|
if not require_email and not require_msisdn:
|
||||||
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.RECAPTCHA,
|
"type": LoginType.RECAPTCHA,
|
||||||
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
|
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
|
||||||
}
|
}
|
||||||
]}
|
])
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return (
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
200,
|
if require_email or not require_msisdn:
|
||||||
{"flows": [
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.EMAIL_IDENTITY,
|
"type": LoginType.EMAIL_IDENTITY,
|
||||||
"stages": [
|
"stages": [
|
||||||
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
|
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
|
])
|
||||||
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
|
if not require_email and not require_msisdn:
|
||||||
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.PASSWORD
|
"type": LoginType.PASSWORD
|
||||||
}
|
}
|
||||||
]}
|
])
|
||||||
)
|
return (200, {"flows": flows})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
@ -195,15 +195,20 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
event_dict = {
|
||||||
event = yield msg_handler.create_and_send_nonmember_event(
|
|
||||||
requester,
|
|
||||||
{
|
|
||||||
"type": event_type,
|
"type": event_type,
|
||||||
"content": content,
|
"content": content,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"sender": requester.user.to_string(),
|
"sender": requester.user.to_string(),
|
||||||
},
|
}
|
||||||
|
|
||||||
|
if 'ts' in request.args and requester.app_service:
|
||||||
|
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
|
||||||
|
|
||||||
|
msg_handler = self.handlers.message_handler
|
||||||
|
event = yield msg_handler.create_and_send_nonmember_event(
|
||||||
|
requester,
|
||||||
|
event_dict,
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ from synapse.http.servlet import (
|
|||||||
)
|
)
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
from ._base import client_v2_patterns, interactive_auth_handler
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
)
|
)
|
||||||
@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
|||||||
|
|
||||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
'msisdn', msisdn
|
'msisdn', msisdn
|
||||||
)
|
)
|
||||||
@ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
|||||||
if absent:
|
if absent:
|
||||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
)
|
)
|
||||||
@ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
|||||||
|
|
||||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
'msisdn', msisdn
|
'msisdn', msisdn
|
||||||
)
|
)
|
||||||
|
@ -26,6 +26,7 @@ from synapse.http.servlet import (
|
|||||||
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
||||||
)
|
)
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
|
|
||||||
from ._base import client_v2_patterns, interactive_auth_handler
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
|||||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
)
|
)
|
||||||
@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
|||||||
|
|
||||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'msisdn', msisdn
|
'msisdn', msisdn
|
||||||
)
|
)
|
||||||
@ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet):
|
|||||||
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
||||||
show_msisdn = True
|
show_msisdn = True
|
||||||
|
|
||||||
|
# FIXME: need a better error than "no auth flow found" for scenarios
|
||||||
|
# where we required 3PID for registration but the user didn't give one
|
||||||
|
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||||
|
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||||
|
|
||||||
|
flows = []
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
flows = [
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
[LoginType.RECAPTCHA],
|
if not require_email and not require_msisdn:
|
||||||
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
flows.extend([[LoginType.RECAPTCHA]])
|
||||||
]
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
|
if not require_msisdn:
|
||||||
|
flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
|
||||||
|
|
||||||
if show_msisdn:
|
if show_msisdn:
|
||||||
|
# only support the MSISDN-only flow if we don't require email 3PIDs
|
||||||
|
if not require_email:
|
||||||
|
flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
|
||||||
|
# always let users provide both MSISDN & email
|
||||||
flows.extend([
|
flows.extend([
|
||||||
[LoginType.MSISDN, LoginType.RECAPTCHA],
|
|
||||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
flows = [
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
[LoginType.DUMMY],
|
if not require_email and not require_msisdn:
|
||||||
[LoginType.EMAIL_IDENTITY],
|
flows.extend([[LoginType.DUMMY]])
|
||||||
]
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
|
if not require_msisdn:
|
||||||
|
flows.extend([[LoginType.EMAIL_IDENTITY]])
|
||||||
|
|
||||||
if show_msisdn:
|
if show_msisdn:
|
||||||
|
# only support the MSISDN-only flow if we don't require email 3PIDs
|
||||||
|
if not require_email or require_msisdn:
|
||||||
|
flows.extend([[LoginType.MSISDN]])
|
||||||
|
# always let users provide both MSISDN & email
|
||||||
flows.extend([
|
flows.extend([
|
||||||
[LoginType.MSISDN],
|
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
|
||||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
|
||||||
])
|
])
|
||||||
|
|
||||||
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check that we're not trying to register a denied 3pid.
|
||||||
|
#
|
||||||
|
# the user-facing checks will probably already have happened in
|
||||||
|
# /register/email/requestToken when we requested a 3pid, but that's not
|
||||||
|
# guaranteed.
|
||||||
|
|
||||||
|
if auth_result:
|
||||||
|
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
|
||||||
|
if login_type in auth_result:
|
||||||
|
medium = auth_result[login_type].threepid['medium']
|
||||||
|
address = auth_result[login_type].threepid['address']
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, medium, address):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed",
|
||||||
|
Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
if registered_user_id is not None:
|
if registered_user_id is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Already registered user ID %r for this session",
|
"Already registered user ID %r for this session",
|
||||||
|
@ -27,9 +27,7 @@ from .identicon_resource import IdenticonResource
|
|||||||
from .preview_url_resource import PreviewUrlResource
|
from .preview_url_resource import PreviewUrlResource
|
||||||
from .filepath import MediaFilePaths
|
from .filepath import MediaFilePaths
|
||||||
from .thumbnailer import Thumbnailer
|
from .thumbnailer import Thumbnailer
|
||||||
from .storage_provider import (
|
from .storage_provider import StorageProviderWrapper
|
||||||
StorageProviderWrapper, FileStorageProviderBackend,
|
|
||||||
)
|
|
||||||
from .media_storage import MediaStorage
|
from .media_storage import MediaStorage
|
||||||
|
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
@ -84,17 +82,13 @@ class MediaRepository(object):
|
|||||||
# potentially upload to.
|
# potentially upload to.
|
||||||
storage_providers = []
|
storage_providers = []
|
||||||
|
|
||||||
# TODO: Move this into config and allow other storage providers to be
|
for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
|
||||||
# defined.
|
backend = clz(hs, provider_config)
|
||||||
if hs.config.backup_media_store_path:
|
|
||||||
backend = FileStorageProviderBackend(
|
|
||||||
self.primary_base_path, hs.config.backup_media_store_path,
|
|
||||||
)
|
|
||||||
provider = StorageProviderWrapper(
|
provider = StorageProviderWrapper(
|
||||||
backend,
|
backend,
|
||||||
store=True,
|
store_local=wrapper_config.store_local,
|
||||||
store_synchronous=hs.config.synchronous_backup_media_store,
|
store_remote=wrapper_config.store_remote,
|
||||||
store_remote=True,
|
store_synchronous=wrapper_config.store_synchronous,
|
||||||
)
|
)
|
||||||
storage_providers.append(provider)
|
storage_providers.append(provider)
|
||||||
|
|
||||||
|
@ -164,6 +164,14 @@ class MediaStorage(object):
|
|||||||
str
|
str
|
||||||
"""
|
"""
|
||||||
if file_info.url_cache:
|
if file_info.url_cache:
|
||||||
|
if file_info.thumbnail:
|
||||||
|
return self.filepaths.url_cache_thumbnail_rel(
|
||||||
|
media_id=file_info.file_id,
|
||||||
|
width=file_info.thumbnail_width,
|
||||||
|
height=file_info.thumbnail_height,
|
||||||
|
content_type=file_info.thumbnail_type,
|
||||||
|
method=file_info.thumbnail_method,
|
||||||
|
)
|
||||||
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
|
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
|
||||||
|
|
||||||
if file_info.server_name:
|
if file_info.server_name:
|
||||||
|
@ -17,6 +17,7 @@ from twisted.internet import defer, threads
|
|||||||
|
|
||||||
from .media_storage import FileResponder
|
from .media_storage import FileResponder
|
||||||
|
|
||||||
|
from synapse.config._base import Config
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -64,19 +65,19 @@ class StorageProviderWrapper(StorageProvider):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
backend (StorageProvider)
|
backend (StorageProvider)
|
||||||
store (bool): Whether to store new files or not.
|
store_local (bool): Whether to store new local files or not.
|
||||||
store_synchronous (bool): Whether to wait for file to be successfully
|
store_synchronous (bool): Whether to wait for file to be successfully
|
||||||
uploaded, or todo the upload in the backgroud.
|
uploaded, or todo the upload in the backgroud.
|
||||||
store_remote (bool): Whether remote media should be uploaded
|
store_remote (bool): Whether remote media should be uploaded
|
||||||
"""
|
"""
|
||||||
def __init__(self, backend, store, store_synchronous, store_remote):
|
def __init__(self, backend, store_local, store_synchronous, store_remote):
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
self.store = store
|
self.store_local = store_local
|
||||||
self.store_synchronous = store_synchronous
|
self.store_synchronous = store_synchronous
|
||||||
self.store_remote = store_remote
|
self.store_remote = store_remote
|
||||||
|
|
||||||
def store_file(self, path, file_info):
|
def store_file(self, path, file_info):
|
||||||
if not self.store:
|
if not file_info.server_name and not self.store_local:
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
|
||||||
if file_info.server_name and not self.store_remote:
|
if file_info.server_name and not self.store_remote:
|
||||||
@ -97,13 +98,13 @@ class FileStorageProviderBackend(StorageProvider):
|
|||||||
"""A storage provider that stores files in a directory on a filesystem.
|
"""A storage provider that stores files in a directory on a filesystem.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_directory (str): Base path of the local media repository
|
hs (HomeServer)
|
||||||
base_directory (str): Base path to store new files
|
config: The config returned by `parse_config`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_directory, base_directory):
|
def __init__(self, hs, config):
|
||||||
self.cache_directory = cache_directory
|
self.cache_directory = hs.config.media_store_path
|
||||||
self.base_directory = base_directory
|
self.base_directory = config
|
||||||
|
|
||||||
def store_file(self, path, file_info):
|
def store_file(self, path, file_info):
|
||||||
"""See StorageProvider.store_file"""
|
"""See StorageProvider.store_file"""
|
||||||
@ -125,3 +126,15 @@ class FileStorageProviderBackend(StorageProvider):
|
|||||||
backup_fname = os.path.join(self.base_directory, path)
|
backup_fname = os.path.join(self.base_directory, path)
|
||||||
if os.path.isfile(backup_fname):
|
if os.path.isfile(backup_fname):
|
||||||
return FileResponder(open(backup_fname, "rb"))
|
return FileResponder(open(backup_fname, "rb"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(config):
|
||||||
|
"""Called on startup to parse config supplied. This should parse
|
||||||
|
the config and raise if there is a problem.
|
||||||
|
|
||||||
|
The returned value is passed into the constructor.
|
||||||
|
|
||||||
|
In this case we only care about a single param, the directory, so let's
|
||||||
|
just pull that out.
|
||||||
|
"""
|
||||||
|
return Config.ensure_directory(config["directory"])
|
||||||
|
@ -67,7 +67,7 @@ class ThumbnailResource(Resource):
|
|||||||
yield self._respond_local_thumbnail(
|
yield self._respond_local_thumbnail(
|
||||||
request, media_id, width, height, method, m_type
|
request, media_id, width, height, method, m_type
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
self.media_repo.mark_recently_accessed(None, media_id)
|
||||||
else:
|
else:
|
||||||
if self.dynamic_thumbnails:
|
if self.dynamic_thumbnails:
|
||||||
yield self._select_or_generate_remote_thumbnail(
|
yield self._select_or_generate_remote_thumbnail(
|
||||||
@ -79,7 +79,7 @@ class ThumbnailResource(Resource):
|
|||||||
request, server_name, media_id,
|
request, server_name, media_id,
|
||||||
width, height, method, m_type
|
width, height, method, m_type
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(None, media_id)
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _respond_local_thumbnail(self, request, media_id, width, height,
|
def _respond_local_thumbnail(self, request, media_id, width, height,
|
||||||
|
139
synapse/util/file_consumer.py
Normal file
139
synapse/util/file_consumer.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import threads, reactor
|
||||||
|
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
|
|
||||||
|
import Queue
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundFileConsumer(object):
|
||||||
|
"""A consumer that writes to a file like object. Supports both push
|
||||||
|
and pull producers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_obj (file): The file like object to write to. Closed when
|
||||||
|
finished.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# For PushProducers pause if we have this many unwritten slices
|
||||||
|
_PAUSE_ON_QUEUE_SIZE = 5
|
||||||
|
# And resume once the size of the queue is less than this
|
||||||
|
_RESUME_ON_QUEUE_SIZE = 2
|
||||||
|
|
||||||
|
def __init__(self, file_obj):
|
||||||
|
self._file_obj = file_obj
|
||||||
|
|
||||||
|
# Producer we're registered with
|
||||||
|
self._producer = None
|
||||||
|
|
||||||
|
# True if PushProducer, false if PullProducer
|
||||||
|
self.streaming = False
|
||||||
|
|
||||||
|
# For PushProducers, indicates whether we've paused the producer and
|
||||||
|
# need to call resumeProducing before we get more data.
|
||||||
|
self._paused_producer = False
|
||||||
|
|
||||||
|
# Queue of slices of bytes to be written. When producer calls
|
||||||
|
# unregister a final None is sent.
|
||||||
|
self._bytes_queue = Queue.Queue()
|
||||||
|
|
||||||
|
# Deferred that is resolved when finished writing
|
||||||
|
self._finished_deferred = None
|
||||||
|
|
||||||
|
# If the _writer thread throws an exception it gets stored here.
|
||||||
|
self._write_exception = None
|
||||||
|
|
||||||
|
def registerProducer(self, producer, streaming):
|
||||||
|
"""Part of IConsumer interface
|
||||||
|
|
||||||
|
Args:
|
||||||
|
producer (IProducer)
|
||||||
|
streaming (bool): True if push based producer, False if pull
|
||||||
|
based.
|
||||||
|
"""
|
||||||
|
if self._producer:
|
||||||
|
raise Exception("registerProducer called twice")
|
||||||
|
|
||||||
|
self._producer = producer
|
||||||
|
self.streaming = streaming
|
||||||
|
self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
|
||||||
|
if not streaming:
|
||||||
|
self._producer.resumeProducing()
|
||||||
|
|
||||||
|
def unregisterProducer(self):
|
||||||
|
"""Part of IProducer interface
|
||||||
|
"""
|
||||||
|
self._producer = None
|
||||||
|
if not self._finished_deferred.called:
|
||||||
|
self._bytes_queue.put_nowait(None)
|
||||||
|
|
||||||
|
def write(self, bytes):
|
||||||
|
"""Part of IProducer interface
|
||||||
|
"""
|
||||||
|
if self._write_exception:
|
||||||
|
raise self._write_exception
|
||||||
|
|
||||||
|
if self._finished_deferred.called:
|
||||||
|
raise Exception("consumer has closed")
|
||||||
|
|
||||||
|
self._bytes_queue.put_nowait(bytes)
|
||||||
|
|
||||||
|
# If this is a PushProducer and the queue is getting behind
|
||||||
|
# then we pause the producer.
|
||||||
|
if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
|
||||||
|
self._paused_producer = True
|
||||||
|
self._producer.pauseProducing()
|
||||||
|
|
||||||
|
def _writer(self):
|
||||||
|
"""This is run in a background thread to write to the file.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
while self._producer or not self._bytes_queue.empty():
|
||||||
|
# If we've paused the producer check if we should resume the
|
||||||
|
# producer.
|
||||||
|
if self._producer and self._paused_producer:
|
||||||
|
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
|
||||||
|
reactor.callFromThread(self._resume_paused_producer)
|
||||||
|
|
||||||
|
bytes = self._bytes_queue.get()
|
||||||
|
|
||||||
|
# If we get a None (or empty list) then that's a signal used
|
||||||
|
# to indicate we should check if we should stop.
|
||||||
|
if bytes:
|
||||||
|
self._file_obj.write(bytes)
|
||||||
|
|
||||||
|
# If its a pull producer then we need to explicitly ask for
|
||||||
|
# more stuff.
|
||||||
|
if not self.streaming and self._producer:
|
||||||
|
reactor.callFromThread(self._producer.resumeProducing)
|
||||||
|
except Exception as e:
|
||||||
|
self._write_exception = e
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._file_obj.close()
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
"""Returns a deferred that resolves when finished writing to file
|
||||||
|
"""
|
||||||
|
return make_deferred_yieldable(self._finished_deferred)
|
||||||
|
|
||||||
|
def _resume_paused_producer(self):
|
||||||
|
"""Gets called if we should resume producing after being paused
|
||||||
|
"""
|
||||||
|
if self._paused_producer and self._producer:
|
||||||
|
self._paused_producer = False
|
||||||
|
self._producer.resumeProducing()
|
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
|||||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
# total number of times we have hit this block
|
# total number of times we have hit this block
|
||||||
response_count = metrics.register_counter(
|
block_counter = metrics.register_counter(
|
||||||
"block_count",
|
"block_count",
|
||||||
labels=["block_name"],
|
labels=["block_name"],
|
||||||
alternative_names=(
|
alternative_names=(
|
||||||
@ -76,7 +76,7 @@ block_db_txn_count = metrics.register_counter(
|
|||||||
block_db_txn_duration = metrics.register_counter(
|
block_db_txn_duration = metrics.register_counter(
|
||||||
"block_db_txn_duration_seconds", labels=["block_name"],
|
"block_db_txn_duration_seconds", labels=["block_name"],
|
||||||
alternative_names=(
|
alternative_names=(
|
||||||
metrics.name_prefix + "_block_db_txn_count:total",
|
metrics.name_prefix + "_block_db_txn_duration:total",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -131,6 +131,8 @@ class Measure(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
duration = self.clock.time_msec() - self.start
|
duration = self.clock.time_msec() - self.start
|
||||||
|
|
||||||
|
block_counter.inc(self.name)
|
||||||
block_timer.inc_by(duration, self.name)
|
block_timer.inc_by(duration, self.name)
|
||||||
|
|
||||||
context = LoggingContext.current_context()
|
context = LoggingContext.current_context()
|
||||||
|
48
synapse/util/threepids.py
Normal file
48
synapse/util/threepids.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_3pid_allowed(hs, medium, address):
|
||||||
|
"""Checks whether a given format of 3PID is allowed to be used on this HS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
medium (str): 3pid medium - e.g. email, msisdn
|
||||||
|
address (str): address within that medium (e.g. "wotan@matrix.org")
|
||||||
|
msisdns need to first have been canonicalised
|
||||||
|
Returns:
|
||||||
|
bool: whether the 3PID medium/address is allowed to be added to this HS
|
||||||
|
"""
|
||||||
|
|
||||||
|
if hs.config.allowed_local_3pids:
|
||||||
|
for constraint in hs.config.allowed_local_3pids:
|
||||||
|
logger.debug(
|
||||||
|
"Checking 3PID %s (%s) against %s (%s)",
|
||||||
|
address, medium, constraint['pattern'], constraint['medium'],
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
medium == constraint['medium'] and
|
||||||
|
re.match(constraint['pattern'], address)
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
@ -15,6 +15,8 @@
|
|||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
@ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
|||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||||
listener = reactor.listenUNIX("\0xxx", server_factory)
|
# XXX: mktemp is unsafe and should never be used. but we're just a test.
|
||||||
|
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
|
||||||
|
listener = reactor.listenUNIX(path, server_factory)
|
||||||
self.addCleanup(listener.stopListening)
|
self.addCleanup(listener.stopListening)
|
||||||
self.streamer = server_factory.streamer
|
self.streamer = server_factory.streamer
|
||||||
|
|
||||||
@ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
|||||||
client_factory = ReplicationClientFactory(
|
client_factory = ReplicationClientFactory(
|
||||||
self.hs, "client_name", self.replication_handler
|
self.hs, "client_name", self.replication_handler
|
||||||
)
|
)
|
||||||
client_connector = reactor.connectUNIX("\0xxx", client_factory)
|
client_connector = reactor.connectUNIX(path, client_factory)
|
||||||
self.addCleanup(client_factory.stopTrying)
|
self.addCleanup(client_factory.stopTrying)
|
||||||
self.addCleanup(client_connector.disconnect)
|
self.addCleanup(client_connector.disconnect)
|
||||||
|
|
||||||
|
@ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||||
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
||||||
self.hs.config.enable_registration = True
|
self.hs.config.enable_registration = True
|
||||||
|
self.hs.config.registrations_require_3pid = []
|
||||||
self.hs.config.auto_join_rooms = []
|
self.hs.config.auto_join_rooms = []
|
||||||
|
|
||||||
# init the thing we're testing
|
# init the thing we're testing
|
||||||
|
176
tests/util/test_file_consumer.py
Normal file
176
tests/util/test_file_consumer.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from twisted.internet import defer, reactor
|
||||||
|
from mock import NonCallableMock
|
||||||
|
|
||||||
|
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
from StringIO import StringIO
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class FileConsumerTests(unittest.TestCase):
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_pull_consumer(self):
|
||||||
|
string_file = StringIO()
|
||||||
|
consumer = BackgroundFileConsumer(string_file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
producer = DummyPullProducer()
|
||||||
|
|
||||||
|
yield producer.register_with_consumer(consumer)
|
||||||
|
|
||||||
|
yield producer.write_and_wait("Foo")
|
||||||
|
|
||||||
|
self.assertEqual(string_file.getvalue(), "Foo")
|
||||||
|
|
||||||
|
yield producer.write_and_wait("Bar")
|
||||||
|
|
||||||
|
self.assertEqual(string_file.getvalue(), "FooBar")
|
||||||
|
finally:
|
||||||
|
consumer.unregisterProducer()
|
||||||
|
|
||||||
|
yield consumer.wait()
|
||||||
|
|
||||||
|
self.assertTrue(string_file.closed)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_push_consumer(self):
|
||||||
|
string_file = BlockingStringWrite()
|
||||||
|
consumer = BackgroundFileConsumer(string_file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
producer = NonCallableMock(spec_set=[])
|
||||||
|
|
||||||
|
consumer.registerProducer(producer, True)
|
||||||
|
|
||||||
|
consumer.write("Foo")
|
||||||
|
yield string_file.wait_for_n_writes(1)
|
||||||
|
|
||||||
|
self.assertEqual(string_file.buffer, "Foo")
|
||||||
|
|
||||||
|
consumer.write("Bar")
|
||||||
|
yield string_file.wait_for_n_writes(2)
|
||||||
|
|
||||||
|
self.assertEqual(string_file.buffer, "FooBar")
|
||||||
|
finally:
|
||||||
|
consumer.unregisterProducer()
|
||||||
|
|
||||||
|
yield consumer.wait()
|
||||||
|
|
||||||
|
self.assertTrue(string_file.closed)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_push_producer_feedback(self):
|
||||||
|
string_file = BlockingStringWrite()
|
||||||
|
consumer = BackgroundFileConsumer(string_file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
|
||||||
|
|
||||||
|
resume_deferred = defer.Deferred()
|
||||||
|
producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
|
||||||
|
|
||||||
|
consumer.registerProducer(producer, True)
|
||||||
|
|
||||||
|
number_writes = 0
|
||||||
|
with string_file.write_lock:
|
||||||
|
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
|
||||||
|
consumer.write("Foo")
|
||||||
|
number_writes += 1
|
||||||
|
|
||||||
|
producer.pauseProducing.assert_called_once()
|
||||||
|
|
||||||
|
yield string_file.wait_for_n_writes(number_writes)
|
||||||
|
|
||||||
|
yield resume_deferred
|
||||||
|
producer.resumeProducing.assert_called_once()
|
||||||
|
finally:
|
||||||
|
consumer.unregisterProducer()
|
||||||
|
|
||||||
|
yield consumer.wait()
|
||||||
|
|
||||||
|
self.assertTrue(string_file.closed)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPullProducer(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.consumer = None
|
||||||
|
self.deferred = defer.Deferred()
|
||||||
|
|
||||||
|
def resumeProducing(self):
|
||||||
|
d = self.deferred
|
||||||
|
self.deferred = defer.Deferred()
|
||||||
|
d.callback(None)
|
||||||
|
|
||||||
|
def write_and_wait(self, bytes):
|
||||||
|
d = self.deferred
|
||||||
|
self.consumer.write(bytes)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def register_with_consumer(self, consumer):
|
||||||
|
d = self.deferred
|
||||||
|
self.consumer = consumer
|
||||||
|
self.consumer.registerProducer(self, False)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
class BlockingStringWrite(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.buffer = ""
|
||||||
|
self.closed = False
|
||||||
|
self.write_lock = threading.Lock()
|
||||||
|
|
||||||
|
self._notify_write_deferred = None
|
||||||
|
self._number_of_writes = 0
|
||||||
|
|
||||||
|
def write(self, bytes):
|
||||||
|
with self.write_lock:
|
||||||
|
self.buffer += bytes
|
||||||
|
self._number_of_writes += 1
|
||||||
|
|
||||||
|
reactor.callFromThread(self._notify_write)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
def _notify_write(self):
|
||||||
|
"Called by write to indicate a write happened"
|
||||||
|
with self.write_lock:
|
||||||
|
if not self._notify_write_deferred:
|
||||||
|
return
|
||||||
|
d = self._notify_write_deferred
|
||||||
|
self._notify_write_deferred = None
|
||||||
|
d.callback(None)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def wait_for_n_writes(self, n):
|
||||||
|
"Wait for n writes to have happened"
|
||||||
|
while True:
|
||||||
|
with self.write_lock:
|
||||||
|
if n <= self._number_of_writes:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._notify_write_deferred:
|
||||||
|
self._notify_write_deferred = defer.Deferred()
|
||||||
|
|
||||||
|
d = self._notify_write_deferred
|
||||||
|
|
||||||
|
yield d
|
Loading…
Reference in New Issue
Block a user