Pass room_version into event_from_pdu_json

It's called from all over the shop, so this one's a bit messy.
This commit is contained in:
Richard van der Hoff 2020-01-31 16:50:13 +00:00
parent b0c8bdd49d
commit 928edef979
5 changed files with 51 additions and 60 deletions

1
changelog.d/6856.misc Normal file
View File

@ -0,0 +1 @@
Refactoring work in preparation for changing the event redaction algorithm.

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -22,9 +23,13 @@ from twisted.internet.defer import DeferredList
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
)
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import event_type_from_format_version from synapse.events import EventBase, event_type_from_format_version
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import ( from synapse.logging.context import (
@ -33,7 +38,7 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
preserve_fn, preserve_fn,
) )
from synapse.types import get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -342,16 +347,15 @@ def _is_invite_via_3pid(event):
) )
def event_from_pdu_json(pdu_json, event_format_version, outlier=False): def event_from_pdu_json(
"""Construct a FrozenEvent from an event json received over federation pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False
) -> EventBase:
"""Construct an EventBase from an event json received over federation
Args: Args:
pdu_json (object): pdu as received over federation pdu_json: pdu as received over federation
event_format_version (int): The event format version room_version: The version of the room this event belongs to
outlier (bool): True to mark this event as an outlier outlier: True to mark this event as an outlier
Returns:
FrozenEvent
Raises: Raises:
SynapseError: if the pdu is missing required fields or is otherwise SynapseError: if the pdu is missing required fields or is otherwise
@ -370,7 +374,7 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
elif depth > MAX_DEPTH: elif depth > MAX_DEPTH:
raise SynapseError(400, "Depth too large", Codes.BAD_JSON) raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
event = event_type_from_format_version(event_format_version)(pdu_json) event = event_type_from_format_version(room_version.event_format)(pdu_json)
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier

View File

@ -49,7 +49,7 @@ from synapse.api.room_versions import (
RoomVersion, RoomVersion,
RoomVersions, RoomVersions,
) )
from synapse.events import EventBase, builder, room_version_to_event_format from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
@ -209,18 +209,18 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%r", transaction_data) logger.debug("backfill transaction_data=%r", transaction_data)
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdus = [ pdus = [
event_from_pdu_json(p, format_ver, outlier=False) event_from_pdu_json(p, room_version, outlier=False)
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = await make_deferred_yieldable( pdus[:] = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True self._check_sigs_and_hashes(room_version.identifier, pdus),
consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
) )
@ -262,8 +262,6 @@ class FederationClient(FederationBase):
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
format_ver = room_version.event_format
signed_pdu = None signed_pdu = None
for destination in destinations: for destination in destinations:
now = self._clock.time_msec() now = self._clock.time_msec()
@ -284,7 +282,7 @@ class FederationClient(FederationBase):
) )
pdu_list = [ pdu_list = [
event_from_pdu_json(p, format_ver, outlier=outlier) event_from_pdu_json(p, room_version, outlier=outlier)
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
@ -350,15 +348,15 @@ class FederationClient(FederationBase):
async def get_event_auth(self, destination, room_id, event_id): async def get_event_auth(self, destination, room_id, event_id):
res = await self.transport_layer.get_event_auth(destination, room_id, event_id) res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [ auth_chain = [
event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"] event_from_pdu_json(p, room_version, outlier=True)
for p in res["auth_chain"]
] ]
signed_auth = await self._check_sigs_and_hash_and_fetch( signed_auth = await self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True, room_version=room_version destination, auth_chain, outlier=True, room_version=room_version.identifier
) )
signed_auth.sort(key=lambda e: e.depth) signed_auth.sort(key=lambda e: e.depth)
@ -547,12 +545,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
state = [ state = [
event_from_pdu_json(p, room_version.event_format, outlier=True) event_from_pdu_json(p, room_version, outlier=True)
for p in content.get("state", []) for p in content.get("state", [])
] ]
auth_chain = [ auth_chain = [
event_from_pdu_json(p, room_version.event_format, outlier=True) event_from_pdu_json(p, room_version, outlier=True)
for p in content.get("auth_chain", []) for p in content.get("auth_chain", [])
] ]
@ -677,7 +675,7 @@ class FederationClient(FederationBase):
logger.debug("Got response to send_invite: %s", pdu_dict) logger.debug("Got response to send_invite: %s", pdu_dict)
pdu = event_from_pdu_json(pdu_dict, room_version.event_format) pdu = event_from_pdu_json(pdu_dict, room_version)
# Check signatures are correct. # Check signatures are correct.
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
@ -865,15 +863,14 @@ class FederationClient(FederationBase):
timeout=timeout, timeout=timeout,
) )
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
events = [ events = [
event_from_pdu_json(e, format_ver) for e in content.get("events", []) event_from_pdu_json(e, room_version) for e in content.get("events", [])
] ]
signed_events = await self._check_sigs_and_hash_and_fetch( signed_events = await self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False, room_version=room_version destination, events, outlier=False, room_version=room_version.identifier
) )
except HttpResponseException as e: except HttpResponseException as e:
if not e.code == 400: if not e.code == 400:

View File

@ -38,7 +38,6 @@ from synapse.api.errors import (
UnsupportedRoomVersionError, UnsupportedRoomVersionError,
) )
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
@ -234,24 +233,17 @@ class FederationServer(FederationBase):
continue continue
try: try:
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version(room_id)
except NotFoundError: except NotFoundError:
logger.info("Ignoring PDU for unknown room_id: %s", room_id) logger.info("Ignoring PDU for unknown room_id: %s", room_id)
continue continue
except UnsupportedRoomVersionError as e:
try:
format_ver = room_version_to_event_format(room_version)
except UnsupportedRoomVersionError:
# this can happen if support for a given room version is withdrawn, # this can happen if support for a given room version is withdrawn,
# so that we still get events for said room. # so that we still get events for said room.
logger.info( logger.info("Ignoring PDU: %s", e)
"Ignoring PDU for room %s with unknown version %s",
room_id,
room_version,
)
continue continue
event = event_from_pdu_json(p, format_ver) event = event_from_pdu_json(p, room_version)
pdus_by_room.setdefault(room_id, []).append(event) pdus_by_room.setdefault(room_id, []).append(event)
pdu_results = {} pdu_results = {}
@ -407,9 +399,7 @@ class FederationServer(FederationBase):
Codes.UNSUPPORTED_ROOM_VERSION, Codes.UNSUPPORTED_ROOM_VERSION,
) )
format_ver = room_version.event_format pdu = event_from_pdu_json(content, room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id) await self.check_server_matches_acl(origin_host, pdu.room_id)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
@ -420,16 +410,15 @@ class FederationServer(FederationBase):
async def on_send_join_request(self, origin, content, room_id): async def on_send_join_request(self, origin, content, room_id):
logger.debug("on_send_join_request: content: %s", content) logger.debug("on_send_join_request: content: %s", content)
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id) await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
pdu = await self._check_sigs_and_hash(room_version, pdu) pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
res_pdus = await self.handler.on_send_join_request(origin, pdu) res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
@ -451,16 +440,15 @@ class FederationServer(FederationBase):
async def on_send_leave_request(self, origin, content, room_id): async def on_send_leave_request(self, origin, content, room_id):
logger.debug("on_send_leave_request: content: %s", content) logger.debug("on_send_leave_request: content: %s", content)
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id) await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
pdu = await self._check_sigs_and_hash(room_version, pdu) pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
await self.handler.on_send_leave_request(origin, pdu) await self.handler.on_send_leave_request(origin, pdu)
return {} return {}
@ -498,15 +486,14 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id) await self.check_server_matches_acl(origin_host, room_id)
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [ auth_chain = [
event_from_pdu_json(e, format_ver) for e in content["auth_chain"] event_from_pdu_json(e, room_version) for e in content["auth_chain"]
] ]
signed_auth = await self._check_sigs_and_hash_and_fetch( signed_auth = await self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True, room_version=room_version origin, auth_chain, outlier=True, room_version=room_version.identifier
) )
ret = await self.handler.on_query_auth( ret = await self.handler.on_query_auth(

View File

@ -99,6 +99,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test") tok = self.login("kermit", "test")
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
room_version = self.get_success(self.store.get_room_version(room_id))
# pretend that another server has joined # pretend that another server has joined
join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
@ -120,7 +121,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"auth_events": [], "auth_events": [],
"origin_server_ts": self.clock.time_msec(), "origin_server_ts": self.clock.time_msec(),
}, },
join_event.format_version, room_version,
) )
with LoggingContext(request="send_rejected"): with LoggingContext(request="send_rejected"):
@ -149,6 +150,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test") tok = self.login("kermit", "test")
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
room_version = self.get_success(self.store.get_room_version(room_id))
# pretend that another server has joined # pretend that another server has joined
join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
@ -171,7 +173,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"auth_events": [], "auth_events": [],
"origin_server_ts": self.clock.time_msec(), "origin_server_ts": self.clock.time_msec(),
}, },
join_event.format_version, room_version,
) )
with LoggingContext(request="send_rejected"): with LoggingContext(request="send_rejected"):