mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-22 16:41:00 -05:00
Merge branch 'release-v0.18.5' of github.com:matrix-org/synapse
This commit is contained in:
commit
f5a4001bb1
17
.travis.yml
Normal file
17
.travis.yml
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
sudo: false
|
||||||
|
language: python
|
||||||
|
python: 2.7
|
||||||
|
|
||||||
|
# tell travis to cache ~/.cache/pip
|
||||||
|
cache: pip
|
||||||
|
|
||||||
|
env:
|
||||||
|
- TOX_ENV=packaging
|
||||||
|
- TOX_ENV=pep8
|
||||||
|
- TOX_ENV=py27
|
||||||
|
|
||||||
|
install:
|
||||||
|
- pip install tox
|
||||||
|
|
||||||
|
script:
|
||||||
|
- tox -e $TOX_ENV
|
65
CHANGES.rst
65
CHANGES.rst
@ -1,3 +1,68 @@
|
|||||||
|
Changes in synapse v0.18.5 (2016-12-16)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix federation /backfill returning events it shouldn't (PR #1700)
|
||||||
|
* Fix crash in url preview (PR #1701)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.18.5-rc3 (2016-12-13)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Add support for E2E for guests (PR #1653)
|
||||||
|
* Add new API appservice specific public room list (PR #1676)
|
||||||
|
* Add new room membership APIs (PR #1680)
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Enable guest access for private rooms by default (PR #653)
|
||||||
|
* Limit the number of events that can be created on a given room concurrently
|
||||||
|
(PR #1620)
|
||||||
|
* Log the args that we have on UI auth completion (PR #1649)
|
||||||
|
* Stop generating refresh_tokens (PR #1654)
|
||||||
|
* Stop putting a time caveat on access tokens (PR #1656)
|
||||||
|
* Remove unspecced GET endpoints for e2e keys (PR #1694)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix handling of 500 and 429's over federation (PR #1650)
|
||||||
|
* Fix Content-Type header parsing (PR #1660)
|
||||||
|
* Fix error when previewing sites that include unicode, thanks to kyrias (PR
|
||||||
|
#1664)
|
||||||
|
* Fix some cases where we drop read receipts (PR #1678)
|
||||||
|
* Fix bug where calls to ``/sync`` didn't correctly timeout (PR #1683)
|
||||||
|
* Fix bug where E2E key query would fail if a single remote host failed (PR
|
||||||
|
#1686)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.18.5-rc2 (2016-11-24)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Don't send old events over federation, fixes bug in -rc1.
|
||||||
|
|
||||||
|
Changes in synapse v0.18.5-rc1 (2016-11-24)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Implement "event_fields" in filters (PR #1638)
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Use external ldap auth pacakge (PR #1628)
|
||||||
|
* Split out federation transaction sending to a worker (PR #1635)
|
||||||
|
* Fail with a coherent error message if `/sync?filter=` is invalid (PR #1636)
|
||||||
|
* More efficient notif count queries (PR #1644)
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.18.4 (2016-11-22)
|
Changes in synapse v0.18.4 (2016-11-22)
|
||||||
=======================================
|
=======================================
|
||||||
|
|
||||||
|
31
README.rst
31
README.rst
@ -120,6 +120,7 @@ Installing prerequisites on Mac OS X::
|
|||||||
xcode-select --install
|
xcode-select --install
|
||||||
sudo easy_install pip
|
sudo easy_install pip
|
||||||
sudo pip install virtualenv
|
sudo pip install virtualenv
|
||||||
|
brew install pkg-config libffi
|
||||||
|
|
||||||
Installing prerequisites on Raspbian::
|
Installing prerequisites on Raspbian::
|
||||||
|
|
||||||
@ -136,6 +137,10 @@ Installing prerequisites on openSUSE::
|
|||||||
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
|
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
|
||||||
python-devel libffi-devel libopenssl-devel libjpeg62-devel
|
python-devel libffi-devel libopenssl-devel libjpeg62-devel
|
||||||
|
|
||||||
|
Installing prerequisites on OpenBSD::
|
||||||
|
doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
|
||||||
|
libxslt
|
||||||
|
|
||||||
To install the synapse homeserver run::
|
To install the synapse homeserver run::
|
||||||
|
|
||||||
virtualenv -p python2.7 ~/.synapse
|
virtualenv -p python2.7 ~/.synapse
|
||||||
@ -369,6 +374,32 @@ Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Mo
|
|||||||
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
|
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
|
||||||
- Packages: ``pkg install py27-matrix-synapse``
|
- Packages: ``pkg install py27-matrix-synapse``
|
||||||
|
|
||||||
|
|
||||||
|
OpenBSD
|
||||||
|
-------
|
||||||
|
|
||||||
|
There is currently no port for OpenBSD. Additionally, OpenBSD's security
|
||||||
|
settings require a slightly more difficult installation process.
|
||||||
|
|
||||||
|
1) Create a new directory in ``/usr/local`` called ``_synapse``. Also, create a
|
||||||
|
new user called ``_synapse`` and set that directory as the new user's home.
|
||||||
|
This is required because, by default, OpenBSD only allows binaries which need
|
||||||
|
write and execute permissions on the same memory space to be run from
|
||||||
|
``/usr/local``.
|
||||||
|
2) ``su`` to the new ``_synapse`` user and change to their home directory.
|
||||||
|
3) Create a new virtualenv: ``virtualenv -p python2.7 ~/.synapse``
|
||||||
|
4) Source the virtualenv configuration located at
|
||||||
|
``/usr/local/_synapse/.synapse/bin/activate``. This is done in ``ksh`` by
|
||||||
|
using the ``.`` command, rather than ``bash``'s ``source``.
|
||||||
|
5) Optionally, use ``pip`` to install ``lxml``, which Synapse needs to parse
|
||||||
|
webpages for their titles.
|
||||||
|
6) Use ``pip`` to install this repository: ``pip install
|
||||||
|
https://github.com/matrix-org/synapse/tarball/master``
|
||||||
|
7) Optionally, change ``_synapse``'s shell to ``/bin/false`` to reduce the
|
||||||
|
chance of a compromised Synapse server being used to take over your box.
|
||||||
|
|
||||||
|
After this, you may proceed with the rest of the install directions.
|
||||||
|
|
||||||
NixOS
|
NixOS
|
||||||
-----
|
-----
|
||||||
|
|
||||||
|
@ -22,3 +22,4 @@ export SYNAPSE_CACHE_FACTOR=1
|
|||||||
--federation-reader \
|
--federation-reader \
|
||||||
--client-reader \
|
--client-reader \
|
||||||
--appservice \
|
--appservice \
|
||||||
|
--federation-sender \
|
||||||
|
@ -15,6 +15,6 @@ tox -e py27 --notest -v
|
|||||||
|
|
||||||
TOX_BIN=$TOX_DIR/py27/bin
|
TOX_BIN=$TOX_DIR/py27/bin
|
||||||
$TOX_BIN/pip install setuptools
|
$TOX_BIN/pip install setuptools
|
||||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
{ python synapse/python_dependencies.py
|
||||||
$TOX_BIN/pip install lxml
|
echo lxml psycopg2
|
||||||
$TOX_BIN/pip install psycopg2
|
} | xargs $TOX_BIN/pip install
|
||||||
|
73
setup.py
73
setup.py
@ -23,6 +23,45 @@ import sys
|
|||||||
here = os.path.abspath(os.path.dirname(__file__))
|
here = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
# Some notes on `setup.py test`:
|
||||||
|
#
|
||||||
|
# Once upon a time we used to try to make `setup.py test` run `tox` to run the
|
||||||
|
# tests. That's a bad idea for three reasons:
|
||||||
|
#
|
||||||
|
# 1: `setup.py test` is supposed to find out whether the tests work in the
|
||||||
|
# *current* environmentt, not whatever tox sets up.
|
||||||
|
# 2: Empirically, trying to install tox during the test run wasn't working ("No
|
||||||
|
# module named virtualenv").
|
||||||
|
# 3: The tox documentation advises against it[1].
|
||||||
|
#
|
||||||
|
# Even further back in time, we used to use setuptools_trial [2]. That has its
|
||||||
|
# own set of issues: for instance, it requires installation of Twisted to build
|
||||||
|
# an sdist (because the recommended mode of usage is to add it to
|
||||||
|
# `setup_requires`). That in turn means that in order to successfully run tox
|
||||||
|
# you have to have the python header files installed for whichever version of
|
||||||
|
# python tox uses (which is python3 on recent ubuntus, for example).
|
||||||
|
#
|
||||||
|
# So, for now at least, we stick with what appears to be the convention among
|
||||||
|
# Twisted projects, and don't attempt to do anything when someone runs
|
||||||
|
# `setup.py test`; instead we direct people to run `trial` directly if they
|
||||||
|
# care.
|
||||||
|
#
|
||||||
|
# [1]: http://tox.readthedocs.io/en/2.5.0/example/basic.html#integration-with-setup-py-test-command
|
||||||
|
# [2]: https://pypi.python.org/pypi/setuptools_trial
|
||||||
|
class TestCommand(Command):
|
||||||
|
user_options = []
|
||||||
|
|
||||||
|
def initialize_options(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finalize_options(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
print ("""Synapse's tests cannot be run via setup.py. To run them, try:
|
||||||
|
PYTHONPATH="." trial tests
|
||||||
|
""")
|
||||||
|
|
||||||
def read_file(path_segments):
|
def read_file(path_segments):
|
||||||
"""Read a file from the package. Takes a list of strings to join to
|
"""Read a file from the package. Takes a list of strings to join to
|
||||||
make the path"""
|
make the path"""
|
||||||
@ -39,38 +78,6 @@ def exec_file(path_segments):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class Tox(Command):
|
|
||||||
user_options = [('tox-args=', 'a', "Arguments to pass to tox")]
|
|
||||||
|
|
||||||
def initialize_options(self):
|
|
||||||
self.tox_args = None
|
|
||||||
|
|
||||||
def finalize_options(self):
|
|
||||||
self.test_args = []
|
|
||||||
self.test_suite = True
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
#import here, cause outside the eggs aren't loaded
|
|
||||||
try:
|
|
||||||
import tox
|
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
self.distribution.fetch_build_eggs("tox")
|
|
||||||
import tox
|
|
||||||
except:
|
|
||||||
raise RuntimeError(
|
|
||||||
"The tests need 'tox' to run. Please install 'tox'."
|
|
||||||
)
|
|
||||||
import shlex
|
|
||||||
args = self.tox_args
|
|
||||||
if args:
|
|
||||||
args = shlex.split(self.tox_args)
|
|
||||||
else:
|
|
||||||
args = []
|
|
||||||
errno = tox.cmdline(args=args)
|
|
||||||
sys.exit(errno)
|
|
||||||
|
|
||||||
|
|
||||||
version = exec_file(("synapse", "__init__.py"))["__version__"]
|
version = exec_file(("synapse", "__init__.py"))["__version__"]
|
||||||
dependencies = exec_file(("synapse", "python_dependencies.py"))
|
dependencies = exec_file(("synapse", "python_dependencies.py"))
|
||||||
long_description = read_file(("README.rst",))
|
long_description = read_file(("README.rst",))
|
||||||
@ -86,5 +93,5 @@ setup(
|
|||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
scripts=["synctl"] + glob.glob("scripts/*"),
|
scripts=["synctl"] + glob.glob("scripts/*"),
|
||||||
cmdclass={'test': Tox},
|
cmdclass={'test': TestCommand},
|
||||||
)
|
)
|
||||||
|
@ -16,4 +16,4 @@
|
|||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.18.4"
|
__version__ = "0.18.5"
|
||||||
|
@ -39,6 +39,9 @@ AuthEventTypes = (
|
|||||||
EventTypes.ThirdPartyInvite,
|
EventTypes.ThirdPartyInvite,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# guests always get this device id.
|
||||||
|
GUEST_DEVICE_ID = "guest_device"
|
||||||
|
|
||||||
|
|
||||||
class Auth(object):
|
class Auth(object):
|
||||||
"""
|
"""
|
||||||
@ -51,17 +54,6 @@ class Auth(object):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||||
# Docs for these currently lives at
|
|
||||||
# github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
|
||||||
# In addition, we have type == delete_pusher which grants access only to
|
|
||||||
# delete pushers.
|
|
||||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
|
||||||
"gen = ",
|
|
||||||
"guest = ",
|
|
||||||
"type = ",
|
|
||||||
"time < ",
|
|
||||||
"user_id = ",
|
|
||||||
])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_from_context(self, event, context, do_sig_check=True):
|
def check_from_context(self, event, context, do_sig_check=True):
|
||||||
@ -685,31 +677,28 @@ class Auth(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_user_by_access_token(self, token, rights="access"):
|
def get_user_by_access_token(self, token, rights="access"):
|
||||||
""" Get a registered user's ID.
|
""" Validate access token and get user_id from it
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token (str): The access token to get the user by.
|
token (str): The access token to get the user by.
|
||||||
|
rights (str): The operation being performed; the access token must
|
||||||
|
allow this.
|
||||||
Returns:
|
Returns:
|
||||||
dict : dict that includes the user and the ID of their access token.
|
dict : dict that includes the user and the ID of their access token.
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
ret = yield self.get_user_from_macaroon(token, rights)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
except AuthError:
|
except Exception: # deserialize can throw more-or-less anything
|
||||||
# TODO(daniel): Remove this fallback when all existing access tokens
|
# doesn't look like a macaroon: treat it as an opaque token which
|
||||||
# have been re-issued as macaroons.
|
# must be in the database.
|
||||||
if self.hs.config.expire_access_token:
|
# TODO: it would be nice to get rid of this, but apparently some
|
||||||
raise
|
# people use access tokens which aren't macaroons
|
||||||
ret = yield self._look_up_user_by_access_token(token)
|
r = yield self._look_up_user_by_access_token(token)
|
||||||
|
defer.returnValue(r)
|
||||||
|
|
||||||
defer.returnValue(ret)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_user_from_macaroon(self, macaroon_str, rights="access"):
|
|
||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
|
||||||
|
|
||||||
user_id = self.get_user_id_from_macaroon(macaroon)
|
user_id = self.get_user_id_from_macaroon(macaroon)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
@ -724,11 +713,36 @@ class Auth(object):
|
|||||||
guest = True
|
guest = True
|
||||||
|
|
||||||
if guest:
|
if guest:
|
||||||
|
# Guest access tokens are not stored in the database (there can
|
||||||
|
# only be one access token per guest, anyway).
|
||||||
|
#
|
||||||
|
# In order to prevent guest access tokens being used as regular
|
||||||
|
# user access tokens (and hence getting around the invalidation
|
||||||
|
# process), we look up the user id and check that it is indeed
|
||||||
|
# a guest user.
|
||||||
|
#
|
||||||
|
# It would of course be much easier to store guest access
|
||||||
|
# tokens in the database as well, but that would break existing
|
||||||
|
# guest tokens.
|
||||||
|
stored_user = yield self.store.get_user_by_id(user_id)
|
||||||
|
if not stored_user:
|
||||||
|
raise AuthError(
|
||||||
|
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||||
|
"Unknown user_id %s" % user_id,
|
||||||
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
|
)
|
||||||
|
if not stored_user["is_guest"]:
|
||||||
|
raise AuthError(
|
||||||
|
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||||
|
"Guest access token used for regular user",
|
||||||
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
|
)
|
||||||
ret = {
|
ret = {
|
||||||
"user": user,
|
"user": user,
|
||||||
"is_guest": True,
|
"is_guest": True,
|
||||||
"token_id": None,
|
"token_id": None,
|
||||||
"device_id": None,
|
# all guests get the same device id
|
||||||
|
"device_id": GUEST_DEVICE_ID,
|
||||||
}
|
}
|
||||||
elif rights == "delete_pusher":
|
elif rights == "delete_pusher":
|
||||||
# We don't store these tokens in the database
|
# We don't store these tokens in the database
|
||||||
@ -750,7 +764,7 @@ class Auth(object):
|
|||||||
# macaroon. They probably should be.
|
# macaroon. They probably should be.
|
||||||
# TODO: build the dictionary from the macaroon once the
|
# TODO: build the dictionary from the macaroon once the
|
||||||
# above are fixed
|
# above are fixed
|
||||||
ret = yield self._look_up_user_by_access_token(macaroon_str)
|
ret = yield self._look_up_user_by_access_token(token)
|
||||||
if ret["user"] != user:
|
if ret["user"] != user:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Macaroon user (%s) != DB user (%s)",
|
"Macaroon user (%s) != DB user (%s)",
|
||||||
@ -798,27 +812,38 @@ class Auth(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
macaroon(pymacaroons.Macaroon): The macaroon to validate
|
macaroon(pymacaroons.Macaroon): The macaroon to validate
|
||||||
type_string(str): The kind of token required (e.g. "access", "refresh",
|
type_string(str): The kind of token required (e.g. "access",
|
||||||
"delete_pusher")
|
"delete_pusher")
|
||||||
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
||||||
This should really always be True, but no clients currently implement
|
|
||||||
token refresh, so we can't enforce expiry yet.
|
|
||||||
user_id (str): The user_id required
|
user_id (str): The user_id required
|
||||||
"""
|
"""
|
||||||
v = pymacaroons.Verifier()
|
v = pymacaroons.Verifier()
|
||||||
|
|
||||||
|
# the verifier runs a test for every caveat on the macaroon, to check
|
||||||
|
# that it is met for the current request. Each caveat must match at
|
||||||
|
# least one of the predicates specified by satisfy_exact or
|
||||||
|
# specify_general.
|
||||||
v.satisfy_exact("gen = 1")
|
v.satisfy_exact("gen = 1")
|
||||||
v.satisfy_exact("type = " + type_string)
|
v.satisfy_exact("type = " + type_string)
|
||||||
v.satisfy_exact("user_id = %s" % user_id)
|
v.satisfy_exact("user_id = %s" % user_id)
|
||||||
v.satisfy_exact("guest = true")
|
v.satisfy_exact("guest = true")
|
||||||
|
|
||||||
|
# verify_expiry should really always be True, but there exist access
|
||||||
|
# tokens in the wild which expire when they should not, so we can't
|
||||||
|
# enforce expiry yet (so we have to allow any caveat starting with
|
||||||
|
# 'time < ' in access tokens).
|
||||||
|
#
|
||||||
|
# On the other hand, short-term login tokens (as used by CAS login, for
|
||||||
|
# example) have an expiry time which we do want to enforce.
|
||||||
|
|
||||||
if verify_expiry:
|
if verify_expiry:
|
||||||
v.satisfy_general(self._verify_expiry)
|
v.satisfy_general(self._verify_expiry)
|
||||||
else:
|
else:
|
||||||
v.satisfy_general(lambda c: c.startswith("time < "))
|
v.satisfy_general(lambda c: c.startswith("time < "))
|
||||||
|
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
# access_tokens include a nonce for uniqueness: any value is acceptable
|
||||||
|
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||||
|
|
||||||
v = pymacaroons.Verifier()
|
|
||||||
v.satisfy_general(self._verify_recognizes_caveats)
|
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
def _verify_expiry(self, caveat):
|
def _verify_expiry(self, caveat):
|
||||||
@ -829,15 +854,6 @@ class Auth(object):
|
|||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
return now < expiry
|
return now < expiry
|
||||||
|
|
||||||
def _verify_recognizes_caveats(self, caveat):
|
|
||||||
first_space = caveat.find(" ")
|
|
||||||
if first_space < 0:
|
|
||||||
return False
|
|
||||||
second_space = caveat.find(" ", first_space + 1)
|
|
||||||
if second_space < 0:
|
|
||||||
return False
|
|
||||||
return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _look_up_user_by_access_token(self, token):
|
def _look_up_user_by_access_token(self, token):
|
||||||
ret = yield self.store.get_user_by_access_token(token)
|
ret = yield self.store.get_user_by_access_token(token)
|
||||||
|
@ -39,6 +39,7 @@ class Codes(object):
|
|||||||
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
|
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
|
||||||
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
|
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
|
||||||
MISSING_PARAM = "M_MISSING_PARAM"
|
MISSING_PARAM = "M_MISSING_PARAM"
|
||||||
|
INVALID_PARAM = "M_INVALID_PARAM"
|
||||||
TOO_LARGE = "M_TOO_LARGE"
|
TOO_LARGE = "M_TOO_LARGE"
|
||||||
EXCLUSIVE = "M_EXCLUSIVE"
|
EXCLUSIVE = "M_EXCLUSIVE"
|
||||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||||
|
@ -71,6 +71,21 @@ class Filtering(object):
|
|||||||
if key in user_filter_json["room"]:
|
if key in user_filter_json["room"]:
|
||||||
self._check_definition(user_filter_json["room"][key])
|
self._check_definition(user_filter_json["room"][key])
|
||||||
|
|
||||||
|
if "event_fields" in user_filter_json:
|
||||||
|
if type(user_filter_json["event_fields"]) != list:
|
||||||
|
raise SynapseError(400, "event_fields must be a list of strings")
|
||||||
|
for field in user_filter_json["event_fields"]:
|
||||||
|
if not isinstance(field, basestring):
|
||||||
|
raise SynapseError(400, "Event field must be a string")
|
||||||
|
# Don't allow '\\' in event field filters. This makes matching
|
||||||
|
# events a lot easier as we can then use a negative lookbehind
|
||||||
|
# assertion to split '\.' If we allowed \\ then it would
|
||||||
|
# incorrectly split '\\.' See synapse.events.utils.serialize_event
|
||||||
|
if r'\\' in field:
|
||||||
|
raise SynapseError(
|
||||||
|
400, r'The escape character \ cannot itself be escaped'
|
||||||
|
)
|
||||||
|
|
||||||
def _check_definition_room_lists(self, definition):
|
def _check_definition_room_lists(self, definition):
|
||||||
"""Check that "rooms" and "not_rooms" are lists of room ids if they
|
"""Check that "rooms" and "not_rooms" are lists of room ids if they
|
||||||
are present
|
are present
|
||||||
@ -152,6 +167,7 @@ class FilterCollection(object):
|
|||||||
self.include_leave = filter_json.get("room", {}).get(
|
self.include_leave = filter_json.get("room", {}).get(
|
||||||
"include_leave", False
|
"include_leave", False
|
||||||
)
|
)
|
||||||
|
self.event_fields = filter_json.get("event_fields", [])
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
|
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
|
||||||
@ -186,6 +202,26 @@ class FilterCollection(object):
|
|||||||
def filter_room_account_data(self, events):
|
def filter_room_account_data(self, events):
|
||||||
return self._room_account_data.filter(self._room_filter.filter(events))
|
return self._room_account_data.filter(self._room_filter.filter(events))
|
||||||
|
|
||||||
|
def blocks_all_presence(self):
|
||||||
|
return (
|
||||||
|
self._presence_filter.filters_all_types() or
|
||||||
|
self._presence_filter.filters_all_senders()
|
||||||
|
)
|
||||||
|
|
||||||
|
def blocks_all_room_ephemeral(self):
|
||||||
|
return (
|
||||||
|
self._room_ephemeral_filter.filters_all_types() or
|
||||||
|
self._room_ephemeral_filter.filters_all_senders() or
|
||||||
|
self._room_ephemeral_filter.filters_all_rooms()
|
||||||
|
)
|
||||||
|
|
||||||
|
def blocks_all_room_timeline(self):
|
||||||
|
return (
|
||||||
|
self._room_timeline_filter.filters_all_types() or
|
||||||
|
self._room_timeline_filter.filters_all_senders() or
|
||||||
|
self._room_timeline_filter.filters_all_rooms()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Filter(object):
|
class Filter(object):
|
||||||
def __init__(self, filter_json):
|
def __init__(self, filter_json):
|
||||||
@ -202,6 +238,15 @@ class Filter(object):
|
|||||||
|
|
||||||
self.contains_url = self.filter_json.get("contains_url", None)
|
self.contains_url = self.filter_json.get("contains_url", None)
|
||||||
|
|
||||||
|
def filters_all_types(self):
|
||||||
|
return "*" in self.not_types
|
||||||
|
|
||||||
|
def filters_all_senders(self):
|
||||||
|
return "*" in self.not_senders
|
||||||
|
|
||||||
|
def filters_all_rooms(self):
|
||||||
|
return "*" in self.not_rooms
|
||||||
|
|
||||||
def check(self, event):
|
def check(self, event):
|
||||||
"""Checks whether the filter matches the given event.
|
"""Checks whether the filter matches the given event.
|
||||||
|
|
||||||
|
331
synapse/app/federation_sender.py
Normal file
331
synapse/app/federation_sender.py
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.config._base import ConfigError
|
||||||
|
from synapse.config.logger import setup_logging
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from synapse.crypto import context_factory
|
||||||
|
from synapse.http.site import SynapseSite
|
||||||
|
from synapse.federation import send_queue
|
||||||
|
from synapse.federation.units import Edu
|
||||||
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
|
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
|
||||||
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
|
from synapse.storage.engines import create_engine
|
||||||
|
from synapse.storage.presence import UserPresenceState
|
||||||
|
from synapse.util.async import sleep
|
||||||
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
from synapse.util.manhole import manhole
|
||||||
|
from synapse.util.rlimit import change_resource_limit
|
||||||
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
|
from synapse import events
|
||||||
|
|
||||||
|
from twisted.internet import reactor, defer
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
from daemonize import Daemonize
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
import ujson as json
|
||||||
|
|
||||||
|
logger = logging.getLogger("synapse.app.appservice")
|
||||||
|
|
||||||
|
|
||||||
|
class FederationSenderSlaveStore(
|
||||||
|
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
|
||||||
|
SlavedRegistrationStore,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FederationSenderServer(HomeServer):
|
||||||
|
def get_db_conn(self, run_new_connection=True):
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
|
||||||
|
if run_new_connection:
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
logger.info("Setting up.")
|
||||||
|
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
|
||||||
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
|
def _listen_http(self, listener_config):
|
||||||
|
port = listener_config["port"]
|
||||||
|
bind_address = listener_config.get("bind_address", "")
|
||||||
|
site_tag = listener_config.get("tag", port)
|
||||||
|
resources = {}
|
||||||
|
for res in listener_config["resources"]:
|
||||||
|
for name in res["names"]:
|
||||||
|
if name == "metrics":
|
||||||
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
|
||||||
|
root_resource = create_resource_tree(resources, Resource())
|
||||||
|
reactor.listenTCP(
|
||||||
|
port,
|
||||||
|
SynapseSite(
|
||||||
|
"synapse.access.http.%s" % (site_tag,),
|
||||||
|
site_tag,
|
||||||
|
listener_config,
|
||||||
|
root_resource,
|
||||||
|
),
|
||||||
|
interface=bind_address
|
||||||
|
)
|
||||||
|
logger.info("Synapse federation_sender now listening on port %d", port)
|
||||||
|
|
||||||
|
def start_listening(self, listeners):
|
||||||
|
for listener in listeners:
|
||||||
|
if listener["type"] == "http":
|
||||||
|
self._listen_http(listener)
|
||||||
|
elif listener["type"] == "manhole":
|
||||||
|
reactor.listenTCP(
|
||||||
|
listener["port"],
|
||||||
|
manhole(
|
||||||
|
username="matrix",
|
||||||
|
password="rabbithole",
|
||||||
|
globals={"hs": self},
|
||||||
|
),
|
||||||
|
interface=listener.get("bind_address", '127.0.0.1')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def replicate(self):
|
||||||
|
http_client = self.get_simple_http_client()
|
||||||
|
store = self.get_datastore()
|
||||||
|
replication_url = self.config.worker_replication_url
|
||||||
|
send_handler = FederationSenderHandler(self)
|
||||||
|
|
||||||
|
send_handler.on_start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
args = store.stream_positions()
|
||||||
|
args.update((yield send_handler.stream_positions()))
|
||||||
|
args["timeout"] = 30000
|
||||||
|
result = yield http_client.get_json(replication_url, args=args)
|
||||||
|
yield store.process_replication(result)
|
||||||
|
yield send_handler.process_replication(result)
|
||||||
|
except:
|
||||||
|
logger.exception("Error replicating from %r", replication_url)
|
||||||
|
yield sleep(30)
|
||||||
|
|
||||||
|
|
||||||
|
def start(config_options):
|
||||||
|
try:
|
||||||
|
config = HomeServerConfig.load_config(
|
||||||
|
"Synapse federation sender", config_options
|
||||||
|
)
|
||||||
|
except ConfigError as e:
|
||||||
|
sys.stderr.write("\n" + e.message + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
assert config.worker_app == "synapse.app.federation_sender"
|
||||||
|
|
||||||
|
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||||
|
|
||||||
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
|
if config.send_federation:
|
||||||
|
sys.stderr.write(
|
||||||
|
"\nThe send_federation must be disabled in the main synapse process"
|
||||||
|
"\nbefore they can be run in a separate worker."
|
||||||
|
"\nPlease add ``send_federation: false`` to the main config"
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Force the pushers to start since they will be disabled in the main config
|
||||||
|
config.send_federation = True
|
||||||
|
|
||||||
|
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||||
|
|
||||||
|
ps = FederationSenderServer(
|
||||||
|
config.server_name,
|
||||||
|
db_config=config.database_config,
|
||||||
|
tls_server_context_factory=tls_server_context_factory,
|
||||||
|
config=config,
|
||||||
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
|
database_engine=database_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
ps.setup()
|
||||||
|
ps.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
with LoggingContext("run"):
|
||||||
|
logger.info("Running")
|
||||||
|
change_resource_limit(config.soft_file_limit)
|
||||||
|
if config.gc_thresholds:
|
||||||
|
gc.set_threshold(*config.gc_thresholds)
|
||||||
|
reactor.run()
|
||||||
|
|
||||||
|
def start():
|
||||||
|
ps.replicate()
|
||||||
|
ps.get_datastore().start_profiling()
|
||||||
|
ps.get_state_handler().start_caching()
|
||||||
|
|
||||||
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
if config.worker_daemonize:
|
||||||
|
daemon = Daemonize(
|
||||||
|
app="synapse-federation-sender",
|
||||||
|
pid=config.worker_pid_file,
|
||||||
|
action=run,
|
||||||
|
auto_close_fds=False,
|
||||||
|
verbose=True,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
daemon.start()
|
||||||
|
else:
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
||||||
|
class FederationSenderHandler(object):
|
||||||
|
"""Processes the replication stream and forwards the appropriate entries
|
||||||
|
to the federation sender.
|
||||||
|
"""
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.federation_sender = hs.get_federation_sender()
|
||||||
|
|
||||||
|
self._room_serials = {}
|
||||||
|
self._room_typing = {}
|
||||||
|
|
||||||
|
def on_start(self):
|
||||||
|
# There may be some events that are persisted but haven't been sent,
|
||||||
|
# so send them now.
|
||||||
|
self.federation_sender.notify_new_events(
|
||||||
|
self.store.get_room_max_stream_ordering()
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def stream_positions(self):
|
||||||
|
stream_id = yield self.store.get_federation_out_pos("federation")
|
||||||
|
defer.returnValue({
|
||||||
|
"federation": stream_id,
|
||||||
|
|
||||||
|
# Ack stuff we've "processed", this should only be called from
|
||||||
|
# one process.
|
||||||
|
"federation_ack": stream_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def process_replication(self, result):
|
||||||
|
# The federation stream contains things that we want to send out, e.g.
|
||||||
|
# presence, typing, etc.
|
||||||
|
fed_stream = result.get("federation")
|
||||||
|
if fed_stream:
|
||||||
|
latest_id = int(fed_stream["position"])
|
||||||
|
|
||||||
|
# The federation stream containis a bunch of different types of
|
||||||
|
# rows that need to be handled differently. We parse the rows, put
|
||||||
|
# them into the appropriate collection and then send them off.
|
||||||
|
presence_to_send = {}
|
||||||
|
keyed_edus = {}
|
||||||
|
edus = {}
|
||||||
|
failures = {}
|
||||||
|
device_destinations = set()
|
||||||
|
|
||||||
|
# Parse the rows in the stream
|
||||||
|
for row in fed_stream["rows"]:
|
||||||
|
position, typ, content_js = row
|
||||||
|
content = json.loads(content_js)
|
||||||
|
|
||||||
|
if typ == send_queue.PRESENCE_TYPE:
|
||||||
|
destination = content["destination"]
|
||||||
|
state = UserPresenceState.from_dict(content["state"])
|
||||||
|
|
||||||
|
presence_to_send.setdefault(destination, []).append(state)
|
||||||
|
elif typ == send_queue.KEYED_EDU_TYPE:
|
||||||
|
key = content["key"]
|
||||||
|
edu = Edu(**content["edu"])
|
||||||
|
|
||||||
|
keyed_edus.setdefault(
|
||||||
|
edu.destination, {}
|
||||||
|
)[(edu.destination, tuple(key))] = edu
|
||||||
|
elif typ == send_queue.EDU_TYPE:
|
||||||
|
edu = Edu(**content)
|
||||||
|
|
||||||
|
edus.setdefault(edu.destination, []).append(edu)
|
||||||
|
elif typ == send_queue.FAILURE_TYPE:
|
||||||
|
destination = content["destination"]
|
||||||
|
failure = content["failure"]
|
||||||
|
|
||||||
|
failures.setdefault(destination, []).append(failure)
|
||||||
|
elif typ == send_queue.DEVICE_MESSAGE_TYPE:
|
||||||
|
device_destinations.add(content["destination"])
|
||||||
|
else:
|
||||||
|
raise Exception("Unrecognised federation type: %r", typ)
|
||||||
|
|
||||||
|
# We've finished collecting, send everything off
|
||||||
|
for destination, states in presence_to_send.items():
|
||||||
|
self.federation_sender.send_presence(destination, states)
|
||||||
|
|
||||||
|
for destination, edu_map in keyed_edus.items():
|
||||||
|
for key, edu in edu_map.items():
|
||||||
|
self.federation_sender.send_edu(
|
||||||
|
edu.destination, edu.edu_type, edu.content, key=key,
|
||||||
|
)
|
||||||
|
|
||||||
|
for destination, edu_list in edus.items():
|
||||||
|
for edu in edu_list:
|
||||||
|
self.federation_sender.send_edu(
|
||||||
|
edu.destination, edu.edu_type, edu.content, key=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
for destination, failure_list in failures.items():
|
||||||
|
for failure in failure_list:
|
||||||
|
self.federation_sender.send_failure(destination, failure)
|
||||||
|
|
||||||
|
for destination in device_destinations:
|
||||||
|
self.federation_sender.send_device_messages(destination)
|
||||||
|
|
||||||
|
# Record where we are in the stream.
|
||||||
|
yield self.store.update_federation_out_pos(
|
||||||
|
"federation", latest_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# We also need to poke the federation sender when new events happen
|
||||||
|
event_stream = result.get("events")
|
||||||
|
if event_stream:
|
||||||
|
latest_pos = event_stream["position"]
|
||||||
|
self.federation_sender.notify_new_events(latest_pos)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with LoggingContext("main"):
|
||||||
|
start(sys.argv[1:])
|
@ -89,6 +89,9 @@ class ApplicationService(object):
|
|||||||
self.namespaces = self._check_namespaces(namespaces)
|
self.namespaces = self._check_namespaces(namespaces)
|
||||||
self.id = id
|
self.id = id
|
||||||
|
|
||||||
|
if "|" in self.id:
|
||||||
|
raise Exception("application service ID cannot contain '|' character")
|
||||||
|
|
||||||
# .protocols is a publicly visible field
|
# .protocols is a publicly visible field
|
||||||
if protocols:
|
if protocols:
|
||||||
self.protocols = set(protocols)
|
self.protocols = set(protocols)
|
||||||
|
@ -19,6 +19,7 @@ from synapse.api.errors import CodeMessageException
|
|||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
from synapse.types import ThirdPartyInstanceID
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
@ -177,6 +178,13 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
" valid result", uri)
|
" valid result", uri)
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
for instance in info.get("instances", []):
|
||||||
|
network_id = instance.get("network_id", None)
|
||||||
|
if network_id is not None:
|
||||||
|
instance["instance_id"] = ThirdPartyInstanceID(
|
||||||
|
service.id, network_id,
|
||||||
|
).to_string()
|
||||||
|
|
||||||
defer.returnValue(info)
|
defer.returnValue(info)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning("query_3pe_protocol to %s threw exception %s",
|
logger.warning("query_3pe_protocol to %s threw exception %s",
|
||||||
|
@ -50,6 +50,7 @@ handlers:
|
|||||||
console:
|
console:
|
||||||
class: logging.StreamHandler
|
class: logging.StreamHandler
|
||||||
formatter: precise
|
formatter: precise
|
||||||
|
filters: [context]
|
||||||
|
|
||||||
loggers:
|
loggers:
|
||||||
synapse:
|
synapse:
|
||||||
|
@ -27,12 +27,18 @@ class PasswordAuthProviderConfig(Config):
|
|||||||
ldap_config = config.get("ldap_config", {})
|
ldap_config = config.get("ldap_config", {})
|
||||||
self.ldap_enabled = ldap_config.get("enabled", False)
|
self.ldap_enabled = ldap_config.get("enabled", False)
|
||||||
if self.ldap_enabled:
|
if self.ldap_enabled:
|
||||||
from synapse.util.ldap_auth_provider import LdapAuthProvider
|
from ldap_auth_provider import LdapAuthProvider
|
||||||
parsed_config = LdapAuthProvider.parse_config(ldap_config)
|
parsed_config = LdapAuthProvider.parse_config(ldap_config)
|
||||||
self.password_providers.append((LdapAuthProvider, parsed_config))
|
self.password_providers.append((LdapAuthProvider, parsed_config))
|
||||||
|
|
||||||
providers = config.get("password_providers", [])
|
providers = config.get("password_providers", [])
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
|
# This is for backwards compat when the ldap auth provider resided
|
||||||
|
# in this package.
|
||||||
|
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
||||||
|
from ldap_auth_provider import LdapAuthProvider
|
||||||
|
provider_class = LdapAuthProvider
|
||||||
|
else:
|
||||||
# We need to import the module, and then pick the class out of
|
# We need to import the module, and then pick the class out of
|
||||||
# that, so we split based on the last dot.
|
# that, so we split based on the last dot.
|
||||||
module, clz = provider['module'].rsplit(".", 1)
|
module, clz = provider['module'].rsplit(".", 1)
|
||||||
@ -50,7 +56,7 @@ class PasswordAuthProviderConfig(Config):
|
|||||||
def default_config(self, **kwargs):
|
def default_config(self, **kwargs):
|
||||||
return """\
|
return """\
|
||||||
# password_providers:
|
# password_providers:
|
||||||
# - module: "synapse.util.ldap_auth_provider.LdapAuthProvider"
|
# - module: "ldap_auth_provider.LdapAuthProvider"
|
||||||
# config:
|
# config:
|
||||||
# enabled: true
|
# enabled: true
|
||||||
# uri: "ldap://ldap.example.com:389"
|
# uri: "ldap://ldap.example.com:389"
|
||||||
|
@ -32,7 +32,6 @@ class RegistrationConfig(Config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||||
self.user_creation_max_duration = int(config["user_creation_max_duration"])
|
|
||||||
|
|
||||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||||
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
||||||
@ -55,11 +54,6 @@ class RegistrationConfig(Config):
|
|||||||
# secret, even if registration is otherwise disabled.
|
# secret, even if registration is otherwise disabled.
|
||||||
registration_shared_secret: "%(registration_shared_secret)s"
|
registration_shared_secret: "%(registration_shared_secret)s"
|
||||||
|
|
||||||
# Sets the expiry for the short term user creation in
|
|
||||||
# milliseconds. For instance the bellow duration is two weeks
|
|
||||||
# in milliseconds.
|
|
||||||
user_creation_max_duration: 1209600000
|
|
||||||
|
|
||||||
# Set the number of bcrypt rounds used to generate password hash.
|
# Set the number of bcrypt rounds used to generate password hash.
|
||||||
# Larger numbers increase the work factor needed to generate the hash.
|
# Larger numbers increase the work factor needed to generate the hash.
|
||||||
# The default number of rounds is 12.
|
# The default number of rounds is 12.
|
||||||
|
@ -30,6 +30,11 @@ class ServerConfig(Config):
|
|||||||
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
||||||
self.public_baseurl = config.get("public_baseurl")
|
self.public_baseurl = config.get("public_baseurl")
|
||||||
|
|
||||||
|
# Whether to send federation traffic out in this process. This only
|
||||||
|
# applies to some federation traffic, and so shouldn't be used to
|
||||||
|
# "disable" federation
|
||||||
|
self.send_federation = config.get("send_federation", True)
|
||||||
|
|
||||||
if self.public_baseurl is not None:
|
if self.public_baseurl is not None:
|
||||||
if self.public_baseurl[-1] != '/':
|
if self.public_baseurl[-1] != '/':
|
||||||
self.public_baseurl += '/'
|
self.public_baseurl += '/'
|
||||||
|
@ -16,6 +16,17 @@
|
|||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from . import EventBase
|
from . import EventBase
|
||||||
|
|
||||||
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
||||||
|
# (?<!stuff) matches if the current position in the string is not preceded
|
||||||
|
# by a match for 'stuff'.
|
||||||
|
# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
|
||||||
|
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
|
||||||
|
SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
|
||||||
|
|
||||||
|
|
||||||
def prune_event(event):
|
def prune_event(event):
|
||||||
""" Returns a pruned version of the given event, which removes all keys we
|
""" Returns a pruned version of the given event, which removes all keys we
|
||||||
@ -97,6 +108,83 @@ def prune_event(event):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_field(src, dst, field):
|
||||||
|
"""Copy the field in 'src' to 'dst'.
|
||||||
|
|
||||||
|
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
|
||||||
|
then dst={"foo":{"bar":5}}.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src(dict): The dict to read from.
|
||||||
|
dst(dict): The dict to modify.
|
||||||
|
field(list<str>): List of keys to drill down to in 'src'.
|
||||||
|
"""
|
||||||
|
if len(field) == 0: # this should be impossible
|
||||||
|
return
|
||||||
|
if len(field) == 1: # common case e.g. 'origin_server_ts'
|
||||||
|
if field[0] in src:
|
||||||
|
dst[field[0]] = src[field[0]]
|
||||||
|
return
|
||||||
|
|
||||||
|
# Else is a nested field e.g. 'content.body'
|
||||||
|
# Pop the last field as that's the key to move across and we need the
|
||||||
|
# parent dict in order to access the data. Drill down to the right dict.
|
||||||
|
key_to_move = field.pop(-1)
|
||||||
|
sub_dict = src
|
||||||
|
for sub_field in field: # e.g. sub_field => "content"
|
||||||
|
if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]:
|
||||||
|
sub_dict = sub_dict[sub_field]
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
if key_to_move not in sub_dict:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Insert the key into the output dictionary, creating nested objects
|
||||||
|
# as required. We couldn't do this any earlier or else we'd need to delete
|
||||||
|
# the empty objects if the key didn't exist.
|
||||||
|
sub_out_dict = dst
|
||||||
|
for sub_field in field:
|
||||||
|
sub_out_dict = sub_out_dict.setdefault(sub_field, {})
|
||||||
|
sub_out_dict[key_to_move] = sub_dict[key_to_move]
|
||||||
|
|
||||||
|
|
||||||
|
def only_fields(dictionary, fields):
|
||||||
|
"""Return a new dict with only the fields in 'dictionary' which are present
|
||||||
|
in 'fields'.
|
||||||
|
|
||||||
|
If there are no event fields specified then all fields are included.
|
||||||
|
The entries may include '.' charaters to indicate sub-fields.
|
||||||
|
So ['content.body'] will include the 'body' field of the 'content' object.
|
||||||
|
A literal '.' character in a field name may be escaped using a '\'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dictionary(dict): The dictionary to read from.
|
||||||
|
fields(list<str>): A list of fields to copy over. Only shallow refs are
|
||||||
|
taken.
|
||||||
|
Returns:
|
||||||
|
dict: A new dictionary with only the given fields. If fields was empty,
|
||||||
|
the same dictionary is returned.
|
||||||
|
"""
|
||||||
|
if len(fields) == 0:
|
||||||
|
return dictionary
|
||||||
|
|
||||||
|
# for each field, convert it:
|
||||||
|
# ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]]
|
||||||
|
split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
|
||||||
|
|
||||||
|
# for each element of the output array of arrays:
|
||||||
|
# remove escaping so we can use the right key names.
|
||||||
|
split_fields[:] = [
|
||||||
|
[f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
|
||||||
|
]
|
||||||
|
|
||||||
|
output = {}
|
||||||
|
for field_array in split_fields:
|
||||||
|
_copy_field(dictionary, output, field_array)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def format_event_raw(d):
|
def format_event_raw(d):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@ -137,7 +225,7 @@ def format_event_for_client_v2_without_room_id(d):
|
|||||||
|
|
||||||
def serialize_event(e, time_now_ms, as_client_event=True,
|
def serialize_event(e, time_now_ms, as_client_event=True,
|
||||||
event_format=format_event_for_client_v1,
|
event_format=format_event_for_client_v1,
|
||||||
token_id=None):
|
token_id=None, only_event_fields=None):
|
||||||
# FIXME(erikj): To handle the case of presence events and the like
|
# FIXME(erikj): To handle the case of presence events and the like
|
||||||
if not isinstance(e, EventBase):
|
if not isinstance(e, EventBase):
|
||||||
return e
|
return e
|
||||||
@ -164,6 +252,12 @@ def serialize_event(e, time_now_ms, as_client_event=True,
|
|||||||
d["unsigned"]["transaction_id"] = txn_id
|
d["unsigned"]["transaction_id"] = txn_id
|
||||||
|
|
||||||
if as_client_event:
|
if as_client_event:
|
||||||
return event_format(d)
|
d = event_format(d)
|
||||||
else:
|
|
||||||
|
if only_event_fields:
|
||||||
|
if (not isinstance(only_event_fields, list) or
|
||||||
|
not all(isinstance(f, basestring) for f in only_event_fields)):
|
||||||
|
raise TypeError("only_event_fields must be a list of strings")
|
||||||
|
d = only_fields(d, only_event_fields)
|
||||||
|
|
||||||
return d
|
return d
|
||||||
|
@ -17,10 +17,9 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .replication import ReplicationLayer
|
from .replication import ReplicationLayer
|
||||||
from .transport.client import TransportLayerClient
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_http_replication(homeserver):
|
def initialize_http_replication(hs):
|
||||||
transport = TransportLayerClient(homeserver)
|
transport = hs.get_federation_transport_client()
|
||||||
|
|
||||||
return ReplicationLayer(homeserver, transport)
|
return ReplicationLayer(hs, transport)
|
||||||
|
@ -18,7 +18,6 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from .federation_base import FederationBase
|
from .federation_base import FederationBase
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
from .units import Edu
|
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException, HttpResponseException, SynapseError,
|
CodeMessageException, HttpResponseException, SynapseError,
|
||||||
@ -45,10 +44,6 @@ logger = logging.getLogger(__name__)
|
|||||||
# synapse.federation.federation_client is a silly name
|
# synapse.federation.federation_client is a silly name
|
||||||
metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
|
metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
|
||||||
|
|
||||||
sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations")
|
|
||||||
|
|
||||||
sent_edus_counter = metrics.register_counter("sent_edus")
|
|
||||||
|
|
||||||
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
|
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
|
||||||
|
|
||||||
|
|
||||||
@ -92,63 +87,6 @@ class FederationClient(FederationBase):
|
|||||||
|
|
||||||
self._get_pdu_cache.start()
|
self._get_pdu_cache.start()
|
||||||
|
|
||||||
@log_function
|
|
||||||
def send_pdu(self, pdu, destinations):
|
|
||||||
"""Informs the replication layer about a new PDU generated within the
|
|
||||||
home server that should be transmitted to others.
|
|
||||||
|
|
||||||
TODO: Figure out when we should actually resolve the deferred.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pdu (Pdu): The new Pdu.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred: Completes when we have successfully processed the PDU
|
|
||||||
and replicated it to any interested remote home servers.
|
|
||||||
"""
|
|
||||||
order = self._order
|
|
||||||
self._order += 1
|
|
||||||
|
|
||||||
sent_pdus_destination_dist.inc_by(len(destinations))
|
|
||||||
|
|
||||||
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
|
|
||||||
|
|
||||||
# TODO, add errback, etc.
|
|
||||||
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"[%s] transaction_layer.enqueue_pdu... done",
|
|
||||||
pdu.event_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def send_presence(self, destination, states):
|
|
||||||
if destination != self.server_name:
|
|
||||||
self._transaction_queue.enqueue_presence(destination, states)
|
|
||||||
|
|
||||||
@log_function
|
|
||||||
def send_edu(self, destination, edu_type, content, key=None):
|
|
||||||
edu = Edu(
|
|
||||||
origin=self.server_name,
|
|
||||||
destination=destination,
|
|
||||||
edu_type=edu_type,
|
|
||||||
content=content,
|
|
||||||
)
|
|
||||||
|
|
||||||
sent_edus_counter.inc()
|
|
||||||
|
|
||||||
self._transaction_queue.enqueue_edu(edu, key=key)
|
|
||||||
|
|
||||||
@log_function
|
|
||||||
def send_device_messages(self, destination):
|
|
||||||
"""Sends the device messages in the local database to the remote
|
|
||||||
destination"""
|
|
||||||
self._transaction_queue.enqueue_device_messages(destination)
|
|
||||||
|
|
||||||
@log_function
|
|
||||||
def send_failure(self, failure, destination):
|
|
||||||
self._transaction_queue.enqueue_failure(failure, destination)
|
|
||||||
return defer.succeed(None)
|
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def make_query(self, destination, query_type, args,
|
def make_query(self, destination, query_type, args,
|
||||||
retry_on_dns_fail=False):
|
retry_on_dns_fail=False):
|
||||||
@ -717,12 +655,15 @@ class FederationClient(FederationBase):
|
|||||||
raise RuntimeError("Failed to send to any server.")
|
raise RuntimeError("Failed to send to any server.")
|
||||||
|
|
||||||
def get_public_rooms(self, destination, limit=None, since_token=None,
|
def get_public_rooms(self, destination, limit=None, since_token=None,
|
||||||
search_filter=None):
|
search_filter=None, include_all_networks=False,
|
||||||
|
third_party_instance_id=None):
|
||||||
if destination == self.server_name:
|
if destination == self.server_name:
|
||||||
return
|
return
|
||||||
|
|
||||||
return self.transport_layer.get_public_rooms(
|
return self.transport_layer.get_public_rooms(
|
||||||
destination, limit, since_token, search_filter
|
destination, limit, since_token, search_filter,
|
||||||
|
include_all_networks=include_all_networks,
|
||||||
|
third_party_instance_id=third_party_instance_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -20,8 +20,6 @@ a given transport.
|
|||||||
from .federation_client import FederationClient
|
from .federation_client import FederationClient
|
||||||
from .federation_server import FederationServer
|
from .federation_server import FederationServer
|
||||||
|
|
||||||
from .transaction_queue import TransactionQueue
|
|
||||||
|
|
||||||
from .persistence import TransactionActions
|
from .persistence import TransactionActions
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -66,9 +64,6 @@ class ReplicationLayer(FederationClient, FederationServer):
|
|||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
|
|
||||||
self.transaction_actions = TransactionActions(self.store)
|
self.transaction_actions = TransactionActions(self.store)
|
||||||
self._transaction_queue = TransactionQueue(hs, transport_layer)
|
|
||||||
|
|
||||||
self._order = 0
|
|
||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
|
298
synapse/federation/send_queue.py
Normal file
298
synapse/federation/send_queue.py
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""A federation sender that forwards things to be sent across replication to
|
||||||
|
a worker process.
|
||||||
|
|
||||||
|
It assumes there is a single worker process feeding off of it.
|
||||||
|
|
||||||
|
Each row in the replication stream consists of a type and some json, where the
|
||||||
|
types indicate whether they are presence, or edus, etc.
|
||||||
|
|
||||||
|
Ephemeral or non-event data are queued up in-memory. When the worker requests
|
||||||
|
updates since a particular point, all in-memory data since before that point is
|
||||||
|
dropped. We also expire things in the queue after 5 minutes, to ensure that a
|
||||||
|
dead worker doesn't cause the queues to grow limitlessly.
|
||||||
|
|
||||||
|
Events are replicated via a separate events stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .units import Edu
|
||||||
|
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
import synapse.metrics
|
||||||
|
|
||||||
|
from blist import sorteddict
|
||||||
|
import ujson
|
||||||
|
|
||||||
|
|
||||||
|
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
PRESENCE_TYPE = "p"
|
||||||
|
KEYED_EDU_TYPE = "k"
|
||||||
|
EDU_TYPE = "e"
|
||||||
|
FAILURE_TYPE = "f"
|
||||||
|
DEVICE_MESSAGE_TYPE = "d"
|
||||||
|
|
||||||
|
|
||||||
|
class FederationRemoteSendQueue(object):
|
||||||
|
"""A drop in replacement for TransactionQueue"""
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.server_name = hs.hostname
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
self.presence_map = {}
|
||||||
|
self.presence_changed = sorteddict()
|
||||||
|
|
||||||
|
self.keyed_edu = {}
|
||||||
|
self.keyed_edu_changed = sorteddict()
|
||||||
|
|
||||||
|
self.edus = sorteddict()
|
||||||
|
|
||||||
|
self.failures = sorteddict()
|
||||||
|
|
||||||
|
self.device_messages = sorteddict()
|
||||||
|
|
||||||
|
self.pos = 1
|
||||||
|
self.pos_time = sorteddict()
|
||||||
|
|
||||||
|
# EVERYTHING IS SAD. In particular, python only makes new scopes when
|
||||||
|
# we make a new function, so we need to make a new function so the inner
|
||||||
|
# lambda binds to the queue rather than to the name of the queue which
|
||||||
|
# changes. ARGH.
|
||||||
|
def register(name, queue):
|
||||||
|
metrics.register_callback(
|
||||||
|
queue_name + "_size",
|
||||||
|
lambda: len(queue),
|
||||||
|
)
|
||||||
|
|
||||||
|
for queue_name in [
|
||||||
|
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
|
||||||
|
"edus", "failures", "device_messages", "pos_time",
|
||||||
|
]:
|
||||||
|
register(queue_name, getattr(self, queue_name))
|
||||||
|
|
||||||
|
self.clock.looping_call(self._clear_queue, 30 * 1000)
|
||||||
|
|
||||||
|
def _next_pos(self):
|
||||||
|
pos = self.pos
|
||||||
|
self.pos += 1
|
||||||
|
self.pos_time[self.clock.time_msec()] = pos
|
||||||
|
return pos
|
||||||
|
|
||||||
|
def _clear_queue(self):
|
||||||
|
"""Clear the queues for anything older than N minutes"""
|
||||||
|
|
||||||
|
FIVE_MINUTES_AGO = 5 * 60 * 1000
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
|
||||||
|
keys = self.pos_time.keys()
|
||||||
|
time = keys.bisect_left(now - FIVE_MINUTES_AGO)
|
||||||
|
if not keys[:time]:
|
||||||
|
return
|
||||||
|
|
||||||
|
position_to_delete = max(keys[:time])
|
||||||
|
for key in keys[:time]:
|
||||||
|
del self.pos_time[key]
|
||||||
|
|
||||||
|
self._clear_queue_before_pos(position_to_delete)
|
||||||
|
|
||||||
|
def _clear_queue_before_pos(self, position_to_delete):
|
||||||
|
"""Clear all the queues from before a given position"""
|
||||||
|
with Measure(self.clock, "send_queue._clear"):
|
||||||
|
# Delete things out of presence maps
|
||||||
|
keys = self.presence_changed.keys()
|
||||||
|
i = keys.bisect_left(position_to_delete)
|
||||||
|
for key in keys[:i]:
|
||||||
|
del self.presence_changed[key]
|
||||||
|
|
||||||
|
user_ids = set(
|
||||||
|
user_id for uids in self.presence_changed.values() for _, user_id in uids
|
||||||
|
)
|
||||||
|
|
||||||
|
to_del = [
|
||||||
|
user_id for user_id in self.presence_map if user_id not in user_ids
|
||||||
|
]
|
||||||
|
for user_id in to_del:
|
||||||
|
del self.presence_map[user_id]
|
||||||
|
|
||||||
|
# Delete things out of keyed edus
|
||||||
|
keys = self.keyed_edu_changed.keys()
|
||||||
|
i = keys.bisect_left(position_to_delete)
|
||||||
|
for key in keys[:i]:
|
||||||
|
del self.keyed_edu_changed[key]
|
||||||
|
|
||||||
|
live_keys = set()
|
||||||
|
for edu_key in self.keyed_edu_changed.values():
|
||||||
|
live_keys.add(edu_key)
|
||||||
|
|
||||||
|
to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
|
||||||
|
for edu_key in to_del:
|
||||||
|
del self.keyed_edu[edu_key]
|
||||||
|
|
||||||
|
# Delete things out of edu map
|
||||||
|
keys = self.edus.keys()
|
||||||
|
i = keys.bisect_left(position_to_delete)
|
||||||
|
for key in keys[:i]:
|
||||||
|
del self.edus[key]
|
||||||
|
|
||||||
|
# Delete things out of failure map
|
||||||
|
keys = self.failures.keys()
|
||||||
|
i = keys.bisect_left(position_to_delete)
|
||||||
|
for key in keys[:i]:
|
||||||
|
del self.failures[key]
|
||||||
|
|
||||||
|
# Delete things out of device map
|
||||||
|
keys = self.device_messages.keys()
|
||||||
|
i = keys.bisect_left(position_to_delete)
|
||||||
|
for key in keys[:i]:
|
||||||
|
del self.device_messages[key]
|
||||||
|
|
||||||
|
def notify_new_events(self, current_id):
|
||||||
|
"""As per TransactionQueue"""
|
||||||
|
# We don't need to replicate this as it gets sent down a different
|
||||||
|
# stream.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def send_edu(self, destination, edu_type, content, key=None):
|
||||||
|
"""As per TransactionQueue"""
|
||||||
|
pos = self._next_pos()
|
||||||
|
|
||||||
|
edu = Edu(
|
||||||
|
origin=self.server_name,
|
||||||
|
destination=destination,
|
||||||
|
edu_type=edu_type,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
if key:
|
||||||
|
assert isinstance(key, tuple)
|
||||||
|
self.keyed_edu[(destination, key)] = edu
|
||||||
|
self.keyed_edu_changed[pos] = (destination, key)
|
||||||
|
else:
|
||||||
|
self.edus[pos] = edu
|
||||||
|
|
||||||
|
def send_presence(self, destination, states):
|
||||||
|
"""As per TransactionQueue"""
|
||||||
|
pos = self._next_pos()
|
||||||
|
|
||||||
|
self.presence_map.update({
|
||||||
|
state.user_id: state
|
||||||
|
for state in states
|
||||||
|
})
|
||||||
|
|
||||||
|
self.presence_changed[pos] = [
|
||||||
|
(destination, state.user_id) for state in states
|
||||||
|
]
|
||||||
|
|
||||||
|
def send_failure(self, failure, destination):
|
||||||
|
"""As per TransactionQueue"""
|
||||||
|
pos = self._next_pos()
|
||||||
|
|
||||||
|
self.failures[pos] = (destination, str(failure))
|
||||||
|
|
||||||
|
def send_device_messages(self, destination):
|
||||||
|
"""As per TransactionQueue"""
|
||||||
|
pos = self._next_pos()
|
||||||
|
self.device_messages[pos] = destination
|
||||||
|
|
||||||
|
def get_current_token(self):
|
||||||
|
return self.pos - 1
|
||||||
|
|
||||||
|
def get_replication_rows(self, token, limit, federation_ack=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
token (int)
|
||||||
|
limit (int)
|
||||||
|
federation_ack (int): Optional. The position where the worker is
|
||||||
|
explicitly acknowledged it has handled. Allows us to drop
|
||||||
|
data from before that point
|
||||||
|
"""
|
||||||
|
# TODO: Handle limit.
|
||||||
|
|
||||||
|
# To handle restarts where we wrap around
|
||||||
|
if token > self.pos:
|
||||||
|
token = -1
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
# There should be only one reader, so lets delete everything its
|
||||||
|
# acknowledged its seen.
|
||||||
|
if federation_ack:
|
||||||
|
self._clear_queue_before_pos(federation_ack)
|
||||||
|
|
||||||
|
# Fetch changed presence
|
||||||
|
keys = self.presence_changed.keys()
|
||||||
|
i = keys.bisect_right(token)
|
||||||
|
dest_user_ids = set(
|
||||||
|
(pos, dest_user_id)
|
||||||
|
for pos in keys[i:]
|
||||||
|
for dest_user_id in self.presence_changed[pos]
|
||||||
|
)
|
||||||
|
|
||||||
|
for (key, (dest, user_id)) in dest_user_ids:
|
||||||
|
rows.append((key, PRESENCE_TYPE, ujson.dumps({
|
||||||
|
"destination": dest,
|
||||||
|
"state": self.presence_map[user_id].as_dict(),
|
||||||
|
})))
|
||||||
|
|
||||||
|
# Fetch changes keyed edus
|
||||||
|
keys = self.keyed_edu_changed.keys()
|
||||||
|
i = keys.bisect_right(token)
|
||||||
|
keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
|
||||||
|
|
||||||
|
for (pos, (destination, edu_key)) in keyed_edus:
|
||||||
|
rows.append(
|
||||||
|
(pos, KEYED_EDU_TYPE, ujson.dumps({
|
||||||
|
"key": edu_key,
|
||||||
|
"edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
|
||||||
|
}))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch changed edus
|
||||||
|
keys = self.edus.keys()
|
||||||
|
i = keys.bisect_right(token)
|
||||||
|
edus = set((k, self.edus[k]) for k in keys[i:])
|
||||||
|
|
||||||
|
for (pos, edu) in edus:
|
||||||
|
rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
|
||||||
|
|
||||||
|
# Fetch changed failures
|
||||||
|
keys = self.failures.keys()
|
||||||
|
i = keys.bisect_right(token)
|
||||||
|
failures = set((k, self.failures[k]) for k in keys[i:])
|
||||||
|
|
||||||
|
for (pos, (destination, failure)) in failures:
|
||||||
|
rows.append((pos, FAILURE_TYPE, ujson.dumps({
|
||||||
|
"destination": destination,
|
||||||
|
"failure": failure,
|
||||||
|
})))
|
||||||
|
|
||||||
|
# Fetch changed device messages
|
||||||
|
keys = self.device_messages.keys()
|
||||||
|
i = keys.bisect_right(token)
|
||||||
|
device_messages = set((k, self.device_messages[k]) for k in keys[i:])
|
||||||
|
|
||||||
|
for (pos, destination) in device_messages:
|
||||||
|
rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
|
||||||
|
"destination": destination,
|
||||||
|
})))
|
||||||
|
|
||||||
|
# Sort rows based on pos
|
||||||
|
rows.sort()
|
||||||
|
|
||||||
|
return rows
|
@ -19,6 +19,7 @@ from twisted.internet import defer
|
|||||||
from .persistence import TransactionActions
|
from .persistence import TransactionActions
|
||||||
from .units import Transaction, Edu
|
from .units import Transaction, Edu
|
||||||
|
|
||||||
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import HttpResponseException
|
from synapse.api.errors import HttpResponseException
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
@ -26,6 +27,7 @@ from synapse.util.retryutils import (
|
|||||||
get_retry_limiter, NotRetryingDestination,
|
get_retry_limiter, NotRetryingDestination,
|
||||||
)
|
)
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.handlers.presence import format_user_presence_state
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
@ -36,6 +38,12 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
|
client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
|
||||||
|
sent_pdus_destination_dist = client_metrics.register_distribution(
|
||||||
|
"sent_pdu_destinations"
|
||||||
|
)
|
||||||
|
sent_edus_counter = client_metrics.register_counter("sent_edus")
|
||||||
|
|
||||||
|
|
||||||
class TransactionQueue(object):
|
class TransactionQueue(object):
|
||||||
"""This class makes sure we only have one transaction in flight at
|
"""This class makes sure we only have one transaction in flight at
|
||||||
@ -44,13 +52,14 @@ class TransactionQueue(object):
|
|||||||
It batches pending PDUs into single transactions.
|
It batches pending PDUs into single transactions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, transport_layer):
|
def __init__(self, hs):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
self.transaction_actions = TransactionActions(self.store)
|
self.transaction_actions = TransactionActions(self.store)
|
||||||
|
|
||||||
self.transport_layer = transport_layer
|
self.transport_layer = hs.get_federation_transport_client()
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@ -95,6 +104,11 @@ class TransactionQueue(object):
|
|||||||
# HACK to get unique tx id
|
# HACK to get unique tx id
|
||||||
self._next_txn_id = int(self.clock.time_msec())
|
self._next_txn_id = int(self.clock.time_msec())
|
||||||
|
|
||||||
|
self._order = 1
|
||||||
|
|
||||||
|
self._is_processing = False
|
||||||
|
self._last_poked_id = -1
|
||||||
|
|
||||||
def can_send_to(self, destination):
|
def can_send_to(self, destination):
|
||||||
"""Can we send messages to the given server?
|
"""Can we send messages to the given server?
|
||||||
|
|
||||||
@ -115,11 +129,61 @@ class TransactionQueue(object):
|
|||||||
else:
|
else:
|
||||||
return not destination.startswith("localhost")
|
return not destination.startswith("localhost")
|
||||||
|
|
||||||
def enqueue_pdu(self, pdu, destinations, order):
|
@defer.inlineCallbacks
|
||||||
|
def notify_new_events(self, current_id):
|
||||||
|
"""This gets called when we have some new events we might want to
|
||||||
|
send out to other servers.
|
||||||
|
"""
|
||||||
|
self._last_poked_id = max(current_id, self._last_poked_id)
|
||||||
|
|
||||||
|
if self._is_processing:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._is_processing = True
|
||||||
|
while True:
|
||||||
|
last_token = yield self.store.get_federation_out_pos("events")
|
||||||
|
next_token, events = yield self.store.get_all_new_events_stream(
|
||||||
|
last_token, self._last_poked_id, limit=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Handling %s -> %s", last_token, next_token)
|
||||||
|
|
||||||
|
if not events and next_token >= self._last_poked_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
users_in_room = yield self.state.get_current_user_in_room(
|
||||||
|
event.room_id, latest_event_ids=[event.event_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
destinations = set(
|
||||||
|
get_domain_from_id(user_id) for user_id in users_in_room
|
||||||
|
)
|
||||||
|
|
||||||
|
if event.type == EventTypes.Member:
|
||||||
|
if event.content["membership"] == Membership.JOIN:
|
||||||
|
destinations.add(get_domain_from_id(event.state_key))
|
||||||
|
|
||||||
|
logger.debug("Sending %s to %r", event, destinations)
|
||||||
|
|
||||||
|
self._send_pdu(event, destinations)
|
||||||
|
|
||||||
|
yield self.store.update_federation_out_pos(
|
||||||
|
"events", next_token
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._is_processing = False
|
||||||
|
|
||||||
|
def _send_pdu(self, pdu, destinations):
|
||||||
# We loop through all destinations to see whether we already have
|
# We loop through all destinations to see whether we already have
|
||||||
# a transaction in progress. If we do, stick it in the pending_pdus
|
# a transaction in progress. If we do, stick it in the pending_pdus
|
||||||
# table and we'll get back to it later.
|
# table and we'll get back to it later.
|
||||||
|
|
||||||
|
order = self._order
|
||||||
|
self._order += 1
|
||||||
|
|
||||||
destinations = set(destinations)
|
destinations = set(destinations)
|
||||||
destinations = set(
|
destinations = set(
|
||||||
dest for dest in destinations if self.can_send_to(dest)
|
dest for dest in destinations if self.can_send_to(dest)
|
||||||
@ -130,6 +194,8 @@ class TransactionQueue(object):
|
|||||||
if not destinations:
|
if not destinations:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
sent_pdus_destination_dist.inc_by(len(destinations))
|
||||||
|
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
self.pending_pdus_by_dest.setdefault(destination, []).append(
|
self.pending_pdus_by_dest.setdefault(destination, []).append(
|
||||||
(pdu, order)
|
(pdu, order)
|
||||||
@ -139,7 +205,10 @@ class TransactionQueue(object):
|
|||||||
self._attempt_new_transaction, destination
|
self._attempt_new_transaction, destination
|
||||||
)
|
)
|
||||||
|
|
||||||
def enqueue_presence(self, destination, states):
|
def send_presence(self, destination, states):
|
||||||
|
if not self.can_send_to(destination):
|
||||||
|
return
|
||||||
|
|
||||||
self.pending_presence_by_dest.setdefault(destination, {}).update({
|
self.pending_presence_by_dest.setdefault(destination, {}).update({
|
||||||
state.user_id: state for state in states
|
state.user_id: state for state in states
|
||||||
})
|
})
|
||||||
@ -148,12 +217,19 @@ class TransactionQueue(object):
|
|||||||
self._attempt_new_transaction, destination
|
self._attempt_new_transaction, destination
|
||||||
)
|
)
|
||||||
|
|
||||||
def enqueue_edu(self, edu, key=None):
|
def send_edu(self, destination, edu_type, content, key=None):
|
||||||
destination = edu.destination
|
edu = Edu(
|
||||||
|
origin=self.server_name,
|
||||||
|
destination=destination,
|
||||||
|
edu_type=edu_type,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
if not self.can_send_to(destination):
|
if not self.can_send_to(destination):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
sent_edus_counter.inc()
|
||||||
|
|
||||||
if key:
|
if key:
|
||||||
self.pending_edus_keyed_by_dest.setdefault(
|
self.pending_edus_keyed_by_dest.setdefault(
|
||||||
destination, {}
|
destination, {}
|
||||||
@ -165,7 +241,7 @@ class TransactionQueue(object):
|
|||||||
self._attempt_new_transaction, destination
|
self._attempt_new_transaction, destination
|
||||||
)
|
)
|
||||||
|
|
||||||
def enqueue_failure(self, failure, destination):
|
def send_failure(self, failure, destination):
|
||||||
if destination == self.server_name or destination == "localhost":
|
if destination == self.server_name or destination == "localhost":
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -180,7 +256,7 @@ class TransactionQueue(object):
|
|||||||
self._attempt_new_transaction, destination
|
self._attempt_new_transaction, destination
|
||||||
)
|
)
|
||||||
|
|
||||||
def enqueue_device_messages(self, destination):
|
def send_device_messages(self, destination):
|
||||||
if destination == self.server_name or destination == "localhost":
|
if destination == self.server_name or destination == "localhost":
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -191,6 +267,9 @@ class TransactionQueue(object):
|
|||||||
self._attempt_new_transaction, destination
|
self._attempt_new_transaction, destination
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_current_token(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _attempt_new_transaction(self, destination):
|
def _attempt_new_transaction(self, destination):
|
||||||
# list of (pending_pdu, deferred, order)
|
# list of (pending_pdu, deferred, order)
|
||||||
@ -383,6 +462,13 @@ class TransactionQueue(object):
|
|||||||
code = e.code
|
code = e.code
|
||||||
response = e.response
|
response = e.response
|
||||||
|
|
||||||
|
if e.code == 429 or 500 <= e.code:
|
||||||
|
logger.info(
|
||||||
|
"TX [%s] {%s} got %d response",
|
||||||
|
destination, txn_id, code
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"TX [%s] {%s} got %d response",
|
"TX [%s] {%s} got %d response",
|
||||||
destination, txn_id, code
|
destination, txn_id, code
|
||||||
|
@ -249,10 +249,15 @@ class TransportLayerClient(object):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_public_rooms(self, remote_server, limit, since_token,
|
def get_public_rooms(self, remote_server, limit, since_token,
|
||||||
search_filter=None):
|
search_filter=None, include_all_networks=False,
|
||||||
|
third_party_instance_id=None):
|
||||||
path = PREFIX + "/publicRooms"
|
path = PREFIX + "/publicRooms"
|
||||||
|
|
||||||
args = {}
|
args = {
|
||||||
|
"include_all_networks": "true" if include_all_networks else "false",
|
||||||
|
}
|
||||||
|
if third_party_instance_id:
|
||||||
|
args["third_party_instance_id"] = third_party_instance_id,
|
||||||
if limit:
|
if limit:
|
||||||
args["limit"] = [str(limit)]
|
args["limit"] = [str(limit)]
|
||||||
if since_token:
|
if since_token:
|
||||||
|
@ -20,9 +20,11 @@ from synapse.api.errors import Codes, SynapseError
|
|||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
||||||
|
parse_boolean_from_args,
|
||||||
)
|
)
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
from synapse.types import ThirdPartyInstanceID
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
@ -558,8 +560,23 @@ class PublicRoomList(BaseFederationServlet):
|
|||||||
def on_GET(self, origin, content, query):
|
def on_GET(self, origin, content, query):
|
||||||
limit = parse_integer_from_args(query, "limit", 0)
|
limit = parse_integer_from_args(query, "limit", 0)
|
||||||
since_token = parse_string_from_args(query, "since", None)
|
since_token = parse_string_from_args(query, "since", None)
|
||||||
|
include_all_networks = parse_boolean_from_args(
|
||||||
|
query, "include_all_networks", False
|
||||||
|
)
|
||||||
|
third_party_instance_id = parse_string_from_args(
|
||||||
|
query, "third_party_instance_id", None
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_all_networks:
|
||||||
|
network_tuple = None
|
||||||
|
elif third_party_instance_id:
|
||||||
|
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
|
||||||
|
else:
|
||||||
|
network_tuple = ThirdPartyInstanceID(None, None)
|
||||||
|
|
||||||
data = yield self.room_list_handler.get_local_public_room_list(
|
data = yield self.room_list_handler.get_local_public_room_list(
|
||||||
limit, since_token
|
limit, since_token,
|
||||||
|
network_tuple=network_tuple
|
||||||
)
|
)
|
||||||
defer.returnValue((200, data))
|
defer.returnValue((200, data))
|
||||||
|
|
||||||
|
@ -24,7 +24,6 @@ from .profile import ProfileHandler
|
|||||||
from .directory import DirectoryHandler
|
from .directory import DirectoryHandler
|
||||||
from .admin import AdminHandler
|
from .admin import AdminHandler
|
||||||
from .identity import IdentityHandler
|
from .identity import IdentityHandler
|
||||||
from .receipts import ReceiptsHandler
|
|
||||||
from .search import SearchHandler
|
from .search import SearchHandler
|
||||||
|
|
||||||
|
|
||||||
@ -56,7 +55,6 @@ class Handlers(object):
|
|||||||
self.profile_handler = ProfileHandler(hs)
|
self.profile_handler = ProfileHandler(hs)
|
||||||
self.directory_handler = DirectoryHandler(hs)
|
self.directory_handler = DirectoryHandler(hs)
|
||||||
self.admin_handler = AdminHandler(hs)
|
self.admin_handler = AdminHandler(hs)
|
||||||
self.receipts_handler = ReceiptsHandler(hs)
|
|
||||||
self.identity_handler = IdentityHandler(hs)
|
self.identity_handler = IdentityHandler(hs)
|
||||||
self.search_handler = SearchHandler(hs)
|
self.search_handler = SearchHandler(hs)
|
||||||
self.room_context_handler = RoomContextHandler(hs)
|
self.room_context_handler = RoomContextHandler(hs)
|
||||||
|
@ -61,6 +61,8 @@ class AuthHandler(BaseHandler):
|
|||||||
for module, config in hs.config.password_providers
|
for module, config in hs.config.password_providers
|
||||||
]
|
]
|
||||||
|
|
||||||
|
logger.info("Extra password_providers: %r", self.password_providers)
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
@ -160,7 +162,15 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
for f in flows:
|
for f in flows:
|
||||||
if len(set(f) - set(creds.keys())) == 0:
|
if len(set(f) - set(creds.keys())) == 0:
|
||||||
logger.info("Auth completed with creds: %r", creds)
|
# it's very useful to know what args are stored, but this can
|
||||||
|
# include the password in the case of registering, so only log
|
||||||
|
# the keys (confusingly, clientdict may contain a password
|
||||||
|
# param, creds is just what the user authed as for UI auth
|
||||||
|
# and is not sensitive).
|
||||||
|
logger.info(
|
||||||
|
"Auth completed with creds: %r. Client dict has keys: %r",
|
||||||
|
creds, clientdict.keys()
|
||||||
|
)
|
||||||
defer.returnValue((True, creds, clientdict, session['id']))
|
defer.returnValue((True, creds, clientdict, session['id']))
|
||||||
|
|
||||||
ret = self._auth_dict_for_flows(flows, session)
|
ret = self._auth_dict_for_flows(flows, session)
|
||||||
@ -378,12 +388,10 @@ class AuthHandler(BaseHandler):
|
|||||||
return self._check_password(user_id, password)
|
return self._check_password(user_id, password)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_login_tuple_for_user_id(self, user_id, device_id=None,
|
def get_access_token_for_user_id(self, user_id, device_id=None,
|
||||||
initial_display_name=None):
|
initial_display_name=None):
|
||||||
"""
|
"""
|
||||||
Gets login tuple for the user with the given user ID.
|
Creates a new access token for the user with the given user ID.
|
||||||
|
|
||||||
Creates a new access/refresh token for the user.
|
|
||||||
|
|
||||||
The user is assumed to have been authenticated by some other
|
The user is assumed to have been authenticated by some other
|
||||||
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||||
@ -398,16 +406,13 @@ class AuthHandler(BaseHandler):
|
|||||||
initial_display_name (str): display name to associate with the
|
initial_display_name (str): display name to associate with the
|
||||||
device if it needs re-registering
|
device if it needs re-registering
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of:
|
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
The refresh token for the user's session.
|
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem storing the token.
|
StoreError if there was a problem storing the token.
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||||
access_token = yield self.issue_access_token(user_id, device_id)
|
access_token = yield self.issue_access_token(user_id, device_id)
|
||||||
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
|
||||||
|
|
||||||
# the device *should* have been registered before we got here; however,
|
# the device *should* have been registered before we got here; however,
|
||||||
# it's possible we raced against a DELETE operation. The thing we
|
# it's possible we raced against a DELETE operation. The thing we
|
||||||
@ -418,7 +423,7 @@ class AuthHandler(BaseHandler):
|
|||||||
user_id, device_id, initial_display_name
|
user_id, device_id, initial_display_name
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((access_token, refresh_token))
|
defer.returnValue(access_token)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_user_exists(self, user_id):
|
def check_user_exists(self, user_id):
|
||||||
@ -529,35 +534,19 @@ class AuthHandler(BaseHandler):
|
|||||||
device_id)
|
device_id)
|
||||||
defer.returnValue(access_token)
|
defer.returnValue(access_token)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def generate_access_token(self, user_id, extra_caveats=None):
|
||||||
def issue_refresh_token(self, user_id, device_id=None):
|
|
||||||
refresh_token = self.generate_refresh_token(user_id)
|
|
||||||
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
|
|
||||||
device_id)
|
|
||||||
defer.returnValue(refresh_token)
|
|
||||||
|
|
||||||
def generate_access_token(self, user_id, extra_caveats=None,
|
|
||||||
duration_in_ms=(60 * 60 * 1000)):
|
|
||||||
extra_caveats = extra_caveats or []
|
extra_caveats = extra_caveats or []
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
now = self.hs.get_clock().time_msec()
|
# Include a nonce, to make sure that each login gets a different
|
||||||
expiry = now + duration_in_ms
|
# access token.
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("nonce = %s" % (
|
||||||
|
stringutils.random_string_with_symbols(16),
|
||||||
|
))
|
||||||
for caveat in extra_caveats:
|
for caveat in extra_caveats:
|
||||||
macaroon.add_first_party_caveat(caveat)
|
macaroon.add_first_party_caveat(caveat)
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def generate_refresh_token(self, user_id):
|
|
||||||
m = self._generate_base_macaroon(user_id)
|
|
||||||
m.add_first_party_caveat("type = refresh")
|
|
||||||
# Important to add a nonce, because otherwise every refresh token for a
|
|
||||||
# user will be the same.
|
|
||||||
m.add_first_party_caveat("nonce = %s" % (
|
|
||||||
stringutils.random_string_with_symbols(16),
|
|
||||||
))
|
|
||||||
return m.serialize()
|
|
||||||
|
|
||||||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = login")
|
macaroon.add_first_party_caveat("type = login")
|
||||||
|
@ -34,9 +34,9 @@ class DeviceMessageHandler(object):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_federation_sender()
|
||||||
|
|
||||||
self.federation.register_edu_handler(
|
hs.get_replication_layer().register_edu_handler(
|
||||||
"m.direct_to_device", self.on_direct_to_device_edu
|
"m.direct_to_device", self.on_direct_to_device_edu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -339,3 +339,22 @@ class DirectoryHandler(BaseHandler):
|
|||||||
yield self.auth.check_can_change_room_list(room_id, requester.user)
|
yield self.auth.check_can_change_room_list(room_id, requester.user)
|
||||||
|
|
||||||
yield self.store.set_room_is_public(room_id, visibility == "public")
|
yield self.store.set_room_is_public(room_id, visibility == "public")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def edit_published_appservice_room_list(self, appservice_id, network_id,
|
||||||
|
room_id, visibility):
|
||||||
|
"""Add or remove a room from the appservice/network specific public
|
||||||
|
room list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
appservice_id (str): ID of the appservice that owns the list
|
||||||
|
network_id (str): The ID of the network the list is associated with
|
||||||
|
room_id (str)
|
||||||
|
visibility (str): either "public" or "private"
|
||||||
|
"""
|
||||||
|
if visibility not in ["public", "private"]:
|
||||||
|
raise SynapseError(400, "Invalid visibility setting")
|
||||||
|
|
||||||
|
yield self.store.set_room_is_public_appservice(
|
||||||
|
room_id, appservice_id, network_id, visibility == "public"
|
||||||
|
)
|
||||||
|
@ -111,6 +111,11 @@ class E2eKeysHandler(object):
|
|||||||
failures[destination] = {
|
failures[destination] = {
|
||||||
"status": 503, "message": "Not ready for retry",
|
"status": 503, "message": "Not ready for retry",
|
||||||
}
|
}
|
||||||
|
except Exception as e:
|
||||||
|
# include ConnectionRefused and other errors
|
||||||
|
failures[destination] = {
|
||||||
|
"status": 503, "message": e.message
|
||||||
|
}
|
||||||
|
|
||||||
yield preserve_context_over_deferred(defer.gatherResults([
|
yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
preserve_fn(do_remote_query)(destination)
|
preserve_fn(do_remote_query)(destination)
|
||||||
@ -222,6 +227,11 @@ class E2eKeysHandler(object):
|
|||||||
failures[destination] = {
|
failures[destination] = {
|
||||||
"status": 503, "message": "Not ready for retry",
|
"status": 503, "message": "Not ready for retry",
|
||||||
}
|
}
|
||||||
|
except Exception as e:
|
||||||
|
# include ConnectionRefused and other errors
|
||||||
|
failures[destination] = {
|
||||||
|
"status": 503, "message": e.message
|
||||||
|
}
|
||||||
|
|
||||||
yield preserve_context_over_deferred(defer.gatherResults([
|
yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
preserve_fn(claim_client_keys)(destination)
|
preserve_fn(claim_client_keys)(destination)
|
||||||
|
@ -80,22 +80,6 @@ class FederationHandler(BaseHandler):
|
|||||||
# When joining a room we need to queue any events for that room up
|
# When joining a room we need to queue any events for that room up
|
||||||
self.room_queues = {}
|
self.room_queues = {}
|
||||||
|
|
||||||
def handle_new_event(self, event, destinations):
|
|
||||||
""" Takes in an event from the client to server side, that has already
|
|
||||||
been authed and handled by the state module, and sends it to any
|
|
||||||
remote home servers that may be interested.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: The event to send
|
|
||||||
destinations: A list of destinations to send it to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred: Resolved when it has successfully been queued for
|
|
||||||
processing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.replication_layer.send_pdu(event, destinations)
|
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
|
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
|
||||||
@ -268,9 +252,12 @@ class FederationHandler(BaseHandler):
|
|||||||
except:
|
except:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Parses mapping `event_id -> (type, state_key) -> state event_id`
|
||||||
|
# to get all state ids that we're interested in.
|
||||||
event_map = yield self.store.get_events([
|
event_map = yield self.store.get_events([
|
||||||
e_id for key_to_eid in event_to_state_ids.values()
|
e_id
|
||||||
for key, e_id in key_to_eid
|
for key_to_eid in event_to_state_ids.values()
|
||||||
|
for key, e_id in key_to_eid.items()
|
||||||
if key[0] != EventTypes.Member or check_match(key[1])
|
if key[0] != EventTypes.Member or check_match(key[1])
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -830,25 +817,6 @@ class FederationHandler(BaseHandler):
|
|||||||
user = UserID.from_string(event.state_key)
|
user = UserID.from_string(event.state_key)
|
||||||
yield user_joined_room(self.distributor, user, event.room_id)
|
yield user_joined_room(self.distributor, user, event.room_id)
|
||||||
|
|
||||||
new_pdu = event
|
|
||||||
|
|
||||||
users_in_room = yield self.store.get_joined_users_from_context(event, context)
|
|
||||||
|
|
||||||
destinations = set(
|
|
||||||
get_domain_from_id(user_id) for user_id in users_in_room
|
|
||||||
if not self.hs.is_mine_id(user_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
destinations.discard(origin)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"on_send_join_request: Sending event: %s, signatures: %s",
|
|
||||||
event.event_id,
|
|
||||||
event.signatures,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
|
||||||
|
|
||||||
state_ids = context.prev_state_ids.values()
|
state_ids = context.prev_state_ids.values()
|
||||||
auth_chain = yield self.store.get_auth_chain(set(
|
auth_chain = yield self.store.get_auth_chain(set(
|
||||||
[event.event_id] + state_ids
|
[event.event_id] + state_ids
|
||||||
@ -1055,24 +1023,6 @@ class FederationHandler(BaseHandler):
|
|||||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||||
)
|
)
|
||||||
|
|
||||||
new_pdu = event
|
|
||||||
|
|
||||||
users_in_room = yield self.store.get_joined_users_from_context(event, context)
|
|
||||||
|
|
||||||
destinations = set(
|
|
||||||
get_domain_from_id(user_id) for user_id in users_in_room
|
|
||||||
if not self.hs.is_mine_id(user_id)
|
|
||||||
)
|
|
||||||
destinations.discard(origin)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"on_send_leave_request: Sending event: %s, signatures: %s",
|
|
||||||
event.event_id,
|
|
||||||
event.signatures,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
|
||||||
|
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -372,11 +372,12 @@ class InitialSyncHandler(BaseHandler):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_receipts():
|
def get_receipts():
|
||||||
receipts_handler = self.hs.get_handlers().receipts_handler
|
receipts = yield self.store.get_linearized_receipts_for_room(
|
||||||
receipts = yield receipts_handler.get_receipts_for_room(
|
|
||||||
room_id,
|
room_id,
|
||||||
now_token.receipt_key
|
to_key=now_token.receipt_key,
|
||||||
)
|
)
|
||||||
|
if not receipts:
|
||||||
|
receipts = []
|
||||||
defer.returnValue(receipts)
|
defer.returnValue(receipts)
|
||||||
|
|
||||||
presence, receipts, (messages, token) = yield defer.gatherResults(
|
presence, receipts, (messages, token) = yield defer.gatherResults(
|
||||||
|
@ -22,9 +22,9 @@ from synapse.events.utils import serialize_event
|
|||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.push.action_generator import ActionGenerator
|
from synapse.push.action_generator import ActionGenerator
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
UserID, RoomAlias, RoomStreamToken, get_domain_from_id
|
UserID, RoomAlias, RoomStreamToken,
|
||||||
)
|
)
|
||||||
from synapse.util.async import run_on_reactor, ReadWriteLock
|
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
@ -50,6 +50,10 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
self.pagination_lock = ReadWriteLock()
|
self.pagination_lock = ReadWriteLock()
|
||||||
|
|
||||||
|
# We arbitrarily limit concurrent event creation for a room to 5.
|
||||||
|
# This is to stop us from diverging history *too* much.
|
||||||
|
self.limiter = Limiter(max_count=5)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def purge_history(self, room_id, event_id):
|
def purge_history(self, room_id, event_id):
|
||||||
event = yield self.store.get_event(event_id)
|
event = yield self.store.get_event(event_id)
|
||||||
@ -191,6 +195,7 @@ class MessageHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
builder = self.event_builder_factory.new(event_dict)
|
builder = self.event_builder_factory.new(event_dict)
|
||||||
|
|
||||||
|
with (yield self.limiter.queue(builder.room_id)):
|
||||||
self.validator.validate_new(builder)
|
self.validator.validate_new(builder)
|
||||||
|
|
||||||
if builder.type == EventTypes.Member:
|
if builder.type == EventTypes.Member:
|
||||||
@ -221,6 +226,7 @@ class MessageHandler(BaseHandler):
|
|||||||
builder=builder,
|
builder=builder,
|
||||||
prev_event_ids=prev_event_ids,
|
prev_event_ids=prev_event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((event, context))
|
defer.returnValue((event, context))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -599,13 +605,6 @@ class MessageHandler(BaseHandler):
|
|||||||
event_stream_id, max_stream_id
|
event_stream_id, max_stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
users_in_room = yield self.store.get_joined_users_from_context(event, context)
|
|
||||||
|
|
||||||
destinations = [
|
|
||||||
get_domain_from_id(user_id) for user_id in users_in_room
|
|
||||||
if not self.hs.is_mine_id(user_id)
|
|
||||||
]
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _notify():
|
def _notify():
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
@ -618,7 +617,3 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
# If invite, remove room_state from unsigned before sending.
|
# If invite, remove room_state from unsigned before sending.
|
||||||
event.unsigned.pop("invite_room_state", None)
|
event.unsigned.pop("invite_room_state", None)
|
||||||
|
|
||||||
preserve_fn(federation_handler.handle_new_event)(
|
|
||||||
event, destinations=destinations,
|
|
||||||
)
|
|
||||||
|
@ -91,28 +91,29 @@ class PresenceHandler(object):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.wheel_timer = WheelTimer()
|
self.wheel_timer = WheelTimer()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.federation = hs.get_replication_layer()
|
self.replication = hs.get_replication_layer()
|
||||||
|
self.federation = hs.get_federation_sender()
|
||||||
|
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
self.federation.register_edu_handler(
|
self.replication.register_edu_handler(
|
||||||
"m.presence", self.incoming_presence
|
"m.presence", self.incoming_presence
|
||||||
)
|
)
|
||||||
self.federation.register_edu_handler(
|
self.replication.register_edu_handler(
|
||||||
"m.presence_invite",
|
"m.presence_invite",
|
||||||
lambda origin, content: self.invite_presence(
|
lambda origin, content: self.invite_presence(
|
||||||
observed_user=UserID.from_string(content["observed_user"]),
|
observed_user=UserID.from_string(content["observed_user"]),
|
||||||
observer_user=UserID.from_string(content["observer_user"]),
|
observer_user=UserID.from_string(content["observer_user"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.federation.register_edu_handler(
|
self.replication.register_edu_handler(
|
||||||
"m.presence_accept",
|
"m.presence_accept",
|
||||||
lambda origin, content: self.accept_presence(
|
lambda origin, content: self.accept_presence(
|
||||||
observed_user=UserID.from_string(content["observed_user"]),
|
observed_user=UserID.from_string(content["observed_user"]),
|
||||||
observer_user=UserID.from_string(content["observer_user"]),
|
observer_user=UserID.from_string(content["observer_user"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.federation.register_edu_handler(
|
self.replication.register_edu_handler(
|
||||||
"m.presence_deny",
|
"m.presence_deny",
|
||||||
lambda origin, content: self.deny_presence(
|
lambda origin, content: self.deny_presence(
|
||||||
observed_user=UserID.from_string(content["observed_user"]),
|
observed_user=UserID.from_string(content["observed_user"]),
|
||||||
|
@ -33,8 +33,8 @@ class ReceiptsHandler(BaseHandler):
|
|||||||
self.server_name = hs.config.server_name
|
self.server_name = hs.config.server_name
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_federation_sender()
|
||||||
self.federation.register_edu_handler(
|
hs.get_replication_layer().register_edu_handler(
|
||||||
"m.receipt", self._received_remote_receipt
|
"m.receipt", self._received_remote_receipt
|
||||||
)
|
)
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
@ -100,7 +100,7 @@ class ReceiptsHandler(BaseHandler):
|
|||||||
|
|
||||||
if not res:
|
if not res:
|
||||||
# res will be None if this read receipt is 'old'
|
# res will be None if this read receipt is 'old'
|
||||||
defer.returnValue(False)
|
continue
|
||||||
|
|
||||||
stream_id, max_persisted_id = res
|
stream_id, max_persisted_id = res
|
||||||
|
|
||||||
@ -109,6 +109,10 @@ class ReceiptsHandler(BaseHandler):
|
|||||||
if max_batch_id is None or max_persisted_id > max_batch_id:
|
if max_batch_id is None or max_persisted_id > max_batch_id:
|
||||||
max_batch_id = max_persisted_id
|
max_batch_id = max_persisted_id
|
||||||
|
|
||||||
|
if min_batch_id is None:
|
||||||
|
# no new receipts
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
affected_room_ids = list(set([r["room_id"] for r in receipts]))
|
affected_room_ids = list(set([r["room_id"] for r in receipts]))
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
|
@ -81,7 +81,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
"User ID already taken.",
|
"User ID already taken.",
|
||||||
errcode=Codes.USER_IN_USE,
|
errcode=Codes.USER_IN_USE,
|
||||||
)
|
)
|
||||||
user_data = yield self.auth.get_user_from_macaroon(guest_access_token)
|
user_data = yield self.auth.get_user_by_access_token(guest_access_token)
|
||||||
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
|
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
@ -369,7 +369,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
|
def get_or_create_user(self, requester, localpart, displayname,
|
||||||
password_hash=None):
|
password_hash=None):
|
||||||
"""Creates a new user if the user does not exist,
|
"""Creates a new user if the user does not exist,
|
||||||
else revokes all previous access tokens and generates a new one.
|
else revokes all previous access tokens and generates a new one.
|
||||||
@ -399,8 +399,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
|
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
token = self.auth_handler().generate_access_token(
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
user_id, None, duration_in_ms)
|
|
||||||
|
|
||||||
if need_register:
|
if need_register:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
|
@ -44,16 +44,19 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
"join_rules": JoinRules.INVITE,
|
"join_rules": JoinRules.INVITE,
|
||||||
"history_visibility": "shared",
|
"history_visibility": "shared",
|
||||||
"original_invitees_have_ops": False,
|
"original_invitees_have_ops": False,
|
||||||
|
"guest_can_join": True,
|
||||||
},
|
},
|
||||||
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
|
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
|
||||||
"join_rules": JoinRules.INVITE,
|
"join_rules": JoinRules.INVITE,
|
||||||
"history_visibility": "shared",
|
"history_visibility": "shared",
|
||||||
"original_invitees_have_ops": True,
|
"original_invitees_have_ops": True,
|
||||||
|
"guest_can_join": True,
|
||||||
},
|
},
|
||||||
RoomCreationPreset.PUBLIC_CHAT: {
|
RoomCreationPreset.PUBLIC_CHAT: {
|
||||||
"join_rules": JoinRules.PUBLIC,
|
"join_rules": JoinRules.PUBLIC,
|
||||||
"history_visibility": "shared",
|
"history_visibility": "shared",
|
||||||
"original_invitees_have_ops": False,
|
"original_invitees_have_ops": False,
|
||||||
|
"guest_can_join": False,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -336,6 +339,13 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
content={"history_visibility": config["history_visibility"]}
|
content={"history_visibility": config["history_visibility"]}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config["guest_can_join"]:
|
||||||
|
if (EventTypes.GuestAccess, '') not in initial_state:
|
||||||
|
yield send(
|
||||||
|
etype=EventTypes.GuestAccess,
|
||||||
|
content={"guest_access": "can_join"}
|
||||||
|
)
|
||||||
|
|
||||||
for (etype, state_key), content in initial_state.items():
|
for (etype, state_key), content in initial_state.items():
|
||||||
yield send(
|
yield send(
|
||||||
etype=etype,
|
etype=etype,
|
||||||
|
@ -22,6 +22,7 @@ from synapse.api.constants import (
|
|||||||
)
|
)
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async import concurrently_execute
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
from synapse.types import ThirdPartyInstanceID
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from unpaddedbase64 import encode_base64, decode_base64
|
from unpaddedbase64 import encode_base64, decode_base64
|
||||||
@ -34,6 +35,10 @@ logger = logging.getLogger(__name__)
|
|||||||
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
|
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
|
# This is used to indicate we should only return rooms published to the main list.
|
||||||
|
EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
|
||||||
|
|
||||||
|
|
||||||
class RoomListHandler(BaseHandler):
|
class RoomListHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RoomListHandler, self).__init__(hs)
|
super(RoomListHandler, self).__init__(hs)
|
||||||
@ -41,22 +46,43 @@ class RoomListHandler(BaseHandler):
|
|||||||
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
|
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
|
||||||
|
|
||||||
def get_local_public_room_list(self, limit=None, since_token=None,
|
def get_local_public_room_list(self, limit=None, since_token=None,
|
||||||
search_filter=None):
|
search_filter=None,
|
||||||
if search_filter:
|
network_tuple=EMTPY_THIRD_PARTY_ID,):
|
||||||
# We explicitly don't bother caching searches.
|
"""Generate a local public room list.
|
||||||
return self._get_public_room_list(limit, since_token, search_filter)
|
|
||||||
|
There are multiple different lists: the main one plus one per third
|
||||||
|
party network. A client can ask for a specific list or to return all.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit (int)
|
||||||
|
since_token (str)
|
||||||
|
search_filter (dict)
|
||||||
|
network_tuple (ThirdPartyInstanceID): Which public list to use.
|
||||||
|
This can be (None, None) to indicate the main list, or a particular
|
||||||
|
appservice and network id to use an appservice specific one.
|
||||||
|
Setting to None returns all public rooms across all lists.
|
||||||
|
"""
|
||||||
|
if search_filter or (network_tuple and network_tuple.appservice_id is not None):
|
||||||
|
# We explicitly don't bother caching searches or requests for
|
||||||
|
# appservice specific lists.
|
||||||
|
return self._get_public_room_list(
|
||||||
|
limit, since_token, search_filter, network_tuple=network_tuple,
|
||||||
|
)
|
||||||
|
|
||||||
result = self.response_cache.get((limit, since_token))
|
result = self.response_cache.get((limit, since_token))
|
||||||
if not result:
|
if not result:
|
||||||
result = self.response_cache.set(
|
result = self.response_cache.set(
|
||||||
(limit, since_token),
|
(limit, since_token),
|
||||||
self._get_public_room_list(limit, since_token)
|
self._get_public_room_list(
|
||||||
|
limit, since_token, network_tuple=network_tuple
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_public_room_list(self, limit=None, since_token=None,
|
def _get_public_room_list(self, limit=None, since_token=None,
|
||||||
search_filter=None):
|
search_filter=None,
|
||||||
|
network_tuple=EMTPY_THIRD_PARTY_ID,):
|
||||||
if since_token and since_token != "END":
|
if since_token and since_token != "END":
|
||||||
since_token = RoomListNextBatch.from_token(since_token)
|
since_token = RoomListNextBatch.from_token(since_token)
|
||||||
else:
|
else:
|
||||||
@ -73,14 +99,15 @@ class RoomListHandler(BaseHandler):
|
|||||||
current_public_id = yield self.store.get_current_public_room_stream_id()
|
current_public_id = yield self.store.get_current_public_room_stream_id()
|
||||||
public_room_stream_id = since_token.public_room_stream_id
|
public_room_stream_id = since_token.public_room_stream_id
|
||||||
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
|
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
|
||||||
public_room_stream_id, current_public_id
|
public_room_stream_id, current_public_id,
|
||||||
|
network_tuple=network_tuple,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
stream_token = yield self.store.get_room_max_stream_ordering()
|
stream_token = yield self.store.get_room_max_stream_ordering()
|
||||||
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
|
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
|
||||||
|
|
||||||
room_ids = yield self.store.get_public_room_ids_at_stream_id(
|
room_ids = yield self.store.get_public_room_ids_at_stream_id(
|
||||||
public_room_stream_id
|
public_room_stream_id, network_tuple=network_tuple,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We want to return rooms in a particular order: the number of joined
|
# We want to return rooms in a particular order: the number of joined
|
||||||
@ -311,7 +338,8 @@ class RoomListHandler(BaseHandler):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
|
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
|
||||||
search_filter=None):
|
search_filter=None, include_all_networks=False,
|
||||||
|
third_party_instance_id=None,):
|
||||||
if search_filter:
|
if search_filter:
|
||||||
# We currently don't support searching across federation, so we have
|
# We currently don't support searching across federation, so we have
|
||||||
# to do it manually without pagination
|
# to do it manually without pagination
|
||||||
@ -320,6 +348,8 @@ class RoomListHandler(BaseHandler):
|
|||||||
|
|
||||||
res = yield self._get_remote_list_cached(
|
res = yield self._get_remote_list_cached(
|
||||||
server_name, limit=limit, since_token=since_token,
|
server_name, limit=limit, since_token=since_token,
|
||||||
|
include_all_networks=include_all_networks,
|
||||||
|
third_party_instance_id=third_party_instance_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if search_filter:
|
if search_filter:
|
||||||
@ -332,22 +362,30 @@ class RoomListHandler(BaseHandler):
|
|||||||
defer.returnValue(res)
|
defer.returnValue(res)
|
||||||
|
|
||||||
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
|
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
|
||||||
search_filter=None):
|
search_filter=None, include_all_networks=False,
|
||||||
|
third_party_instance_id=None,):
|
||||||
repl_layer = self.hs.get_replication_layer()
|
repl_layer = self.hs.get_replication_layer()
|
||||||
if search_filter:
|
if search_filter:
|
||||||
# We can't cache when asking for search
|
# We can't cache when asking for search
|
||||||
return repl_layer.get_public_rooms(
|
return repl_layer.get_public_rooms(
|
||||||
server_name, limit=limit, since_token=since_token,
|
server_name, limit=limit, since_token=since_token,
|
||||||
search_filter=search_filter,
|
search_filter=search_filter, include_all_networks=include_all_networks,
|
||||||
|
third_party_instance_id=third_party_instance_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.remote_response_cache.get((server_name, limit, since_token))
|
key = (
|
||||||
|
server_name, limit, since_token, include_all_networks,
|
||||||
|
third_party_instance_id,
|
||||||
|
)
|
||||||
|
result = self.remote_response_cache.get(key)
|
||||||
if not result:
|
if not result:
|
||||||
result = self.remote_response_cache.set(
|
result = self.remote_response_cache.set(
|
||||||
(server_name, limit, since_token),
|
key,
|
||||||
repl_layer.get_public_rooms(
|
repl_layer.get_public_rooms(
|
||||||
server_name, limit=limit, since_token=since_token,
|
server_name, limit=limit, since_token=since_token,
|
||||||
search_filter=search_filter,
|
search_filter=search_filter,
|
||||||
|
include_all_networks=include_all_networks,
|
||||||
|
third_party_instance_id=third_party_instance_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
@ -277,6 +277,7 @@ class SyncHandler(object):
|
|||||||
"""
|
"""
|
||||||
with Measure(self.clock, "load_filtered_recents"):
|
with Measure(self.clock, "load_filtered_recents"):
|
||||||
timeline_limit = sync_config.filter_collection.timeline_limit()
|
timeline_limit = sync_config.filter_collection.timeline_limit()
|
||||||
|
block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline()
|
||||||
|
|
||||||
if recents is None or newly_joined_room or timeline_limit < len(recents):
|
if recents is None or newly_joined_room or timeline_limit < len(recents):
|
||||||
limited = True
|
limited = True
|
||||||
@ -293,7 +294,7 @@ class SyncHandler(object):
|
|||||||
else:
|
else:
|
||||||
recents = []
|
recents = []
|
||||||
|
|
||||||
if not limited:
|
if not limited or block_all_timeline:
|
||||||
defer.returnValue(TimelineBatch(
|
defer.returnValue(TimelineBatch(
|
||||||
events=recents,
|
events=recents,
|
||||||
prev_batch=now_token,
|
prev_batch=now_token,
|
||||||
@ -509,6 +510,7 @@ class SyncHandler(object):
|
|||||||
Returns:
|
Returns:
|
||||||
Deferred(SyncResult)
|
Deferred(SyncResult)
|
||||||
"""
|
"""
|
||||||
|
logger.info("Calculating sync response for %r", sync_config.user)
|
||||||
|
|
||||||
# NB: The now_token gets changed by some of the generate_sync_* methods,
|
# NB: The now_token gets changed by some of the generate_sync_* methods,
|
||||||
# this is due to some of the underlying streams not supporting the ability
|
# this is due to some of the underlying streams not supporting the ability
|
||||||
@ -531,6 +533,11 @@ class SyncHandler(object):
|
|||||||
)
|
)
|
||||||
newly_joined_rooms, newly_joined_users = res
|
newly_joined_rooms, newly_joined_users = res
|
||||||
|
|
||||||
|
block_all_presence_data = (
|
||||||
|
since_token is None and
|
||||||
|
sync_config.filter_collection.blocks_all_presence()
|
||||||
|
)
|
||||||
|
if not block_all_presence_data:
|
||||||
yield self._generate_sync_entry_for_presence(
|
yield self._generate_sync_entry_for_presence(
|
||||||
sync_result_builder, newly_joined_rooms, newly_joined_users
|
sync_result_builder, newly_joined_rooms, newly_joined_users
|
||||||
)
|
)
|
||||||
@ -569,16 +576,20 @@ class SyncHandler(object):
|
|||||||
# We only delete messages when a new message comes in, but that's
|
# We only delete messages when a new message comes in, but that's
|
||||||
# fine so long as we delete them at some point.
|
# fine so long as we delete them at some point.
|
||||||
|
|
||||||
logger.debug("Deleting messages up to %d", since_stream_id)
|
deleted = yield self.store.delete_messages_for_device(
|
||||||
yield self.store.delete_messages_for_device(
|
|
||||||
user_id, device_id, since_stream_id
|
user_id, device_id, since_stream_id
|
||||||
)
|
)
|
||||||
|
logger.info("Deleted %d to-device messages up to %d",
|
||||||
|
deleted, since_stream_id)
|
||||||
|
|
||||||
logger.debug("Getting messages up to %d", now_token.to_device_key)
|
|
||||||
messages, stream_id = yield self.store.get_new_messages_for_device(
|
messages, stream_id = yield self.store.get_new_messages_for_device(
|
||||||
user_id, device_id, since_stream_id, now_token.to_device_key
|
user_id, device_id, since_stream_id, now_token.to_device_key
|
||||||
)
|
)
|
||||||
logger.debug("Got messages up to %d: %r", stream_id, messages)
|
|
||||||
|
logger.info(
|
||||||
|
"Returning %d to-device messages between %d and %d (current token: %d)",
|
||||||
|
len(messages), since_stream_id, stream_id, now_token.to_device_key
|
||||||
|
)
|
||||||
sync_result_builder.now_token = now_token.copy_and_replace(
|
sync_result_builder.now_token = now_token.copy_and_replace(
|
||||||
"to_device_key", stream_id
|
"to_device_key", stream_id
|
||||||
)
|
)
|
||||||
@ -709,7 +720,14 @@ class SyncHandler(object):
|
|||||||
`(newly_joined_rooms, newly_joined_users)`
|
`(newly_joined_rooms, newly_joined_users)`
|
||||||
"""
|
"""
|
||||||
user_id = sync_result_builder.sync_config.user.to_string()
|
user_id = sync_result_builder.sync_config.user.to_string()
|
||||||
|
block_all_room_ephemeral = (
|
||||||
|
sync_result_builder.since_token is None and
|
||||||
|
sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
|
||||||
|
)
|
||||||
|
|
||||||
|
if block_all_room_ephemeral:
|
||||||
|
ephemeral_by_room = {}
|
||||||
|
else:
|
||||||
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
|
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
|
||||||
sync_result_builder.sync_config,
|
sync_result_builder.sync_config,
|
||||||
now_token=sync_result_builder.now_token,
|
now_token=sync_result_builder.now_token,
|
||||||
|
@ -55,9 +55,9 @@ class TypingHandler(object):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||||
|
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_federation_sender()
|
||||||
|
|
||||||
self.federation.register_edu_handler("m.typing", self._recv_edu)
|
hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
|
||||||
|
|
||||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ from synapse.api.errors import (
|
|||||||
|
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
|
|
||||||
|
import cgi
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
@ -292,12 +293,7 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
# We need to update the transactions table to say it was sent?
|
# We need to update the transactions table to say it was sent?
|
||||||
c_type = response.headers.getRawHeaders("Content-Type")
|
check_content_type_is_json(response.headers)
|
||||||
|
|
||||||
if "application/json" not in c_type:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Content-Type not application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield preserve_context_over_fn(readBody, response)
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
@ -342,12 +338,7 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
# We need to update the transactions table to say it was sent?
|
# We need to update the transactions table to say it was sent?
|
||||||
c_type = response.headers.getRawHeaders("Content-Type")
|
check_content_type_is_json(response.headers)
|
||||||
|
|
||||||
if "application/json" not in c_type:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Content-Type not application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield preserve_context_over_fn(readBody, response)
|
||||||
|
|
||||||
@ -400,12 +391,7 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
# We need to update the transactions table to say it was sent?
|
# We need to update the transactions table to say it was sent?
|
||||||
c_type = response.headers.getRawHeaders("Content-Type")
|
check_content_type_is_json(response.headers)
|
||||||
|
|
||||||
if "application/json" not in c_type:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Content-Type not application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield preserve_context_over_fn(readBody, response)
|
||||||
|
|
||||||
@ -525,3 +511,29 @@ def _flatten_response_never_received(e):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return "%s: %s" % (type(e).__name__, e.message,)
|
return "%s: %s" % (type(e).__name__, e.message,)
|
||||||
|
|
||||||
|
|
||||||
|
def check_content_type_is_json(headers):
|
||||||
|
"""
|
||||||
|
Check that a set of HTTP headers have a Content-Type header, and that it
|
||||||
|
is application/json.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
headers (twisted.web.http_headers.Headers): headers to check
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError if the
|
||||||
|
|
||||||
|
"""
|
||||||
|
c_type = headers.getRawHeaders("Content-Type")
|
||||||
|
if c_type is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No Content-Type header"
|
||||||
|
)
|
||||||
|
|
||||||
|
c_type = c_type[0] # only the first header
|
||||||
|
val, options = cgi.parse_header(c_type)
|
||||||
|
if val != "application/json":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Content-Type not application/json: was '%s'" % c_type
|
||||||
|
)
|
||||||
|
@ -78,12 +78,16 @@ def parse_boolean(request, name, default=None, required=False):
|
|||||||
parameter is present and not one of "true" or "false".
|
parameter is present and not one of "true" or "false".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if name in request.args:
|
return parse_boolean_from_args(request.args, name, default, required)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_boolean_from_args(args, name, default=None, required=False):
|
||||||
|
if name in args:
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
"true": True,
|
"true": True,
|
||||||
"false": False,
|
"false": False,
|
||||||
}[request.args[name][0]]
|
}[args[name][0]]
|
||||||
except:
|
except:
|
||||||
message = (
|
message = (
|
||||||
"Boolean query parameter %r must be one of"
|
"Boolean query parameter %r must be one of"
|
||||||
|
@ -17,6 +17,7 @@ from twisted.internet import defer
|
|||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
|
|
||||||
|
from synapse.util import DeferredTimedOutError
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
|
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
|
||||||
@ -143,6 +144,12 @@ class Notifier(object):
|
|||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.appservice_handler = hs.get_application_service_handler()
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
|
|
||||||
|
if hs.should_send_federation():
|
||||||
|
self.federation_sender = hs.get_federation_sender()
|
||||||
|
else:
|
||||||
|
self.federation_sender = None
|
||||||
|
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
|
|
||||||
self.clock.looping_call(
|
self.clock.looping_call(
|
||||||
@ -220,6 +227,9 @@ class Notifier(object):
|
|||||||
# poke any interested application service.
|
# poke any interested application service.
|
||||||
self.appservice_handler.notify_interested_services(room_stream_id)
|
self.appservice_handler.notify_interested_services(room_stream_id)
|
||||||
|
|
||||||
|
if self.federation_sender:
|
||||||
|
self.federation_sender.notify_new_events(room_stream_id)
|
||||||
|
|
||||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||||
self._user_joined_room(event.state_key, event.room_id)
|
self._user_joined_room(event.state_key, event.room_id)
|
||||||
|
|
||||||
@ -285,14 +295,7 @@ class Notifier(object):
|
|||||||
|
|
||||||
result = None
|
result = None
|
||||||
if timeout:
|
if timeout:
|
||||||
# Will be set to a _NotificationListener that we'll be waiting on.
|
end_time = self.clock.time_msec() + timeout
|
||||||
# Allows us to cancel it.
|
|
||||||
listener = None
|
|
||||||
|
|
||||||
def timed_out():
|
|
||||||
if listener:
|
|
||||||
listener.deferred.cancel()
|
|
||||||
timer = self.clock.call_later(timeout / 1000., timed_out)
|
|
||||||
|
|
||||||
prev_token = from_token
|
prev_token = from_token
|
||||||
while not result:
|
while not result:
|
||||||
@ -303,6 +306,10 @@ class Notifier(object):
|
|||||||
if result:
|
if result:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
if end_time <= now:
|
||||||
|
break
|
||||||
|
|
||||||
# Now we wait for the _NotifierUserStream to be told there
|
# Now we wait for the _NotifierUserStream to be told there
|
||||||
# is a new token.
|
# is a new token.
|
||||||
# We need to supply the token we supplied to callback so
|
# We need to supply the token we supplied to callback so
|
||||||
@ -310,11 +317,14 @@ class Notifier(object):
|
|||||||
prev_token = current_token
|
prev_token = current_token
|
||||||
listener = user_stream.new_listener(prev_token)
|
listener = user_stream.new_listener(prev_token)
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
yield listener.deferred
|
yield self.clock.time_bound_deferred(
|
||||||
|
listener.deferred,
|
||||||
|
time_out=(end_time - now) / 1000.
|
||||||
|
)
|
||||||
|
except DeferredTimedOutError:
|
||||||
|
break
|
||||||
except defer.CancelledError:
|
except defer.CancelledError:
|
||||||
break
|
break
|
||||||
|
|
||||||
self.clock.cancel_call_later(timer, ignore_errs=True)
|
|
||||||
else:
|
else:
|
||||||
current_token = user_stream.current_token
|
current_token = user_stream.current_token
|
||||||
result = yield callback(from_token, current_token)
|
result = yield callback(from_token, current_token)
|
||||||
@ -483,22 +493,27 @@ class Notifier(object):
|
|||||||
"""
|
"""
|
||||||
listener = _NotificationListener(None)
|
listener = _NotificationListener(None)
|
||||||
|
|
||||||
def timed_out():
|
end_time = self.clock.time_msec() + timeout
|
||||||
listener.deferred.cancel()
|
|
||||||
|
|
||||||
timer = self.clock.call_later(timeout / 1000., timed_out)
|
|
||||||
while True:
|
while True:
|
||||||
listener.deferred = self.replication_deferred.observe()
|
listener.deferred = self.replication_deferred.observe()
|
||||||
result = yield callback()
|
result = yield callback()
|
||||||
if result:
|
if result:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
if end_time <= now:
|
||||||
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
yield listener.deferred
|
yield self.clock.time_bound_deferred(
|
||||||
|
listener.deferred,
|
||||||
|
time_out=(end_time - now) / 1000.
|
||||||
|
)
|
||||||
|
except DeferredTimedOutError:
|
||||||
|
break
|
||||||
except defer.CancelledError:
|
except defer.CancelledError:
|
||||||
break
|
break
|
||||||
|
|
||||||
self.clock.cancel_call_later(timer, ignore_errs=True)
|
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
@ -87,12 +87,12 @@ class BulkPushRuleEvaluator:
|
|||||||
condition_cache = {}
|
condition_cache = {}
|
||||||
|
|
||||||
for uid, rules in self.rules_by_user.items():
|
for uid, rules in self.rules_by_user.items():
|
||||||
display_name = None
|
display_name = room_members.get(uid, {}).get("display_name", None)
|
||||||
member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
|
if not display_name:
|
||||||
if member_ev_id:
|
# Handle the case where we are pushing a membership event to
|
||||||
member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
|
# that user, as they might not be already joined.
|
||||||
if member_ev:
|
if event.type == EventTypes.Member and event.state_key == uid:
|
||||||
display_name = member_ev.content.get("displayname", None)
|
display_name = event.content.get("displayname", None)
|
||||||
|
|
||||||
filtered = filtered_by_user[uid]
|
filtered = filtered_by_user[uid]
|
||||||
if len(filtered) == 0:
|
if len(filtered) == 0:
|
||||||
|
@ -49,8 +49,8 @@ CONDITIONAL_REQUIREMENTS = {
|
|||||||
"Jinja2>=2.8": ["Jinja2>=2.8"],
|
"Jinja2>=2.8": ["Jinja2>=2.8"],
|
||||||
"bleach>=1.4.2": ["bleach>=1.4.2"],
|
"bleach>=1.4.2": ["bleach>=1.4.2"],
|
||||||
},
|
},
|
||||||
"ldap": {
|
"matrix-synapse-ldap3": {
|
||||||
"ldap3>=1.0": ["ldap3>=1.0"],
|
"matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"],
|
||||||
},
|
},
|
||||||
"psutil": {
|
"psutil": {
|
||||||
"psutil>=2.0.0": ["psutil>=2.0.0"],
|
"psutil>=2.0.0": ["psutil>=2.0.0"],
|
||||||
@ -69,6 +69,7 @@ def requirements(config=None, include_conditional=False):
|
|||||||
def github_link(project, version, egg):
|
def github_link(project, version, egg):
|
||||||
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
|
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
|
||||||
|
|
||||||
|
|
||||||
DEPENDENCY_LINKS = {
|
DEPENDENCY_LINKS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,6 +157,7 @@ def list_requirements():
|
|||||||
result.append(requirement)
|
result.append(requirement)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
sys.stdout.writelines(req + "\n" for req in list_requirements())
|
sys.stdout.writelines(req + "\n" for req in list_requirements())
|
||||||
|
60
synapse/replication/expire_cache.py
Normal file
60
synapse/replication/expire_cache.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.http.server import respond_with_json_bytes, request_handler
|
||||||
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
|
||||||
|
|
||||||
|
class ExpireCacheResource(Resource):
|
||||||
|
"""
|
||||||
|
HTTP endpoint for expiring storage caches.
|
||||||
|
|
||||||
|
POST /_synapse/replication/expire_cache HTTP/1.1
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"invalidate": [
|
||||||
|
{
|
||||||
|
"name": "func_name",
|
||||||
|
"keys": ["key1", "key2"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
Resource.__init__(self) # Resource is old-style, so no super()
|
||||||
|
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.version_string = hs.version_string
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
def render_POST(self, request):
|
||||||
|
self._async_render_POST(request)
|
||||||
|
return NOT_DONE_YET
|
||||||
|
|
||||||
|
@request_handler()
|
||||||
|
def _async_render_POST(self, request):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
for row in content["invalidate"]:
|
||||||
|
name = row["name"]
|
||||||
|
keys = tuple(row["keys"])
|
||||||
|
|
||||||
|
getattr(self.store, name).invalidate(keys)
|
||||||
|
|
||||||
|
respond_with_json_bytes(request, 200, "{}")
|
@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string
|
|||||||
from synapse.http.server import request_handler, finish_request
|
from synapse.http.server import request_handler, finish_request
|
||||||
from synapse.replication.pusher_resource import PusherResource
|
from synapse.replication.pusher_resource import PusherResource
|
||||||
from synapse.replication.presence_resource import PresenceResource
|
from synapse.replication.presence_resource import PresenceResource
|
||||||
|
from synapse.replication.expire_cache import ExpireCacheResource
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
@ -44,6 +45,7 @@ STREAM_NAMES = (
|
|||||||
("caches",),
|
("caches",),
|
||||||
("to_device",),
|
("to_device",),
|
||||||
("public_rooms",),
|
("public_rooms",),
|
||||||
|
("federation",),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -116,11 +118,14 @@ class ReplicationResource(Resource):
|
|||||||
self.sources = hs.get_event_sources()
|
self.sources = hs.get_event_sources()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
self.typing_handler = hs.get_typing_handler()
|
self.typing_handler = hs.get_typing_handler()
|
||||||
|
self.federation_sender = hs.get_federation_sender()
|
||||||
self.notifier = hs.notifier
|
self.notifier = hs.notifier
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.config = hs.get_config()
|
||||||
|
|
||||||
self.putChild("remove_pushers", PusherResource(hs))
|
self.putChild("remove_pushers", PusherResource(hs))
|
||||||
self.putChild("syncing_users", PresenceResource(hs))
|
self.putChild("syncing_users", PresenceResource(hs))
|
||||||
|
self.putChild("expire_cache", ExpireCacheResource(hs))
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
self._async_render_GET(request)
|
self._async_render_GET(request)
|
||||||
@ -134,6 +139,7 @@ class ReplicationResource(Resource):
|
|||||||
pushers_token = self.store.get_pushers_stream_token()
|
pushers_token = self.store.get_pushers_stream_token()
|
||||||
caches_token = self.store.get_cache_stream_token()
|
caches_token = self.store.get_cache_stream_token()
|
||||||
public_rooms_token = self.store.get_current_public_room_stream_id()
|
public_rooms_token = self.store.get_current_public_room_stream_id()
|
||||||
|
federation_token = self.federation_sender.get_current_token()
|
||||||
|
|
||||||
defer.returnValue(_ReplicationToken(
|
defer.returnValue(_ReplicationToken(
|
||||||
room_stream_token,
|
room_stream_token,
|
||||||
@ -148,6 +154,7 @@ class ReplicationResource(Resource):
|
|||||||
caches_token,
|
caches_token,
|
||||||
int(stream_token.to_device_key),
|
int(stream_token.to_device_key),
|
||||||
int(public_rooms_token),
|
int(public_rooms_token),
|
||||||
|
int(federation_token),
|
||||||
))
|
))
|
||||||
|
|
||||||
@request_handler()
|
@request_handler()
|
||||||
@ -164,8 +171,13 @@ class ReplicationResource(Resource):
|
|||||||
}
|
}
|
||||||
request_streams["streams"] = parse_string(request, "streams")
|
request_streams["streams"] = parse_string(request, "streams")
|
||||||
|
|
||||||
|
federation_ack = parse_integer(request, "federation_ack", None)
|
||||||
|
|
||||||
def replicate():
|
def replicate():
|
||||||
return self.replicate(request_streams, limit)
|
return self.replicate(
|
||||||
|
request_streams, limit,
|
||||||
|
federation_ack=federation_ack
|
||||||
|
)
|
||||||
|
|
||||||
writer = yield self.notifier.wait_for_replication(replicate, timeout)
|
writer = yield self.notifier.wait_for_replication(replicate, timeout)
|
||||||
result = writer.finish()
|
result = writer.finish()
|
||||||
@ -183,7 +195,7 @@ class ReplicationResource(Resource):
|
|||||||
finish_request(request)
|
finish_request(request)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def replicate(self, request_streams, limit):
|
def replicate(self, request_streams, limit, federation_ack=None):
|
||||||
writer = _Writer()
|
writer = _Writer()
|
||||||
current_token = yield self.current_replication_token()
|
current_token = yield self.current_replication_token()
|
||||||
logger.debug("Replicating up to %r", current_token)
|
logger.debug("Replicating up to %r", current_token)
|
||||||
@ -202,6 +214,7 @@ class ReplicationResource(Resource):
|
|||||||
yield self.caches(writer, current_token, limit, request_streams)
|
yield self.caches(writer, current_token, limit, request_streams)
|
||||||
yield self.to_device(writer, current_token, limit, request_streams)
|
yield self.to_device(writer, current_token, limit, request_streams)
|
||||||
yield self.public_rooms(writer, current_token, limit, request_streams)
|
yield self.public_rooms(writer, current_token, limit, request_streams)
|
||||||
|
self.federation(writer, current_token, limit, request_streams, federation_ack)
|
||||||
self.streams(writer, current_token, request_streams)
|
self.streams(writer, current_token, request_streams)
|
||||||
|
|
||||||
logger.debug("Replicated %d rows", writer.total)
|
logger.debug("Replicated %d rows", writer.total)
|
||||||
@ -462,7 +475,24 @@ class ReplicationResource(Resource):
|
|||||||
)
|
)
|
||||||
upto_token = _position_from_rows(public_rooms_rows, current_position)
|
upto_token = _position_from_rows(public_rooms_rows, current_position)
|
||||||
writer.write_header_and_rows("public_rooms", public_rooms_rows, (
|
writer.write_header_and_rows("public_rooms", public_rooms_rows, (
|
||||||
"position", "room_id", "visibility"
|
"position", "room_id", "visibility", "appservice_id", "network_id",
|
||||||
|
), position=upto_token)
|
||||||
|
|
||||||
|
def federation(self, writer, current_token, limit, request_streams, federation_ack):
|
||||||
|
if self.config.send_federation:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_position = current_token.federation
|
||||||
|
|
||||||
|
federation = request_streams.get("federation")
|
||||||
|
|
||||||
|
if federation is not None and federation != current_position:
|
||||||
|
federation_rows = self.federation_sender.get_replication_rows(
|
||||||
|
federation, limit, federation_ack=federation_ack,
|
||||||
|
)
|
||||||
|
upto_token = _position_from_rows(federation_rows, current_position)
|
||||||
|
writer.write_header_and_rows("federation", federation_rows, (
|
||||||
|
"position", "type", "content",
|
||||||
), position=upto_token)
|
), position=upto_token)
|
||||||
|
|
||||||
|
|
||||||
@ -497,6 +527,7 @@ class _Writer(object):
|
|||||||
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
|
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
|
||||||
"events", "presence", "typing", "receipts", "account_data", "backfill",
|
"events", "presence", "typing", "receipts", "account_data", "backfill",
|
||||||
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
|
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
|
||||||
|
"federation",
|
||||||
))):
|
))):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
|
|
||||||
|
@ -34,6 +34,9 @@ class BaseSlavedStore(SQLBaseStore):
|
|||||||
else:
|
else:
|
||||||
self._cache_id_gen = None
|
self._cache_id_gen = None
|
||||||
|
|
||||||
|
self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
|
||||||
|
self.http_client = hs.get_simple_http_client()
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
pos = {}
|
pos = {}
|
||||||
if self._cache_id_gen:
|
if self._cache_id_gen:
|
||||||
@ -54,3 +57,19 @@ class BaseSlavedStore(SQLBaseStore):
|
|||||||
logger.info("Got unexpected cache_func: %r", cache_func)
|
logger.info("Got unexpected cache_func: %r", cache_func)
|
||||||
self._cache_id_gen.advance(int(stream["position"]))
|
self._cache_id_gen.advance(int(stream["position"]))
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||||
|
txn.call_after(cache_func.invalidate, keys)
|
||||||
|
txn.call_after(self._send_invalidation_poke, cache_func, keys)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _send_invalidation_poke(self, cache_func, keys):
|
||||||
|
try:
|
||||||
|
yield self.http_client.post_json_get_json(self.expire_cache_url, {
|
||||||
|
"invalidate": [{
|
||||||
|
"name": cache_func.__name__,
|
||||||
|
"keys": list(keys),
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
except:
|
||||||
|
logger.exception("Failed to poke on expire_cache")
|
||||||
|
@ -29,10 +29,16 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
|
|||||||
"DeviceInboxStreamChangeCache",
|
"DeviceInboxStreamChangeCache",
|
||||||
self._device_inbox_id_gen.get_current_token()
|
self._device_inbox_id_gen.get_current_token()
|
||||||
)
|
)
|
||||||
|
self._device_federation_outbox_stream_cache = StreamChangeCache(
|
||||||
|
"DeviceFederationOutboxStreamChangeCache",
|
||||||
|
self._device_inbox_id_gen.get_current_token()
|
||||||
|
)
|
||||||
|
|
||||||
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
|
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
|
||||||
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
|
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
|
||||||
|
get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
|
||||||
delete_messages_for_device = DataStore.delete_messages_for_device.__func__
|
delete_messages_for_device = DataStore.delete_messages_for_device.__func__
|
||||||
|
delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedDeviceInboxStore, self).stream_positions()
|
result = super(SlavedDeviceInboxStore, self).stream_positions()
|
||||||
@ -45,9 +51,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
|
|||||||
self._device_inbox_id_gen.advance(int(stream["position"]))
|
self._device_inbox_id_gen.advance(int(stream["position"]))
|
||||||
for row in stream["rows"]:
|
for row in stream["rows"]:
|
||||||
stream_id = row[0]
|
stream_id = row[0]
|
||||||
user_id = row[1]
|
entity = row[1]
|
||||||
|
|
||||||
|
if entity.startswith("@"):
|
||||||
self._device_inbox_stream_cache.entity_has_changed(
|
self._device_inbox_stream_cache.entity_has_changed(
|
||||||
user_id, stream_id
|
entity, stream_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._device_federation_outbox_stream_cache.entity_has_changed(
|
||||||
|
entity, stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return super(SlavedDeviceInboxStore, self).process_replication(result)
|
return super(SlavedDeviceInboxStore, self).process_replication(result)
|
||||||
|
@ -26,6 +26,11 @@ from synapse.storage.stream import StreamStore
|
|||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# So, um, we want to borrow a load of functions intended for reading from
|
# So, um, we want to borrow a load of functions intended for reading from
|
||||||
# a DataStore, but we don't want to take functions that either write to the
|
# a DataStore, but we don't want to take functions that either write to the
|
||||||
@ -180,6 +185,11 @@ class SlavedEventStore(BaseSlavedStore):
|
|||||||
EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
|
EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
|
||||||
|
|
||||||
|
get_federation_out_pos = DataStore.get_federation_out_pos.__func__
|
||||||
|
update_federation_out_pos = DataStore.update_federation_out_pos.__func__
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedEventStore, self).stream_positions()
|
result = super(SlavedEventStore, self).stream_positions()
|
||||||
result["events"] = self._stream_id_gen.get_current_token()
|
result["events"] = self._stream_id_gen.get_current_token()
|
||||||
@ -194,6 +204,10 @@ class SlavedEventStore(BaseSlavedStore):
|
|||||||
stream = result.get("events")
|
stream = result.get("events")
|
||||||
if stream:
|
if stream:
|
||||||
self._stream_id_gen.advance(int(stream["position"]))
|
self._stream_id_gen.advance(int(stream["position"]))
|
||||||
|
|
||||||
|
if stream["rows"]:
|
||||||
|
logger.info("Got %d event rows", len(stream["rows"]))
|
||||||
|
|
||||||
for row in stream["rows"]:
|
for row in stream["rows"]:
|
||||||
self._process_replication_row(
|
self._process_replication_row(
|
||||||
row, backfilled=False, state_resets=state_resets
|
row, backfilled=False, state_resets=state_resets
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
|
from synapse.storage.room import RoomStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
|
|
||||||
@ -30,7 +31,7 @@ class RoomStore(BaseSlavedStore):
|
|||||||
DataStore.get_current_public_room_stream_id.__func__
|
DataStore.get_current_public_room_stream_id.__func__
|
||||||
)
|
)
|
||||||
get_public_room_ids_at_stream_id = (
|
get_public_room_ids_at_stream_id = (
|
||||||
DataStore.get_public_room_ids_at_stream_id.__func__
|
RoomStore.__dict__["get_public_room_ids_at_stream_id"]
|
||||||
)
|
)
|
||||||
get_public_room_ids_at_stream_id_txn = (
|
get_public_room_ids_at_stream_id_txn = (
|
||||||
DataStore.get_public_room_ids_at_stream_id_txn.__func__
|
DataStore.get_public_room_ids_at_stream_id_txn.__func__
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.storage.transactions import TransactionStore
|
from synapse.storage.transactions import TransactionStore
|
||||||
@ -22,9 +21,10 @@ from synapse.storage.transactions import TransactionStore
|
|||||||
class TransactionStore(BaseSlavedStore):
|
class TransactionStore(BaseSlavedStore):
|
||||||
get_destination_retry_timings = TransactionStore.__dict__[
|
get_destination_retry_timings = TransactionStore.__dict__[
|
||||||
"get_destination_retry_timings"
|
"get_destination_retry_timings"
|
||||||
].orig
|
]
|
||||||
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
|
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
|
||||||
|
set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
|
||||||
|
_set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
|
||||||
|
|
||||||
# For now, don't record the destination rety timings
|
prep_send_transaction = DataStore.prep_send_transaction.__func__
|
||||||
def set_destination_retry_timings(*args, **kwargs):
|
delivered_txn = DataStore.delivered_txn.__func__
|
||||||
return defer.succeed(None)
|
|
||||||
|
@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
|
|||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
ClientDirectoryServer(hs).register(http_server)
|
ClientDirectoryServer(hs).register(http_server)
|
||||||
ClientDirectoryListServer(hs).register(http_server)
|
ClientDirectoryListServer(hs).register(http_server)
|
||||||
|
ClientAppserviceDirectoryListServer(hs).register(http_server)
|
||||||
|
|
||||||
|
|
||||||
class ClientDirectoryServer(ClientV1RestServlet):
|
class ClientDirectoryServer(ClientV1RestServlet):
|
||||||
@ -184,3 +185,36 @@ class ClientDirectoryListServer(ClientV1RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class ClientAppserviceDirectoryListServer(ClientV1RestServlet):
|
||||||
|
PATTERNS = client_path_patterns(
|
||||||
|
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ClientAppserviceDirectoryListServer, self).__init__(hs)
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
|
def on_PUT(self, request, network_id, room_id):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
visibility = content.get("visibility", "public")
|
||||||
|
return self._edit(request, network_id, room_id, visibility)
|
||||||
|
|
||||||
|
def on_DELETE(self, request, network_id, room_id):
|
||||||
|
return self._edit(request, network_id, room_id, "private")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _edit(self, request, network_id, room_id, visibility):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
if not requester.app_service:
|
||||||
|
raise AuthError(
|
||||||
|
403, "Only appservices can edit the appservice published room list"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.handlers.directory_handler.edit_published_appservice_room_list(
|
||||||
|
requester.app_service.id, network_id, room_id, visibility,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
@ -137,16 +137,13 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
password=login_submission["password"],
|
password=login_submission["password"],
|
||||||
)
|
)
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
device_id = yield self._register_device(user_id, login_submission)
|
||||||
access_token, refresh_token = (
|
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||||
yield auth_handler.get_login_tuple_for_user_id(
|
|
||||||
user_id, device_id,
|
user_id, device_id,
|
||||||
login_submission.get("initial_device_display_name")
|
login_submission.get("initial_device_display_name"),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"refresh_token": refresh_token,
|
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
}
|
}
|
||||||
@ -161,16 +158,13 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||||
)
|
)
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
device_id = yield self._register_device(user_id, login_submission)
|
||||||
access_token, refresh_token = (
|
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||||
yield auth_handler.get_login_tuple_for_user_id(
|
|
||||||
user_id, device_id,
|
user_id, device_id,
|
||||||
login_submission.get("initial_device_display_name")
|
login_submission.get("initial_device_display_name"),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"refresh_token": refresh_token,
|
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
}
|
}
|
||||||
@ -207,16 +201,14 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
device_id = yield self._register_device(
|
device_id = yield self._register_device(
|
||||||
registered_user_id, login_submission
|
registered_user_id, login_submission
|
||||||
)
|
)
|
||||||
access_token, refresh_token = (
|
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||||
yield auth_handler.get_login_tuple_for_user_id(
|
|
||||||
registered_user_id, device_id,
|
registered_user_id, device_id,
|
||||||
login_submission.get("initial_device_display_name")
|
login_submission.get("initial_device_display_name"),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"user_id": registered_user_id,
|
"user_id": registered_user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"refresh_token": refresh_token,
|
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
@ -384,7 +384,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(CreateUserRestServlet, self).__init__(hs)
|
super(CreateUserRestServlet, self).__init__(hs)
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -418,18 +417,8 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||||||
if "displayname" not in user_json:
|
if "displayname" not in user_json:
|
||||||
raise SynapseError(400, "Expected 'displayname' key.")
|
raise SynapseError(400, "Expected 'displayname' key.")
|
||||||
|
|
||||||
if "duration_seconds" not in user_json:
|
|
||||||
raise SynapseError(400, "Expected 'duration_seconds' key.")
|
|
||||||
|
|
||||||
localpart = user_json["localpart"].encode("utf-8")
|
localpart = user_json["localpart"].encode("utf-8")
|
||||||
displayname = user_json["displayname"].encode("utf-8")
|
displayname = user_json["displayname"].encode("utf-8")
|
||||||
duration_seconds = 0
|
|
||||||
try:
|
|
||||||
duration_seconds = int(user_json["duration_seconds"])
|
|
||||||
except ValueError:
|
|
||||||
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
|
||||||
if duration_seconds > self.direct_user_creation_max_duration:
|
|
||||||
duration_seconds = self.direct_user_creation_max_duration
|
|
||||||
password_hash = user_json["password_hash"].encode("utf-8") \
|
password_hash = user_json["password_hash"].encode("utf-8") \
|
||||||
if user_json.get("password_hash") else None
|
if user_json.get("password_hash") else None
|
||||||
|
|
||||||
@ -438,7 +427,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||||||
requester=requester,
|
requester=requester,
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
displayname=displayname,
|
displayname=displayname,
|
||||||
duration_in_ms=(duration_seconds * 1000),
|
|
||||||
password_hash=password_hash
|
password_hash=password_hash
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ from synapse.api.errors import SynapseError, Codes, AuthError
|
|||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
from synapse.types import UserID, RoomID, RoomAlias
|
from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID
|
||||||
from synapse.events.utils import serialize_event, format_event_for_client_v2
|
from synapse.events.utils import serialize_event, format_event_for_client_v2
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
parse_json_object_from_request, parse_string, parse_integer
|
parse_json_object_from_request, parse_string, parse_integer
|
||||||
@ -321,6 +321,20 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
|||||||
since_token = content.get("since", None)
|
since_token = content.get("since", None)
|
||||||
search_filter = content.get("filter", None)
|
search_filter = content.get("filter", None)
|
||||||
|
|
||||||
|
include_all_networks = content.get("include_all_networks", False)
|
||||||
|
third_party_instance_id = content.get("third_party_instance_id", None)
|
||||||
|
|
||||||
|
if include_all_networks:
|
||||||
|
network_tuple = None
|
||||||
|
if third_party_instance_id is not None:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Can't use include_all_networks with an explicit network"
|
||||||
|
)
|
||||||
|
elif third_party_instance_id is None:
|
||||||
|
network_tuple = ThirdPartyInstanceID(None, None)
|
||||||
|
else:
|
||||||
|
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
|
||||||
|
|
||||||
handler = self.hs.get_room_list_handler()
|
handler = self.hs.get_room_list_handler()
|
||||||
if server:
|
if server:
|
||||||
data = yield handler.get_remote_public_room_list(
|
data = yield handler.get_remote_public_room_list(
|
||||||
@ -328,12 +342,15 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
since_token=since_token,
|
since_token=since_token,
|
||||||
search_filter=search_filter,
|
search_filter=search_filter,
|
||||||
|
include_all_networks=include_all_networks,
|
||||||
|
third_party_instance_id=third_party_instance_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data = yield handler.get_local_public_room_list(
|
data = yield handler.get_local_public_room_list(
|
||||||
limit=limit,
|
limit=limit,
|
||||||
since_token=since_token,
|
since_token=since_token,
|
||||||
search_filter=search_filter,
|
search_filter=search_filter,
|
||||||
|
network_tuple=network_tuple,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, data))
|
defer.returnValue((200, data))
|
||||||
@ -369,6 +386,24 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
|
||||||
|
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, room_id):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"joined": users_with_profile
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
# TODO: Needs better unit testing
|
# TODO: Needs better unit testing
|
||||||
class RoomMessageListRestServlet(ClientV1RestServlet):
|
class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
|
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
|
||||||
@ -692,6 +727,22 @@ class SearchRestServlet(ClientV1RestServlet):
|
|||||||
defer.returnValue((200, results))
|
defer.returnValue((200, results))
|
||||||
|
|
||||||
|
|
||||||
|
class JoinedRoomsRestServlet(ClientV1RestServlet):
|
||||||
|
PATTERNS = client_path_patterns("/joined_rooms$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(JoinedRoomsRestServlet, self).__init__(hs)
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request):
|
||||||
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
|
rooms = yield self.store.get_rooms_for_user(requester.user.to_string())
|
||||||
|
room_ids = set(r.room_id for r in rooms) # Ensure they're unique.
|
||||||
|
defer.returnValue((200, {"joined_rooms": list(room_ids)}))
|
||||||
|
|
||||||
|
|
||||||
def register_txn_path(servlet, regex_string, http_server, with_get=False):
|
def register_txn_path(servlet, regex_string, http_server, with_get=False):
|
||||||
"""Registers a transaction-based path.
|
"""Registers a transaction-based path.
|
||||||
|
|
||||||
@ -727,6 +778,7 @@ def register_servlets(hs, http_server):
|
|||||||
RoomStateEventRestServlet(hs).register(http_server)
|
RoomStateEventRestServlet(hs).register(http_server)
|
||||||
RoomCreateRestServlet(hs).register(http_server)
|
RoomCreateRestServlet(hs).register(http_server)
|
||||||
RoomMemberListRestServlet(hs).register(http_server)
|
RoomMemberListRestServlet(hs).register(http_server)
|
||||||
|
JoinedRoomMemberListRestServlet(hs).register(http_server)
|
||||||
RoomMessageListRestServlet(hs).register(http_server)
|
RoomMessageListRestServlet(hs).register(http_server)
|
||||||
JoinRoomAliasServlet(hs).register(http_server)
|
JoinRoomAliasServlet(hs).register(http_server)
|
||||||
RoomForgetRestServlet(hs).register(http_server)
|
RoomForgetRestServlet(hs).register(http_server)
|
||||||
@ -738,4 +790,5 @@ def register_servlets(hs, http_server):
|
|||||||
RoomRedactEventRestServlet(hs).register(http_server)
|
RoomRedactEventRestServlet(hs).register(http_server)
|
||||||
RoomTypingRestServlet(hs).register(http_server)
|
RoomTypingRestServlet(hs).register(http_server)
|
||||||
SearchRestServlet(hs).register(http_server)
|
SearchRestServlet(hs).register(http_server)
|
||||||
|
JoinedRoomsRestServlet(hs).register(http_server)
|
||||||
RoomEventContext(hs).register(http_server)
|
RoomEventContext(hs).register(http_server)
|
||||||
|
@ -39,7 +39,7 @@ class DevicesRestServlet(servlet.RestServlet):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
devices = yield self.device_handler.get_devices_by_user(
|
devices = yield self.device_handler.get_devices_by_user(
|
||||||
requester.user.to_string()
|
requester.user.to_string()
|
||||||
)
|
)
|
||||||
@ -63,7 +63,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, device_id):
|
def on_GET(self, request, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
device = yield self.device_handler.get_device(
|
device = yield self.device_handler.get_device(
|
||||||
requester.user.to_string(),
|
requester.user.to_string(),
|
||||||
device_id,
|
device_id,
|
||||||
@ -99,7 +99,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, device_id):
|
def on_PUT(self, request, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
body = servlet.parse_json_object_from_request(request)
|
body = servlet.parse_json_object_from_request(request)
|
||||||
yield self.device_handler.update_device(
|
yield self.device_handler.update_device(
|
||||||
|
@ -65,7 +65,7 @@ class KeyUploadServlet(RestServlet):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, device_id):
|
def on_POST(self, request, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
@ -94,10 +94,6 @@ class KeyUploadServlet(RestServlet):
|
|||||||
|
|
||||||
class KeyQueryServlet(RestServlet):
|
class KeyQueryServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
GET /keys/query/<user_id> HTTP/1.1
|
|
||||||
|
|
||||||
GET /keys/query/<user_id>/<device_id> HTTP/1.1
|
|
||||||
|
|
||||||
POST /keys/query HTTP/1.1
|
POST /keys/query HTTP/1.1
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
{
|
{
|
||||||
@ -131,11 +127,7 @@ class KeyQueryServlet(RestServlet):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
PATTERNS = client_v2_patterns(
|
PATTERNS = client_v2_patterns(
|
||||||
"/keys/query(?:"
|
"/keys/query$",
|
||||||
"/(?P<user_id>[^/]*)(?:"
|
|
||||||
"/(?P<device_id>[^/]*)"
|
|
||||||
")?"
|
|
||||||
")?",
|
|
||||||
releases=()
|
releases=()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -149,31 +141,16 @@ class KeyQueryServlet(RestServlet):
|
|||||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id, device_id):
|
def on_POST(self, request):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
result = yield self.e2e_keys_handler.query_devices(body, timeout)
|
result = yield self.e2e_keys_handler.query_devices(body, timeout)
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def on_GET(self, request, user_id, device_id):
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
|
||||||
auth_user_id = requester.user.to_string()
|
|
||||||
user_id = user_id if user_id else auth_user_id
|
|
||||||
device_ids = [device_id] if device_id else []
|
|
||||||
result = yield self.e2e_keys_handler.query_devices(
|
|
||||||
{"device_keys": {user_id: device_ids}},
|
|
||||||
timeout,
|
|
||||||
)
|
|
||||||
defer.returnValue((200, result))
|
|
||||||
|
|
||||||
|
|
||||||
class OneTimeKeyServlet(RestServlet):
|
class OneTimeKeyServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
|
|
||||||
|
|
||||||
POST /keys/claim HTTP/1.1
|
POST /keys/claim HTTP/1.1
|
||||||
{
|
{
|
||||||
"one_time_keys": {
|
"one_time_keys": {
|
||||||
@ -191,9 +168,7 @@ class OneTimeKeyServlet(RestServlet):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
PATTERNS = client_v2_patterns(
|
PATTERNS = client_v2_patterns(
|
||||||
"/keys/claim(?:/?|(?:/"
|
"/keys/claim$",
|
||||||
"(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
|
|
||||||
")?)",
|
|
||||||
releases=()
|
releases=()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -203,18 +178,8 @@ class OneTimeKeyServlet(RestServlet):
|
|||||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, device_id, algorithm):
|
def on_POST(self, request):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
|
||||||
result = yield self.e2e_keys_handler.claim_one_time_keys(
|
|
||||||
{"one_time_keys": {user_id: {device_id: algorithm}}},
|
|
||||||
timeout,
|
|
||||||
)
|
|
||||||
defer.returnValue((200, result))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def on_POST(self, request, user_id, device_id, algorithm):
|
|
||||||
yield self.auth.get_user_by_req(request)
|
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
result = yield self.e2e_keys_handler.claim_one_time_keys(
|
result = yield self.e2e_keys_handler.claim_one_time_keys(
|
||||||
|
@ -36,7 +36,7 @@ class ReceiptRestServlet(RestServlet):
|
|||||||
super(ReceiptRestServlet, self).__init__()
|
super(ReceiptRestServlet, self).__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.receipts_handler = hs.get_handlers().receipts_handler
|
self.receipts_handler = hs.get_receipts_handler()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse
|
||||||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
from synapse.api.auth import get_access_token_from_request, has_access_token
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||||
@ -100,12 +101,14 @@ class RegisterRestServlet(RestServlet):
|
|||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
kind = "user"
|
kind = "user"
|
||||||
if "kind" in request.args:
|
if "kind" in request.args:
|
||||||
kind = request.args["kind"][0]
|
kind = request.args["kind"][0]
|
||||||
|
|
||||||
if kind == "guest":
|
if kind == "guest":
|
||||||
ret = yield self._do_guest_registration()
|
ret = yield self._do_guest_registration(body)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
return
|
return
|
||||||
elif kind != "user":
|
elif kind != "user":
|
||||||
@ -113,8 +116,6 @@ class RegisterRestServlet(RestServlet):
|
|||||||
"Do not understand membership kind: %s" % (kind,)
|
"Do not understand membership kind: %s" % (kind,)
|
||||||
)
|
)
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
# we do basic sanity checks here because the auth layer will store these
|
# we do basic sanity checks here because the auth layer will store these
|
||||||
# in sessions. Pull out the username/password provided to us.
|
# in sessions. Pull out the username/password provided to us.
|
||||||
desired_password = None
|
desired_password = None
|
||||||
@ -373,8 +374,7 @@ class RegisterRestServlet(RestServlet):
|
|||||||
def _create_registration_details(self, user_id, params):
|
def _create_registration_details(self, user_id, params):
|
||||||
"""Complete registration of newly-registered user
|
"""Complete registration of newly-registered user
|
||||||
|
|
||||||
Allocates device_id if one was not given; also creates access_token
|
Allocates device_id if one was not given; also creates access_token.
|
||||||
and refresh_token.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
(str) user_id: full canonical @user:id
|
(str) user_id: full canonical @user:id
|
||||||
@ -385,8 +385,8 @@ class RegisterRestServlet(RestServlet):
|
|||||||
"""
|
"""
|
||||||
device_id = yield self._register_device(user_id, params)
|
device_id = yield self._register_device(user_id, params)
|
||||||
|
|
||||||
access_token, refresh_token = (
|
access_token = (
|
||||||
yield self.auth_handler.get_login_tuple_for_user_id(
|
yield self.auth_handler.get_access_token_for_user_id(
|
||||||
user_id, device_id=device_id,
|
user_id, device_id=device_id,
|
||||||
initial_display_name=params.get("initial_device_display_name")
|
initial_display_name=params.get("initial_device_display_name")
|
||||||
)
|
)
|
||||||
@ -396,7 +396,6 @@ class RegisterRestServlet(RestServlet):
|
|||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"refresh_token": refresh_token,
|
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -421,20 +420,28 @@ class RegisterRestServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_guest_registration(self):
|
def _do_guest_registration(self, params):
|
||||||
if not self.hs.config.allow_guest_access:
|
if not self.hs.config.allow_guest_access:
|
||||||
defer.returnValue((403, "Guest access is disabled"))
|
defer.returnValue((403, "Guest access is disabled"))
|
||||||
user_id, _ = yield self.registration_handler.register(
|
user_id, _ = yield self.registration_handler.register(
|
||||||
generate_token=False,
|
generate_token=False,
|
||||||
make_guest=True
|
make_guest=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# we don't allow guests to specify their own device_id, because
|
||||||
|
# we have nowhere to store it.
|
||||||
|
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
||||||
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
|
self.device_handler.check_device_registered(
|
||||||
|
user_id, device_id, initial_display_name
|
||||||
|
)
|
||||||
|
|
||||||
access_token = self.auth_handler.generate_access_token(
|
access_token = self.auth_handler.generate_access_token(
|
||||||
user_id, ["guest = true"]
|
user_id, ["guest = true"]
|
||||||
)
|
)
|
||||||
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
|
|
||||||
# so long as we don't return a refresh_token here.
|
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
}))
|
}))
|
||||||
|
@ -50,7 +50,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _put(self, request, message_type, txn_id):
|
def _put(self, request, message_type, txn_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
joined = self.encode_joined(
|
joined = self.encode_joined(
|
||||||
sync_result.joined, time_now, requester.access_token_id
|
sync_result.joined, time_now, requester.access_token_id, filter.event_fields
|
||||||
)
|
)
|
||||||
|
|
||||||
invited = self.encode_invited(
|
invited = self.encode_invited(
|
||||||
@ -170,7 +170,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
archived = self.encode_archived(
|
archived = self.encode_archived(
|
||||||
sync_result.archived, time_now, requester.access_token_id
|
sync_result.archived, time_now, requester.access_token_id, filter.event_fields
|
||||||
)
|
)
|
||||||
|
|
||||||
response_content = {
|
response_content = {
|
||||||
@ -197,7 +197,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
formatted.append(event)
|
formatted.append(event)
|
||||||
return {"events": formatted}
|
return {"events": formatted}
|
||||||
|
|
||||||
def encode_joined(self, rooms, time_now, token_id):
|
def encode_joined(self, rooms, time_now, token_id, event_fields):
|
||||||
"""
|
"""
|
||||||
Encode the joined rooms in a sync result
|
Encode the joined rooms in a sync result
|
||||||
|
|
||||||
@ -208,7 +208,8 @@ class SyncRestServlet(RestServlet):
|
|||||||
calculations
|
calculations
|
||||||
token_id(int): ID of the user's auth token - used for namespacing
|
token_id(int): ID of the user's auth token - used for namespacing
|
||||||
of transaction IDs
|
of transaction IDs
|
||||||
|
event_fields(list<str>): List of event fields to include. If empty,
|
||||||
|
all fields will be returned.
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, dict[str, object]]: the joined rooms list, in our
|
dict[str, dict[str, object]]: the joined rooms list, in our
|
||||||
response format
|
response format
|
||||||
@ -216,7 +217,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
joined = {}
|
joined = {}
|
||||||
for room in rooms:
|
for room in rooms:
|
||||||
joined[room.room_id] = self.encode_room(
|
joined[room.room_id] = self.encode_room(
|
||||||
room, time_now, token_id
|
room, time_now, token_id, only_fields=event_fields
|
||||||
)
|
)
|
||||||
|
|
||||||
return joined
|
return joined
|
||||||
@ -253,7 +254,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
|
|
||||||
return invited
|
return invited
|
||||||
|
|
||||||
def encode_archived(self, rooms, time_now, token_id):
|
def encode_archived(self, rooms, time_now, token_id, event_fields):
|
||||||
"""
|
"""
|
||||||
Encode the archived rooms in a sync result
|
Encode the archived rooms in a sync result
|
||||||
|
|
||||||
@ -264,7 +265,8 @@ class SyncRestServlet(RestServlet):
|
|||||||
calculations
|
calculations
|
||||||
token_id(int): ID of the user's auth token - used for namespacing
|
token_id(int): ID of the user's auth token - used for namespacing
|
||||||
of transaction IDs
|
of transaction IDs
|
||||||
|
event_fields(list<str>): List of event fields to include. If empty,
|
||||||
|
all fields will be returned.
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, dict[str, object]]: The invited rooms list, in our
|
dict[str, dict[str, object]]: The invited rooms list, in our
|
||||||
response format
|
response format
|
||||||
@ -272,13 +274,13 @@ class SyncRestServlet(RestServlet):
|
|||||||
joined = {}
|
joined = {}
|
||||||
for room in rooms:
|
for room in rooms:
|
||||||
joined[room.room_id] = self.encode_room(
|
joined[room.room_id] = self.encode_room(
|
||||||
room, time_now, token_id, joined=False
|
room, time_now, token_id, joined=False, only_fields=event_fields
|
||||||
)
|
)
|
||||||
|
|
||||||
return joined
|
return joined
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encode_room(room, time_now, token_id, joined=True):
|
def encode_room(room, time_now, token_id, joined=True, only_fields=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
|
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
|
||||||
@ -289,7 +291,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
of transaction IDs
|
of transaction IDs
|
||||||
joined (bool): True if the user is joined to this room - will mean
|
joined (bool): True if the user is joined to this room - will mean
|
||||||
we handle ephemeral events
|
we handle ephemeral events
|
||||||
|
only_fields(list<str>): Optional. The list of event fields to include.
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, object]: the room, encoded in our response format
|
dict[str, object]: the room, encoded in our response format
|
||||||
"""
|
"""
|
||||||
@ -298,6 +300,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
return serialize_event(
|
return serialize_event(
|
||||||
event, time_now, token_id=token_id,
|
event, time_now, token_id=token_id,
|
||||||
event_format=format_event_for_client_v2_without_room_id,
|
event_format=format_event_for_client_v2_without_room_id,
|
||||||
|
only_event_fields=only_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_dict = room.state
|
state_dict = room.state
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, StoreError, SynapseError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
@ -30,30 +30,10 @@ class TokenRefreshRestServlet(RestServlet):
|
|||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(TokenRefreshRestServlet, self).__init__()
|
super(TokenRefreshRestServlet, self).__init__()
|
||||||
self.hs = hs
|
|
||||||
self.store = hs.get_datastore()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
body = parse_json_object_from_request(request)
|
raise AuthError(403, "tokenrefresh is no longer supported.")
|
||||||
try:
|
|
||||||
old_refresh_token = body["refresh_token"]
|
|
||||||
auth_handler = self.hs.get_auth_handler()
|
|
||||||
refresh_result = yield self.store.exchange_refresh_token(
|
|
||||||
old_refresh_token, auth_handler.generate_refresh_token
|
|
||||||
)
|
|
||||||
(user_id, new_refresh_token, device_id) = refresh_result
|
|
||||||
new_access_token = yield auth_handler.issue_access_token(
|
|
||||||
user_id, device_id
|
|
||||||
)
|
|
||||||
defer.returnValue((200, {
|
|
||||||
"access_token": new_access_token,
|
|
||||||
"refresh_token": new_refresh_token,
|
|
||||||
}))
|
|
||||||
except KeyError:
|
|
||||||
raise SynapseError(400, "Missing required key 'refresh_token'.")
|
|
||||||
except StoreError:
|
|
||||||
raise AuthError(403, "Did not recognize refresh token")
|
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
|
@ -381,7 +381,10 @@ def _calc_og(tree, media_uri):
|
|||||||
if 'og:title' not in og:
|
if 'og:title' not in og:
|
||||||
# do some basic spidering of the HTML
|
# do some basic spidering of the HTML
|
||||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||||
og['og:title'] = title[0].text.strip() if title else None
|
if title and title[0].text is not None:
|
||||||
|
og['og:title'] = title[0].text.strip()
|
||||||
|
else:
|
||||||
|
og['og:title'] = None
|
||||||
|
|
||||||
if 'og:image' not in og:
|
if 'og:image' not in og:
|
||||||
# TODO: extract a favicon failing all else
|
# TODO: extract a favicon failing all else
|
||||||
@ -543,5 +546,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
|||||||
|
|
||||||
# We always add an ellipsis because at the very least
|
# We always add an ellipsis because at the very least
|
||||||
# we chopped mid paragraph.
|
# we chopped mid paragraph.
|
||||||
description = new_desc.strip() + "…"
|
description = new_desc.strip() + u"…"
|
||||||
return description if description else None
|
return description if description else None
|
||||||
|
@ -32,6 +32,9 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler
|
|||||||
from synapse.crypto.keyring import Keyring
|
from synapse.crypto.keyring import Keyring
|
||||||
from synapse.events.builder import EventBuilderFactory
|
from synapse.events.builder import EventBuilderFactory
|
||||||
from synapse.federation import initialize_http_replication
|
from synapse.federation import initialize_http_replication
|
||||||
|
from synapse.federation.send_queue import FederationRemoteSendQueue
|
||||||
|
from synapse.federation.transport.client import TransportLayerClient
|
||||||
|
from synapse.federation.transaction_queue import TransactionQueue
|
||||||
from synapse.handlers import Handlers
|
from synapse.handlers import Handlers
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.handlers.auth import AuthHandler
|
from synapse.handlers.auth import AuthHandler
|
||||||
@ -44,6 +47,7 @@ from synapse.handlers.sync import SyncHandler
|
|||||||
from synapse.handlers.typing import TypingHandler
|
from synapse.handlers.typing import TypingHandler
|
||||||
from synapse.handlers.events import EventHandler, EventStreamHandler
|
from synapse.handlers.events import EventHandler, EventStreamHandler
|
||||||
from synapse.handlers.initial_sync import InitialSyncHandler
|
from synapse.handlers.initial_sync import InitialSyncHandler
|
||||||
|
from synapse.handlers.receipts import ReceiptsHandler
|
||||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.notifier import Notifier
|
from synapse.notifier import Notifier
|
||||||
@ -124,6 +128,9 @@ class HomeServer(object):
|
|||||||
'http_client_context_factory',
|
'http_client_context_factory',
|
||||||
'simple_http_client',
|
'simple_http_client',
|
||||||
'media_repository',
|
'media_repository',
|
||||||
|
'federation_transport_client',
|
||||||
|
'federation_sender',
|
||||||
|
'receipts_handler',
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hostname, **kwargs):
|
def __init__(self, hostname, **kwargs):
|
||||||
@ -265,9 +272,30 @@ class HomeServer(object):
|
|||||||
def build_media_repository(self):
|
def build_media_repository(self):
|
||||||
return MediaRepository(self)
|
return MediaRepository(self)
|
||||||
|
|
||||||
|
def build_federation_transport_client(self):
|
||||||
|
return TransportLayerClient(self)
|
||||||
|
|
||||||
|
def build_federation_sender(self):
|
||||||
|
if self.should_send_federation():
|
||||||
|
return TransactionQueue(self)
|
||||||
|
elif not self.config.worker_app:
|
||||||
|
return FederationRemoteSendQueue(self)
|
||||||
|
else:
|
||||||
|
raise Exception("Workers cannot send federation traffic")
|
||||||
|
|
||||||
|
def build_receipts_handler(self):
|
||||||
|
return ReceiptsHandler(self)
|
||||||
|
|
||||||
def remove_pusher(self, app_id, push_key, user_id):
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||||
|
|
||||||
|
def should_send_federation(self):
|
||||||
|
"Should this server be sending federation traffic directly?"
|
||||||
|
return self.config.send_federation and (
|
||||||
|
not self.config.worker_app
|
||||||
|
or self.config.worker_app == "synapse.app.federation_sender"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_dependency_method(depname):
|
def _make_dependency_method(depname):
|
||||||
def _get(hs):
|
def _get(hs):
|
||||||
|
@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
||||||
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
|
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
|
||||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||||
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
|
||||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||||
@ -223,6 +222,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
||||||
|
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
||||||
|
|
||||||
super(DataStore, self).__init__(hs)
|
super(DataStore, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -561,12 +561,17 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||||
|
if keyvalues:
|
||||||
|
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
|
||||||
|
else:
|
||||||
|
where = ""
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
|
"SELECT %(retcol)s FROM %(table)s %(where)s"
|
||||||
) % {
|
) % {
|
||||||
"retcol": retcol,
|
"retcol": retcol,
|
||||||
"table": table,
|
"table": table,
|
||||||
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
|
"where": where,
|
||||||
}
|
}
|
||||||
|
|
||||||
txn.execute(sql, keyvalues.values())
|
txn.execute(sql, keyvalues.values())
|
||||||
@ -744,10 +749,15 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
||||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
if keyvalues:
|
||||||
|
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
|
||||||
|
else:
|
||||||
|
where = ""
|
||||||
|
|
||||||
|
update_sql = "UPDATE %s SET %s %s" % (
|
||||||
table,
|
table,
|
||||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
where,
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
|
@ -39,6 +39,14 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||||||
def get_app_services(self):
|
def get_app_services(self):
|
||||||
return self.services_cache
|
return self.services_cache
|
||||||
|
|
||||||
|
def get_if_app_services_interested_in_user(self, user_id):
|
||||||
|
"""Check if the user is one associated with an app service
|
||||||
|
"""
|
||||||
|
for service in self.services_cache:
|
||||||
|
if service.is_interested_in_user(user_id):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_app_service_by_user_id(self, user_id):
|
def get_app_service_by_user_id(self, user_id):
|
||||||
"""Retrieve an application service from their user ID.
|
"""Retrieve an application service from their user ID.
|
||||||
|
|
||||||
|
@ -242,7 +242,7 @@ class DeviceInboxStore(SQLBaseStore):
|
|||||||
device_id(str): The recipient device_id.
|
device_id(str): The recipient device_id.
|
||||||
up_to_stream_id(int): Where to delete messages up to.
|
up_to_stream_id(int): Where to delete messages up to.
|
||||||
Returns:
|
Returns:
|
||||||
A deferred that resolves when the messages have been deleted.
|
A deferred that resolves to the number of messages deleted.
|
||||||
"""
|
"""
|
||||||
def delete_messages_for_device_txn(txn):
|
def delete_messages_for_device_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
@ -251,6 +251,7 @@ class DeviceInboxStore(SQLBaseStore):
|
|||||||
" AND stream_id <= ?"
|
" AND stream_id <= ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id, device_id, up_to_stream_id))
|
txn.execute(sql, (user_id, device_id, up_to_stream_id))
|
||||||
|
return txn.rowcount
|
||||||
|
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"delete_messages_for_device", delete_messages_for_device_txn
|
"delete_messages_for_device", delete_messages_for_device_txn
|
||||||
@ -269,27 +270,29 @@ class DeviceInboxStore(SQLBaseStore):
|
|||||||
return defer.succeed([])
|
return defer.succeed([])
|
||||||
|
|
||||||
def get_all_new_device_messages_txn(txn):
|
def get_all_new_device_messages_txn(txn):
|
||||||
|
# We limit like this as we might have multiple rows per stream_id, and
|
||||||
|
# we want to make sure we always get all entries for any stream_id
|
||||||
|
# we return.
|
||||||
|
upper_pos = min(current_pos, last_pos + limit)
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_id FROM device_inbox"
|
"SELECT stream_id, user_id"
|
||||||
" WHERE ? < stream_id AND stream_id <= ?"
|
|
||||||
" GROUP BY stream_id"
|
|
||||||
" ORDER BY stream_id ASC"
|
|
||||||
" LIMIT ?"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (last_pos, current_pos, limit))
|
|
||||||
stream_ids = txn.fetchall()
|
|
||||||
if not stream_ids:
|
|
||||||
return []
|
|
||||||
max_stream_id_in_limit = stream_ids[-1]
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT stream_id, user_id, device_id, message_json"
|
|
||||||
" FROM device_inbox"
|
" FROM device_inbox"
|
||||||
" WHERE ? < stream_id AND stream_id <= ?"
|
" WHERE ? < stream_id AND stream_id <= ?"
|
||||||
" ORDER BY stream_id ASC"
|
" ORDER BY stream_id ASC"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (last_pos, max_stream_id_in_limit))
|
txn.execute(sql, (last_pos, upper_pos))
|
||||||
return txn.fetchall()
|
rows = txn.fetchall()
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT stream_id, destination"
|
||||||
|
" FROM device_federation_outbox"
|
||||||
|
" WHERE ? < stream_id AND stream_id <= ?"
|
||||||
|
" ORDER BY stream_id ASC"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (last_pos, upper_pos))
|
||||||
|
rows.extend(txn.fetchall())
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_all_new_device_messages", get_all_new_device_messages_txn
|
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||||
|
@ -39,6 +39,14 @@ class EventPushActionsStore(SQLBaseStore):
|
|||||||
columns=["user_id", "stream_ordering"],
|
columns=["user_id", "stream_ordering"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
"event_push_actions_highlights_index",
|
||||||
|
index_name="event_push_actions_highlights_index",
|
||||||
|
table="event_push_actions",
|
||||||
|
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
|
||||||
|
where_clause="highlight=1"
|
||||||
|
)
|
||||||
|
|
||||||
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
|
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -88,8 +96,11 @@ class EventPushActionsStore(SQLBaseStore):
|
|||||||
topological_ordering, stream_ordering
|
topological_ordering, stream_ordering
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# First get number of notifications.
|
||||||
|
# We don't need to put a notif=1 clause as all rows always have
|
||||||
|
# notif=1
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT sum(notif), sum(highlight)"
|
"SELECT count(*)"
|
||||||
" FROM event_push_actions ea"
|
" FROM event_push_actions ea"
|
||||||
" WHERE"
|
" WHERE"
|
||||||
" user_id = ?"
|
" user_id = ?"
|
||||||
@ -99,13 +110,27 @@ class EventPushActionsStore(SQLBaseStore):
|
|||||||
|
|
||||||
txn.execute(sql, (user_id, room_id))
|
txn.execute(sql, (user_id, room_id))
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
if row:
|
notify_count = row[0] if row else 0
|
||||||
|
|
||||||
|
# Now get the number of highlights
|
||||||
|
sql = (
|
||||||
|
"SELECT count(*)"
|
||||||
|
" FROM event_push_actions ea"
|
||||||
|
" WHERE"
|
||||||
|
" highlight = 1"
|
||||||
|
" AND user_id = ?"
|
||||||
|
" AND room_id = ?"
|
||||||
|
" AND %s"
|
||||||
|
) % (lower_bound(token, self.database_engine, inclusive=False),)
|
||||||
|
|
||||||
|
txn.execute(sql, (user_id, room_id))
|
||||||
|
row = txn.fetchone()
|
||||||
|
highlight_count = row[0] if row else 0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"notify_count": row[0] or 0,
|
"notify_count": notify_count,
|
||||||
"highlight_count": row[1] or 0,
|
"highlight_count": highlight_count,
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
return {"notify_count": 0, "highlight_count": 0}
|
|
||||||
|
|
||||||
ret = yield self.runInteraction(
|
ret = yield self.runInteraction(
|
||||||
"get_unread_event_push_actions_by_room",
|
"get_unread_event_push_actions_by_room",
|
||||||
|
@ -54,6 +54,7 @@ def encode_json(json_object):
|
|||||||
else:
|
else:
|
||||||
return json.dumps(json_object, ensure_ascii=False)
|
return json.dumps(json_object, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
# These values are used in the `enqueus_event` and `_do_fetch` methods to
|
# These values are used in the `enqueus_event` and `_do_fetch` methods to
|
||||||
# control how we batch/bulk fetch events from the database.
|
# control how we batch/bulk fetch events from the database.
|
||||||
# The values are plucked out of thing air to make initial sync run faster
|
# The values are plucked out of thing air to make initial sync run faster
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.api.errors import SynapseError, Codes
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
@ -24,6 +25,13 @@ import simplejson as json
|
|||||||
class FilteringStore(SQLBaseStore):
|
class FilteringStore(SQLBaseStore):
|
||||||
@cachedInlineCallbacks(num_args=2)
|
@cachedInlineCallbacks(num_args=2)
|
||||||
def get_user_filter(self, user_localpart, filter_id):
|
def get_user_filter(self, user_localpart, filter_id):
|
||||||
|
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||||
|
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||||
|
try:
|
||||||
|
int(filter_id)
|
||||||
|
except ValueError:
|
||||||
|
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
|
||||||
|
|
||||||
def_json = yield self._simple_select_one_onecol(
|
def_json = yield self._simple_select_one_onecol(
|
||||||
table="user_filters",
|
table="user_filters",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 38
|
SCHEMA_VERSION = 39
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
@ -37,6 +37,13 @@ class UserPresenceState(namedtuple("UserPresenceState",
|
|||||||
status_msg (str): User set status message.
|
status_msg (str): User set status message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return dict(self._asdict())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(d):
|
||||||
|
return UserPresenceState(**d)
|
||||||
|
|
||||||
def copy_and_replace(self, **kwargs):
|
def copy_and_replace(self, **kwargs):
|
||||||
return self._replace(**kwargs)
|
return self._replace(**kwargs)
|
||||||
|
|
||||||
|
@ -156,12 +156,20 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
|
|
||||||
local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u))
|
# We ignore app service users for now. This is so that we don't fill
|
||||||
|
# up the `get_if_users_have_pushers` cache with AS entries that we
|
||||||
|
# know don't have pushers, nor even read receipts.
|
||||||
|
local_users_in_room = set(
|
||||||
|
u for u in users_in_room
|
||||||
|
if self.hs.is_mine_id(u)
|
||||||
|
and not self.get_if_app_services_interested_in_user(u)
|
||||||
|
)
|
||||||
|
|
||||||
# users in the room who have pushers need to get push rules run because
|
# users in the room who have pushers need to get push rules run because
|
||||||
# that's how their pushers work
|
# that's how their pushers work
|
||||||
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
||||||
local_users_in_room, on_invalidate=cache_context.invalidate,
|
local_users_in_room,
|
||||||
|
on_invalidate=cache_context.invalidate,
|
||||||
)
|
)
|
||||||
user_ids = set(
|
user_ids = set(
|
||||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||||
|
@ -405,7 +405,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
room_id, receipt_type, user_id, event_ids, data
|
room_id, receipt_type, user_id, event_ids, data
|
||||||
)
|
)
|
||||||
|
|
||||||
max_persisted_id = self._stream_id_gen.get_current_token()
|
max_persisted_id = self._receipts_id_gen.get_current_token()
|
||||||
|
|
||||||
defer.returnValue((stream_id, max_persisted_id))
|
defer.returnValue((stream_id, max_persisted_id))
|
||||||
|
|
||||||
|
@ -68,31 +68,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
desc="add_access_token_to_user",
|
desc="add_access_token_to_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def add_refresh_token_to_user(self, user_id, token, device_id=None):
|
|
||||||
"""Adds a refresh token for the given user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id (str): The user ID.
|
|
||||||
token (str): The new refresh token to add.
|
|
||||||
device_id (str): ID of the device to associate with the access
|
|
||||||
token
|
|
||||||
Raises:
|
|
||||||
StoreError if there was a problem adding this.
|
|
||||||
"""
|
|
||||||
next_id = self._refresh_tokens_id_gen.get_next()
|
|
||||||
|
|
||||||
yield self._simple_insert(
|
|
||||||
"refresh_tokens",
|
|
||||||
{
|
|
||||||
"id": next_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"token": token,
|
|
||||||
"device_id": device_id,
|
|
||||||
},
|
|
||||||
desc="add_refresh_token_to_user",
|
|
||||||
)
|
|
||||||
|
|
||||||
def register(self, user_id, token=None, password_hash=None,
|
def register(self, user_id, token=None, password_hash=None,
|
||||||
was_guest=False, make_guest=False, appservice_id=None,
|
was_guest=False, make_guest=False, appservice_id=None,
|
||||||
create_profile_with_localpart=None, admin=False):
|
create_profile_with_localpart=None, admin=False):
|
||||||
@ -353,47 +328,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
token
|
token
|
||||||
)
|
)
|
||||||
|
|
||||||
def exchange_refresh_token(self, refresh_token, token_generator):
|
|
||||||
"""Exchange a refresh token for a new one.
|
|
||||||
|
|
||||||
Doing so invalidates the old refresh token - refresh tokens are single
|
|
||||||
use.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
refresh_token (str): The refresh token of a user.
|
|
||||||
token_generator (fn: str -> str): Function which, when given a
|
|
||||||
user ID, returns a unique refresh token for that user. This
|
|
||||||
function must never return the same value twice.
|
|
||||||
Returns:
|
|
||||||
tuple of (user_id, new_refresh_token, device_id)
|
|
||||||
Raises:
|
|
||||||
StoreError if no user was found with that refresh token.
|
|
||||||
"""
|
|
||||||
return self.runInteraction(
|
|
||||||
"exchange_refresh_token",
|
|
||||||
self._exchange_refresh_token,
|
|
||||||
refresh_token,
|
|
||||||
token_generator
|
|
||||||
)
|
|
||||||
|
|
||||||
def _exchange_refresh_token(self, txn, old_token, token_generator):
|
|
||||||
sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
|
|
||||||
txn.execute(sql, (old_token,))
|
|
||||||
rows = self.cursor_to_dict(txn)
|
|
||||||
if not rows:
|
|
||||||
raise StoreError(403, "Did not recognize refresh token")
|
|
||||||
user_id = rows[0]["user_id"]
|
|
||||||
device_id = rows[0]["device_id"]
|
|
||||||
|
|
||||||
# TODO(danielwh): Maybe perform a validation on the macaroon that
|
|
||||||
# macaroon.user_id == user_id.
|
|
||||||
|
|
||||||
new_token = token_generator(user_id)
|
|
||||||
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
|
|
||||||
txn.execute(sql, (new_token, old_token,))
|
|
||||||
|
|
||||||
return user_id, new_token, device_id
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def is_server_admin(self, user):
|
def is_server_admin(self, user):
|
||||||
res = yield self._simple_select_one_onecol(
|
res = yield self._simple_select_one_onecol(
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from .engines import PostgresEngine, Sqlite3Engine
|
from .engines import PostgresEngine, Sqlite3Engine
|
||||||
@ -106,7 +107,11 @@ class RoomStore(SQLBaseStore):
|
|||||||
entries = self._simple_select_list_txn(
|
entries = self._simple_select_list_txn(
|
||||||
txn,
|
txn,
|
||||||
table="public_room_list_stream",
|
table="public_room_list_stream",
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
"appservice_id": None,
|
||||||
|
"network_id": None,
|
||||||
|
},
|
||||||
retcols=("stream_id", "visibility"),
|
retcols=("stream_id", "visibility"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -124,6 +129,8 @@ class RoomStore(SQLBaseStore):
|
|||||||
"stream_id": next_id,
|
"stream_id": next_id,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"visibility": is_public,
|
"visibility": is_public,
|
||||||
|
"appservice_id": None,
|
||||||
|
"network_id": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -132,6 +139,87 @@ class RoomStore(SQLBaseStore):
|
|||||||
"set_room_is_public",
|
"set_room_is_public",
|
||||||
set_room_is_public_txn, next_id,
|
set_room_is_public_txn, next_id,
|
||||||
)
|
)
|
||||||
|
self.hs.get_notifier().on_new_replication_data()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
|
||||||
|
is_public):
|
||||||
|
"""Edit the appservice/network specific public room list.
|
||||||
|
|
||||||
|
Each appservice can have a number of published room lists associated
|
||||||
|
with them, keyed off of an appservice defined `network_id`, which
|
||||||
|
basically represents a single instance of a bridge to a third party
|
||||||
|
network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str)
|
||||||
|
appservice_id (str)
|
||||||
|
network_id (str)
|
||||||
|
is_public (bool): Whether to publish or unpublish the room from the
|
||||||
|
list.
|
||||||
|
"""
|
||||||
|
def set_room_is_public_appservice_txn(txn, next_id):
|
||||||
|
if is_public:
|
||||||
|
try:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="appservice_room_list",
|
||||||
|
values={
|
||||||
|
"appservice_id": appservice_id,
|
||||||
|
"network_id": network_id,
|
||||||
|
"room_id": room_id
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except self.database_engine.module.IntegrityError:
|
||||||
|
# We've already inserted, nothing to do.
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="appservice_room_list",
|
||||||
|
keyvalues={
|
||||||
|
"appservice_id": appservice_id,
|
||||||
|
"network_id": network_id,
|
||||||
|
"room_id": room_id
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
entries = self._simple_select_list_txn(
|
||||||
|
txn,
|
||||||
|
table="public_room_list_stream",
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
"appservice_id": appservice_id,
|
||||||
|
"network_id": network_id,
|
||||||
|
},
|
||||||
|
retcols=("stream_id", "visibility"),
|
||||||
|
)
|
||||||
|
|
||||||
|
entries.sort(key=lambda r: r["stream_id"])
|
||||||
|
|
||||||
|
add_to_stream = True
|
||||||
|
if entries:
|
||||||
|
add_to_stream = bool(entries[-1]["visibility"]) != is_public
|
||||||
|
|
||||||
|
if add_to_stream:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="public_room_list_stream",
|
||||||
|
values={
|
||||||
|
"stream_id": next_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
"visibility": is_public,
|
||||||
|
"appservice_id": appservice_id,
|
||||||
|
"network_id": network_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._public_room_id_gen.get_next() as next_id:
|
||||||
|
yield self.runInteraction(
|
||||||
|
"set_room_is_public_appservice",
|
||||||
|
set_room_is_public_appservice_txn, next_id,
|
||||||
|
)
|
||||||
|
self.hs.get_notifier().on_new_replication_data()
|
||||||
|
|
||||||
def get_public_room_ids(self):
|
def get_public_room_ids(self):
|
||||||
return self._simple_select_onecol(
|
return self._simple_select_onecol(
|
||||||
@ -259,38 +347,96 @@ class RoomStore(SQLBaseStore):
|
|||||||
def get_current_public_room_stream_id(self):
|
def get_current_public_room_stream_id(self):
|
||||||
return self._public_room_id_gen.get_current_token()
|
return self._public_room_id_gen.get_current_token()
|
||||||
|
|
||||||
def get_public_room_ids_at_stream_id(self, stream_id):
|
@cached(num_args=2, max_entries=100)
|
||||||
|
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
|
||||||
|
"""Get pulbic rooms for a particular list, or across all lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id (int)
|
||||||
|
network_tuple (ThirdPartyInstanceID): The list to use (None, None)
|
||||||
|
means the main list, None means all lsits.
|
||||||
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_public_room_ids_at_stream_id",
|
"get_public_room_ids_at_stream_id",
|
||||||
self.get_public_room_ids_at_stream_id_txn, stream_id
|
self.get_public_room_ids_at_stream_id_txn,
|
||||||
|
stream_id, network_tuple=network_tuple
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id):
|
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
|
||||||
|
network_tuple):
|
||||||
return {
|
return {
|
||||||
rm
|
rm
|
||||||
for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items()
|
for rm, vis in self.get_published_at_stream_id_txn(
|
||||||
|
txn, stream_id, network_tuple=network_tuple
|
||||||
|
).items()
|
||||||
if vis
|
if vis
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_published_at_stream_id_txn(self, txn, stream_id):
|
def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
|
||||||
|
if network_tuple:
|
||||||
|
# We want to get from a particular list. No aggregation required.
|
||||||
|
|
||||||
sql = ("""
|
sql = ("""
|
||||||
SELECT room_id, visibility FROM public_room_list_stream
|
SELECT room_id, visibility FROM public_room_list_stream
|
||||||
INNER JOIN (
|
INNER JOIN (
|
||||||
SELECT room_id, max(stream_id) AS stream_id
|
SELECT room_id, max(stream_id) AS stream_id
|
||||||
FROM public_room_list_stream
|
FROM public_room_list_stream
|
||||||
WHERE stream_id <= ?
|
WHERE stream_id <= ? %s
|
||||||
GROUP BY room_id
|
GROUP BY room_id
|
||||||
) grouped USING (room_id, stream_id)
|
) grouped USING (room_id, stream_id)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
txn.execute(sql, (stream_id,))
|
if network_tuple.appservice_id is not None:
|
||||||
|
txn.execute(
|
||||||
|
sql % ("AND appservice_id = ? AND network_id = ?",),
|
||||||
|
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
txn.execute(
|
||||||
|
sql % ("AND appservice_id IS NULL",),
|
||||||
|
(stream_id,)
|
||||||
|
)
|
||||||
return dict(txn.fetchall())
|
return dict(txn.fetchall())
|
||||||
|
else:
|
||||||
|
# We want to get from all lists, so we need to aggregate the results
|
||||||
|
|
||||||
def get_public_room_changes(self, prev_stream_id, new_stream_id):
|
logger.info("Executing full list")
|
||||||
|
|
||||||
|
sql = ("""
|
||||||
|
SELECT room_id, visibility
|
||||||
|
FROM public_room_list_stream
|
||||||
|
INNER JOIN (
|
||||||
|
SELECT
|
||||||
|
room_id, max(stream_id) AS stream_id, appservice_id,
|
||||||
|
network_id
|
||||||
|
FROM public_room_list_stream
|
||||||
|
WHERE stream_id <= ?
|
||||||
|
GROUP BY room_id, appservice_id, network_id
|
||||||
|
) grouped USING (room_id, stream_id)
|
||||||
|
""")
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(stream_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
# A room is visible if its visible on any list.
|
||||||
|
for room_id, visibility in txn.fetchall():
|
||||||
|
results[room_id] = bool(visibility) or results.get(room_id, False)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_public_room_changes(self, prev_stream_id, new_stream_id,
|
||||||
|
network_tuple):
|
||||||
def get_public_room_changes_txn(txn):
|
def get_public_room_changes_txn(txn):
|
||||||
then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id)
|
then_rooms = self.get_public_room_ids_at_stream_id_txn(
|
||||||
|
txn, prev_stream_id, network_tuple
|
||||||
|
)
|
||||||
|
|
||||||
now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id)
|
now_rooms_dict = self.get_published_at_stream_id_txn(
|
||||||
|
txn, new_stream_id, network_tuple
|
||||||
|
)
|
||||||
|
|
||||||
now_rooms_visible = set(
|
now_rooms_visible = set(
|
||||||
rm for rm, vis in now_rooms_dict.items() if vis
|
rm for rm, vis in now_rooms_dict.items() if vis
|
||||||
@ -311,7 +457,8 @@ class RoomStore(SQLBaseStore):
|
|||||||
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
||||||
def get_all_new_public_rooms(txn):
|
def get_all_new_public_rooms(txn):
|
||||||
sql = ("""
|
sql = ("""
|
||||||
SELECT stream_id, room_id, visibility FROM public_room_list_stream
|
SELECT stream_id, room_id, visibility, appservice_id, network_id
|
||||||
|
FROM public_room_list_stream
|
||||||
WHERE stream_id > ? AND stream_id <= ?
|
WHERE stream_id > ? AND stream_id <= ?
|
||||||
ORDER BY stream_id ASC
|
ORDER BY stream_id ASC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
|
@ -24,6 +24,7 @@ from synapse.api.constants import Membership, EventTypes
|
|||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import ujson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -34,7 +35,15 @@ RoomsForUser = namedtuple(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberStore(SQLBaseStore):
|
class RoomMemberStore(SQLBaseStore):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomMemberStore, self).__init__(hs)
|
||||||
|
self.register_background_update_handler(
|
||||||
|
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
||||||
|
)
|
||||||
|
|
||||||
def _store_room_members_txn(self, txn, events, backfilled):
|
def _store_room_members_txn(self, txn, events, backfilled):
|
||||||
"""Store a room member in the database.
|
"""Store a room member in the database.
|
||||||
@ -49,6 +58,8 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
"sender": event.user_id,
|
"sender": event.user_id,
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"membership": event.membership,
|
"membership": event.membership,
|
||||||
|
"display_name": event.content.get("displayname", None),
|
||||||
|
"avatar_url": event.content.get("avatar_url", None),
|
||||||
}
|
}
|
||||||
for event in events
|
for event in events
|
||||||
]
|
]
|
||||||
@ -398,7 +409,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
table="room_memberships",
|
table="room_memberships",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=member_event_ids,
|
iterable=member_event_ids,
|
||||||
retcols=['user_id'],
|
retcols=['user_id', 'display_name', 'avatar_url'],
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"membership": Membership.JOIN,
|
"membership": Membership.JOIN,
|
||||||
},
|
},
|
||||||
@ -406,11 +417,21 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
desc="_get_joined_users_from_context",
|
desc="_get_joined_users_from_context",
|
||||||
)
|
)
|
||||||
|
|
||||||
users_in_room = set(row["user_id"] for row in rows)
|
users_in_room = {
|
||||||
|
row["user_id"]: {
|
||||||
|
"display_name": row["display_name"],
|
||||||
|
"avatar_url": row["avatar_url"],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
|
||||||
if event is not None and event.type == EventTypes.Member:
|
if event is not None and event.type == EventTypes.Member:
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if event.event_id in member_event_ids:
|
if event.event_id in member_event_ids:
|
||||||
users_in_room.add(event.state_key)
|
users_in_room[event.state_key] = {
|
||||||
|
"display_name": event.content.get("displayname", None),
|
||||||
|
"avatar_url": event.content.get("avatar_url", None),
|
||||||
|
}
|
||||||
|
|
||||||
defer.returnValue(users_in_room)
|
defer.returnValue(users_in_room)
|
||||||
|
|
||||||
@ -448,3 +469,78 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
|
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _background_add_membership_profile(self, progress, batch_size):
|
||||||
|
target_min_stream_id = progress.get(
|
||||||
|
"target_min_stream_id_inclusive", self._min_stream_order_on_start
|
||||||
|
)
|
||||||
|
max_stream_id = progress.get(
|
||||||
|
"max_stream_id_exclusive", self._stream_order_on_start + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
INSERT_CLUMP_SIZE = 1000
|
||||||
|
|
||||||
|
def add_membership_profile_txn(txn):
|
||||||
|
sql = ("""
|
||||||
|
SELECT stream_ordering, event_id, events.room_id, content
|
||||||
|
FROM events
|
||||||
|
INNER JOIN room_memberships USING (event_id)
|
||||||
|
WHERE ? <= stream_ordering AND stream_ordering < ?
|
||||||
|
AND type = 'm.room.member'
|
||||||
|
ORDER BY stream_ordering DESC
|
||||||
|
LIMIT ?
|
||||||
|
""")
|
||||||
|
|
||||||
|
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
|
||||||
|
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
if not rows:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
min_stream_id = rows[-1]["stream_ordering"]
|
||||||
|
|
||||||
|
to_update = []
|
||||||
|
for row in rows:
|
||||||
|
event_id = row["event_id"]
|
||||||
|
room_id = row["room_id"]
|
||||||
|
try:
|
||||||
|
content = json.loads(row["content"])
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
display_name = content.get("displayname", None)
|
||||||
|
avatar_url = content.get("avatar_url", None)
|
||||||
|
|
||||||
|
if display_name or avatar_url:
|
||||||
|
to_update.append((
|
||||||
|
display_name, avatar_url, event_id, room_id
|
||||||
|
))
|
||||||
|
|
||||||
|
to_update_sql = ("""
|
||||||
|
UPDATE room_memberships SET display_name = ?, avatar_url = ?
|
||||||
|
WHERE event_id = ? AND room_id = ?
|
||||||
|
""")
|
||||||
|
for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
|
||||||
|
clump = to_update[index:index + INSERT_CLUMP_SIZE]
|
||||||
|
txn.executemany(to_update_sql, clump)
|
||||||
|
|
||||||
|
progress = {
|
||||||
|
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||||
|
"max_stream_id_exclusive": min_stream_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
self._background_update_progress_txn(
|
||||||
|
txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(rows)
|
||||||
|
|
||||||
|
result = yield self.runInteraction(
|
||||||
|
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
|
||||||
|
|
||||||
|
defer.returnValue(result)
|
||||||
|
29
synapse/storage/schema/delta/39/appservice_room_list.sql
Normal file
29
synapse/storage/schema/delta/39/appservice_room_list.sql
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE appservice_room_list(
|
||||||
|
appservice_id TEXT NOT NULL,
|
||||||
|
network_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Each appservice can have multiple published room lists associated with them,
|
||||||
|
-- keyed of a particular network_id
|
||||||
|
CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list(
|
||||||
|
appservice_id, network_id, room_id
|
||||||
|
);
|
||||||
|
|
||||||
|
ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT;
|
||||||
|
ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT;
|
@ -0,0 +1,16 @@
|
|||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id);
|
17
synapse/storage/schema/delta/39/event_push_index.sql
Normal file
17
synapse/storage/schema/delta/39/event_push_index.sql
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('event_push_actions_highlights_index', '{}');
|
22
synapse/storage/schema/delta/39/federation_out_position.sql
Normal file
22
synapse/storage/schema/delta/39/federation_out_position.sql
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE federation_stream_position(
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
stream_id INTEGER NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1);
|
||||||
|
INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events;
|
20
synapse/storage/schema/delta/39/membership_profile.sql
Normal file
20
synapse/storage/schema/delta/39/membership_profile.sql
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
ALTER TABLE room_memberships ADD COLUMN display_name TEXT;
|
||||||
|
ALTER TABLE room_memberships ADD COLUMN avatar_url TEXT;
|
||||||
|
|
||||||
|
INSERT into background_updates (update_name, progress_json)
|
||||||
|
VALUES ('room_membership_profile_update', '{}');
|
@ -653,7 +653,10 @@ class StateStore(SQLBaseStore):
|
|||||||
else:
|
else:
|
||||||
state_dict = results[group]
|
state_dict = results[group]
|
||||||
|
|
||||||
state_dict.update(group_state_dict)
|
state_dict.update({
|
||||||
|
(intern_string(k[0]), intern_string(k[1])): v
|
||||||
|
for k, v in group_state_dict.items()
|
||||||
|
})
|
||||||
|
|
||||||
self._state_group_cache.update(
|
self._state_group_cache.update(
|
||||||
cache_seq_num,
|
cache_seq_num,
|
||||||
|
@ -541,6 +541,9 @@ class StreamStore(SQLBaseStore):
|
|||||||
def get_room_max_stream_ordering(self):
|
def get_room_max_stream_ordering(self):
|
||||||
return self._stream_id_gen.get_current_token()
|
return self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
|
def get_room_min_stream_ordering(self):
|
||||||
|
return self._backfill_id_gen.get_current_token()
|
||||||
|
|
||||||
def get_stream_token_for_event(self, event_id):
|
def get_stream_token_for_event(self, event_id):
|
||||||
"""The stream token for an event
|
"""The stream token for an event
|
||||||
Args:
|
Args:
|
||||||
@ -765,3 +768,50 @@ class StreamStore(SQLBaseStore):
|
|||||||
"token": end_token,
|
"token": end_token,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_all_new_events_stream(self, from_id, current_id, limit):
|
||||||
|
"""Get all new events"""
|
||||||
|
|
||||||
|
def get_all_new_events_stream_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT e.stream_ordering, e.event_id"
|
||||||
|
" FROM events AS e"
|
||||||
|
" WHERE"
|
||||||
|
" ? < e.stream_ordering AND e.stream_ordering <= ?"
|
||||||
|
" ORDER BY e.stream_ordering ASC"
|
||||||
|
" LIMIT ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(sql, (from_id, current_id, limit))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
|
||||||
|
upper_bound = current_id
|
||||||
|
if len(rows) == limit:
|
||||||
|
upper_bound = rows[-1][0]
|
||||||
|
|
||||||
|
return upper_bound, [row[1] for row in rows]
|
||||||
|
|
||||||
|
upper_bound, event_ids = yield self.runInteraction(
|
||||||
|
"get_all_new_events_stream", get_all_new_events_stream_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
events = yield self._get_events(event_ids)
|
||||||
|
|
||||||
|
defer.returnValue((upper_bound, events))
|
||||||
|
|
||||||
|
def get_federation_out_pos(self, typ):
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="federation_stream_position",
|
||||||
|
retcol="stream_id",
|
||||||
|
keyvalues={"type": typ},
|
||||||
|
desc="get_federation_out_pos"
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_federation_out_pos(self, typ, stream_id):
|
||||||
|
return self._simple_update_one(
|
||||||
|
table="federation_stream_position",
|
||||||
|
keyvalues={"type": typ},
|
||||||
|
updatevalues={"stream_id": stream_id},
|
||||||
|
desc="update_federation_out_pos",
|
||||||
|
)
|
||||||
|
@ -200,23 +200,46 @@ class TransactionStore(SQLBaseStore):
|
|||||||
|
|
||||||
def _set_destination_retry_timings(self, txn, destination,
|
def _set_destination_retry_timings(self, txn, destination,
|
||||||
retry_last_ts, retry_interval):
|
retry_last_ts, retry_interval):
|
||||||
txn.call_after(self.get_destination_retry_timings.invalidate, (destination,))
|
self.database_engine.lock_table(txn, "destinations")
|
||||||
|
|
||||||
self._simple_upsert_txn(
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_destination_retry_timings, (destination,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# We need to be careful here as the data may have changed from under us
|
||||||
|
# due to a worker setting the timings.
|
||||||
|
|
||||||
|
prev_row = self._simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
table="destinations",
|
||||||
|
keyvalues={
|
||||||
|
"destination": destination,
|
||||||
|
},
|
||||||
|
retcols=("retry_last_ts", "retry_interval"),
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not prev_row:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="destinations",
|
||||||
|
values={
|
||||||
|
"destination": destination,
|
||||||
|
"retry_last_ts": retry_last_ts,
|
||||||
|
"retry_interval": retry_interval,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
|
||||||
|
self._simple_update_one_txn(
|
||||||
txn,
|
txn,
|
||||||
"destinations",
|
"destinations",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"destination": destination,
|
"destination": destination,
|
||||||
},
|
},
|
||||||
values={
|
updatevalues={
|
||||||
"retry_last_ts": retry_last_ts,
|
"retry_last_ts": retry_last_ts,
|
||||||
"retry_interval": retry_interval,
|
"retry_interval": retry_interval,
|
||||||
},
|
},
|
||||||
insertion_values={
|
|
||||||
"destination": destination,
|
|
||||||
"retry_last_ts": retry_last_ts,
|
|
||||||
"retry_interval": retry_interval,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_destinations_needing_retry(self):
|
def get_destinations_needing_retry(self):
|
||||||
|
@ -274,3 +274,37 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||||||
return "t%d-%d" % (self.topological, self.stream)
|
return "t%d-%d" % (self.topological, self.stream)
|
||||||
else:
|
else:
|
||||||
return "s%d" % (self.stream,)
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
|
||||||
|
class ThirdPartyInstanceID(
|
||||||
|
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
|
||||||
|
):
|
||||||
|
# Deny iteration because it will bite you if you try to create a singleton
|
||||||
|
# set by:
|
||||||
|
# users = set(user)
|
||||||
|
def __iter__(self):
|
||||||
|
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
|
||||||
|
|
||||||
|
# Because this class is a namedtuple of strings, it is deeply immutable.
|
||||||
|
def __copy__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo):
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, s):
|
||||||
|
bits = s.split("|", 2)
|
||||||
|
if len(bits) != 2:
|
||||||
|
raise SynapseError(400, "Invalid ID %r" % (s,))
|
||||||
|
|
||||||
|
return cls(appservice_id=bits[0], network_id=bits[1])
|
||||||
|
|
||||||
|
def to_string(self):
|
||||||
|
return "%s|%s" % (self.appservice_id, self.network_id,)
|
||||||
|
|
||||||
|
__str__ = to_string
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, appservice_id, network_id,):
|
||||||
|
return cls(appservice_id=appservice_id, network_id=network_id)
|
||||||
|
@ -24,6 +24,11 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeferredTimedOutError(SynapseError):
|
||||||
|
def __init__(self):
|
||||||
|
super(SynapseError).__init__(504, "Timed out")
|
||||||
|
|
||||||
|
|
||||||
def unwrapFirstError(failure):
|
def unwrapFirstError(failure):
|
||||||
# defer.gatherResults and DeferredLists wrap failures.
|
# defer.gatherResults and DeferredLists wrap failures.
|
||||||
failure.trap(defer.FirstError)
|
failure.trap(defer.FirstError)
|
||||||
@ -89,7 +94,7 @@ class Clock(object):
|
|||||||
|
|
||||||
def timed_out_fn():
|
def timed_out_fn():
|
||||||
try:
|
try:
|
||||||
ret_deferred.errback(SynapseError(504, "Timed out"))
|
ret_deferred.errback(DeferredTimedOutError())
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -197,6 +197,64 @@ class Linearizer(object):
|
|||||||
defer.returnValue(_ctx_manager())
|
defer.returnValue(_ctx_manager())
|
||||||
|
|
||||||
|
|
||||||
|
class Limiter(object):
|
||||||
|
"""Limits concurrent access to resources based on a key. Useful to ensure
|
||||||
|
only a few thing happen at a time on a given resource.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
with (yield limiter.queue("test_key")):
|
||||||
|
# do some work.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, max_count):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
max_count(int): The maximum number of concurrent access
|
||||||
|
"""
|
||||||
|
self.max_count = max_count
|
||||||
|
|
||||||
|
# key_to_defer is a map from the key to a 2 element list where
|
||||||
|
# the first element is the number of things executing
|
||||||
|
# the second element is a list of deferreds for the things blocked from
|
||||||
|
# executing.
|
||||||
|
self.key_to_defer = {}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def queue(self, key):
|
||||||
|
entry = self.key_to_defer.setdefault(key, [0, []])
|
||||||
|
|
||||||
|
# If the number of things executing is greater than the maximum
|
||||||
|
# then add a deferred to the list of blocked items
|
||||||
|
# When on of the things currently executing finishes it will callback
|
||||||
|
# this item so that it can continue executing.
|
||||||
|
if entry[0] >= self.max_count:
|
||||||
|
new_defer = defer.Deferred()
|
||||||
|
entry[1].append(new_defer)
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
yield new_defer
|
||||||
|
|
||||||
|
entry[0] += 1
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _ctx_manager():
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# We've finished executing so check if there are any things
|
||||||
|
# blocked waiting to execute and start one of them
|
||||||
|
entry[0] -= 1
|
||||||
|
try:
|
||||||
|
entry[1].pop(0).callback(None)
|
||||||
|
except IndexError:
|
||||||
|
# If nothing else is executing for this key then remove it
|
||||||
|
# from the map
|
||||||
|
if entry[0] == 0:
|
||||||
|
self.key_to_defer.pop(key, None)
|
||||||
|
|
||||||
|
defer.returnValue(_ctx_manager())
|
||||||
|
|
||||||
|
|
||||||
class ReadWriteLock(object):
|
class ReadWriteLock(object):
|
||||||
"""A deferred style read write lock.
|
"""A deferred style read write lock.
|
||||||
|
|
||||||
|
@ -76,15 +76,26 @@ class JsonEncodedObject(object):
|
|||||||
d.update(self.unrecognized_keys)
|
d.update(self.unrecognized_keys)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
def get_internal_dict(self):
|
||||||
|
d = {
|
||||||
|
k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
|
||||||
|
if k in self.valid_keys
|
||||||
|
}
|
||||||
|
d.update(self.unrecognized_keys)
|
||||||
|
return d
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
|
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
|
||||||
|
|
||||||
|
|
||||||
def _encode(obj):
|
def _encode(obj, internal=False):
|
||||||
if type(obj) is list:
|
if type(obj) is list:
|
||||||
return [_encode(o) for o in obj]
|
return [_encode(o, internal=internal) for o in obj]
|
||||||
|
|
||||||
if isinstance(obj, JsonEncodedObject):
|
if isinstance(obj, JsonEncodedObject):
|
||||||
|
if internal:
|
||||||
|
return obj.get_internal_dict()
|
||||||
|
else:
|
||||||
return obj.get_dict()
|
return obj.get_dict()
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
@ -1,369 +0,0 @@
|
|||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.config._base import ConfigError
|
|
||||||
from synapse.types import UserID
|
|
||||||
|
|
||||||
import ldap3
|
|
||||||
import ldap3.core.exceptions
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ldap3
|
|
||||||
import ldap3.core.exceptions
|
|
||||||
except ImportError:
|
|
||||||
ldap3 = None
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LDAPMode(object):
|
|
||||||
SIMPLE = "simple",
|
|
||||||
SEARCH = "search",
|
|
||||||
|
|
||||||
LIST = (SIMPLE, SEARCH)
|
|
||||||
|
|
||||||
|
|
||||||
class LdapAuthProvider(object):
|
|
||||||
__version__ = "0.1"
|
|
||||||
|
|
||||||
def __init__(self, config, account_handler):
|
|
||||||
self.account_handler = account_handler
|
|
||||||
|
|
||||||
if not ldap3:
|
|
||||||
raise RuntimeError(
|
|
||||||
'Missing ldap3 library. This is required for LDAP Authentication.'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ldap_mode = config.mode
|
|
||||||
self.ldap_uri = config.uri
|
|
||||||
self.ldap_start_tls = config.start_tls
|
|
||||||
self.ldap_base = config.base
|
|
||||||
self.ldap_attributes = config.attributes
|
|
||||||
if self.ldap_mode == LDAPMode.SEARCH:
|
|
||||||
self.ldap_bind_dn = config.bind_dn
|
|
||||||
self.ldap_bind_password = config.bind_password
|
|
||||||
self.ldap_filter = config.filter
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def check_password(self, user_id, password):
|
|
||||||
""" Attempt to authenticate a user against an LDAP Server
|
|
||||||
and register an account if none exists.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if authentication against LDAP was successful
|
|
||||||
"""
|
|
||||||
localpart = UserID.from_string(user_id).localpart
|
|
||||||
|
|
||||||
try:
|
|
||||||
server = ldap3.Server(self.ldap_uri)
|
|
||||||
logger.debug(
|
|
||||||
"Attempting LDAP connection with %s",
|
|
||||||
self.ldap_uri
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.ldap_mode == LDAPMode.SIMPLE:
|
|
||||||
result, conn = self._ldap_simple_bind(
|
|
||||||
server=server, localpart=localpart, password=password
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
'LDAP authentication method simple bind returned: %s (conn: %s)',
|
|
||||||
result,
|
|
||||||
conn
|
|
||||||
)
|
|
||||||
if not result:
|
|
||||||
defer.returnValue(False)
|
|
||||||
elif self.ldap_mode == LDAPMode.SEARCH:
|
|
||||||
result, conn = self._ldap_authenticated_search(
|
|
||||||
server=server, localpart=localpart, password=password
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
'LDAP auth method authenticated search returned: %s (conn: %s)',
|
|
||||||
result,
|
|
||||||
conn
|
|
||||||
)
|
|
||||||
if not result:
|
|
||||||
defer.returnValue(False)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
'Invalid LDAP mode specified: {mode}'.format(
|
|
||||||
mode=self.ldap_mode
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
"User authenticated against LDAP server: %s",
|
|
||||||
conn
|
|
||||||
)
|
|
||||||
except NameError:
|
|
||||||
logger.warn(
|
|
||||||
"Authentication method yielded no LDAP connection, aborting!"
|
|
||||||
)
|
|
||||||
defer.returnValue(False)
|
|
||||||
|
|
||||||
# check if user with user_id exists
|
|
||||||
if (yield self.account_handler.check_user_exists(user_id)):
|
|
||||||
# exists, authentication complete
|
|
||||||
conn.unbind()
|
|
||||||
defer.returnValue(True)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# does not exist, fetch metadata for account creation from
|
|
||||||
# existing ldap connection
|
|
||||||
query = "({prop}={value})".format(
|
|
||||||
prop=self.ldap_attributes['uid'],
|
|
||||||
value=localpart
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
|
|
||||||
query = "(&{filter}{user_filter})".format(
|
|
||||||
filter=query,
|
|
||||||
user_filter=self.ldap_filter
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"ldap registration filter: %s",
|
|
||||||
query
|
|
||||||
)
|
|
||||||
|
|
||||||
conn.search(
|
|
||||||
search_base=self.ldap_base,
|
|
||||||
search_filter=query,
|
|
||||||
attributes=[
|
|
||||||
self.ldap_attributes['name'],
|
|
||||||
self.ldap_attributes['mail']
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(conn.response) == 1:
|
|
||||||
attrs = conn.response[0]['attributes']
|
|
||||||
mail = attrs[self.ldap_attributes['mail']][0]
|
|
||||||
name = attrs[self.ldap_attributes['name']][0]
|
|
||||||
|
|
||||||
# create account
|
|
||||||
user_id, access_token = (
|
|
||||||
yield self.account_handler.register(localpart=localpart)
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: bind email, set displayname with data from ldap directory
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Registration based on LDAP data was successful: %d: %s (%s, %)",
|
|
||||||
user_id,
|
|
||||||
localpart,
|
|
||||||
name,
|
|
||||||
mail
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(True)
|
|
||||||
else:
|
|
||||||
if len(conn.response) == 0:
|
|
||||||
logger.warn("LDAP registration failed, no result.")
|
|
||||||
else:
|
|
||||||
logger.warn(
|
|
||||||
"LDAP registration failed, too many results (%s)",
|
|
||||||
len(conn.response)
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(False)
|
|
||||||
|
|
||||||
defer.returnValue(False)
|
|
||||||
|
|
||||||
except ldap3.core.exceptions.LDAPException as e:
|
|
||||||
logger.warn("Error during ldap authentication: %s", e)
|
|
||||||
defer.returnValue(False)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_config(config):
|
|
||||||
class _LdapConfig(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
ldap_config = _LdapConfig()
|
|
||||||
|
|
||||||
ldap_config.enabled = config.get("enabled", False)
|
|
||||||
|
|
||||||
ldap_config.mode = LDAPMode.SIMPLE
|
|
||||||
|
|
||||||
# verify config sanity
|
|
||||||
_require_keys(config, [
|
|
||||||
"uri",
|
|
||||||
"base",
|
|
||||||
"attributes",
|
|
||||||
])
|
|
||||||
|
|
||||||
ldap_config.uri = config["uri"]
|
|
||||||
ldap_config.start_tls = config.get("start_tls", False)
|
|
||||||
ldap_config.base = config["base"]
|
|
||||||
ldap_config.attributes = config["attributes"]
|
|
||||||
|
|
||||||
if "bind_dn" in config:
|
|
||||||
ldap_config.mode = LDAPMode.SEARCH
|
|
||||||
_require_keys(config, [
|
|
||||||
"bind_dn",
|
|
||||||
"bind_password",
|
|
||||||
])
|
|
||||||
|
|
||||||
ldap_config.bind_dn = config["bind_dn"]
|
|
||||||
ldap_config.bind_password = config["bind_password"]
|
|
||||||
ldap_config.filter = config.get("filter", None)
|
|
||||||
|
|
||||||
# verify attribute lookup
|
|
||||||
_require_keys(config['attributes'], [
|
|
||||||
"uid",
|
|
||||||
"name",
|
|
||||||
"mail",
|
|
||||||
])
|
|
||||||
|
|
||||||
return ldap_config
|
|
||||||
|
|
||||||
def _ldap_simple_bind(self, server, localpart, password):
|
|
||||||
""" Attempt a simple bind with the credentials
|
|
||||||
given by the user against the LDAP server.
|
|
||||||
|
|
||||||
Returns True, LDAP3Connection
|
|
||||||
if the bind was successful
|
|
||||||
Returns False, None
|
|
||||||
if an error occured
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# bind with the the local users ldap credentials
|
|
||||||
bind_dn = "{prop}={value},{base}".format(
|
|
||||||
prop=self.ldap_attributes['uid'],
|
|
||||||
value=localpart,
|
|
||||||
base=self.ldap_base
|
|
||||||
)
|
|
||||||
conn = ldap3.Connection(server, bind_dn, password,
|
|
||||||
authentication=ldap3.AUTH_SIMPLE)
|
|
||||||
logger.debug(
|
|
||||||
"Established LDAP connection in simple bind mode: %s",
|
|
||||||
conn
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.ldap_start_tls:
|
|
||||||
conn.start_tls()
|
|
||||||
logger.debug(
|
|
||||||
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
|
|
||||||
conn
|
|
||||||
)
|
|
||||||
|
|
||||||
if conn.bind():
|
|
||||||
# GOOD: bind okay
|
|
||||||
logger.debug("LDAP Bind successful in simple bind mode.")
|
|
||||||
return True, conn
|
|
||||||
|
|
||||||
# BAD: bind failed
|
|
||||||
logger.info(
|
|
||||||
"Binding against LDAP failed for '%s' failed: %s",
|
|
||||||
localpart, conn.result['description']
|
|
||||||
)
|
|
||||||
conn.unbind()
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
except ldap3.core.exceptions.LDAPException as e:
|
|
||||||
logger.warn("Error during LDAP authentication: %s", e)
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
def _ldap_authenticated_search(self, server, localpart, password):
|
|
||||||
""" Attempt to login with the preconfigured bind_dn
|
|
||||||
and then continue searching and filtering within
|
|
||||||
the base_dn
|
|
||||||
|
|
||||||
Returns (True, LDAP3Connection)
|
|
||||||
if a single matching DN within the base was found
|
|
||||||
that matched the filter expression, and with which
|
|
||||||
a successful bind was achieved
|
|
||||||
|
|
||||||
The LDAP3Connection returned is the instance that was used to
|
|
||||||
verify the password not the one using the configured bind_dn.
|
|
||||||
Returns (False, None)
|
|
||||||
if an error occured
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn = ldap3.Connection(
|
|
||||||
server,
|
|
||||||
self.ldap_bind_dn,
|
|
||||||
self.ldap_bind_password
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Established LDAP connection in search mode: %s",
|
|
||||||
conn
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.ldap_start_tls:
|
|
||||||
conn.start_tls()
|
|
||||||
logger.debug(
|
|
||||||
"Upgraded LDAP connection in search mode through StartTLS: %s",
|
|
||||||
conn
|
|
||||||
)
|
|
||||||
|
|
||||||
if not conn.bind():
|
|
||||||
logger.warn(
|
|
||||||
"Binding against LDAP with `bind_dn` failed: %s",
|
|
||||||
conn.result['description']
|
|
||||||
)
|
|
||||||
conn.unbind()
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
# construct search_filter like (uid=localpart)
|
|
||||||
query = "({prop}={value})".format(
|
|
||||||
prop=self.ldap_attributes['uid'],
|
|
||||||
value=localpart
|
|
||||||
)
|
|
||||||
if self.ldap_filter:
|
|
||||||
# combine with the AND expression
|
|
||||||
query = "(&{query}{filter})".format(
|
|
||||||
query=query,
|
|
||||||
filter=self.ldap_filter
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"LDAP search filter: %s",
|
|
||||||
query
|
|
||||||
)
|
|
||||||
conn.search(
|
|
||||||
search_base=self.ldap_base,
|
|
||||||
search_filter=query
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(conn.response) == 1:
|
|
||||||
# GOOD: found exactly one result
|
|
||||||
user_dn = conn.response[0]['dn']
|
|
||||||
logger.debug('LDAP search found dn: %s', user_dn)
|
|
||||||
|
|
||||||
# unbind and simple bind with user_dn to verify the password
|
|
||||||
# Note: do not use rebind(), for some reason it did not verify
|
|
||||||
# the password for me!
|
|
||||||
conn.unbind()
|
|
||||||
return self._ldap_simple_bind(server, localpart, password)
|
|
||||||
else:
|
|
||||||
# BAD: found 0 or > 1 results, abort!
|
|
||||||
if len(conn.response) == 0:
|
|
||||||
logger.info(
|
|
||||||
"LDAP search returned no results for '%s'",
|
|
||||||
localpart
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"LDAP search returned too many (%s) results for '%s'",
|
|
||||||
len(conn.response), localpart
|
|
||||||
)
|
|
||||||
conn.unbind()
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
except ldap3.core.exceptions.LDAPException as e:
|
|
||||||
logger.warn("Error during LDAP authentication: %s", e)
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
|
|
||||||
def _require_keys(config, required):
|
|
||||||
missing = [key for key in required if key not in config]
|
|
||||||
if missing:
|
|
||||||
raise ConfigError(
|
|
||||||
"LDAP enabled but missing required config values: {}".format(
|
|
||||||
", ".join(missing)
|
|
||||||
)
|
|
||||||
)
|
|
@ -121,15 +121,9 @@ class RetryDestinationLimiter(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
def err(failure):
|
|
||||||
logger.exception(
|
|
||||||
"Failed to store set_destination_retry_timings",
|
|
||||||
failure.value
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_err_code = False
|
valid_err_code = False
|
||||||
if exc_type is not None and issubclass(exc_type, CodeMessageException):
|
if exc_type is not None and issubclass(exc_type, CodeMessageException):
|
||||||
valid_err_code = 0 <= exc_val.code < 500
|
valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500
|
||||||
|
|
||||||
if exc_type is None or valid_err_code:
|
if exc_type is None or valid_err_code:
|
||||||
# We connected successfully.
|
# We connected successfully.
|
||||||
@ -151,6 +145,15 @@ class RetryDestinationLimiter(object):
|
|||||||
|
|
||||||
retry_last_ts = int(self.clock.time_msec())
|
retry_last_ts = int(self.clock.time_msec())
|
||||||
|
|
||||||
self.store.set_destination_retry_timings(
|
@defer.inlineCallbacks
|
||||||
|
def store_retry_timings():
|
||||||
|
try:
|
||||||
|
yield self.store.set_destination_retry_timings(
|
||||||
self.destination, retry_last_ts, self.retry_interval
|
self.destination, retry_last_ts, self.retry_interval
|
||||||
).addErrback(err)
|
)
|
||||||
|
except:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to store set_destination_retry_timings",
|
||||||
|
)
|
||||||
|
|
||||||
|
store_retry_timings()
|
||||||
|
@ -12,17 +12,22 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from tests import unittest
|
|
||||||
|
import pymacaroons
|
||||||
|
from mock import Mock
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from mock import Mock
|
import synapse.handlers.auth
|
||||||
|
|
||||||
from synapse.api.auth import Auth
|
from synapse.api.auth import Auth
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from tests import unittest
|
||||||
from tests.utils import setup_test_homeserver, mock_getRawHeaders
|
from tests.utils import setup_test_homeserver, mock_getRawHeaders
|
||||||
|
|
||||||
import pymacaroons
|
|
||||||
|
class TestHandlers(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
|
||||||
|
|
||||||
|
|
||||||
class AuthTestCase(unittest.TestCase):
|
class AuthTestCase(unittest.TestCase):
|
||||||
@ -34,14 +39,17 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.hs = yield setup_test_homeserver(handlers=None)
|
self.hs = yield setup_test_homeserver(handlers=None)
|
||||||
self.hs.get_datastore = Mock(return_value=self.store)
|
self.hs.get_datastore = Mock(return_value=self.store)
|
||||||
|
self.hs.handlers = TestHandlers(self.hs)
|
||||||
self.auth = Auth(self.hs)
|
self.auth = Auth(self.hs)
|
||||||
|
|
||||||
self.test_user = "@foo:bar"
|
self.test_user = "@foo:bar"
|
||||||
self.test_token = "_test_token_"
|
self.test_token = "_test_token_"
|
||||||
|
|
||||||
|
# this is overridden for the appservice tests
|
||||||
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_by_req_user_valid_token(self):
|
def test_get_user_by_req_user_valid_token(self):
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
|
||||||
user_info = {
|
user_info = {
|
||||||
"name": self.test_user,
|
"name": self.test_user,
|
||||||
"token_id": "ditto",
|
"token_id": "ditto",
|
||||||
@ -56,7 +64,6 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
|
|
||||||
def test_get_user_by_req_user_bad_token(self):
|
def test_get_user_by_req_user_bad_token(self):
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
@ -66,7 +73,6 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
|
||||||
def test_get_user_by_req_user_missing_token(self):
|
def test_get_user_by_req_user_missing_token(self):
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
|
||||||
user_info = {
|
user_info = {
|
||||||
"name": self.test_user,
|
"name": self.test_user,
|
||||||
"token_id": "ditto",
|
"token_id": "ditto",
|
||||||
@ -158,7 +164,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||||
user = user_info["user"]
|
user = user_info["user"]
|
||||||
self.assertEqual(UserID.from_string(user_id), user)
|
self.assertEqual(UserID.from_string(user_id), user)
|
||||||
|
|
||||||
@ -168,6 +174,10 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_guest_user_from_macaroon(self):
|
def test_get_guest_user_from_macaroon(self):
|
||||||
|
self.store.get_user_by_id = Mock(return_value={
|
||||||
|
"is_guest": True,
|
||||||
|
})
|
||||||
|
|
||||||
user_id = "@baldrick:matrix.org"
|
user_id = "@baldrick:matrix.org"
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.hs.config.server_name,
|
location=self.hs.config.server_name,
|
||||||
@ -179,11 +189,12 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon.add_first_party_caveat("guest = true")
|
macaroon.add_first_party_caveat("guest = true")
|
||||||
serialized = macaroon.serialize()
|
serialized = macaroon.serialize()
|
||||||
|
|
||||||
user_info = yield self.auth.get_user_from_macaroon(serialized)
|
user_info = yield self.auth.get_user_by_access_token(serialized)
|
||||||
user = user_info["user"]
|
user = user_info["user"]
|
||||||
is_guest = user_info["is_guest"]
|
is_guest = user_info["is_guest"]
|
||||||
self.assertEqual(UserID.from_string(user_id), user)
|
self.assertEqual(UserID.from_string(user_id), user)
|
||||||
self.assertTrue(is_guest)
|
self.assertTrue(is_guest)
|
||||||
|
self.store.get_user_by_id.assert_called_with(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_from_macaroon_user_db_mismatch(self):
|
def test_get_user_from_macaroon_user_db_mismatch(self):
|
||||||
@ -200,7 +211,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||||
with self.assertRaises(AuthError) as cm:
|
with self.assertRaises(AuthError) as cm:
|
||||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||||
self.assertEqual(401, cm.exception.code)
|
self.assertEqual(401, cm.exception.code)
|
||||||
self.assertIn("User mismatch", cm.exception.msg)
|
self.assertIn("User mismatch", cm.exception.msg)
|
||||||
|
|
||||||
@ -220,7 +231,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
|
|
||||||
with self.assertRaises(AuthError) as cm:
|
with self.assertRaises(AuthError) as cm:
|
||||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||||
self.assertEqual(401, cm.exception.code)
|
self.assertEqual(401, cm.exception.code)
|
||||||
self.assertIn("No user caveat", cm.exception.msg)
|
self.assertIn("No user caveat", cm.exception.msg)
|
||||||
|
|
||||||
@ -242,7 +253,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||||
|
|
||||||
with self.assertRaises(AuthError) as cm:
|
with self.assertRaises(AuthError) as cm:
|
||||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||||
self.assertEqual(401, cm.exception.code)
|
self.assertEqual(401, cm.exception.code)
|
||||||
self.assertIn("Invalid macaroon", cm.exception.msg)
|
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||||
|
|
||||||
@ -265,7 +276,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon.add_first_party_caveat("cunning > fox")
|
macaroon.add_first_party_caveat("cunning > fox")
|
||||||
|
|
||||||
with self.assertRaises(AuthError) as cm:
|
with self.assertRaises(AuthError) as cm:
|
||||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||||
self.assertEqual(401, cm.exception.code)
|
self.assertEqual(401, cm.exception.code)
|
||||||
self.assertIn("Invalid macaroon", cm.exception.msg)
|
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||||
|
|
||||||
@ -293,12 +304,12 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.hs.clock.now = 5000 # seconds
|
self.hs.clock.now = 5000 # seconds
|
||||||
self.hs.config.expire_access_token = True
|
self.hs.config.expire_access_token = True
|
||||||
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
# yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||||
# TODO(daniel): Turn on the check that we validate expiration, when we
|
# TODO(daniel): Turn on the check that we validate expiration, when we
|
||||||
# validate expiration (and remove the above line, which will start
|
# validate expiration (and remove the above line, which will start
|
||||||
# throwing).
|
# throwing).
|
||||||
with self.assertRaises(AuthError) as cm:
|
with self.assertRaises(AuthError) as cm:
|
||||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||||
self.assertEqual(401, cm.exception.code)
|
self.assertEqual(401, cm.exception.code)
|
||||||
self.assertIn("Invalid macaroon", cm.exception.msg)
|
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||||
|
|
||||||
@ -327,6 +338,58 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.hs.clock.now = 5000 # seconds
|
self.hs.clock.now = 5000 # seconds
|
||||||
self.hs.config.expire_access_token = True
|
self.hs.config.expire_access_token = True
|
||||||
|
|
||||||
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
|
||||||
user = user_info["user"]
|
user = user_info["user"]
|
||||||
self.assertEqual(UserID.from_string(user_id), user)
|
self.assertEqual(UserID.from_string(user_id), user)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cannot_use_regular_token_as_guest(self):
|
||||||
|
USER_ID = "@percy:matrix.org"
|
||||||
|
self.store.add_access_token_to_user = Mock()
|
||||||
|
|
||||||
|
token = yield self.hs.handlers.auth_handler.issue_access_token(
|
||||||
|
USER_ID, "DEVICE"
|
||||||
|
)
|
||||||
|
self.store.add_access_token_to_user.assert_called_with(
|
||||||
|
USER_ID, token, "DEVICE"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_user(tok):
|
||||||
|
if token != tok:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"name": USER_ID,
|
||||||
|
"is_guest": False,
|
||||||
|
"token_id": 1234,
|
||||||
|
"device_id": "DEVICE",
|
||||||
|
}
|
||||||
|
self.store.get_user_by_access_token = get_user
|
||||||
|
self.store.get_user_by_id = Mock(return_value={
|
||||||
|
"is_guest": False,
|
||||||
|
})
|
||||||
|
|
||||||
|
# check the token works
|
||||||
|
request = Mock(args={})
|
||||||
|
request.args["access_token"] = [token]
|
||||||
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
self.assertEqual(UserID.from_string(USER_ID), requester.user)
|
||||||
|
self.assertFalse(requester.is_guest)
|
||||||
|
|
||||||
|
# add an is_guest caveat
|
||||||
|
mac = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
mac.add_first_party_caveat("guest = true")
|
||||||
|
guest_tok = mac.serialize()
|
||||||
|
|
||||||
|
# the token should *not* work now
|
||||||
|
request = Mock(args={})
|
||||||
|
request.args["access_token"] = [guest_tok]
|
||||||
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
|
||||||
|
with self.assertRaises(AuthError) as cm:
|
||||||
|
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
|
self.assertEqual(401, cm.exception.code)
|
||||||
|
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
|
||||||
|
|
||||||
|
self.store.get_user_by_id.assert_called_with(USER_ID)
|
||||||
|
@ -17,7 +17,11 @@
|
|||||||
from .. import unittest
|
from .. import unittest
|
||||||
|
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event, serialize_event
|
||||||
|
|
||||||
|
|
||||||
|
def MockEvent(**kwargs):
|
||||||
|
return FrozenEvent(kwargs)
|
||||||
|
|
||||||
|
|
||||||
class PruneEventTestCase(unittest.TestCase):
|
class PruneEventTestCase(unittest.TestCase):
|
||||||
@ -114,3 +118,167 @@ class PruneEventTestCase(unittest.TestCase):
|
|||||||
'unsigned': {},
|
'unsigned': {},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SerializeEventTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def serialize(self, ev, fields):
|
||||||
|
return serialize_event(ev, 1479807801915, only_event_fields=fields)
|
||||||
|
|
||||||
|
def test_event_fields_works_with_keys(self):
|
||||||
|
self.assertEquals(
|
||||||
|
self.serialize(
|
||||||
|
MockEvent(
|
||||||
|
sender="@alice:localhost",
|
||||||
|
room_id="!foo:bar"
|
||||||
|
),
|
||||||
|
["room_id"]
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"room_id": "!foo:bar",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_event_fields_works_with_nested_keys(self):
|
||||||
|
self.assertEquals(
|
||||||
|
self.serialize(
|
||||||
|
MockEvent(
|
||||||
|
sender="@alice:localhost",
|
||||||
|
room_id="!foo:bar",
|
||||||
|
content={
|
||||||
|
"body": "A message",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
["content.body"]
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"body": "A message",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_event_fields_works_with_dot_keys(self):
|
||||||
|
self.assertEquals(
|
||||||
|
self.serialize(
|
||||||
|
MockEvent(
|
||||||
|
sender="@alice:localhost",
|
||||||
|
room_id="!foo:bar",
|
||||||
|
content={
|
||||||
|
"key.with.dots": {},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
["content.key\.with\.dots"]
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"key.with.dots": {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_event_fields_works_with_nested_dot_keys(self):
|
||||||
|
self.assertEquals(
|
||||||
|
self.serialize(
|
||||||
|
MockEvent(
|
||||||
|
sender="@alice:localhost",
|
||||||
|
room_id="!foo:bar",
|
||||||
|
content={
|
||||||
|
"not_me": 1,
|
||||||
|
"nested.dot.key": {
|
||||||
|
"leaf.key": 42,
|
||||||
|
"not_me_either": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
["content.nested\.dot\.key.leaf\.key"]
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"nested.dot.key": {
|
||||||
|
"leaf.key": 42,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_event_fields_nops_with_unknown_keys(self):
|
||||||
|
self.assertEquals(
|
||||||
|
self.serialize(
|
||||||
|
MockEvent(
|
||||||
|
sender="@alice:localhost",
|
||||||
|
room_id="!foo:bar",
|
||||||
|
content={
|
||||||
|
"foo": "bar",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
["content.foo", "content.notexists"]
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"foo": "bar",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_event_fields_nops_with_non_dict_keys(self):
|
||||||
|
self.assertEquals(
|
||||||
|
self.serialize(
|
||||||
|
MockEvent(
|
||||||
|
sender="@alice:localhost",
|
||||||
|
room_id="!foo:bar",
|
||||||
|
content={
|
||||||
|
"foo": ["I", "am", "an", "array"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
["content.foo.am"]
|
||||||
|
),
|
||||||
|
{}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_event_fields_nops_with_array_keys(self):
|
||||||
|
self.assertEquals(
|
||||||
|
self.serialize(
|
||||||
|
MockEvent(
|
||||||
|
sender="@alice:localhost",
|
||||||
|
room_id="!foo:bar",
|
||||||
|
content={
|
||||||
|
"foo": ["I", "am", "an", "array"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
["content.foo.1"]
|
||||||
|
),
|
||||||
|
{}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_event_fields_all_fields_if_empty(self):
|
||||||
|
self.assertEquals(
|
||||||
|
self.serialize(
|
||||||
|
MockEvent(
|
||||||
|
room_id="!foo:bar",
|
||||||
|
content={
|
||||||
|
"foo": "bar",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
[]
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"room_id": "!foo:bar",
|
||||||
|
"content": {
|
||||||
|
"foo": "bar",
|
||||||
|
},
|
||||||
|
"unsigned": {}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_event_fields_fail_if_fields_not_str(self):
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
self.serialize(
|
||||||
|
MockEvent(
|
||||||
|
room_id="!foo:bar",
|
||||||
|
content={
|
||||||
|
"foo": "bar",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
["room_id", 4]
|
||||||
|
)
|
||||||
|
@ -61,14 +61,14 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
def verify_type(caveat):
|
def verify_type(caveat):
|
||||||
return caveat == "type = access"
|
return caveat == "type = access"
|
||||||
|
|
||||||
def verify_expiry(caveat):
|
def verify_nonce(caveat):
|
||||||
return caveat == "time < 8600000"
|
return caveat.startswith("nonce =")
|
||||||
|
|
||||||
v = pymacaroons.Verifier()
|
v = pymacaroons.Verifier()
|
||||||
v.satisfy_general(verify_gen)
|
v.satisfy_general(verify_gen)
|
||||||
v.satisfy_general(verify_user)
|
v.satisfy_general(verify_user)
|
||||||
v.satisfy_general(verify_type)
|
v.satisfy_general(verify_type)
|
||||||
v.satisfy_general(verify_expiry)
|
v.satisfy_general(verify_nonce)
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
def test_short_term_login_token_gives_user_id(self):
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
|
@ -53,13 +53,12 @@ class RegistrationTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||||
duration_ms = 200
|
|
||||||
local_part = "someone"
|
local_part = "someone"
|
||||||
display_name = "someone"
|
display_name = "someone"
|
||||||
user_id = "@someone:test"
|
user_id = "@someone:test"
|
||||||
requester = create_requester("@as:test")
|
requester = create_requester("@as:test")
|
||||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
requester, local_part, display_name, duration_ms)
|
requester, local_part, display_name)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
||||||
@ -71,12 +70,11 @@ class RegistrationTestCase(unittest.TestCase):
|
|||||||
user_id=frank.to_string(),
|
user_id=frank.to_string(),
|
||||||
token="jkv;g498752-43gj['eamb!-5",
|
token="jkv;g498752-43gj['eamb!-5",
|
||||||
password_hash=None)
|
password_hash=None)
|
||||||
duration_ms = 200
|
|
||||||
local_part = "frank"
|
local_part = "frank"
|
||||||
display_name = "Frank"
|
display_name = "Frank"
|
||||||
user_id = "@frank:test"
|
user_id = "@frank:test"
|
||||||
requester = create_requester("@as:test")
|
requester = create_requester("@as:test")
|
||||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
requester, local_part, display_name, duration_ms)
|
requester, local_part, display_name)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
@ -103,7 +103,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
|||||||
room_id = yield self.create_room()
|
room_id = yield self.create_room()
|
||||||
event_id = yield self.send_text_message(room_id, "Hello, World")
|
event_id = yield self.send_text_message(room_id, "Hello, World")
|
||||||
get = self.get(receipts="-1")
|
get = self.get(receipts="-1")
|
||||||
yield self.hs.get_handlers().receipts_handler.received_client_receipt(
|
yield self.hs.get_receipts_handler().received_client_receipt(
|
||||||
room_id, "m.read", self.user_id, event_id
|
room_id, "m.read", self.user_id, event_id
|
||||||
)
|
)
|
||||||
code, body = yield get
|
code, body = yield get
|
||||||
|
@ -67,8 +67,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||||||
self.registration_handler.appservice_register = Mock(
|
self.registration_handler.appservice_register = Mock(
|
||||||
return_value=user_id
|
return_value=user_id
|
||||||
)
|
)
|
||||||
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
self.auth_handler.get_access_token_for_user_id = Mock(
|
||||||
return_value=(token, "kermits_refresh_token")
|
return_value=token
|
||||||
)
|
)
|
||||||
|
|
||||||
(code, result) = yield self.servlet.on_POST(self.request)
|
(code, result) = yield self.servlet.on_POST(self.request)
|
||||||
@ -76,11 +76,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
"refresh_token": "kermits_refresh_token",
|
|
||||||
"home_server": self.hs.hostname
|
"home_server": self.hs.hostname
|
||||||
}
|
}
|
||||||
self.assertDictContainsSubset(det_data, result)
|
self.assertDictContainsSubset(det_data, result)
|
||||||
self.assertIn("refresh_token", result)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_POST_appservice_registration_invalid(self):
|
def test_POST_appservice_registration_invalid(self):
|
||||||
@ -126,8 +124,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
}, None)
|
}, None)
|
||||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||||
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
self.auth_handler.get_access_token_for_user_id = Mock(
|
||||||
return_value=(token, "kermits_refresh_token")
|
return_value=token
|
||||||
)
|
)
|
||||||
self.device_handler.check_device_registered = \
|
self.device_handler.check_device_registered = \
|
||||||
Mock(return_value=device_id)
|
Mock(return_value=device_id)
|
||||||
@ -137,12 +135,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
"refresh_token": "kermits_refresh_token",
|
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
}
|
}
|
||||||
self.assertDictContainsSubset(det_data, result)
|
self.assertDictContainsSubset(det_data, result)
|
||||||
self.assertIn("refresh_token", result)
|
|
||||||
self.auth_handler.get_login_tuple_for_user_id(
|
self.auth_handler.get_login_tuple_for_user_id(
|
||||||
user_id, device_id=device_id, initial_device_display_name=None)
|
user_id, device_id=device_id, initial_device_display_name=None)
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
|||||||
event_cache_size=1,
|
event_cache_size=1,
|
||||||
password_providers=[],
|
password_providers=[],
|
||||||
)
|
)
|
||||||
hs = yield setup_test_homeserver(config=config)
|
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
|
||||||
|
|
||||||
self.as_token = "token1"
|
self.as_token = "token1"
|
||||||
self.as_url = "some_url"
|
self.as_url = "some_url"
|
||||||
@ -112,7 +112,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
event_cache_size=1,
|
event_cache_size=1,
|
||||||
password_providers=[],
|
password_providers=[],
|
||||||
)
|
)
|
||||||
hs = yield setup_test_homeserver(config=config)
|
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
|
||||||
self.db_pool = hs.get_db_pool()
|
self.db_pool = hs.get_db_pool()
|
||||||
|
|
||||||
self.as_list = [
|
self.as_list = [
|
||||||
@ -443,7 +443,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
app_service_config_files=[f1, f2], event_cache_size=1,
|
app_service_config_files=[f1, f2], event_cache_size=1,
|
||||||
password_providers=[]
|
password_providers=[]
|
||||||
)
|
)
|
||||||
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
hs = yield setup_test_homeserver(
|
||||||
|
config=config,
|
||||||
|
datastore=Mock(),
|
||||||
|
federation_sender=Mock()
|
||||||
|
)
|
||||||
|
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(hs)
|
||||||
|
|
||||||
@ -456,7 +460,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
app_service_config_files=[f1, f2], event_cache_size=1,
|
app_service_config_files=[f1, f2], event_cache_size=1,
|
||||||
password_providers=[]
|
password_providers=[]
|
||||||
)
|
)
|
||||||
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
hs = yield setup_test_homeserver(
|
||||||
|
config=config,
|
||||||
|
datastore=Mock(),
|
||||||
|
federation_sender=Mock()
|
||||||
|
)
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(hs)
|
||||||
@ -475,7 +483,11 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
app_service_config_files=[f1, f2], event_cache_size=1,
|
app_service_config_files=[f1, f2], event_cache_size=1,
|
||||||
password_providers=[]
|
password_providers=[]
|
||||||
)
|
)
|
||||||
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
hs = yield setup_test_homeserver(
|
||||||
|
config=config,
|
||||||
|
datastore=Mock(),
|
||||||
|
federation_sender=Mock()
|
||||||
|
)
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(hs)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user