Merge branch 'develop' of github.com:matrix-org/synapse into neilj/mau_sign_in_log_in_limits

This commit is contained in:
Neil Johnson 2018-08-01 13:42:50 +01:00
commit 303f1c851f
27 changed files with 5194 additions and 199 deletions

View File

@ -1,16 +1,32 @@
FROM docker.io/python:2-alpine3.7 FROM docker.io/python:2-alpine3.7
RUN apk add --no-cache --virtual .nacl_deps su-exec build-base libffi-dev zlib-dev libressl-dev libjpeg-turbo-dev linux-headers postgresql-dev libxslt-dev RUN apk add --no-cache --virtual .nacl_deps \
build-base \
libffi-dev \
libjpeg-turbo-dev \
libressl-dev \
libxslt-dev \
linux-headers \
postgresql-dev \
su-exec \
zlib-dev
COPY . /synapse COPY . /synapse
# A wheel cache may be provided in ./cache for faster build # A wheel cache may be provided in ./cache for faster build
RUN cd /synapse \ RUN cd /synapse \
&& pip install --upgrade pip setuptools psycopg2 lxml \ && pip install --upgrade \
lxml \
pip \
psycopg2 \
setuptools \
&& mkdir -p /synapse/cache \ && mkdir -p /synapse/cache \
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \ && pip install -f /synapse/cache --upgrade --process-dependency-links . \
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \ && mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \
&& rm -rf setup.py setup.cfg synapse && rm -rf \
setup.cfg \
setup.py \
synapse
VOLUME ["/data"] VOLUME ["/data"]

1
changelog.d/3384.misc Normal file
View File

@ -0,0 +1 @@
Rewrite cache list decorator

1
changelog.d/3543.misc Normal file
View File

@ -0,0 +1 @@
Improve Dockerfile and docker-compose instructions

1
changelog.d/3612.misc Normal file
View File

@ -0,0 +1 @@
Make EventStore inherit from EventFederationStore

1
changelog.d/3626.bugfix Normal file
View File

@ -0,0 +1 @@
Only import secrets when available (fix for py < 3.6)

1
changelog.d/3628.misc Normal file
View File

@ -0,0 +1 @@
Remove unused field "pdu_failures" from transactions.

1
changelog.d/3634.misc Normal file
View File

@ -0,0 +1 @@
rename replication_layer to federation_client

View File

@ -9,13 +9,7 @@ use that server.
## Build ## Build
Build the docker image with the `docker build` command from the root of the synapse repository. Build the docker image with the `docker-compose build` command.
```
docker build -t docker.io/matrixdotorg/synapse .
```
The `-t` option sets the image tag. Official images are tagged `matrixdotorg/synapse:<version>` where `<version>` is the same as the release tag in the synapse git repository.
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project. You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.

View File

@ -6,6 +6,7 @@ version: '3'
services: services:
synapse: synapse:
build: ../..
image: docker.io/matrixdotorg/synapse:latest image: docker.io/matrixdotorg/synapse:latest
# Since snyapse does not retry to connect to the database, restart upon # Since snyapse does not retry to connect to the database, restart upon
# failure # failure

View File

@ -0,0 +1,6 @@
# Using the Synapse Grafana dashboard
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
3. Set up additional recording rules

4961
contrib/grafana/synapse.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -207,10 +207,6 @@ class FederationServer(FederationBase):
edu.content edu.content
) )
pdu_failures = getattr(transaction, "pdu_failures", [])
for fail in pdu_failures:
logger.info("Got failure %r", fail)
response = { response = {
"pdus": pdu_results, "pdus": pdu_results,
} }

View File

@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
self.edus = SortedDict() # stream position -> Edu self.edus = SortedDict() # stream position -> Edu
self.failures = SortedDict() # stream position -> (destination, Failure)
self.device_messages = SortedDict() # stream position -> destination self.device_messages = SortedDict() # stream position -> destination
self.pos = 1 self.pos = 1
@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
for queue_name in [ for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
"edus", "failures", "device_messages", "pos_time", "edus", "device_messages", "pos_time",
]: ]:
register(queue_name, getattr(self, queue_name)) register(queue_name, getattr(self, queue_name))
@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
for key in keys[:i]: for key in keys[:i]:
del self.edus[key] del self.edus[key]
# Delete things out of failure map
keys = self.failures.keys()
i = self.failures.bisect_left(position_to_delete)
for key in keys[:i]:
del self.failures[key]
# Delete things out of device map # Delete things out of device map
keys = self.device_messages.keys() keys = self.device_messages.keys()
i = self.device_messages.bisect_left(position_to_delete) i = self.device_messages.bisect_left(position_to_delete)
@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
self.notifier.on_new_replication_data() self.notifier.on_new_replication_data()
def send_failure(self, failure, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.failures[pos] = (destination, str(failure))
self.notifier.on_new_replication_data()
def send_device_messages(self, destination): def send_device_messages(self, destination):
"""As per TransactionQueue""" """As per TransactionQueue"""
pos = self._next_pos() pos = self._next_pos()
@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
for (pos, edu) in edus: for (pos, edu) in edus:
rows.append((pos, EduRow(edu))) rows.append((pos, EduRow(edu)))
# Fetch changed failures
i = self.failures.bisect_right(from_token)
j = self.failures.bisect_right(to_token) + 1
failures = self.failures.items()[i:j]
for (pos, (destination, failure)) in failures:
rows.append((pos, FailureRow(
destination=destination,
failure=failure,
)))
# Fetch changed device messages # Fetch changed device messages
i = self.device_messages.bisect_right(from_token) i = self.device_messages.bisect_right(from_token)
j = self.device_messages.bisect_right(to_token) + 1 j = self.device_messages.bisect_right(to_token) + 1
@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
buff.edus.setdefault(self.edu.destination, []).append(self.edu) buff.edus.setdefault(self.edu.destination, []).append(self.edu)
class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
"destination", # str
"failure",
))):
"""Streams failures to a remote server. Failures are issued when there was
something wrong with a transaction the remote sent us, e.g. it included
an event that was invalid.
"""
TypeId = "f"
@staticmethod
def from_data(data):
return FailureRow(
destination=data["destination"],
failure=data["failure"],
)
def to_data(self):
return {
"destination": self.destination,
"failure": self.failure,
}
def add_to_buffer(self, buff):
buff.failures.setdefault(self.destination, []).append(self.failure)
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
"destination", # str "destination", # str
))): ))):
@ -471,7 +417,6 @@ TypeToRow = {
PresenceRow, PresenceRow,
KeyedEduRow, KeyedEduRow,
EduRow, EduRow,
FailureRow,
DeviceRow, DeviceRow,
) )
} }
@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
"presence", # list(UserPresenceState) "presence", # list(UserPresenceState)
"keyed_edus", # dict of destination -> { key -> Edu } "keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu] "edus", # dict of destination -> [Edu]
"failures", # dict of destination -> [failures]
"device_destinations", # set of destinations "device_destinations", # set of destinations
)) ))
@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
presence=[], presence=[],
keyed_edus={}, keyed_edus={},
edus={}, edus={},
failures={},
device_destinations=set(), device_destinations=set(),
) )
@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
edu.destination, edu.edu_type, edu.content, key=None, edu.destination, edu.edu_type, edu.content, key=None,
) )
for destination, failure_list in iteritems(buff.failures):
for failure in failure_list:
transaction_queue.send_failure(destination, failure)
for destination in buff.device_destinations: for destination in buff.device_destinations:
transaction_queue.send_device_messages(destination) transaction_queue.send_device_messages(destination)

View File

@ -116,9 +116,6 @@ class TransactionQueue(object):
), ),
) )
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
# destination -> stream_id of last successfully sent to-device message. # destination -> stream_id of last successfully sent to-device message.
# NB: may be a long or an int. # NB: may be a long or an int.
self.last_device_stream_id_by_dest = {} self.last_device_stream_id_by_dest = {}
@ -382,19 +379,6 @@ class TransactionQueue(object):
self._attempt_new_transaction(destination) self._attempt_new_transaction(destination)
def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
if not self.can_send_to(destination):
return
self.pending_failures_by_dest.setdefault(
destination, []
).append(failure)
self._attempt_new_transaction(destination)
def send_device_messages(self, destination): def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
return return
@ -469,7 +453,6 @@ class TransactionQueue(object):
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend( pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values() self.pending_edus_keyed_by_dest.pop(destination, {}).values()
@ -497,7 +480,7 @@ class TransactionQueue(object):
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus)) destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures: if not pending_pdus and not pending_edus:
logger.debug("TX [%s] Nothing to send", destination) logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = ( self.last_device_stream_id_by_dest[destination] = (
device_stream_id device_stream_id
@ -507,7 +490,7 @@ class TransactionQueue(object):
# END CRITICAL SECTION # END CRITICAL SECTION
success = yield self._send_new_transaction( success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures, destination, pending_pdus, pending_edus,
) )
if success: if success:
sent_transactions_counter.inc() sent_transactions_counter.inc()
@ -584,14 +567,12 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus):
pending_failures):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus] pdus = [x[0] for x in pending_pdus]
edus = pending_edus edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
success = True success = True
@ -601,11 +582,10 @@ class TransactionQueue(object):
logger.debug( logger.debug(
"TX [%s] {%s} Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d)",
destination, txn_id, destination, txn_id,
len(pdus), len(pdus),
len(edus), len(edus),
len(failures)
) )
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
@ -617,7 +597,6 @@ class TransactionQueue(object):
destination=destination, destination=destination,
pdus=pdus, pdus=pdus,
edus=edus, edus=edus,
pdu_failures=failures,
) )
self._next_txn_id += 1 self._next_txn_id += 1
@ -627,12 +606,11 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination) logger.debug("TX [%s] Persisted transaction", destination)
logger.info( logger.info(
"TX [%s] {%s} Sending transaction [%s]," "TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)", " (PDUs: %d, EDUs: %d)",
destination, txn_id, destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pdus), len(pdus),
len(edus), len(edus),
len(failures),
) )
# Actually send the transaction # Actually send the transaction

View File

@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet):
) )
logger.info( logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)", "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
transaction_id, origin, transaction_id, origin,
len(transaction_data.get("pdus", [])), len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])), len(transaction_data.get("edus", [])),
len(transaction_data.get("failures", [])),
) )
# We should ideally be getting this from the security layer. # We should ideally be getting this from the security layer.

View File

@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject):
"previous_ids", "previous_ids",
"pdus", "pdus",
"edus", "edus",
"pdu_failures",
] ]
internal_keys = [ internal_keys = [

View File

@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_federation_client() self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -255,7 +255,7 @@ class FederationHandler(BaseHandler):
# know about # know about
for p in prevs - seen: for p in prevs - seen:
state, got_auth_chain = ( state, got_auth_chain = (
yield self.replication_layer.get_state_for_room( yield self.federation_client.get_state_for_room(
origin, pdu.room_id, p origin, pdu.room_id, p
) )
) )
@ -338,7 +338,7 @@ class FederationHandler(BaseHandler):
# #
# see https://github.com/matrix-org/synapse/pull/1744 # see https://github.com/matrix-org/synapse/pull/1744
missing_events = yield self.replication_layer.get_missing_events( missing_events = yield self.federation_client.get_missing_events(
origin, origin,
pdu.room_id, pdu.room_id,
earliest_events_ids=list(latest), earliest_events_ids=list(latest),
@ -522,7 +522,7 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
events = yield self.replication_layer.backfill( events = yield self.federation_client.backfill(
dest, dest,
room_id, room_id,
limit=limit, limit=limit,
@ -570,7 +570,7 @@ class FederationHandler(BaseHandler):
state_events = {} state_events = {}
events_to_state = {} events_to_state = {}
for e_id in edges: for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room( state, auth = yield self.federation_client.get_state_for_room(
destination=dest, destination=dest,
room_id=room_id, room_id=room_id,
event_id=e_id event_id=e_id
@ -612,7 +612,7 @@ class FederationHandler(BaseHandler):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
logcontext.run_in_background( logcontext.run_in_background(
self.replication_layer.get_pdu, self.federation_client.get_pdu,
[dest], [dest],
event_id, event_id,
outlier=True, outlier=True,
@ -893,7 +893,7 @@ class FederationHandler(BaseHandler):
Invites must be signed by the invitee's server before distribution. Invites must be signed by the invitee's server before distribution.
""" """
pdu = yield self.replication_layer.send_invite( pdu = yield self.federation_client.send_invite(
destination=target_host, destination=target_host,
room_id=event.room_id, room_id=event.room_id,
event_id=event.event_id, event_id=event.event_id,
@ -955,7 +955,7 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin) target_hosts.insert(0, origin)
except ValueError: except ValueError:
pass pass
ret = yield self.replication_layer.send_join(target_hosts, event) ret = yield self.federation_client.send_join(target_hosts, event)
origin = ret["origin"] origin = ret["origin"]
state = ret["state"] state = ret["state"]
@ -1211,7 +1211,7 @@ class FederationHandler(BaseHandler):
except ValueError: except ValueError:
pass pass
yield self.replication_layer.send_leave( yield self.federation_client.send_leave(
target_hosts, target_hosts,
event event
) )
@ -1234,7 +1234,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={},): content={},):
origin, pdu = yield self.replication_layer.make_membership_event( origin, pdu = yield self.federation_client.make_membership_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
@ -1567,7 +1567,7 @@ class FederationHandler(BaseHandler):
missing_auth_events.add(e_id) missing_auth_events.add(e_id)
for e_id in missing_auth_events: for e_id in missing_auth_events:
m_ev = yield self.replication_layer.get_pdu( m_ev = yield self.federation_client.get_pdu(
[origin], [origin],
e_id, e_id,
outlier=True, outlier=True,
@ -1777,7 +1777,7 @@ class FederationHandler(BaseHandler):
logger.info("Missing auth: %s", missing_auth) logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them. # If we don't have all the auth events, we need to get them.
try: try:
remote_auth_chain = yield self.replication_layer.get_event_auth( remote_auth_chain = yield self.federation_client.get_event_auth(
origin, event.room_id, event.event_id origin, event.room_id, event.event_id
) )
@ -1893,7 +1893,7 @@ class FederationHandler(BaseHandler):
try: try:
# 2. Get remote difference. # 2. Get remote difference.
result = yield self.replication_layer.query_auth( result = yield self.federation_client.query_auth(
origin, origin,
event.room_id, event.room_id,
event.event_id, event.event_id,
@ -2192,7 +2192,7 @@ class FederationHandler(BaseHandler):
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
yield self.replication_layer.forward_third_party_invite( yield self.federation_client.forward_third_party_invite(
destinations, destinations,
room_id, room_id,
event_dict, event_dict,

View File

@ -20,17 +20,16 @@ See https://docs.python.org/3/library/secrets.html#module-secrets for the API
used in Python 3.6, and the API emulated in Python 2.7. used in Python 3.6, and the API emulated in Python 2.7.
""" """
import six import sys
if six.PY3: # secrets is available since python 3.6
if sys.version_info[0:2] >= (3, 6):
import secrets import secrets
def Secrets(): def Secrets():
return secrets return secrets
else: else:
import os import os
import binascii import binascii

View File

@ -66,6 +66,7 @@ class DataStore(RoomMemberStore, RoomStore,
PresenceStore, TransactionStore, PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore, DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore, ApplicationServiceStore,
EventsStore,
EventFederationStore, EventFederationStore,
MediaRepositoryStore, MediaRepositoryStore,
RejectionsStore, RejectionsStore,
@ -73,7 +74,6 @@ class DataStore(RoomMemberStore, RoomStore,
PusherStore, PusherStore,
PushRuleStore, PushRuleStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
EventsStore,
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore, SearchStore,

View File

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from ._base import SQLBaseStore from ._base import SQLBaseStore

View File

@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore from synapse.storage.signatures import SignatureWorkerStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached

View File

@ -34,6 +34,8 @@ from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
@ -193,7 +195,9 @@ def _retry_on_integrity_error(func):
return f return f
class EventsStore(EventsWorkerStore): # inherits from EventFederationStore so that we can call _update_backward_extremities
# and _handle_mult_prev_events (though arguably those could both be moved in here)
class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"

View File

@ -24,7 +24,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string

View File

@ -43,7 +43,7 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background

View File

@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its
# whenever we are invalidated # invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name] list_args = arg_dict[self.list_name]
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
results = {} results = {}
cached_defers = {}
missing = [] def update_results_dict(res, arg):
results[arg] = res
# list of deferreds to wait for
cached_defers = []
missing = set()
# If the cache takes a single arg then that is used as the key, # If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used. # otherwise a tuple is used.
if num_args == 1: if num_args == 1:
def cache_get(arg): def arg_to_cache_key(arg):
return cache.get(arg, callback=invalidate_callback) return arg
else: else:
key = list(keyargs) keylist = list(keyargs)
def cache_get(arg): def arg_to_cache_key(arg):
key[self.list_pos] = arg keylist[self.list_pos] = arg
return cache.get(tuple(key), callback=invalidate_callback) return tuple(keylist)
for arg in list_args: for arg in list_args:
try: try:
res = cache_get(arg) res = cache.get(arg_to_cache_key(arg),
callback=invalidate_callback)
if not isinstance(res, ObservableDeferred): if not isinstance(res, ObservableDeferred):
results[arg] = res results[arg] = res
elif not res.has_succeeded(): elif not res.has_succeeded():
res = res.observe() res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(update_results_dict, arg)
cached_defers[arg] = res cached_defers.append(res)
else: else:
results[arg] = res.get_result() results[arg] = res.get_result()
except KeyError: except KeyError:
missing.append(arg) missing.add(arg)
if missing: if missing:
args_to_call = dict(arg_dict) # we need an observable deferred for each entry in the list,
args_to_call[self.list_name] = missing # which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
for arg in missing:
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
cache.set(key, observable, callback=invalidate_callback)
ret_d = defer.maybeDeferred( def complete_all(res):
# the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in
# the cache and update our own result map.
for e in missing:
val = res.get(e, None)
deferreds_map[e].callback(val)
results[e] = val
def errback(f):
# the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail
# their deferreds.
for e in missing:
key = arg_to_cache_key(e)
cache.invalidate(key)
deferreds_map[e].errback(f)
# return the failure, to propagate to our caller.
return f
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = list(missing)
cached_defers.append(defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call), logcontext.preserve_fn(self.function_to_call),
**args_to_call **args_to_call
) ).addCallbacks(complete_all, errback))
ret_d = ObservableDeferred(ret_d)
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
if num_args == 1:
cache.set(
arg, observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, arg)
else:
key = list(keyargs)
key[self.list_pos] = arg
cache.set(
tuple(key), observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached_defers[arg] = res
if cached_defers: if cached_defers:
def update_results_dict(res): d = defer.gatherResults(
results.update(res) cached_defers,
return results
return logcontext.make_deferred_yieldable(defer.gatherResults(
list(cached_defers.values()),
consumeErrors=True, consumeErrors=True,
).addCallback(update_results_dict).addErrback( ).addCallbacks(
lambda _: results,
unwrapFirstError unwrapFirstError
)) )
return logcontext.make_deferred_yieldable(d)
else: else:
return results return results
@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
cache. cache.
Args: Args:
cache (Cache): The underlying cache to use. cached_method_name (str): The name of the single-item lookup method.
This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache. do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache num_args (int): Number of arguments to use as the key in the cache

View File

@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"):
"content": content, "content": content,
} }
], ],
"pdu_failures": [],
} }

View File

@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
r = yield obj.fn(2, 3) r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips') self.assertEqual(r, 'chips')
obj.mock.assert_not_called() obj.mock.assert_not_called()
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2):
pass
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
assert (
logcontext.LoggingContext.current_context().request == "c1"
)
# we want this to behave like an asynchronous function
yield run_on_reactor()
assert (
logcontext.LoggingContext.current_context().request == "c1"
)
defer.returnValue(self.mock(args1, arg2))
with logcontext.LoggingContext() as c1:
c1.request = "c1"
obj = Cls()
obj.mock.return_value = {10: 'fish', 20: 'chips'}
d1 = obj.list_fn([10, 20], 2)
self.assertEqual(
logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel,
)
r = yield d1
self.assertEqual(
logcontext.LoggingContext.current_context(),
c1
)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = {30: 'peas'}
r = yield obj.list_fn([20, 30], 2)
obj.mock.assert_called_once_with([30], 2)
self.assertEqual(r, {20: 'chips', 30: 'peas'})
obj.mock.reset_mock()
# all the values should now be cached
r = yield obj.fn(10, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(20, 2)
self.assertEqual(r, 'chips')
r = yield obj.fn(30, 2)
self.assertEqual(r, 'peas')
r = yield obj.list_fn([10, 20, 30], 2)
obj.mock.assert_not_called()
self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
@defer.inlineCallbacks
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2):
pass
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
yield run_on_reactor()
defer.returnValue(self.mock(args1, arg2))
obj = Cls()
invalidate0 = mock.Mock()
invalidate1 = mock.Mock()
# cache miss
obj.mock.return_value = {10: 'fish', 20: 'chips'}
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r1, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock()
# cache hit
r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
obj.mock.assert_not_called()
self.assertEqual(r2, {10: 'fish', 20: 'chips'})
invalidate0.assert_not_called()
invalidate1.assert_not_called()
# now if we invalidate the keys, both invalidations should get called
obj.fn.invalidate((10, 2))
invalidate0.assert_called_once()
invalidate1.assert_called_once()