Factor out an is_mine_server_name method (#15542)

Add an `is_mine_server_name` method, similar to `is_mine_id`.

Ideally we would use this consistently, instead of sometimes comparing
against `hs.hostname` and other times reaching into
`hs.config.server.server_name`.

Also fix a bug in the tests where `hs.hostname` would sometimes differ
from `hs.config.server.server_name`.

Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
Sean Quah 2023-05-05 15:06:22 +01:00 committed by GitHub
parent 83e7fa5eee
commit e46d5f3586
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 64 additions and 36 deletions

1
changelog.d/15542.misc Normal file
View File

@ -0,0 +1 @@
Factor out an `is_mine_server_name` method.

View File

@ -39,7 +39,7 @@ class AuthBlocking:
self._mau_limits_reserved_threepids = ( self._mau_limits_reserved_threepids = (
hs.config.server.mau_limits_reserved_threepids hs.config.server.mau_limits_reserved_threepids
) )
self._server_name = hs.hostname self._is_mine_server_name = hs.is_mine_server_name
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
async def check_auth_blocking( async def check_auth_blocking(
@ -77,7 +77,7 @@ class AuthBlocking:
if requester: if requester:
if requester.authenticated_entity.startswith("@"): if requester.authenticated_entity.startswith("@"):
user_id = requester.authenticated_entity user_id = requester.authenticated_entity
elif requester.authenticated_entity == self._server_name: elif self._is_mine_server_name(requester.authenticated_entity):
# We never block the server from doing actions on behalf of # We never block the server from doing actions on behalf of
# users. # users.
return return

View File

@ -173,7 +173,7 @@ class Keyring:
process_batch_callback=self._inner_fetch_key_requests, process_batch_callback=self._inner_fetch_key_requests,
) )
self._hostname = hs.hostname self._is_mine_server_name = hs.is_mine_server_name
# build a FetchKeyResult for each of our own keys, to shortcircuit the # build a FetchKeyResult for each of our own keys, to shortcircuit the
# fetcher. # fetcher.
@ -277,7 +277,7 @@ class Keyring:
# If we are the originating server, short-circuit the key-fetch for any keys # If we are the originating server, short-circuit the key-fetch for any keys
# we already have # we already have
if verify_request.server_name == self._hostname: if self._is_mine_server_name(verify_request.server_name):
for key_id in verify_request.key_ids: for key_id in verify_request.key_ids:
if key_id in self._local_verify_keys: if key_id in self._local_verify_keys:
found_keys[key_id] = self._local_verify_keys[key_id] found_keys[key_id] = self._local_verify_keys[key_id]

View File

@ -49,7 +49,7 @@ class FederationBase:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.server_name = hs.hostname self._is_mine_server_name = hs.is_mine_server_name
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self.store = hs.get_datastores().main self.store = hs.get_datastores().main

View File

@ -854,7 +854,7 @@ class FederationClient(FederationBase):
for destination in destinations: for destination in destinations:
# We don't want to ask our own server for information we don't have # We don't want to ask our own server for information we don't have
if destination == self.server_name: if self._is_mine_server_name(destination):
continue continue
try: try:
@ -1536,7 +1536,7 @@ class FederationClient(FederationBase):
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
) -> None: ) -> None:
for destination in destinations: for destination in destinations:
if destination == self.server_name: if self._is_mine_server_name(destination):
continue continue
try: try:

View File

@ -129,6 +129,7 @@ class FederationServer(FederationBase):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.server_name = hs.hostname
self.handler = hs.get_federation_handler() self.handler = hs.get_federation_handler()
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self._federation_event_handler = hs.get_federation_event_handler() self._federation_event_handler = hs.get_federation_event_handler()
@ -942,7 +943,7 @@ class FederationServer(FederationBase):
authorising_server = get_domain_from_id( authorising_server = get_domain_from_id(
event.content[EventContentFields.AUTHORISING_USER] event.content[EventContentFields.AUTHORISING_USER]
) )
if authorising_server != self.server_name: if not self._is_mine_server_name(authorising_server):
raise SynapseError( raise SynapseError(
400, 400,
f"Cannot authorise request from resident server: {authorising_server}", f"Cannot authorise request from resident server: {authorising_server}",

View File

@ -68,6 +68,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name
# We may have multiple federation sender instances, so we need to track # We may have multiple federation sender instances, so we need to track
# their positions separately. # their positions separately.
@ -198,7 +199,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
key: Optional[Hashable] = None, key: Optional[Hashable] = None,
) -> None: ) -> None:
"""As per FederationSender""" """As per FederationSender"""
if destination == self.server_name: if self.is_mine_server_name(destination):
logger.info("Not sending EDU to ourselves") logger.info("Not sending EDU to ourselves")
return return

View File

@ -362,6 +362,7 @@ class FederationSender(AbstractFederationSender):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name
self._presence_router: Optional["PresenceRouter"] = None self._presence_router: Optional["PresenceRouter"] = None
self._transaction_manager = TransactionManager(hs) self._transaction_manager = TransactionManager(hs)
@ -766,7 +767,7 @@ class FederationSender(AbstractFederationSender):
domains = [ domains = [
d d
for d in domains_set for d in domains_set
if d != self.server_name if not self.is_mine_server_name(d)
and self._federation_shard_config.should_handle(self._instance_name, d) and self._federation_shard_config.should_handle(self._instance_name, d)
] ]
if not domains: if not domains:
@ -832,7 +833,7 @@ class FederationSender(AbstractFederationSender):
assert self.is_mine_id(state.user_id) assert self.is_mine_id(state.user_id)
for destination in destinations: for destination in destinations:
if destination == self.server_name: if self.is_mine_server_name(destination):
continue continue
if not self._federation_shard_config.should_handle( if not self._federation_shard_config.should_handle(
self._instance_name, destination self._instance_name, destination
@ -860,7 +861,7 @@ class FederationSender(AbstractFederationSender):
content: content of EDU content: content of EDU
key: clobbering key for this edu key: clobbering key for this edu
""" """
if destination == self.server_name: if self.is_mine_server_name(destination):
logger.info("Not sending EDU to ourselves") logger.info("Not sending EDU to ourselves")
return return
@ -897,7 +898,7 @@ class FederationSender(AbstractFederationSender):
queue.send_edu(edu) queue.send_edu(edu)
def send_device_messages(self, destination: str, immediate: bool = True) -> None: def send_device_messages(self, destination: str, immediate: bool = True) -> None:
if destination == self.server_name: if self.is_mine_server_name(destination):
logger.warning("Not sending device update to ourselves") logger.warning("Not sending device update to ourselves")
return return
@ -919,7 +920,7 @@ class FederationSender(AbstractFederationSender):
might have come back. might have come back.
""" """
if destination == self.server_name: if self.is_mine_server_name(destination):
logger.warning("Not waking up ourselves") logger.warning("Not waking up ourselves")
return return

View File

@ -58,9 +58,9 @@ class TransportLayerClient:
"""Sends federation HTTP requests to other servers""" """Sends federation HTTP requests to other servers"""
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client() self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
self._is_mine_server_name = hs.is_mine_server_name
async def get_room_state_ids( async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str self, destination: str, room_id: str, event_id: str
@ -235,7 +235,7 @@ class TransportLayerClient:
transaction.transaction_id, transaction.transaction_id,
) )
if transaction.destination == self.server_name: if self._is_mine_server_name(transaction.destination):
raise RuntimeError("Transport layer cannot send to itself!") raise RuntimeError("Transport layer cannot send to itself!")
# FIXME: This is only used by the tests. The actual json sent is # FIXME: This is only used by the tests. The actual json sent is

View File

@ -57,6 +57,7 @@ class Authenticator:
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.server_name = hs.hostname self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.federation_domain_whitelist = ( self.federation_domain_whitelist = (
hs.config.federation.federation_domain_whitelist hs.config.federation.federation_domain_whitelist
@ -100,7 +101,9 @@ class Authenticator:
json_request["signatures"].setdefault(origin, {})[key] = sig json_request["signatures"].setdefault(origin, {})[key] = sig
# if the origin_server sent a destination along it needs to match our own server_name # if the origin_server sent a destination along it needs to match our own server_name
if destination is not None and destination != self.server_name: if destination is not None and not self._is_mine_server_name(
destination
):
raise AuthenticationError( raise AuthenticationError(
HTTPStatus.UNAUTHORIZED, HTTPStatus.UNAUTHORIZED,
"Destination mismatch in auth header", "Destination mismatch in auth header",

View File

@ -29,7 +29,7 @@ from synapse.event_auth import (
) )
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.builder import EventBuilder from synapse.events.builder import EventBuilder
from synapse.types import StateMap, StrCollection, get_domain_from_id from synapse.types import StateMap, StrCollection
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -47,6 +47,7 @@ class EventAuthHandler:
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self._state_storage_controller = hs.get_storage_controllers().state self._state_storage_controller = hs.get_storage_controllers().state
self._server_name = hs.hostname self._server_name = hs.hostname
self._is_mine_id = hs.is_mine_id
async def check_auth_rules_from_context( async def check_auth_rules_from_context(
self, self,
@ -247,7 +248,7 @@ class EventAuthHandler:
if not await self.is_user_in_rooms(allowed_rooms, user_id): if not await self.is_user_in_rooms(allowed_rooms, user_id):
# If this is a remote request, the user might be in an allowed room # If this is a remote request, the user might be in an allowed room
# that we do not know about. # that we do not know about.
if get_domain_from_id(user_id) != self._server_name: if not self._is_mine_id(user_id):
for room_id in allowed_rooms: for room_id in allowed_rooms:
if not await self._store.is_host_joined(room_id, self._server_name): if not await self._store.is_host_joined(room_id, self._server_name):
raise SynapseError( raise SynapseError(

View File

@ -141,6 +141,7 @@ class FederationHandler:
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@ -453,7 +454,7 @@ class FederationHandler:
for dom in domains: for dom in domains:
# We don't want to ask our own server for information we don't have # We don't want to ask our own server for information we don't have
if dom == self.server_name: if self.is_mine_server_name(dom):
continue continue
try: try:

View File

@ -163,6 +163,7 @@ class FederationEventHandler:
self._notifier = hs.get_notifier() self._notifier = hs.get_notifier()
self._is_mine_id = hs.is_mine_id self._is_mine_id = hs.is_mine_id
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname self._server_name = hs.hostname
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
@ -688,7 +689,7 @@ class FederationEventHandler:
server from invalid events (there is probably no point in trying to server from invalid events (there is probably no point in trying to
re-fetch invalid events from every other HS in the room.) re-fetch invalid events from every other HS in the room.)
""" """
if dest == self._server_name: if self._is_mine_server_name(dest):
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
events = await self._federation_client.backfill( events = await self._federation_client.backfill(

View File

@ -59,7 +59,7 @@ class ProfileHandler:
self.max_avatar_size = hs.config.server.max_avatar_size self.max_avatar_size = hs.config.server.max_avatar_size
self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes
self.server_name = hs.config.server.server_name self._is_mine_server_name = hs.is_mine_server_name
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
@ -309,7 +309,7 @@ class ProfileHandler:
else: else:
server_name = host server_name = host
if server_name == self.server_name: if self._is_mine_server_name(server_name):
media_info = await self.store.get_local_media(media_id) media_info = await self.store.get_local_media(media_id)
else: else:
media_info = await self.store.get_cached_remote_media(server_name, media_id) media_info = await self.store.get_cached_remote_media(server_name, media_id)

View File

@ -194,6 +194,7 @@ class SsoHandler:
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self._server_name = hs.hostname self._server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()
@ -802,7 +803,7 @@ class SsoHandler:
if profile["avatar_url"] is not None: if profile["avatar_url"] is not None:
server_name = profile["avatar_url"].split("/")[-2] server_name = profile["avatar_url"].split("/")[-2]
media_id = profile["avatar_url"].split("/")[-1] media_id = profile["avatar_url"].split("/")[-1]
if server_name == self._server_name: if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id) media = await self._media_repo.store.get_local_media(media_id)
if media is not None and upload_name == media["upload_name"]: if media is not None and upload_name == media["upload_name"]:
logger.info("skipping saving the user avatar") logger.info("skipping saving the user avatar")

View File

@ -68,6 +68,7 @@ class FollowerTypingHandler:
self.server_name = hs.config.server.server_name self.server_name = hs.config.server.server_name
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name
self.federation = None self.federation = None
if hs.should_send_federation(): if hs.should_send_federation():
@ -153,7 +154,7 @@ class FollowerTypingHandler:
member.room_id member.room_id
) )
for domain in hosts: for domain in hosts:
if domain != self.server_name: if not self.is_mine_server_name(domain):
logger.debug("sending typing update to %s", domain) logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu( self.federation.build_and_send_edu(
destination=domain, destination=domain,

View File

@ -258,7 +258,7 @@ class DeleteMediaByID(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.server_name = hs.hostname self._is_mine_server_name = hs.is_mine_server_name
self.media_repository = hs.get_media_repository() self.media_repository = hs.get_media_repository()
async def on_DELETE( async def on_DELETE(
@ -266,7 +266,7 @@ class DeleteMediaByID(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name: if not self._is_mine_server_name(server_name):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
if await self.store.get_local_media(media_id) is None: if await self.store.get_local_media(media_id) is None:

View File

@ -501,7 +501,7 @@ class PublicRoomListRestServlet(RestServlet):
limit = None limit = None
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server.server_name: if server and not self.hs.is_mine_server_name(server):
# Ensure the server is valid. # Ensure the server is valid.
try: try:
parse_and_validate_server_name(server) parse_and_validate_server_name(server)
@ -551,7 +551,7 @@ class PublicRoomListRestServlet(RestServlet):
limit = None limit = None
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server.server_name: if server and not self.hs.is_mine_server_name(server):
# Ensure the server is valid. # Ensure the server is valid.
try: try:
parse_and_validate_server_name(server) parse_and_validate_server_name(server)

View File

@ -37,7 +37,7 @@ class DownloadResource(DirectServeJsonResource):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__() super().__init__()
self.media_repo = media_repo self.media_repo = media_repo
self.server_name = hs.hostname self._is_mine_server_name = hs.is_mine_server_name
async def _async_render_GET(self, request: SynapseRequest) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request) set_cors_headers(request)
@ -59,7 +59,7 @@ class DownloadResource(DirectServeJsonResource):
b"no-referrer", b"no-referrer",
) )
server_name, media_id, name = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if self._is_mine_server_name(server_name):
await self.media_repo.get_local_media(request, media_id, name) await self.media_repo.get_local_media(request, media_id, name)
else: else:
allow_remote = parse_boolean(request, "allow_remote", default=True) allow_remote = parse_boolean(request, "allow_remote", default=True)

View File

@ -59,7 +59,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_repo = media_repo self.media_repo = media_repo
self.media_storage = media_storage self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self.server_name = hs.hostname self._is_mine_server_name = hs.is_mine_server_name
async def _async_render_GET(self, request: SynapseRequest) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request) set_cors_headers(request)
@ -71,7 +71,7 @@ class ThumbnailResource(DirectServeJsonResource):
# TODO Parse the Accept header to get an prioritised list of thumbnail types. # TODO Parse the Accept header to get an prioritised list of thumbnail types.
m_type = "image/png" m_type = "image/png"
if server_name == self.server_name: if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails: if self.dynamic_thumbnails:
await self._select_or_generate_local_thumbnail( await self._select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type request, media_id, width, height, method, m_type

View File

@ -377,6 +377,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return False return False
return localpart_hostname[1] == self.hostname return localpart_hostname[1] == self.hostname
def is_mine_server_name(self, server_name: str) -> bool:
"""Determines whether a server name refers to this homeserver."""
return server_name == self.hostname
@cache_in_self @cache_in_self
def get_clock(self) -> Clock: def get_clock(self) -> Clock:
return Clock(self._reactor) return Clock(self._reactor)

View File

@ -996,7 +996,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
If it is `None` media will be removed from quarantine If it is `None` media will be removed from quarantine
""" """
logger.info("Quarantining media: %s/%s", server_name, media_id) logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server.server_name is_local = self.hs.is_mine_server_name(server_name)
def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int: def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
local_mxcs = [media_id] if is_local else [] local_mxcs = [media_id] if is_local else []

View File

@ -566,7 +566,9 @@ class HomeserverTestCase(TestCase):
client_ip, client_ip,
) )
def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer: def setup_test_homeserver(
self, name: Optional[str] = None, **kwargs: Any
) -> HomeServer:
""" """
Set up the test homeserver, meant to be called by the overridable Set up the test homeserver, meant to be called by the overridable
make_homeserver. It automatically passes through the test class's make_homeserver. It automatically passes through the test class's
@ -585,15 +587,25 @@ class HomeserverTestCase(TestCase):
else: else:
config = kwargs["config"] config = kwargs["config"]
# The server name can be specified using either the `name` argument or a config
# override. The `name` argument takes precedence over any config overrides.
if name is not None:
config["server_name"] = name
# Parse the config from a config dict into a HomeServerConfig # Parse the config from a config dict into a HomeServerConfig
config_obj = make_homeserver_config_obj(config) config_obj = make_homeserver_config_obj(config)
kwargs["config"] = config_obj kwargs["config"] = config_obj
# The server name in the config is now `name`, if provided, or the `server_name`
# from a config override, or the default of "test". Whichever it is, we
# construct a homeserver with a matching name.
kwargs["name"] = config_obj.server.server_name
async def run_bg_updates() -> None: async def run_bg_updates() -> None:
with LoggingContext("run_bg_updates"): with LoggingContext("run_bg_updates"):
self.get_success(stor.db_pool.updates.run_background_updates(False)) self.get_success(stor.db_pool.updates.run_background_updates(False))
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) hs = setup_test_homeserver(self.addCleanup, **kwargs)
stor = hs.get_datastores().main stor = hs.get_datastores().main
# Run the database background updates, when running against "master". # Run the database background updates, when running against "master".