mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-01-27 19:06:59 -05:00
Merge branch 'release-v0.25.0' of github.com:matrix-org/synapse
This commit is contained in:
commit
552f123bea
58
CHANGES.rst
58
CHANGES.rst
@ -1,3 +1,61 @@
|
|||||||
|
Changes in synapse v0.25.0 (2017-11-15)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix port script (PR #2673)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.25.0-rc1 (2017-11-14)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Add is_public to groups table to allow for private groups (PR #2582)
|
||||||
|
* Add a route for determining who you are (PR #2668) Thanks to @turt2live!
|
||||||
|
* Add more features to the password providers (PR #2608, #2610, #2620, #2622,
|
||||||
|
#2623, #2624, #2626, #2628, #2629)
|
||||||
|
* Add a hook for custom rest endpoints (PR #2627)
|
||||||
|
* Add API to update group room visibility (PR #2651)
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Ignore <noscript> tags when generating URL preview descriptions (PR #2576)
|
||||||
|
Thanks to @maximevaillancourt!
|
||||||
|
* Register some /unstable endpoints in /r0 as well (PR #2579) Thanks to
|
||||||
|
@krombel!
|
||||||
|
* Support /keys/upload on /r0 as well as /unstable (PR #2585)
|
||||||
|
* Front-end proxy: pass through auth header (PR #2586)
|
||||||
|
* Allow ASes to deactivate their own users (PR #2589)
|
||||||
|
* Remove refresh tokens (PR #2613)
|
||||||
|
* Automatically set default displayname on register (PR #2617)
|
||||||
|
* Log login requests (PR #2618)
|
||||||
|
* Always return `is_public` in the `/groups/:group_id/rooms` API (PR #2630)
|
||||||
|
* Avoid no-op media deletes (PR #2637) Thanks to @spantaleev!
|
||||||
|
* Fix various embarrassing typos around user_directory and add some doc. (PR
|
||||||
|
#2643)
|
||||||
|
* Return whether a user is an admin within a group (PR #2647)
|
||||||
|
* Namespace visibility options for groups (PR #2657)
|
||||||
|
* Downcase UserIDs on registration (PR #2662)
|
||||||
|
* Cache failures when fetching URL previews (PR #2669)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix port script (PR #2577)
|
||||||
|
* Fix error when running synapse with no logfile (PR #2581)
|
||||||
|
* Fix UI auth when deleting devices (PR #2591)
|
||||||
|
* Fix typo when checking if user is invited to group (PR #2599)
|
||||||
|
* Fix the port script to drop NUL values in all tables (PR #2611)
|
||||||
|
* Fix appservices being backlogged and not receiving new events due to a bug in
|
||||||
|
notify_interested_services (PR #2631) Thanks to @xyzz!
|
||||||
|
* Fix updating rooms avatar/display name when modified by admin (PR #2636)
|
||||||
|
Thanks to @farialima!
|
||||||
|
* Fix bug in state group storage (PR #2649)
|
||||||
|
* Fix 500 on invalid utf-8 in request (PR #2663)
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.24.1 (2017-10-24)
|
Changes in synapse v0.24.1 (2017-10-24)
|
||||||
=======================================
|
=======================================
|
||||||
|
|
||||||
|
@ -823,7 +823,9 @@ spidering 'internal' URLs on your network. At the very least we recommend that
|
|||||||
your loopback and RFC1918 IP addresses are blacklisted.
|
your loopback and RFC1918 IP addresses are blacklisted.
|
||||||
|
|
||||||
This also requires the optional lxml and netaddr python dependencies to be
|
This also requires the optional lxml and netaddr python dependencies to be
|
||||||
installed.
|
installed. This in turn requires the libxml2 library to be available - on
|
||||||
|
Debian/Ubuntu this means ``apt-get install libxml2-dev``, or equivalent for
|
||||||
|
your OS.
|
||||||
|
|
||||||
|
|
||||||
Password reset
|
Password reset
|
||||||
|
@ -1,26 +1,13 @@
|
|||||||
Basically, PEP8
|
- Everything should comply with PEP8. Code should pass
|
||||||
|
``pep8 --max-line-length=100`` without any warnings.
|
||||||
|
|
||||||
- NEVER tabs. 4 spaces to indent.
|
- **Indenting**:
|
||||||
- Max line width: 79 chars (with flexibility to overflow by a "few chars" if
|
|
||||||
the overflowing content is not semantically significant and avoids an
|
- NEVER tabs. 4 spaces to indent.
|
||||||
explosion of vertical whitespace).
|
|
||||||
- Use camel case for class and type names
|
- follow PEP8; either hanging indent or multiline-visual indent depending
|
||||||
- Use underscores for functions and variables.
|
on the size and shape of the arguments and what makes more sense to the
|
||||||
- Use double quotes.
|
author. In other words, both this::
|
||||||
- Use parentheses instead of '\\' for line continuation where ever possible
|
|
||||||
(which is pretty much everywhere)
|
|
||||||
- There should be max a single new line between:
|
|
||||||
- statements
|
|
||||||
- functions in a class
|
|
||||||
- There should be two new lines between:
|
|
||||||
- definitions in a module (e.g., between different classes)
|
|
||||||
- There should be spaces where spaces should be and not where there shouldn't be:
|
|
||||||
- a single space after a comma
|
|
||||||
- a single space before and after for '=' when used as assignment
|
|
||||||
- no spaces before and after for '=' for default values and keyword arguments.
|
|
||||||
- Indenting must follow PEP8; either hanging indent or multiline-visual indent
|
|
||||||
depending on the size and shape of the arguments and what makes more sense to
|
|
||||||
the author. In other words, both this::
|
|
||||||
|
|
||||||
print("I am a fish %s" % "moo")
|
print("I am a fish %s" % "moo")
|
||||||
|
|
||||||
@ -33,20 +20,100 @@ Basically, PEP8
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
"I am a fish %s" %
|
"I am a fish %s" %
|
||||||
"moo"
|
"moo",
|
||||||
)
|
)
|
||||||
|
|
||||||
...are valid, although given each one takes up 2x more vertical space than
|
...are valid, although given each one takes up 2x more vertical space than
|
||||||
the previous, it's up to the author's discretion as to which layout makes most
|
the previous, it's up to the author's discretion as to which layout makes
|
||||||
sense for their function invocation. (e.g. if they want to add comments
|
most sense for their function invocation. (e.g. if they want to add
|
||||||
per-argument, or put expressions in the arguments, or group related arguments
|
comments per-argument, or put expressions in the arguments, or group
|
||||||
together, or want to deliberately extend or preserve vertical/horizontal
|
related arguments together, or want to deliberately extend or preserve
|
||||||
space)
|
vertical/horizontal space)
|
||||||
|
|
||||||
Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
|
- **Line length**:
|
||||||
This is so that we can generate documentation with
|
|
||||||
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
|
|
||||||
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
|
|
||||||
in the sphinx documentation.
|
|
||||||
|
|
||||||
Code should pass pep8 --max-line-length=100 without any warnings.
|
Max line length is 79 chars (with flexibility to overflow by a "few chars" if
|
||||||
|
the overflowing content is not semantically significant and avoids an
|
||||||
|
explosion of vertical whitespace).
|
||||||
|
|
||||||
|
Use parentheses instead of ``\`` for line continuation where ever possible
|
||||||
|
(which is pretty much everywhere).
|
||||||
|
|
||||||
|
- **Naming**:
|
||||||
|
|
||||||
|
- Use camel case for class and type names
|
||||||
|
- Use underscores for functions and variables.
|
||||||
|
|
||||||
|
- Use double quotes ``"foo"`` rather than single quotes ``'foo'``.
|
||||||
|
|
||||||
|
- **Blank lines**:
|
||||||
|
|
||||||
|
- There should be max a single new line between:
|
||||||
|
|
||||||
|
- statements
|
||||||
|
- functions in a class
|
||||||
|
|
||||||
|
- There should be two new lines between:
|
||||||
|
|
||||||
|
- definitions in a module (e.g., between different classes)
|
||||||
|
|
||||||
|
- **Whitespace**:
|
||||||
|
|
||||||
|
There should be spaces where spaces should be and not where there shouldn't
|
||||||
|
be:
|
||||||
|
|
||||||
|
- a single space after a comma
|
||||||
|
- a single space before and after for '=' when used as assignment
|
||||||
|
- no spaces before and after for '=' for default values and keyword arguments.
|
||||||
|
|
||||||
|
- **Comments**: should follow the `google code style
|
||||||
|
<http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
|
||||||
|
This is so that we can generate documentation with `sphinx
|
||||||
|
<http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
|
||||||
|
`examples
|
||||||
|
<http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
|
||||||
|
in the sphinx documentation.
|
||||||
|
|
||||||
|
- **Imports**:
|
||||||
|
|
||||||
|
- Prefer to import classes and functions than packages or modules.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
from synapse.types import UserID
|
||||||
|
...
|
||||||
|
user_id = UserID(local, server)
|
||||||
|
|
||||||
|
is preferred over::
|
||||||
|
|
||||||
|
from synapse import types
|
||||||
|
...
|
||||||
|
user_id = types.UserID(local, server)
|
||||||
|
|
||||||
|
(or any other variant).
|
||||||
|
|
||||||
|
This goes against the advice in the Google style guide, but it means that
|
||||||
|
errors in the name are caught early (at import time).
|
||||||
|
|
||||||
|
- Multiple imports from the same package can be combined onto one line::
|
||||||
|
|
||||||
|
from synapse.types import GroupID, RoomID, UserID
|
||||||
|
|
||||||
|
An effort should be made to keep the individual imports in alphabetical
|
||||||
|
order.
|
||||||
|
|
||||||
|
If the list becomes long, wrap it with parentheses and split it over
|
||||||
|
multiple lines.
|
||||||
|
|
||||||
|
- As per `PEP-8 <https://www.python.org/dev/peps/pep-0008/#imports>`_,
|
||||||
|
imports should be grouped in the following order, with a blank line between
|
||||||
|
each group:
|
||||||
|
|
||||||
|
1. standard library imports
|
||||||
|
2. related third party imports
|
||||||
|
3. local application/library specific imports
|
||||||
|
|
||||||
|
- Imports within each group should be sorted alphabetically by module name.
|
||||||
|
|
||||||
|
- Avoid wildcard imports (``from synapse.types import *``) and relative
|
||||||
|
imports (``from .types import UserID``).
|
||||||
|
99
docs/password_auth_providers.rst
Normal file
99
docs/password_auth_providers.rst
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
Password auth provider modules
|
||||||
|
==============================
|
||||||
|
|
||||||
|
Password auth providers offer a way for server administrators to integrate
|
||||||
|
their Synapse installation with an existing authentication system.
|
||||||
|
|
||||||
|
A password auth provider is a Python class which is dynamically loaded into
|
||||||
|
Synapse, and provides a number of methods by which it can integrate with the
|
||||||
|
authentication system.
|
||||||
|
|
||||||
|
This document serves as a reference for those looking to implement their own
|
||||||
|
password auth providers.
|
||||||
|
|
||||||
|
Required methods
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Password auth provider classes must provide the following methods:
|
||||||
|
|
||||||
|
*class* ``SomeProvider.parse_config``\(*config*)
|
||||||
|
|
||||||
|
This method is passed the ``config`` object for this module from the
|
||||||
|
homeserver configuration file.
|
||||||
|
|
||||||
|
It should perform any appropriate sanity checks on the provided
|
||||||
|
configuration, and return an object which is then passed into ``__init__``.
|
||||||
|
|
||||||
|
*class* ``SomeProvider``\(*config*, *account_handler*)
|
||||||
|
|
||||||
|
The constructor is passed the config object returned by ``parse_config``,
|
||||||
|
and a ``synapse.module_api.ModuleApi`` object which allows the
|
||||||
|
password provider to check if accounts exist and/or create new ones.
|
||||||
|
|
||||||
|
Optional methods
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Password auth provider classes may optionally provide the following methods.
|
||||||
|
|
||||||
|
*class* ``SomeProvider.get_db_schema_files``\()
|
||||||
|
|
||||||
|
This method, if implemented, should return an Iterable of ``(name,
|
||||||
|
stream)`` pairs of database schema files. Each file is applied in turn at
|
||||||
|
initialisation, and a record is then made in the database so that it is
|
||||||
|
not re-applied on the next start.
|
||||||
|
|
||||||
|
``someprovider.get_supported_login_types``\()
|
||||||
|
|
||||||
|
This method, if implemented, should return a ``dict`` mapping from a login
|
||||||
|
type identifier (such as ``m.login.password``) to an iterable giving the
|
||||||
|
fields which must be provided by the user in the submission to the
|
||||||
|
``/login`` api. These fields are passed in the ``login_dict`` dictionary
|
||||||
|
to ``check_auth``.
|
||||||
|
|
||||||
|
For example, if a password auth provider wants to implement a custom login
|
||||||
|
type of ``com.example.custom_login``, where the client is expected to pass
|
||||||
|
the fields ``secret1`` and ``secret2``, the provider should implement this
|
||||||
|
method and return the following dict::
|
||||||
|
|
||||||
|
{"com.example.custom_login": ("secret1", "secret2")}
|
||||||
|
|
||||||
|
``someprovider.check_auth``\(*username*, *login_type*, *login_dict*)
|
||||||
|
|
||||||
|
This method is the one that does the real work. If implemented, it will be
|
||||||
|
called for each login attempt where the login type matches one of the keys
|
||||||
|
returned by ``get_supported_login_types``.
|
||||||
|
|
||||||
|
It is passed the (possibly UNqualified) ``user`` provided by the client,
|
||||||
|
the login type, and a dictionary of login secrets passed by the client.
|
||||||
|
|
||||||
|
The method should return a Twisted ``Deferred`` object, which resolves to
|
||||||
|
the canonical ``@localpart:domain`` user id if authentication is successful,
|
||||||
|
and ``None`` if not.
|
||||||
|
|
||||||
|
Alternatively, the ``Deferred`` can resolve to a ``(str, func)`` tuple, in
|
||||||
|
which case the second field is a callback which will be called with the
|
||||||
|
result from the ``/login`` call (including ``access_token``, ``device_id``,
|
||||||
|
etc.)
|
||||||
|
|
||||||
|
``someprovider.check_password``\(*user_id*, *password*)
|
||||||
|
|
||||||
|
This method provides a simpler interface than ``get_supported_login_types``
|
||||||
|
and ``check_auth`` for password auth providers that just want to provide a
|
||||||
|
mechanism for validating ``m.login.password`` logins.
|
||||||
|
|
||||||
|
Iif implemented, it will be called to check logins with an
|
||||||
|
``m.login.password`` login type. It is passed a qualified
|
||||||
|
``@localpart:domain`` user id, and the password provided by the user.
|
||||||
|
|
||||||
|
The method should return a Twisted ``Deferred`` object, which resolves to
|
||||||
|
``True`` if authentication is successful, and ``False`` if not.
|
||||||
|
|
||||||
|
``someprovider.on_logged_out``\(*user_id*, *device_id*, *access_token*)
|
||||||
|
|
||||||
|
This method, if implemented, is called when a user logs out. It is passed
|
||||||
|
the qualified user ID, the ID of the deactivated device (if any: access
|
||||||
|
tokens are occasionally created without an associated device ID), and the
|
||||||
|
(now deactivated) access token.
|
||||||
|
|
||||||
|
It may return a Twisted ``Deferred`` object; the logout request will wait
|
||||||
|
for the deferred to complete but the result is ignored.
|
@ -56,6 +56,7 @@ As a first cut, let's do #2 and have the receiver hit the API to calculate its o
|
|||||||
API
|
API
|
||||||
---
|
---
|
||||||
|
|
||||||
|
```
|
||||||
GET /_matrix/media/r0/preview_url?url=http://wherever.com
|
GET /_matrix/media/r0/preview_url?url=http://wherever.com
|
||||||
200 OK
|
200 OK
|
||||||
{
|
{
|
||||||
@ -66,6 +67,7 @@ GET /_matrix/media/r0/preview_url?url=http://wherever.com
|
|||||||
"og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”"
|
"og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”"
|
||||||
"og:site_name" : "Twitter"
|
"og:site_name" : "Twitter"
|
||||||
}
|
}
|
||||||
|
```
|
||||||
|
|
||||||
* Downloads the URL
|
* Downloads the URL
|
||||||
* If HTML, just stores it in RAM and parses it for OG meta tags
|
* If HTML, just stores it in RAM and parses it for OG meta tags
|
17
docs/user_directory.md
Normal file
17
docs/user_directory.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
User Directory API Implementation
|
||||||
|
=================================
|
||||||
|
|
||||||
|
The user directory is currently maintained based on the 'visible' users
|
||||||
|
on this particular server - i.e. ones which your account shares a room with, or
|
||||||
|
who are present in a publicly viewable room present on the server.
|
||||||
|
|
||||||
|
The directory info is stored in various tables, which can (typically after
|
||||||
|
DB corruption) get stale or out of sync. If this happens, for now the
|
||||||
|
quickest solution to fix it is:
|
||||||
|
|
||||||
|
```
|
||||||
|
UPDATE user_directory_stream_pos SET stream_id = NULL;
|
||||||
|
```
|
||||||
|
|
||||||
|
and restart the synapse, which should then start a background task to
|
||||||
|
flush the current tables and regenerate the directory.
|
@ -42,6 +42,14 @@ BOOLEAN_COLUMNS = {
|
|||||||
"public_room_list_stream": ["visibility"],
|
"public_room_list_stream": ["visibility"],
|
||||||
"device_lists_outbound_pokes": ["sent"],
|
"device_lists_outbound_pokes": ["sent"],
|
||||||
"users_who_share_rooms": ["share_private"],
|
"users_who_share_rooms": ["share_private"],
|
||||||
|
"groups": ["is_public"],
|
||||||
|
"group_rooms": ["is_public"],
|
||||||
|
"group_users": ["is_public", "is_admin"],
|
||||||
|
"group_summary_rooms": ["is_public"],
|
||||||
|
"group_room_categories": ["is_public"],
|
||||||
|
"group_summary_users": ["is_public"],
|
||||||
|
"group_roles": ["is_public"],
|
||||||
|
"local_group_membership": ["is_publicised", "is_admin"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -112,6 +120,7 @@ class Store(object):
|
|||||||
|
|
||||||
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
|
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
|
||||||
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
|
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
|
||||||
|
_simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"]
|
||||||
|
|
||||||
def runInteraction(self, desc, func, *args, **kwargs):
|
def runInteraction(self, desc, func, *args, **kwargs):
|
||||||
def r(conn):
|
def r(conn):
|
||||||
@ -318,7 +327,7 @@ class Porter(object):
|
|||||||
backward_chunk = min(row[0] for row in brows) - 1
|
backward_chunk = min(row[0] for row in brows) - 1
|
||||||
|
|
||||||
rows = frows + brows
|
rows = frows + brows
|
||||||
self._convert_rows(table, headers, rows)
|
rows = self._convert_rows(table, headers, rows)
|
||||||
|
|
||||||
def insert(txn):
|
def insert(txn):
|
||||||
self.postgres_store.insert_many_txn(
|
self.postgres_store.insert_many_txn(
|
||||||
@ -554,17 +563,29 @@ class Porter(object):
|
|||||||
i for i, h in enumerate(headers) if h in bool_col_names
|
i for i, h in enumerate(headers) if h in bool_col_names
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class BadValueException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
def conv(j, col):
|
def conv(j, col):
|
||||||
if j in bool_cols:
|
if j in bool_cols:
|
||||||
return bool(col)
|
return bool(col)
|
||||||
|
elif isinstance(col, basestring) and "\0" in col:
|
||||||
|
logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
|
||||||
|
raise BadValueException();
|
||||||
return col
|
return col
|
||||||
|
|
||||||
|
outrows = []
|
||||||
for i, row in enumerate(rows):
|
for i, row in enumerate(rows):
|
||||||
rows[i] = tuple(
|
try:
|
||||||
|
outrows.append(tuple(
|
||||||
conv(j, col)
|
conv(j, col)
|
||||||
for j, col in enumerate(row)
|
for j, col in enumerate(row)
|
||||||
if j > 0
|
if j > 0
|
||||||
)
|
))
|
||||||
|
except BadValueException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return outrows
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _setup_sent_transactions(self):
|
def _setup_sent_transactions(self):
|
||||||
@ -592,7 +613,7 @@ class Porter(object):
|
|||||||
"select", r,
|
"select", r,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._convert_rows("sent_transactions", headers, rows)
|
rows = self._convert_rows("sent_transactions", headers, rows)
|
||||||
|
|
||||||
inserted_rows = len(rows)
|
inserted_rows = len(rows)
|
||||||
if inserted_rows:
|
if inserted_rows:
|
||||||
|
@ -16,4 +16,4 @@
|
|||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.24.1"
|
__version__ = "0.25.0"
|
||||||
|
@ -50,8 +50,7 @@ logger = logging.getLogger("synapse.app.frontend_proxy")
|
|||||||
|
|
||||||
|
|
||||||
class KeyUploadServlet(RestServlet):
|
class KeyUploadServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
|
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
|
||||||
releases=())
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
"""
|
"""
|
||||||
@ -89,9 +88,16 @@ class KeyUploadServlet(RestServlet):
|
|||||||
|
|
||||||
if body:
|
if body:
|
||||||
# They're actually trying to upload something, proxy to main synapse.
|
# They're actually trying to upload something, proxy to main synapse.
|
||||||
|
# Pass through the auth headers, if any, in case the access token
|
||||||
|
# is there.
|
||||||
|
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
|
||||||
|
headers = {
|
||||||
|
"Authorization": auth_headers,
|
||||||
|
}
|
||||||
result = yield self.http_client.post_json_get_json(
|
result = yield self.http_client.post_json_get_json(
|
||||||
self.main_uri + request.uri,
|
self.main_uri + request.uri,
|
||||||
body,
|
body,
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
@ -30,6 +30,8 @@ from synapse.config._base import ConfigError
|
|||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.crypto import context_factory
|
from synapse.crypto import context_factory
|
||||||
from synapse.federation.transport.server import TransportLayerServer
|
from synapse.federation.transport.server import TransportLayerServer
|
||||||
|
from synapse.module_api import ModuleApi
|
||||||
|
from synapse.http.additional_resource import AdditionalResource
|
||||||
from synapse.http.server import RootRedirect
|
from synapse.http.server import RootRedirect
|
||||||
from synapse.http.site import SynapseSite
|
from synapse.http.site import SynapseSite
|
||||||
from synapse.metrics import register_memory_metrics
|
from synapse.metrics import register_memory_metrics
|
||||||
@ -49,6 +51,7 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d
|
|||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.util.manhole import manhole
|
from synapse.util.manhole import manhole
|
||||||
|
from synapse.util.module_loader import load_module
|
||||||
from synapse.util.rlimit import change_resource_limit
|
from synapse.util.rlimit import change_resource_limit
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
from twisted.application import service
|
from twisted.application import service
|
||||||
@ -107,9 +110,68 @@ class SynapseHomeServer(HomeServer):
|
|||||||
resources = {}
|
resources = {}
|
||||||
for res in listener_config["resources"]:
|
for res in listener_config["resources"]:
|
||||||
for name in res["names"]:
|
for name in res["names"]:
|
||||||
|
resources.update(self._configure_named_resource(
|
||||||
|
name, res.get("compress", False),
|
||||||
|
))
|
||||||
|
|
||||||
|
additional_resources = listener_config.get("additional_resources", {})
|
||||||
|
logger.debug("Configuring additional resources: %r",
|
||||||
|
additional_resources)
|
||||||
|
module_api = ModuleApi(self, self.get_auth_handler())
|
||||||
|
for path, resmodule in additional_resources.items():
|
||||||
|
handler_cls, config = load_module(resmodule)
|
||||||
|
handler = handler_cls(config, module_api)
|
||||||
|
resources[path] = AdditionalResource(self, handler.handle_request)
|
||||||
|
|
||||||
|
if WEB_CLIENT_PREFIX in resources:
|
||||||
|
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
||||||
|
else:
|
||||||
|
root_resource = Resource()
|
||||||
|
|
||||||
|
root_resource = create_resource_tree(resources, root_resource)
|
||||||
|
|
||||||
|
if tls:
|
||||||
|
for address in bind_addresses:
|
||||||
|
reactor.listenSSL(
|
||||||
|
port,
|
||||||
|
SynapseSite(
|
||||||
|
"synapse.access.https.%s" % (site_tag,),
|
||||||
|
site_tag,
|
||||||
|
listener_config,
|
||||||
|
root_resource,
|
||||||
|
),
|
||||||
|
self.tls_server_context_factory,
|
||||||
|
interface=address
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for address in bind_addresses:
|
||||||
|
reactor.listenTCP(
|
||||||
|
port,
|
||||||
|
SynapseSite(
|
||||||
|
"synapse.access.http.%s" % (site_tag,),
|
||||||
|
site_tag,
|
||||||
|
listener_config,
|
||||||
|
root_resource,
|
||||||
|
),
|
||||||
|
interface=address
|
||||||
|
)
|
||||||
|
logger.info("Synapse now listening on port %d", port)
|
||||||
|
|
||||||
|
def _configure_named_resource(self, name, compress=False):
|
||||||
|
"""Build a resource map for a named resource
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): named resource: one of "client", "federation", etc
|
||||||
|
compress (bool): whether to enable gzip compression for this
|
||||||
|
resource
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Resource]: map from path to HTTP resource
|
||||||
|
"""
|
||||||
|
resources = {}
|
||||||
if name == "client":
|
if name == "client":
|
||||||
client_resource = ClientRestResource(self)
|
client_resource = ClientRestResource(self)
|
||||||
if res["compress"]:
|
if compress:
|
||||||
client_resource = gz_wrap(client_resource)
|
client_resource = gz_wrap(client_resource)
|
||||||
|
|
||||||
resources.update({
|
resources.update({
|
||||||
@ -154,39 +216,7 @@ class SynapseHomeServer(HomeServer):
|
|||||||
if name == "metrics" and self.get_config().enable_metrics:
|
if name == "metrics" and self.get_config().enable_metrics:
|
||||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
|
||||||
if WEB_CLIENT_PREFIX in resources:
|
return resources
|
||||||
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
|
||||||
else:
|
|
||||||
root_resource = Resource()
|
|
||||||
|
|
||||||
root_resource = create_resource_tree(resources, root_resource)
|
|
||||||
|
|
||||||
if tls:
|
|
||||||
for address in bind_addresses:
|
|
||||||
reactor.listenSSL(
|
|
||||||
port,
|
|
||||||
SynapseSite(
|
|
||||||
"synapse.access.https.%s" % (site_tag,),
|
|
||||||
site_tag,
|
|
||||||
listener_config,
|
|
||||||
root_resource,
|
|
||||||
),
|
|
||||||
self.tls_server_context_factory,
|
|
||||||
interface=address
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for address in bind_addresses:
|
|
||||||
reactor.listenTCP(
|
|
||||||
port,
|
|
||||||
SynapseSite(
|
|
||||||
"synapse.access.http.%s" % (site_tag,),
|
|
||||||
site_tag,
|
|
||||||
listener_config,
|
|
||||||
root_resource,
|
|
||||||
),
|
|
||||||
interface=address
|
|
||||||
)
|
|
||||||
logger.info("Synapse now listening on port %d", port)
|
|
||||||
|
|
||||||
def start_listening(self):
|
def start_listening(self):
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
|
@ -18,6 +18,7 @@ from synapse.api.constants import ThirdPartyEntityKind
|
|||||||
from synapse.api.errors import CodeMessageException
|
from synapse.api.errors import CodeMessageException
|
||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
|
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.types import ThirdPartyInstanceID
|
from synapse.types import ThirdPartyInstanceID
|
||||||
|
|
||||||
@ -192,9 +193,12 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
key = (service.id, protocol)
|
key = (service.id, protocol)
|
||||||
return self.protocol_meta_cache.get(key) or (
|
result = self.protocol_meta_cache.get(key)
|
||||||
self.protocol_meta_cache.set(key, _get())
|
if not result:
|
||||||
|
result = self.protocol_meta_cache.set(
|
||||||
|
key, preserve_fn(_get)()
|
||||||
)
|
)
|
||||||
|
return make_deferred_yieldable(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def push_bulk(self, service, events, txn_id=None):
|
def push_bulk(self, service, events, txn_id=None):
|
||||||
|
@ -41,7 +41,7 @@ class CasConfig(Config):
|
|||||||
#cas_config:
|
#cas_config:
|
||||||
# enabled: true
|
# enabled: true
|
||||||
# server_url: "https://cas-server.com"
|
# server_url: "https://cas-server.com"
|
||||||
# service_url: "https://homesever.domain.com:8448"
|
# service_url: "https://homeserver.domain.com:8448"
|
||||||
# #required_attributes:
|
# #required_attributes:
|
||||||
# # name: value
|
# # name: value
|
||||||
"""
|
"""
|
||||||
|
@ -148,8 +148,8 @@ def setup_logging(config, use_worker_options=False):
|
|||||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||||
" - %(message)s"
|
" - %(message)s"
|
||||||
)
|
)
|
||||||
if log_config is None:
|
|
||||||
|
|
||||||
|
if log_config is None:
|
||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
level_for_storage = logging.INFO
|
level_for_storage = logging.INFO
|
||||||
if config.verbosity:
|
if config.verbosity:
|
||||||
@ -176,6 +176,10 @@ def setup_logging(config, use_worker_options=False):
|
|||||||
logger.info("Opened new log file due to SIGHUP")
|
logger.info("Opened new log file due to SIGHUP")
|
||||||
else:
|
else:
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
|
|
||||||
|
def sighup(signum, stack):
|
||||||
|
pass
|
||||||
|
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
handler.addFilter(LoggingContextFilter(request=""))
|
handler.addFilter(LoggingContextFilter(request=""))
|
||||||
|
@ -13,41 +13,40 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config
|
||||||
|
|
||||||
from synapse.util.module_loader import load_module
|
from synapse.util.module_loader import load_module
|
||||||
|
|
||||||
|
LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
|
||||||
|
|
||||||
|
|
||||||
class PasswordAuthProviderConfig(Config):
|
class PasswordAuthProviderConfig(Config):
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.password_providers = []
|
self.password_providers = []
|
||||||
|
providers = []
|
||||||
provider_config = None
|
|
||||||
|
|
||||||
# We want to be backwards compatible with the old `ldap_config`
|
# We want to be backwards compatible with the old `ldap_config`
|
||||||
# param.
|
# param.
|
||||||
ldap_config = config.get("ldap_config", {})
|
ldap_config = config.get("ldap_config", {})
|
||||||
self.ldap_enabled = ldap_config.get("enabled", False)
|
if ldap_config.get("enabled", False):
|
||||||
if self.ldap_enabled:
|
providers.append[{
|
||||||
from ldap_auth_provider import LdapAuthProvider
|
'module': LDAP_PROVIDER,
|
||||||
parsed_config = LdapAuthProvider.parse_config(ldap_config)
|
'config': ldap_config,
|
||||||
self.password_providers.append((LdapAuthProvider, parsed_config))
|
}]
|
||||||
|
|
||||||
providers = config.get("password_providers", [])
|
providers.extend(config.get("password_providers", []))
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
|
mod_name = provider['module']
|
||||||
|
|
||||||
# This is for backwards compat when the ldap auth provider resided
|
# This is for backwards compat when the ldap auth provider resided
|
||||||
# in this package.
|
# in this package.
|
||||||
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
||||||
from ldap_auth_provider import LdapAuthProvider
|
mod_name = LDAP_PROVIDER
|
||||||
provider_class = LdapAuthProvider
|
|
||||||
try:
|
(provider_class, provider_config) = load_module({
|
||||||
provider_config = provider_class.parse_config(provider["config"])
|
"module": mod_name,
|
||||||
except Exception as e:
|
"config": provider['config'],
|
||||||
raise ConfigError(
|
})
|
||||||
"Failed to parse config for %r: %r" % (provider['module'], e)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
(provider_class, provider_config) = load_module(provider)
|
|
||||||
|
|
||||||
self.password_providers.append((provider_class, provider_config))
|
self.password_providers.append((provider_class, provider_config))
|
||||||
|
|
||||||
|
@ -247,6 +247,13 @@ class ServerConfig(Config):
|
|||||||
- names: [federation] # Federation APIs
|
- names: [federation] # Federation APIs
|
||||||
compress: false
|
compress: false
|
||||||
|
|
||||||
|
# optional list of additional endpoints which can be loaded via
|
||||||
|
# dynamic modules
|
||||||
|
# additional_resources:
|
||||||
|
# "/_matrix/my/custom/endpoint":
|
||||||
|
# module: my_module.CustomRequestHandler
|
||||||
|
# config: {}
|
||||||
|
|
||||||
# Unsecure HTTP listener,
|
# Unsecure HTTP listener,
|
||||||
# For when matrix traffic passes through loadbalancer that unwraps TLS.
|
# For when matrix traffic passes through loadbalancer that unwraps TLS.
|
||||||
- port: %(unsecure_port)s
|
- port: %(unsecure_port)s
|
||||||
|
@ -109,6 +109,12 @@ class TlsConfig(Config):
|
|||||||
# key. It may be necessary to publish the fingerprints of a new
|
# key. It may be necessary to publish the fingerprints of a new
|
||||||
# certificate and wait until the "valid_until_ts" of the previous key
|
# certificate and wait until the "valid_until_ts" of the previous key
|
||||||
# responses have passed before deploying it.
|
# responses have passed before deploying it.
|
||||||
|
#
|
||||||
|
# You can calculate a fingerprint from a given TLS listener via:
|
||||||
|
# openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
|
||||||
|
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
|
||||||
|
# or by checking matrix.org/federationtester/api/report?server_name=$host
|
||||||
|
#
|
||||||
tls_fingerprints: []
|
tls_fingerprints: []
|
||||||
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
||||||
""" % locals()
|
""" % locals()
|
||||||
|
@ -18,6 +18,7 @@ from .federation_base import FederationBase
|
|||||||
from .units import Transaction, Edu
|
from .units import Transaction, Edu
|
||||||
|
|
||||||
from synapse.util import async
|
from synapse.util import async
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
@ -253,12 +254,13 @@ class FederationServer(FederationBase):
|
|||||||
result = self._state_resp_cache.get((room_id, event_id))
|
result = self._state_resp_cache.get((room_id, event_id))
|
||||||
if not result:
|
if not result:
|
||||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||||
resp = yield self._state_resp_cache.set(
|
d = self._state_resp_cache.set(
|
||||||
(room_id, event_id),
|
(room_id, event_id),
|
||||||
self._on_context_state_request_compute(room_id, event_id)
|
preserve_fn(self._on_context_state_request_compute)(room_id, event_id)
|
||||||
)
|
)
|
||||||
|
resp = yield make_deferred_yieldable(d)
|
||||||
else:
|
else:
|
||||||
resp = yield result
|
resp = yield make_deferred_yieldable(result)
|
||||||
|
|
||||||
defer.returnValue((200, resp))
|
defer.returnValue((200, resp))
|
||||||
|
|
||||||
|
@ -545,6 +545,20 @@ class TransportLayerClient(object):
|
|||||||
ignore_backoff=True,
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
|
||||||
|
config_key, content):
|
||||||
|
"""Update room in group
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/groups/%s/room/%s/config/%s" % (group_id, room_id, config_key,)
|
||||||
|
|
||||||
|
return self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
args={"requester_user_id": requester_user_id},
|
||||||
|
data=content,
|
||||||
|
ignore_backoff=True,
|
||||||
|
)
|
||||||
|
|
||||||
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
|
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
|
||||||
"""Remove a room from a group
|
"""Remove a room from a group
|
||||||
"""
|
"""
|
||||||
|
@ -676,7 +676,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
|
|||||||
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
||||||
"""Add/remove room from group
|
"""Add/remove room from group
|
||||||
"""
|
"""
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/room/(?<room_id>)$"
|
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, origin, content, query, group_id, room_id):
|
def on_POST(self, origin, content, query, group_id, room_id):
|
||||||
@ -703,6 +703,27 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
|||||||
defer.returnValue((200, new_content))
|
defer.returnValue((200, new_content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
|
||||||
|
"""Update room config in group
|
||||||
|
"""
|
||||||
|
PATH = (
|
||||||
|
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
|
||||||
|
"/config/(?P<config_key>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query, group_id, room_id, config_key):
|
||||||
|
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||||
|
if get_domain_from_id(requester_user_id) != origin:
|
||||||
|
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||||
|
|
||||||
|
result = yield self.groups_handler.update_room_in_group(
|
||||||
|
group_id, requester_user_id, room_id, config_key, content,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
class FederationGroupsUsersServlet(BaseFederationServlet):
|
class FederationGroupsUsersServlet(BaseFederationServlet):
|
||||||
"""Get the users in a group on behalf of a user
|
"""Get the users in a group on behalf of a user
|
||||||
"""
|
"""
|
||||||
@ -1142,6 +1163,8 @@ GROUP_SERVER_SERVLET_CLASSES = (
|
|||||||
FederationGroupsRolesServlet,
|
FederationGroupsRolesServlet,
|
||||||
FederationGroupsRoleServlet,
|
FederationGroupsRoleServlet,
|
||||||
FederationGroupsSummaryUsersServlet,
|
FederationGroupsSummaryUsersServlet,
|
||||||
|
FederationGroupsAddRoomsServlet,
|
||||||
|
FederationGroupsAddRoomsConfigServlet,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +13,31 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Attestations ensure that users and groups can't lie about their memberships.
|
||||||
|
|
||||||
|
When a user joins a group the HS and GS swap attestations, which allow them
|
||||||
|
both to independently prove to third parties their membership.These
|
||||||
|
attestations have a validity period so need to be periodically renewed.
|
||||||
|
|
||||||
|
If a user leaves (or gets kicked out of) a group, either side can still use
|
||||||
|
their attestation to "prove" their membership, until the attestation expires.
|
||||||
|
Therefore attestations shouldn't be relied on to prove membership in important
|
||||||
|
cases, but can for less important situtations, e.g. showing a users membership
|
||||||
|
of groups on their profile, showing flairs, etc.abs
|
||||||
|
|
||||||
|
An attestsation is a signed blob of json that looks like:
|
||||||
|
|
||||||
|
{
|
||||||
|
"user_id": "@foo:a.example.com",
|
||||||
|
"group_id": "+bar:b.example.com",
|
||||||
|
"valid_until_ms": 1507994728530,
|
||||||
|
"signatures":{"matrix.org":{"ed25519:auto":"..."}}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
@ -22,9 +47,17 @@ from synapse.util.logcontext import preserve_fn
|
|||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Default validity duration for new attestations we create
|
# Default validity duration for new attestations we create
|
||||||
DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
|
DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
|
# We add some jitter to the validity duration of attestations so that if we
|
||||||
|
# add lots of users at once we don't need to renew them all at once.
|
||||||
|
# The jitter is a multiplier picked randomly between the first and second number
|
||||||
|
DEFAULT_ATTESTATION_JITTER = (0.9, 1.3)
|
||||||
|
|
||||||
# Start trying to update our attestations when they come this close to expiring
|
# Start trying to update our attestations when they come this close to expiring
|
||||||
UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
|
UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
@ -73,10 +106,14 @@ class GroupAttestationSigning(object):
|
|||||||
"""Create an attestation for the group_id and user_id with default
|
"""Create an attestation for the group_id and user_id with default
|
||||||
validity length.
|
validity length.
|
||||||
"""
|
"""
|
||||||
|
validity_period = DEFAULT_ATTESTATION_LENGTH_MS
|
||||||
|
validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
|
||||||
|
valid_until_ms = int(self.clock.time_msec() + validity_period)
|
||||||
|
|
||||||
return sign_json({
|
return sign_json({
|
||||||
"group_id": group_id,
|
"group_id": group_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS,
|
"valid_until_ms": valid_until_ms,
|
||||||
}, self.server_name, self.signing_key)
|
}, self.server_name, self.signing_key)
|
||||||
|
|
||||||
|
|
||||||
@ -128,12 +165,19 @@ class GroupAttestionRenewer(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _renew_attestation(group_id, user_id):
|
def _renew_attestation(group_id, user_id):
|
||||||
attestation = self.attestations.create_attestation(group_id, user_id)
|
if not self.is_mine_id(group_id):
|
||||||
|
destination = get_domain_from_id(group_id)
|
||||||
if self.is_mine_id(group_id):
|
elif not self.is_mine_id(user_id):
|
||||||
destination = get_domain_from_id(user_id)
|
destination = get_domain_from_id(user_id)
|
||||||
else:
|
else:
|
||||||
destination = get_domain_from_id(group_id)
|
logger.warn(
|
||||||
|
"Incorrectly trying to do attestations for user: %r in %r",
|
||||||
|
user_id, group_id,
|
||||||
|
)
|
||||||
|
yield self.store.remove_attestation_renewal(group_id, user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
attestation = self.attestations.create_attestation(group_id, user_id)
|
||||||
|
|
||||||
yield self.transport_client.renew_group_attestation(
|
yield self.transport_client.renew_group_attestation(
|
||||||
destination, group_id, user_id,
|
destination, group_id, user_id,
|
||||||
|
@ -49,7 +49,8 @@ class GroupsServerHandler(object):
|
|||||||
hs.get_groups_attestation_renewer()
|
hs.get_groups_attestation_renewer()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None):
|
def check_group_is_ours(self, group_id, requester_user_id,
|
||||||
|
and_exists=False, and_is_admin=None):
|
||||||
"""Check that the group is ours, and optionally if it exists.
|
"""Check that the group is ours, and optionally if it exists.
|
||||||
|
|
||||||
If group does exist then return group.
|
If group does exist then return group.
|
||||||
@ -67,6 +68,10 @@ class GroupsServerHandler(object):
|
|||||||
if and_exists and not group:
|
if and_exists and not group:
|
||||||
raise SynapseError(404, "Unknown group")
|
raise SynapseError(404, "Unknown group")
|
||||||
|
|
||||||
|
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||||
|
if group and not is_user_in_group and not group["is_public"]:
|
||||||
|
raise SynapseError(404, "Unknown group")
|
||||||
|
|
||||||
if and_is_admin:
|
if and_is_admin:
|
||||||
is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
|
is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
@ -84,7 +89,7 @@ class GroupsServerHandler(object):
|
|||||||
|
|
||||||
A user/room may appear in multiple roles/categories.
|
A user/room may appear in multiple roles/categories.
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||||
|
|
||||||
@ -153,10 +158,16 @@ class GroupsServerHandler(object):
|
|||||||
})
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update_group_summary_room(self, group_id, user_id, room_id, category_id, content):
|
def update_group_summary_room(self, group_id, requester_user_id,
|
||||||
|
room_id, category_id, content):
|
||||||
"""Add/update a room to the group summary
|
"""Add/update a room to the group summary
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
yield self.check_group_is_ours(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
and_exists=True,
|
||||||
|
and_is_admin=requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
RoomID.from_string(room_id) # Ensure valid room id
|
RoomID.from_string(room_id) # Ensure valid room id
|
||||||
|
|
||||||
@ -175,10 +186,16 @@ class GroupsServerHandler(object):
|
|||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_group_summary_room(self, group_id, user_id, room_id, category_id):
|
def delete_group_summary_room(self, group_id, requester_user_id,
|
||||||
|
room_id, category_id):
|
||||||
"""Remove a room from the summary
|
"""Remove a room from the summary
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
yield self.check_group_is_ours(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
and_exists=True,
|
||||||
|
and_is_admin=requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
yield self.store.remove_room_from_summary(
|
yield self.store.remove_room_from_summary(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
@ -189,10 +206,10 @@ class GroupsServerHandler(object):
|
|||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_group_categories(self, group_id, user_id):
|
def get_group_categories(self, group_id, requester_user_id):
|
||||||
"""Get all categories in a group (as seen by user)
|
"""Get all categories in a group (as seen by user)
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
categories = yield self.store.get_group_categories(
|
categories = yield self.store.get_group_categories(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
@ -200,10 +217,10 @@ class GroupsServerHandler(object):
|
|||||||
defer.returnValue({"categories": categories})
|
defer.returnValue({"categories": categories})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_group_category(self, group_id, user_id, category_id):
|
def get_group_category(self, group_id, requester_user_id, category_id):
|
||||||
"""Get a specific category in a group (as seen by user)
|
"""Get a specific category in a group (as seen by user)
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
res = yield self.store.get_group_category(
|
res = yield self.store.get_group_category(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
@ -213,10 +230,15 @@ class GroupsServerHandler(object):
|
|||||||
defer.returnValue(res)
|
defer.returnValue(res)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update_group_category(self, group_id, user_id, category_id, content):
|
def update_group_category(self, group_id, requester_user_id, category_id, content):
|
||||||
"""Add/Update a group category
|
"""Add/Update a group category
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
yield self.check_group_is_ours(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
and_exists=True,
|
||||||
|
and_is_admin=requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
is_public = _parse_visibility_from_contents(content)
|
is_public = _parse_visibility_from_contents(content)
|
||||||
profile = content.get("profile")
|
profile = content.get("profile")
|
||||||
@ -231,10 +253,15 @@ class GroupsServerHandler(object):
|
|||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_group_category(self, group_id, user_id, category_id):
|
def delete_group_category(self, group_id, requester_user_id, category_id):
|
||||||
"""Delete a group category
|
"""Delete a group category
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
yield self.check_group_is_ours(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
and_exists=True,
|
||||||
|
and_is_admin=requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
yield self.store.remove_group_category(
|
yield self.store.remove_group_category(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
@ -244,10 +271,10 @@ class GroupsServerHandler(object):
|
|||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_group_roles(self, group_id, user_id):
|
def get_group_roles(self, group_id, requester_user_id):
|
||||||
"""Get all roles in a group (as seen by user)
|
"""Get all roles in a group (as seen by user)
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
roles = yield self.store.get_group_roles(
|
roles = yield self.store.get_group_roles(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
@ -255,10 +282,10 @@ class GroupsServerHandler(object):
|
|||||||
defer.returnValue({"roles": roles})
|
defer.returnValue({"roles": roles})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_group_role(self, group_id, user_id, role_id):
|
def get_group_role(self, group_id, requester_user_id, role_id):
|
||||||
"""Get a specific role in a group (as seen by user)
|
"""Get a specific role in a group (as seen by user)
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
res = yield self.store.get_group_role(
|
res = yield self.store.get_group_role(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
@ -267,10 +294,15 @@ class GroupsServerHandler(object):
|
|||||||
defer.returnValue(res)
|
defer.returnValue(res)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update_group_role(self, group_id, user_id, role_id, content):
|
def update_group_role(self, group_id, requester_user_id, role_id, content):
|
||||||
"""Add/update a role in a group
|
"""Add/update a role in a group
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
yield self.check_group_is_ours(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
and_exists=True,
|
||||||
|
and_is_admin=requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
is_public = _parse_visibility_from_contents(content)
|
is_public = _parse_visibility_from_contents(content)
|
||||||
|
|
||||||
@ -286,10 +318,15 @@ class GroupsServerHandler(object):
|
|||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_group_role(self, group_id, user_id, role_id):
|
def delete_group_role(self, group_id, requester_user_id, role_id):
|
||||||
"""Remove role from group
|
"""Remove role from group
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
yield self.check_group_is_ours(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
and_exists=True,
|
||||||
|
and_is_admin=requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
yield self.store.remove_group_role(
|
yield self.store.remove_group_role(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
@ -304,7 +341,7 @@ class GroupsServerHandler(object):
|
|||||||
"""Add/update a users entry in the group summary
|
"""Add/update a users entry in the group summary
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(
|
yield self.check_group_is_ours(
|
||||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
order = content.get("order", None)
|
order = content.get("order", None)
|
||||||
@ -326,7 +363,7 @@ class GroupsServerHandler(object):
|
|||||||
"""Remove a user from the group summary
|
"""Remove a user from the group summary
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(
|
yield self.check_group_is_ours(
|
||||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.store.remove_user_from_summary(
|
yield self.store.remove_user_from_summary(
|
||||||
@ -342,7 +379,7 @@ class GroupsServerHandler(object):
|
|||||||
"""Get the group profile as seen by requester_user_id
|
"""Get the group profile as seen by requester_user_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield self.check_group_is_ours(group_id)
|
yield self.check_group_is_ours(group_id, requester_user_id)
|
||||||
|
|
||||||
group_description = yield self.store.get_group(group_id)
|
group_description = yield self.store.get_group(group_id)
|
||||||
|
|
||||||
@ -356,7 +393,7 @@ class GroupsServerHandler(object):
|
|||||||
"""Update the group profile
|
"""Update the group profile
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(
|
yield self.check_group_is_ours(
|
||||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
profile = {}
|
profile = {}
|
||||||
@ -377,7 +414,7 @@ class GroupsServerHandler(object):
|
|||||||
The ordering is arbitrary at the moment
|
The ordering is arbitrary at the moment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||||
|
|
||||||
@ -389,14 +426,15 @@ class GroupsServerHandler(object):
|
|||||||
for user_result in user_results:
|
for user_result in user_results:
|
||||||
g_user_id = user_result["user_id"]
|
g_user_id = user_result["user_id"]
|
||||||
is_public = user_result["is_public"]
|
is_public = user_result["is_public"]
|
||||||
|
is_privileged = user_result["is_admin"]
|
||||||
|
|
||||||
entry = {"user_id": g_user_id}
|
entry = {"user_id": g_user_id}
|
||||||
|
|
||||||
profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
|
profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
|
||||||
entry.update(profile)
|
entry.update(profile)
|
||||||
|
|
||||||
if not is_public:
|
entry["is_public"] = bool(is_public)
|
||||||
entry["is_public"] = False
|
entry["is_privileged"] = bool(is_privileged)
|
||||||
|
|
||||||
if not self.is_mine_id(g_user_id):
|
if not self.is_mine_id(g_user_id):
|
||||||
attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
|
attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
|
||||||
@ -425,7 +463,7 @@ class GroupsServerHandler(object):
|
|||||||
The ordering is arbitrary at the moment
|
The ordering is arbitrary at the moment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||||
|
|
||||||
@ -459,7 +497,7 @@ class GroupsServerHandler(object):
|
|||||||
This returns rooms in order of decreasing number of joined users
|
This returns rooms in order of decreasing number of joined users
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||||
|
|
||||||
@ -470,7 +508,6 @@ class GroupsServerHandler(object):
|
|||||||
chunk = []
|
chunk = []
|
||||||
for room_result in room_results:
|
for room_result in room_results:
|
||||||
room_id = room_result["room_id"]
|
room_id = room_result["room_id"]
|
||||||
is_public = room_result["is_public"]
|
|
||||||
|
|
||||||
joined_users = yield self.store.get_users_in_room(room_id)
|
joined_users = yield self.store.get_users_in_room(room_id)
|
||||||
entry = yield self.room_list_handler.generate_room_entry(
|
entry = yield self.room_list_handler.generate_room_entry(
|
||||||
@ -481,8 +518,7 @@ class GroupsServerHandler(object):
|
|||||||
if not entry:
|
if not entry:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not is_public:
|
entry["is_public"] = bool(room_result["is_public"])
|
||||||
entry["is_public"] = False
|
|
||||||
|
|
||||||
chunk.append(entry)
|
chunk.append(entry)
|
||||||
|
|
||||||
@ -500,7 +536,7 @@ class GroupsServerHandler(object):
|
|||||||
RoomID.from_string(room_id) # Ensure valid room id
|
RoomID.from_string(room_id) # Ensure valid room id
|
||||||
|
|
||||||
yield self.check_group_is_ours(
|
yield self.check_group_is_ours(
|
||||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
is_public = _parse_visibility_from_contents(content)
|
is_public = _parse_visibility_from_contents(content)
|
||||||
@ -509,12 +545,35 @@ class GroupsServerHandler(object):
|
|||||||
|
|
||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
|
||||||
|
content):
|
||||||
|
"""Update room in group
|
||||||
|
"""
|
||||||
|
RoomID.from_string(room_id) # Ensure valid room id
|
||||||
|
|
||||||
|
yield self.check_group_is_ours(
|
||||||
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if config_key == "m.visibility":
|
||||||
|
is_public = _parse_visibility_dict(content)
|
||||||
|
|
||||||
|
yield self.store.update_room_in_group_visibility(
|
||||||
|
group_id, room_id,
|
||||||
|
is_public=is_public,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise SynapseError(400, "Uknown config option")
|
||||||
|
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def remove_room_from_group(self, group_id, requester_user_id, room_id):
|
def remove_room_from_group(self, group_id, requester_user_id, room_id):
|
||||||
"""Remove room from group
|
"""Remove room from group
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(
|
yield self.check_group_is_ours(
|
||||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.store.remove_room_from_group(group_id, room_id)
|
yield self.store.remove_room_from_group(group_id, room_id)
|
||||||
@ -527,7 +586,7 @@ class GroupsServerHandler(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
group = yield self.check_group_is_ours(
|
group = yield self.check_group_is_ours(
|
||||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Check if user knocked
|
# TODO: Check if user knocked
|
||||||
@ -596,35 +655,40 @@ class GroupsServerHandler(object):
|
|||||||
raise SynapseError(502, "Unknown state returned by HS")
|
raise SynapseError(502, "Unknown state returned by HS")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def accept_invite(self, group_id, user_id, content):
|
def accept_invite(self, group_id, requester_user_id, content):
|
||||||
"""User tries to accept an invite to the group.
|
"""User tries to accept an invite to the group.
|
||||||
|
|
||||||
This is different from them asking to join, and so should error if no
|
This is different from them asking to join, and so should error if no
|
||||||
invite exists (and they're not a member of the group)
|
invite exists (and they're not a member of the group)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
if not self.store.is_user_invited_to_local_group(group_id, user_id):
|
is_invited = yield self.store.is_user_invited_to_local_group(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
)
|
||||||
|
if not is_invited:
|
||||||
raise SynapseError(403, "User not invited to group")
|
raise SynapseError(403, "User not invited to group")
|
||||||
|
|
||||||
if not self.hs.is_mine_id(user_id):
|
if not self.hs.is_mine_id(requester_user_id):
|
||||||
|
local_attestation = self.attestations.create_attestation(
|
||||||
|
group_id, requester_user_id,
|
||||||
|
)
|
||||||
remote_attestation = content["attestation"]
|
remote_attestation = content["attestation"]
|
||||||
|
|
||||||
yield self.attestations.verify_attestation(
|
yield self.attestations.verify_attestation(
|
||||||
remote_attestation,
|
remote_attestation,
|
||||||
user_id=user_id,
|
user_id=requester_user_id,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
local_attestation = None
|
||||||
remote_attestation = None
|
remote_attestation = None
|
||||||
|
|
||||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
|
||||||
|
|
||||||
is_public = _parse_visibility_from_contents(content)
|
is_public = _parse_visibility_from_contents(content)
|
||||||
|
|
||||||
yield self.store.add_user_to_group(
|
yield self.store.add_user_to_group(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
is_admin=False,
|
is_admin=False,
|
||||||
is_public=is_public,
|
is_public=is_public,
|
||||||
local_attestation=local_attestation,
|
local_attestation=local_attestation,
|
||||||
@ -637,31 +701,31 @@ class GroupsServerHandler(object):
|
|||||||
})
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def knock(self, group_id, user_id, content):
|
def knock(self, group_id, requester_user_id, content):
|
||||||
"""A user requests becoming a member of the group
|
"""A user requests becoming a member of the group
|
||||||
"""
|
"""
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def accept_knock(self, group_id, user_id, content):
|
def accept_knock(self, group_id, requester_user_id, content):
|
||||||
"""Accept a users knock to the room.
|
"""Accept a users knock to the room.
|
||||||
|
|
||||||
Errors if the user hasn't knocked, rather than inviting them.
|
Errors if the user hasn't knocked, rather than inviting them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||||
"""Remove a user from the group; either a user is leaving or and admin
|
"""Remove a user from the group; either a user is leaving or an admin
|
||||||
kicked htem.
|
kicked them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
is_kick = False
|
is_kick = False
|
||||||
if requester_user_id != user_id:
|
if requester_user_id != user_id:
|
||||||
@ -692,8 +756,8 @@ class GroupsServerHandler(object):
|
|||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_group(self, group_id, user_id, content):
|
def create_group(self, group_id, requester_user_id, content):
|
||||||
group = yield self.check_group_is_ours(group_id)
|
group = yield self.check_group_is_ours(group_id, requester_user_id)
|
||||||
|
|
||||||
logger.info("Attempting to create group with ID: %r", group_id)
|
logger.info("Attempting to create group with ID: %r", group_id)
|
||||||
|
|
||||||
@ -703,11 +767,11 @@ class GroupsServerHandler(object):
|
|||||||
if group:
|
if group:
|
||||||
raise SynapseError(400, "Group already exists")
|
raise SynapseError(400, "Group already exists")
|
||||||
|
|
||||||
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
|
is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
if not self.hs.config.enable_group_creation:
|
if not self.hs.config.enable_group_creation:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403, "Only server admin can create group on this server",
|
403, "Only a server admin can create groups on this server",
|
||||||
)
|
)
|
||||||
localpart = group_id_obj.localpart
|
localpart = group_id_obj.localpart
|
||||||
if not localpart.startswith(self.hs.config.group_creation_prefix):
|
if not localpart.startswith(self.hs.config.group_creation_prefix):
|
||||||
@ -727,38 +791,41 @@ class GroupsServerHandler(object):
|
|||||||
|
|
||||||
yield self.store.create_group(
|
yield self.store.create_group(
|
||||||
group_id,
|
group_id,
|
||||||
user_id,
|
requester_user_id,
|
||||||
name=name,
|
name=name,
|
||||||
avatar_url=avatar_url,
|
avatar_url=avatar_url,
|
||||||
short_description=short_description,
|
short_description=short_description,
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.hs.is_mine_id(user_id):
|
if not self.hs.is_mine_id(requester_user_id):
|
||||||
remote_attestation = content["attestation"]
|
remote_attestation = content["attestation"]
|
||||||
|
|
||||||
yield self.attestations.verify_attestation(
|
yield self.attestations.verify_attestation(
|
||||||
remote_attestation,
|
remote_attestation,
|
||||||
user_id=user_id,
|
user_id=requester_user_id,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
local_attestation = self.attestations.create_attestation(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
local_attestation = None
|
local_attestation = None
|
||||||
remote_attestation = None
|
remote_attestation = None
|
||||||
|
|
||||||
yield self.store.add_user_to_group(
|
yield self.store.add_user_to_group(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
is_admin=True,
|
is_admin=True,
|
||||||
is_public=True, # TODO
|
is_public=True, # TODO
|
||||||
local_attestation=local_attestation,
|
local_attestation=local_attestation,
|
||||||
remote_attestation=remote_attestation,
|
remote_attestation=remote_attestation,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.hs.is_mine_id(user_id):
|
if not self.hs.is_mine_id(requester_user_id):
|
||||||
yield self.store.add_remote_profile_cache(
|
yield self.store.add_remote_profile_cache(
|
||||||
user_id,
|
requester_user_id,
|
||||||
displayname=user_profile.get("displayname"),
|
displayname=user_profile.get("displayname"),
|
||||||
avatar_url=user_profile.get("avatar_url"),
|
avatar_url=user_profile.get("avatar_url"),
|
||||||
)
|
)
|
||||||
@ -773,15 +840,25 @@ def _parse_visibility_from_contents(content):
|
|||||||
public or not
|
public or not
|
||||||
"""
|
"""
|
||||||
|
|
||||||
visibility = content.get("visibility")
|
visibility = content.get("m.visibility")
|
||||||
if visibility:
|
if visibility:
|
||||||
vis_type = visibility["type"]
|
return _parse_visibility_dict(visibility)
|
||||||
if vis_type not in ("public", "private"):
|
|
||||||
raise SynapseError(
|
|
||||||
400, "Synapse only supports 'public'/'private' visibility"
|
|
||||||
)
|
|
||||||
is_public = vis_type == "public"
|
|
||||||
else:
|
else:
|
||||||
is_public = True
|
is_public = True
|
||||||
|
|
||||||
return is_public
|
return is_public
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_visibility_dict(visibility):
|
||||||
|
"""Given a dict for the "m.visibility" config return if the entity should
|
||||||
|
be public or not
|
||||||
|
"""
|
||||||
|
vis_type = visibility.get("type")
|
||||||
|
if not vis_type:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if vis_type not in ("public", "private"):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Synapse only supports 'public'/'private' visibility"
|
||||||
|
)
|
||||||
|
return vis_type == "public"
|
||||||
|
@ -70,11 +70,10 @@ class ApplicationServicesHandler(object):
|
|||||||
with Measure(self.clock, "notify_interested_services"):
|
with Measure(self.clock, "notify_interested_services"):
|
||||||
self.is_processing = True
|
self.is_processing = True
|
||||||
try:
|
try:
|
||||||
upper_bound = self.current_max
|
|
||||||
limit = 100
|
limit = 100
|
||||||
while True:
|
while True:
|
||||||
upper_bound, events = yield self.store.get_new_events_for_appservice(
|
upper_bound, events = yield self.store.get_new_events_for_appservice(
|
||||||
upper_bound, limit
|
self.current_max, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
if not events:
|
if not events:
|
||||||
@ -105,9 +104,6 @@ class ApplicationServicesHandler(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield self.store.set_appservice_last_pos(upper_bound)
|
yield self.store.set_appservice_last_pos(upper_bound)
|
||||||
|
|
||||||
if len(events) < limit:
|
|
||||||
break
|
|
||||||
finally:
|
finally:
|
||||||
self.is_processing = False
|
self.is_processing = False
|
||||||
|
|
||||||
|
@ -13,13 +13,13 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.types import UserID
|
|
||||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
||||||
|
from synapse.module_api import ModuleApi
|
||||||
|
from synapse.types import UserID
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
|
||||||
@ -63,10 +63,7 @@ class AuthHandler(BaseHandler):
|
|||||||
reset_expiry_on_get=True,
|
reset_expiry_on_get=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
account_handler = _AccountHandler(
|
account_handler = ModuleApi(hs, self)
|
||||||
hs, check_user_exists=self.check_user_exists
|
|
||||||
)
|
|
||||||
|
|
||||||
self.password_providers = [
|
self.password_providers = [
|
||||||
module(config=config, account_handler=account_handler)
|
module(config=config, account_handler=account_handler)
|
||||||
for module, config in hs.config.password_providers
|
for module, config in hs.config.password_providers
|
||||||
@ -75,14 +72,24 @@ class AuthHandler(BaseHandler):
|
|||||||
logger.info("Extra password_providers: %r", self.password_providers)
|
logger.info("Extra password_providers: %r", self.password_providers)
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.device_handler = hs.get_device_handler()
|
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
self._password_enabled = hs.config.password_enabled
|
||||||
|
|
||||||
|
login_types = set()
|
||||||
|
if self._password_enabled:
|
||||||
|
login_types.add(LoginType.PASSWORD)
|
||||||
|
for provider in self.password_providers:
|
||||||
|
if hasattr(provider, "get_supported_login_types"):
|
||||||
|
login_types.update(
|
||||||
|
provider.get_supported_login_types().keys()
|
||||||
|
)
|
||||||
|
self._supported_login_types = frozenset(login_types)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
"""
|
"""
|
||||||
Takes a dictionary sent by the client in the login / registration
|
Takes a dictionary sent by the client in the login / registration
|
||||||
protocol and handles the login flow.
|
protocol and handles the User-Interactive Auth flow.
|
||||||
|
|
||||||
As a side effect, this function fills in the 'creds' key on the user's
|
As a side effect, this function fills in the 'creds' key on the user's
|
||||||
session with a map, which maps each auth-type (str) to the relevant
|
session with a map, which maps each auth-type (str) to the relevant
|
||||||
@ -260,16 +267,19 @@ class AuthHandler(BaseHandler):
|
|||||||
sess = self._get_session_info(session_id)
|
sess = self._get_session_info(session_id)
|
||||||
return sess.setdefault('serverdict', {}).get(key, default)
|
return sess.setdefault('serverdict', {}).get(key, default)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def _check_password_auth(self, authdict, _):
|
def _check_password_auth(self, authdict, _):
|
||||||
if "user" not in authdict or "password" not in authdict:
|
if "user" not in authdict or "password" not in authdict:
|
||||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
user_id = authdict["user"]
|
user_id = authdict["user"]
|
||||||
password = authdict["password"]
|
password = authdict["password"]
|
||||||
if not user_id.startswith('@'):
|
|
||||||
user_id = UserID(user_id, self.hs.hostname).to_string()
|
|
||||||
|
|
||||||
return self._check_password(user_id, password)
|
(canonical_id, callback) = yield self.validate_login(user_id, {
|
||||||
|
"type": LoginType.PASSWORD,
|
||||||
|
"password": password,
|
||||||
|
})
|
||||||
|
defer.returnValue(canonical_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_recaptcha(self, authdict, clientip):
|
def _check_recaptcha(self, authdict, clientip):
|
||||||
@ -398,26 +408,8 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
return self.sessions[session_id]
|
return self.sessions[session_id]
|
||||||
|
|
||||||
def validate_password_login(self, user_id, password):
|
|
||||||
"""
|
|
||||||
Authenticates the user with their username and password.
|
|
||||||
|
|
||||||
Used only by the v1 login API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id (str): complete @user:id
|
|
||||||
password (str): Password
|
|
||||||
Returns:
|
|
||||||
defer.Deferred: (str) canonical user id
|
|
||||||
Raises:
|
|
||||||
StoreError if there was a problem accessing the database
|
|
||||||
LoginError if there was an authentication problem.
|
|
||||||
"""
|
|
||||||
return self._check_password(user_id, password)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_access_token_for_user_id(self, user_id, device_id=None,
|
def get_access_token_for_user_id(self, user_id, device_id=None):
|
||||||
initial_display_name=None):
|
|
||||||
"""
|
"""
|
||||||
Creates a new access token for the user with the given user ID.
|
Creates a new access token for the user with the given user ID.
|
||||||
|
|
||||||
@ -431,13 +423,10 @@ class AuthHandler(BaseHandler):
|
|||||||
device_id (str|None): the device ID to associate with the tokens.
|
device_id (str|None): the device ID to associate with the tokens.
|
||||||
None to leave the tokens unassociated with a device (deprecated:
|
None to leave the tokens unassociated with a device (deprecated:
|
||||||
we should always have a device ID)
|
we should always have a device ID)
|
||||||
initial_display_name (str): display name to associate with the
|
|
||||||
device if it needs re-registering
|
|
||||||
Returns:
|
Returns:
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem storing the token.
|
StoreError if there was a problem storing the token.
|
||||||
LoginError if there was an authentication problem.
|
|
||||||
"""
|
"""
|
||||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||||
access_token = yield self.issue_access_token(user_id, device_id)
|
access_token = yield self.issue_access_token(user_id, device_id)
|
||||||
@ -447,9 +436,11 @@ class AuthHandler(BaseHandler):
|
|||||||
# really don't want is active access_tokens without a record of the
|
# really don't want is active access_tokens without a record of the
|
||||||
# device, so we double-check it here.
|
# device, so we double-check it here.
|
||||||
if device_id is not None:
|
if device_id is not None:
|
||||||
yield self.device_handler.check_device_registered(
|
try:
|
||||||
user_id, device_id, initial_display_name
|
yield self.store.get_device(user_id, device_id)
|
||||||
)
|
except StoreError:
|
||||||
|
yield self.store.delete_access_token(access_token)
|
||||||
|
raise StoreError(400, "Login raced against device deletion")
|
||||||
|
|
||||||
defer.returnValue(access_token)
|
defer.returnValue(access_token)
|
||||||
|
|
||||||
@ -501,29 +492,115 @@ class AuthHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def get_supported_login_types(self):
|
||||||
def _check_password(self, user_id, password):
|
"""Get a the login types supported for the /login API
|
||||||
"""Authenticate a user against the LDAP and local databases.
|
|
||||||
|
|
||||||
user_id is checked case insensitively against the local database, but
|
By default this is just 'm.login.password' (unless password_enabled is
|
||||||
will throw if there are multiple inexact matches.
|
False in the config file), but password auth providers can provide
|
||||||
|
other login types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[str]: login types
|
||||||
|
"""
|
||||||
|
return self._supported_login_types
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def validate_login(self, username, login_submission):
|
||||||
|
"""Authenticates the user for the /login API
|
||||||
|
|
||||||
|
Also used by the user-interactive auth flow to validate
|
||||||
|
m.login.password auth types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): complete @user:id
|
username (str): username supplied by the user
|
||||||
|
login_submission (dict): the whole of the login submission
|
||||||
|
(including 'type' and other relevant fields)
|
||||||
Returns:
|
Returns:
|
||||||
(str) the canonical_user_id
|
Deferred[str, func]: canonical user id, and optional callback
|
||||||
|
to be called once the access token and device id are issued
|
||||||
Raises:
|
Raises:
|
||||||
LoginError if login fails
|
StoreError if there was a problem accessing the database
|
||||||
|
SynapseError if there was a problem with the request
|
||||||
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
for provider in self.password_providers:
|
|
||||||
is_valid = yield provider.check_password(user_id, password)
|
|
||||||
if is_valid:
|
|
||||||
defer.returnValue(user_id)
|
|
||||||
|
|
||||||
canonical_user_id = yield self._check_local_password(user_id, password)
|
if username.startswith('@'):
|
||||||
|
qualified_user_id = username
|
||||||
|
else:
|
||||||
|
qualified_user_id = UserID(
|
||||||
|
username, self.hs.hostname
|
||||||
|
).to_string()
|
||||||
|
|
||||||
|
login_type = login_submission.get("type")
|
||||||
|
known_login_type = False
|
||||||
|
|
||||||
|
# special case to check for "password" for the check_password interface
|
||||||
|
# for the auth providers
|
||||||
|
password = login_submission.get("password")
|
||||||
|
if login_type == LoginType.PASSWORD:
|
||||||
|
if not self._password_enabled:
|
||||||
|
raise SynapseError(400, "Password login has been disabled.")
|
||||||
|
if not password:
|
||||||
|
raise SynapseError(400, "Missing parameter: password")
|
||||||
|
|
||||||
|
for provider in self.password_providers:
|
||||||
|
if (hasattr(provider, "check_password")
|
||||||
|
and login_type == LoginType.PASSWORD):
|
||||||
|
known_login_type = True
|
||||||
|
is_valid = yield provider.check_password(
|
||||||
|
qualified_user_id, password,
|
||||||
|
)
|
||||||
|
if is_valid:
|
||||||
|
defer.returnValue(qualified_user_id)
|
||||||
|
|
||||||
|
if (not hasattr(provider, "get_supported_login_types")
|
||||||
|
or not hasattr(provider, "check_auth")):
|
||||||
|
# this password provider doesn't understand custom login types
|
||||||
|
continue
|
||||||
|
|
||||||
|
supported_login_types = provider.get_supported_login_types()
|
||||||
|
if login_type not in supported_login_types:
|
||||||
|
# this password provider doesn't understand this login type
|
||||||
|
continue
|
||||||
|
|
||||||
|
known_login_type = True
|
||||||
|
login_fields = supported_login_types[login_type]
|
||||||
|
|
||||||
|
missing_fields = []
|
||||||
|
login_dict = {}
|
||||||
|
for f in login_fields:
|
||||||
|
if f not in login_submission:
|
||||||
|
missing_fields.append(f)
|
||||||
|
else:
|
||||||
|
login_dict[f] = login_submission[f]
|
||||||
|
if missing_fields:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Missing parameters for login type %s: %s" % (
|
||||||
|
login_type,
|
||||||
|
missing_fields,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = yield provider.check_auth(
|
||||||
|
username, login_type, login_dict,
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
if isinstance(result, str):
|
||||||
|
result = (result, None)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
if login_type == LoginType.PASSWORD:
|
||||||
|
known_login_type = True
|
||||||
|
|
||||||
|
canonical_user_id = yield self._check_local_password(
|
||||||
|
qualified_user_id, password,
|
||||||
|
)
|
||||||
|
|
||||||
if canonical_user_id:
|
if canonical_user_id:
|
||||||
defer.returnValue(canonical_user_id)
|
defer.returnValue((canonical_user_id, None))
|
||||||
|
|
||||||
|
if not known_login_type:
|
||||||
|
raise SynapseError(400, "Unknown login type %s" % login_type)
|
||||||
|
|
||||||
# unknown username or invalid password. We raise a 403 here, but note
|
# unknown username or invalid password. We raise a 403 here, but note
|
||||||
# that if we're doing user-interactive login, it turns all LoginErrors
|
# that if we're doing user-interactive login, it turns all LoginErrors
|
||||||
@ -584,13 +661,80 @@ class AuthHandler(BaseHandler):
|
|||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||||
raise e
|
raise e
|
||||||
yield self.store.user_delete_access_tokens(
|
yield self.delete_access_tokens_for_user(
|
||||||
user_id, except_access_token_id
|
user_id, except_token_id=except_access_token_id,
|
||||||
)
|
)
|
||||||
yield self.hs.get_pusherpool().remove_pushers_by_user(
|
yield self.hs.get_pusherpool().remove_pushers_by_user(
|
||||||
user_id, except_access_token_id
|
user_id, except_access_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def deactivate_account(self, user_id):
|
||||||
|
"""Deactivate a user's account
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): ID of user to be deactivated
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
# FIXME: Theoretically there is a race here wherein user resets
|
||||||
|
# password using threepid.
|
||||||
|
yield self.delete_access_tokens_for_user(user_id)
|
||||||
|
yield self.store.user_delete_threepids(user_id)
|
||||||
|
yield self.store.user_set_password_hash(user_id, None)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_access_token(self, access_token):
|
||||||
|
"""Invalidate a single access token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token (str): access token to be deleted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
user_info = yield self.auth.get_user_by_access_token(access_token)
|
||||||
|
yield self.store.delete_access_token(access_token)
|
||||||
|
|
||||||
|
# see if any of our auth providers want to know about this
|
||||||
|
for provider in self.password_providers:
|
||||||
|
if hasattr(provider, "on_logged_out"):
|
||||||
|
yield provider.on_logged_out(
|
||||||
|
user_id=str(user_info["user"]),
|
||||||
|
device_id=user_info["device_id"],
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
|
||||||
|
device_id=None):
|
||||||
|
"""Invalidate access tokens belonging to a user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): ID of user the tokens belong to
|
||||||
|
except_token_id (str|None): access_token ID which should *not* be
|
||||||
|
deleted
|
||||||
|
device_id (str|None): ID of device the tokens are associated with.
|
||||||
|
If None, tokens associated with any device (or no device) will
|
||||||
|
be deleted
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
tokens_and_devices = yield self.store.user_delete_access_tokens(
|
||||||
|
user_id, except_token_id=except_token_id, device_id=device_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# see if any of our auth providers want to know about this
|
||||||
|
for provider in self.password_providers:
|
||||||
|
if hasattr(provider, "on_logged_out"):
|
||||||
|
for token, device_id in tokens_and_devices:
|
||||||
|
yield provider.on_logged_out(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
access_token=token,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_threepid(self, user_id, medium, address, validated_at):
|
def add_threepid(self, user_id, medium, address, validated_at):
|
||||||
# 'Canonicalise' email addresses down to lower case.
|
# 'Canonicalise' email addresses down to lower case.
|
||||||
@ -696,30 +840,3 @@ class MacaroonGeneartor(object):
|
|||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
return macaroon
|
return macaroon
|
||||||
|
|
||||||
|
|
||||||
class _AccountHandler(object):
|
|
||||||
"""A proxy object that gets passed to password auth providers so they
|
|
||||||
can register new users etc if necessary.
|
|
||||||
"""
|
|
||||||
def __init__(self, hs, check_user_exists):
|
|
||||||
self.hs = hs
|
|
||||||
|
|
||||||
self._check_user_exists = check_user_exists
|
|
||||||
|
|
||||||
def check_user_exists(self, user_id):
|
|
||||||
"""Check if user exissts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred(bool)
|
|
||||||
"""
|
|
||||||
return self._check_user_exists(user_id)
|
|
||||||
|
|
||||||
def register(self, localpart):
|
|
||||||
"""Registers a new user with given localpart
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred: a 2-tuple of (user_id, access_token)
|
|
||||||
"""
|
|
||||||
reg = self.hs.get_handlers().registration_handler
|
|
||||||
return reg.register(localpart=localpart)
|
|
||||||
|
@ -34,6 +34,7 @@ class DeviceHandler(BaseHandler):
|
|||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self.federation_sender = hs.get_federation_sender()
|
self.federation_sender = hs.get_federation_sender()
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_replication_layer()
|
||||||
|
|
||||||
@ -159,9 +160,8 @@ class DeviceHandler(BaseHandler):
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
yield self.store.user_delete_access_tokens(
|
yield self._auth_handler.delete_access_tokens_for_user(
|
||||||
user_id, device_id=device_id,
|
user_id, device_id=device_id,
|
||||||
delete_refresh_tokens=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.store.delete_e2e_keys_by_device(
|
yield self.store.delete_e2e_keys_by_device(
|
||||||
@ -194,9 +194,8 @@ class DeviceHandler(BaseHandler):
|
|||||||
# Delete access tokens and e2e keys for each device. Not optimised as it is not
|
# Delete access tokens and e2e keys for each device. Not optimised as it is not
|
||||||
# considered as part of a critical path.
|
# considered as part of a critical path.
|
||||||
for device_id in device_ids:
|
for device_id in device_ids:
|
||||||
yield self.store.user_delete_access_tokens(
|
yield self._auth_handler.delete_access_tokens_for_user(
|
||||||
user_id, device_id=device_id,
|
user_id, device_id=device_id,
|
||||||
delete_refresh_tokens=True,
|
|
||||||
)
|
)
|
||||||
yield self.store.delete_e2e_keys_by_device(
|
yield self.store.delete_e2e_keys_by_device(
|
||||||
user_id=user_id, device_id=device_id
|
user_id=user_id, device_id=device_id
|
||||||
|
@ -1706,6 +1706,17 @@ class FederationHandler(BaseHandler):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def do_auth(self, origin, event, context, auth_events):
|
def do_auth(self, origin, event, context, auth_events):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin (str):
|
||||||
|
event (synapse.events.FrozenEvent):
|
||||||
|
context (synapse.events.snapshot.EventContext):
|
||||||
|
auth_events (dict[(str, str)->str]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred[None]
|
||||||
|
"""
|
||||||
# Check if we have all the auth events.
|
# Check if we have all the auth events.
|
||||||
current_state = set(e.event_id for e in auth_events.values())
|
current_state = set(e.event_id for e in auth_events.values())
|
||||||
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
||||||
@ -1817,16 +1828,9 @@ class FederationHandler(BaseHandler):
|
|||||||
current_state = set(e.event_id for e in auth_events.values())
|
current_state = set(e.event_id for e in auth_events.values())
|
||||||
different_auth = event_auth_events - current_state
|
different_auth = event_auth_events - current_state
|
||||||
|
|
||||||
context.current_state_ids = dict(context.current_state_ids)
|
self._update_context_for_auth_events(
|
||||||
context.current_state_ids.update({
|
context, auth_events, event_key,
|
||||||
k: a.event_id for k, a in auth_events.items()
|
)
|
||||||
if k != event_key
|
|
||||||
})
|
|
||||||
context.prev_state_ids = dict(context.prev_state_ids)
|
|
||||||
context.prev_state_ids.update({
|
|
||||||
k: a.event_id for k, a in auth_events.items()
|
|
||||||
})
|
|
||||||
context.state_group = self.store.get_next_state_group()
|
|
||||||
|
|
||||||
if different_auth and not event.internal_metadata.is_outlier():
|
if different_auth and not event.internal_metadata.is_outlier():
|
||||||
logger.info("Different auth after resolution: %s", different_auth)
|
logger.info("Different auth after resolution: %s", different_auth)
|
||||||
@ -1906,16 +1910,9 @@ class FederationHandler(BaseHandler):
|
|||||||
# 4. Look at rejects and their proofs.
|
# 4. Look at rejects and their proofs.
|
||||||
# TODO.
|
# TODO.
|
||||||
|
|
||||||
context.current_state_ids = dict(context.current_state_ids)
|
self._update_context_for_auth_events(
|
||||||
context.current_state_ids.update({
|
context, auth_events, event_key,
|
||||||
k: a.event_id for k, a in auth_events.items()
|
)
|
||||||
if k != event_key
|
|
||||||
})
|
|
||||||
context.prev_state_ids = dict(context.prev_state_ids)
|
|
||||||
context.prev_state_ids.update({
|
|
||||||
k: a.event_id for k, a in auth_events.items()
|
|
||||||
})
|
|
||||||
context.state_group = self.store.get_next_state_group()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, auth_events=auth_events)
|
self.auth.check(event, auth_events=auth_events)
|
||||||
@ -1923,6 +1920,35 @@ class FederationHandler(BaseHandler):
|
|||||||
logger.warn("Failed auth resolution for %r because %s", event, e)
|
logger.warn("Failed auth resolution for %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def _update_context_for_auth_events(self, context, auth_events,
|
||||||
|
event_key):
|
||||||
|
"""Update the state_ids in an event context after auth event resolution
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (synapse.events.snapshot.EventContext): event context
|
||||||
|
to be updated
|
||||||
|
|
||||||
|
auth_events (dict[(str, str)->str]): Events to update in the event
|
||||||
|
context.
|
||||||
|
|
||||||
|
event_key ((str, str)): (type, state_key) for the current event.
|
||||||
|
this will not be included in the current_state in the context.
|
||||||
|
"""
|
||||||
|
state_updates = {
|
||||||
|
k: a.event_id for k, a in auth_events.iteritems()
|
||||||
|
if k != event_key
|
||||||
|
}
|
||||||
|
context.current_state_ids = dict(context.current_state_ids)
|
||||||
|
context.current_state_ids.update(state_updates)
|
||||||
|
if context.delta_ids is not None:
|
||||||
|
context.delta_ids = dict(context.delta_ids)
|
||||||
|
context.delta_ids.update(state_updates)
|
||||||
|
context.prev_state_ids = dict(context.prev_state_ids)
|
||||||
|
context.prev_state_ids.update({
|
||||||
|
k: a.event_id for k, a in auth_events.iteritems()
|
||||||
|
})
|
||||||
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def construct_auth_difference(self, local_auth, remote_auth):
|
def construct_auth_difference(self, local_auth, remote_auth):
|
||||||
""" Given a local and remote auth chain, find the differences. This
|
""" Given a local and remote auth chain, find the differences. This
|
||||||
|
@ -71,6 +71,7 @@ class GroupsLocalHandler(object):
|
|||||||
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
|
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
|
||||||
|
|
||||||
add_room_to_group = _create_rerouter("add_room_to_group")
|
add_room_to_group = _create_rerouter("add_room_to_group")
|
||||||
|
update_room_in_group = _create_rerouter("update_room_in_group")
|
||||||
remove_room_from_group = _create_rerouter("remove_room_from_group")
|
remove_room_from_group = _create_rerouter("remove_room_from_group")
|
||||||
|
|
||||||
update_group_summary_room = _create_rerouter("update_group_summary_room")
|
update_group_summary_room = _create_rerouter("update_group_summary_room")
|
||||||
|
@ -17,7 +17,6 @@ import logging
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.types
|
|
||||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||||
from synapse.types import UserID, get_domain_from_id
|
from synapse.types import UserID, get_domain_from_id
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
@ -140,7 +139,7 @@ class ProfileHandler(BaseHandler):
|
|||||||
target_user.localpart, new_displayname
|
target_user.localpart, new_displayname
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._update_join_states(requester)
|
yield self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_avatar_url(self, target_user):
|
def get_avatar_url(self, target_user):
|
||||||
@ -184,7 +183,7 @@ class ProfileHandler(BaseHandler):
|
|||||||
target_user.localpart, new_avatar_url
|
target_user.localpart, new_avatar_url
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._update_join_states(requester)
|
yield self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_profile_query(self, args):
|
def on_profile_query(self, args):
|
||||||
@ -209,28 +208,24 @@ class ProfileHandler(BaseHandler):
|
|||||||
defer.returnValue(response)
|
defer.returnValue(response)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _update_join_states(self, requester):
|
def _update_join_states(self, requester, target_user):
|
||||||
user = requester.user
|
if not self.hs.is_mine(target_user):
|
||||||
if not self.hs.is_mine(user):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
yield self.ratelimit(requester)
|
yield self.ratelimit(requester)
|
||||||
|
|
||||||
room_ids = yield self.store.get_rooms_for_user(
|
room_ids = yield self.store.get_rooms_for_user(
|
||||||
user.to_string(),
|
target_user.to_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
handler = self.hs.get_handlers().room_member_handler
|
handler = self.hs.get_handlers().room_member_handler
|
||||||
try:
|
try:
|
||||||
# Assume the user isn't a guest because we don't let guests set
|
# Assume the target_user isn't a guest,
|
||||||
# profile or avatar data.
|
# because we don't let guests set profile or avatar data.
|
||||||
# XXX why are we recreating `requester` here for each room?
|
|
||||||
# what was wrong with the `requester` we were passed?
|
|
||||||
requester = synapse.types.create_requester(user)
|
|
||||||
yield handler.update_membership(
|
yield handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
user,
|
target_user,
|
||||||
room_id,
|
room_id,
|
||||||
"join", # We treat a profile update like a join.
|
"join", # We treat a profile update like a join.
|
||||||
ratelimit=False, # Try to hide that these events aren't atomic.
|
ratelimit=False, # Try to hide that these events aren't atomic.
|
||||||
|
@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
super(RegistrationHandler, self).__init__(hs)
|
super(RegistrationHandler, self).__init__(hs)
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self.profile_handler = hs.get_profile_handler()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||||
|
|
||||||
@ -416,7 +417,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
create_profile_with_localpart=user.localpart,
|
create_profile_with_localpart=user.localpart,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield self.store.user_delete_access_tokens(user_id=user_id)
|
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||||
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||||
|
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
|
@ -20,6 +20,7 @@ from ._base import BaseHandler
|
|||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventTypes, JoinRules,
|
EventTypes, JoinRules,
|
||||||
)
|
)
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async import concurrently_execute
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
@ -70,6 +71,7 @@ class RoomListHandler(BaseHandler):
|
|||||||
if search_filter:
|
if search_filter:
|
||||||
# We explicitly don't bother caching searches or requests for
|
# We explicitly don't bother caching searches or requests for
|
||||||
# appservice specific lists.
|
# appservice specific lists.
|
||||||
|
logger.info("Bypassing cache as search request.")
|
||||||
return self._get_public_room_list(
|
return self._get_public_room_list(
|
||||||
limit, since_token, search_filter, network_tuple=network_tuple,
|
limit, since_token, search_filter, network_tuple=network_tuple,
|
||||||
)
|
)
|
||||||
@ -77,13 +79,16 @@ class RoomListHandler(BaseHandler):
|
|||||||
key = (limit, since_token, network_tuple)
|
key = (limit, since_token, network_tuple)
|
||||||
result = self.response_cache.get(key)
|
result = self.response_cache.get(key)
|
||||||
if not result:
|
if not result:
|
||||||
|
logger.info("No cached result, calculating one.")
|
||||||
result = self.response_cache.set(
|
result = self.response_cache.set(
|
||||||
key,
|
key,
|
||||||
self._get_public_room_list(
|
preserve_fn(self._get_public_room_list)(
|
||||||
limit, since_token, network_tuple=network_tuple
|
limit, since_token, network_tuple=network_tuple
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return result
|
else:
|
||||||
|
logger.info("Using cached deferred result.")
|
||||||
|
return make_deferred_yieldable(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_public_room_list(self, limit=None, since_token=None,
|
def _get_public_room_list(self, limit=None, since_token=None,
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async import concurrently_execute
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util.metrics import Measure, measure_func
|
from synapse.util.metrics import Measure, measure_func
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.push.clientformat import format_push_rules_for_user
|
from synapse.push.clientformat import format_push_rules_for_user
|
||||||
@ -184,11 +184,11 @@ class SyncHandler(object):
|
|||||||
if not result:
|
if not result:
|
||||||
result = self.response_cache.set(
|
result = self.response_cache.set(
|
||||||
sync_config.request_key,
|
sync_config.request_key,
|
||||||
self._wait_for_sync_for_user(
|
preserve_fn(self._wait_for_sync_for_user)(
|
||||||
sync_config, since_token, timeout, full_state
|
sync_config, since_token, timeout, full_state
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return result
|
return make_deferred_yieldable(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
|
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
|
||||||
|
@ -152,7 +152,7 @@ class UserDirectoyHandler(object):
|
|||||||
|
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
|
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
|
||||||
yield self._handle_intial_room(room_id)
|
yield self._handle_initial_room(room_id)
|
||||||
num_processed_rooms += 1
|
num_processed_rooms += 1
|
||||||
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
||||||
|
|
||||||
@ -166,7 +166,7 @@ class UserDirectoyHandler(object):
|
|||||||
yield self.store.update_user_directory_stream_pos(new_pos)
|
yield self.store.update_user_directory_stream_pos(new_pos)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _handle_intial_room(self, room_id):
|
def _handle_initial_room(self, room_id):
|
||||||
"""Called when we initially fill out user_directory one room at a time
|
"""Called when we initially fill out user_directory one room at a time
|
||||||
"""
|
"""
|
||||||
is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
|
is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
|
||||||
|
55
synapse/http/additional_resource.py
Normal file
55
synapse/http/additional_resource.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.http.server import wrap_request_handler
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
|
||||||
|
|
||||||
|
class AdditionalResource(Resource):
|
||||||
|
"""Resource wrapper for additional_resources
|
||||||
|
|
||||||
|
If the user has configured additional_resources, we need to wrap the
|
||||||
|
handler class with a Resource so that we can map it into the resource tree.
|
||||||
|
|
||||||
|
This class is also where we wrap the request handler with logging, metrics,
|
||||||
|
and exception handling.
|
||||||
|
"""
|
||||||
|
def __init__(self, hs, handler):
|
||||||
|
"""Initialise AdditionalResource
|
||||||
|
|
||||||
|
The ``handler`` should return a deferred which completes when it has
|
||||||
|
done handling the request. It should write a response with
|
||||||
|
``request.write()``, and call ``request.finish()``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): homeserver
|
||||||
|
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
|
||||||
|
function to be called to handle the request.
|
||||||
|
"""
|
||||||
|
Resource.__init__(self)
|
||||||
|
self._handler = handler
|
||||||
|
|
||||||
|
# these are required by the request_handler wrapper
|
||||||
|
self.version_string = hs.version_string
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
def render(self, request):
|
||||||
|
self._async_render(request)
|
||||||
|
return NOT_DONE_YET
|
||||||
|
|
||||||
|
@wrap_request_handler
|
||||||
|
def _async_render(self, request):
|
||||||
|
return self._handler(request)
|
@ -18,7 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
|
|||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
|
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
|
||||||
)
|
)
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
from synapse.util import logcontext
|
from synapse.util import logcontext
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
from synapse.http.endpoint import SpiderEndpoint
|
from synapse.http.endpoint import SpiderEndpoint
|
||||||
@ -114,43 +114,73 @@ class SimpleHttpClient(object):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def post_urlencoded_get_json(self, uri, args={}):
|
def post_urlencoded_get_json(self, uri, args={}, headers=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
uri (str):
|
||||||
|
args (dict[str, str|List[str]]): query params
|
||||||
|
headers (dict[str, List[str]]|None): If not None, a map from
|
||||||
|
header name to a list of values for that header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[object]: parsed json
|
||||||
|
"""
|
||||||
|
|
||||||
# TODO: Do we ever want to log message contents?
|
# TODO: Do we ever want to log message contents?
|
||||||
logger.debug("post_urlencoded_get_json args: %s", args)
|
logger.debug("post_urlencoded_get_json args: %s", args)
|
||||||
|
|
||||||
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
|
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
|
||||||
|
|
||||||
|
actual_headers = {
|
||||||
|
b"Content-Type": [b"application/x-www-form-urlencoded"],
|
||||||
|
b"User-Agent": [self.user_agent],
|
||||||
|
}
|
||||||
|
if headers:
|
||||||
|
actual_headers.update(headers)
|
||||||
|
|
||||||
response = yield self.request(
|
response = yield self.request(
|
||||||
"POST",
|
"POST",
|
||||||
uri.encode("ascii"),
|
uri.encode("ascii"),
|
||||||
headers=Headers({
|
headers=Headers(actual_headers),
|
||||||
b"Content-Type": [b"application/x-www-form-urlencoded"],
|
|
||||||
b"User-Agent": [self.user_agent],
|
|
||||||
}),
|
|
||||||
bodyProducer=FileBodyProducer(StringIO(query_bytes))
|
bodyProducer=FileBodyProducer(StringIO(query_bytes))
|
||||||
)
|
)
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield make_deferred_yieldable(readBody(response))
|
||||||
|
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def post_json_get_json(self, uri, post_json):
|
def post_json_get_json(self, uri, post_json, headers=None):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri (str):
|
||||||
|
post_json (object):
|
||||||
|
headers (dict[str, List[str]]|None): If not None, a map from
|
||||||
|
header name to a list of values for that header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[object]: parsed json
|
||||||
|
"""
|
||||||
json_str = encode_canonical_json(post_json)
|
json_str = encode_canonical_json(post_json)
|
||||||
|
|
||||||
logger.debug("HTTP POST %s -> %s", json_str, uri)
|
logger.debug("HTTP POST %s -> %s", json_str, uri)
|
||||||
|
|
||||||
|
actual_headers = {
|
||||||
|
b"Content-Type": [b"application/json"],
|
||||||
|
b"User-Agent": [self.user_agent],
|
||||||
|
}
|
||||||
|
if headers:
|
||||||
|
actual_headers.update(headers)
|
||||||
|
|
||||||
response = yield self.request(
|
response = yield self.request(
|
||||||
"POST",
|
"POST",
|
||||||
uri.encode("ascii"),
|
uri.encode("ascii"),
|
||||||
headers=Headers({
|
headers=Headers(actual_headers),
|
||||||
b"Content-Type": [b"application/json"],
|
|
||||||
b"User-Agent": [self.user_agent],
|
|
||||||
}),
|
|
||||||
bodyProducer=FileBodyProducer(StringIO(json_str))
|
bodyProducer=FileBodyProducer(StringIO(json_str))
|
||||||
)
|
)
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield make_deferred_yieldable(readBody(response))
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
@ -160,7 +190,7 @@ class SimpleHttpClient(object):
|
|||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_json(self, uri, args={}):
|
def get_json(self, uri, args={}, headers=None):
|
||||||
""" Gets some json from the given URI.
|
""" Gets some json from the given URI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -169,6 +199,8 @@ class SimpleHttpClient(object):
|
|||||||
None.
|
None.
|
||||||
**Note**: The value of each key is assumed to be an iterable
|
**Note**: The value of each key is assumed to be an iterable
|
||||||
and *not* a string.
|
and *not* a string.
|
||||||
|
headers (dict[str, List[str]]|None): If not None, a map from
|
||||||
|
header name to a list of values for that header
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||||
HTTP body as JSON.
|
HTTP body as JSON.
|
||||||
@ -177,13 +209,13 @@ class SimpleHttpClient(object):
|
|||||||
error message.
|
error message.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
body = yield self.get_raw(uri, args)
|
body = yield self.get_raw(uri, args, headers=headers)
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
raise self._exceptionFromFailedRequest(e.code, e.msg)
|
raise self._exceptionFromFailedRequest(e.code, e.msg)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def put_json(self, uri, json_body, args={}):
|
def put_json(self, uri, json_body, args={}, headers=None):
|
||||||
""" Puts some json to the given URI.
|
""" Puts some json to the given URI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -193,6 +225,8 @@ class SimpleHttpClient(object):
|
|||||||
None.
|
None.
|
||||||
**Note**: The value of each key is assumed to be an iterable
|
**Note**: The value of each key is assumed to be an iterable
|
||||||
and *not* a string.
|
and *not* a string.
|
||||||
|
headers (dict[str, List[str]]|None): If not None, a map from
|
||||||
|
header name to a list of values for that header
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||||
HTTP body as JSON.
|
HTTP body as JSON.
|
||||||
@ -205,17 +239,21 @@ class SimpleHttpClient(object):
|
|||||||
|
|
||||||
json_str = encode_canonical_json(json_body)
|
json_str = encode_canonical_json(json_body)
|
||||||
|
|
||||||
|
actual_headers = {
|
||||||
|
b"Content-Type": [b"application/json"],
|
||||||
|
b"User-Agent": [self.user_agent],
|
||||||
|
}
|
||||||
|
if headers:
|
||||||
|
actual_headers.update(headers)
|
||||||
|
|
||||||
response = yield self.request(
|
response = yield self.request(
|
||||||
"PUT",
|
"PUT",
|
||||||
uri.encode("ascii"),
|
uri.encode("ascii"),
|
||||||
headers=Headers({
|
headers=Headers(actual_headers),
|
||||||
b"User-Agent": [self.user_agent],
|
|
||||||
"Content-Type": ["application/json"]
|
|
||||||
}),
|
|
||||||
bodyProducer=FileBodyProducer(StringIO(json_str))
|
bodyProducer=FileBodyProducer(StringIO(json_str))
|
||||||
)
|
)
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield make_deferred_yieldable(readBody(response))
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
@ -226,7 +264,7 @@ class SimpleHttpClient(object):
|
|||||||
raise CodeMessageException(response.code, body)
|
raise CodeMessageException(response.code, body)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_raw(self, uri, args={}):
|
def get_raw(self, uri, args={}, headers=None):
|
||||||
""" Gets raw text from the given URI.
|
""" Gets raw text from the given URI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -235,6 +273,8 @@ class SimpleHttpClient(object):
|
|||||||
None.
|
None.
|
||||||
**Note**: The value of each key is assumed to be an iterable
|
**Note**: The value of each key is assumed to be an iterable
|
||||||
and *not* a string.
|
and *not* a string.
|
||||||
|
headers (dict[str, List[str]]|None): If not None, a map from
|
||||||
|
header name to a list of values for that header
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||||
HTTP body at text.
|
HTTP body at text.
|
||||||
@ -246,15 +286,19 @@ class SimpleHttpClient(object):
|
|||||||
query_bytes = urllib.urlencode(args, True)
|
query_bytes = urllib.urlencode(args, True)
|
||||||
uri = "%s?%s" % (uri, query_bytes)
|
uri = "%s?%s" % (uri, query_bytes)
|
||||||
|
|
||||||
|
actual_headers = {
|
||||||
|
b"User-Agent": [self.user_agent],
|
||||||
|
}
|
||||||
|
if headers:
|
||||||
|
actual_headers.update(headers)
|
||||||
|
|
||||||
response = yield self.request(
|
response = yield self.request(
|
||||||
"GET",
|
"GET",
|
||||||
uri.encode("ascii"),
|
uri.encode("ascii"),
|
||||||
headers=Headers({
|
headers=Headers(actual_headers),
|
||||||
b"User-Agent": [self.user_agent],
|
|
||||||
})
|
|
||||||
)
|
)
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield make_deferred_yieldable(readBody(response))
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
defer.returnValue(body)
|
defer.returnValue(body)
|
||||||
@ -274,27 +318,33 @@ class SimpleHttpClient(object):
|
|||||||
# The two should be factored out.
|
# The two should be factored out.
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_file(self, url, output_stream, max_size=None):
|
def get_file(self, url, output_stream, max_size=None, headers=None):
|
||||||
"""GETs a file from a given URL
|
"""GETs a file from a given URL
|
||||||
Args:
|
Args:
|
||||||
url (str): The URL to GET
|
url (str): The URL to GET
|
||||||
output_stream (file): File to write the response body to.
|
output_stream (file): File to write the response body to.
|
||||||
|
headers (dict[str, List[str]]|None): If not None, a map from
|
||||||
|
header name to a list of values for that header
|
||||||
Returns:
|
Returns:
|
||||||
A (int,dict,string,int) tuple of the file length, dict of the response
|
A (int,dict,string,int) tuple of the file length, dict of the response
|
||||||
headers, absolute URI of the response and HTTP response code.
|
headers, absolute URI of the response and HTTP response code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
actual_headers = {
|
||||||
|
b"User-Agent": [self.user_agent],
|
||||||
|
}
|
||||||
|
if headers:
|
||||||
|
actual_headers.update(headers)
|
||||||
|
|
||||||
response = yield self.request(
|
response = yield self.request(
|
||||||
"GET",
|
"GET",
|
||||||
url.encode("ascii"),
|
url.encode("ascii"),
|
||||||
headers=Headers({
|
headers=Headers(actual_headers),
|
||||||
b"User-Agent": [self.user_agent],
|
|
||||||
})
|
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = dict(response.headers.getAllRawHeaders())
|
resp_headers = dict(response.headers.getAllRawHeaders())
|
||||||
|
|
||||||
if 'Content-Length' in headers and headers['Content-Length'] > max_size:
|
if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
|
||||||
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
|
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
502,
|
502,
|
||||||
@ -315,10 +365,9 @@ class SimpleHttpClient(object):
|
|||||||
# straight back in again
|
# straight back in again
|
||||||
|
|
||||||
try:
|
try:
|
||||||
length = yield preserve_context_over_fn(
|
length = yield make_deferred_yieldable(_readBodyToFile(
|
||||||
_readBodyToFile,
|
response, output_stream, max_size,
|
||||||
response, output_stream, max_size
|
))
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to download body")
|
logger.exception("Failed to download body")
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
@ -327,7 +376,9 @@ class SimpleHttpClient(object):
|
|||||||
Codes.UNKNOWN,
|
Codes.UNKNOWN,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((length, headers, response.request.absoluteURI, response.code))
|
defer.returnValue(
|
||||||
|
(length, resp_headers, response.request.absoluteURI, response.code),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
||||||
@ -395,7 +446,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield make_deferred_yieldable(readBody(response))
|
||||||
defer.returnValue(body)
|
defer.returnValue(body)
|
||||||
except PartialDownloadError as e:
|
except PartialDownloadError as e:
|
||||||
# twisted dislikes google's response, no content length.
|
# twisted dislikes google's response, no content length.
|
||||||
|
@ -167,7 +167,8 @@ def parse_json_value_from_request(request):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
content = simplejson.loads(content_bytes)
|
content = simplejson.loads(content_bytes)
|
||||||
except simplejson.JSONDecodeError:
|
except Exception as e:
|
||||||
|
logger.warn("Unable to parse JSON: %s", e)
|
||||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
113
synapse/module_api/__init__.py
Normal file
113
synapse/module_api/__init__.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleApi(object):
|
||||||
|
"""A proxy object that gets passed to password auth providers so they
|
||||||
|
can register new users etc if necessary.
|
||||||
|
"""
|
||||||
|
def __init__(self, hs, auth_handler):
|
||||||
|
self.hs = hs
|
||||||
|
|
||||||
|
self._store = hs.get_datastore()
|
||||||
|
self._auth = hs.get_auth()
|
||||||
|
self._auth_handler = auth_handler
|
||||||
|
|
||||||
|
def get_user_by_req(self, req, allow_guest=False):
|
||||||
|
"""Check the access_token provided for a request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req (twisted.web.server.Request): Incoming HTTP request
|
||||||
|
allow_guest (bool): True if guest users should be allowed. If this
|
||||||
|
is False, and the access token is for a guest user, an
|
||||||
|
AuthError will be thrown
|
||||||
|
Returns:
|
||||||
|
twisted.internet.defer.Deferred[synapse.types.Requester]:
|
||||||
|
the requester for this request
|
||||||
|
Raises:
|
||||||
|
synapse.api.errors.AuthError: if no user by that token exists,
|
||||||
|
or the token is invalid.
|
||||||
|
"""
|
||||||
|
return self._auth.get_user_by_req(req, allow_guest)
|
||||||
|
|
||||||
|
def get_qualified_user_id(self, username):
|
||||||
|
"""Qualify a user id, if necessary
|
||||||
|
|
||||||
|
Takes a user id provided by the user and adds the @ and :domain to
|
||||||
|
qualify it, if necessary
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username (str): provided user id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: qualified @user:id
|
||||||
|
"""
|
||||||
|
if username.startswith('@'):
|
||||||
|
return username
|
||||||
|
return UserID(username, self.hs.hostname).to_string()
|
||||||
|
|
||||||
|
def check_user_exists(self, user_id):
|
||||||
|
"""Check if user exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): Complete @user:id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[str|None]: Canonical (case-corrected) user_id, or None
|
||||||
|
if the user is not registered.
|
||||||
|
"""
|
||||||
|
return self._auth_handler.check_user_exists(user_id)
|
||||||
|
|
||||||
|
def register(self, localpart):
|
||||||
|
"""Registers a new user with given localpart
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: a 2-tuple of (user_id, access_token)
|
||||||
|
"""
|
||||||
|
reg = self.hs.get_handlers().registration_handler
|
||||||
|
return reg.register(localpart=localpart)
|
||||||
|
|
||||||
|
def invalidate_access_token(self, access_token):
|
||||||
|
"""Invalidate an access token for a user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token(str): access token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
twisted.internet.defer.Deferred - resolves once the access token
|
||||||
|
has been removed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
synapse.api.errors.AuthError: the access token is invalid
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._auth_handler.delete_access_token(access_token)
|
||||||
|
|
||||||
|
def run_db_interaction(self, desc, func, *args, **kwargs):
|
||||||
|
"""Run a function with a database connection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
desc (str): description for the transaction, for metrics etc
|
||||||
|
func (func): function to be run. Passed a database cursor object
|
||||||
|
as well as *args and **kwargs
|
||||||
|
*args: positional args to be passed to func
|
||||||
|
**kwargs: named args to be passed to func
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[object]: result of func
|
||||||
|
"""
|
||||||
|
return self._store.runInteraction(desc, func, *args, **kwargs)
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class BaseSlavedStore(SQLBaseStore):
|
class BaseSlavedStore(SQLBaseStore):
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(BaseSlavedStore, self).__init__(hs)
|
super(BaseSlavedStore, self).__init__(db_conn, hs)
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
self._cache_id_gen = SlavedIdTracker(
|
self._cache_id_gen = SlavedIdTracker(
|
||||||
db_conn, "cache_invalidation_stream", "stream_id",
|
db_conn, "cache_invalidation_stream", "stream_id",
|
||||||
|
@ -137,7 +137,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
|||||||
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
|
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.store = hs.get_datastore()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
super(DeactivateAccountRestServlet, self).__init__(hs)
|
super(DeactivateAccountRestServlet, self).__init__(hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -149,12 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
|||||||
if not is_admin:
|
if not is_admin:
|
||||||
raise AuthError(403, "You are not a server admin")
|
raise AuthError(403, "You are not a server admin")
|
||||||
|
|
||||||
# FIXME: Theoretically there is a race here wherein user resets password
|
yield self._auth_handler.deactivate_account(target_user_id)
|
||||||
# using threepid.
|
|
||||||
yield self.store.user_delete_access_tokens(target_user_id)
|
|
||||||
yield self.store.user_delete_threepids(target_user_id)
|
|
||||||
yield self.store.user_set_password_hash(target_user_id, None)
|
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
|
|||||||
|
|
||||||
class LoginRestServlet(ClientV1RestServlet):
|
class LoginRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/login$")
|
PATTERNS = client_path_patterns("/login$")
|
||||||
PASS_TYPE = "m.login.password"
|
|
||||||
SAML2_TYPE = "m.login.saml2"
|
SAML2_TYPE = "m.login.saml2"
|
||||||
CAS_TYPE = "m.login.cas"
|
CAS_TYPE = "m.login.cas"
|
||||||
TOKEN_TYPE = "m.login.token"
|
TOKEN_TYPE = "m.login.token"
|
||||||
@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(LoginRestServlet, self).__init__(hs)
|
super(LoginRestServlet, self).__init__(hs)
|
||||||
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
||||||
self.password_enabled = hs.config.password_enabled
|
|
||||||
self.saml2_enabled = hs.config.saml2_enabled
|
self.saml2_enabled = hs.config.saml2_enabled
|
||||||
self.jwt_enabled = hs.config.jwt_enabled
|
self.jwt_enabled = hs.config.jwt_enabled
|
||||||
self.jwt_secret = hs.config.jwt_secret
|
self.jwt_secret = hs.config.jwt_secret
|
||||||
@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
# fall back to the fallback API if they don't understand one of the
|
# fall back to the fallback API if they don't understand one of the
|
||||||
# login flow types returned.
|
# login flow types returned.
|
||||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||||
if self.password_enabled:
|
|
||||||
flows.append({"type": LoginRestServlet.PASS_TYPE})
|
flows.extend((
|
||||||
|
{"type": t} for t in self.auth_handler.get_supported_login_types()
|
||||||
|
))
|
||||||
|
|
||||||
return (200, {"flows": flows})
|
return (200, {"flows": flows})
|
||||||
|
|
||||||
@ -133,13 +133,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
login_submission = parse_json_object_from_request(request)
|
login_submission = parse_json_object_from_request(request)
|
||||||
try:
|
try:
|
||||||
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
if self.saml2_enabled and (login_submission["type"] ==
|
||||||
if not self.password_enabled:
|
|
||||||
raise SynapseError(400, "Password login has been disabled.")
|
|
||||||
|
|
||||||
result = yield self.do_password_login(login_submission)
|
|
||||||
defer.returnValue(result)
|
|
||||||
elif self.saml2_enabled and (login_submission["type"] ==
|
|
||||||
LoginRestServlet.SAML2_TYPE):
|
LoginRestServlet.SAML2_TYPE):
|
||||||
relay_state = ""
|
relay_state = ""
|
||||||
if "relay_state" in login_submission:
|
if "relay_state" in login_submission:
|
||||||
@ -157,15 +151,31 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
result = yield self.do_token_login(login_submission)
|
result = yield self.do_token_login(login_submission)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
else:
|
else:
|
||||||
raise SynapseError(400, "Bad login type.")
|
result = yield self._do_other_login(login_submission)
|
||||||
|
defer.returnValue(result)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SynapseError(400, "Missing JSON keys.")
|
raise SynapseError(400, "Missing JSON keys.")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_password_login(self, login_submission):
|
def _do_other_login(self, login_submission):
|
||||||
if "password" not in login_submission:
|
"""Handle non-token/saml/jwt logins
|
||||||
raise SynapseError(400, "Missing parameter: password")
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
login_submission:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(int, object): HTTP code/response
|
||||||
|
"""
|
||||||
|
# Log the request we got, but only certain fields to minimise the chance of
|
||||||
|
# logging someone's password (even if they accidentally put it in the wrong
|
||||||
|
# field)
|
||||||
|
logger.info(
|
||||||
|
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
|
||||||
|
login_submission.get('identifier'),
|
||||||
|
login_submission.get('medium'),
|
||||||
|
login_submission.get('address'),
|
||||||
|
login_submission.get('user'),
|
||||||
|
)
|
||||||
login_submission_legacy_convert(login_submission)
|
login_submission_legacy_convert(login_submission)
|
||||||
|
|
||||||
if "identifier" not in login_submission:
|
if "identifier" not in login_submission:
|
||||||
@ -208,30 +218,29 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
if "user" not in identifier:
|
if "user" not in identifier:
|
||||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||||
|
|
||||||
user_id = identifier["user"]
|
|
||||||
|
|
||||||
if not user_id.startswith('@'):
|
|
||||||
user_id = UserID(
|
|
||||||
user_id, self.hs.hostname
|
|
||||||
).to_string()
|
|
||||||
|
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id = yield auth_handler.validate_password_login(
|
canonical_user_id, callback = yield auth_handler.validate_login(
|
||||||
user_id=user_id,
|
identifier["user"],
|
||||||
password=login_submission["password"],
|
login_submission,
|
||||||
|
)
|
||||||
|
|
||||||
|
device_id = yield self._register_device(
|
||||||
|
canonical_user_id, login_submission,
|
||||||
)
|
)
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
|
||||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||||
user_id, device_id,
|
canonical_user_id, device_id,
|
||||||
login_submission.get("initial_device_display_name"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": canonical_user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
yield callback(result)
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -244,7 +253,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
device_id = yield self._register_device(user_id, login_submission)
|
device_id = yield self._register_device(user_id, login_submission)
|
||||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||||
user_id, device_id,
|
user_id, device_id,
|
||||||
login_submission.get("initial_device_display_name"),
|
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
@ -287,7 +295,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
)
|
)
|
||||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||||
registered_user_id, device_id,
|
registered_user_id, device_id,
|
||||||
login_submission.get("initial_device_display_name"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
|
@ -30,7 +30,7 @@ class LogoutRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(LogoutRestServlet, self).__init__(hs)
|
super(LogoutRestServlet, self).__init__(hs)
|
||||||
self.store = hs.get_datastore()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
def on_OPTIONS(self, request):
|
def on_OPTIONS(self, request):
|
||||||
return (200, {})
|
return (200, {})
|
||||||
@ -38,7 +38,7 @@ class LogoutRestServlet(ClientV1RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
access_token = get_access_token_from_request(request)
|
access_token = get_access_token_from_request(request)
|
||||||
yield self.store.delete_access_token(access_token)
|
yield self._auth_handler.delete_access_token(access_token)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
@ -47,8 +47,8 @@ class LogoutAllRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(LogoutAllRestServlet, self).__init__(hs)
|
super(LogoutAllRestServlet, self).__init__(hs)
|
||||||
self.store = hs.get_datastore()
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
def on_OPTIONS(self, request):
|
def on_OPTIONS(self, request):
|
||||||
return (200, {})
|
return (200, {})
|
||||||
@ -57,7 +57,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
|
|||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
yield self.store.user_delete_access_tokens(user_id)
|
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
@ -359,7 +359,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||||||
if compare_digest(want_mac, got_mac):
|
if compare_digest(want_mac, got_mac):
|
||||||
handler = self.handlers.registration_handler
|
handler = self.handlers.registration_handler
|
||||||
user_id, token = yield handler.register(
|
user_id, token = yield handler.register(
|
||||||
localpart=user,
|
localpart=user.lower(),
|
||||||
password=password,
|
password=password,
|
||||||
admin=bool(admin),
|
admin=bool(admin),
|
||||||
)
|
)
|
||||||
|
@ -13,22 +13,21 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.auth import has_access_token
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import LoginError, SynapseError, Codes
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet, parse_json_object_from_request, assert_params_in_request
|
RestServlet, assert_params_in_request,
|
||||||
|
parse_json_object_from_request,
|
||||||
)
|
)
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -163,7 +162,6 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
super(DeactivateAccountRestServlet, self).__init__()
|
super(DeactivateAccountRestServlet, self).__init__()
|
||||||
@ -172,6 +170,20 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
# if the caller provides an access token, it ought to be valid.
|
||||||
|
requester = None
|
||||||
|
if has_access_token(request):
|
||||||
|
requester = yield self.auth.get_user_by_req(
|
||||||
|
request,
|
||||||
|
) # type: synapse.types.Requester
|
||||||
|
|
||||||
|
# allow ASes to dectivate their own users
|
||||||
|
if requester and requester.app_service:
|
||||||
|
yield self.auth_handler.deactivate_account(
|
||||||
|
requester.user.to_string()
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||||
[LoginType.PASSWORD],
|
[LoginType.PASSWORD],
|
||||||
], body, self.hs.get_ip_from_request(request))
|
], body, self.hs.get_ip_from_request(request))
|
||||||
@ -179,25 +191,22 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||||||
if not authed:
|
if not authed:
|
||||||
defer.returnValue((401, result))
|
defer.returnValue((401, result))
|
||||||
|
|
||||||
user_id = None
|
|
||||||
requester = None
|
|
||||||
|
|
||||||
if LoginType.PASSWORD in result:
|
if LoginType.PASSWORD in result:
|
||||||
|
user_id = result[LoginType.PASSWORD]
|
||||||
# if using password, they should also be logged in
|
# if using password, they should also be logged in
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
if requester is None:
|
||||||
user_id = requester.user.to_string()
|
raise SynapseError(
|
||||||
if user_id != result[LoginType.PASSWORD]:
|
400,
|
||||||
|
"Deactivate account requires an access_token",
|
||||||
|
errcode=Codes.MISSING_TOKEN
|
||||||
|
)
|
||||||
|
if requester.user.to_string() != user_id:
|
||||||
raise LoginError(400, "", Codes.UNKNOWN)
|
raise LoginError(400, "", Codes.UNKNOWN)
|
||||||
else:
|
else:
|
||||||
logger.error("Auth succeeded but no known type!", result.keys())
|
logger.error("Auth succeeded but no known type!", result.keys())
|
||||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||||
|
|
||||||
# FIXME: Theoretically there is a race here wherein user resets password
|
yield self.auth_handler.deactivate_account(user_id)
|
||||||
# using threepid.
|
|
||||||
yield self.store.user_delete_access_tokens(user_id)
|
|
||||||
yield self.store.user_delete_threepids(user_id)
|
|
||||||
yield self.store.user_set_password_hash(user_id, None)
|
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
@ -373,6 +382,20 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class WhoamiRestServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/account/whoami$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(WhoamiRestServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
defer.returnValue((200, {'user_id': requester.user.to_string()}))
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
|
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||||
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
|
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||||
@ -382,3 +405,4 @@ def register_servlets(hs, http_server):
|
|||||||
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
|
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||||
ThreepidRestServlet(hs).register(http_server)
|
ThreepidRestServlet(hs).register(http_server)
|
||||||
ThreepidDeleteRestServlet(hs).register(http_server)
|
ThreepidDeleteRestServlet(hs).register(http_server)
|
||||||
|
WhoamiRestServlet(hs).register(http_server)
|
||||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DevicesRestServlet(servlet.RestServlet):
|
class DevicesRestServlet(servlet.RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
"""
|
"""
|
||||||
@ -51,7 +51,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
|||||||
API for bulk deletion of devices. Accepts a JSON object with a devices
|
API for bulk deletion of devices. Accepts a JSON object with a devices
|
||||||
key which lists the device_ids to delete. Requires user interactive auth.
|
key which lists the device_ids to delete. Requires user interactive auth.
|
||||||
"""
|
"""
|
||||||
PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
|
PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(DeleteDevicesRestServlet, self).__init__()
|
super(DeleteDevicesRestServlet, self).__init__()
|
||||||
@ -93,8 +93,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
|||||||
|
|
||||||
|
|
||||||
class DeviceRestServlet(servlet.RestServlet):
|
class DeviceRestServlet(servlet.RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
|
||||||
releases=[], v2_alpha=False)
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
"""
|
"""
|
||||||
@ -118,6 +117,8 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_DELETE(self, request, device_id):
|
def on_DELETE(self, request, device_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body = servlet.parse_json_object_from_request(request)
|
body = servlet.parse_json_object_from_request(request)
|
||||||
|
|
||||||
@ -136,11 +137,12 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||||||
if not authed:
|
if not authed:
|
||||||
defer.returnValue((401, result))
|
defer.returnValue((401, result))
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
# check that the UI auth matched the access token
|
||||||
yield self.device_handler.delete_device(
|
user_id = result[constants.LoginType.PASSWORD]
|
||||||
requester.user.to_string(),
|
if user_id != requester.user.to_string():
|
||||||
device_id,
|
raise errors.AuthError(403, "Invalid auth")
|
||||||
)
|
|
||||||
|
yield self.device_handler.delete_device(user_id, device_id)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -39,20 +39,23 @@ class GroupServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
group_description = yield self.groups_handler.get_group_profile(group_id, user_id)
|
group_description = yield self.groups_handler.get_group_profile(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((200, group_description))
|
defer.returnValue((200, group_description))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, group_id):
|
def on_POST(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
yield self.groups_handler.update_group_profile(
|
yield self.groups_handler.update_group_profile(
|
||||||
group_id, user_id, content,
|
group_id, requester_user_id, content,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
@ -72,9 +75,12 @@ class GroupSummaryServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id)
|
get_group_summary = yield self.groups_handler.get_group_summary(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((200, get_group_summary))
|
defer.returnValue((200, get_group_summary))
|
||||||
|
|
||||||
@ -101,11 +107,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, group_id, category_id, room_id):
|
def on_PUT(self, request, group_id, category_id, room_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
resp = yield self.groups_handler.update_group_summary_room(
|
resp = yield self.groups_handler.update_group_summary_room(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
category_id=category_id,
|
category_id=category_id,
|
||||||
content=content,
|
content=content,
|
||||||
@ -116,10 +122,10 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_DELETE(self, request, group_id, category_id, room_id):
|
def on_DELETE(self, request, group_id, category_id, room_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
resp = yield self.groups_handler.delete_group_summary_room(
|
resp = yield self.groups_handler.delete_group_summary_room(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
category_id=category_id,
|
category_id=category_id,
|
||||||
)
|
)
|
||||||
@ -143,10 +149,10 @@ class GroupCategoryServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id, category_id):
|
def on_GET(self, request, group_id, category_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
category = yield self.groups_handler.get_group_category(
|
category = yield self.groups_handler.get_group_category(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
category_id=category_id,
|
category_id=category_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -155,11 +161,11 @@ class GroupCategoryServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, group_id, category_id):
|
def on_PUT(self, request, group_id, category_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
resp = yield self.groups_handler.update_group_category(
|
resp = yield self.groups_handler.update_group_category(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
category_id=category_id,
|
category_id=category_id,
|
||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
@ -169,10 +175,10 @@ class GroupCategoryServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_DELETE(self, request, group_id, category_id):
|
def on_DELETE(self, request, group_id, category_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
resp = yield self.groups_handler.delete_group_category(
|
resp = yield self.groups_handler.delete_group_category(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
category_id=category_id,
|
category_id=category_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -195,10 +201,10 @@ class GroupCategoriesServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
category = yield self.groups_handler.get_group_categories(
|
category = yield self.groups_handler.get_group_categories(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, category))
|
defer.returnValue((200, category))
|
||||||
@ -220,10 +226,10 @@ class GroupRoleServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id, role_id):
|
def on_GET(self, request, group_id, role_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
category = yield self.groups_handler.get_group_role(
|
category = yield self.groups_handler.get_group_role(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
role_id=role_id,
|
role_id=role_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -232,11 +238,11 @@ class GroupRoleServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, group_id, role_id):
|
def on_PUT(self, request, group_id, role_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
resp = yield self.groups_handler.update_group_role(
|
resp = yield self.groups_handler.update_group_role(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
role_id=role_id,
|
role_id=role_id,
|
||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
@ -246,10 +252,10 @@ class GroupRoleServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_DELETE(self, request, group_id, role_id):
|
def on_DELETE(self, request, group_id, role_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
resp = yield self.groups_handler.delete_group_role(
|
resp = yield self.groups_handler.delete_group_role(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
role_id=role_id,
|
role_id=role_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -272,10 +278,10 @@ class GroupRolesServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
category = yield self.groups_handler.get_group_roles(
|
category = yield self.groups_handler.get_group_roles(
|
||||||
group_id, user_id,
|
group_id, requester_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, category))
|
defer.returnValue((200, category))
|
||||||
@ -343,9 +349,9 @@ class GroupRoomServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
result = yield self.groups_handler.get_rooms_in_group(group_id, user_id)
|
result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
@ -364,9 +370,9 @@ class GroupUsersServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
result = yield self.groups_handler.get_users_in_group(group_id, user_id)
|
result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
@ -385,9 +391,12 @@ class GroupInvitedUsersServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id)
|
result = yield self.groups_handler.get_invited_users_in_group(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
@ -407,14 +416,18 @@ class GroupCreateServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
# TODO: Create group on remote server
|
# TODO: Create group on remote server
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
localpart = content.pop("localpart")
|
localpart = content.pop("localpart")
|
||||||
group_id = GroupID(localpart, self.server_name).to_string()
|
group_id = GroupID(localpart, self.server_name).to_string()
|
||||||
|
|
||||||
result = yield self.groups_handler.create_group(group_id, user_id, content)
|
result = yield self.groups_handler.create_group(
|
||||||
|
group_id,
|
||||||
|
requester_user_id,
|
||||||
|
content,
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
@ -435,11 +448,11 @@ class GroupAdminRoomsServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, group_id, room_id):
|
def on_PUT(self, request, group_id, room_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
result = yield self.groups_handler.add_room_to_group(
|
result = yield self.groups_handler.add_room_to_group(
|
||||||
group_id, user_id, room_id, content,
|
group_id, requester_user_id, room_id, content,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
@ -447,10 +460,37 @@ class GroupAdminRoomsServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_DELETE(self, request, group_id, room_id):
|
def on_DELETE(self, request, group_id, room_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
result = yield self.groups_handler.remove_room_from_group(
|
result = yield self.groups_handler.remove_room_from_group(
|
||||||
group_id, user_id, room_id,
|
group_id, requester_user_id, room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
|
||||||
|
class GroupAdminRoomsConfigServlet(RestServlet):
|
||||||
|
"""Update the config of a room in a group
|
||||||
|
"""
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
|
||||||
|
"/config/(?P<config_key>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(GroupAdminRoomsConfigServlet, self).__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, group_id, room_id, config_key):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
result = yield self.groups_handler.update_room_in_group(
|
||||||
|
group_id, requester_user_id, room_id, config_key, content,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
@ -685,9 +725,9 @@ class GroupsForUserServlet(RestServlet):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
result = yield self.groups_handler.get_joined_groups(user_id)
|
result = yield self.groups_handler.get_joined_groups(requester_user_id)
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
@ -700,6 +740,7 @@ def register_servlets(hs, http_server):
|
|||||||
GroupRoomServlet(hs).register(http_server)
|
GroupRoomServlet(hs).register(http_server)
|
||||||
GroupCreateServlet(hs).register(http_server)
|
GroupCreateServlet(hs).register(http_server)
|
||||||
GroupAdminRoomsServlet(hs).register(http_server)
|
GroupAdminRoomsServlet(hs).register(http_server)
|
||||||
|
GroupAdminRoomsConfigServlet(hs).register(http_server)
|
||||||
GroupAdminUsersInviteServlet(hs).register(http_server)
|
GroupAdminUsersInviteServlet(hs).register(http_server)
|
||||||
GroupAdminUsersKickServlet(hs).register(http_server)
|
GroupAdminUsersKickServlet(hs).register(http_server)
|
||||||
GroupSelfLeaveServlet(hs).register(http_server)
|
GroupSelfLeaveServlet(hs).register(http_server)
|
||||||
|
@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
|
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
|
||||||
releases=())
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
"""
|
"""
|
||||||
@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet):
|
|||||||
} } } } } }
|
} } } } } }
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATTERNS = client_v2_patterns(
|
PATTERNS = client_v2_patterns("/keys/query$")
|
||||||
"/keys/query$",
|
|
||||||
releases=()
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
"""
|
"""
|
||||||
@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet):
|
|||||||
200 OK
|
200 OK
|
||||||
{ "changed": ["@foo:example.com"] }
|
{ "changed": ["@foo:example.com"] }
|
||||||
"""
|
"""
|
||||||
PATTERNS = client_v2_patterns(
|
PATTERNS = client_v2_patterns("/keys/changes$")
|
||||||
"/keys/changes$",
|
|
||||||
releases=()
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
"""
|
"""
|
||||||
@ -213,10 +206,7 @@ class OneTimeKeyServlet(RestServlet):
|
|||||||
} } } }
|
} } } }
|
||||||
|
|
||||||
"""
|
"""
|
||||||
PATTERNS = client_v2_patterns(
|
PATTERNS = client_v2_patterns("/keys/claim$")
|
||||||
"/keys/claim$",
|
|
||||||
releases=()
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(OneTimeKeyServlet, self).__init__()
|
super(OneTimeKeyServlet, self).__init__()
|
||||||
|
@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class NotificationsServlet(RestServlet):
|
class NotificationsServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/notifications$", releases=())
|
PATTERNS = client_v2_patterns("/notifications$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(NotificationsServlet, self).__init__()
|
super(NotificationsServlet, self).__init__()
|
||||||
|
@ -224,6 +224,12 @@ class RegisterRestServlet(RestServlet):
|
|||||||
# 'user' key not 'username'). Since this is a new addition, we'll
|
# 'user' key not 'username'). Since this is a new addition, we'll
|
||||||
# fallback to 'username' if they gave one.
|
# fallback to 'username' if they gave one.
|
||||||
desired_username = body.get("user", desired_username)
|
desired_username = body.get("user", desired_username)
|
||||||
|
|
||||||
|
# XXX we should check that desired_username is valid. Currently
|
||||||
|
# we give appservices carte blanche for any insanity in mxids,
|
||||||
|
# because the IRC bridges rely on being able to register stupid
|
||||||
|
# IDs.
|
||||||
|
|
||||||
access_token = get_access_token_from_request(request)
|
access_token = get_access_token_from_request(request)
|
||||||
|
|
||||||
if isinstance(desired_username, basestring):
|
if isinstance(desired_username, basestring):
|
||||||
@ -233,6 +239,15 @@ class RegisterRestServlet(RestServlet):
|
|||||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# for either shared secret or regular registration, downcase the
|
||||||
|
# provided username before attempting to register it. This should mean
|
||||||
|
# that people who try to register with upper-case in their usernames
|
||||||
|
# don't get a nasty surprise. (Note that we treat username
|
||||||
|
# case-insenstively in login, so they are free to carry on imagining
|
||||||
|
# that their username is CrAzYh4cKeR if that keeps them happy)
|
||||||
|
if desired_username is not None:
|
||||||
|
desired_username = desired_username.lower()
|
||||||
|
|
||||||
# == Shared Secret Registration == (e.g. create new user scripts)
|
# == Shared Secret Registration == (e.g. create new user scripts)
|
||||||
if 'mac' in body:
|
if 'mac' in body:
|
||||||
# FIXME: Should we really be determining if this is shared secret
|
# FIXME: Should we really be determining if this is shared secret
|
||||||
@ -336,6 +351,9 @@ class RegisterRestServlet(RestServlet):
|
|||||||
new_password = params.get("password", None)
|
new_password = params.get("password", None)
|
||||||
guest_access_token = params.get("guest_access_token", None)
|
guest_access_token = params.get("guest_access_token", None)
|
||||||
|
|
||||||
|
if desired_username is not None:
|
||||||
|
desired_username = desired_username.lower()
|
||||||
|
|
||||||
(registered_user_id, _) = yield self.registration_handler.register(
|
(registered_user_id, _) = yield self.registration_handler.register(
|
||||||
localpart=desired_username,
|
localpart=desired_username,
|
||||||
password=new_password,
|
password=new_password,
|
||||||
@ -417,13 +435,22 @@ class RegisterRestServlet(RestServlet):
|
|||||||
def _do_shared_secret_registration(self, username, password, body):
|
def _do_shared_secret_registration(self, username, password, body):
|
||||||
if not self.hs.config.registration_shared_secret:
|
if not self.hs.config.registration_shared_secret:
|
||||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||||
|
if not username:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "username must be specified", errcode=Codes.BAD_JSON,
|
||||||
|
)
|
||||||
|
|
||||||
user = username.encode("utf-8")
|
# use the username from the original request rather than the
|
||||||
|
# downcased one in `username` for the mac calculation
|
||||||
|
user = body["username"].encode("utf-8")
|
||||||
|
|
||||||
# str() because otherwise hmac complains that 'unicode' does not
|
# str() because otherwise hmac complains that 'unicode' does not
|
||||||
# have the buffer interface
|
# have the buffer interface
|
||||||
got_mac = str(body["mac"])
|
got_mac = str(body["mac"])
|
||||||
|
|
||||||
|
# FIXME this is different to the /v1/register endpoint, which
|
||||||
|
# includes the password and admin flag in the hashed text. Why are
|
||||||
|
# these different?
|
||||||
want_mac = hmac.new(
|
want_mac = hmac.new(
|
||||||
key=self.hs.config.registration_shared_secret,
|
key=self.hs.config.registration_shared_secret,
|
||||||
msg=user,
|
msg=user,
|
||||||
@ -557,25 +584,28 @@ class RegisterRestServlet(RestServlet):
|
|||||||
Args:
|
Args:
|
||||||
(str) user_id: full canonical @user:id
|
(str) user_id: full canonical @user:id
|
||||||
(object) params: registration parameters, from which we pull
|
(object) params: registration parameters, from which we pull
|
||||||
device_id and initial_device_name
|
device_id, initial_device_name and inhibit_login
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: (object) dictionary for response from /register
|
defer.Deferred: (object) dictionary for response from /register
|
||||||
"""
|
"""
|
||||||
|
result = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
}
|
||||||
|
if not params.get("inhibit_login", False):
|
||||||
device_id = yield self._register_device(user_id, params)
|
device_id = yield self._register_device(user_id, params)
|
||||||
|
|
||||||
access_token = (
|
access_token = (
|
||||||
yield self.auth_handler.get_access_token_for_user_id(
|
yield self.auth_handler.get_access_token_for_user_id(
|
||||||
user_id, device_id=device_id,
|
user_id, device_id=device_id,
|
||||||
initial_display_name=params.get("initial_device_display_name")
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue({
|
result.update({
|
||||||
"user_id": user_id,
|
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
})
|
})
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
def _register_device(self, user_id, params):
|
def _register_device(self, user_id, params):
|
||||||
"""Register a device for a user.
|
"""Register a device for a user.
|
||||||
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class SendToDeviceRestServlet(servlet.RestServlet):
|
class SendToDeviceRestServlet(servlet.RestServlet):
|
||||||
PATTERNS = client_v2_patterns(
|
PATTERNS = client_v2_patterns(
|
||||||
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
|
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
|
||||||
releases=[], v2_alpha=False
|
v2_alpha=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ThirdPartyProtocolsServlet(RestServlet):
|
class ThirdPartyProtocolsServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
|
PATTERNS = client_v2_patterns("/thirdparty/protocols")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ThirdPartyProtocolsServlet, self).__init__()
|
super(ThirdPartyProtocolsServlet, self).__init__()
|
||||||
@ -43,8 +43,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
|
|||||||
|
|
||||||
|
|
||||||
class ThirdPartyProtocolServlet(RestServlet):
|
class ThirdPartyProtocolServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
|
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
|
||||||
releases=())
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ThirdPartyProtocolServlet, self).__init__()
|
super(ThirdPartyProtocolServlet, self).__init__()
|
||||||
@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet):
|
|||||||
|
|
||||||
|
|
||||||
class ThirdPartyUserServlet(RestServlet):
|
class ThirdPartyUserServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
|
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
|
||||||
releases=())
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ThirdPartyUserServlet, self).__init__()
|
super(ThirdPartyUserServlet, self).__init__()
|
||||||
@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet):
|
|||||||
|
|
||||||
|
|
||||||
class ThirdPartyLocationServlet(RestServlet):
|
class ThirdPartyLocationServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
|
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
|
||||||
releases=())
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ThirdPartyLocationServlet, self).__init__()
|
super(ThirdPartyLocationServlet, self).__init__()
|
||||||
|
@ -20,6 +20,7 @@ from twisted.web.resource import Resource
|
|||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
SynapseError, Codes,
|
SynapseError, Codes,
|
||||||
)
|
)
|
||||||
|
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.http.client import SpiderHttpClient
|
from synapse.http.client import SpiderHttpClient
|
||||||
@ -63,16 +64,15 @@ class PreviewUrlResource(Resource):
|
|||||||
|
|
||||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||||
|
|
||||||
# simple memory cache mapping urls to OG metadata
|
# memory cache mapping urls to an ObservableDeferred returning
|
||||||
self.cache = ExpiringCache(
|
# JSON-encoded OG metadata
|
||||||
|
self._cache = ExpiringCache(
|
||||||
cache_name="url_previews",
|
cache_name="url_previews",
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
# don't spider URLs more often than once an hour
|
# don't spider URLs more often than once an hour
|
||||||
expiry_ms=60 * 60 * 1000,
|
expiry_ms=60 * 60 * 1000,
|
||||||
)
|
)
|
||||||
self.cache.start()
|
self._cache.start()
|
||||||
|
|
||||||
self.downloads = {}
|
|
||||||
|
|
||||||
self._cleaner_loop = self.clock.looping_call(
|
self._cleaner_loop = self.clock.looping_call(
|
||||||
self._expire_url_cache_data, 10 * 1000
|
self._expire_url_cache_data, 10 * 1000
|
||||||
@ -94,6 +94,7 @@ class PreviewUrlResource(Resource):
|
|||||||
else:
|
else:
|
||||||
ts = self.clock.time_msec()
|
ts = self.clock.time_msec()
|
||||||
|
|
||||||
|
# XXX: we could move this into _do_preview if we wanted.
|
||||||
url_tuple = urlparse.urlsplit(url)
|
url_tuple = urlparse.urlsplit(url)
|
||||||
for entry in self.url_preview_url_blacklist:
|
for entry in self.url_preview_url_blacklist:
|
||||||
match = True
|
match = True
|
||||||
@ -126,14 +127,42 @@ class PreviewUrlResource(Resource):
|
|||||||
Codes.UNKNOWN
|
Codes.UNKNOWN
|
||||||
)
|
)
|
||||||
|
|
||||||
# first check the memory cache - good to handle all the clients on this
|
# the in-memory cache:
|
||||||
# HS thundering away to preview the same URL at the same time.
|
# * ensures that only one request is active at a time
|
||||||
og = self.cache.get(url)
|
# * takes load off the DB for the thundering herds
|
||||||
if og:
|
# * also caches any failures (unlike the DB) so we don't keep
|
||||||
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
|
# requesting the same endpoint
|
||||||
return
|
|
||||||
|
|
||||||
# then check the URL cache in the DB (which will also provide us with
|
observable = self._cache.get(url)
|
||||||
|
|
||||||
|
if not observable:
|
||||||
|
download = preserve_fn(self._do_preview)(
|
||||||
|
url, requester.user, ts,
|
||||||
|
)
|
||||||
|
observable = ObservableDeferred(
|
||||||
|
download,
|
||||||
|
consumeErrors=True
|
||||||
|
)
|
||||||
|
self._cache[url] = observable
|
||||||
|
else:
|
||||||
|
logger.info("Returning cached response")
|
||||||
|
|
||||||
|
og = yield make_deferred_yieldable(observable.observe())
|
||||||
|
respond_with_json_bytes(request, 200, og, send_cors=True)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_preview(self, url, user, ts):
|
||||||
|
"""Check the db, and download the URL and build a preview
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str):
|
||||||
|
user (str):
|
||||||
|
ts (int):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[str]: json-encoded og data
|
||||||
|
"""
|
||||||
|
# check the URL cache in the DB (which will also provide us with
|
||||||
# historical previews, if we have any)
|
# historical previews, if we have any)
|
||||||
cache_result = yield self.store.get_url_cache(url, ts)
|
cache_result = yield self.store.get_url_cache(url, ts)
|
||||||
if (
|
if (
|
||||||
@ -141,32 +170,10 @@ class PreviewUrlResource(Resource):
|
|||||||
cache_result["expires_ts"] > ts and
|
cache_result["expires_ts"] > ts and
|
||||||
cache_result["response_code"] / 100 == 2
|
cache_result["response_code"] / 100 == 2
|
||||||
):
|
):
|
||||||
respond_with_json_bytes(
|
defer.returnValue(cache_result["og"])
|
||||||
request, 200, cache_result["og"].encode('utf-8'),
|
|
||||||
send_cors=True
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Ensure only one download for a given URL is active at a time
|
media_info = yield self._download_url(url, user)
|
||||||
download = self.downloads.get(url)
|
|
||||||
if download is None:
|
|
||||||
download = self._download_url(url, requester.user)
|
|
||||||
download = ObservableDeferred(
|
|
||||||
download,
|
|
||||||
consumeErrors=True
|
|
||||||
)
|
|
||||||
self.downloads[url] = download
|
|
||||||
|
|
||||||
@download.addBoth
|
|
||||||
def callback(media_info):
|
|
||||||
del self.downloads[url]
|
|
||||||
return media_info
|
|
||||||
media_info = yield download.observe()
|
|
||||||
|
|
||||||
# FIXME: we should probably update our cache now anyway, so that
|
|
||||||
# even if the OG calculation raises, we don't keep hammering on the
|
|
||||||
# remote server. For now, leave it uncached to aid debugging OG
|
|
||||||
# calculation problems
|
|
||||||
|
|
||||||
logger.debug("got media_info of '%s'" % media_info)
|
logger.debug("got media_info of '%s'" % media_info)
|
||||||
|
|
||||||
@ -212,7 +219,7 @@ class PreviewUrlResource(Resource):
|
|||||||
# just rely on the caching on the master request to speed things up.
|
# just rely on the caching on the master request to speed things up.
|
||||||
if 'og:image' in og and og['og:image']:
|
if 'og:image' in og and og['og:image']:
|
||||||
image_info = yield self._download_url(
|
image_info = yield self._download_url(
|
||||||
_rebase_url(og['og:image'], media_info['uri']), requester.user
|
_rebase_url(og['og:image'], media_info['uri']), user
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_media(image_info['media_type']):
|
if _is_media(image_info['media_type']):
|
||||||
@ -239,8 +246,7 @@ class PreviewUrlResource(Resource):
|
|||||||
|
|
||||||
logger.debug("Calculated OG for %s as %s" % (url, og))
|
logger.debug("Calculated OG for %s as %s" % (url, og))
|
||||||
|
|
||||||
# store OG in ephemeral in-memory cache
|
jsonog = json.dumps(og)
|
||||||
self.cache[url] = og
|
|
||||||
|
|
||||||
# store OG in history-aware DB cache
|
# store OG in history-aware DB cache
|
||||||
yield self.store.store_url_cache(
|
yield self.store.store_url_cache(
|
||||||
@ -248,12 +254,12 @@ class PreviewUrlResource(Resource):
|
|||||||
media_info["response_code"],
|
media_info["response_code"],
|
||||||
media_info["etag"],
|
media_info["etag"],
|
||||||
media_info["expires"] + media_info["created_ts"],
|
media_info["expires"] + media_info["created_ts"],
|
||||||
json.dumps(og),
|
jsonog,
|
||||||
media_info["filesystem_id"],
|
media_info["filesystem_id"],
|
||||||
media_info["created_ts"],
|
media_info["created_ts"],
|
||||||
)
|
)
|
||||||
|
|
||||||
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
|
defer.returnValue(jsonog)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _download_url(self, url, user):
|
def _download_url(self, url, user):
|
||||||
@ -520,7 +526,14 @@ def _calc_og(tree, media_uri):
|
|||||||
from lxml import etree
|
from lxml import etree
|
||||||
|
|
||||||
TAGS_TO_REMOVE = (
|
TAGS_TO_REMOVE = (
|
||||||
"header", "nav", "aside", "footer", "script", "style", etree.Comment
|
"header",
|
||||||
|
"nav",
|
||||||
|
"aside",
|
||||||
|
"footer",
|
||||||
|
"script",
|
||||||
|
"noscript",
|
||||||
|
"style",
|
||||||
|
etree.Comment
|
||||||
)
|
)
|
||||||
|
|
||||||
# Split all the text nodes into paragraphs (by splitting on new
|
# Split all the text nodes into paragraphs (by splitting on new
|
||||||
|
@ -268,7 +268,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
||||||
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
||||||
|
|
||||||
super(DataStore, self).__init__(hs)
|
super(DataStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
def take_presence_startup_info(self):
|
def take_presence_startup_info(self):
|
||||||
active_on_startup = self._presence_on_startup
|
active_on_startup = self._presence_on_startup
|
||||||
|
@ -162,7 +162,7 @@ class PerformanceCounters(object):
|
|||||||
class SQLBaseStore(object):
|
class SQLBaseStore(object):
|
||||||
_TXN_ID = 0
|
_TXN_ID = 0
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self._db_pool = hs.get_db_pool()
|
self._db_pool = hs.get_db_pool()
|
||||||
|
@ -63,7 +63,7 @@ class AccountDataStore(SQLBaseStore):
|
|||||||
"get_account_data_for_user", get_account_data_for_user_txn
|
"get_account_data_for_user", get_account_data_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2)
|
@cachedInlineCallbacks(num_args=2, max_entries=5000)
|
||||||
def get_global_account_data_by_type_for_user(self, data_type, user_id):
|
def get_global_account_data_by_type_for_user(self, data_type, user_id):
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -48,8 +48,8 @@ def _make_exclusive_regex(services_cache):
|
|||||||
|
|
||||||
class ApplicationServiceStore(SQLBaseStore):
|
class ApplicationServiceStore(SQLBaseStore):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(ApplicationServiceStore, self).__init__(hs)
|
super(ApplicationServiceStore, self).__init__(db_conn, hs)
|
||||||
self.hostname = hs.hostname
|
self.hostname = hs.hostname
|
||||||
self.services_cache = load_appservices(
|
self.services_cache = load_appservices(
|
||||||
hs.hostname,
|
hs.hostname,
|
||||||
@ -173,8 +173,8 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||||||
|
|
||||||
class ApplicationServiceTransactionStore(SQLBaseStore):
|
class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(ApplicationServiceTransactionStore, self).__init__(hs)
|
super(ApplicationServiceTransactionStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_appservices_by_state(self, state):
|
def get_appservices_by_state(self, state):
|
||||||
|
@ -80,8 +80,8 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||||||
BACKGROUND_UPDATE_INTERVAL_MS = 1000
|
BACKGROUND_UPDATE_INTERVAL_MS = 1000
|
||||||
BACKGROUND_UPDATE_DURATION_MS = 100
|
BACKGROUND_UPDATE_DURATION_MS = 100
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(BackgroundUpdateStore, self).__init__(hs)
|
super(BackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||||
self._background_update_performance = {}
|
self._background_update_performance = {}
|
||||||
self._background_update_queue = []
|
self._background_update_queue = []
|
||||||
self._background_update_handlers = {}
|
self._background_update_handlers = {}
|
||||||
|
@ -32,14 +32,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000
|
|||||||
|
|
||||||
|
|
||||||
class ClientIpStore(background_updates.BackgroundUpdateStore):
|
class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
self.client_ip_last_seen = Cache(
|
self.client_ip_last_seen = Cache(
|
||||||
name="client_ip_last_seen",
|
name="client_ip_last_seen",
|
||||||
keylen=4,
|
keylen=4,
|
||||||
max_entries=50000 * CACHE_SIZE_FACTOR,
|
max_entries=50000 * CACHE_SIZE_FACTOR,
|
||||||
)
|
)
|
||||||
|
|
||||||
super(ClientIpStore, self).__init__(hs)
|
super(ClientIpStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self.register_background_index_update(
|
self.register_background_index_update(
|
||||||
"user_ips_device_index",
|
"user_ips_device_index",
|
||||||
|
@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class DeviceInboxStore(BackgroundUpdateStore):
|
class DeviceInboxStore(BackgroundUpdateStore):
|
||||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(DeviceInboxStore, self).__init__(hs)
|
super(DeviceInboxStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self.register_background_index_update(
|
self.register_background_index_update(
|
||||||
"device_inbox_stream_index",
|
"device_inbox_stream_index",
|
||||||
|
@ -26,8 +26,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DeviceStore(SQLBaseStore):
|
class DeviceStore(SQLBaseStore):
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(DeviceStore, self).__init__(hs)
|
super(DeviceStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||||
# the device exists.
|
# the device exists.
|
||||||
|
@ -39,8 +39,8 @@ class EventFederationStore(SQLBaseStore):
|
|||||||
|
|
||||||
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
|
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(EventFederationStore, self).__init__(hs)
|
super(EventFederationStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self.register_background_update_handler(
|
self.register_background_update_handler(
|
||||||
self.EVENT_AUTH_STATE_ONLY,
|
self.EVENT_AUTH_STATE_ONLY,
|
||||||
|
@ -65,8 +65,8 @@ def _deserialize_action(actions, is_highlight):
|
|||||||
class EventPushActionsStore(SQLBaseStore):
|
class EventPushActionsStore(SQLBaseStore):
|
||||||
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
|
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(EventPushActionsStore, self).__init__(hs)
|
super(EventPushActionsStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self.register_background_index_update(
|
self.register_background_index_update(
|
||||||
self.EPA_HIGHLIGHT_INDEX,
|
self.EPA_HIGHLIGHT_INDEX,
|
||||||
|
@ -197,8 +197,8 @@ class EventsStore(SQLBaseStore):
|
|||||||
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
||||||
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(EventsStore, self).__init__(hs)
|
super(EventsStore, self).__init__(db_conn, hs)
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.register_background_update_handler(
|
self.register_background_update_handler(
|
||||||
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
|
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
|
||||||
|
@ -35,7 +35,9 @@ class GroupServerStore(SQLBaseStore):
|
|||||||
keyvalues={
|
keyvalues={
|
||||||
"group_id": group_id,
|
"group_id": group_id,
|
||||||
},
|
},
|
||||||
retcols=("name", "short_description", "long_description", "avatar_url",),
|
retcols=(
|
||||||
|
"name", "short_description", "long_description", "avatar_url", "is_public"
|
||||||
|
),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="is_user_in_group",
|
desc="is_user_in_group",
|
||||||
)
|
)
|
||||||
@ -52,7 +54,7 @@ class GroupServerStore(SQLBaseStore):
|
|||||||
return self._simple_select_list(
|
return self._simple_select_list(
|
||||||
table="group_users",
|
table="group_users",
|
||||||
keyvalues=keyvalues,
|
keyvalues=keyvalues,
|
||||||
retcols=("user_id", "is_public",),
|
retcols=("user_id", "is_public", "is_admin",),
|
||||||
desc="get_users_in_group",
|
desc="get_users_in_group",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -855,6 +857,19 @@ class GroupServerStore(SQLBaseStore):
|
|||||||
desc="add_room_to_group",
|
desc="add_room_to_group",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_room_in_group_visibility(self, group_id, room_id, is_public):
|
||||||
|
return self._simple_update(
|
||||||
|
table="group_rooms",
|
||||||
|
keyvalues={
|
||||||
|
"group_id": group_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
},
|
||||||
|
updatevalues={
|
||||||
|
"is_public": is_public,
|
||||||
|
},
|
||||||
|
desc="update_room_in_group_visibility",
|
||||||
|
)
|
||||||
|
|
||||||
def remove_room_from_group(self, group_id, room_id):
|
def remove_room_from_group(self, group_id, room_id):
|
||||||
def _remove_room_from_group_txn(txn):
|
def _remove_room_from_group_txn(txn):
|
||||||
self._simple_delete_txn(
|
self._simple_delete_txn(
|
||||||
@ -1026,6 +1041,7 @@ class GroupServerStore(SQLBaseStore):
|
|||||||
"avatar_url": avatar_url,
|
"avatar_url": avatar_url,
|
||||||
"short_description": short_description,
|
"short_description": short_description,
|
||||||
"long_description": long_description,
|
"long_description": long_description,
|
||||||
|
"is_public": True,
|
||||||
},
|
},
|
||||||
desc="create_group",
|
desc="create_group",
|
||||||
)
|
)
|
||||||
@ -1086,6 +1102,24 @@ class GroupServerStore(SQLBaseStore):
|
|||||||
desc="update_remote_attestion",
|
desc="update_remote_attestion",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def remove_attestation_renewal(self, group_id, user_id):
|
||||||
|
"""Remove an attestation that we thought we should renew, but actually
|
||||||
|
shouldn't. Ideally this would never get called as we would never
|
||||||
|
incorrectly try and do attestations for local users on local groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id (str)
|
||||||
|
user_id (str)
|
||||||
|
"""
|
||||||
|
return self._simple_delete(
|
||||||
|
table="group_attestations_renewals",
|
||||||
|
keyvalues={
|
||||||
|
"group_id": group_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
desc="remove_attestation_renewal",
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_remote_attestation(self, group_id, user_id):
|
def get_remote_attestation(self, group_id, user_id):
|
||||||
"""Get the attestation that proves the remote agrees that the user is
|
"""Get the attestation that proves the remote agrees that the user is
|
||||||
|
@ -254,6 +254,9 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||||||
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
|
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
|
||||||
|
|
||||||
def delete_url_cache(self, media_ids):
|
def delete_url_cache(self, media_ids):
|
||||||
|
if len(media_ids) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"DELETE FROM local_media_repository_url_cache"
|
"DELETE FROM local_media_repository_url_cache"
|
||||||
" WHERE media_id = ?"
|
" WHERE media_id = ?"
|
||||||
@ -281,6 +284,9 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def delete_url_cache_media(self, media_ids):
|
def delete_url_cache_media(self, media_ids):
|
||||||
|
if len(media_ids) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
def _delete_url_cache_media_txn(txn):
|
def _delete_url_cache_media_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"DELETE FROM local_media_repository"
|
"DELETE FROM local_media_repository"
|
||||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 45
|
SCHEMA_VERSION = 46
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
|
|||||||
|
|
||||||
If `config` is None then prepare_database will assert that no upgrade is
|
If `config` is None then prepare_database will assert that no upgrade is
|
||||||
necessary, *or* will create a fresh database if the database is empty.
|
necessary, *or* will create a fresh database if the database is empty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn:
|
||||||
|
database_engine:
|
||||||
|
config (synapse.config.homeserver.HomeServerConfig|None):
|
||||||
|
application config, or None if we are connecting to an existing
|
||||||
|
database which we expect to be configured already
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
cur = db_conn.cursor()
|
cur = db_conn.cursor()
|
||||||
@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config):
|
|||||||
else:
|
else:
|
||||||
_setup_new_database(cur, database_engine)
|
_setup_new_database(cur, database_engine)
|
||||||
|
|
||||||
|
# check if any of our configured dynamic modules want a database
|
||||||
|
if config is not None:
|
||||||
|
_apply_module_schemas(cur, database_engine, config)
|
||||||
|
|
||||||
cur.close()
|
cur.close()
|
||||||
db_conn.commit()
|
db_conn.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_module_schemas(txn, database_engine, config):
|
||||||
|
"""Apply the module schemas for the dynamic modules, if any
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cur: database cursor
|
||||||
|
database_engine: synapse database engine class
|
||||||
|
config (synapse.config.homeserver.HomeServerConfig):
|
||||||
|
application config
|
||||||
|
"""
|
||||||
|
for (mod, _config) in config.password_providers:
|
||||||
|
if not hasattr(mod, 'get_db_schema_files'):
|
||||||
|
continue
|
||||||
|
modname = ".".join((mod.__module__, mod.__name__))
|
||||||
|
_apply_module_schema_files(
|
||||||
|
txn, database_engine, modname, mod.get_db_schema_files(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
|
||||||
|
"""Apply the module schemas for a single module
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cur: database cursor
|
||||||
|
database_engine: synapse database engine class
|
||||||
|
modname (str): fully qualified name of the module
|
||||||
|
names_and_streams (Iterable[(str, file)]): the names and streams of
|
||||||
|
schemas to be applied
|
||||||
|
"""
|
||||||
|
cur.execute(
|
||||||
|
database_engine.convert_param_style(
|
||||||
|
"SELECT file FROM applied_module_schemas WHERE module_name = ?"
|
||||||
|
),
|
||||||
|
(modname,)
|
||||||
|
)
|
||||||
|
applied_deltas = set(d for d, in cur)
|
||||||
|
for (name, stream) in names_and_streams:
|
||||||
|
if name in applied_deltas:
|
||||||
|
continue
|
||||||
|
|
||||||
|
root_name, ext = os.path.splitext(name)
|
||||||
|
if ext != '.sql':
|
||||||
|
raise PrepareDatabaseException(
|
||||||
|
"only .sql files are currently supported for module schemas",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("applying schema %s for %s", name, modname)
|
||||||
|
for statement in get_statements(stream):
|
||||||
|
cur.execute(statement)
|
||||||
|
|
||||||
|
# Mark as done.
|
||||||
|
cur.execute(
|
||||||
|
database_engine.convert_param_style(
|
||||||
|
"INSERT INTO applied_module_schemas (module_name, file)"
|
||||||
|
" VALUES (?,?)",
|
||||||
|
),
|
||||||
|
(modname, name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_statements(f):
|
def get_statements(f):
|
||||||
statement_buffer = ""
|
statement_buffer = ""
|
||||||
in_comment = False # If we're in a /* ... */ style comment
|
in_comment = False # If we're in a /* ... */ style comment
|
||||||
|
@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ReceiptsStore(SQLBaseStore):
|
class ReceiptsStore(SQLBaseStore):
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(ReceiptsStore, self).__init__(hs)
|
super(ReceiptsStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self._receipts_stream_cache = StreamChangeCache(
|
self._receipts_stream_cache = StreamChangeCache(
|
||||||
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
||||||
|
@ -24,8 +24,8 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
|||||||
|
|
||||||
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(RegistrationStore, self).__init__(hs)
|
super(RegistrationStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@ -36,12 +36,15 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
columns=["user_id", "device_id"],
|
columns=["user_id", "device_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.register_background_index_update(
|
# we no longer use refresh tokens, but it's possible that some people
|
||||||
"refresh_tokens_device_index",
|
# might have a background update queued to build this index. Just
|
||||||
index_name="refresh_tokens_device_id",
|
# clear the background update.
|
||||||
table="refresh_tokens",
|
@defer.inlineCallbacks
|
||||||
columns=["user_id", "device_id"],
|
def noop_update(progress, batch_size):
|
||||||
)
|
yield self._end_background_update("refresh_tokens_device_index")
|
||||||
|
defer.returnValue(1)
|
||||||
|
self.register_background_update_handler(
|
||||||
|
"refresh_tokens_device_index", noop_update)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_access_token_to_user(self, user_id, token, device_id=None):
|
def add_access_token_to_user(self, user_id, token, device_id=None):
|
||||||
@ -177,9 +180,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if create_profile_with_localpart:
|
if create_profile_with_localpart:
|
||||||
|
# set a default displayname serverside to avoid ugly race
|
||||||
|
# between auto-joins and clients trying to set displaynames
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"INSERT INTO profiles(user_id) VALUES (?)",
|
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
|
||||||
(create_profile_with_localpart,)
|
(create_profile_with_localpart, create_profile_with_localpart)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
@ -236,12 +241,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
"user_set_password_hash", user_set_password_hash_txn
|
"user_set_password_hash", user_set_password_hash_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def user_delete_access_tokens(self, user_id, except_token_id=None,
|
def user_delete_access_tokens(self, user_id, except_token_id=None,
|
||||||
device_id=None,
|
device_id=None):
|
||||||
delete_refresh_tokens=False):
|
|
||||||
"""
|
"""
|
||||||
Invalidate access/refresh tokens belonging to a user
|
Invalidate access tokens belonging to a user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): ID of user the tokens belong to
|
user_id (str): ID of user the tokens belong to
|
||||||
@ -250,10 +253,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
device_id (str|None): ID of device the tokens are associated with.
|
device_id (str|None): ID of device the tokens are associated with.
|
||||||
If None, tokens associated with any device (or no device) will
|
If None, tokens associated with any device (or no device) will
|
||||||
be deleted
|
be deleted
|
||||||
delete_refresh_tokens (bool): True to delete refresh tokens as
|
|
||||||
well as access tokens.
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred:
|
defer.Deferred[list[str, str|None]]: a list of the deleted tokens
|
||||||
|
and device IDs
|
||||||
"""
|
"""
|
||||||
def f(txn):
|
def f(txn):
|
||||||
keyvalues = {
|
keyvalues = {
|
||||||
@ -262,13 +264,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
if device_id is not None:
|
if device_id is not None:
|
||||||
keyvalues["device_id"] = device_id
|
keyvalues["device_id"] = device_id
|
||||||
|
|
||||||
if delete_refresh_tokens:
|
|
||||||
self._simple_delete_txn(
|
|
||||||
txn,
|
|
||||||
table="refresh_tokens",
|
|
||||||
keyvalues=keyvalues,
|
|
||||||
)
|
|
||||||
|
|
||||||
items = keyvalues.items()
|
items = keyvalues.items()
|
||||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||||
values = [v for _, v in items]
|
values = [v for _, v in items]
|
||||||
@ -277,14 +272,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
values.append(except_token_id)
|
values.append(except_token_id)
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT token FROM access_tokens WHERE %s" % where_clause,
|
"SELECT token, device_id FROM access_tokens WHERE %s" % where_clause,
|
||||||
values
|
values
|
||||||
)
|
)
|
||||||
rows = self.cursor_to_dict(txn)
|
tokens_and_devices = [(r[0], r[1]) for r in txn]
|
||||||
|
|
||||||
for row in rows:
|
for token, _ in tokens_and_devices:
|
||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
txn, self.get_user_by_access_token, (row["token"],)
|
txn, self.get_user_by_access_token, (token,)
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
@ -292,7 +287,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||||||
values
|
values
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.runInteraction(
|
return tokens_and_devices
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
"user_delete_access_tokens", f,
|
"user_delete_access_tokens", f,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -49,8 +49,8 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
|
|||||||
|
|
||||||
|
|
||||||
class RoomMemberStore(SQLBaseStore):
|
class RoomMemberStore(SQLBaseStore):
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(RoomMemberStore, self).__init__(hs)
|
super(RoomMemberStore, self).__init__(db_conn, hs)
|
||||||
self.register_background_update_handler(
|
self.register_background_update_handler(
|
||||||
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
||||||
)
|
)
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
/* Copyright 2016 OpenMarket Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
|
||||||
('refresh_tokens_device_index', '{}');
|
|
@ -29,5 +29,5 @@ CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
|
|||||||
CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
|
CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
|
||||||
|
|
||||||
|
|
||||||
-- Make sure that we popualte the table initially
|
-- Make sure that we populate the table initially
|
||||||
UPDATE user_directory_stream_pos SET stream_id = NULL;
|
UPDATE user_directory_stream_pos SET stream_id = NULL;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2016 OpenMarket Ltd
|
/* Copyright 2017 New Vector Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@ -13,4 +13,5 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
|
/* we no longer use (or create) the refresh_tokens table */
|
||||||
|
DROP TABLE IF EXISTS refresh_tokens;
|
32
synapse/storage/schema/delta/46/group_server.sql
Normal file
32
synapse/storage/schema/delta/46/group_server.sql
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright 2017 New Vector Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE groups_new (
|
||||||
|
group_id TEXT NOT NULL,
|
||||||
|
name TEXT, -- the display name of the room
|
||||||
|
avatar_url TEXT,
|
||||||
|
short_description TEXT,
|
||||||
|
long_description TEXT,
|
||||||
|
is_public BOOL NOT NULL -- whether non-members can access group APIs
|
||||||
|
);
|
||||||
|
|
||||||
|
-- NB: awful hack to get the default to be true on postgres and 1 on sqlite
|
||||||
|
INSERT INTO groups_new
|
||||||
|
SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups;
|
||||||
|
|
||||||
|
DROP TABLE groups;
|
||||||
|
ALTER TABLE groups_new RENAME TO groups;
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX groups_idx ON groups(group_id);
|
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2015, 2016 OpenMarket Ltd
|
/* Copyright 2017 New Vector Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@ -13,9 +13,12 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS refresh_tokens(
|
-- this is just embarassing :|
|
||||||
id INTEGER PRIMARY KEY,
|
ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms;
|
||||||
token TEXT NOT NULL,
|
|
||||||
user_id TEXT NOT NULL,
|
-- this is only 300K rows on matrix.org and takes ~3s to generate the index,
|
||||||
UNIQUE (token)
|
-- so is hopefully not going to block anyone else for that long...
|
||||||
);
|
CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id);
|
||||||
|
CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id);
|
||||||
|
DROP INDEX users_in_pubic_room_room_idx;
|
||||||
|
DROP INDEX users_in_pubic_room_user_idx;
|
@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
|
|||||||
file TEXT NOT NULL,
|
file TEXT NOT NULL,
|
||||||
UNIQUE(version, file)
|
UNIQUE(version, file)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
-- a list of schema files we have loaded on behalf of dynamic modules
|
||||||
|
CREATE TABLE IF NOT EXISTS applied_module_schemas(
|
||||||
|
module_name TEXT NOT NULL,
|
||||||
|
file TEXT NOT NULL,
|
||||||
|
UNIQUE(module_name, file)
|
||||||
|
);
|
||||||
|
@ -33,8 +33,8 @@ class SearchStore(BackgroundUpdateStore):
|
|||||||
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
|
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
|
||||||
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
|
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(SearchStore, self).__init__(hs)
|
super(SearchStore, self).__init__(db_conn, hs)
|
||||||
self.register_background_update_handler(
|
self.register_background_update_handler(
|
||||||
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
|
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
|
||||||
)
|
)
|
||||||
|
@ -63,8 +63,8 @@ class StateStore(SQLBaseStore):
|
|||||||
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
|
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
|
||||||
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(StateStore, self).__init__(hs)
|
super(StateStore, self).__init__(db_conn, hs)
|
||||||
self.register_background_update_handler(
|
self.register_background_update_handler(
|
||||||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
|
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
|
||||||
self._background_deduplicate_state,
|
self._background_deduplicate_state,
|
||||||
|
@ -46,8 +46,8 @@ class TransactionStore(SQLBaseStore):
|
|||||||
"""A collection of queries for handling PDUs.
|
"""A collection of queries for handling PDUs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(TransactionStore, self).__init__(hs)
|
super(TransactionStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
|
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
user_ids (list(str)): Users to add
|
user_ids (list(str)): Users to add
|
||||||
"""
|
"""
|
||||||
yield self._simple_insert_many(
|
yield self._simple_insert_many(
|
||||||
table="users_in_pubic_room",
|
table="users_in_public_rooms",
|
||||||
values=[
|
values=[
|
||||||
{
|
{
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@ -219,7 +219,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update_user_in_public_user_list(self, user_id, room_id):
|
def update_user_in_public_user_list(self, user_id, room_id):
|
||||||
yield self._simple_update_one(
|
yield self._simple_update_one(
|
||||||
table="users_in_pubic_room",
|
table="users_in_public_rooms",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
updatevalues={"room_id": room_id},
|
updatevalues={"room_id": room_id},
|
||||||
desc="update_user_in_public_user_list",
|
desc="update_user_in_public_user_list",
|
||||||
@ -240,7 +240,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
self._simple_delete_txn(
|
self._simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
table="users_in_pubic_room",
|
table="users_in_public_rooms",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
)
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
@ -256,7 +256,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def remove_from_user_in_public_room(self, user_id):
|
def remove_from_user_in_public_room(self, user_id):
|
||||||
yield self._simple_delete(
|
yield self._simple_delete(
|
||||||
table="users_in_pubic_room",
|
table="users_in_public_rooms",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
desc="remove_from_user_in_public_room",
|
desc="remove_from_user_in_public_room",
|
||||||
)
|
)
|
||||||
@ -267,7 +267,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
in the given room_id
|
in the given room_id
|
||||||
"""
|
"""
|
||||||
return self._simple_select_onecol(
|
return self._simple_select_onecol(
|
||||||
table="users_in_pubic_room",
|
table="users_in_public_rooms",
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={"room_id": room_id},
|
||||||
retcol="user_id",
|
retcol="user_id",
|
||||||
desc="get_users_in_public_due_to_room",
|
desc="get_users_in_public_due_to_room",
|
||||||
@ -286,7 +286,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
user_ids_pub = yield self._simple_select_onecol(
|
user_ids_pub = yield self._simple_select_onecol(
|
||||||
table="users_in_pubic_room",
|
table="users_in_public_rooms",
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={"room_id": room_id},
|
||||||
retcol="user_id",
|
retcol="user_id",
|
||||||
desc="get_users_in_dir_due_to_room",
|
desc="get_users_in_dir_due_to_room",
|
||||||
@ -514,7 +514,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
def _delete_all_from_user_dir_txn(txn):
|
def _delete_all_from_user_dir_txn(txn):
|
||||||
txn.execute("DELETE FROM user_directory")
|
txn.execute("DELETE FROM user_directory")
|
||||||
txn.execute("DELETE FROM user_directory_search")
|
txn.execute("DELETE FROM user_directory_search")
|
||||||
txn.execute("DELETE FROM users_in_pubic_room")
|
txn.execute("DELETE FROM users_in_public_rooms")
|
||||||
txn.execute("DELETE FROM users_who_share_rooms")
|
txn.execute("DELETE FROM users_who_share_rooms")
|
||||||
txn.call_after(self.get_user_in_directory.invalidate_all)
|
txn.call_after(self.get_user_in_directory.invalidate_all)
|
||||||
txn.call_after(self.get_user_in_public_room.invalidate_all)
|
txn.call_after(self.get_user_in_public_room.invalidate_all)
|
||||||
@ -537,7 +537,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
@cached()
|
@cached()
|
||||||
def get_user_in_public_room(self, user_id):
|
def get_user_in_public_room(self, user_id):
|
||||||
return self._simple_select_one(
|
return self._simple_select_one(
|
||||||
table="users_in_pubic_room",
|
table="users_in_public_rooms",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
retcols=("room_id",),
|
retcols=("room_id",),
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
@ -641,7 +641,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
SELECT d.user_id, display_name, avatar_url
|
SELECT d.user_id, display_name, avatar_url
|
||||||
FROM user_directory_search
|
FROM user_directory_search
|
||||||
INNER JOIN user_directory AS d USING (user_id)
|
INNER JOIN user_directory AS d USING (user_id)
|
||||||
LEFT JOIN users_in_pubic_room AS p USING (user_id)
|
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
||||||
WHERE user_id = ? AND share_private
|
WHERE user_id = ? AND share_private
|
||||||
@ -680,7 +680,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
SELECT d.user_id, display_name, avatar_url
|
SELECT d.user_id, display_name, avatar_url
|
||||||
FROM user_directory_search
|
FROM user_directory_search
|
||||||
INNER JOIN user_directory AS d USING (user_id)
|
INNER JOIN user_directory AS d USING (user_id)
|
||||||
LEFT JOIN users_in_pubic_room AS p USING (user_id)
|
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
||||||
WHERE user_id = ? AND share_private
|
WHERE user_id = ? AND share_private
|
||||||
|
@ -278,8 +278,13 @@ class Limiter(object):
|
|||||||
if entry[0] >= self.max_count:
|
if entry[0] >= self.max_count:
|
||||||
new_defer = defer.Deferred()
|
new_defer = defer.Deferred()
|
||||||
entry[1].append(new_defer)
|
entry[1].append(new_defer)
|
||||||
|
|
||||||
|
logger.info("Waiting to acquire limiter lock for key %r", key)
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
yield new_defer
|
yield new_defer
|
||||||
|
logger.info("Acquired limiter lock for key %r", key)
|
||||||
|
else:
|
||||||
|
logger.info("Acquired uncontended limiter lock for key %r", key)
|
||||||
|
|
||||||
entry[0] += 1
|
entry[0] += 1
|
||||||
|
|
||||||
@ -288,16 +293,21 @@ class Limiter(object):
|
|||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
logger.info("Releasing limiter lock for key %r", key)
|
||||||
|
|
||||||
# We've finished executing so check if there are any things
|
# We've finished executing so check if there are any things
|
||||||
# blocked waiting to execute and start one of them
|
# blocked waiting to execute and start one of them
|
||||||
entry[0] -= 1
|
entry[0] -= 1
|
||||||
try:
|
|
||||||
entry[1].pop(0).callback(None)
|
if entry[1]:
|
||||||
except IndexError:
|
next_def = entry[1].pop(0)
|
||||||
# If nothing else is executing for this key then remove it
|
|
||||||
# from the map
|
with PreserveLoggingContext():
|
||||||
if entry[0] == 0:
|
next_def.callback(None)
|
||||||
self.key_to_defer.pop(key, None)
|
elif entry[0] == 0:
|
||||||
|
# We were the last thing for this key: remove it from the
|
||||||
|
# map.
|
||||||
|
del self.key_to_defer[key]
|
||||||
|
|
||||||
defer.returnValue(_ctx_manager())
|
defer.returnValue(_ctx_manager())
|
||||||
|
|
||||||
|
@ -53,7 +53,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
type="m.room.message",
|
type="m.room.message",
|
||||||
room_id="!foo:bar"
|
room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||||
|
(0, [event]),
|
||||||
|
(0, [])
|
||||||
|
]
|
||||||
self.mock_as_api.push = Mock()
|
self.mock_as_api.push = Mock()
|
||||||
yield self.handler.notify_interested_services(0)
|
yield self.handler.notify_interested_services(0)
|
||||||
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
|
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
|
||||||
@ -75,7 +78,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.mock_as_api.push = Mock()
|
self.mock_as_api.push = Mock()
|
||||||
self.mock_as_api.query_user = Mock()
|
self.mock_as_api.query_user = Mock()
|
||||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||||
|
(0, [event]),
|
||||||
|
(0, [])
|
||||||
|
]
|
||||||
yield self.handler.notify_interested_services(0)
|
yield self.handler.notify_interested_services(0)
|
||||||
self.mock_as_api.query_user.assert_called_once_with(
|
self.mock_as_api.query_user.assert_called_once_with(
|
||||||
services[0], user_id
|
services[0], user_id
|
||||||
@ -98,7 +104,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.mock_as_api.push = Mock()
|
self.mock_as_api.push = Mock()
|
||||||
self.mock_as_api.query_user = Mock()
|
self.mock_as_api.query_user = Mock()
|
||||||
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
|
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||||
|
(0, [event]),
|
||||||
|
(0, [])
|
||||||
|
]
|
||||||
yield self.handler.notify_interested_services(0)
|
yield self.handler.notify_interested_services(0)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
self.mock_as_api.query_user.called,
|
self.mock_as_api.query_user.called,
|
||||||
|
@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
|||||||
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
|
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
|
||||||
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
|
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
|
||||||
# must be done after inserts
|
# must be done after inserts
|
||||||
self.store = ApplicationServiceStore(hs)
|
self.store = ApplicationServiceStore(None, hs)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# TODO: suboptimal that we need to create files for tests!
|
# TODO: suboptimal that we need to create files for tests!
|
||||||
@ -150,7 +150,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.as_yaml_files = []
|
self.as_yaml_files = []
|
||||||
|
|
||||||
self.store = TestTransactionStore(hs)
|
self.store = TestTransactionStore(None, hs)
|
||||||
|
|
||||||
def _add_service(self, url, as_token, id):
|
def _add_service(self, url, as_token, id):
|
||||||
as_yaml = dict(url=url, as_token=as_token, hs_token="something",
|
as_yaml = dict(url=url, as_token=as_token, hs_token="something",
|
||||||
@ -420,8 +420,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
class TestTransactionStore(ApplicationServiceTransactionStore,
|
class TestTransactionStore(ApplicationServiceTransactionStore,
|
||||||
ApplicationServiceStore):
|
ApplicationServiceStore):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(TestTransactionStore, self).__init__(hs)
|
super(TestTransactionStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||||
@ -458,7 +458,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
replication_layer=Mock(),
|
replication_layer=Mock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(None, hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_duplicate_ids(self):
|
def test_duplicate_ids(self):
|
||||||
@ -477,7 +477,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(None, hs)
|
||||||
|
|
||||||
e = cm.exception
|
e = cm.exception
|
||||||
self.assertIn(f1, e.message)
|
self.assertIn(f1, e.message)
|
||||||
@ -501,7 +501,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(None, hs)
|
||||||
|
|
||||||
e = cm.exception
|
e = cm.exception
|
||||||
self.assertIn(f1, e.message)
|
self.assertIn(f1, e.message)
|
||||||
|
@ -56,7 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||||||
database_engine=create_engine(config.database_config),
|
database_engine=create_engine(config.database_config),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore = SQLBaseStore(hs)
|
self.datastore = SQLBaseStore(None, hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_insert_1col(self):
|
def test_insert_1col(self):
|
||||||
|
@ -29,7 +29,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield setup_test_homeserver()
|
hs = yield setup_test_homeserver()
|
||||||
|
|
||||||
self.store = DirectoryStore(hs)
|
self.store = DirectoryStore(None, hs)
|
||||||
|
|
||||||
self.room = RoomID.from_string("!abcde:test")
|
self.room = RoomID.from_string("!abcde:test")
|
||||||
self.alias = RoomAlias.from_string("#my-room:test")
|
self.alias = RoomAlias.from_string("#my-room:test")
|
||||||
|
@ -29,7 +29,7 @@ class PresenceStoreTestCase(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield setup_test_homeserver(clock=MockClock())
|
hs = yield setup_test_homeserver(clock=MockClock())
|
||||||
|
|
||||||
self.store = PresenceStore(hs)
|
self.store = PresenceStore(None, hs)
|
||||||
|
|
||||||
self.u_apple = UserID.from_string("@apple:test")
|
self.u_apple = UserID.from_string("@apple:test")
|
||||||
self.u_banana = UserID.from_string("@banana:test")
|
self.u_banana = UserID.from_string("@banana:test")
|
||||||
|
@ -29,7 +29,7 @@ class ProfileStoreTestCase(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield setup_test_homeserver()
|
hs = yield setup_test_homeserver()
|
||||||
|
|
||||||
self.store = ProfileStore(hs)
|
self.store = ProfileStore(None, hs)
|
||||||
|
|
||||||
self.u_frank = UserID.from_string("@frank:test")
|
self.u_frank = UserID.from_string("@frank:test")
|
||||||
|
|
||||||
|
@ -86,7 +86,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
# now delete some
|
# now delete some
|
||||||
yield self.store.user_delete_access_tokens(
|
yield self.store.user_delete_access_tokens(
|
||||||
self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
|
self.user_id, device_id=self.device_id,
|
||||||
|
)
|
||||||
|
|
||||||
# check they were deleted
|
# check they were deleted
|
||||||
user = yield self.store.get_user_by_access_token(self.tokens[1])
|
user = yield self.store.get_user_by_access_token(self.tokens[1])
|
||||||
@ -97,8 +98,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(self.user_id, user["name"])
|
self.assertEqual(self.user_id, user["name"])
|
||||||
|
|
||||||
# now delete the rest
|
# now delete the rest
|
||||||
yield self.store.user_delete_access_tokens(
|
yield self.store.user_delete_access_tokens(self.user_id)
|
||||||
self.user_id, delete_refresh_tokens=True)
|
|
||||||
|
|
||||||
user = yield self.store.get_user_by_access_token(self.tokens[0])
|
user = yield self.store.get_user_by_access_token(self.tokens[0])
|
||||||
self.assertIsNone(user,
|
self.assertIsNone(user,
|
||||||
|
@ -310,6 +310,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.config = Mock()
|
self.config = Mock()
|
||||||
|
self.config.password_providers = []
|
||||||
self.config.database_config = {"name": "sqlite3"}
|
self.config.database_config = {"name": "sqlite3"}
|
||||||
|
|
||||||
def prepare(self):
|
def prepare(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user