Start making more things use EventContext rather than event.*

This commit is contained in:
Erik Johnston 2014-12-05 16:20:48 +00:00
parent c5c32266d8
commit 6630e1b579
10 changed files with 212 additions and 134 deletions

View File

@ -351,27 +351,27 @@ class Auth(object):
return self.store.is_server_admin(user)
@defer.inlineCallbacks
def get_auth_events(self, event, current_state):
if event.type == RoomCreateEvent.TYPE:
event.auth_events = []
def add_auth_events(self, builder, context):
if builder.type == RoomCreateEvent.TYPE:
builder.auth_events = []
return
auth_events = []
key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = current_state.get(key)
power_level_event = context.current_state.get(key)
if power_level_event:
auth_events.append(power_level_event.event_id)
key = (RoomJoinRulesEvent.TYPE, "", )
join_rule_event = current_state.get(key)
join_rule_event = context.current_state.get(key)
key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = current_state.get(key)
key = (RoomMemberEvent.TYPE, builder.user_id, )
member_event = context.current_state.get(key)
key = (RoomCreateEvent.TYPE, "", )
create_event = current_state.get(key)
create_event = context.current_state.get(key)
if create_event:
auth_events.append(create_event.event_id)
@ -381,8 +381,8 @@ class Auth(object):
else:
is_public = False
if event.type == RoomMemberEvent.TYPE:
e_type = event.content["membership"]
if builder.type == RoomMemberEvent.TYPE:
e_type = builder.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event:
auth_events.append(join_rule_event.event_id)
@ -393,11 +393,18 @@ class Auth(object):
if member_event.content["membership"] == Membership.JOIN:
auth_events.append(member_event.event_id)
auth_events = yield self.store.add_event_hashes(
auth_events
auth_ids = [(a.event_id, h) for a, h in auth_events]
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
)
defer.returnValue(auth_events)
builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
@log_function
def _can_send_event(self, event, auth_events):

View File

@ -17,8 +17,8 @@ from frozendict import frozendict
def _freeze(o):
if isinstance(o, dict):
return frozendict({k: _freeze(v) for k,v in o.items()})
if isinstance(o, dict) or isinstance(o, frozendict):
return frozendict({k: _freeze(v) for k, v in o.items()})
if isinstance(o, basestring):
return o
@ -31,6 +31,21 @@ def _freeze(o):
return o
def _unfreeze(o):
if isinstance(o, frozendict) or isinstance(o, dict):
return dict({k: _unfreeze(v) for k, v in o.items()})
if isinstance(o, basestring):
return o
try:
return [_unfreeze(i) for i in o]
except TypeError:
pass
return o
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = internal_metadata_dict
@ -69,6 +84,7 @@ class EventBase(object):
)
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes")
@ -81,6 +97,10 @@ class EventBase(object):
type = _event_dict_property("type")
user_id = _event_dict_property("sender")
@property
def membership(self):
return self.content["membership"]
def is_state(self):
return hasattr(self, "state_key")
@ -134,3 +154,14 @@ class FrozenEvent(EventBase):
e.internal_metadata = event.internal_metadata
return e
def get_dict(self):
# We need to unfreeze what we return
d = _unfreeze(self._event_dict)
d.update({
"signatures": self.signatures,
"unsigned": self.unsigned,
})
return d

View File

@ -14,6 +14,7 @@
# limitations under the License.
from synapse.api.constants import EventTypes
from . import EventBase
def prune_event(event):
@ -80,3 +81,18 @@ def prune_event(event):
allowed_fields["content"] = new_content
return type(event)(allowed_fields)
def serialize_event(hs, e):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
if "age_ts" in d["unsigned"]:
now = int(hs.get_clock().time_msec())
d["unsigned"]["age"] = now - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"]
return d

View File

@ -62,6 +62,8 @@ class BaseHandler(object):
@defer.inlineCallbacks
def _create_new_client_event(self, builder):
context = EventContext()
latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id,
)
@ -69,34 +71,26 @@ class BaseHandler(object):
depth = max([d for _, _, d in latest_ret])
prev_events = [(e, h) for e, h, _ in latest_ret]
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler
if builder.is_state():
ret = yield state_handler.resolve_state_groups(
[e for e, _ in prev_events],
event_type=builder.event_type,
state_key=builder.state_key,
ret = yield state_handler.annotate_context_with_state(
builder,
context,
)
group, prev_state = ret
group, curr_state, prev_state = ret
if builder.is_state():
prev_state = yield self.store.add_event_hashes(
prev_state
)
builder.prev_state = prev_state
else:
group, curr_state, _ = yield state_handler.resolve_state_groups(
[e for e, _ in prev_events],
)
builder.internal_metadata.state_group = group
builder.prev_events = prev_events
builder.depth = depth
auth_events = yield self.auth.get_auth_events(builder, curr_state)
builder.update_event_key("auth_events", auth_events)
yield self.auth.add_auth_events(builder, context)
add_hashes_and_signatures(
builder, self.server_name, self.signing_key
@ -104,18 +98,6 @@ class BaseHandler(object):
event = builder.build()
auth_ids = zip(*auth_events)[0]
curr_auth_events = {
k: v
for k, v in curr_state.items()
if v.event_id in auth_ids
}
context = EventContext(
current_state=curr_state,
auth_events=curr_auth_events,
)
defer.returnValue(
(event, context,)
)
@ -128,7 +110,7 @@ class BaseHandler(object):
if not suppress_auth:
self.auth.check(event, auth_events=context.auth_events)
yield self.store.persist_event(event)
yield self.store.persist_event(event, context=context)
destinations = set(extra_destinations)
for k, s in context.current_state.items():
@ -152,63 +134,63 @@ class BaseHandler(object):
destinations=destinations,
)
@defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[],
extra_users=[], suppress_auth=False,
do_invite_host=None):
yield run_on_reactor()
snapshot.fill_out_prev_events(event)
yield self.state_handler.annotate_event_with_state(event)
yield self.auth.add_auth_events(event)
logger.debug("Signing event...")
add_hashes_and_signatures(
event, self.server_name, self.signing_key
)
logger.debug("Signed event.")
if not suppress_auth:
logger.debug("Authing...")
self.auth.check(event, auth_events=event.old_state_events)
logger.debug("Authed")
else:
logger.debug("Suppressed auth.")
if do_invite_host:
federation_handler = self.hs.get_handlers().federation_handler
invite_event = yield federation_handler.send_invite(
do_invite_host,
event
)
# FIXME: We need to check if the remote changed anything else
event.signatures = invite_event.signatures
yield self.store.persist_event(event)
destinations = set(extra_destinations)
# Send a PDU to all hosts who have joined the room.
for k, s in event.state_events.items():
try:
if k[0] == RoomMemberEvent.TYPE:
if s.content["membership"] == Membership.JOIN:
destinations.add(
self.hs.parse_userid(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
event.destinations = list(destinations)
yield self.notifier.on_new_room_event(event, extra_users=extra_users)
federation_handler = self.hs.get_handlers().federation_handler
yield federation_handler.handle_new_event(event, snapshot)
# @defer.inlineCallbacks
# def _on_new_room_event(self, event, snapshot, extra_destinations=[],
# extra_users=[], suppress_auth=False,
# do_invite_host=None):
# yield run_on_reactor()
#
# snapshot.fill_out_prev_events(event)
#
# yield self.state_handler.annotate_event_with_state(event)
#
# yield self.auth.add_auth_events(event)
#
# logger.debug("Signing event...")
#
# add_hashes_and_signatures(
# event, self.server_name, self.signing_key
# )
#
# logger.debug("Signed event.")
#
# if not suppress_auth:
# logger.debug("Authing...")
# self.auth.check(event, auth_events=event.old_state_events)
# logger.debug("Authed")
# else:
# logger.debug("Suppressed auth.")
#
# if do_invite_host:
# federation_handler = self.hs.get_handlers().federation_handler
# invite_event = yield federation_handler.send_invite(
# do_invite_host,
# event
# )
#
# # FIXME: We need to check if the remote changed anything else
# event.signatures = invite_event.signatures
#
# yield self.store.persist_event(event)
#
# destinations = set(extra_destinations)
# # Send a PDU to all hosts who have joined the room.
#
# for k, s in event.state_events.items():
# try:
# if k[0] == RoomMemberEvent.TYPE:
# if s.content["membership"] == Membership.JOIN:
# destinations.add(
# self.hs.parse_userid(s.state_key).domain
# )
# except:
# logger.warn(
# "Failed to get destination from event %s", s.event_id
# )
#
# event.destinations = list(destinations)
#
# yield self.notifier.on_new_room_event(event, extra_users=extra_users)
#
# federation_handler = self.hs.get_handlers().federation_handler
# yield federation_handler.handle_new_event(event, snapshot)

View File

@ -17,7 +17,8 @@
from ._base import BaseHandler
from synapse.api.events.utils import prune_event
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
from synapse.api.errors import (
AuthError, FederationError, SynapseError, StoreError,
)
@ -416,7 +417,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def on_make_join_request(self, context, user_id):
def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial
join event for the room and return that. We don *not* persist or
process it until the other server has signed it and sent it back.
@ -424,7 +425,7 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new({
"type": RoomMemberEvent.TYPE,
"content": {"membership": Membership.JOIN},
"room_id": context,
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
})
@ -433,9 +434,7 @@ class FederationHandler(BaseHandler):
builder=builder,
)
yield self.state_handler.annotate_event_with_state(event)
yield self.auth.add_auth_events(event)
self.auth.check(event, auth_events=event.old_state_events)
self.auth.check(event, auth_events=context.auth_events)
pdu = event
@ -505,7 +504,9 @@ class FederationHandler(BaseHandler):
"""
event = pdu
event.outlier = True
context = EventContext()
event.internal_metadata.outlier = True
event.signatures.update(
compute_event_signature(
@ -515,10 +516,11 @@ class FederationHandler(BaseHandler):
)
)
yield self.state_handler.annotate_event_with_state(event)
yield self.state_handler.annotate_context_with_state(event, context)
yield self.store.persist_event(
event,
context=context,
backfilled=False,
)
@ -640,6 +642,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None, fetch_missing=True):
context = EventContext()
is_new_state = yield self.state_handler.annotate_event_with_state(
event,
old_state=state

View File

@ -20,7 +20,7 @@
# Imports required for the default HomeServer() implementation
from synapse.federation import initialize_http_replication
from synapse.api.events import serialize_event
from synapse.events.utils import serialize_event
from synapse.api.events.factory import EventFactory
from synapse.api.events.validator import EventValidator
from synapse.notifier import Notifier

View File

@ -135,6 +135,39 @@ class StateHandler(object):
defer.returnValue(res[1].values())
@defer.inlineCallbacks
def annotate_context_with_state(self, event, context):
if event.is_state():
ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events],
event_type=event.event_type,
state_key=event.state_key,
)
else:
ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events],
)
group, curr_state, prev_state = ret
context.current_state = curr_state
prev_state = yield self.store.add_event_hashes(
prev_state
)
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0]
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
defer.returnValue(
(group, prev_state)
)
@defer.inlineCallbacks
@log_function
def resolve_state_groups(self, event_ids, event_type=None, state_key=""):

View File

@ -21,6 +21,7 @@ from synapse.api.events.room import (
)
from synapse.util.logutils import log_function
from synapse.util.frozenutils import FrozenEncoder
from .directory import DirectoryStore
from .feedback import FeedbackStore
@ -93,8 +94,8 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks
@log_function
def persist_event(self, event, backfilled=False, is_new_state=True,
current_state=None):
def persist_event(self, event, context, backfilled=False,
is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
if not self.min_token_deferred.called:
@ -107,6 +108,7 @@ class DataStore(RoomMemberStore, RoomStore,
"persist_event",
self._persist_event_txn,
event=event,
context=context,
backfilled=backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
@ -138,8 +140,9 @@ class DataStore(RoomMemberStore, RoomStore,
defer.returnValue(event[0])
@log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
is_new_state=True, current_state=None):
def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True,
current_state=None):
if event.type == RoomMemberEvent.TYPE:
self._store_room_member_txn(txn, event)
elif event.type == FeedbackEvent.TYPE:
@ -152,12 +155,12 @@ class DataStore(RoomMemberStore, RoomStore,
self._store_redaction(txn, event)
outlier = False
if hasattr(event, "outlier"):
outlier = event.outlier
if hasattr(event.internal_metadata, "outlier"):
outlier = event.internal_metadata.outlier
event_dict = {
k: v
for k, v in event.get_full_dict().items()
for k, v in event.get_dict().items()
if k not in [
"redacted",
"redacted_because",
@ -179,7 +182,7 @@ class DataStore(RoomMemberStore, RoomStore,
"event_id": event.event_id,
"type": event.type,
"room_id": event.room_id,
"content": json.dumps(event.content),
"content": json.dumps(event.content, cls=FrozenEncoder),
"processed": True,
"outlier": outlier,
"depth": event.depth,
@ -190,7 +193,7 @@ class DataStore(RoomMemberStore, RoomStore,
unrec = {
k: v
for k, v in event.get_full_dict().items()
for k, v in event.get_dict().items()
if k not in vals.keys() and k not in [
"redacted",
"redacted_because",
@ -225,7 +228,7 @@ class DataStore(RoomMemberStore, RoomStore,
room_id=event.room_id,
)
self._store_state_groups_txn(txn, event)
self._store_state_groups_txn(txn, event, context)
if current_state:
txn.execute(

View File

@ -15,7 +15,8 @@
import logging
from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from syutil.base64util import encode_base64
@ -497,10 +498,7 @@ class SQLBaseStore(object):
d = json.loads(js)
ev = self.event_factory.create_event(
etype=d["type"],
**d
)
ev = FrozenEvent(d)
if hasattr(ev, "redacted") and ev.redacted:
# Get the redaction event.

View File

@ -86,11 +86,16 @@ class StateStore(SQLBaseStore):
self._store_state_groups_txn, event
)
def _store_state_groups_txn(self, txn, event):
if event.state_events is None:
def _store_state_groups_txn(self, txn, event, context):
if context.current_state_events is None:
return
state_group = event.state_group
state_events = context.current_state_events
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = context.state_group
if not state_group:
state_group = self._simple_insert_txn(
txn,
@ -102,7 +107,7 @@ class StateStore(SQLBaseStore):
or_ignore=True,
)
for state in event.state_events.values():
for state in context.state_events.values():
self._simple_insert_txn(
txn,
table="state_groups_state",