mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add basic impl for room history ACL on GET /messages client API
This commit is contained in:
parent
6825eef955
commit
1a60545626
@ -75,6 +75,8 @@ class EventTypes(object):
|
||||
Redaction = "m.room.redaction"
|
||||
Feedback = "m.room.message.feedback"
|
||||
|
||||
RoomHistoryVisibility = "m.room.history_visibility"
|
||||
|
||||
# These are used for validation
|
||||
Message = "m.room.message"
|
||||
Topic = "m.room.topic"
|
||||
|
@ -113,11 +113,42 @@ class MessageHandler(BaseHandler):
|
||||
"room_key", next_key
|
||||
)
|
||||
|
||||
if not events:
|
||||
defer.returnValue({
|
||||
"chunk": [],
|
||||
"start": pagin_config.from_token.to_string(),
|
||||
"end": next_token.to_string(),
|
||||
})
|
||||
|
||||
states = yield self.store.get_state_for_events(
|
||||
room_id, [e.event_id for e in events],
|
||||
)
|
||||
|
||||
events_and_states = zip(events, states)
|
||||
|
||||
def allowed(event_and_state):
|
||||
_, state = event_and_state
|
||||
|
||||
membership = state.get((EventTypes.Member, user_id), None)
|
||||
if membership and membership.membership == Membership.JOIN:
|
||||
return True
|
||||
|
||||
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||
if history and history.content["visibility"] == "after_join":
|
||||
return False
|
||||
|
||||
events_and_states = filter(allowed, events_and_states)
|
||||
events = [
|
||||
ev
|
||||
for ev, _ in events_and_states
|
||||
]
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
chunk = {
|
||||
"chunk": [
|
||||
serialize_event(e, time_now, as_client_event) for e in events
|
||||
serialize_event(e, time_now, as_client_event)
|
||||
for e in events
|
||||
],
|
||||
"start": pagin_config.from_token.to_string(),
|
||||
"end": next_token.to_string(),
|
||||
|
@ -92,11 +92,11 @@ class StateStore(SQLBaseStore):
|
||||
defer.returnValue(dict(state_list))
|
||||
|
||||
@cached(num_args=1)
|
||||
def _fetch_events_for_group(self, state_group, events):
|
||||
def _fetch_events_for_group(self, key, events):
|
||||
return self._get_events(
|
||||
events, get_prev_content=False
|
||||
).addCallback(
|
||||
lambda evs: (state_group, evs)
|
||||
lambda evs: (key, evs)
|
||||
)
|
||||
|
||||
def _store_state_groups_txn(self, txn, event, context):
|
||||
@ -194,6 +194,65 @@ class StateStore(SQLBaseStore):
|
||||
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||
defer.returnValue(events)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_events(self, room_id, event_ids):
|
||||
def f(txn):
|
||||
groups = set()
|
||||
event_to_group = {}
|
||||
for event_id in event_ids:
|
||||
# TODO: Remove this loop.
|
||||
group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="state_group",
|
||||
allow_none=True,
|
||||
)
|
||||
if group:
|
||||
event_to_group[event_id] = group
|
||||
groups.add(group)
|
||||
|
||||
group_to_state_ids = {}
|
||||
for group in groups:
|
||||
state_ids = self._simple_select_onecol_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
keyvalues={"state_group": group},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
group_to_state_ids[group] = state_ids
|
||||
|
||||
return event_to_group, group_to_state_ids
|
||||
|
||||
res = yield self.runInteraction(
|
||||
"annotate_events_with_state_groups",
|
||||
f,
|
||||
)
|
||||
|
||||
event_to_group, group_to_state_ids = res
|
||||
|
||||
state_list = yield defer.gatherResults(
|
||||
[
|
||||
self._fetch_events_for_group(group, vals)
|
||||
for group, vals in group_to_state_ids.items()
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
state_dict = {
|
||||
group: {
|
||||
(ev.type, ev.state_key): ev
|
||||
for ev in state
|
||||
}
|
||||
for group, state in state_list
|
||||
}
|
||||
|
||||
defer.returnValue([
|
||||
state_dict.get(event_to_group.get(event, None), None)
|
||||
for event in event_ids
|
||||
])
|
||||
|
||||
|
||||
def _make_group_id(clock):
|
||||
return str(int(clock.time_msec())) + random_string(5)
|
||||
|
Loading…
Reference in New Issue
Block a user