Merge branch 'develop' of github.com:matrix-org/synapse into erikj/patch_inner

This commit is contained in:
Erik Johnston 2019-10-09 16:52:21 +01:00
commit 5c1f886c75
99 changed files with 1804 additions and 1115 deletions

View File

@ -7,7 +7,7 @@ about: Create a report to help us improve
<!-- <!--
**IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**: **IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**:
You will likely get better support more quickly if you ask in ** #matrix:matrix.org ** ;) You will likely get better support more quickly if you ask in ** #synapse:matrix.org ** ;)
This is a bug report template. By following the instructions below and This is a bug report template. By following the instructions below and
@ -50,9 +50,13 @@ If not matrix.org:
<!-- <!--
What version of Synapse is running? What version of Synapse is running?
You can find the Synapse version by inspecting the server headers (replace matrix.org with
your own homeserver domain): You can find the Synapse version with this command:
$ curl -v https://matrix.org/_matrix/client/versions 2>&1 | grep "Server:"
$ curl http://localhost:8008/_synapse/admin/v1/server_version
(You may need to replace `localhost:8008` if Synapse is not configured to
listen on that port.)
--> -->
- **Version**: - **Version**:

1
.gitignore vendored
View File

@ -10,6 +10,7 @@
*.tac *.tac
_trial_temp/ _trial_temp/
_trial_temp*/ _trial_temp*/
/out
# stuff that is likely to exist when you run a server locally # stuff that is likely to exist when you run a server locally
/*.db /*.db

View File

@ -1,8 +1,35 @@
Synapse 1.4.0 (2019-10-03)
==========================
Bugfixes
--------
- Redact `client_secret` in server logs. ([\#6158](https://github.com/matrix-org/synapse/issues/6158))
Synapse 1.4.0rc2 (2019-10-02)
=============================
Bugfixes
--------
- Fix bug in background update that adds last seen information to the `devices` table, and improve its performance on Postgres. ([\#6135](https://github.com/matrix-org/synapse/issues/6135))
- Fix bad performance of censoring redactions background task. ([\#6141](https://github.com/matrix-org/synapse/issues/6141))
- Fix fetching censored redactions from DB, which caused APIs like initial sync to fail if it tried to include the censored redaction. ([\#6145](https://github.com/matrix-org/synapse/issues/6145))
- Fix exceptions when storing large retry intervals for down remote servers. ([\#6146](https://github.com/matrix-org/synapse/issues/6146))
Internal Changes
----------------
- Fix up sample config entry for `redaction_retention_period` option. ([\#6117](https://github.com/matrix-org/synapse/issues/6117))
Synapse 1.4.0rc1 (2019-09-26) Synapse 1.4.0rc1 (2019-09-26)
============================= =============================
Note that this release includes significant changes around 3pid Note that this release includes significant changes around 3pid
verification. Administrators are reminded to review the [upgrade notes](UPGRADE.rst##upgrading-to-v140). verification. Administrators are reminded to review the [upgrade notes](UPGRADE.rst#upgrading-to-v140).
Features Features
-------- --------
@ -48,7 +75,7 @@ Features
- Let synctl accept a directory of config files. ([\#5904](https://github.com/matrix-org/synapse/issues/5904)) - Let synctl accept a directory of config files. ([\#5904](https://github.com/matrix-org/synapse/issues/5904))
- Increase max display name size to 256. ([\#5906](https://github.com/matrix-org/synapse/issues/5906)) - Increase max display name size to 256. ([\#5906](https://github.com/matrix-org/synapse/issues/5906))
- Add admin API endpoint for getting whether or not a user is a server administrator. ([\#5914](https://github.com/matrix-org/synapse/issues/5914)) - Add admin API endpoint for getting whether or not a user is a server administrator. ([\#5914](https://github.com/matrix-org/synapse/issues/5914))
- Redact events in the database that have been redacted for a month. ([\#5934](https://github.com/matrix-org/synapse/issues/5934)) - Redact events in the database that have been redacted for a week. ([\#5934](https://github.com/matrix-org/synapse/issues/5934))
- New prometheus metrics: - New prometheus metrics:
- `synapse_federation_known_servers`: represents the total number of servers your server knows about (i.e. is in rooms with), including itself. Enable by setting `metrics_flags.known_servers` to True in the configuration.([\#5981](https://github.com/matrix-org/synapse/issues/5981)) - `synapse_federation_known_servers`: represents the total number of servers your server knows about (i.e. is in rooms with), including itself. Enable by setting `metrics_flags.known_servers` to True in the configuration.([\#5981](https://github.com/matrix-org/synapse/issues/5981))
- `synapse_build_info`: exposes the Python version, OS version, and Synapse version of the running server. ([\#6005](https://github.com/matrix-org/synapse/issues/6005)) - `synapse_build_info`: exposes the Python version, OS version, and Synapse version of the running server. ([\#6005](https://github.com/matrix-org/synapse/issues/6005))

View File

@ -349,6 +349,13 @@ sudo pip uninstall py-bcrypt
sudo pip install py-bcrypt sudo pip install py-bcrypt
``` ```
### Void Linux
Synapse can be found in the void repositories as 'synapse':
xbps-install -Su
xbps-install -S synapse
### FreeBSD ### FreeBSD
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from: Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:

View File

@ -381,3 +381,16 @@ indicate that your server is also issuing far more outgoing federation
requests than can be accounted for by your users' activity, this is a requests than can be accounted for by your users' activity, this is a
likely cause. The misbehavior can be worked around by setting likely cause. The misbehavior can be worked around by setting
``use_presence: false`` in the Synapse config file. ``use_presence: false`` in the Synapse config file.
People can't accept room invitations from me
--------------------------------------------
The typical failure mode here is that you send an invitation to someone
to join a room or direct chat, but when they go to accept it, they get an
error (typically along the lines of "Invalid signature"). They might see
something like the following in their logs::
2019-09-11 19:32:04,271 - synapse.federation.transport.server - 288 - WARNING - GET-11752 - authenticate_request failed: 401: Invalid signature for server <server> with key ed25519:a_EqML: Unable to verify signature for <server>
This is normally caused by a misconfiguration in your reverse-proxy. See
`<docs/reverse_proxy.rst>`_ and double-check that your settings are correct.

1
changelog.d/1172.misc Normal file
View File

@ -0,0 +1 @@
Update `user_filters` table to have a unique index, and non-null columns. Thanks to @pik for contributing this.

1
changelog.d/2142.feature Normal file
View File

@ -0,0 +1 @@
Improve quality of thumbnails for 1-bit/8-bit color palette images.

1
changelog.d/6019.misc Normal file
View File

@ -0,0 +1 @@
Improve performance of the public room list directory.

1
changelog.d/6077.misc Normal file
View File

@ -0,0 +1 @@
Edit header dicts docstrings in SimpleHttpClient to note that `str` or `bytes` can be passed as header keys.

1
changelog.d/6109.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug when uploading a large file: Synapse responds with `M_UNKNOWN` while it should be `M_TOO_LARGE` according to spec. Contributed by Anshul Angaria.

1
changelog.d/6115.misc Normal file
View File

@ -0,0 +1 @@
Drop some unused database tables.

1
changelog.d/6125.feature Normal file
View File

@ -0,0 +1 @@
Reject all pending invites for a user during deactivation.

1
changelog.d/6139.misc Normal file
View File

@ -0,0 +1 @@
Log responder when responding to media request.

1
changelog.d/6144.bugfix Normal file
View File

@ -0,0 +1 @@
Prevent user push rules being deleted from a room when it is upgraded.

1
changelog.d/6147.bugfix Normal file
View File

@ -0,0 +1 @@
Don't 500 when trying to exchange a revoked 3PID invite.

1
changelog.d/6148.misc Normal file
View File

@ -0,0 +1 @@
Improve performance of `find_next_generated_user_id` DB query.

1
changelog.d/6150.misc Normal file
View File

@ -0,0 +1 @@
Expand type-checking on modules imported by synapse.config.

1
changelog.d/6152.misc Normal file
View File

@ -0,0 +1 @@
Improve performance of the public room list directory.

1
changelog.d/6153.misc Normal file
View File

@ -0,0 +1 @@
Improve performance of the public room list directory.

1
changelog.d/6154.misc Normal file
View File

@ -0,0 +1 @@
Improve performance of the public room list directory.

1
changelog.d/6159.misc Normal file
View File

@ -0,0 +1 @@
Add more caching to `_get_joined_users_from_context` DB query.

1
changelog.d/6160.misc Normal file
View File

@ -0,0 +1 @@
Add some metrics on the federation sender.

1
changelog.d/6161.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where guest account registration can wedge after restart.

1
changelog.d/6167.misc Normal file
View File

@ -0,0 +1 @@
Add some logging to the rooms stats updates, to try to track down a flaky test.

1
changelog.d/6170.bugfix Normal file
View File

@ -0,0 +1 @@
Fix /federation/v1/state endpoint for recent room versions.

1
changelog.d/6175.misc Normal file
View File

@ -0,0 +1 @@
Update `user_filters` table to have a unique index, and non-null columns. Thanks to @pik for contributing this.

1
changelog.d/6178.bugfix Normal file
View File

@ -0,0 +1 @@
Make the `synapse_port_db` script create the right indexes on a new PostgreSQL database.

1
changelog.d/6179.misc Normal file
View File

@ -0,0 +1 @@
Remove unused `timeout` parameter from `_get_public_room_list`.

1
changelog.d/6185.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where redacted events were sometimes incorrectly censored in the database, breaking APIs that attempted to fetch such events.

6
debian/changelog vendored
View File

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.4.0) stable; urgency=medium
* New synapse release 1.4.0.
-- Synapse Packaging team <packages@matrix.org> Thu, 03 Oct 2019 13:22:25 +0100
matrix-synapse-py3 (1.3.1) stable; urgency=medium matrix-synapse-py3 (1.3.1) stable; urgency=medium
* New synapse release 1.3.1. * New synapse release 1.3.1.

View File

@ -10,3 +10,15 @@ server admin by updating the database directly, e.g.:
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'`` ``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
Restarting may be required for the changes to register. Restarting may be required for the changes to register.
Using an admin access_token
###########################
Many of the API calls listed in the documentation here will require to include an admin `access_token`.
Finding your user's `access_token` is client-dependent, but will usually be shown in the client's settings.
Once you have your `access_token`, to include it in a request, the best option is to add the token to a request header:
``curl --header "Authorization: Bearer <access_token>" <the_rest_of_your_API_request>``
Fore more details, please refer to the complete `matrix spec documentation <https://matrix.org/docs/spec/client_server/r0.5.0#using-access-tokens>`_.

View File

@ -314,7 +314,7 @@ listeners:
# #
# Defaults to `7d`. Set to `null` to disable. # Defaults to `7d`. Set to `null` to disable.
# #
redaction_retention_period: 7d #redaction_retention_period: 28d
# How long to track users' last seen time and IPs in the database. # How long to track users' last seen time and IPs in the database.
# #

View File

@ -38,7 +38,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.4.0rc1" __version__ = "1.4.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
do_patch() do_patch()

View File

@ -17,6 +17,7 @@
"""Contains exceptions and error codes.""" """Contains exceptions and error codes."""
import logging import logging
from typing import Dict
from six import iteritems from six import iteritems
from six.moves import http_client from six.moves import http_client
@ -111,7 +112,7 @@ class ProxiedRequestError(SynapseError):
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
super(ProxiedRequestError, self).__init__(code, msg, errcode) super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None: if additional_fields is None:
self._additional_fields = {} self._additional_fields = {} # type: Dict
else: else:
self._additional_fields = dict(additional_fields) self._additional_fields = dict(additional_fields)

View File

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict
import attr import attr
@ -102,4 +105,4 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V4, RoomVersions.V4,
RoomVersions.V5, RoomVersions.V5,
) )
} # type: dict[str, RoomVersion] } # type: Dict[str, RoomVersion]

View File

@ -263,7 +263,9 @@ def start(hs, listeners=None):
refresh_certificate(hs) refresh_certificate(hs)
# Start the tracer # Start the tracer
synapse.logging.opentracing.init_tracer(hs.config) synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
hs.config
)
# It is now safe to start your Synapse. # It is now safe to start your Synapse.
hs.start_listening(listeners) hs.start_listening(listeners)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict
from six import string_types from six import string_types
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
@ -56,8 +57,8 @@ def load_appservices(hostname, config_files):
return [] return []
# Dicts of value -> filename # Dicts of value -> filename
seen_as_tokens = {} seen_as_tokens = {} # type: Dict[str, str]
seen_ids = {} seen_ids = {} # type: Dict[str, str]
appservices = [] appservices = []

View File

@ -73,8 +73,8 @@ DEFAULT_CONFIG = """\
class ConsentConfig(Config): class ConsentConfig(Config):
def __init__(self): def __init__(self, *args):
super(ConsentConfig, self).__init__() super(ConsentConfig, self).__init__(*args)
self.user_consent_version = None self.user_consent_version = None
self.user_consent_template_dir = None self.user_consent_template_dir = None

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, List
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
from ._base import Config from ._base import Config
@ -22,7 +24,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config): class PasswordAuthProviderConfig(Config):
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.password_providers = [] self.password_providers = [] # type: List[Any]
providers = [] providers = []
# We want to be backwards compatible with the old `ldap_config` # We want to be backwards compatible with the old `ldap_config`

View File

@ -15,6 +15,7 @@
import os import os
from collections import namedtuple from collections import namedtuple
from typing import Dict, List
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
@ -61,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of Dictionary mapping from media type string to list of
ThumbnailRequirement tuples. ThumbnailRequirement tuples.
""" """
requirements = {} requirements = {} # type: Dict[str, List]
for size in thumbnail_sizes: for size in thumbnail_sizes:
width = size["width"] width = size["width"]
height = size["height"] height = size["height"]
@ -130,7 +131,7 @@ class ContentRepositoryConfig(Config):
# #
# We don't create the storage providers here as not all workers need # We don't create the storage providers here as not all workers need
# them to be started. # them to be started.
self.media_storage_providers = [] self.media_storage_providers = [] # type: List[tuple]
for provider_config in storage_providers: for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to # We special case the module "file_system" so as not to need to

View File

@ -19,6 +19,7 @@ import logging
import os.path import os.path
import re import re
from textwrap import indent from textwrap import indent
from typing import List
import attr import attr
import yaml import yaml
@ -243,7 +244,7 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile. # events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
self.listeners = [] self.listeners = [] # type: List[dict]
for listener in config.get("listeners", []): for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int): if not isinstance(listener.get("port", None), int):
raise ConfigError( raise ConfigError(
@ -287,7 +288,10 @@ class ServerConfig(Config):
validator=attr.validators.instance_of(bool), default=False validator=attr.validators.instance_of(bool), default=False
) )
complexity = attr.ib( complexity = attr.ib(
validator=attr.validators.instance_of((int, float)), default=1.0 validator=attr.validators.instance_of(
(float, int) # type: ignore[arg-type] # noqa
),
default=1.0,
) )
complexity_error = attr.ib( complexity_error = attr.ib(
validator=attr.validators.instance_of(str), validator=attr.validators.instance_of(str),
@ -366,7 +370,7 @@ class ServerConfig(Config):
"cleanup_extremities_with_dummy_events", True "cleanup_extremities_with_dummy_events", True
) )
def has_tls_listener(self): def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners) return any(l["tls"] for l in self.listeners)
def generate_config_section( def generate_config_section(
@ -742,7 +746,7 @@ class ServerConfig(Config):
# #
# Defaults to `7d`. Set to `null` to disable. # Defaults to `7d`. Set to `null` to disable.
# #
redaction_retention_period: 7d #redaction_retention_period: 28d
# How long to track users' last seen time and IPs in the database. # How long to track users' last seen time and IPs in the database.
# #

View File

@ -59,8 +59,8 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled. None if server notices are not enabled.
""" """
def __init__(self): def __init__(self, *args):
super(ServerNoticesConfig, self).__init__() super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None self.server_notices_mxid = None
self.server_notices_mxid_display_name = None self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None self.server_notices_mxid_avatar_url = None

View File

@ -36,7 +36,6 @@ from synapse.api.errors import (
UnsupportedRoomVersionError, UnsupportedRoomVersionError,
) )
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.crypto.event_signing import compute_event_signature
from synapse.events import room_version_to_event_format from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
@ -322,18 +321,6 @@ class FederationServer(FederationBase):
pdus = yield self.handler.get_state_for_pdu(room_id, event_id) pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
for event in auth_chain:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0],
)
)
return { return {
"pdus": [pdu.get_pdu_json() for pdu in pdus], "pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],

View File

@ -38,7 +38,7 @@ from synapse.metrics import (
events_processed_counter, events_processed_counter,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import measure_func from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -183,8 +183,8 @@ class FederationSender(object):
# Otherwise if the last member on a server in a room is # Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't # banned then it won't receive the event because it won't
# be in the room after the ban. # be in the room after the ban.
destinations = yield self.state.get_current_hosts_in_room( destinations = yield self.state.get_hosts_in_room_at_events(
event.room_id, latest_event_ids=event.prev_event_ids() event.room_id, event_ids=event.prev_event_ids()
) )
except Exception: except Exception:
logger.exception( logger.exception(
@ -207,6 +207,7 @@ class FederationSender(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_room_events(events): def handle_room_events(events):
with Measure(self.clock, "handle_room_events"):
for event in events: for event in events:
yield handle_event(event) yield handle_event(event)

View File

@ -765,6 +765,10 @@ class PublicRoomList(BaseFederationServlet):
else: else:
network_tuple = ThirdPartyInstanceID(None, None) network_tuple = ThirdPartyInstanceID(None, None)
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
data = await maybeDeferred( data = await maybeDeferred(
self.handler.get_local_public_room_list, self.handler.get_local_public_room_list,
limit, limit,
@ -800,6 +804,10 @@ class PublicRoomList(BaseFederationServlet):
if search_filter is None: if search_filter is None:
logger.warning("Nonefilter") logger.warning("Nonefilter")
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
data = await self.handler.get_local_public_room_list( data = await self.handler.get_local_public_room_list(
limit=limit, limit=limit,
since_token=since_token, since_token=since_token,

View File

@ -120,6 +120,10 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running) # parts users from rooms (if it isn't already running)
self._start_user_parting() self._start_user_parting()
# Reject all pending invites for the user, so that the user doesn't show up in the
# "invited" section of rooms' members list.
yield self._reject_pending_invites_for_user(user_id)
# Remove all information on the user from the account_validity table. # Remove all information on the user from the account_validity table.
if self._account_validity_enabled: if self._account_validity_enabled:
yield self.store.delete_account_validity_for_user(user_id) yield self.store.delete_account_validity_for_user(user_id)
@ -129,6 +133,39 @@ class DeactivateAccountHandler(BaseHandler):
return identity_server_supports_unbinding return identity_server_supports_unbinding
@defer.inlineCallbacks
def _reject_pending_invites_for_user(self, user_id):
"""Reject pending invites addressed to a given user ID.
Args:
user_id (str): The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
pending_invites = yield self.store.get_invited_rooms_for_user(user_id)
for room in pending_invites:
try:
yield self._room_member_handler.update_membership(
create_requester(user),
user,
room.room_id,
"leave",
ratelimit=False,
require_consent=False,
)
logger.info(
"Rejected invite for deactivated user %r in room %r",
user_id,
room.room_id,
)
except Exception:
logger.exception(
"Failed to reject invite for user %r in room %r:"
" ignoring and continuing",
user_id,
room.room_id,
)
def _start_user_parting(self): def _start_user_parting(self):
""" """
Start the process that goes through the table of users Start the process that goes through the table of users

View File

@ -2570,7 +2570,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check_from_context(room_version, event, context) yield self.auth.check_from_context(room_version, event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e) logger.warn("Denying third party invite %r because %s", event, e)
raise e raise e
@ -2599,7 +2599,12 @@ class FederationHandler(BaseHandler):
original_invite_id, allow_none=True original_invite_id, allow_none=True
) )
if original_invite: if original_invite:
display_name = original_invite.content["display_name"] # If the m.room.third_party_invite event's content is empty, it means the
# invite has been revoked. In this case, we don't have to raise an error here
# because the auth check will fail on the invite (because it's not able to
# fetch public keys from the m.room.third_party_invite event's content, which
# is empty).
display_name = original_invite.content.get("display_name")
event_dict["content"]["third_party_invite"]["display_name"] = display_name event_dict["content"]["third_party_invite"]["display_name"] = display_name
else: else:
logger.info( logger.info(

View File

@ -217,10 +217,9 @@ class RegistrationHandler(BaseHandler):
else: else:
# autogen a sequential user ID # autogen a sequential user ID
attempts = 0
user = None user = None
while not user: while not user:
localpart = yield self._generate_user_id(attempts > 0) localpart = yield self._generate_user_id()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
@ -238,7 +237,6 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another # if user id is taken, just generate another
user = None user = None
user_id = None user_id = None
attempts += 1
if not self.hs.config.user_consent_at_registration: if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id) yield self._auto_join_rooms(user_id)
@ -379,10 +377,10 @@ class RegistrationHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_user_id(self, reseed=False): def _generate_user_id(self):
if reseed or self._next_generated_user_id is None: if self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())): with (yield self._generate_user_id_linearizer.queue(())):
if reseed or self._next_generated_user_id is None: if self._next_generated_user_id is None:
self._next_generated_user_id = ( self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart() yield self.store.find_next_generated_user_id_localpart()
) )

View File

@ -16,8 +16,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from six import PY3, iteritems from six import iteritems
from six.moves import range
import msgpack import msgpack
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
@ -27,7 +26,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, HttpResponseException from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -37,7 +35,6 @@ logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list. # This is used to indicate we should only return rooms published to the main list.
EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
@ -72,6 +69,8 @@ class RoomListHandler(BaseHandler):
This can be (None, None) to indicate the main list, or a particular This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one. appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists. Setting to None returns all public rooms across all lists.
from_federation (bool): true iff the request comes from the federation
API
""" """
if not self.enable_room_list_search: if not self.enable_room_list_search:
return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) return defer.succeed({"chunk": [], "total_room_count_estimate": 0})
@ -89,16 +88,8 @@ class RoomListHandler(BaseHandler):
# appservice specific lists. # appservice specific lists.
logger.info("Bypassing cache as search request.") logger.info("Bypassing cache as search request.")
# XXX: Quick hack to stop room directory queries taking too long.
# Timeout request after 60s. Probably want a more fundamental
# solution at some point
timeout = self.clock.time() + 60
return self._get_public_room_list( return self._get_public_room_list(
limit, limit, since_token, search_filter, network_tuple=network_tuple
since_token,
search_filter,
network_tuple=network_tuple,
timeout=timeout,
) )
key = (limit, since_token, network_tuple) key = (limit, since_token, network_tuple)
@ -119,7 +110,6 @@ class RoomListHandler(BaseHandler):
search_filter=None, search_filter=None,
network_tuple=EMPTY_THIRD_PARTY_ID, network_tuple=EMPTY_THIRD_PARTY_ID,
from_federation=False, from_federation=False,
timeout=None,
): ):
"""Generate a public room list. """Generate a public room list.
Args: Args:
@ -132,240 +122,116 @@ class RoomListHandler(BaseHandler):
Setting to None returns all public rooms across all lists. Setting to None returns all public rooms across all lists.
from_federation (bool): Whether this request originated from a from_federation (bool): Whether this request originated from a
federating server or a client. Used for room filtering. federating server or a client. Used for room filtering.
timeout (int|None): Amount of seconds to wait for a response before
timing out.
""" """
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
since_token = None
rooms_to_order_value = {} # Pagination tokens work by storing the room ID sent in the last batch,
rooms_to_num_joined = {} # plus the direction (forwards or backwards). Next batch tokens always
# go forwards, prev batch tokens always go backwards.
newly_visible = []
newly_unpublished = []
if since_token:
stream_token = since_token.stream_ordering
current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
public_room_stream_id, current_public_id, network_tuple=network_tuple
)
else:
stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id(
public_room_stream_id, network_tuple=network_tuple
)
# We want to return rooms in a particular order: the number of joined
# users. We then arbitrarily use the room_id as a tie breaker.
@defer.inlineCallbacks
def get_order_for_room(room_id):
# Most of the rooms won't have changed between the since token and
# now (especially if the since token is "now"). So, we can ask what
# the current users are in a room (that will hit a cache) and then
# check if the room has changed since the since token. (We have to
# do it in that order to avoid races).
# If things have changed then fall back to getting the current state
# at the since token.
joined_users = yield self.store.get_users_in_room(room_id)
if self.store.has_room_changed_since(room_id, stream_token):
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token
)
if not latest_event_ids:
return
joined_users = yield self.state_handler.get_current_users_in_room(
room_id, latest_event_ids
)
num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users
if num_joined_users == 0:
return
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
logger.info(
"Getting ordering for %i rooms since %s", len(room_ids), stream_token
)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
sorted_rooms = [room_id for room_id, _ in sorted_entries]
# `sorted_rooms` should now be a list of all public room ids that is
# stable across pagination. Therefore, we can use indices into this
# list as our pagination tokens.
# Filter out rooms that we don't want to return
rooms_to_scan = [
r
for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[r] > 0
]
total_room_count = len(rooms_to_scan)
if since_token: if since_token:
# Filter out rooms we've already returned previously batch_token = RoomListNextBatch.from_token(since_token)
# `since_token.current_limit` is the index of the last room we
# sent down, so we exclude it and everything before/after it. bounds = (batch_token.last_joined_members, batch_token.last_room_id)
if since_token.direction_is_forward: forwards = batch_token.direction_is_forward
rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :]
else: else:
rooms_to_scan = rooms_to_scan[: since_token.current_limit] batch_token = None
rooms_to_scan.reverse() bounds = None
logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan)) forwards = True
# _append_room_entry_to_chunk will append to chunk but will stop if # we request one more than wanted to see if there are more pages to come
# len(chunk) > limit probing_limit = limit + 1 if limit is not None else None
#
# Normally we will generate enough results on the first iteration here,
# but if there is a search filter, _append_room_entry_to_chunk may
# filter some results out, in which case we loop again.
#
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
#
# XXX if there is no limit, we may end up DoSing the server with
# calls to get_current_state_ids for every single room on the
# server. Surely we should cap this somehow?
#
if limit:
step = limit + 1
else:
# step cannot be zero
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = [] results = yield self.store.get_largest_public_rooms(
for i in range(0, len(rooms_to_scan), step): network_tuple,
if timeout and self.clock.time() > timeout:
raise Exception("Timed out searching room directory")
batch = rooms_to_scan[i : i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk(
r,
rooms_to_num_joined[r],
chunk,
limit,
search_filter, search_filter,
from_federation=from_federation, probing_limit,
), bounds=bounds,
batch, forwards=forwards,
5, ignore_non_federatable=from_federation,
) )
logger.info("Now %i rooms in result", len(chunk))
if len(chunk) >= limit + 1:
break
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) def build_room_entry(room):
entry = {
"room_id": room["room_id"],
"name": room["name"],
"topic": room["topic"],
"canonical_alias": room["canonical_alias"],
"num_joined_members": room["joined_members"],
"avatar_url": room["avatar"],
"world_readable": room["history_visibility"] == "world_readable",
"guest_can_join": room["guest_access"] == "can_join",
}
# Work out the new limit of the batch for pagination, or None if we # Filter out Nones rather omit the field altogether
# know there are no more results that would be returned. return {k: v for k, v in entry.items() if v is not None}
# i.e., [since_token.current_limit..new_limit] is the batch of rooms
# we've returned (or the reverse if we paginated backwards)
# We tried to pull out limit + 1 rooms above, so if we have <= limit
# then we know there are no more results to return
new_limit = None
if chunk and (not limit or len(chunk) > limit):
if not since_token or since_token.direction_is_forward: results = [build_room_entry(r) for r in results]
if limit:
chunk = chunk[:limit] response = {}
last_room_id = chunk[-1]["room_id"] num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
# Depending on direction we trim either the front or back.
if forwards:
results = results[:limit]
else: else:
if limit: results = results[-limit:]
chunk = chunk[-limit:]
last_room_id = chunk[0]["room_id"]
new_limit = sorted_rooms.index(last_room_id)
results = {"chunk": chunk, "total_room_count_estimate": total_room_count}
if since_token:
results["new_rooms"] = bool(newly_visible)
if not since_token or since_token.direction_is_forward:
if new_limit is not None:
results["next_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=True,
).to_token()
if since_token:
results["prev_batch"] = since_token.copy_and_replace(
direction_is_forward=False,
current_limit=since_token.current_limit + 1,
).to_token()
else: else:
if new_limit is not None: more_to_come = False
results["prev_batch"] = RoomListNextBatch(
stream_ordering=stream_token, if num_results > 0:
public_room_stream_id=public_room_stream_id, final_entry = results[-1]
current_limit=new_limit, initial_entry = results[0]
if forwards:
if batch_token:
# If there was a token given then we assume that there
# must be previous results.
response["prev_batch"] = RoomListNextBatch(
last_joined_members=initial_entry["num_joined_members"],
last_room_id=initial_entry["room_id"],
direction_is_forward=False, direction_is_forward=False,
).to_token() ).to_token()
if since_token: if more_to_come:
results["next_batch"] = since_token.copy_and_replace( response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"],
direction_is_forward=True,
).to_token()
else:
if batch_token:
response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"],
direction_is_forward=True, direction_is_forward=True,
current_limit=since_token.current_limit - 1,
).to_token() ).to_token()
return results if more_to_come:
response["prev_batch"] = RoomListNextBatch(
last_joined_members=initial_entry["num_joined_members"],
last_room_id=initial_entry["room_id"],
direction_is_forward=False,
).to_token()
@defer.inlineCallbacks for room in results:
def _append_room_entry_to_chunk( # populate search result entries with additional fields, namely
self, # 'aliases'
room_id, room_id = room["room_id"]
num_joined_users,
chunk,
limit,
search_filter,
from_federation=False,
):
"""Generate the entry for a room in the public room list and append it
to the `chunk` if it matches the search filter
Args: aliases = yield self.store.get_aliases_for_room(room_id)
room_id (str): The ID of the room. if aliases:
num_joined_users (int): The number of joined users in the room. room["aliases"] = aliases
chunk (list)
limit (int|None): Maximum amount of rooms to display. Function will
return if length of chunk is greater than limit + 1.
search_filter (dict|None)
from_federation (bool): Whether this request originated from a
federating server or a client. Used for room filtering.
"""
if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it.
return
result = yield self.generate_room_entry(room_id, num_joined_users) response["chunk"] = results
if not result:
return
if from_federation and not result.get("m.federate", True): response["total_room_count_estimate"] = yield self.store.count_public_rooms(
# This is a room that other servers cannot join. Do not show them network_tuple, ignore_non_federatable=from_federation
# this room. )
return
if _matches_room_entry(result, search_filter): return response
chunk.append(result)
@cachedInlineCallbacks(num_args=1, cache_context=True) @cachedInlineCallbacks(num_args=1, cache_context=True)
def generate_room_entry( def generate_room_entry(
@ -580,18 +446,15 @@ class RoomListNextBatch(
namedtuple( namedtuple(
"RoomListNextBatch", "RoomListNextBatch",
( (
"stream_ordering", # stream_ordering of the first public room list "last_joined_members", # The count to get rooms after/before
"public_room_stream_id", # public room stream id for first public room list "last_room_id", # The room_id to get rooms after/before
"current_limit", # The number of previous rooms returned
"direction_is_forward", # Bool if this is a next_batch, false if prev_batch "direction_is_forward", # Bool if this is a next_batch, false if prev_batch
), ),
) )
): ):
KEY_DICT = { KEY_DICT = {
"stream_ordering": "s", "last_joined_members": "m",
"public_room_stream_id": "p", "last_room_id": "r",
"current_limit": "n",
"direction_is_forward": "d", "direction_is_forward": "d",
} }
@ -599,13 +462,7 @@ class RoomListNextBatch(
@classmethod @classmethod
def from_token(cls, token): def from_token(cls, token):
if PY3:
# The argument raw=False is only available on new versions of
# msgpack, and only really needed on Python 3. Gate it behind
# a PY3 check to avoid causing issues on Debian-packaged versions.
decoded = msgpack.loads(decode_base64(token), raw=False) decoded = msgpack.loads(decode_base64(token), raw=False)
else:
decoded = msgpack.loads(decode_base64(token))
return RoomListNextBatch( return RoomListNextBatch(
**{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
) )

View File

@ -216,8 +216,8 @@ class RoomMemberHandler(object):
yield self.copy_room_tags_and_direct_to_room( yield self.copy_room_tags_and_direct_to_room(
predecessor["room_id"], room_id, user_id predecessor["room_id"], room_id, user_id
) )
# Move over old push rules # Copy over push rules
yield self.store.move_push_rules_from_room_to_room_for_user( yield self.store.copy_push_rules_from_room_to_room_for_user(
predecessor["room_id"], room_id, user_id predecessor["room_id"], room_id, user_id
) )
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:

View File

@ -293,6 +293,7 @@ class StatsHandler(StateDeltasHandler):
room_state["guest_access"] = event_content.get("guest_access") room_state["guest_access"] = event_content.get("guest_access")
for room_id, state in room_to_state_updates.items(): for room_id, state in room_to_state_updates.items():
logger.info("Updating room_stats_state for %s: %s", room_id, state)
yield self.store.update_room_state(room_id, state) yield self.store.update_room_state(room_id, state)
return room_to_stats_deltas, user_to_stats_deltas return room_to_stats_deltas, user_to_stats_deltas

View File

@ -42,11 +42,13 @@ def cancelled_to_request_timed_out_error(value, timeout):
ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")
def redact_uri(uri): def redact_uri(uri):
"""Strips access tokens from the uri replaces with <redacted>""" """Strips sensitive information from the uri replaces with <redacted>"""
return ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri) uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
class QuieterFileBodyProducer(FileBodyProducer): class QuieterFileBodyProducer(FileBodyProducer):

View File

@ -327,7 +327,7 @@ class SimpleHttpClient(object):
Args: Args:
uri (str): uri (str):
args (dict[str, str|List[str]]): query params args (dict[str, str|List[str]]): query params
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
@ -371,7 +371,7 @@ class SimpleHttpClient(object):
Args: Args:
uri (str): uri (str):
post_json (object): post_json (object):
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
@ -414,7 +414,7 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@ -438,7 +438,7 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@ -482,7 +482,7 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@ -516,7 +516,7 @@ class SimpleHttpClient(object):
Args: Args:
url (str): The URL to GET url (str): The URL to GET
output_stream (file): File to write the response body to. output_stream (file): File to write the response body to.
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
A (int,dict,string,int) tuple of the file length, dict of the response A (int,dict,string,int) tuple of the file length, dict of the response

View File

@ -170,6 +170,7 @@ import inspect
import logging import logging
import re import re
from functools import wraps from functools import wraps
from typing import Dict
from canonicaljson import json from canonicaljson import json
@ -547,7 +548,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
return return
span = opentracing.tracer.active_span span = opentracing.tracer.active_span
carrier = {} carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items(): for key, value in carrier.items():
@ -584,7 +585,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
span = opentracing.tracer.active_span span = opentracing.tracer.active_span
carrier = {} carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items(): for key, value in carrier.items():
@ -639,7 +640,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination): if destination and not whitelisted_homeserver(destination):
return {} return {}
carrier = {} carrier = {} # type: Dict[str, str]
opentracing.tracer.inject( opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
) )
@ -653,7 +654,7 @@ def active_span_context_as_string():
Returns: Returns:
The active span context encoded as a string. The active span context encoded as a string.
""" """
carrier = {} carrier = {} # type: Dict[str, str]
if opentracing: if opentracing:
opentracing.tracer.inject( opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier

View File

@ -119,7 +119,11 @@ def trace_function(f):
logger = logging.getLogger(name) logger = logging.getLogger(name)
level = logging.DEBUG level = logging.DEBUG
s = inspect.currentframe().f_back frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back
to_print = [ to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s" "\t%s:%s %s. Args: args=%s, kwargs=%s"
@ -144,7 +148,7 @@ def trace_function(f):
pathname=pathname, pathname=pathname,
lineno=lineno, lineno=lineno,
msg=msg, msg=msg,
args=None, args=tuple(),
exc_info=None, exc_info=None,
) )
@ -157,7 +161,12 @@ def trace_function(f):
def get_previous_frames(): def get_previous_frames():
s = inspect.currentframe().f_back.f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
to_return = [] to_return = []
while s: while s:
if s.f_globals["__name__"].startswith("synapse"): if s.f_globals["__name__"].startswith("synapse"):
@ -174,7 +183,10 @@ def get_previous_frames():
def get_previous_frame(ignore=[]): def get_previous_frame(ignore=[]):
s = inspect.currentframe().f_back.f_back frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
while s: while s:
if s.f_globals["__name__"].startswith("synapse"): if s.f_globals["__name__"].startswith("synapse"):

View File

@ -125,7 +125,7 @@ class InFlightGauge(object):
) )
# Counts number of in flight blocks for a given set of label values # Counts number of in flight blocks for a given set of label values
self._registrations = {} self._registrations = {} # type: Dict
# Protects access to _registrations # Protects access to _registrations
self._lock = threading.Lock() self._lock = threading.Lock()
@ -226,7 +226,7 @@ class BucketCollector(object):
# Fetch the data -- this must be synchronous! # Fetch the data -- this must be synchronous!
data = self.data_collector() data = self.data_collector()
buckets = {} buckets = {} # type: Dict[float, int]
res = [] res = []
for x in data.keys(): for x in data.keys():

View File

@ -36,9 +36,9 @@ from twisted.web.resource import Resource
try: try:
from prometheus_client.samples import Sample from prometheus_client.samples import Sample
except ImportError: except ImportError:
Sample = namedtuple( Sample = namedtuple( # type: ignore[no-redef] # noqa
"Sample", ["name", "labels", "value", "timestamp", "exemplar"] "Sample", ["name", "labels", "value", "timestamp", "exemplar"]
) # type: ignore )
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Set from typing import List, Set
from pkg_resources import ( from pkg_resources import (
DistributionNotFound, DistributionNotFound,
@ -73,6 +73,7 @@ REQUIREMENTS = [
"netaddr>=0.7.18", "netaddr>=0.7.18",
"Jinja2>=2.9", "Jinja2>=2.9",
"bleach>=1.4.3", "bleach>=1.4.3",
"typing-extensions>=3.7.4",
] ]
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
@ -144,7 +145,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency) deps_needed.append(dependency)
errors.append( errors.append(
"Needed %s, got %s==%s" "Needed %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version) % (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
) )
except DistributionNotFound: except DistributionNotFound:
deps_needed.append(dependency) deps_needed.append(dependency)
@ -159,7 +164,7 @@ def check_requirements(for_feature=None):
if not for_feature: if not for_feature:
# Check the optional dependencies are up to date. We allow them to not be # Check the optional dependencies are up to date. We allow them to not be
# installed. # installed.
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]
for dependency in OPTS: for dependency in OPTS:
try: try:
@ -168,7 +173,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency) deps_needed.append(dependency)
errors.append( errors.append(
"Needed optional %s, got %s==%s" "Needed optional %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version) % (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
) )
except DistributionNotFound: except DistributionNotFound:
# If it's not found, we don't care # If it's not found, we don't care

View File

@ -361,6 +361,10 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = parse_integer(request, "limit", 0) limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None) since_token = parse_string(request, "since", None)
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server: if server:
data = yield handler.get_remote_public_room_list( data = yield handler.get_remote_public_room_list(
@ -398,6 +402,10 @@ class PublicRoomListRestServlet(TransactionRestServlet):
else: else:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server: if server:
data = yield handler.get_remote_public_room_list( data = yield handler.get_remote_public_room_list(

View File

@ -195,7 +195,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
respond_404(request) respond_404(request)
return return
logger.debug("Responding to media request with responder %s") logger.debug("Responding to media request with responder %s", responder)
add_file_headers(request, media_type, file_size, upload_name) add_file_headers(request, media_type, file_size, upload_name)
try: try:
with responder: with responder:

View File

@ -82,13 +82,21 @@ class Thumbnailer(object):
else: else:
return (max_height * self.width) // self.height, max_height return (max_height * self.width) // self.height, max_height
def _resize(self, width, height):
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
# looks awful
if self.image.mode in ["1", "P"]:
self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width, height, output_type): def scale(self, width, height, output_type):
"""Rescales the image to the given dimensions. """Rescales the image to the given dimensions.
Returns: Returns:
BytesIO: the bytes of the encoded image ready to be written to disk BytesIO: the bytes of the encoded image ready to be written to disk
""" """
scaled = self.image.resize((width, height), Image.ANTIALIAS) scaled = self._resize(width, height)
return self._encode_image(scaled, output_type) return self._encode_image(scaled, output_type)
def crop(self, width, height, output_type): def crop(self, width, height, output_type):
@ -107,13 +115,13 @@ class Thumbnailer(object):
""" """
if width * self.height > height * self.width: if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width scaled_height = (width * self.height) // self.width
scaled_image = self.image.resize((width, scaled_height), Image.ANTIALIAS) scaled_image = self._resize(width, scaled_height)
crop_top = (scaled_height - height) // 2 crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top crop_bottom = height + crop_top
cropped = scaled_image.crop((0, crop_top, width, crop_bottom)) cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
else: else:
scaled_width = (height * self.width) // self.height scaled_width = (height * self.width) // self.height
scaled_image = self.image.resize((scaled_width, height), Image.ANTIALIAS) scaled_image = self._resize(scaled_width, height)
crop_left = (scaled_width - width) // 2 crop_left = (scaled_width - width) // 2
crop_right = width + crop_left crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height)) cropped = scaled_image.crop((crop_left, 0, crop_right, height))

View File

@ -17,7 +17,7 @@ import logging
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import ( from synapse.http.server import (
DirectServeResource, DirectServeResource,
respond_with_json, respond_with_json,
@ -56,7 +56,11 @@ class UploadResource(DirectServeResource):
if content_length is None: if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400) raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size: if int(content_length) > self.max_upload_size:
raise SynapseError(msg="Upload request body is too large", code=413) raise SynapseError(
msg="Upload request body is too large",
code=413,
errcode=Codes.TOO_LARGE,
)
upload_name = parse_string(request, b"filename", encoding=None) upload_name = parse_string(request, b"filename", encoding=None)
if upload_name: if upload_name:

View File

@ -33,7 +33,7 @@ from synapse.state import v1, v2
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -191,11 +191,22 @@ class StateHandler(object):
return joined_users return joined_users
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_hosts_in_room(self, room_id, latest_event_ids=None): def get_current_hosts_in_room(self, room_id):
if not latest_event_ids: event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) @defer.inlineCallbacks
def get_hosts_in_room_at_events(self, room_id, event_ids):
"""Get the hosts that were in a room at the given event ids
Args:
room_id (str):
event_ids (list[str]):
Returns:
Deferred[list[str]]: the hosts in the room at the given events
"""
entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry) joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
return joined_hosts return joined_hosts
@ -344,6 +355,7 @@ class StateHandler(object):
return context return context
@measure_func()
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids): def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each

View File

@ -33,16 +33,9 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000 LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore): class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
self.register_background_index_update( self.register_background_index_update(
"user_ips_device_index", "user_ips_device_index",
@ -92,19 +85,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"devices_last_seen", self._devices_last_seen_update "devices_last_seen", self._devices_last_seen_update
) )
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks @defer.inlineCallbacks
def _remove_user_ip_nonunique(self, progress, batch_size): def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn): def f(conn):
@ -303,6 +283,110 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
return batch_size return batch_size
@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
# This consists of two queries:
#
# 1. The sub-query searches for the next N devices and joins
# against user_ips to find the max last_seen associated with
# that device.
# 2. The outer query then joins again against user_ips on
# user/device/last_seen. This *should* hopefully only
# return one row, but if it does return more than one then
# we'll just end up updating the same device row multiple
# times, which is fine.
if self.database_engine.supports_tuple_comparison:
where_clause = "(user_id, device_id) > (?, ?)"
where_args = [last_user_id, last_device_id]
else:
# We explicitly do a `user_id >= ? AND (...)` here to ensure
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
# makes it hard for query optimiser to tell that it can use the
# index on user_id
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
where_args = [last_user_id, last_user_id, last_device_id]
sql = """
SELECT
last_seen, ip, user_agent, user_id, device_id
FROM (
SELECT
user_id, device_id, MAX(u.last_seen) AS last_seen
FROM devices
INNER JOIN user_ips AS u USING (user_id, device_id)
WHERE %(where_clause)s
GROUP BY user_id, device_id
ORDER BY user_id ASC, device_id ASC
LIMIT ?
) c
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
""" % {
"where_clause": where_clause
}
txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall()
if not rows:
return 0
sql = """
UPDATE devices
SET last_seen = ?, ip = ?, user_agent = ?
WHERE user_id = ? AND device_id = ?
"""
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
self._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
)
return len(rows)
updated = yield self.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
yield self._end_background_update("devices_last_seen")
return updated
class ClientIpStore(ClientIpBackgroundUpdateStore):
def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks @defer.inlineCallbacks
def insert_client_ip( def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None self, user_id, access_token, ip, user_agent, device_id, now=None
@ -454,53 +538,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
for (access_token, ip), (user_agent, last_seen) in iteritems(results) for (access_token, ip), (user_agent, last_seen) in iteritems(results)
) )
@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
sql = """
SELECT u.last_seen, u.ip, u.user_agent, user_id, device_id FROM devices
INNER JOIN user_ips AS u USING (user_id, device_id)
WHERE user_id > ? OR (user_id = ? AND device_id > ?)
ORDER BY user_id ASC, device_id ASC
LIMIT ?
"""
txn.execute(sql, (last_user_id, last_user_id, last_device_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
sql = """
UPDATE devices
SET last_seen = ?, ip = ?, user_agent = ?
WHERE user_id = ? AND device_id = ?
"""
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
self._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
)
return len(rows)
updated = yield self.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
yield self._end_background_update("devices_last_seen")
return updated
@wrap_as_background_process("prune_old_user_ips") @wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self): async def _prune_old_user_ips(self):
"""Removes entries in user IPs older than the configured period. """Removes entries in user IPs older than the configured period.

View File

@ -208,11 +208,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(db_conn, hs) super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update( self.register_background_index_update(
"device_inbox_stream_index", "device_inbox_stream_index",
@ -225,6 +225,26 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
) )
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been # Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions. # deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache( self._last_device_delete_cache = ExpiringCache(
@ -435,16 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
return self.runInteraction( return self.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn "get_all_new_device_messages", get_all_new_device_messages_txn
) )
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1

View File

@ -512,17 +512,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results return results
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs) super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
name="device_id_exists", keylen=2, max_entries=10000
)
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
self.register_background_index_update( self.register_background_index_update(
"device_lists_stream_idx", "device_lists_stream_idx",
@ -555,6 +547,31 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self._drop_device_list_streams_non_unique_indexes, self._drop_device_list_streams_non_unique_indexes,
) )
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
return 1
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
name="device_id_exists", keylen=2, max_entries=10000
)
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
@defer.inlineCallbacks @defer.inlineCallbacks
def store_device(self, user_id, device_id, initial_device_display_name): def store_device(self, user_id, device_id, initial_device_display_name):
"""Ensure the given device is known; add it to the store if not """Ensure the given device is known; add it to the store if not
@ -910,15 +927,3 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"_prune_old_outbound_device_pokes", "_prune_old_outbound_device_pokes",
_prune_txn, _prune_txn,
) )
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
return 1

View File

@ -72,6 +72,13 @@ class PostgresEngine(object):
""" """
return True return True
@property
def supports_tuple_comparison(self):
"""
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
"""
return True
def is_deadlock(self, error): def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError): if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html # https://www.postgresql.org/docs/current/static/errcodes-appendix.html

View File

@ -38,6 +38,14 @@ class Sqlite3Engine(object):
""" """
return self.module.sqlite_version_info >= (3, 24, 0) return self.module.sqlite_version_info >= (3, 24, 0)
@property
def supports_tuple_comparison(self):
"""
Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires
SQLite 3.15+.
"""
return self.module.sqlite_version_info >= (3, 15, 0)
def check_database(self, txn): def check_database(self, txn):
pass pass

View File

@ -23,7 +23,7 @@ from functools import wraps
from six import iteritems, text_type from six import iteritems, text_type
from six.moves import range from six.moves import range
from canonicaljson import encode_canonical_json, json from canonicaljson import json
from prometheus_client import Counter, Histogram from prometheus_client import Counter, Histogram
from twisted.internet import defer from twisted.internet import defer
@ -1389,6 +1389,18 @@ class EventsStore(
], ],
) )
for event, _ in events_and_contexts:
if not event.internal_metadata.is_redacted():
# If we're persisting an unredacted event we go and ensure
# that we mark any redactions that reference this event as
# requiring censoring.
self._simple_update_txn(
txn,
table="redactions",
keyvalues={"redacts": event.event_id},
updatevalues={"have_censored": False},
)
def _store_rejected_events_txn(self, txn, events_and_contexts): def _store_rejected_events_txn(self, txn, events_and_contexts):
"""Add rows to the 'rejections' table for received events which were """Add rows to the 'rejections' table for received events which were
rejected rejected
@ -1552,9 +1564,15 @@ class EventsStore(
def _store_redaction(self, txn, event): def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event # invalidate the cache for the redacted event
txn.call_after(self._invalidate_get_event_cache, event.redacts) txn.call_after(self._invalidate_get_event_cache, event.redacts)
txn.execute(
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)", self._simple_insert_txn(
(event.event_id, event.redacts), txn,
table="redactions",
values={
"event_id": event.event_id,
"redacts": event.redacts,
"received_ts": self._clock.time_msec(),
},
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1571,36 +1589,29 @@ class EventsStore(
if self.hs.config.redaction_retention_period is None: if self.hs.config.redaction_retention_period is None:
return return
max_pos = yield self.find_first_stream_ordering_after_ts( before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
self._clock.time_msec() - self.hs.config.redaction_retention_period
)
# We fetch all redactions that: # We fetch all redactions that:
# 1. point to an event we have, # 1. point to an event we have,
# 2. has a stream ordering from before the cut off, and # 2. has a received_ts from before the cut off, and
# 3. we haven't yet censored. # 3. we haven't yet censored.
# #
# This is limited to 100 events to ensure that we don't try and do too # This is limited to 100 events to ensure that we don't try and do too
# much at once. We'll get called again so this should eventually catch # much at once. We'll get called again so this should eventually catch
# up. # up.
#
# We use the range [-max_pos, max_pos] to handle backfilled events,
# which are given negative stream ordering.
sql = """ sql = """
SELECT redact_event.event_id, redacts FROM redactions SELECT redactions.event_id, redacts FROM redactions
INNER JOIN events AS redact_event USING (event_id) LEFT JOIN events AS original_event ON (
INNER JOIN events AS original_event ON ( redacts = original_event.event_id
redact_event.room_id = original_event.room_id
AND redacts = original_event.event_id
) )
WHERE NOT have_censored WHERE NOT have_censored
AND ? <= redact_event.stream_ordering AND redact_event.stream_ordering <= ? AND redactions.received_ts <= ?
ORDER BY redact_event.stream_ordering ASC ORDER BY redactions.received_ts ASC
LIMIT ? LIMIT ?
""" """
rows = yield self._execute( rows = yield self._execute(
"_censor_redactions_fetch", None, sql, -max_pos, max_pos, 100 "_censor_redactions_fetch", None, sql, before_ts, 100
) )
updates = [] updates = []
@ -1621,9 +1632,7 @@ class EventsStore(
and original_event.internal_metadata.is_redacted() and original_event.internal_metadata.is_redacted()
): ):
# Redaction was allowed # Redaction was allowed
pruned_json = encode_canonical_json( pruned_json = encode_json(prune_event_dict(original_event.get_dict()))
prune_event_dict(original_event.get_dict())
)
else: else:
# Redaction wasn't allowed # Redaction wasn't allowed
pruned_json = None pruned_json = None

View File

@ -67,6 +67,10 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
) )
self.register_background_update_handler(
"redactions_received_ts", self._redactions_received_ts
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size): def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
@ -397,3 +401,60 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
) )
return num_handled return num_handled
@defer.inlineCallbacks
def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions.
"""
last_event_id = progress.get("last_event_id", "")
def _redactions_received_ts_txn(txn):
# Fetch the set of event IDs that we want to update
sql = """
SELECT event_id FROM redactions
WHERE event_id > ?
ORDER BY event_id ASC
LIMIT ?
"""
txn.execute(sql, (last_event_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
upper_event_id, = rows[-1]
# Update the redactions with the received_ts.
#
# Note: Not all events have an associated received_ts, so we
# fallback to using origin_server_ts. If we for some reason don't
# have an origin_server_ts, lets just use the current timestamp.
#
# We don't want to leave it null, as then we'll never try and
# censor those redactions.
sql = """
UPDATE redactions
SET received_ts = (
SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events
WHERE events.event_id = redactions.event_id
)
WHERE ? <= event_id AND event_id <= ?
"""
txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
self._background_update_progress_txn(
txn, "redactions_received_ts", {"last_event_id": upper_event_id}
)
return len(rows)
count = yield self.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
yield self._end_background_update("redactions_received_ts")
return count

View File

@ -238,6 +238,20 @@ class EventsWorkerStore(SQLBaseStore):
# we have to recheck auth now. # we have to recheck auth now.
if not allow_rejected and entry.event.type == EventTypes.Redaction: if not allow_rejected and entry.event.type == EventTypes.Redaction:
if not hasattr(entry.event, "redacts"):
# A redacted redaction doesn't have a `redacts` key, in
# which case lets just withhold the event.
#
# Note: Most of the time if the redactions has been
# redacted we still have the un-redacted event in the DB
# and so we'll still see the `redacts` key. However, this
# isn't always true e.g. if we have censored the event.
logger.debug(
"Withholding redaction event %s as we don't have redacts key",
event_id,
)
continue
redacted_event_id = entry.event.redacts redacted_event_id = entry.event.redacts
event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id) original_event_entry = event_map.get(redacted_event_id)

View File

@ -15,11 +15,9 @@
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
class MediaRepositoryStore(BackgroundUpdateStore): class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs) super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update( self.register_background_index_update(
update_name="local_media_repository_url_idx", update_name="local_media_repository_url_idx",
@ -29,6 +27,13 @@ class MediaRepositoryStore(BackgroundUpdateStore):
where_clause="url_cache IS NOT NULL", where_clause="url_cache IS NOT NULL",
) )
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs)
def get_local_media(self, media_id): def get_local_media(self, media_id):
"""Get the metadata for a local piece of media """Get the metadata for a local piece of media
Returns: Returns:

View File

@ -183,8 +183,8 @@ class PushRulesWorkerStore(
return results return results
@defer.inlineCallbacks @defer.inlineCallbacks
def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule): def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
"""Move a single push rule from one room to another for a specific user. """Copy a single push rule from one room to another for a specific user.
Args: Args:
new_room_id (str): ID of the new room. new_room_id (str): ID of the new room.
@ -209,14 +209,11 @@ class PushRulesWorkerStore(
actions=rule["actions"], actions=rule["actions"],
) )
# Delete push rule for the old room
yield self.delete_push_rule(user_id, rule["rule_id"])
@defer.inlineCallbacks @defer.inlineCallbacks
def move_push_rules_from_room_to_room_for_user( def copy_push_rules_from_room_to_room_for_user(
self, old_room_id, new_room_id, user_id self, old_room_id, new_room_id, user_id
): ):
"""Move all of the push rules from one room to another for a specific """Copy all of the push rules from one room to another for a specific
user. user.
Args: Args:
@ -227,15 +224,14 @@ class PushRulesWorkerStore(
# Retrieve push rules for this user # Retrieve push rules for this user
user_push_rules = yield self.get_push_rules_for_user(user_id) user_push_rules = yield self.get_push_rules_for_user(user_id)
# Get rules relating to the old room, move them to the new room, then # Get rules relating to the old room and copy them to the new room
# delete them from the old room
for rule in user_push_rules: for rule in user_push_rules:
conditions = rule.get("conditions", []) conditions = rule.get("conditions", [])
if any( if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id) (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions for c in conditions
): ):
yield self.move_push_rule_from_room_to_room(new_room_id, user_id, rule) yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@defer.inlineCallbacks @defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context): def bulk_get_push_rules_for_room(self, event, context):

View File

@ -493,7 +493,9 @@ class RegistrationWorkerStore(SQLBaseStore):
""" """
def _find_next_generated_user_id(txn): def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users") # We bound between '@1' and '@a' to avoid pulling the entire table
# out.
txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):") regex = re.compile(r"^@(\d+):")
@ -785,13 +787,14 @@ class RegistrationWorkerStore(SQLBaseStore):
) )
class RegistrationStore( class RegistrationBackgroundUpdateStore(
RegistrationWorkerStore, background_updates.BackgroundUpdateStore RegistrationWorkerStore, background_updates.BackgroundUpdateStore
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs) super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.config = hs.config
self.register_background_index_update( self.register_background_index_update(
"access_tokens_device_index", "access_tokens_device_index",
@ -807,8 +810,6 @@ class RegistrationStore(
columns=["creation_ts"], columns=["creation_ts"],
) )
self._account_validity = hs.config.account_validity
# we no longer use refresh tokens, but it's possible that some people # we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just # might have a background update queued to build this index. Just
# clear the background update. # clear the background update.
@ -822,17 +823,6 @@ class RegistrationStore(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag "users_set_deactivated_flag", self._background_update_set_deactivated_flag
) )
# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"cull_expired_threepid_validation_tokens",
self.cull_expired_threepid_validation_tokens,
)
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_update_set_deactivated_flag(self, progress, batch_size): def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
@ -894,6 +884,54 @@ class RegistrationStore(
return nb_processed return nb_processed
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
id_servers = set(self.config.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
INSERT INTO user_threepid_id_server
(user_id, medium, address, id_server)
SELECT user_id, medium, address, ?
FROM user_threepids
"""
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
return 1
class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
self._account_validity = hs.config.account_validity
# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"cull_expired_threepid_validation_tokens",
self.cull_expired_threepid_validation_tokens,
)
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
"""Adds an access token for the given user. """Adds an access token for the given user.
@ -1242,36 +1280,6 @@ class RegistrationStore(
desc="get_users_pending_deactivation", desc="get_users_pending_deactivation",
) )
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
id_servers = set(self.config.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
INSERT INTO user_threepid_id_server
(user_id, medium, address, id_server)
SELECT user_id, medium, address, ?
FROM user_threepids
"""
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
return 1
def validate_threepid_session(self, session_id, client_secret, token, current_ts): def validate_threepid_session(self, session_id, client_secret, token, current_ts):
"""Attempt to validate a threepid session using a token """Attempt to validate a threepid session using a token
@ -1463,17 +1471,6 @@ class RegistrationStore(
self.clock.time_msec(), self.clock.time_msec(),
) )
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self._simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_user_deactivated_status(self, user_id, deactivated): def set_user_deactivated_status(self, user_id, deactivated):
"""Set the `deactivated` property for the provided user to the provided value. """Set the `deactivated` property for the provided user to the provided value.
@ -1489,3 +1486,14 @@ class RegistrationStore(
user_id, user_id,
deactivated, deactivated,
) )
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self._simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,6 +17,7 @@
import collections import collections
import logging import logging
import re import re
from typing import Optional, Tuple
from canonicaljson import json from canonicaljson import json
@ -24,6 +26,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.search import SearchStore from synapse.storage.search import SearchStore
from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,103 +66,196 @@ class RoomWorkerStore(SQLBaseStore):
desc="get_public_room_ids", desc="get_public_room_ids",
) )
@cached(num_args=2, max_entries=100) def count_public_rooms(self, network_tuple, ignore_non_federatable):
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): """Counts the number of public rooms as tracked in the room_stats_current
"""Get pulbic rooms for a particular list, or across all lists. and room_stats_state table.
Args: Args:
stream_id (int) network_tuple (ThirdPartyInstanceID|None)
network_tuple (ThirdPartyInstanceID): The list to use (None, None) ignore_non_federatable (bool): If true filters out non-federatable rooms
means the main list, None means all lsits.
""" """
return self.runInteraction(
"get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn,
stream_id,
network_tuple=network_tuple,
)
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple): def _count_public_rooms_txn(txn):
return { query_args = []
rm
for rm, vis in self.get_published_at_stream_id_txn( if network_tuple:
txn, stream_id, network_tuple=network_tuple if network_tuple.appservice_id:
).items() published_sql = """
if vis SELECT room_id from appservice_room_list
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
query_args.append(network_tuple.network_id)
else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
"""
else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
UNION SELECT room_id from appservice_room_list
"""
sql = """
SELECT
COALESCE(COUNT(*), 0)
FROM (
%(published_sql)s
) published
INNER JOIN room_stats_state USING (room_id)
INNER JOIN room_stats_current USING (room_id)
WHERE
(
join_rules = 'public' OR history_visibility = 'world_readable'
)
AND joined_members > 0
""" % {
"published_sql": published_sql
} }
def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): txn.execute(sql, query_args)
if network_tuple: return txn.fetchone()[0]
# We want to get from a particular list. No aggregation required.
return self.runInteraction("count_public_rooms", _count_public_rooms_txn)
@defer.inlineCallbacks
def get_largest_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
bounds: Optional[Tuple[int, str]],
forwards: bool,
ignore_non_federatable: bool = False,
):
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
Args:
network_tuple
search_filter
limit: Maxmimum number of rows to return, unlimited otherwise.
bounds: An uppoer or lower bound to apply to result set if given,
consists of a joined member count and room_id (these are
excluded from result set).
forwards: true iff going forwards, going backwards otherwise
ignore_non_federatable: If true filters out non-federatable rooms.
Returns:
Rooms in order: biggest number of joined users first.
We then arbitrarily use the room_id as a tie breaker.
sql = """
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
FROM public_room_list_stream
WHERE stream_id <= ? %s
GROUP BY room_id
) grouped USING (room_id, stream_id)
""" """
if network_tuple.appservice_id is not None: where_clauses = []
txn.execute( query_args = []
sql % ("AND appservice_id = ? AND network_id = ?",),
(stream_id, network_tuple.appservice_id, network_tuple.network_id), if network_tuple:
if network_tuple.appservice_id:
published_sql = """
SELECT room_id from appservice_room_list
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
query_args.append(network_tuple.network_id)
else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
"""
else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
UNION SELECT room_id from appservice_room_list
"""
# Work out the bounds if we're given them, these bounds look slightly
# odd, but are designed to help query planner use indices by pulling
# out a common bound.
if bounds:
last_joined_members, last_room_id = bounds
if forwards:
where_clauses.append(
"""
joined_members <= ? AND (
joined_members < ? OR room_id < ?
)
"""
) )
else: else:
txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,)) where_clauses.append(
return dict(txn) """
else: joined_members >= ? AND (
# We want to get from all lists, so we need to aggregate the results joined_members > ? OR room_id > ?
)
"""
)
logger.info("Executing full list") query_args += [last_joined_members, last_joined_members, last_room_id]
if ignore_non_federatable:
where_clauses.append("is_federatable")
if search_filter and search_filter.get("generic_search_term", None):
search_term = "%" + search_filter["generic_search_term"] + "%"
where_clauses.append(
"""
(
name LIKE ?
OR topic LIKE ?
OR canonical_alias LIKE ?
)
"""
)
query_args += [search_term, search_term, search_term]
where_clause = ""
if where_clauses:
where_clause = " AND " + " AND ".join(where_clauses)
sql = """ sql = """
SELECT room_id, visibility
FROM public_room_list_stream
INNER JOIN (
SELECT SELECT
room_id, max(stream_id) AS stream_id, appservice_id, room_id, name, topic, canonical_alias, joined_members,
network_id avatar, history_visibility, joined_members, guest_access
FROM public_room_list_stream FROM (
WHERE stream_id <= ? %(published_sql)s
GROUP BY room_id, appservice_id, network_id ) published
) grouped USING (room_id, stream_id) INNER JOIN room_stats_state USING (room_id)
INNER JOIN room_stats_current USING (room_id)
WHERE
(
join_rules = 'public' OR history_visibility = 'world_readable'
)
AND joined_members > 0
%(where_clause)s
ORDER BY joined_members %(dir)s, room_id %(dir)s
""" % {
"published_sql": published_sql,
"where_clause": where_clause,
"dir": "DESC" if forwards else "ASC",
}
if limit is not None:
query_args.append(limit)
sql += """
LIMIT ?
""" """
txn.execute(sql, (stream_id,)) def _get_largest_public_rooms_txn(txn):
txn.execute(sql, query_args)
results = {} results = self.cursor_to_dict(txn)
# A room is visible if its visible on any list.
for room_id, visibility in txn: if not forwards:
results[room_id] = bool(visibility) or results.get(room_id, False) results.reverse()
return results return results
def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple): ret_val = yield self.runInteraction(
def get_public_room_changes_txn(txn): "get_largest_public_rooms", _get_largest_public_rooms_txn
then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple
)
now_rooms_dict = self.get_published_at_stream_id_txn(
txn, new_stream_id, network_tuple
)
now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis
)
newly_visible = now_rooms_visible - then_rooms
newly_unpublished = now_rooms_not_visible & then_rooms
return newly_visible, newly_unpublished
return self.runInteraction(
"get_public_room_changes", get_public_room_changes_txn
) )
defer.returnValue(ret_val)
@cached(max_entries=10000) @cached(max_entries=10000)
def is_room_blocked(self, room_id): def is_room_blocked(self, room_id):

View File

@ -27,12 +27,14 @@ from synapse.api.constants import EventTypes, Membership
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction from synapse.storage._base import LoggingTransaction
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import Sqlite3Engine from synapse.storage.engines import Sqlite3Engine
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.metrics import Measure
from synapse.util.stringutils import to_ascii from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -483,6 +485,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) )
return result return result
@defer.inlineCallbacks
def get_joined_users_from_state(self, room_id, state_entry): def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group state_group = state_entry.state_group
if not state_group: if not state_group:
@ -492,9 +495,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object() # To do this we set the state_group to a new object as object() != object()
state_group = object() state_group = object()
return self._get_joined_users_from_context( with Measure(self._clock, "get_joined_users_from_state"):
return (
yield self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry room_id, state_group, state_entry.state, context=state_entry
) )
)
@cachedInlineCallbacks( @cachedInlineCallbacks(
num_args=2, cache_context=True, iterable=True, max_entries=100000 num_args=2, cache_context=True, iterable=True, max_entries=100000
@ -567,25 +573,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
missing_member_event_ids.append(event_id) missing_member_event_ids.append(event_id)
if missing_member_event_ids: if missing_member_event_ids:
rows = yield self._simple_select_many_batch( event_to_memberships = yield self._get_joined_profiles_from_event_ids(
table="room_memberships", missing_member_event_ids
column="event_id",
iterable=missing_member_event_ids,
retcols=("user_id", "display_name", "avatar_url"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_joined_users_from_context",
)
users_in_room.update(
{
to_ascii(row["user_id"]): ProfileInfo(
avatar_url=to_ascii(row["avatar_url"]),
display_name=to_ascii(row["display_name"]),
)
for row in rows
}
) )
users_in_room.update((row for row in event_to_memberships.values() if row))
if event is not None and event.type == EventTypes.Member: if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
@ -597,6 +588,47 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return users_in_room return users_in_room
@cached(max_entries=10000)
def _get_joined_profile_from_event_id(self, event_id):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
inlineCallbacks=True,
)
def _get_joined_profiles_from_event_ids(self, event_ids):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
Args:
event_ids (Iterable[str]): The member event IDs to lookup
Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
rows = yield self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_membership_from_event_ids",
)
return {
row["event_id"]: (
row["user_id"],
ProfileInfo(
avatar_url=row["avatar_url"], display_name=row["display_name"]
),
)
for row in rows
}
@cachedInlineCallbacks(max_entries=10000) @cachedInlineCallbacks(max_entries=10000)
def is_host_joined(self, room_id, host): def is_host_joined(self, room_id, host):
if "%" in host or "_" in host: if "%" in host or "_" in host:
@ -669,6 +701,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True return True
@defer.inlineCallbacks
def get_joined_hosts(self, room_id, state_entry): def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group state_group = state_entry.state_group
if not state_group: if not state_group:
@ -678,9 +711,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object() # To do this we set the state_group to a new object as object() != object()
state_group = object() state_group = object()
return self._get_joined_hosts( with Measure(self._clock, "get_joined_hosts"):
return (
yield self._get_joined_hosts(
room_id, state_group, state_entry.state, state_entry=state_entry room_id, state_group, state_entry.state, state_entry=state_entry
) )
)
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
# @defer.inlineCallbacks # @defer.inlineCallbacks
@ -785,9 +821,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids) return set(room_ids)
class RoomMemberStore(RoomMemberWorkerStore): class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs) super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler( self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
) )
@ -803,112 +839,6 @@ class RoomMemberStore(RoomMemberWorkerStore):
where_clause="forgotten = 1", where_clause="forgotten = 1",
) )
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
txn,
table="room_memberships",
values=[
{
"event_id": event.event_id,
"user_id": event.state_key,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
],
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key,
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened. If the event is an
# outlier it is only current if its an "out of band membership",
# like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
},
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
),
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size): def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get( target_min_stream_id = progress.get(
@ -1043,6 +973,117 @@ class RoomMemberStore(RoomMemberWorkerStore):
return row_count return row_count
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
txn,
table="room_memberships",
values=[
{
"event_id": event.event_id,
"user_id": event.state_key,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
],
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key,
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened. If the event is an
# outlier it is only current if its an "out of band membership",
# like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
},
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
),
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f)
class _JoinedHostsCache(object): class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates """Cache for joined hosts in a room that is optimised to handle updates
via state deltas. via state deltas.

View File

@ -0,0 +1,18 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- We want to store large retry intervals so we upgrade the column from INT
-- to BIGINT. We don't need to do this on SQLite.
ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT;

View File

@ -0,0 +1,20 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- these tables are never used.
DROP TABLE IF EXISTS room_names;
DROP TABLE IF EXISTS topics;
DROP TABLE IF EXISTS history_visibility;
DROP TABLE IF EXISTS guest_access;

View File

@ -0,0 +1,16 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id);

View File

@ -0,0 +1,20 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored;
INSERT INTO background_updates (update_name, progress_json) VALUES
('redactions_received_ts', '{}');

View File

@ -0,0 +1,26 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- There was a bug where we may have updated censored redactions as bytes,
-- which can (somehow) cause json to be inserted hex encoded. This goes and
-- undoes any such hex encoded JSON.
UPDATE event_json SET json = convert_from(json::bytea, 'utf8')
WHERE event_id IN (
SELECT event_json.event_id
FROM event_json
INNER JOIN redactions ON (event_json.event_id = redacts)
WHERE have_censored AND json NOT LIKE '{%'
);

View File

@ -0,0 +1,46 @@
import logging
from synapse.storage.engines import PostgresEngine
logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
select_clause = """
CREATE TEMPORARY TABLE user_filters_migration AS
SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json
FROM user_filters;
"""
else:
select_clause = """
CREATE TEMPORARY TABLE user_filters_migration AS
SELECT * FROM user_filters GROUP BY user_id, filter_id;
"""
sql = (
"""
BEGIN;
%s
DROP INDEX user_filters_by_user_id_filter_id;
DELETE FROM user_filters;
ALTER TABLE user_filters
ALTER COLUMN user_id SET NOT NULL,
ALTER COLUMN filter_id SET NOT NULL,
ALTER COLUMN filter_json SET NOT NULL;
INSERT INTO user_filters(user_id, filter_id, filter_json)
SELECT * FROM user_filters_migration;
DROP TABLE user_filters_migration;
CREATE UNIQUE INDEX user_filters_by_user_id_filter_id_unique
ON user_filters(user_id, filter_id);
END;
"""
% select_clause
)
if isinstance(database_engine, PostgresEngine):
cur.execute(sql)
else:
cur.executescript(sql)
def run_create(cur, database_engine, *args, **kwargs):
pass

View File

@ -36,7 +36,7 @@ SearchEntry = namedtuple(
) )
class SearchStore(BackgroundUpdateStore): class SearchBackgroundUpdateStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@ -44,7 +44,7 @@ class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs) super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
if not hs.config.enable_search: if not hs.config.enable_search:
return return
@ -289,29 +289,6 @@ class SearchStore(BackgroundUpdateStore):
return num_rows return num_rows
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
"""
self.store_search_entries_txn(
txn,
(
SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),
),
)
def store_search_entries_txn(self, txn, entries): def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table """Add entries to the search table
@ -358,6 +335,34 @@ class SearchStore(BackgroundUpdateStore):
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs)
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
"""
self.store_search_entries_txn(
txn,
(
SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),
),
)
@defer.inlineCallbacks @defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys): def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.

View File

@ -353,8 +353,158 @@ class StateFilter(object):
return member_filter, non_member_filter return member_filter, non_member_filter
class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""Defines functions related to state groups needed to run the state backgroud
updates.
"""
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
"""
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
txn.execute("SET LOCAL enable_seqscan=off")
# The below query walks the state_group tree so that the "state"
# table includes all state_groups in the tree. It then joins
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# The PARTITION is used to get the event_id in the greatest state
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
"""
for group in groups:
args = [group]
args.extend(where_args)
txn.execute(sql + where_clause, args)
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
args,
)
results[group].update(
((typ, state_key), event_id)
for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
)
# If the number of entries in the (type,state_key)->event_id dict
# matches the number of (type,state_keys) types we were searching
# for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
max_entries_returned is not None
and len(results[group]) == max_entries_returned
):
break
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
return results
# this inherits from EventsWorkerStore because it calls self.get_events # this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): class StateGroupWorkerStore(
EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
):
"""The parts of StateGroupStore that can be called from workers. """The parts of StateGroupStore that can be called from workers.
""" """
@ -694,107 +844,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results return results
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
txn.execute("SET LOCAL enable_seqscan=off")
# The below query walks the state_group tree so that the "state"
# table includes all state_groups in the tree. It then joins
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# The PARTITION is used to get the event_id in the greatest state
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
"""
for group in groups:
args = [group]
args.extend(where_args)
txn.execute(sql + where_clause, args)
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
args,
)
results[group].update(
((typ, state_key), event_id)
for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
)
# If the number of entries in the (type,state_key)->event_id dict
# matches the number of (type,state_keys) types we were searching
# for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
max_entries_returned is not None
and len(results[group]) == max_entries_returned
):
break
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
return results
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
@ -1238,66 +1287,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return self.runInteraction("store_state_group", _store_state_group_txn) return self.runInteraction("store_state_group", _store_state_group_txn)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long. class StateBackgroundUpdateStore(
""" StateGroupBackgroundUpdateStore, BackgroundUpdateStore
if isinstance(self.database_engine, PostgresEngine): ):
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
"""
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
@ -1305,7 +1298,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs) super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler( self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state, self._background_deduplicate_state,
@ -1327,34 +1320,6 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
columns=["state_group"], columns=["state_group"],
) )
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill, (event_id,), state_group_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size): def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding """This background update will slowly deduplicate state by reencoding
@ -1527,3 +1492,54 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
return 1 return 1
class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill, (event_id,), state_group_id
)

View File

@ -332,6 +332,9 @@ class StatsStore(StateDeltasStore):
def _bulk_update_stats_delta_txn(txn): def _bulk_update_stats_delta_txn(txn):
for stats_type, stats_updates in updates.items(): for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items(): for stats_id, fields in stats_updates.items():
logger.info(
"Updating %s stats for %s: %s", stats_type, stats_id, fields
)
self._update_stats_delta_txn( self._update_stats_delta_txn(
txn, txn,
ts=ts, ts=ts,

View File

@ -32,14 +32,14 @@ logger = logging.getLogger(__name__)
TEMP_TABLE = "_temp_populate_user_directory" TEMP_TABLE = "_temp_populate_user_directory"
class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore):
# How many records do we calculate before sending it to # How many records do we calculate before sending it to
# add_users_who_share_private_rooms? # add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500 SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(UserDirectoryStore, self).__init__(db_conn, hs) super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname
@ -452,55 +452,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"update_profile_in_user_dir", _update_profile_in_user_dir_txn "update_profile_in_user_dir", _update_profile_in_user_dir_txn
) )
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
user_ids_share_pub = yield self._simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids_share_priv = yield self._simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids = set(user_ids_share_pub)
user_ids.update(user_ids_share_priv)
return user_ids
def add_users_who_share_private_room(self, room_id, user_id_tuples): def add_users_who_share_private_room(self, room_id, user_id_tuples):
"""Insert entries into the users_who_share_private_rooms table. The first """Insert entries into the users_who_share_private_rooms table. The first
user should be a local user. user should be a local user.
@ -551,6 +502,98 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"add_users_in_public_rooms", _add_users_in_public_rooms_txn "add_users_in_public_rooms", _add_users_in_public_rooms_txn
) )
def delete_all_from_user_dir(self):
"""Delete the entire user directory
"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
)
def update_user_directory_stream_pos(self, stream_id):
return self._simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
desc="update_user_directory_stream_pos",
)
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, db_conn, hs):
super(UserDirectoryStore, self).__init__(db_conn, hs)
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
user_ids_share_pub = yield self._simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids_share_priv = yield self._simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids = set(user_ids_share_pub)
user_ids.update(user_ids_share_priv)
return user_ids
def remove_user_who_share_room(self, user_id, room_id): def remove_user_who_share_room(self, user_id, room_id):
""" """
Deletes entries in the users_who_share_*_rooms table. The first Deletes entries in the users_who_share_*_rooms table. The first
@ -637,31 +680,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
return [room_id for room_id, in rows] return [room_id for room_id, in rows]
def delete_all_from_user_dir(self):
"""Delete the entire user directory
"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
)
def get_user_directory_stream_pos(self): def get_user_directory_stream_pos(self):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="user_directory_stream_pos", table="user_directory_stream_pos",
@ -670,14 +688,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
desc="get_user_directory_stream_pos", desc="get_user_directory_stream_pos",
) )
def update_user_directory_stream_pos(self, stream_id):
return self._simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
desc="update_user_directory_stream_pos",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def search_user_dir(self, user_id, search_term, limit): def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory """Searches for users in directory

View File

@ -318,6 +318,7 @@ class StreamToken(
) )
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
START = None # type: StreamToken
@classmethod @classmethod
def from_string(cls, string): def from_string(cls, string):
@ -402,7 +403,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
followed by the "stream_ordering" id of the event it comes after. followed by the "stream_ordering" id of the event it comes after.
""" """
__slots__ = [] __slots__ = [] # type: list
@classmethod @classmethod
def parse(cls, string): def parse(cls, string):

View File

@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import collections
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union
from six.moves import range from six.moves import range
@ -213,7 +215,9 @@ class Linearizer(object):
# the first element is the number of things executing, and # the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the # the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing. # things blocked from executing.
self.key_to_defer = {} self.key_to_defer = (
{}
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key): def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@ -340,10 +344,10 @@ class ReadWriteLock(object):
def __init__(self): def __init__(self):
# Latest readers queued # Latest readers queued
self.key_to_current_readers = {} self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued # Latest writer queued
self.key_to_current_writer = {} self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks @defer.inlineCallbacks
def read(self, key): def read(self, key):

View File

@ -16,6 +16,7 @@
import logging import logging
import os import os
from typing import Dict
import six import six
from six.moves import intern from six.moves import intern
@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):
caches_by_name = {} caches_by_name = {}
collectors_by_name = {} collectors_by_name = {} # type: Dict
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])

View File

@ -18,10 +18,12 @@ import inspect
import logging import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
from typing import Any, cast
from six import itervalues from six import itervalues
from prometheus_client import Gauge from prometheus_client import Gauge
from typing_extensions import Protocol
from twisted.internet import defer from twisted.internet import defer
@ -37,6 +39,18 @@ from . import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _CachedFunction(Protocol):
invalidate = None # type: Any
invalidate_all = None # type: Any
invalidate_many = None # type: Any
prefill = None # type: Any
cache = None # type: Any
num_args = None # type: Any
def __name__(self):
...
cache_pending_metric = Gauge( cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending", "synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache", "Number of lookups currently pending for this cache",
@ -245,7 +259,9 @@ class Cache(object):
class _CacheDescriptorBase(object): class _CacheDescriptorBase(object):
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): def __init__(
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
):
self.orig = orig self.orig = orig
if inlineCallbacks: if inlineCallbacks:
@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return tuple(get_cache_key_gen(args, kwargs)) return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated # whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(observer) return make_deferred_yieldable(observer)
wrapped = cast(_CachedFunction, _wrapped)
if self.num_args == 1: if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val) wrapped.prefill = lambda key, val: cache.prefill(key[0], val)

View File

@ -1,3 +1,5 @@
from typing import Dict
from six import itervalues from six import itervalues
SENTINEL = object() SENTINEL = object()
@ -12,7 +14,7 @@ class TreeCache(object):
def __init__(self): def __init__(self):
self.size = 0 self.size = 0
self.root = {} self.root = {} # type: Dict
def __setitem__(self, key, value): def __setitem__(self, key, value):
return self.set(key, value) return self.set(key, value)

View File

@ -60,12 +60,14 @@ in_flight = InFlightGauge(
) )
def measure_func(name): def measure_func(name=None):
def wrapper(func): def wrapper(func):
block_name = func.__name__ if name is None else name
@wraps(func) @wraps(func)
@defer.inlineCallbacks @defer.inlineCallbacks
def measured_func(self, *args, **kwargs): def measured_func(self, *args, **kwargs):
with Measure(self.clock, name): with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs) r = yield func(self, *args, **kwargs)
return r return r

View File

@ -54,5 +54,5 @@ def load_python_module(location: str):
if spec is None: if spec is None:
raise Exception("Unable to load module at %s" % (location,)) raise Exception("Unable to load module at %s" % (location,))
mod = importlib.util.module_from_spec(spec) mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) spec.loader.exec_module(mod) # type: ignore
return mod return mod

View File

@ -29,7 +29,7 @@ MIN_RETRY_INTERVAL = 10 * 60 * 1000
RETRY_MULTIPLIER = 5 RETRY_MULTIPLIER = 5
# a cap on the backoff. (Essentially none) # a cap on the backoff. (Essentially none)
MAX_RETRY_INTERVAL = 2 ** 63 MAX_RETRY_INTERVAL = 2 ** 62
class NotRetryingDestination(Exception): class NotRetryingDestination(Exception):

View File

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests import unittest
class FederationTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(http_client=None)
self.handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
return hs
def test_exchange_revoked_invite(self):
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
# Send a 3PID invite event with an empty body so it's considered as a revoked one.
invite_token = "sometoken"
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.ThirdPartyInvite,
state_key=invite_token,
body={},
tok=tok,
)
d = self.handler.on_exchange_third_party_invite_request(
room_id=room_id,
event_dict={
"type": EventTypes.Member,
"room_id": room_id,
"sender": user_id,
"state_key": "@someone:example.org",
"content": {
"membership": "invite",
"third_party_invite": {
"display_name": "alice",
"signed": {
"mxid": "@alice:localhost",
"token": invite_token,
"signatures": {
"magic.forest": {
"ed25519:3": "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIsNffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg"
}
},
},
},
},
},
)
failure = self.get_failure(d, AuthError).value
self.assertEqual(failure.code, 403, failure)
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")

View File

@ -1,39 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.handlers.room_list import RoomListNextBatch
import tests.unittest
import tests.utils
class RoomListTestCase(tests.unittest.TestCase):
""" Tests RoomList's RoomListNextBatch. """
def setUp(self):
pass
def test_check_read_batch_tokens(self):
batch_token = RoomListNextBatch(
stream_ordering="abcdef",
public_room_stream_id="123",
current_limit=20,
direction_is_forward=True,
).to_token()
next_batch = RoomListNextBatch.from_token(batch_token)
self.assertEquals(next_batch.stream_ordering, "abcdef")
self.assertEquals(next_batch.public_room_stream_id, "123")
self.assertEquals(next_batch.current_limit, 20)
self.assertEquals(next_batch.direction_is_forward, True)

View File

@ -23,8 +23,8 @@ from email.parser import Parser
import pkg_resources import pkg_resources
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import LoginType from synapse.api.constants import LoginType, Membership
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register from synapse.rest.client.v2_alpha import account, register
from tests import unittest from tests import unittest
@ -244,16 +244,66 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets, login.register_servlets,
account.register_servlets, account.register_servlets,
room.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver() self.hs = self.setup_test_homeserver()
return hs return self.hs
def test_deactivate_account(self): def test_deactivate_account(self):
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test") tok = self.login("kermit", "test")
self.deactivate(user_id, tok)
store = self.hs.get_datastore()
# Check that the user has been marked as deactivated.
self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
# Check that this access token has been invalidated.
request, channel = self.make_request("GET", "account/whoami")
self.render(request)
self.assertEqual(request.code, 401)
@unittest.INFO
def test_pending_invites(self):
"""Tests that deactivating a user rejects every pending invite for them."""
store = self.hs.get_datastore()
inviter_id = self.register_user("inviter", "test")
inviter_tok = self.login("inviter", "test")
invitee_id = self.register_user("invitee", "test")
invitee_tok = self.login("invitee", "test")
# Make @inviter:test invite @invitee:test in a new room.
room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok)
self.helper.invite(
room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok
)
# Make sure the invite is here.
pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
self.assertEqual(len(pending_invites), 1, pending_invites)
self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
# Deactivate @invitee:test.
self.deactivate(invitee_id, invitee_tok)
# Check that the invite isn't there anymore.
pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
self.assertEqual(len(pending_invites), 0, pending_invites)
# Check that the membership of @invitee:test in the room is now "leave".
memberships = self.get_success(
store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
)
self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships)
def deactivate(self, user_id, tok):
request_data = json.dumps( request_data = json.dumps(
{ {
"auth": { "auth": {
@ -269,13 +319,3 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
) )
self.render(request) self.render(request)
self.assertEqual(request.code, 200) self.assertEqual(request.code, 200)
store = self.hs.get_datastore()
# Check that the user has been marked as deactivated.
self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
# Check that this access token has been invalidated.
request, channel = self.make_request("GET", "account/whoami")
self.render(request)
self.assertEqual(request.code, 401)

View File

@ -118,6 +118,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.persist_event(event, context)) self.get_success(self.store.persist_event(event, context))
return event
def test_redact(self): def test_redact(self):
self.get_success( self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@ -361,3 +363,37 @@ class RedactionTestCase(unittest.HomeserverTestCase):
) )
self.assert_dict({"content": {}}, json.loads(event_json)) self.assert_dict({"content": {}}, json.loads(event_json))
def test_redact_redaction(self):
"""Tests that we can redact a redaction and can fetch it again.
"""
self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
)
msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
first_redact_event = self.get_success(
self.inject_redaction(
self.room1, msg_event.event_id, self.u_alice, "Redacting message"
)
)
self.get_success(
self.inject_redaction(
self.room1,
first_redact_event.event_id,
self.u_alice,
"Redacting redaction",
)
)
# Now lets jump to the future where we have censored the redaction event
# in the DB.
self.reactor.advance(60 * 60 * 24 * 31)
# We just want to check that fetching the event doesn't raise an exception.
self.get_success(
self.store.get_event(first_redact_event.event_id, allow_none=True)
)

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.retryutils import MAX_RETRY_INTERVAL
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -45,3 +47,12 @@ class TransactionStoreTestCase(HomeserverTestCase):
""" """
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
self.get_success(d) self.get_success(d)
def test_large_destination_retry(self):
d = self.store.set_destination_retry_timings(
"example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
)
self.get_success(d)
d = self.store.get_destination_retry_timings("example.com")
self.get_success(d)