Use stable identifiers for faster joins (#14832)

* Use new query param when requesting a partial join

* Read new query param when serving partial join

* Provide new field names when serving partial joins

* Read new field names from partial join response

* Changelog
This commit is contained in:
David Robertson 2023-01-13 17:58:53 +00:00 committed by GitHub
parent 73ff493dfb
commit 52ae80dd1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 87 additions and 22 deletions

1
changelog.d/14832.misc Normal file
View File

@ -0,0 +1 @@
Faster joins: use stable identifiers from [MSC3706](https://github.com/matrix-org/matrix-spec-proposals/pull/3706).

View File

@ -725,10 +725,12 @@ class FederationServer(FederationBase):
"state": [p.get_pdu_json(time_now) for p in state_events], "state": [p.get_pdu_json(time_now) for p in state_events],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events], "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
"org.matrix.msc3706.partial_state": caller_supports_partial_state, "org.matrix.msc3706.partial_state": caller_supports_partial_state,
"members_omitted": caller_supports_partial_state,
} }
if servers_in_room is not None: if servers_in_room is not None:
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room) resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)
resp["servers_in_room"] = list(servers_in_room)
return resp return resp

View File

@ -357,6 +357,7 @@ class TransportLayerClient:
if self._faster_joins_enabled: if self._faster_joins_enabled:
# lazy-load state on join # lazy-load state on join
query_params["org.matrix.msc3706.partial_state"] = "true" query_params["org.matrix.msc3706.partial_state"] = "true"
query_params["omit_members"] = "true"
return await self.client.put_json( return await self.client.put_json(
destination=destination, destination=destination,
@ -909,6 +910,14 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
use_float="True", use_float="True",
) )
) )
# The stable field name comes last, so it "wins" if the fields disagree
self._coros.append(
ijson.items_coro(
_partial_state_parser(self._response),
"members_omitted",
use_float="True",
)
)
self._coros.append( self._coros.append(
ijson.items_coro( ijson.items_coro(
@ -918,6 +927,15 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
) )
) )
# Again, stable field name comes last
self._coros.append(
ijson.items_coro(
_servers_in_room_parser(self._response),
"servers_in_room",
use_float="True",
)
)
def write(self, data: bytes) -> int: def write(self, data: bytes) -> int:
for c in self._coros: for c in self._coros:
c.send(data) c.send(data)

View File

@ -437,9 +437,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
partial_state = False partial_state = False
if self._msc3706_enabled: if self._msc3706_enabled:
# The stable query parameter wins, if it disagrees with the unstable
# parameter for some reason.
stable_param = parse_boolean_from_args(query, "omit_members", default=None)
if stable_param is not None:
partial_state = stable_param
else:
partial_state = parse_boolean_from_args( partial_state = parse_boolean_from_args(
query, "org.matrix.msc3706.partial_state", default=False query, "org.matrix.msc3706.partial_state", default=False
) )
result = await self.handler.on_send_join_request( result = await self.handler.on_send_join_request(
origin, content, room_id, caller_supports_partial_state=partial_state origin, content, room_id, caller_supports_partial_state=partial_state
) )

View File

@ -224,7 +224,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
) )
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
"PUT", "PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", f"/_matrix/federation/v2/send_join/{self._room_id}/x?omit_members=true",
content=join_event_dict, content=join_event_dict,
) )
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

View File

@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
import json import json
from typing import List, Optional
from unittest.mock import Mock from unittest.mock import Mock
import ijson.common import ijson.common
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.federation.transport.client import SendJoinParser from synapse.federation.transport.client import SendJoinParser
from synapse.types import JsonDict
from synapse.util import ExceptionBundle from synapse.util import ExceptionBundle
from tests.unittest import TestCase from tests.unittest import TestCase
@ -71,11 +73,9 @@ class SendJoinParserTestCase(TestCase):
def test_partial_state(self) -> None: def test_partial_state(self) -> None:
"""Check that the partial_state flag is correctly parsed""" """Check that the partial_state flag is correctly parsed"""
parser = SendJoinParser(RoomVersions.V1, False)
response = {
"org.matrix.msc3706.partial_state": True,
}
def parse(response: JsonDict) -> bool:
parser = SendJoinParser(RoomVersions.V1, False)
serialised_response = json.dumps(response).encode() serialised_response = json.dumps(response).encode()
# Send data to the parser # Send data to the parser
@ -83,13 +83,27 @@ class SendJoinParserTestCase(TestCase):
# Retrieve and check the parsed SendJoinResponse # Retrieve and check the parsed SendJoinResponse
parsed_response = parser.finish() parsed_response = parser.finish()
self.assertTrue(parsed_response.partial_state) return parsed_response.partial_state
self.assertTrue(parse({"members_omitted": True}))
self.assertTrue(parse({"org.matrix.msc3706.partial_state": True}))
self.assertFalse(parse({"members_omitted": False}))
self.assertFalse(parse({"org.matrix.msc3706.partial_state": False}))
# If there's a conflict, the stable field wins.
self.assertTrue(
parse({"members_omitted": True, "org.matrix.msc3706.partial_state": False})
)
self.assertFalse(
parse({"members_omitted": False, "org.matrix.msc3706.partial_state": True})
)
def test_servers_in_room(self) -> None: def test_servers_in_room(self) -> None:
"""Check that the servers_in_room field is correctly parsed""" """Check that the servers_in_room field is correctly parsed"""
parser = SendJoinParser(RoomVersions.V1, False)
response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
def parse(response: JsonDict) -> Optional[List[str]]:
parser = SendJoinParser(RoomVersions.V1, False)
serialised_response = json.dumps(response).encode() serialised_response = json.dumps(response).encode()
# Send data to the parser # Send data to the parser
@ -97,7 +111,30 @@ class SendJoinParserTestCase(TestCase):
# Retrieve and check the parsed SendJoinResponse # Retrieve and check the parsed SendJoinResponse
parsed_response = parser.finish() parsed_response = parser.finish()
self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"]) return parsed_response.servers_in_room
self.assertEqual(
parse({"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}),
["hs1", "hs2"],
)
self.assertEqual(parse({"servers_in_room": ["example.com"]}), ["example.com"])
# If both are provided, the stable identifier should win
self.assertEqual(
parse(
{
"org.matrix.msc3706.servers_in_room": ["old"],
"servers_in_room": ["new"],
}
),
["new"],
)
# And lastly, we should be able to tell if neither field was present.
self.assertEqual(
parse({}),
None,
)
def test_errors_closing_coroutines(self) -> None: def test_errors_closing_coroutines(self) -> None:
"""Check we close all coroutines, even if closing the first raises an Exception. """Check we close all coroutines, even if closing the first raises an Exception.