Fixup slave stores

This commit is contained in:
Erik Johnston 2019-03-04 18:03:29 +00:00
parent aba5eeabd5
commit a84b8d56c2
6 changed files with 572 additions and 577 deletions

View file

@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage import DataStore
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
from ._slaved_id_tracker import SlavedIdTracker
class SlavedDeviceInboxStore(BaseSlavedStore):
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
@ -43,12 +42,6 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token)
get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device)
get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote)
delete_messages_for_device = __func__(DataStore.delete_messages_for_device)
delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote)
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()

View file

@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage import DataStore
from synapse.storage.end_to_end_keys import EndToEndKeyStore
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.devices import DeviceWorkerStore
from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
from ._slaved_id_tracker import SlavedIdTracker
class SlavedDeviceStore(BaseSlavedStore):
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
@ -38,17 +37,6 @@ class SlavedDeviceStore(BaseSlavedStore):
"DeviceListFederationStreamChangeCache", device_list_max,
)
get_device_stream_token = __func__(DataStore.get_device_stream_token)
get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
_get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
_get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
_mark_as_sent_devices_by_remote_txn = (
__func__(DataStore._mark_as_sent_devices_by_remote_txn)
)
count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
@ -58,14 +46,23 @@ class SlavedDeviceStore(BaseSlavedStore):
if stream_name == "device_lists":
self._device_list_id_gen.advance(token)
for row in rows:
self._device_list_stream_cache.entity_has_changed(
row.user_id, token
self._invalidate_caches_for_devices(
token, row.user_id, row.destination,
)
if row.destination:
self._device_list_federation_stream_cache.entity_has_changed(
row.destination, token
)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
def _invalidate_caches_for_devices(self, token, user_id, destination):
self._device_list_stream_cache.entity_has_changed(
user_id, token
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
self._get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))

View file

@ -20,7 +20,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def __init__(self, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",