Start making more things use EventContext rather than event.*

This commit is contained in:
Erik Johnston 2014-12-05 16:20:48 +00:00
parent c5c32266d8
commit 6630e1b579
10 changed files with 212 additions and 134 deletions

View File

@ -351,27 +351,27 @@ class Auth(object):
return self.store.is_server_admin(user) return self.store.is_server_admin(user)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_auth_events(self, event, current_state): def add_auth_events(self, builder, context):
if event.type == RoomCreateEvent.TYPE: if builder.type == RoomCreateEvent.TYPE:
event.auth_events = [] builder.auth_events = []
return return
auth_events = [] auth_events = []
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = current_state.get(key) power_level_event = context.current_state.get(key)
if power_level_event: if power_level_event:
auth_events.append(power_level_event.event_id) auth_events.append(power_level_event.event_id)
key = (RoomJoinRulesEvent.TYPE, "", ) key = (RoomJoinRulesEvent.TYPE, "", )
join_rule_event = current_state.get(key) join_rule_event = context.current_state.get(key)
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (RoomMemberEvent.TYPE, builder.user_id, )
member_event = current_state.get(key) member_event = context.current_state.get(key)
key = (RoomCreateEvent.TYPE, "", ) key = (RoomCreateEvent.TYPE, "", )
create_event = current_state.get(key) create_event = context.current_state.get(key)
if create_event: if create_event:
auth_events.append(create_event.event_id) auth_events.append(create_event.event_id)
@ -381,8 +381,8 @@ class Auth(object):
else: else:
is_public = False is_public = False
if event.type == RoomMemberEvent.TYPE: if builder.type == RoomMemberEvent.TYPE:
e_type = event.content["membership"] e_type = builder.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]: if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event: if join_rule_event:
auth_events.append(join_rule_event.event_id) auth_events.append(join_rule_event.event_id)
@ -393,11 +393,18 @@ class Auth(object):
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_events.append(member_event.event_id) auth_events.append(member_event.event_id)
auth_events = yield self.store.add_event_hashes( auth_ids = [(a.event_id, h) for a, h in auth_events]
auth_events auth_events_entries = yield self.store.add_event_hashes(
auth_ids
) )
defer.returnValue(auth_events) builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
@log_function @log_function
def _can_send_event(self, event, auth_events): def _can_send_event(self, event, auth_events):

View File

@ -17,8 +17,8 @@ from frozendict import frozendict
def _freeze(o): def _freeze(o):
if isinstance(o, dict): if isinstance(o, dict) or isinstance(o, frozendict):
return frozendict({k: _freeze(v) for k,v in o.items()}) return frozendict({k: _freeze(v) for k, v in o.items()})
if isinstance(o, basestring): if isinstance(o, basestring):
return o return o
@ -31,6 +31,21 @@ def _freeze(o):
return o return o
def _unfreeze(o):
if isinstance(o, frozendict) or isinstance(o, dict):
return dict({k: _unfreeze(v) for k, v in o.items()})
if isinstance(o, basestring):
return o
try:
return [_unfreeze(i) for i in o]
except TypeError:
pass
return o
class _EventInternalMetadata(object): class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict): def __init__(self, internal_metadata_dict):
self.__dict__ = internal_metadata_dict self.__dict__ = internal_metadata_dict
@ -69,6 +84,7 @@ class EventBase(object):
) )
auth_events = _event_dict_property("auth_events") auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
content = _event_dict_property("content") content = _event_dict_property("content")
event_id = _event_dict_property("event_id") event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes") hashes = _event_dict_property("hashes")
@ -81,6 +97,10 @@ class EventBase(object):
type = _event_dict_property("type") type = _event_dict_property("type")
user_id = _event_dict_property("sender") user_id = _event_dict_property("sender")
@property
def membership(self):
return self.content["membership"]
def is_state(self): def is_state(self):
return hasattr(self, "state_key") return hasattr(self, "state_key")
@ -134,3 +154,14 @@ class FrozenEvent(EventBase):
e.internal_metadata = event.internal_metadata e.internal_metadata = event.internal_metadata
return e return e
def get_dict(self):
# We need to unfreeze what we return
d = _unfreeze(self._event_dict)
d.update({
"signatures": self.signatures,
"unsigned": self.unsigned,
})
return d

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from . import EventBase
def prune_event(event): def prune_event(event):
@ -80,3 +81,18 @@ def prune_event(event):
allowed_fields["content"] = new_content allowed_fields["content"] = new_content
return type(event)(allowed_fields) return type(event)(allowed_fields)
def serialize_event(hs, e):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
if "age_ts" in d["unsigned"]:
now = int(hs.get_clock().time_msec())
d["unsigned"]["age"] = now - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"]
return d

View File

@ -62,6 +62,8 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder): def _create_new_client_event(self, builder):
context = EventContext()
latest_ret = yield self.store.get_latest_events_in_room( latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id, builder.room_id,
) )
@ -69,34 +71,26 @@ class BaseHandler(object):
depth = max([d for _, _, d in latest_ret]) depth = max([d for _, _, d in latest_ret])
prev_events = [(e, h) for e, h, _ in latest_ret] prev_events = [(e, h) for e, h, _ in latest_ret]
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler state_handler = self.state_handler
if builder.is_state(): ret = yield state_handler.annotate_context_with_state(
ret = yield state_handler.resolve_state_groups( builder,
[e for e, _ in prev_events], context,
event_type=builder.event_type,
state_key=builder.state_key,
) )
group, prev_state = ret
group, curr_state, prev_state = ret if builder.is_state():
prev_state = yield self.store.add_event_hashes( prev_state = yield self.store.add_event_hashes(
prev_state prev_state
) )
builder.prev_state = prev_state builder.prev_state = prev_state
else:
group, curr_state, _ = yield state_handler.resolve_state_groups(
[e for e, _ in prev_events],
)
builder.internal_metadata.state_group = group builder.internal_metadata.state_group = group
builder.prev_events = prev_events yield self.auth.add_auth_events(builder, context)
builder.depth = depth
auth_events = yield self.auth.get_auth_events(builder, curr_state)
builder.update_event_key("auth_events", auth_events)
add_hashes_and_signatures( add_hashes_and_signatures(
builder, self.server_name, self.signing_key builder, self.server_name, self.signing_key
@ -104,18 +98,6 @@ class BaseHandler(object):
event = builder.build() event = builder.build()
auth_ids = zip(*auth_events)[0]
curr_auth_events = {
k: v
for k, v in curr_state.items()
if v.event_id in auth_ids
}
context = EventContext(
current_state=curr_state,
auth_events=curr_auth_events,
)
defer.returnValue( defer.returnValue(
(event, context,) (event, context,)
) )
@ -128,7 +110,7 @@ class BaseHandler(object):
if not suppress_auth: if not suppress_auth:
self.auth.check(event, auth_events=context.auth_events) self.auth.check(event, auth_events=context.auth_events)
yield self.store.persist_event(event) yield self.store.persist_event(event, context=context)
destinations = set(extra_destinations) destinations = set(extra_destinations)
for k, s in context.current_state.items(): for k, s in context.current_state.items():
@ -152,63 +134,63 @@ class BaseHandler(object):
destinations=destinations, destinations=destinations,
) )
@defer.inlineCallbacks # @defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[], # def _on_new_room_event(self, event, snapshot, extra_destinations=[],
extra_users=[], suppress_auth=False, # extra_users=[], suppress_auth=False,
do_invite_host=None): # do_invite_host=None):
yield run_on_reactor() # yield run_on_reactor()
#
snapshot.fill_out_prev_events(event) # snapshot.fill_out_prev_events(event)
#
yield self.state_handler.annotate_event_with_state(event) # yield self.state_handler.annotate_event_with_state(event)
#
yield self.auth.add_auth_events(event) # yield self.auth.add_auth_events(event)
#
logger.debug("Signing event...") # logger.debug("Signing event...")
#
add_hashes_and_signatures( # add_hashes_and_signatures(
event, self.server_name, self.signing_key # event, self.server_name, self.signing_key
) # )
#
logger.debug("Signed event.") # logger.debug("Signed event.")
#
if not suppress_auth: # if not suppress_auth:
logger.debug("Authing...") # logger.debug("Authing...")
self.auth.check(event, auth_events=event.old_state_events) # self.auth.check(event, auth_events=event.old_state_events)
logger.debug("Authed") # logger.debug("Authed")
else: # else:
logger.debug("Suppressed auth.") # logger.debug("Suppressed auth.")
#
if do_invite_host: # if do_invite_host:
federation_handler = self.hs.get_handlers().federation_handler # federation_handler = self.hs.get_handlers().federation_handler
invite_event = yield federation_handler.send_invite( # invite_event = yield federation_handler.send_invite(
do_invite_host, # do_invite_host,
event # event
) # )
#
# FIXME: We need to check if the remote changed anything else # # FIXME: We need to check if the remote changed anything else
event.signatures = invite_event.signatures # event.signatures = invite_event.signatures
#
yield self.store.persist_event(event) # yield self.store.persist_event(event)
#
destinations = set(extra_destinations) # destinations = set(extra_destinations)
# Send a PDU to all hosts who have joined the room. # # Send a PDU to all hosts who have joined the room.
#
for k, s in event.state_events.items(): # for k, s in event.state_events.items():
try: # try:
if k[0] == RoomMemberEvent.TYPE: # if k[0] == RoomMemberEvent.TYPE:
if s.content["membership"] == Membership.JOIN: # if s.content["membership"] == Membership.JOIN:
destinations.add( # destinations.add(
self.hs.parse_userid(s.state_key).domain # self.hs.parse_userid(s.state_key).domain
) # )
except: # except:
logger.warn( # logger.warn(
"Failed to get destination from event %s", s.event_id # "Failed to get destination from event %s", s.event_id
) # )
#
event.destinations = list(destinations) # event.destinations = list(destinations)
#
yield self.notifier.on_new_room_event(event, extra_users=extra_users) # yield self.notifier.on_new_room_event(event, extra_users=extra_users)
#
federation_handler = self.hs.get_handlers().federation_handler # federation_handler = self.hs.get_handlers().federation_handler
yield federation_handler.handle_new_event(event, snapshot) # yield federation_handler.handle_new_event(event, snapshot)

View File

@ -17,7 +17,8 @@
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.events.utils import prune_event from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, SynapseError, StoreError, AuthError, FederationError, SynapseError, StoreError,
) )
@ -416,7 +417,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_make_join_request(self, context, user_id): def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial """ We've received a /make_join/ request, so we create a partial
join event for the room and return that. We don *not* persist or join event for the room and return that. We don *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
@ -424,7 +425,7 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new({
"type": RoomMemberEvent.TYPE, "type": RoomMemberEvent.TYPE,
"content": {"membership": Membership.JOIN}, "content": {"membership": Membership.JOIN},
"room_id": context, "room_id": room_id,
"sender": user_id, "sender": user_id,
"state_key": user_id, "state_key": user_id,
}) })
@ -433,9 +434,7 @@ class FederationHandler(BaseHandler):
builder=builder, builder=builder,
) )
yield self.state_handler.annotate_event_with_state(event) self.auth.check(event, auth_events=context.auth_events)
yield self.auth.add_auth_events(event)
self.auth.check(event, auth_events=event.old_state_events)
pdu = event pdu = event
@ -505,7 +504,9 @@ class FederationHandler(BaseHandler):
""" """
event = pdu event = pdu
event.outlier = True context = EventContext()
event.internal_metadata.outlier = True
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
@ -515,10 +516,11 @@ class FederationHandler(BaseHandler):
) )
) )
yield self.state_handler.annotate_event_with_state(event) yield self.state_handler.annotate_context_with_state(event, context)
yield self.store.persist_event( yield self.store.persist_event(
event, event,
context=context,
backfilled=False, backfilled=False,
) )
@ -640,6 +642,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False, def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None, fetch_missing=True): current_state=None, fetch_missing=True):
context = EventContext()
is_new_state = yield self.state_handler.annotate_event_with_state( is_new_state = yield self.state_handler.annotate_event_with_state(
event, event,
old_state=state old_state=state

View File

@ -20,7 +20,7 @@
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.api.events import serialize_event from synapse.events.utils import serialize_event
from synapse.api.events.factory import EventFactory from synapse.api.events.factory import EventFactory
from synapse.api.events.validator import EventValidator from synapse.api.events.validator import EventValidator
from synapse.notifier import Notifier from synapse.notifier import Notifier

View File

@ -135,6 +135,39 @@ class StateHandler(object):
defer.returnValue(res[1].values()) defer.returnValue(res[1].values())
@defer.inlineCallbacks
def annotate_context_with_state(self, event, context):
if event.is_state():
ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events],
event_type=event.event_type,
state_key=event.state_key,
)
else:
ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events],
)
group, curr_state, prev_state = ret
context.current_state = curr_state
prev_state = yield self.store.add_event_hashes(
prev_state
)
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0]
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
defer.returnValue(
(group, prev_state)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups(self, event_ids, event_type=None, state_key=""): def resolve_state_groups(self, event_ids, event_type=None, state_key=""):

View File

@ -21,6 +21,7 @@ from synapse.api.events.room import (
) )
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.frozenutils import FrozenEncoder
from .directory import DirectoryStore from .directory import DirectoryStore
from .feedback import FeedbackStore from .feedback import FeedbackStore
@ -93,8 +94,8 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, backfilled=False, is_new_state=True, def persist_event(self, event, context, backfilled=False,
current_state=None): is_new_state=True, current_state=None):
stream_ordering = None stream_ordering = None
if backfilled: if backfilled:
if not self.min_token_deferred.called: if not self.min_token_deferred.called:
@ -107,6 +108,7 @@ class DataStore(RoomMemberStore, RoomStore,
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
event=event, event=event,
context=context,
backfilled=backfilled, backfilled=backfilled,
stream_ordering=stream_ordering, stream_ordering=stream_ordering,
is_new_state=is_new_state, is_new_state=is_new_state,
@ -138,8 +140,9 @@ class DataStore(RoomMemberStore, RoomStore,
defer.returnValue(event[0]) defer.returnValue(event[0])
@log_function @log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, def _persist_event_txn(self, txn, event, context, backfilled,
is_new_state=True, current_state=None): stream_ordering=None, is_new_state=True,
current_state=None):
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:
self._store_room_member_txn(txn, event) self._store_room_member_txn(txn, event)
elif event.type == FeedbackEvent.TYPE: elif event.type == FeedbackEvent.TYPE:
@ -152,12 +155,12 @@ class DataStore(RoomMemberStore, RoomStore,
self._store_redaction(txn, event) self._store_redaction(txn, event)
outlier = False outlier = False
if hasattr(event, "outlier"): if hasattr(event.internal_metadata, "outlier"):
outlier = event.outlier outlier = event.internal_metadata.outlier
event_dict = { event_dict = {
k: v k: v
for k, v in event.get_full_dict().items() for k, v in event.get_dict().items()
if k not in [ if k not in [
"redacted", "redacted",
"redacted_because", "redacted_because",
@ -179,7 +182,7 @@ class DataStore(RoomMemberStore, RoomStore,
"event_id": event.event_id, "event_id": event.event_id,
"type": event.type, "type": event.type,
"room_id": event.room_id, "room_id": event.room_id,
"content": json.dumps(event.content), "content": json.dumps(event.content, cls=FrozenEncoder),
"processed": True, "processed": True,
"outlier": outlier, "outlier": outlier,
"depth": event.depth, "depth": event.depth,
@ -190,7 +193,7 @@ class DataStore(RoomMemberStore, RoomStore,
unrec = { unrec = {
k: v k: v
for k, v in event.get_full_dict().items() for k, v in event.get_dict().items()
if k not in vals.keys() and k not in [ if k not in vals.keys() and k not in [
"redacted", "redacted",
"redacted_because", "redacted_because",
@ -225,7 +228,7 @@ class DataStore(RoomMemberStore, RoomStore,
room_id=event.room_id, room_id=event.room_id,
) )
self._store_state_groups_txn(txn, event) self._store_state_groups_txn(txn, event, context)
if current_state: if current_state:
txn.execute( txn.execute(

View File

@ -15,7 +15,8 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from syutil.base64util import encode_base64 from syutil.base64util import encode_base64
@ -497,10 +498,7 @@ class SQLBaseStore(object):
d = json.loads(js) d = json.loads(js)
ev = self.event_factory.create_event( ev = FrozenEvent(d)
etype=d["type"],
**d
)
if hasattr(ev, "redacted") and ev.redacted: if hasattr(ev, "redacted") and ev.redacted:
# Get the redaction event. # Get the redaction event.

View File

@ -86,11 +86,16 @@ class StateStore(SQLBaseStore):
self._store_state_groups_txn, event self._store_state_groups_txn, event
) )
def _store_state_groups_txn(self, txn, event): def _store_state_groups_txn(self, txn, event, context):
if event.state_events is None: if context.current_state_events is None:
return return
state_group = event.state_group state_events = context.current_state_events
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = context.state_group
if not state_group: if not state_group:
state_group = self._simple_insert_txn( state_group = self._simple_insert_txn(
txn, txn,
@ -102,7 +107,7 @@ class StateStore(SQLBaseStore):
or_ignore=True, or_ignore=True,
) )
for state in event.state_events.values(): for state in context.state_events.values():
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups_state", table="state_groups_state",