Merge branch 'develop' of github.com:matrix-org/synapse into erikj/replication_noop

This commit is contained in:
Erik Johnston 2016-10-11 14:08:29 +01:00
commit 3061dac53e
28 changed files with 605 additions and 337 deletions

View File

@ -1,3 +1,86 @@
Changes in synapse v0.18.1 (2016-10-0)
======================================
No changes since v0.18.1-rc1
Changes in synapse v0.18.1-rc1 (2016-09-30)
===========================================
Features:
* Add total_room_count_estimate to ``/publicRooms`` (PR #1133)
Changes:
* Time out typing over federation (PR #1140)
* Restructure LDAP authentication (PR #1153)
Bug fixes:
* Fix 3pid invites when server is already in the room (PR #1136)
* Fix upgrading with SQLite taking lots of CPU for a few days
after upgrade (PR #1144)
* Fix upgrading from very old database versions (PR #1145)
* Fix port script to work with recently added tables (PR #1146)
Changes in synapse v0.18.0 (2016-09-19)
=======================================
The release includes major changes to the state storage database schemas, which
significantly reduce database size. Synapse will attempt to upgrade the current
data in the background. Servers with large SQLite database may experience
degradation of performance while this upgrade is in progress, therefore you may
want to consider migrating to using Postgres before upgrading very large SQLite
databases
Changes:
* Make public room search case insensitive (PR #1127)
Bug fixes:
* Fix and clean up publicRooms pagination (PR #1129)
Changes in synapse v0.18.0-rc1 (2016-09-16)
===========================================
Features:
* Add ``only=highlight`` on ``/notifications`` (PR #1081)
* Add server param to /publicRooms (PR #1082)
* Allow clients to ask for the whole of a single state event (PR #1094)
* Add is_direct param to /createRoom (PR #1108)
* Add pagination support to publicRooms (PR #1121)
* Add very basic filter API to /publicRooms (PR #1126)
* Add basic direct to device messaging support for E2E (PR #1074, #1084, #1104,
#1111)
Changes:
* Move to storing state_groups_state as deltas, greatly reducing DB size (PR
#1065)
* Reduce amount of state pulled out of the DB during common requests (PR #1069)
* Allow PDF to be rendered from media repo (PR #1071)
* Reindex state_groups_state after pruning (PR #1085)
* Clobber EDUs in send queue (PR #1095)
* Conform better to the CAS protocol specification (PR #1100)
* Limit how often we ask for keys from dead servers (PR #1114)
Bug fixes:
* Fix /notifications API when used with ``from`` param (PR #1080)
* Fix backfill when cannot find an event. (PR #1107)
Changes in synapse v0.17.3 (2016-09-09) Changes in synapse v0.17.3 (2016-09-09)
======================================= =======================================

View File

@ -18,7 +18,9 @@
<div class="summarytext">{{ summary_text }}</div> <div class="summarytext">{{ summary_text }}</div>
</td> </td>
<td class="logo"> <td class="logo">
{% if app_name == "Vector" %} {% if app_name == "Riot" %}
<img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% else %} {% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>

View File

@ -39,6 +39,7 @@ BOOLEAN_COLUMNS = {
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
"presence_stream": ["currently_active"], "presence_stream": ["currently_active"],
"public_room_list_stream": ["visibility"],
} }
@ -71,6 +72,14 @@ APPEND_ONLY_TABLES = [
"event_to_state_groups", "event_to_state_groups",
"rejections", "rejections",
"event_search", "event_search",
"presence_stream",
"push_rules_stream",
"current_state_resets",
"ex_outlier_stream",
"cache_invalidation_stream",
"public_room_list_stream",
"state_group_edges",
"stream_ordering_to_exterm",
] ]

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.17.3" __version__ = "0.18.1"

View File

@ -653,7 +653,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request): def _get_appservice_user_id(self, request):
app_service = yield self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
get_access_token_from_request( get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
@ -855,13 +855,12 @@ class Auth(object):
} }
defer.returnValue(user_info) defer.returnValue(user_info)
@defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: try:
token = get_access_token_from_request( token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
service = yield self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,)) logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError( raise AuthError(
@ -870,7 +869,7 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
request.authenticated_entity = service.sender request.authenticated_entity = service.sender
defer.returnValue(service) return defer.succeed(service)
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
@ -1002,16 +1001,6 @@ class Auth(object):
403, 403,
"You are not allowed to set others state" "You are not allowed to set others state"
) )
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True return True

View File

@ -136,9 +136,7 @@ class FederationClient(FederationBase):
sent_edus_counter.inc() sent_edus_counter.inc()
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu, key=key) self._transaction_queue.enqueue_edu(edu, key=key)
return defer.succeed(None)
@log_function @log_function
def send_device_messages(self, destination): def send_device_messages(self, destination):

View File

@ -55,8 +55,14 @@ class BaseHandler(object):
def ratelimit(self, requester): def ratelimit(self, requester):
time_now = self.clock.time() time_now = self.clock.time()
user_id = requester.user.to_string()
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service is not None:
return # do not ratelimit app service senders
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
requester.user.to_string(), time_now, user_id, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second, msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count, burst_count=self.hs.config.rc_message_burst_count,
) )

View File

@ -59,7 +59,7 @@ class ApplicationServicesHandler(object):
Args: Args:
current_id(int): The current maximum ID. current_id(int): The current maximum ID.
""" """
services = yield self.store.get_app_services() services = self.store.get_app_services()
if not services or not self.notify_appservices: if not services or not self.notify_appservices:
return return
@ -142,7 +142,7 @@ class ApplicationServicesHandler(object):
association can be found. association can be found.
""" """
room_alias_str = room_alias.to_string() room_alias_str = room_alias.to_string()
services = yield self.store.get_app_services() services = self.store.get_app_services()
alias_query_services = [ alias_query_services = [
s for s in services if ( s for s in services if (
s.is_interested_in_alias(room_alias_str) s.is_interested_in_alias(room_alias_str)
@ -177,7 +177,7 @@ class ApplicationServicesHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_3pe_protocols(self, only_protocol=None): def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services() services = self.store.get_app_services()
protocols = {} protocols = {}
# Collect up all the individual protocol responses out of the ASes # Collect up all the individual protocol responses out of the ASes
@ -224,7 +224,7 @@ class ApplicationServicesHandler(object):
list<ApplicationService>: A list of services interested in this list<ApplicationService>: A list of services interested in this
event based on the service regex. event based on the service regex.
""" """
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
yield s.is_interested(event, self.store) yield s.is_interested(event, self.store)
@ -232,23 +232,21 @@ class ApplicationServicesHandler(object):
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_user(self, user_id): def _get_services_for_user(self, user_id):
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
s.is_interested_in_user(user_id) s.is_interested_in_user(user_id)
) )
] ]
defer.returnValue(interested_list) return defer.succeed(interested_list)
@defer.inlineCallbacks
def _get_services_for_3pn(self, protocol): def _get_services_for_3pn(self, protocol):
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if s.is_interested_in_protocol(protocol) s for s in services if s.is_interested_in_protocol(protocol)
] ]
defer.returnValue(interested_list) return defer.succeed(interested_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):
@ -264,7 +262,7 @@ class ApplicationServicesHandler(object):
return return
# user not found; could be the AS though, so check. # user not found; could be the AS though, so check.
services = yield self.store.get_app_services() services = self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id] service_list = [s for s in services if s.sender == user_id]
defer.returnValue(len(service_list) == 0) defer.returnValue(len(service_list) == 0)

View File

@ -31,6 +31,7 @@ import simplejson
try: try:
import ldap3 import ldap3
import ldap3.core.exceptions
except ImportError: except ImportError:
ldap3 = None ldap3 = None
pass pass
@ -58,7 +59,6 @@ class AuthHandler(BaseHandler):
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled self.ldap_enabled = hs.config.ldap_enabled
if self.ldap_enabled: if self.ldap_enabled:
@ -148,13 +148,30 @@ class AuthHandler(BaseHandler):
creds = session['creds'] creds = session['creds']
# check auth type currently being presented # check auth type currently being presented
errordict = {}
if 'type' in authdict: if 'type' in authdict:
if authdict['type'] not in self.checkers: login_type = authdict['type']
if login_type not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED) raise LoginError(400, "", Codes.UNRECOGNIZED)
result = yield self.checkers[authdict['type']](authdict, clientip) try:
result = yield self.checkers[login_type](authdict, clientip)
if result: if result:
creds[authdict['type']] = result creds[login_type] = result
self._save_session(session) self._save_session(session)
except LoginError, e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
# got a 401 with a 'flows' field.
# (https://github.com/vector-im/vector-web/issues/2447).
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
raise e
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
@ -163,6 +180,7 @@ class AuthHandler(BaseHandler):
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id'])) defer.returnValue((False, ret, clientdict, session['id']))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -430,37 +448,40 @@ class AuthHandler(BaseHandler):
defer.Deferred: (str) canonical_user_id, or None if zero or defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches multiple matches
""" """
try:
res = yield self._find_user_id_and_pwd_hash(user_id) res = yield self._find_user_id_and_pwd_hash(user_id)
if res is not None:
defer.returnValue(res[0]) defer.returnValue(res[0])
except LoginError:
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case
insensitively, but will throw if there are multiple inexact matches. insensitively, but will return None if there are multiple inexact
matches.
Returns: Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)` tuple: A 2-tuple of `(canonical_user_id, password_hash)`
None: if there is not exactly one match
""" """
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
result = None
if not user_infos: if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id) logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) elif len(user_infos) == 1:
# a single match (possibly not exact)
if len(user_infos) > 1: result = user_infos.popitem()
if user_id not in user_infos: elif user_id in user_infos:
# multiple matches, but one is exact
result = (user_id, user_infos[user_id])
else:
# multiple matches, none of them exact
logger.warn( logger.warn(
"Attempted to login as %s but it matches more than one user " "Attempted to login as %s but it matches more than one user "
"inexactly: %r", "inexactly: %r",
user_id, user_infos.keys() user_id, user_infos.keys()
) )
raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(result)
defer.returnValue((user_id, user_infos[user_id]))
else:
defer.returnValue(user_infos.popitem())
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
@ -474,36 +495,185 @@ class AuthHandler(BaseHandler):
Returns: Returns:
(str) the canonical_user_id (str) the canonical_user_id
Raises: Raises:
LoginError if the password was incorrect LoginError if login fails
""" """
valid_ldap = yield self._check_ldap_password(user_id, password) valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap: if valid_ldap:
defer.returnValue(user_id) defer.returnValue(user_id)
result = yield self._check_local_password(user_id, password) canonical_user_id = yield self._check_local_password(user_id, password)
defer.returnValue(result)
if canonical_user_id:
defer.returnValue(canonical_user_id)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
# into a 401 anyway.
raise LoginError(
403, "Invalid password",
errcode=Codes.FORBIDDEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
user_id is checked case insensitively, but will throw if there are user_id is checked case insensitively, but will return None if there are
multiple inexact matches. multiple inexact matches.
Args: Args:
user_id (str): complete @user:id user_id (str): complete @user:id
Returns: Returns:
(str) the canonical_user_id (str) the canonical_user_id, or None if unknown user / bad password
Raises:
LoginError if the password was incorrect
""" """
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
result = self.validate_hash(password, password_hash) result = self.validate_hash(password, password_hash)
if not result: if not result:
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(None)
defer.returnValue(user_id) defer.returnValue(user_id)
def _ldap_simple_bind(self, server, localpart, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.
Returns True, LDAP3Connection
if the bind was successful
Returns False, None
if an error occured
"""
try:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established LDAP connection in simple bind mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
conn
)
if conn.bind():
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
return True, conn
# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
localpart, conn.result['description']
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _ldap_authenticated_search(self, server, localpart, password):
""" Attempt to login with the preconfigured bind_dn
and then continue searching and filtering within
the base_dn
Returns (True, LDAP3Connection)
if a single matching DN within the base was found
that matched the filter expression, and with which
a successful bind was achieved
The LDAP3Connection returned is the instance that was used to
verify the password not the one using the configured bind_dn.
Returns (False, None)
if an error occured
"""
try:
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established LDAP connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in search mode through StartTLS: %s",
conn
)
if not conn.bind():
logger.warn(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
conn.unbind()
return False, None
# construct search_filter like (uid=localpart)
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
# combine with the AND expression
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug(
"LDAP search filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query
)
if len(conn.response) == 1:
# GOOD: found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('LDAP search found dn: %s', user_dn)
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
conn.unbind()
return self._ldap_simple_bind(server, localpart, password)
else:
# BAD: found 0 or > 1 results, abort!
if len(conn.response) == 0:
logger.info(
"LDAP search returned no results for '%s'",
localpart
)
else:
logger.info(
"LDAP search returned too many (%s) results for '%s'",
len(conn.response), localpart
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_ldap_password(self, user_id, password): def _check_ldap_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server """ Attempt to authenticate a user against an LDAP Server
@ -516,106 +686,62 @@ class AuthHandler(BaseHandler):
if not ldap3 or not self.ldap_enabled: if not ldap3 or not self.ldap_enabled:
defer.returnValue(False) defer.returnValue(False)
if self.ldap_mode not in LDAPMode.LIST: localpart = UserID.from_string(user_id).localpart
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting LDAP connection with %s",
self.ldap_uri
)
if self.ldap_mode == LDAPMode.SIMPLE:
result, conn = self._ldap_simple_bind(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP authentication method simple bind returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
elif self.ldap_mode == LDAPMode.SEARCH:
result, conn = self._ldap_authenticated_search(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP auth method authenticated search returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
else:
raise RuntimeError( raise RuntimeError(
'Invalid ldap mode specified: {mode}'.format( 'Invalid LDAP mode specified: {mode}'.format(
mode=self.ldap_mode mode=self.ldap_mode
) )
) )
try: try:
server = ldap3.Server(self.ldap_uri) logger.info(
logger.debug( "User authenticated against LDAP server: %s",
"Attempting ldap connection with %s",
self.ldap_uri
)
localpart = UserID.from_string(user_id).localpart
if self.ldap_mode == LDAPMode.SIMPLE:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established ldap connection in simple mode: %s",
conn conn
) )
except NameError:
if self.ldap_start_tls: logger.warn("Authentication method yielded no LDAP connection, aborting!")
conn.start_tls()
logger.debug(
"Upgraded ldap connection in simple mode through StartTLS: %s",
conn
)
conn.bind()
elif self.ldap_mode == LDAPMode.SEARCH:
# connect with preconfigured credentials and search for local user
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established ldap connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in search mode through StartTLS: %s",
conn
)
conn.bind()
# find matching dn
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug("ldap search filter: %s", query)
result = conn.search(self.ldap_base, query)
if result and len(conn.response) == 1:
# found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('ldap search found dn: %s', user_dn)
# unbind and reconnect, rebind with found dn
conn.unbind()
conn = ldap3.Connection(
server,
user_dn,
password,
auto_bind=True
)
else:
# found 0 or > 1 results, abort!
logger.warn(
"ldap search returned unexpected (%d!=1) amount of results",
len(conn.response)
)
defer.returnValue(False) defer.returnValue(False)
logger.info( # check if user with user_id exists
"User authenticated against ldap server: %s", if (yield self.check_user_exists(user_id)):
conn # exists, authentication complete
) conn.unbind()
defer.returnValue(True)
# check for existing account, if none exists, create one else:
if not (yield self.check_user_exists(user_id)): # does not exist, fetch metadata for account creation from
# query user metadata for account creation # existing ldap connection
query = "({prop}={value})".format( query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'], prop=self.ldap_attributes['uid'],
value=localpart value=localpart
@ -626,9 +752,12 @@ class AuthHandler(BaseHandler):
filter=query, filter=query,
user_filter=self.ldap_filter user_filter=self.ldap_filter
) )
logger.debug("ldap registration filter: %s", query) logger.debug(
"ldap registration filter: %s",
query
)
result = conn.search( conn.search(
search_base=self.ldap_base, search_base=self.ldap_base,
search_filter=query, search_filter=query,
attributes=[ attributes=[
@ -651,20 +780,27 @@ class AuthHandler(BaseHandler):
# TODO: bind email, set displayname with data from ldap directory # TODO: bind email, set displayname with data from ldap directory
logger.info( logger.info(
"ldap registration successful: %d: %s (%s, %)", "Registration based on LDAP data was successful: %d: %s (%s, %)",
user_id, user_id,
localpart, localpart,
name, name,
mail mail
) )
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(True) defer.returnValue(True)
else:
if len(conn.response) == 0:
logger.warn("LDAP registration failed, no result.")
else:
logger.warn(
"LDAP registration failed, too many results (%s)",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(False)
except ldap3.core.exceptions.LDAPException as e: except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e) logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False) defer.returnValue(False)

View File

@ -288,13 +288,12 @@ class DirectoryHandler(BaseHandler):
result = yield as_handler.query_room_alias_exists(room_alias) result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def can_modify_alias(self, alias, user_id=None): def can_modify_alias(self, alias, user_id=None):
# Any application service "interested" in an alias they are regexing on # Any application service "interested" in an alias they are regexing on
# can modify the alias. # can modify the alias.
# Users can only modify the alias if ALL the interested services have # Users can only modify the alias if ALL the interested services have
# non-exclusive locks on the alias (or there are no interested services) # non-exclusive locks on the alias (or there are no interested services)
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services if s.is_interested_in_alias(alias.to_string()) s for s in services if s.is_interested_in_alias(alias.to_string())
] ]
@ -302,14 +301,12 @@ class DirectoryHandler(BaseHandler):
for service in interested_services: for service in interested_services:
if user_id == service.sender: if user_id == service.sender:
# this user IS the app service so they can do whatever they like # this user IS the app service so they can do whatever they like
defer.returnValue(True) return defer.succeed(True)
return
elif service.is_exclusive_alias(alias.to_string()): elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias. # another service has an exclusive lock on this alias.
defer.returnValue(False) return defer.succeed(False)
return
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
defer.returnValue(True) return defer.succeed(True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _user_can_delete_alias(self, alias, user_id): def _user_can_delete_alias(self, alias, user_id):

View File

@ -65,13 +65,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["displayname"]) defer.returnValue(result["displayname"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname): def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
"""target_user is the user whose displayname is to be changed; """target_user is the user whose displayname is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != requester.user: if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '': if new_displayname == '':
@ -111,13 +111,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["avatar_url"]) defer.returnValue(result["avatar_url"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_avatar_url(self, target_user, requester, new_avatar_url): def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
"""target_user is the user whose avatar_url is to be changed; """target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != requester.user: if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url") raise AuthError(400, "Cannot set another user's avatar_url")
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(

View File

@ -19,7 +19,6 @@ import urllib
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
@ -194,7 +193,7 @@ class RegistrationHandler(BaseHandler):
def appservice_register(self, user_localpart, as_token): def appservice_register(self, user_localpart, as_token):
user = UserID(user_localpart, self.hs.hostname) user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
service = yield self.store.get_app_service_by_token(as_token) service = self.store.get_app_service_by_token(as_token)
if not service: if not service:
raise AuthError(403, "Invalid application service token.") raise AuthError(403, "Invalid application service token.")
if not service.is_interested_in_user(user_id): if not service.is_interested_in_user(user_id):
@ -305,11 +304,10 @@ class RegistrationHandler(BaseHandler):
# XXX: This should be a deferred list, shouldn't it? # XXX: This should be a deferred list, shouldn't it?
yield identity_handler.bind_threepid(c, user_id) yield identity_handler.bind_threepid(c, user_id)
@defer.inlineCallbacks
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
# valid user IDs must not clash with any user ID namespaces claimed by # valid user IDs must not clash with any user ID namespaces claimed by
# application services. # application services.
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services s for s in services
if s.is_interested_in_user(user_id) if s.is_interested_in_user(user_id)
@ -371,7 +369,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_in_ms, def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
password_hash=None): password_hash=None):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@ -418,9 +416,8 @@ class RegistrationHandler(BaseHandler):
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname( yield profile_handler.set_displayname(
user, requester, displayname user, requester, displayname, by_admin=True,
) )
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))

View File

@ -437,7 +437,7 @@ class RoomEventSource(object):
logger.warn("Stream has topological part!!!! %r", from_key) logger.warn("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,) from_key = "s%s" % (from_token.stream,)
app_service = yield self.store.get_app_service_by_user_id( app_service = self.store.get_app_service_by_user_id(
user.to_string() user.to_string()
) )
if app_service: if app_service:

View File

@ -788,7 +788,7 @@ class SyncHandler(object):
assert since_token assert since_token
app_service = yield self.store.get_app_service_by_user_id(user_id) app_service = self.store.get_app_service_by_user_id(user_id)
if app_service: if app_service:
rooms = yield self.store.get_app_service_rooms(app_service) rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms) joined_room_ids = set(r.room_id for r in rooms)

View File

@ -16,10 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import ( from synapse.util.logcontext import preserve_fn
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
import logging import logging
@ -35,6 +34,13 @@ logger = logging.getLogger(__name__)
RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
# How often we expect remote servers to resend us presence.
FEDERATION_TIMEOUT = 60 * 1000
# How often to resend typing across federation.
FEDERATION_PING_INTERVAL = 40 * 1000
class TypingHandler(object): class TypingHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -44,7 +50,10 @@ class TypingHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.wheel_timer = WheelTimer(bucket_size=5000)
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
@ -53,7 +62,7 @@ class TypingHandler(object):
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop self._member_typing_until = {} # clock time we expect to stop
self._member_typing_timer = {} # deferreds to manage theabove self._member_last_federation_poke = {}
# map room IDs to serial numbers # map room IDs to serial numbers
self._room_serials = {} self._room_serials = {}
@ -61,12 +70,41 @@ class TypingHandler(object):
# map room IDs to sets of users currently typing # map room IDs to sets of users currently typing
self._room_typing = {} self._room_typing = {}
def tearDown(self): self.clock.looping_call(
"""Cancels all the pending timers. self._handle_timeouts,
Normally this shouldn't be needed, but it's required from unit tests 5000,
to avoid a "Reactor was unclean" warning.""" )
for t in self._member_typing_timer.values():
self.clock.cancel_call_later(t) def _handle_timeouts(self):
logger.info("Checking for typing timeouts")
now = self.clock.time_msec()
members = set(self.wheel_timer.fetch(now))
for member in members:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
continue
until = self._member_typing_until.get(member, None)
if not until or until < now:
logger.info("Timing out typing for: %s", member.user_id)
preserve_fn(self._stopped_typing)(member)
continue
# Check if we need to resend a keep alive over federation for this
# user.
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now:
preserve_fn(self._push_remote)(
member=member,
typing=True
)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
@defer.inlineCallbacks @defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout): def started_typing(self, target_user, auth_user, room_id, timeout):
@ -85,23 +123,17 @@ class TypingHandler(object):
"%s has started typing in %s", target_user_id, room_id "%s has started typing in %s", target_user_id, room_id
) )
until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user_id=target_user_id) member = RoomMember(room_id=room_id, user_id=target_user_id)
was_present = member in self._member_typing_until was_present = member.user_id in self._room_typing.get(room_id, set())
if member in self._member_typing_timer: now = self.clock.time_msec()
self.clock.cancel_call_later(self._member_typing_timer[member]) self._member_typing_until[member] = now + timeout
def _cb(): self.wheel_timer.insert(
logger.debug( now=now,
"%s has timed out in %s", target_user.to_string(), room_id obj=member,
) then=now + timeout,
self._stopped_typing(member)
self._member_typing_until[member] = until
self._member_typing_timer[member] = self.clock.call_later(
timeout / 1000.0, _cb
) )
if was_present: if was_present:
@ -109,8 +141,7 @@ class TypingHandler(object):
defer.returnValue(None) defer.returnValue(None)
yield self._push_update( yield self._push_update(
room_id=room_id, member=member,
user_id=target_user_id,
typing=True, typing=True,
) )
@ -133,10 +164,6 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=target_user_id) member = RoomMember(room_id=room_id, user_id=target_user_id)
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
del self._member_typing_timer[member]
yield self._stopped_typing(member) yield self._stopped_typing(member)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -148,50 +175,52 @@ class TypingHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _stopped_typing(self, member): def _stopped_typing(self, member):
if member not in self._member_typing_until: if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point # No point
defer.returnValue(None) defer.returnValue(None)
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
yield self._push_update( yield self._push_update(
room_id=member.room_id, member=member,
user_id=member.user_id,
typing=False, typing=False,
) )
del self._member_typing_until[member]
if member in self._member_typing_timer:
# Don't cancel it - either it already expired, or the real
# stopped_typing() will cancel it
del self._member_typing_timer[member]
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_update(self, room_id, user_id, typing): def _push_update(self, member, typing):
users = yield self.state.get_current_user_in_room(room_id) if self.hs.is_mine_id(member.user_id):
domains = set(get_domain_from_id(u) for u in users) # Only send updates for changes to our own users.
yield self._push_remote(member, typing)
deferreds = [] self._push_update_local(
for domain in domains: member=member,
if domain == self.server_name:
preserve_fn(self._push_update_local)(
room_id=room_id,
user_id=user_id,
typing=typing typing=typing
) )
else:
deferreds.append(preserve_fn(self.federation.send_edu)( @defer.inlineCallbacks
def _push_remote(self, member, typing):
users = yield self.state.get_current_user_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
self.wheel_timer.insert(
now=now,
obj=member,
then=now + FEDERATION_PING_INTERVAL,
)
for domain in set(get_domain_from_id(u) for u in users):
if domain != self.server_name:
self.federation.send_edu(
destination=domain, destination=domain,
edu_type="m.typing", edu_type="m.typing",
content={ content={
"room_id": room_id, "room_id": member.room_id,
"user_id": user_id, "user_id": member.user_id,
"typing": typing, "typing": typing,
}, },
key=(room_id, user_id), key=member,
))
yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -199,6 +228,8 @@ class TypingHandler(object):
room_id = content["room_id"] room_id = content["room_id"]
user_id = content["user_id"] user_id = content["user_id"]
member = RoomMember(user_id=user_id, room_id=room_id)
# Check that the string is a valid user id # Check that the string is a valid user id
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -213,25 +244,31 @@ class TypingHandler(object):
domains = set(get_domain_from_id(u) for u in users) domains = set(get_domain_from_id(u) for u in users)
if self.server_name in domains: if self.server_name in domains:
logger.info("Got typing update from %s: %r", user_id, content)
now = self.clock.time_msec()
self._member_typing_until[member] = now + FEDERATION_TIMEOUT
self.wheel_timer.insert(
now=now,
obj=member,
then=now + FEDERATION_TIMEOUT,
)
self._push_update_local( self._push_update_local(
room_id=room_id, member=member,
user_id=user_id,
typing=content["typing"] typing=content["typing"]
) )
def _push_update_local(self, room_id, user_id, typing): def _push_update_local(self, member, typing):
room_set = self._room_typing.setdefault(room_id, set()) room_set = self._room_typing.setdefault(member.room_id, set())
if typing: if typing:
room_set.add(user_id) room_set.add(member.user_id)
else: else:
room_set.discard(user_id) room_set.discard(member.user_id)
self._latest_room_serial += 1 self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial self._room_serials[member.room_id] = self._latest_room_serial
with PreserveLoggingContext():
self.notifier.on_new_event( self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[room_id] "typing_key", self._latest_room_serial, rooms=[member.room_id]
) )
def get_all_typing_updates(self, last_id, current_id): def get_all_typing_updates(self, last_id, current_id):

View File

@ -22,6 +22,7 @@ from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.types import create_requester
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -391,15 +392,16 @@ class CreateUserRestServlet(ClientV1RestServlet):
user_json = parse_json_object_from_request(request) user_json = parse_json_object_from_request(request)
access_token = get_access_token_from_request(request) access_token = get_access_token_from_request(request)
app_service = yield self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
access_token access_token
) )
if not app_service: if not app_service:
raise SynapseError(403, "Invalid application service token.") raise SynapseError(403, "Invalid application service token.")
logger.debug("creating user: %s", user_json) requester = create_requester(app_service.sender)
response = yield self._do_create(user_json) logger.debug("creating user: %s", user_json)
response = yield self._do_create(requester, user_json)
defer.returnValue((200, response)) defer.returnValue((200, response))
@ -407,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
return 403, {} return 403, {}
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_create(self, user_json): def _do_create(self, requester, user_json):
yield run_on_reactor() yield run_on_reactor()
if "localpart" not in user_json: if "localpart" not in user_json:
@ -433,6 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user( user_id, token = yield handler.get_or_create_user(
requester=requester,
localpart=localpart, localpart=localpart,
displayname=displayname, displayname=displayname,
duration_in_ms=(duration_seconds * 1000), duration_in_ms=(duration_seconds * 1000),

View File

@ -705,12 +705,15 @@ class RoomTypingRestServlet(ClientV1RestServlet):
yield self.presence_handler.bump_presence_active_time(requester.user) yield self.presence_handler.bump_presence_active_time(requester.user)
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
if content["typing"]: if content["typing"]:
yield self.typing_handler.started_typing( yield self.typing_handler.started_typing(
target_user=target_user, target_user=target_user,
auth_user=requester.user, auth_user=requester.user,
room_id=room_id, room_id=room_id,
timeout=content.get("timeout", 30000), timeout=timeout,
) )
else: else:
yield self.typing_handler.stopped_typing( yield self.typing_handler.stopped_typing(

View File

@ -77,8 +77,10 @@ SUCCESS_TEMPLATE = """
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> <link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script> <script>
if (window.onAuthDone != undefined) { if (window.onAuthDone) {
window.onAuthDone(); window.onAuthDone();
} else if (window.opener && window.opener.postMessage) {
window.opener.postMessage("authDone", "*");
} }
</script> </script>
</head> </head>

View File

@ -37,7 +37,7 @@ class ApplicationServiceStore(SQLBaseStore):
) )
def get_app_services(self): def get_app_services(self):
return defer.succeed(self.services_cache) return self.services_cache
def get_app_service_by_user_id(self, user_id): def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID. """Retrieve an application service from their user ID.
@ -54,8 +54,8 @@ class ApplicationServiceStore(SQLBaseStore):
""" """
for service in self.services_cache: for service in self.services_cache:
if service.sender == user_id: if service.sender == user_id:
return defer.succeed(service) return service
return defer.succeed(None) return None
def get_app_service_by_token(self, token): def get_app_service_by_token(self, token):
"""Get the application service with the given appservice token. """Get the application service with the given appservice token.
@ -67,8 +67,8 @@ class ApplicationServiceStore(SQLBaseStore):
""" """
for service in self.services_cache: for service in self.services_cache:
if service.token == token: if service.token == token:
return defer.succeed(service) return service
return defer.succeed(None) return None
def get_app_service_rooms(self, service): def get_app_service_rooms(self, service):
"""Get a list of RoomsForUser for this application service. """Get a list of RoomsForUser for this application service.
@ -163,7 +163,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
["as_id"] ["as_id"]
) )
# NB: This assumes this class is linked with ApplicationServiceStore # NB: This assumes this class is linked with ApplicationServiceStore
as_list = yield self.get_app_services() as_list = self.get_app_services()
services = [] services = []
for res in results: for res in results:

View File

@ -1355,39 +1355,53 @@ class EventsStore(SQLBaseStore):
min_stream_id = rows[-1][0] min_stream_id = rows[-1][0]
event_ids = [row[1] for row in rows] event_ids = [row[1] for row in rows]
events = self._get_events_txn(txn, event_ids) rows_to_update = []
rows = [] chunks = [
for event in events: event_ids[i:i + 100]
for i in xrange(0, len(event_ids), 100)
]
for chunk in chunks:
ev_rows = self._simple_select_many_txn(
txn,
table="event_json",
column="event_id",
iterable=chunk,
retcols=["event_id", "json"],
keyvalues={},
)
for row in ev_rows:
event_id = row["event_id"]
event_json = json.loads(row["json"])
try: try:
event_id = event.event_id origin_server_ts = event_json["origin_server_ts"]
origin_server_ts = event.origin_server_ts
except (KeyError, AttributeError): except (KeyError, AttributeError):
# If the event is missing a necessary field then # If the event is missing a necessary field then
# skip over it. # skip over it.
continue continue
rows.append((origin_server_ts, event_id)) rows_to_update.append((origin_server_ts, event_id))
sql = ( sql = (
"UPDATE events SET origin_server_ts = ? WHERE event_id = ?" "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
) )
for index in range(0, len(rows), INSERT_CLUMP_SIZE): for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
clump = rows[index:index + INSERT_CLUMP_SIZE] clump = rows_to_update[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump) txn.executemany(sql, clump)
progress = { progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id, "max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(rows) "rows_inserted": rows_inserted + len(rows_to_update)
} }
self._background_update_progress_txn( self._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
) )
return len(rows) return len(rows_to_update)
result = yield self.runInteraction( result = yield self.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn

View File

@ -307,6 +307,9 @@ class StateStore(SQLBaseStore):
def _get_state_groups_from_groups_txn(self, txn, groups, types=None): def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
results = {group: {} for group in groups} results = {group: {} for group in groups}
if types is not None:
types = list(set(types)) # deduplicate types list
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is # Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in # a temporary hack until we can add the right indices in
@ -375,10 +378,35 @@ class StateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions # We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy) # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups: for group in groups:
group_tree = [group]
next_group = group next_group = group
while next_group: while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indicies (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
if types:
args.extend(i for typ in types for i in typ)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? %s" % (where_clause,),
args
)
rows = txn.fetchall()
results[group].update({
(typ, state_key): event_id
for typ, state_key, event_id in rows
if (typ, state_key) not in results[group]
})
# If the lengths match then we must have all the types,
# so no need to go walk further down the tree.
if types is not None and len(results[group]) == len(types):
break
next_group = self._simple_select_one_onecol_txn( next_group = self._simple_select_one_onecol_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
@ -386,28 +414,6 @@ class StateStore(SQLBaseStore):
retcol="prev_state_group", retcol="prev_state_group",
allow_none=True, allow_none=True,
) )
if next_group:
group_tree.append(next_group)
sql = ("""
SELECT type, state_key, event_id FROM state_groups_state
INNER JOIN (
SELECT type, state_key, max(state_group) as state_group
FROM state_groups_state
WHERE state_group IN (%s) %s
GROUP BY type, state_key
) USING (type, state_key, state_group);
""") % (",".join("?" for _ in group_tree), where_clause,)
args = list(group_tree)
if types is not None:
args.extend([i for typ in types for i in typ])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
for row in rows:
key = (row["type"], row["state_key"])
results[group][key] = row["event_id"]
return results return results

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from .. import unittest from .. import unittest
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID from synapse.types import UserID, create_requester
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -57,8 +57,9 @@ class RegistrationTestCase(unittest.TestCase):
local_part = "someone" local_part = "someone"
display_name = "someone" display_name = "someone"
user_id = "@someone:test" user_id = "@someone:test"
requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms) requester, local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@ -74,7 +75,8 @@ class RegistrationTestCase(unittest.TestCase):
local_part = "frank" local_part = "frank"
display_name = "Frank" display_name = "Frank"
user_id = "@frank:test" user_id = "@frank:test"
requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms) requester, local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')

View File

@ -267,10 +267,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
from synapse.handlers.typing import RoomMember from synapse.handlers.typing import RoomMember
member = RoomMember(self.room_id, self.u_apple.to_string()) member = RoomMember(self.room_id, self.u_apple.to_string())
self.handler._member_typing_until[member] = 1002000 self.handler._member_typing_until[member] = 1002000
self.handler._member_typing_timer[member] = ( self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
self.clock.call_later(1002, lambda: 0)
)
self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),))
self.assertEquals(self.event_source.get_current_key(), 0) self.assertEquals(self.event_source.get_current_key(), 0)
@ -330,7 +327,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
}, },
}]) }])
self.clock.advance_time(11) self.clock.advance_time(16)
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 2, rooms=[self.room_id]), call('typing_key', 2, rooms=[self.room_id]),

View File

@ -31,33 +31,21 @@ class CreateUserServletTestCase(unittest.TestCase):
) )
self.request.args = {} self.request.args = {}
self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice))
)
self.auth_result = (False, None, None, None)
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
)
self.registration_handler = Mock() self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
# do the dance to hook it up to the hs global self.appservice = Mock(sender="@as:test")
self.handlers = Mock( self.datastore = Mock(
auth_handler=self.auth_handler, get_app_service_by_token=Mock(return_value=self.appservice)
)
# do the dance to hook things up to the hs global
handlers = Mock(
registration_handler=self.registration_handler, registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
) )
self.hs = Mock() self.hs = Mock()
self.hs.hostname = "supergbig~testing~thing.com" self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_handlers = Mock(return_value=handlers)
self.hs.config.enable_registration = True
# init the thing we're testing
self.servlet = CreateUserRestServlet(self.hs) self.servlet = CreateUserRestServlet(self.hs)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -105,9 +105,6 @@ class RoomTypingTestCase(RestTestCase):
# Need another user to make notifications actually work # Need another user to make notifications actually work
yield self.join(self.room_id, user="@jim:red") yield self.join(self.room_id, user="@jim:red")
def tearDown(self):
self.hs.get_typing_handler().tearDown()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_typing(self): def test_set_typing(self):
(code, _) = yield self.mock_resource.trigger( (code, _) = yield self.mock_resource.trigger(
@ -147,7 +144,7 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
self.clock.advance_time(31) self.clock.advance_time(36)
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)

View File

@ -19,7 +19,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.appservice = None self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock( self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice)) side_effect=lambda x: self.appservice)
) )
self.auth_result = (False, None, None, None) self.auth_result = (False, None, None, None)

View File

@ -71,14 +71,12 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
outfile.write(yaml.dump(as_yaml)) outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token) self.as_yaml_files.append(as_token)
@defer.inlineCallbacks
def test_retrieve_unknown_service_token(self): def test_retrieve_unknown_service_token(self):
service = yield self.store.get_app_service_by_token("invalid_token") service = self.store.get_app_service_by_token("invalid_token")
self.assertEquals(service, None) self.assertEquals(service, None)
@defer.inlineCallbacks
def test_retrieval_of_service(self): def test_retrieval_of_service(self):
stored_service = yield self.store.get_app_service_by_token( stored_service = self.store.get_app_service_by_token(
self.as_token self.as_token
) )
self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.token, self.as_token)
@ -97,9 +95,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
[] []
) )
@defer.inlineCallbacks
def test_retrieval_of_all_services(self): def test_retrieval_of_all_services(self):
services = yield self.store.get_app_services() services = self.store.get_app_services()
self.assertEquals(len(services), 3) self.assertEquals(len(services), 3)

View File

@ -220,6 +220,7 @@ class MockClock(object):
# list of lists of [absolute_time, callback, expired] in no particular # list of lists of [absolute_time, callback, expired] in no particular
# order # order
self.timers = [] self.timers = []
self.loopers = []
def time(self): def time(self):
return self.now return self.now
@ -240,7 +241,7 @@ class MockClock(object):
return t return t
def looping_call(self, function, interval): def looping_call(self, function, interval):
pass self.loopers.append([function, interval / 1000., self.now])
def cancel_call_later(self, timer, ignore_errs=False): def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]: if timer[2]:
@ -269,6 +270,12 @@ class MockClock(object):
else: else:
self.timers.append(t) self.timers.append(t)
for looped in self.loopers:
func, interval, last = looped
if last + interval < self.now:
func()
looped[2] = self.now
def advance_time_msec(self, ms): def advance_time_msec(self, ms):
self.advance_time(ms / 1000.) self.advance_time(ms / 1000.)