Refactor event building into EventBuilder

This is so that everything is done in one place, making it easier to
change the event format based on room version
This commit is contained in:
Erik Johnston 2019-01-25 17:19:31 +00:00
parent 554ca58ea1
commit be47cfa9c9
5 changed files with 257 additions and 115 deletions

View File

@ -13,79 +13,156 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import attr
from synapse.api.constants import RoomVersions from twisted.internet import defer
from synapse.api.constants import (
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS,
MAX_DEPTH,
)
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.types import EventID from synapse.types import EventID
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from . import EventBase, FrozenEvent, _event_dict_property from . import (
_EventInternalMetadata,
event_type_from_format_version,
room_version_to_event_format,
)
def get_event_builder(room_version, key_values={}, internal_metadata_dict={}): @attr.s(slots=True, cmp=False, frozen=True)
"""Generate an event builder appropriate for the given room version class EventBuilder(object):
"""A format independent event builder used to build up the event content
before signing the event.
Args: (Note that while objects of this class are frozen, the
room_version (str): Version of the room that we're creating an content/unsigned/internal_metadata fields are still mutable)
event builder for
key_values (dict): Fields used as the basis of the new event
internal_metadata_dict (dict): Used to create the `_EventInternalMetadata`
object.
Returns: Attributes:
EventBuilder format_version (int): Event format version
room_id (str)
type (str)
sender (str)
content (dict)
unsigned (dict)
internal_metadata (_EventInternalMetadata)
_state (StateHandler)
_auth (synapse.api.Auth)
_store (DataStore)
_clock (Clock)
_hostname (str): The hostname of the server creating the event
_signing_key: The signing key to use to sign the event as the server
""" """
if room_version in {
RoomVersions.V1, _state = attr.ib()
RoomVersions.V2, _auth = attr.ib()
RoomVersions.VDH_TEST, _store = attr.ib()
RoomVersions.STATE_V2_TEST, _clock = attr.ib()
}: _hostname = attr.ib()
return EventBuilder(key_values, internal_metadata_dict) _signing_key = attr.ib()
else:
raise Exception( format_version = attr.ib()
"No event format defined for version %r" % (room_version,)
room_id = attr.ib()
type = attr.ib()
sender = attr.ib()
content = attr.ib(default=attr.Factory(dict))
unsigned = attr.ib(default=attr.Factory(dict))
# These only exist on a subset of events, so they raise AttributeError if
# someone tries to get them when they don't exist.
_state_key = attr.ib(default=None)
_redacts = attr.ib(default=None)
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
@property
def state_key(self):
if self._state_key is not None:
return self._state_key
raise AttributeError("state_key")
def is_state(self):
return self._state_key is not None
@defer.inlineCallbacks
def build(self, prev_event_ids):
"""Transform into a fully signed and hashed event
Args:
prev_event_ids (list[str]): The event IDs to use as the prev events
Returns:
Deferred[FrozenEvent]
"""
state_ids = yield self._state.get_current_state_ids(
self.room_id, prev_event_ids,
)
auth_ids = yield self._auth.compute_auth_events(
self, state_ids,
) )
auth_events = yield self._store.add_event_hashes(auth_ids)
prev_events = yield self._store.add_event_hashes(prev_event_ids)
class EventBuilder(EventBase): old_depth = yield self._store.get_max_depth_of(
def __init__(self, key_values={}, internal_metadata_dict={}): prev_event_ids,
signatures = copy.deepcopy(key_values.pop("signatures", {}))
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
super(EventBuilder, self).__init__(
key_values,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
) )
depth = old_depth + 1
event_id = _event_dict_property("event_id") # we cap depth of generated events, to ensure that they are not
state_key = _event_dict_property("state_key") # rejected by other servers (and so that they can be persisted in
type = _event_dict_property("type") # the db)
depth = min(depth, MAX_DEPTH)
def build(self): event_dict = {
return FrozenEvent.from_event(self) "auth_events": auth_events,
"prev_events": prev_events,
"type": self.type,
"room_id": self.room_id,
"sender": self.sender,
"content": self.content,
"unsigned": self.unsigned,
"depth": depth,
"prev_state": [],
}
if self.is_state():
event_dict["state_key"] = self._state_key
if self._redacts is not None:
event_dict["redacts"] = self._redacts
defer.returnValue(
create_local_event_from_event_dict(
clock=self._clock,
hostname=self._hostname,
signing_key=self._signing_key,
format_version=self.format_version,
event_dict=event_dict,
internal_metadata_dict=self.internal_metadata.get_dict(),
)
)
class EventBuilderFactory(object): class EventBuilderFactory(object):
def __init__(self, clock, hostname): def __init__(self, hs):
self.clock = clock self.clock = hs.get_clock()
self.hostname = hostname self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self.event_id_count = 0 self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
def create_event_id(self): def new(self, room_version, key_values):
i = str(self.event_id_count)
self.event_id_count += 1
local_part = str(int(self.clock.time())) + i + random_string(5)
e_id = EventID(local_part, self.hostname)
return e_id.to_string()
def new(self, room_version, key_values={}):
"""Generate an event builder appropriate for the given room version """Generate an event builder appropriate for the given room version
Args: Args:
@ -98,27 +175,104 @@ class EventBuilderFactory(object):
""" """
# There's currently only the one event version defined # There's currently only the one event version defined
if room_version not in { if room_version not in KNOWN_ROOM_VERSIONS:
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.VDH_TEST,
RoomVersions.STATE_V2_TEST,
}:
raise Exception( raise Exception(
"No event format defined for version %r" % (room_version,) "No event format defined for version %r" % (room_version,)
) )
key_values["event_id"] = self.create_event_id() key_values["event_id"] = _create_event_id(self.clock, self.hostname)
time_now = int(self.clock.time_msec()) return EventBuilder(
store=self.store,
state=self.state,
auth=self.auth,
clock=self.clock,
hostname=self.hostname,
signing_key=self.signing_key,
format_version=room_version_to_event_format(room_version),
type=key_values["type"],
state_key=key_values.get("state_key"),
room_id=key_values["room_id"],
sender=key_values["sender"],
content=key_values.get("content", {}),
unsigned=key_values.get("unsigned", {}),
redacts=key_values.get("redacts", None),
)
key_values.setdefault("origin", self.hostname)
key_values.setdefault("origin_server_ts", time_now)
key_values.setdefault("unsigned", {}) def create_local_event_from_event_dict(clock, hostname, signing_key,
age = key_values["unsigned"].pop("age", 0) format_version, event_dict,
key_values["unsigned"].setdefault("age_ts", time_now - age) internal_metadata_dict=None):
"""Takes a fully formed event dict, ensuring that fields like `origin`
and `origin_server_ts` have correct values for a locally produced event,
then signs and hashes it.
key_values["signatures"] = {} Args:
clock (Clock)
hostname (str)
signing_key
format_version (int)
event_dict (dict)
internal_metadata_dict (dict|None)
return EventBuilder(key_values=key_values,) Returns:
FrozenEvent
"""
# There's currently only the one event version defined
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
raise Exception(
"No event format defined for version %r" % (format_version,)
)
if internal_metadata_dict is None:
internal_metadata_dict = {}
time_now = int(clock.time_msec())
event_dict["event_id"] = _create_event_id(clock, hostname)
event_dict["origin"] = hostname
event_dict["origin_server_ts"] = time_now
event_dict.setdefault("unsigned", {})
age = event_dict["unsigned"].pop("age", 0)
event_dict["unsigned"].setdefault("age_ts", time_now - age)
event_dict.setdefault("signatures", {})
add_hashes_and_signatures(
event_dict,
hostname,
signing_key,
)
return event_type_from_format_version(format_version)(
event_dict, internal_metadata_dict=internal_metadata_dict,
)
# A counter used when generating new event IDs
_event_id_counter = 0
def _create_event_id(clock, hostname):
"""Create a new event ID
Args:
clock (Clock)
hostname (str): The server name for the event ID
Returns:
str
"""
global _event_id_counter
i = str(_event_id_counter)
_event_id_counter += 1
local_part = str(int(clock.time())) + i + random_string(5)
e_id = EventID(local_part, hostname)
return e_id.to_string()

View File

@ -37,8 +37,7 @@ from synapse.api.errors import (
HttpResponseException, HttpResponseException,
SynapseError, SynapseError,
) )
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import builder, room_version_to_event_format
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -72,7 +71,8 @@ class FederationClient(FederationBase):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client() self.transport_layer = hs.get_federation_transport_client()
self.event_builder_factory = hs.get_event_builder_factory() self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self._get_pdu_cache = ExpiringCache( self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
@ -608,18 +608,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict: if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = [] pdu_dict["prev_state"] = []
# Strip off the fields that we want to clobber. ev = builder.create_local_event_from_event_dict(
pdu_dict.pop("origin", None) self._clock, self.hostname, self.signing_key,
pdu_dict.pop("origin_server_ts", None) format_version=event_format, event_dict=pdu_dict,
pdu_dict.pop("unsigned", None)
builder = self.event_builder_factory.new(room_version, pdu_dict)
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0]
) )
ev = builder.build()
defer.returnValue( defer.returnValue(
(destination, ev, event_format) (destination, ev, event_format)

View File

@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -31,7 +31,6 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@ -545,40 +544,17 @@ class EventCreationHandler(object):
prev_events_and_hashes = \ prev_events_and_hashes = \
yield self.store.get_prev_events_for_room(builder.room_id) yield self.store.get_prev_events_for_room(builder.room_id)
if prev_events_and_hashes:
depth = max([d for _, _, d in prev_events_and_hashes]) + 1
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(depth, MAX_DEPTH)
else:
depth = 1
prev_events = [ prev_events = [
(event_id, prev_hashes) (event_id, prev_hashes)
for event_id, prev_hashes, _ in prev_events_and_hashes for event_id, prev_hashes, _ in prev_events_and_hashes
] ]
builder.prev_events = prev_events event = yield builder.build(
builder.depth = depth prev_event_ids=[p for p, _ in prev_events],
context = yield self.state.compute_event_context(builder)
if requester:
context.app_service = requester.app_service
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
signing_key = self.hs.config.signing_key[0]
add_hashes_and_signatures(
builder, self.server_name, signing_key
) )
context = yield self.state.compute_event_context(event)
event = builder.build() self.validator.validate_new(event)
logger.debug( logger.debug(
"Created event %s", "Created event %s",

View File

@ -355,10 +355,7 @@ class HomeServer(object):
return Keyring(self) return Keyring(self)
def build_event_builder_factory(self): def build_event_builder_factory(self):
return EventBuilderFactory( return EventBuilderFactory(self)
clock=self.get_clock(),
hostname=self.hostname,
)
def build_filtering(self): def build_filtering(self):
return Filtering(self) return Filtering(self)

View File

@ -125,6 +125,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return dict(txn) return dict(txn)
@defer.inlineCallbacks
def get_max_depth_of(self, event_ids):
"""Returns the max depth of a set of event IDs
Args:
event_ids (list[str])
Returns
Deferred[int]
"""
rows = yield self._simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=("depth",),
desc="get_max_depth_of",
)
if not rows:
defer.returnValue(0)
else:
defer.returnValue(max(row["depth"] for row in rows))
def _get_oldest_events_in_room_txn(self, txn, room_id): def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn( return self._simple_select_onecol_txn(
txn, txn,