mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-04 22:10:52 -05:00
When joining a remote room limit the number of events we concurrently check signatures/hashes for (#10117)
If we do hundreds of thousands at once the memory overhead can easily reach 500+ MB.
This commit is contained in:
parent
a0101fc021
commit
c842c581ed
1
changelog.d/10117.feature
Normal file
1
changelog.d/10117.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Significantly reduce memory usage of joining large remote rooms.
|
@ -233,41 +233,19 @@ class Keyring:
|
|||||||
for server_name, json_object, validity_time in server_and_json
|
for server_name, json_object, validity_time in server_and_json
|
||||||
]
|
]
|
||||||
|
|
||||||
def verify_events_for_server(
|
async def verify_event_for_server(
|
||||||
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
|
self,
|
||||||
) -> List[defer.Deferred]:
|
server_name: str,
|
||||||
"""Bulk verification of signatures on events.
|
event: EventBase,
|
||||||
|
validity_time: int,
|
||||||
Args:
|
) -> None:
|
||||||
server_and_events:
|
await self.process_request(
|
||||||
Iterable of `(server_name, event, validity_time)` tuples.
|
VerifyJsonRequest.from_event(
|
||||||
|
server_name,
|
||||||
`server_name` is which server we are verifying the signature for
|
event,
|
||||||
on the event.
|
validity_time,
|
||||||
|
|
||||||
`event` is the event that we'll verify the signatures of for
|
|
||||||
the given `server_name`.
|
|
||||||
|
|
||||||
`validity_time` is a timestamp at which the signing key must be
|
|
||||||
valid.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
|
||||||
or failure to verify each event's signature for the given
|
|
||||||
server_name. The deferreds run their callbacks in the sentinel
|
|
||||||
logcontext.
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
run_in_background(
|
|
||||||
self.process_request,
|
|
||||||
VerifyJsonRequest.from_event(
|
|
||||||
server_name,
|
|
||||||
event,
|
|
||||||
validity_time,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for server_name, event, validity_time in server_and_events
|
)
|
||||||
]
|
|
||||||
|
|
||||||
async def process_request(self, verify_request: VerifyJsonRequest) -> None:
|
async def process_request(self, verify_request: VerifyJsonRequest) -> None:
|
||||||
"""Processes the `VerifyJsonRequest`. Raises if the object is not signed
|
"""Processes the `VerifyJsonRequest`. Raises if the object is not signed
|
||||||
|
@ -14,11 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Iterable, List
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.defer import Deferred, DeferredList
|
|
||||||
from twisted.python.failure import Failure
|
|
||||||
|
|
||||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
|
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
@ -28,11 +23,6 @@ from synapse.crypto.keyring import Keyring
|
|||||||
from synapse.events import EventBase, make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.events.utils import prune_event, validate_canonicaljson
|
from synapse.events.utils import prune_event, validate_canonicaljson
|
||||||
from synapse.http.servlet import assert_params_in_dict
|
from synapse.http.servlet import assert_params_in_dict
|
||||||
from synapse.logging.context import (
|
|
||||||
PreserveLoggingContext,
|
|
||||||
current_context,
|
|
||||||
make_deferred_yieldable,
|
|
||||||
)
|
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -48,112 +38,82 @@ class FederationBase:
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
|
|
||||||
def _check_sigs_and_hash(
|
async def _check_sigs_and_hash(
|
||||||
self, room_version: RoomVersion, pdu: EventBase
|
self, room_version: RoomVersion, pdu: EventBase
|
||||||
) -> Deferred:
|
) -> EventBase:
|
||||||
return make_deferred_yieldable(
|
"""Checks that event is correctly signed by the sending server.
|
||||||
self._check_sigs_and_hashes(room_version, [pdu])[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_sigs_and_hashes(
|
|
||||||
self, room_version: RoomVersion, pdus: List[EventBase]
|
|
||||||
) -> List[Deferred]:
|
|
||||||
"""Checks that each of the received events is correctly signed by the
|
|
||||||
sending server.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_version: The room version of the PDUs
|
room_version: The room version of the PDU
|
||||||
pdus: the events to be checked
|
pdu: the event to be checked
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
For each input event, a deferred which:
|
* the original event if the checks pass
|
||||||
* returns the original event if the checks pass
|
* a redacted version of the event (if the signature
|
||||||
* returns a redacted version of the event (if the signature
|
|
||||||
matched but the hash did not)
|
matched but the hash did not)
|
||||||
* throws a SynapseError if the signature check failed.
|
* throws a SynapseError if the signature check failed."""
|
||||||
The deferreds run their callbacks in the sentinel
|
try:
|
||||||
"""
|
await _check_sigs_on_pdu(self.keyring, room_version, pdu)
|
||||||
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
|
except SynapseError as e:
|
||||||
|
logger.warning(
|
||||||
ctx = current_context()
|
"Signature check failed for %s: %s",
|
||||||
|
pdu.event_id,
|
||||||
@defer.inlineCallbacks
|
e,
|
||||||
def callback(_, pdu: EventBase):
|
|
||||||
with PreserveLoggingContext(ctx):
|
|
||||||
if not check_event_content_hash(pdu):
|
|
||||||
# let's try to distinguish between failures because the event was
|
|
||||||
# redacted (which are somewhat expected) vs actual ball-tampering
|
|
||||||
# incidents.
|
|
||||||
#
|
|
||||||
# This is just a heuristic, so we just assume that if the keys are
|
|
||||||
# about the same between the redacted and received events, then the
|
|
||||||
# received event was probably a redacted copy (but we then use our
|
|
||||||
# *actual* redacted copy to be on the safe side.)
|
|
||||||
redacted_event = prune_event(pdu)
|
|
||||||
if set(redacted_event.keys()) == set(pdu.keys()) and set(
|
|
||||||
redacted_event.content.keys()
|
|
||||||
) == set(pdu.content.keys()):
|
|
||||||
logger.info(
|
|
||||||
"Event %s seems to have been redacted; using our redacted "
|
|
||||||
"copy",
|
|
||||||
pdu.event_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Event %s content has been tampered, redacting",
|
|
||||||
pdu.event_id,
|
|
||||||
)
|
|
||||||
return redacted_event
|
|
||||||
|
|
||||||
result = yield defer.ensureDeferred(
|
|
||||||
self.spam_checker.check_event_for_spam(pdu)
|
|
||||||
)
|
|
||||||
|
|
||||||
if result:
|
|
||||||
logger.warning(
|
|
||||||
"Event contains spam, redacting %s: %s",
|
|
||||||
pdu.event_id,
|
|
||||||
pdu.get_pdu_json(),
|
|
||||||
)
|
|
||||||
return prune_event(pdu)
|
|
||||||
|
|
||||||
return pdu
|
|
||||||
|
|
||||||
def errback(failure: Failure, pdu: EventBase):
|
|
||||||
failure.trap(SynapseError)
|
|
||||||
with PreserveLoggingContext(ctx):
|
|
||||||
logger.warning(
|
|
||||||
"Signature check failed for %s: %s",
|
|
||||||
pdu.event_id,
|
|
||||||
failure.getErrorMessage(),
|
|
||||||
)
|
|
||||||
return failure
|
|
||||||
|
|
||||||
for deferred, pdu in zip(deferreds, pdus):
|
|
||||||
deferred.addCallbacks(
|
|
||||||
callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
|
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
return deferreds
|
if not check_event_content_hash(pdu):
|
||||||
|
# let's try to distinguish between failures because the event was
|
||||||
|
# redacted (which are somewhat expected) vs actual ball-tampering
|
||||||
|
# incidents.
|
||||||
|
#
|
||||||
|
# This is just a heuristic, so we just assume that if the keys are
|
||||||
|
# about the same between the redacted and received events, then the
|
||||||
|
# received event was probably a redacted copy (but we then use our
|
||||||
|
# *actual* redacted copy to be on the safe side.)
|
||||||
|
redacted_event = prune_event(pdu)
|
||||||
|
if set(redacted_event.keys()) == set(pdu.keys()) and set(
|
||||||
|
redacted_event.content.keys()
|
||||||
|
) == set(pdu.content.keys()):
|
||||||
|
logger.info(
|
||||||
|
"Event %s seems to have been redacted; using our redacted copy",
|
||||||
|
pdu.event_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Event %s content has been tampered, redacting",
|
||||||
|
pdu.event_id,
|
||||||
|
)
|
||||||
|
return redacted_event
|
||||||
|
|
||||||
|
result = await self.spam_checker.check_event_for_spam(pdu)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.warning(
|
||||||
|
"Event contains spam, redacting %s: %s",
|
||||||
|
pdu.event_id,
|
||||||
|
pdu.get_pdu_json(),
|
||||||
|
)
|
||||||
|
return prune_event(pdu)
|
||||||
|
|
||||||
|
return pdu
|
||||||
|
|
||||||
|
|
||||||
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
|
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _check_sigs_on_pdus(
|
async def _check_sigs_on_pdu(
|
||||||
keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
|
keyring: Keyring, room_version: RoomVersion, pdu: EventBase
|
||||||
) -> List[Deferred]:
|
) -> None:
|
||||||
"""Check that the given events are correctly signed
|
"""Check that the given events are correctly signed
|
||||||
|
|
||||||
|
Raise a SynapseError if the event wasn't correctly signed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
keyring: keyring object to do the checks
|
keyring: keyring object to do the checks
|
||||||
room_version: the room version of the PDUs
|
room_version: the room version of the PDUs
|
||||||
pdus: the events to be checked
|
pdus: the events to be checked
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Deferred for each event in pdus, which will either succeed if
|
|
||||||
the signatures are valid, or fail (with a SynapseError) if not.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# we want to check that the event is signed by:
|
# we want to check that the event is signed by:
|
||||||
@ -177,90 +137,47 @@ def _check_sigs_on_pdus(
|
|||||||
# let's start by getting the domain for each pdu, and flattening the event back
|
# let's start by getting the domain for each pdu, and flattening the event back
|
||||||
# to JSON.
|
# to JSON.
|
||||||
|
|
||||||
pdus_to_check = [
|
|
||||||
PduToCheckSig(
|
|
||||||
pdu=p,
|
|
||||||
sender_domain=get_domain_from_id(p.sender),
|
|
||||||
deferreds=[],
|
|
||||||
)
|
|
||||||
for p in pdus
|
|
||||||
]
|
|
||||||
|
|
||||||
# First we check that the sender event is signed by the sender's domain
|
# First we check that the sender event is signed by the sender's domain
|
||||||
# (except if its a 3pid invite, in which case it may be sent by any server)
|
# (except if its a 3pid invite, in which case it may be sent by any server)
|
||||||
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
|
if not _is_invite_via_3pid(pdu):
|
||||||
|
try:
|
||||||
more_deferreds = keyring.verify_events_for_server(
|
await keyring.verify_event_for_server(
|
||||||
[
|
get_domain_from_id(pdu.sender),
|
||||||
(
|
pdu,
|
||||||
p.sender_domain,
|
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||||
p.pdu,
|
|
||||||
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
|
||||||
)
|
)
|
||||||
for p in pdus_to_check_sender
|
except Exception as e:
|
||||||
]
|
errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
|
||||||
)
|
pdu.event_id,
|
||||||
|
get_domain_from_id(pdu.sender),
|
||||||
def sender_err(e, pdu_to_check):
|
e,
|
||||||
errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
|
)
|
||||||
pdu_to_check.pdu.event_id,
|
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
|
||||||
pdu_to_check.sender_domain,
|
|
||||||
e.getErrorMessage(),
|
|
||||||
)
|
|
||||||
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
|
|
||||||
|
|
||||||
for p, d in zip(pdus_to_check_sender, more_deferreds):
|
|
||||||
d.addErrback(sender_err, p)
|
|
||||||
p.deferreds.append(d)
|
|
||||||
|
|
||||||
# now let's look for events where the sender's domain is different to the
|
# now let's look for events where the sender's domain is different to the
|
||||||
# event id's domain (normally only the case for joins/leaves), and add additional
|
# event id's domain (normally only the case for joins/leaves), and add additional
|
||||||
# checks. Only do this if the room version has a concept of event ID domain
|
# checks. Only do this if the room version has a concept of event ID domain
|
||||||
# (ie, the room version uses old-style non-hash event IDs).
|
# (ie, the room version uses old-style non-hash event IDs).
|
||||||
if room_version.event_format == EventFormatVersions.V1:
|
if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id(
|
||||||
pdus_to_check_event_id = [
|
pdu.event_id
|
||||||
p
|
) != get_domain_from_id(pdu.sender):
|
||||||
for p in pdus_to_check
|
try:
|
||||||
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
|
await keyring.verify_event_for_server(
|
||||||
]
|
get_domain_from_id(pdu.event_id),
|
||||||
|
pdu,
|
||||||
more_deferreds = keyring.verify_events_for_server(
|
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||||
[
|
)
|
||||||
(
|
except Exception as e:
|
||||||
get_domain_from_id(p.pdu.event_id),
|
|
||||||
p.pdu,
|
|
||||||
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
|
||||||
)
|
|
||||||
for p in pdus_to_check_event_id
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def event_err(e, pdu_to_check):
|
|
||||||
errmsg = (
|
errmsg = (
|
||||||
"event id %s: unable to verify signature for event id domain: %s"
|
"event id %s: unable to verify signature for event id domain %s: %s"
|
||||||
% (pdu_to_check.pdu.event_id, e.getErrorMessage())
|
% (
|
||||||
|
pdu.event_id,
|
||||||
|
get_domain_from_id(pdu.event_id),
|
||||||
|
e,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
|
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
|
||||||
|
|
||||||
for p, d in zip(pdus_to_check_event_id, more_deferreds):
|
|
||||||
d.addErrback(event_err, p)
|
|
||||||
p.deferreds.append(d)
|
|
||||||
|
|
||||||
# replace lists of deferreds with single Deferreds
|
|
||||||
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
|
|
||||||
|
|
||||||
|
|
||||||
def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
|
|
||||||
"""Given a list of deferreds, either return the single deferred,
|
|
||||||
combine into a DeferredList, or return an already resolved deferred.
|
|
||||||
"""
|
|
||||||
if len(deferreds) > 1:
|
|
||||||
return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
|
|
||||||
elif len(deferreds) == 1:
|
|
||||||
return deferreds[0]
|
|
||||||
else:
|
|
||||||
return defer.succeed(None)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_invite_via_3pid(event: EventBase) -> bool:
|
def _is_invite_via_3pid(event: EventBase) -> bool:
|
||||||
return (
|
return (
|
||||||
|
@ -21,6 +21,7 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
@ -35,9 +36,6 @@ from typing import (
|
|||||||
import attr
|
import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.defer import Deferred
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException,
|
CodeMessageException,
|
||||||
@ -56,10 +54,9 @@ from synapse.api.room_versions import (
|
|||||||
from synapse.events import EventBase, builder
|
from synapse.events import EventBase, builder
|
||||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||||
from synapse.federation.transport.client import SendJoinResponse
|
from synapse.federation.transport.client import SendJoinResponse
|
||||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
|
||||||
from synapse.logging.utils import log_function
|
from synapse.logging.utils import log_function
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util.async_helpers import concurrently_execute
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
@ -360,10 +357,9 @@ class FederationClient(FederationBase):
|
|||||||
async def _check_sigs_and_hash_and_fetch(
|
async def _check_sigs_and_hash_and_fetch(
|
||||||
self,
|
self,
|
||||||
origin: str,
|
origin: str,
|
||||||
pdus: List[EventBase],
|
pdus: Collection[EventBase],
|
||||||
room_version: RoomVersion,
|
room_version: RoomVersion,
|
||||||
outlier: bool = False,
|
outlier: bool = False,
|
||||||
include_none: bool = False,
|
|
||||||
) -> List[EventBase]:
|
) -> List[EventBase]:
|
||||||
"""Takes a list of PDUs and checks the signatures and hashes of each
|
"""Takes a list of PDUs and checks the signatures and hashes of each
|
||||||
one. If a PDU fails its signature check then we check if we have it in
|
one. If a PDU fails its signature check then we check if we have it in
|
||||||
@ -380,57 +376,87 @@ class FederationClient(FederationBase):
|
|||||||
pdu
|
pdu
|
||||||
room_version
|
room_version
|
||||||
outlier: Whether the events are outliers or not
|
outlier: Whether the events are outliers or not
|
||||||
include_none: Whether to include None in the returned list
|
|
||||||
for events that have failed their checks
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of PDUs that have valid signatures and hashes.
|
A list of PDUs that have valid signatures and hashes.
|
||||||
"""
|
"""
|
||||||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
|
||||||
|
|
||||||
async def handle_check_result(pdu: EventBase, deferred: Deferred):
|
# We limit how many PDUs we check at once, as if we try to do hundreds
|
||||||
|
# of thousands of PDUs at once we see large memory spikes.
|
||||||
|
|
||||||
|
valid_pdus = []
|
||||||
|
|
||||||
|
async def _execute(pdu: EventBase) -> None:
|
||||||
|
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
|
||||||
|
pdu=pdu,
|
||||||
|
origin=origin,
|
||||||
|
outlier=outlier,
|
||||||
|
room_version=room_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
if valid_pdu:
|
||||||
|
valid_pdus.append(valid_pdu)
|
||||||
|
|
||||||
|
await concurrently_execute(_execute, pdus, 10000)
|
||||||
|
|
||||||
|
return valid_pdus
|
||||||
|
|
||||||
|
async def _check_sigs_and_hash_and_fetch_one(
|
||||||
|
self,
|
||||||
|
pdu: EventBase,
|
||||||
|
origin: str,
|
||||||
|
room_version: RoomVersion,
|
||||||
|
outlier: bool = False,
|
||||||
|
) -> Optional[EventBase]:
|
||||||
|
"""Takes a PDU and checks its signatures and hashes. If the PDU fails
|
||||||
|
its signature check then we check if we have it in the database and if
|
||||||
|
not then request if from the originating server of that PDU.
|
||||||
|
|
||||||
|
If then PDU fails its content hash check then it is redacted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin
|
||||||
|
pdu
|
||||||
|
room_version
|
||||||
|
outlier: Whether the events are outliers or not
|
||||||
|
include_none: Whether to include None in the returned list
|
||||||
|
for events that have failed their checks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The PDU (possibly redacted) if it has valid signatures and hashes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = None
|
||||||
|
try:
|
||||||
|
res = await self._check_sigs_and_hash(room_version, pdu)
|
||||||
|
except SynapseError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not res:
|
||||||
|
# Check local db.
|
||||||
|
res = await self.store.get_event(
|
||||||
|
pdu.event_id, allow_rejected=True, allow_none=True
|
||||||
|
)
|
||||||
|
|
||||||
|
pdu_origin = get_domain_from_id(pdu.sender)
|
||||||
|
if not res and pdu_origin != origin:
|
||||||
try:
|
try:
|
||||||
res = await make_deferred_yieldable(deferred)
|
res = await self.get_pdu(
|
||||||
|
destinations=[pdu_origin],
|
||||||
|
event_id=pdu.event_id,
|
||||||
|
room_version=room_version,
|
||||||
|
outlier=outlier,
|
||||||
|
timeout=10000,
|
||||||
|
)
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
res = None
|
pass
|
||||||
|
|
||||||
if not res:
|
if not res:
|
||||||
# Check local db.
|
logger.warning(
|
||||||
res = await self.store.get_event(
|
"Failed to find copy of %s with valid signature", pdu.event_id
|
||||||
pdu.event_id, allow_rejected=True, allow_none=True
|
)
|
||||||
)
|
|
||||||
|
|
||||||
pdu_origin = get_domain_from_id(pdu.sender)
|
return res
|
||||||
if not res and pdu_origin != origin:
|
|
||||||
try:
|
|
||||||
res = await self.get_pdu(
|
|
||||||
destinations=[pdu_origin],
|
|
||||||
event_id=pdu.event_id,
|
|
||||||
room_version=room_version,
|
|
||||||
outlier=outlier,
|
|
||||||
timeout=10000,
|
|
||||||
)
|
|
||||||
except SynapseError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not res:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to find copy of %s with valid signature", pdu.event_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
handle = preserve_fn(handle_check_result)
|
|
||||||
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
|
|
||||||
|
|
||||||
valid_pdus = await make_deferred_yieldable(
|
|
||||||
defer.gatherResults(deferreds2, consumeErrors=True)
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
if include_none:
|
|
||||||
return valid_pdus
|
|
||||||
else:
|
|
||||||
return [p for p in valid_pdus if p]
|
|
||||||
|
|
||||||
async def get_event_auth(
|
async def get_event_auth(
|
||||||
self, destination: str, room_id: str, event_id: str
|
self, destination: str, room_id: str, event_id: str
|
||||||
@ -671,8 +697,6 @@ class FederationClient(FederationBase):
|
|||||||
state = response.state
|
state = response.state
|
||||||
auth_chain = response.auth_events
|
auth_chain = response.auth_events
|
||||||
|
|
||||||
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
|
|
||||||
|
|
||||||
create_event = None
|
create_event = None
|
||||||
for e in state:
|
for e in state:
|
||||||
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
||||||
@ -696,14 +720,29 @@ class FederationClient(FederationBase):
|
|||||||
% (create_room_version,)
|
% (create_room_version,)
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_pdus = await self._check_sigs_and_hash_and_fetch(
|
logger.info(
|
||||||
destination,
|
"Processing from send_join %d events", len(state) + len(auth_chain)
|
||||||
list(pdus.values()),
|
|
||||||
outlier=True,
|
|
||||||
room_version=room_version,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_pdus_map = {p.event_id: p for p in valid_pdus}
|
# We now go and check the signatures and hashes for the event. Note
|
||||||
|
# that we limit how many events we process at a time to keep the
|
||||||
|
# memory overhead from exploding.
|
||||||
|
valid_pdus_map: Dict[str, EventBase] = {}
|
||||||
|
|
||||||
|
async def _execute(pdu: EventBase) -> None:
|
||||||
|
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
|
||||||
|
pdu=pdu,
|
||||||
|
origin=destination,
|
||||||
|
outlier=True,
|
||||||
|
room_version=room_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
if valid_pdu:
|
||||||
|
valid_pdus_map[valid_pdu.event_id] = valid_pdu
|
||||||
|
|
||||||
|
await concurrently_execute(
|
||||||
|
_execute, itertools.chain(state, auth_chain), 10000
|
||||||
|
)
|
||||||
|
|
||||||
# NB: We *need* to copy to ensure that we don't have multiple
|
# NB: We *need* to copy to ensure that we don't have multiple
|
||||||
# references being passed on, as that causes... issues.
|
# references being passed on, as that causes... issues.
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import inspect
|
import inspect
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -160,8 +161,11 @@ class ObservableDeferred:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def concurrently_execute(
|
def concurrently_execute(
|
||||||
func: Callable, args: Iterable[Any], limit: int
|
func: Callable[[T], Any], args: Iterable[T], limit: int
|
||||||
) -> defer.Deferred:
|
) -> defer.Deferred:
|
||||||
"""Executes the function with each argument concurrently while limiting
|
"""Executes the function with each argument concurrently while limiting
|
||||||
the number of concurrent executions.
|
the number of concurrent executions.
|
||||||
@ -173,20 +177,27 @@ def concurrently_execute(
|
|||||||
limit: Maximum number of conccurent executions.
|
limit: Maximum number of conccurent executions.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[list]: Resolved when all function invocations have finished.
|
Deferred: Resolved when all function invocations have finished.
|
||||||
"""
|
"""
|
||||||
it = iter(args)
|
it = iter(args)
|
||||||
|
|
||||||
async def _concurrently_execute_inner():
|
async def _concurrently_execute_inner(value: T) -> None:
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
await maybe_awaitable(func(next(it)))
|
await maybe_awaitable(func(value))
|
||||||
|
value = next(it)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# We use `itertools.islice` to handle the case where the number of args is
|
||||||
|
# less than the limit, avoiding needlessly spawning unnecessary background
|
||||||
|
# tasks.
|
||||||
return make_deferred_yieldable(
|
return make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
defer.gatherResults(
|
||||||
[run_in_background(_concurrently_execute_inner) for _ in range(limit)],
|
[
|
||||||
|
run_in_background(_concurrently_execute_inner, value)
|
||||||
|
for value in itertools.islice(it, limit)
|
||||||
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
)
|
)
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
Loading…
Reference in New Issue
Block a user