Add ability to run multiple pusher instances (#7855)

This reuses the same scheme as federation sender sharding
This commit is contained in:
Erik Johnston 2020-07-16 14:06:28 +01:00 committed by GitHub
parent a827838706
commit 649a7ead5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 293 additions and 82 deletions

1
changelog.d/7855.feature Normal file
View File

@ -0,0 +1 @@
Add experimental support for running multiple pusher workers.

View File

@ -19,9 +19,11 @@ import argparse
import errno import errno
import os import os
from collections import OrderedDict from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent from textwrap import dedent
from typing import Any, MutableMapping, Optional from typing import Any, List, MutableMapping, Optional
import attr
import yaml import yaml
@ -717,4 +719,36 @@ def find_config_files(search_paths):
return config_files return config_files
__all__ = ["Config", "RootConfig"] @attr.s
class ShardedWorkerHandlingConfig:
"""Algorithm for choosing which instance is responsible for handling some
sharded work.
For example, the federation senders use this to determine which instances
handles sending stuff to a given destination (which is used as the `key`
below).
"""
instances = attr.ib(type=List[str])
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.
"""
# If multiple instances are not defined we always return true.
if not self.instances or len(self.instances) == 1:
return True
# We shard by taking the hash, modulo it by the number of instances and
# then checking whether this instance matches the instance at that
# index.
#
# (Technically this introduces some bias and is not entirely uniform,
# but since the hash is so large the bias is ridiculously small).
dest_hash = sha256(key.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
return self.instances[remainder] == instance_name
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]

View File

@ -137,3 +137,8 @@ class Config:
def read_config_files(config_files: List[str]): ... def read_config_files(config_files: List[str]): ...
def find_config_files(search_paths: List[str]): ... def find_config_files(search_paths: List[str]): ...
class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...

View File

@ -13,42 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from hashlib import sha256 from typing import Optional
from typing import List, Optional
import attr
from netaddr import IPSet from netaddr import IPSet
from ._base import Config, ConfigError from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
@attr.s
class ShardedFederationSendingConfig:
"""Algorithm for choosing which federation sender instance is responsible
for which destionation host.
"""
instances = attr.ib(type=List[str])
def should_send_to(self, instance_name: str, destination: str) -> bool:
"""Whether this instance is responsible for sending transcations for
the given host.
"""
# If multiple federation senders are not defined we always return true.
if not self.instances or len(self.instances) == 1:
return True
# We shard by taking the hash, modulo it by the number of federation
# senders and then checking whether this instance matches the instance
# at that index.
#
# (Technically this introduces some bias and is not entirely uniform, but
# since the hash is so large the bias is ridiculously small).
dest_hash = sha256(destination.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
return self.instances[remainder] == instance_name
class FederationConfig(Config): class FederationConfig(Config):
@ -61,7 +30,7 @@ class FederationConfig(Config):
self.send_federation = config.get("send_federation", True) self.send_federation = config.get("send_federation", True)
federation_sender_instances = config.get("federation_sender_instances") or [] federation_sender_instances = config.get("federation_sender_instances") or []
self.federation_shard_config = ShardedFederationSendingConfig( self.federation_shard_config = ShardedWorkerHandlingConfig(
federation_sender_instances federation_sender_instances
) )

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config, ShardedWorkerHandlingConfig
class PushConfig(Config): class PushConfig(Config):
@ -24,6 +24,9 @@ class PushConfig(Config):
push_config = config.get("push", {}) push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True) self.push_include_content = push_config.get("include_content", True)
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
# There was a a 'redact_content' setting but mistakenly read from the # There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log, # 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade. # but do not honour it to avoid nasty surprises when people upgrade.

View File

@ -197,7 +197,7 @@ class FederationSender(object):
destinations = { destinations = {
d d
for d in destinations for d in destinations
if self._federation_shard_config.should_send_to( if self._federation_shard_config.should_handle(
self._instance_name, d self._instance_name, d
) )
} }
@ -335,7 +335,7 @@ class FederationSender(object):
d d
for d in domains for d in domains
if d != self.server_name if d != self.server_name
and self._federation_shard_config.should_send_to(self._instance_name, d) and self._federation_shard_config.should_handle(self._instance_name, d)
] ]
if not domains: if not domains:
return return
@ -441,7 +441,7 @@ class FederationSender(object):
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
if not self._federation_shard_config.should_send_to( if not self._federation_shard_config.should_handle(
self._instance_name, destination self._instance_name, destination
): ):
continue continue
@ -460,7 +460,7 @@ class FederationSender(object):
if destination == self.server_name: if destination == self.server_name:
continue continue
if not self._federation_shard_config.should_send_to( if not self._federation_shard_config.should_handle(
self._instance_name, destination self._instance_name, destination
): ):
continue continue
@ -486,7 +486,7 @@ class FederationSender(object):
logger.info("Not sending EDU to ourselves") logger.info("Not sending EDU to ourselves")
return return
if not self._federation_shard_config.should_send_to( if not self._federation_shard_config.should_handle(
self._instance_name, destination self._instance_name, destination
): ):
return return
@ -507,7 +507,7 @@ class FederationSender(object):
edu: edu to send edu: edu to send
key: clobbering key for this edu key: clobbering key for this edu
""" """
if not self._federation_shard_config.should_send_to( if not self._federation_shard_config.should_handle(
self._instance_name, edu.destination self._instance_name, edu.destination
): ):
return return
@ -523,7 +523,7 @@ class FederationSender(object):
logger.warning("Not sending device update to ourselves") logger.warning("Not sending device update to ourselves")
return return
if not self._federation_shard_config.should_send_to( if not self._federation_shard_config.should_handle(
self._instance_name, destination self._instance_name, destination
): ):
return return
@ -541,7 +541,7 @@ class FederationSender(object):
logger.warning("Not waking up ourselves") logger.warning("Not waking up ourselves")
return return
if not self._federation_shard_config.should_send_to( if not self._federation_shard_config.should_handle(
self._instance_name, destination self._instance_name, destination
): ):
return return

View File

@ -78,7 +78,7 @@ class PerDestinationQueue(object):
self._federation_shard_config = hs.config.federation.federation_shard_config self._federation_shard_config = hs.config.federation.federation_shard_config
self._should_send_on_this_instance = True self._should_send_on_this_instance = True
if not self._federation_shard_config.should_send_to( if not self._federation_shard_config.should_handle(
self._instance_name, destination self._instance_name, destination
): ):
# We don't raise an exception here to avoid taking out any other # We don't raise an exception here to avoid taking out any other

View File

@ -15,13 +15,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import defaultdict from typing import TYPE_CHECKING, Dict, Union
from threading import Lock
from typing import Dict, Tuple, Union from prometheus_client import Gauge
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher from synapse.push.emailpusher import EmailPusher
@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher
from synapse.push.pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
synapse_pushers = Gauge(
"synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"]
)
class PusherPool: class PusherPool:
""" """
The pusher pool. This is responsible for dispatching notifications of new events to The pusher pool. This is responsible for dispatching notifications of new events to
@ -47,36 +55,20 @@ class PusherPool:
Pusher.on_new_receipts are not expected to return deferreds. Pusher.on_new_receipts are not expected to return deferreds.
""" """
def __init__(self, _hs): def __init__(self, hs: "HomeServer"):
self.hs = _hs self.hs = hs
self.pusher_factory = PusherFactory(_hs) self.pusher_factory = PusherFactory(hs)
self._should_start_pushers = _hs.config.start_pushers self._should_start_pushers = hs.config.start_pushers
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
# map from user id to app_id:pushkey to pusher # map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
# a lock for the pushers dict, since `count_pushers` is called from an different
# and we otherwise get concurrent modification errors
self._pushers_lock = Lock()
def count_pushers():
results = defaultdict(int) # type: Dict[Tuple[str, str], int]
with self._pushers_lock:
for pushers in self.pushers.values():
for pusher in pushers.values():
k = (type(pusher).__name__, pusher.app_id)
results[k] += 1
return results
LaterGauge(
name="synapse_pushers",
desc="the number of active pushers",
labels=["kind", "app_id"],
caller=count_pushers,
)
def start(self): def start(self):
"""Starts the pushers off in a background process. """Starts the pushers off in a background process.
""" """
@ -104,6 +96,7 @@ class PusherPool:
Returns: Returns:
Deferred[EmailPusher|HttpPusher] Deferred[EmailPusher|HttpPusher]
""" """
time_now_msec = self.clock.time_msec() time_now_msec = self.clock.time_msec()
# we try to create the pusher just to validate the config: it # we try to create the pusher just to validate the config: it
@ -176,6 +169,9 @@ class PusherPool:
access_tokens (Iterable[int]): access token *ids* to remove pushers access_tokens (Iterable[int]): access token *ids* to remove pushers
for for
""" """
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
tokens = set(access_tokens) tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)): for p in (yield self.store.get_pushers_by_user_id(user_id)):
if p["access_token"] in tokens: if p["access_token"] in tokens:
@ -237,6 +233,9 @@ class PusherPool:
if not self._should_start_pushers: if not self._should_start_pushers:
return return
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None pusher_dict = None
@ -275,6 +274,11 @@ class PusherPool:
Returns: Returns:
Deferred[EmailPusher|HttpPusher] Deferred[EmailPusher|HttpPusher]
""" """
if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"]
):
return
try: try:
p = self.pusher_factory.create_pusher(pusherdict) p = self.pusher_factory.create_pusher(pusherdict)
except PusherConfigException as e: except PusherConfigException as e:
@ -298,11 +302,12 @@ class PusherPool:
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
with self._pushers_lock: byuser = self.pushers.setdefault(pusherdict["user_name"], {})
byuser = self.pushers.setdefault(pusherdict["user_name"], {}) if appid_pushkey in byuser:
if appid_pushkey in byuser: byuser[appid_pushkey].on_stop()
byuser[appid_pushkey].on_stop() byuser[appid_pushkey] = p
byuser[appid_pushkey] = p
synapse_pushers.labels(type(p).__name__, p.app_id).inc()
# Check if there *may* be push to process. We do this as this check is a # Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to # lot cheaper to do than actually fetching the exact rows we need to
@ -330,9 +335,10 @@ class PusherPool:
if appid_pushkey in byuser: if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
byuser[appid_pushkey].on_stop() pusher = byuser.pop(appid_pushkey)
with self._pushers_lock: pusher.on_stop()
del byuser[appid_pushkey]
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
yield self.store.delete_pusher_by_app_id_pushkey_user_id( yield self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id app_id, pushkey, user_id

View File

@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from mock import Mock
from twisted.internet import defer
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests.replication._base import BaseMultiWorkerStreamTestCase
logger = logging.getLogger(__name__)
class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks pusher sharding works
"""
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
# Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
def default_config(self):
conf = super().default_config()
conf["start_pushers"] = False
return conf
def _create_pusher_and_send_msg(self, localpart):
# Create a user that will get push notifications
user_id = self.register_user(localpart, "pass")
access_token = self.login(localpart, "pass")
# Register a pusher
user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_dict["token_id"]
self.get_success(
self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
app_id="m.http",
app_display_name="HTTP Push Notifications",
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "https://push.example.com/push"},
)
)
self.pump()
# Create a room
room = self.helper.create_room_as(user_id, tok=access_token)
# The other user joins
self.helper.join(
room=room, user=self.other_user_id, tok=self.other_access_token
)
# The other user sends some messages
response = self.helper.send(room, body="Hi!", tok=self.other_access_token)
event_id = response["event_id"]
return event_id
def test_send_push_single_worker(self):
"""Test that registration works when using a pusher worker.
"""
http_client_mock = Mock(spec_set=["post_json_get_json"])
http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
{}
)
self.make_worker_hs(
"synapse.app.pusher",
{"start_pushers": True},
proxied_http_client=http_client_mock,
)
event_id = self._create_pusher_and_send_msg("user")
# Advance time a bit, so the pusher will register something has happened
self.pump()
http_client_mock.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock.post_json_get_json.call_args[0][0],
"https://push.example.com/push",
)
self.assertEqual(
event_id,
http_client_mock.post_json_get_json.call_args[0][1]["notification"][
"event_id"
],
)
def test_send_push_multiple_workers(self):
"""Test that registration works when using sharded pusher workers.
"""
http_client_mock1 = Mock(spec_set=["post_json_get_json"])
http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
{}
)
self.make_worker_hs(
"synapse.app.pusher",
{
"start_pushers": True,
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
proxied_http_client=http_client_mock1,
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
{}
)
self.make_worker_hs(
"synapse.app.pusher",
{
"start_pushers": True,
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
proxied_http_client=http_client_mock2,
)
# We choose a user name that we know should go to pusher1.
event_id = self._create_pusher_and_send_msg("user2")
# Advance time a bit, so the pusher will register something has happened
self.pump()
http_client_mock1.post_json_get_json.assert_called_once()
http_client_mock2.post_json_get_json.assert_not_called()
self.assertEqual(
http_client_mock1.post_json_get_json.call_args[0][0],
"https://push.example.com/push",
)
self.assertEqual(
event_id,
http_client_mock1.post_json_get_json.call_args[0][1]["notification"][
"event_id"
],
)
http_client_mock1.post_json_get_json.reset_mock()
http_client_mock2.post_json_get_json.reset_mock()
# Now we choose a user name that we know should go to pusher2.
event_id = self._create_pusher_and_send_msg("user4")
# Advance time a bit, so the pusher will register something has happened
self.pump()
http_client_mock1.post_json_get_json.assert_not_called()
http_client_mock2.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock2.post_json_get_json.call_args[0][0],
"https://push.example.com/push",
)
self.assertEqual(
event_id,
http_client_mock2.post_json_get_json.call_args[0][1]["notification"][
"event_id"
],
)