mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Make token serializing/deserializing async (#8427)
The idea is that in future tokens will encode a mapping of instance to position. However, we don't want to include the full instance name in the string representation, so instead we'll have a mapping between instance name and an immutable integer ID in the DB that we can use instead. We'll then do the lookup when we serialize/deserialize the token (we could alternatively pass around an `Instance` type that includes both the name and ID, but that turns out to be a lot more invasive).
This commit is contained in:
parent
a0a1ba6973
commit
7941372ec8
1
changelog.d/8427.misc
Normal file
1
changelog.d/8427.misc
Normal file
@ -0,0 +1 @@
|
||||
Make stream token serializing/deserializing async.
|
@ -133,8 +133,8 @@ class EventStreamHandler(BaseHandler):
|
||||
|
||||
chunk = {
|
||||
"chunk": chunks,
|
||||
"start": tokens[0].to_string(),
|
||||
"end": tokens[1].to_string(),
|
||||
"start": await tokens[0].to_string(self.store),
|
||||
"end": await tokens[1].to_string(self.store),
|
||||
}
|
||||
|
||||
return chunk
|
||||
|
@ -203,8 +203,8 @@ class InitialSyncHandler(BaseHandler):
|
||||
messages, time_now=time_now, as_client_event=as_client_event
|
||||
)
|
||||
),
|
||||
"start": start_token.to_string(),
|
||||
"end": end_token.to_string(),
|
||||
"start": await start_token.to_string(self.store),
|
||||
"end": await end_token.to_string(self.store),
|
||||
}
|
||||
|
||||
d["state"] = await self._event_serializer.serialize_events(
|
||||
@ -249,7 +249,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
],
|
||||
"account_data": account_data_events,
|
||||
"receipts": receipt,
|
||||
"end": now_token.to_string(),
|
||||
"end": await now_token.to_string(self.store),
|
||||
}
|
||||
|
||||
return ret
|
||||
@ -348,8 +348,8 @@ class InitialSyncHandler(BaseHandler):
|
||||
"chunk": (
|
||||
await self._event_serializer.serialize_events(messages, time_now)
|
||||
),
|
||||
"start": start_token.to_string(),
|
||||
"end": end_token.to_string(),
|
||||
"start": await start_token.to_string(self.store),
|
||||
"end": await end_token.to_string(self.store),
|
||||
},
|
||||
"state": (
|
||||
await self._event_serializer.serialize_events(
|
||||
@ -447,8 +447,8 @@ class InitialSyncHandler(BaseHandler):
|
||||
"chunk": (
|
||||
await self._event_serializer.serialize_events(messages, time_now)
|
||||
),
|
||||
"start": start_token.to_string(),
|
||||
"end": end_token.to_string(),
|
||||
"start": await start_token.to_string(self.store),
|
||||
"end": await end_token.to_string(self.store),
|
||||
},
|
||||
"state": state,
|
||||
"presence": presence,
|
||||
|
@ -413,8 +413,8 @@ class PaginationHandler:
|
||||
if not events:
|
||||
return {
|
||||
"chunk": [],
|
||||
"start": from_token.to_string(),
|
||||
"end": next_token.to_string(),
|
||||
"start": await from_token.to_string(self.store),
|
||||
"end": await next_token.to_string(self.store),
|
||||
}
|
||||
|
||||
state = None
|
||||
@ -442,8 +442,8 @@ class PaginationHandler:
|
||||
events, time_now, as_client_event=as_client_event
|
||||
)
|
||||
),
|
||||
"start": from_token.to_string(),
|
||||
"end": next_token.to_string(),
|
||||
"start": await from_token.to_string(self.store),
|
||||
"end": await next_token.to_string(self.store),
|
||||
}
|
||||
|
||||
if state:
|
||||
|
@ -1077,11 +1077,13 @@ class RoomContextHandler:
|
||||
# the token, which we replace.
|
||||
token = StreamToken.START
|
||||
|
||||
results["start"] = token.copy_and_replace(
|
||||
results["start"] = await token.copy_and_replace(
|
||||
"room_key", results["start"]
|
||||
).to_string()
|
||||
).to_string(self.store)
|
||||
|
||||
results["end"] = token.copy_and_replace("room_key", results["end"]).to_string()
|
||||
results["end"] = await token.copy_and_replace(
|
||||
"room_key", results["end"]
|
||||
).to_string(self.store)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -362,13 +362,13 @@ class SearchHandler(BaseHandler):
|
||||
self.storage, user.to_string(), res["events_after"]
|
||||
)
|
||||
|
||||
res["start"] = now_token.copy_and_replace(
|
||||
res["start"] = await now_token.copy_and_replace(
|
||||
"room_key", res["start"]
|
||||
).to_string()
|
||||
).to_string(self.store)
|
||||
|
||||
res["end"] = now_token.copy_and_replace(
|
||||
res["end"] = await now_token.copy_and_replace(
|
||||
"room_key", res["end"]
|
||||
).to_string()
|
||||
).to_string(self.store)
|
||||
|
||||
if include_profile:
|
||||
senders = {
|
||||
|
@ -110,7 +110,7 @@ class PurgeHistoryRestServlet(RestServlet):
|
||||
raise SynapseError(400, "Event is for wrong room.")
|
||||
|
||||
room_token = await self.store.get_topological_token_for_event(event_id)
|
||||
token = str(room_token)
|
||||
token = await room_token.to_string(self.store)
|
||||
|
||||
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
|
||||
elif "purge_up_to_ts" in body:
|
||||
|
@ -33,6 +33,7 @@ class EventStreamRestServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.event_stream_handler = hs.get_event_stream_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet):
|
||||
if b"room_id" in request.args:
|
||||
room_id = request.args[b"room_id"][0].decode("ascii")
|
||||
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
pagin_config = await PaginationConfig.from_request(self.store, request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if b"timeout" in request.args:
|
||||
try:
|
||||
|
@ -27,11 +27,12 @@ class InitialSyncRestServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.initial_sync_handler = hs.get_initial_sync_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
as_client_event = b"raw" not in request.args
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
pagination_config = await PaginationConfig.from_request(self.store, request)
|
||||
include_archived = parse_boolean(request, "archived", default=False)
|
||||
content = await self.initial_sync_handler.snapshot_all_rooms(
|
||||
user_id=requester.user.to_string(),
|
||||
|
@ -451,6 +451,7 @@ class RoomMemberListRestServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.message_handler = hs.get_message_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, room_id):
|
||||
# TODO support Pagination stream API (limit/tokens)
|
||||
@ -465,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet):
|
||||
if at_token_string is None:
|
||||
at_token = None
|
||||
else:
|
||||
at_token = StreamToken.from_string(at_token_string)
|
||||
at_token = await StreamToken.from_string(self.store, at_token_string)
|
||||
|
||||
# let you filter down on particular memberships.
|
||||
# XXX: this may not be the best shape for this API - we could pass in a filter
|
||||
@ -521,10 +522,13 @@ class RoomMessageListRestServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.pagination_handler = hs.get_pagination_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, room_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(request, default_limit=10)
|
||||
pagination_config = await PaginationConfig.from_request(
|
||||
self.store, request, default_limit=10
|
||||
)
|
||||
as_client_event = b"raw" not in request.args
|
||||
filter_str = parse_string(request, b"filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
@ -580,10 +584,11 @@ class RoomInitialSyncRestServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.initial_sync_handler = hs.get_initial_sync_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, room_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
pagination_config = await PaginationConfig.from_request(self.store, request)
|
||||
content = await self.initial_sync_handler.room_initial_sync(
|
||||
room_id=room_id, requester=requester, pagin_config=pagination_config
|
||||
)
|
||||
|
@ -180,6 +180,7 @@ class KeyChangesServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet):
|
||||
# changes after the "to" as well as before.
|
||||
set_tag("to", parse_string(request, "to"))
|
||||
|
||||
from_token = StreamToken.from_string(from_token_string)
|
||||
from_token = await StreamToken.from_string(self.store, from_token_string)
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
|
@ -77,6 +77,7 @@ class SyncRestServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.sync_handler = hs.get_sync_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.filtering = hs.get_filtering()
|
||||
@ -151,10 +152,9 @@ class SyncRestServlet(RestServlet):
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
since_token = None
|
||||
if since is not None:
|
||||
since_token = StreamToken.from_string(since)
|
||||
else:
|
||||
since_token = None
|
||||
since_token = await StreamToken.from_string(self.store, since)
|
||||
|
||||
# send any outstanding server notices to the user.
|
||||
await self._server_notices_sender.on_user_syncing(user.to_string())
|
||||
@ -236,7 +236,7 @@ class SyncRestServlet(RestServlet):
|
||||
"leave": sync_result.groups.leave,
|
||||
},
|
||||
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
||||
"next_batch": sync_result.next_batch.to_string(),
|
||||
"next_batch": await sync_result.next_batch.to_string(self.store),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@ -413,7 +413,7 @@ class SyncRestServlet(RestServlet):
|
||||
result = {
|
||||
"timeline": {
|
||||
"events": serialized_timeline,
|
||||
"prev_batch": room.timeline.prev_batch.to_string(),
|
||||
"prev_batch": await room.timeline.prev_batch.to_string(self.store),
|
||||
"limited": room.timeline.limited,
|
||||
},
|
||||
"state": {"events": serialized_state},
|
||||
|
@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
The set of state groups that are referenced by deleted events.
|
||||
"""
|
||||
|
||||
parsed_token = await RoomStreamToken.parse(self, token)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"purge_history",
|
||||
self._purge_history_txn,
|
||||
room_id,
|
||||
token,
|
||||
parsed_token,
|
||||
delete_local_events,
|
||||
)
|
||||
|
||||
def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
|
||||
token = RoomStreamToken.parse(token_str)
|
||||
|
||||
def _purge_history_txn(self, txn, room_id, token, delete_local_events):
|
||||
# Tables that should be pruned:
|
||||
# event_auth
|
||||
# event_backward_extremities
|
||||
|
@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@ -21,6 +20,7 @@ import attr
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import StreamToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -39,8 +39,9 @@ class PaginationConfig:
|
||||
limit = attr.ib(type=Optional[int])
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
async def from_request(
|
||||
cls,
|
||||
store: "DataStore",
|
||||
request: SynapseRequest,
|
||||
raise_invalid_params: bool = True,
|
||||
default_limit: Optional[int] = None,
|
||||
@ -54,13 +55,13 @@ class PaginationConfig:
|
||||
if from_tok == "END":
|
||||
from_tok = None # For backwards compat.
|
||||
elif from_tok:
|
||||
from_tok = StreamToken.from_string(from_tok)
|
||||
from_tok = await StreamToken.from_string(store, from_tok)
|
||||
except Exception:
|
||||
raise SynapseError(400, "'from' parameter is invalid")
|
||||
|
||||
try:
|
||||
if to_tok:
|
||||
to_tok = StreamToken.from_string(to_tok)
|
||||
to_tok = await StreamToken.from_string(store, to_tok)
|
||||
except Exception:
|
||||
raise SynapseError(400, "'to' parameter is invalid")
|
||||
|
||||
|
@ -18,7 +18,17 @@ import re
|
||||
import string
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import attr
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
@ -26,6 +36,9 @@ from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
# define a version of typing.Collection that works on python 3.5
|
||||
if sys.version_info[:3] >= (3, 6, 0):
|
||||
from typing import Collection
|
||||
@ -393,7 +406,7 @@ class RoomStreamToken:
|
||||
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string: str) -> "RoomStreamToken":
|
||||
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
|
||||
try:
|
||||
if string[0] == "s":
|
||||
return cls(topological=None, stream=int(string[1:]))
|
||||
@ -428,7 +441,7 @@ class RoomStreamToken:
|
||||
def as_tuple(self) -> Tuple[Optional[int], int]:
|
||||
return (self.topological, self.stream)
|
||||
|
||||
def __str__(self) -> str:
|
||||
async def to_string(self, store: "DataStore") -> str:
|
||||
if self.topological is not None:
|
||||
return "t%d-%d" % (self.topological, self.stream)
|
||||
else:
|
||||
@ -453,18 +466,32 @@ class StreamToken:
|
||||
START = None # type: StreamToken
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, string):
|
||||
async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
|
||||
try:
|
||||
keys = string.split(cls._SEPARATOR)
|
||||
while len(keys) < len(attr.fields(cls)):
|
||||
# i.e. old token from before receipt_key
|
||||
keys.append("0")
|
||||
return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
|
||||
return cls(
|
||||
await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
|
||||
)
|
||||
except Exception:
|
||||
raise SynapseError(400, "Invalid Token")
|
||||
|
||||
def to_string(self):
|
||||
return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
|
||||
async def to_string(self, store: "DataStore") -> str:
|
||||
return self._SEPARATOR.join(
|
||||
[
|
||||
await self.room_key.to_string(store),
|
||||
str(self.presence_key),
|
||||
str(self.typing_key),
|
||||
str(self.receipt_key),
|
||||
str(self.account_data_key),
|
||||
str(self.push_rules_key),
|
||||
str(self.to_device_key),
|
||||
str(self.device_list_key),
|
||||
str(self.groups_key),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def room_stream_id(self):
|
||||
@ -493,7 +520,7 @@ class StreamToken:
|
||||
return attr.evolve(self, **{key: new_value})
|
||||
|
||||
|
||||
StreamToken.START = StreamToken.from_string("s0_0")
|
||||
StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
|
@ -902,16 +902,18 @@ class RoomMessageListTestCase(RoomBase):
|
||||
|
||||
# Send a first message in the room, which will be removed by the purge.
|
||||
first_event_id = self.helper.send(self.room_id, "message 1")["event_id"]
|
||||
first_token = str(
|
||||
self.get_success(store.get_topological_token_for_event(first_event_id))
|
||||
first_token = self.get_success(
|
||||
store.get_topological_token_for_event(first_event_id)
|
||||
)
|
||||
first_token_str = self.get_success(first_token.to_string(store))
|
||||
|
||||
# Send a second message in the room, which won't be removed, and which we'll
|
||||
# use as the marker to purge events before.
|
||||
second_event_id = self.helper.send(self.room_id, "message 2")["event_id"]
|
||||
second_token = str(
|
||||
self.get_success(store.get_topological_token_for_event(second_event_id))
|
||||
second_token = self.get_success(
|
||||
store.get_topological_token_for_event(second_event_id)
|
||||
)
|
||||
second_token_str = self.get_success(second_token.to_string(store))
|
||||
|
||||
# Send a third event in the room to ensure we don't fall under any edge case
|
||||
# due to our marker being the latest forward extremity in the room.
|
||||
@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase):
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
|
||||
% (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
|
||||
% (
|
||||
self.room_id,
|
||||
second_token_str,
|
||||
json.dumps({"types": [EventTypes.Message]}),
|
||||
),
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase):
|
||||
pagination_handler._purge_history(
|
||||
purge_id=purge_id,
|
||||
room_id=self.room_id,
|
||||
token=second_token,
|
||||
token=second_token_str,
|
||||
delete_local_events=True,
|
||||
)
|
||||
)
|
||||
@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase):
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
|
||||
% (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
|
||||
% (
|
||||
self.room_id,
|
||||
second_token_str,
|
||||
json.dumps({"types": [EventTypes.Message]}),
|
||||
),
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase):
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
|
||||
% (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
|
||||
% (
|
||||
self.room_id,
|
||||
first_token_str,
|
||||
json.dumps({"types": [EventTypes.Message]}),
|
||||
),
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
@ -47,12 +47,15 @@ class PurgeTests(HomeserverTestCase):
|
||||
storage = self.hs.get_storage()
|
||||
|
||||
# Get the topological token
|
||||
event = str(
|
||||
self.get_success(store.get_topological_token_for_event(last["event_id"]))
|
||||
token = self.get_success(
|
||||
store.get_topological_token_for_event(last["event_id"])
|
||||
)
|
||||
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
|
||||
|
||||
# Purge everything before this topological token
|
||||
self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
|
||||
self.get_success(
|
||||
storage.purge_events.purge_history(self.room_id, token_str, True)
|
||||
)
|
||||
|
||||
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
||||
# and last is not.
|
||||
|
Loading…
Reference in New Issue
Block a user