# -*- coding: utf-8 -*- # Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import threading from typing import Callable, List, Optional from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.types import Cursor class SequenceGenerator(metaclass=abc.ABCMeta): """A class which generates a unique sequence of integers""" @abc.abstractmethod def get_next_id_txn(self, txn: Cursor) -> int: """Gets the next ID in the sequence""" ... class PostgresSequenceGenerator(SequenceGenerator): """An implementation of SequenceGenerator which uses a postgres sequence""" def __init__(self, sequence_name: str): self._sequence_name = sequence_name def get_next_id_txn(self, txn: Cursor) -> int: txn.execute("SELECT nextval(?)", (self._sequence_name,)) return txn.fetchone()[0] def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: txn.execute( "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n) ) return [i for (i,) in txn] GetFirstCallbackType = Callable[[Cursor], int] class LocalSequenceGenerator(SequenceGenerator): """An implementation of SequenceGenerator which uses local locking This only works reliably if there are no other worker processes generating IDs at the same time. """ def __init__(self, get_first_callback: GetFirstCallbackType): """ Args: get_first_callback: a callback which is called on the first call to get_next_id_txn; should return the curreent maximum id """ # the callback. this is cleared after it is called, so that it can be GCed. self._callback = get_first_callback # type: Optional[GetFirstCallbackType] # The current max value, or None if we haven't looked in the DB yet. self._current_max_id = None # type: Optional[int] self._lock = threading.Lock() def get_next_id_txn(self, txn: Cursor) -> int: # We do application locking here since if we're using sqlite then # we are a single process synapse. with self._lock: if self._current_max_id is None: assert self._callback is not None self._current_max_id = self._callback(txn) self._callback = None self._current_max_id += 1 return self._current_max_id def build_sequence_generator( database_engine: BaseDatabaseEngine, get_first_callback: GetFirstCallbackType, sequence_name: str, ) -> SequenceGenerator: """Get the best impl of SequenceGenerator available This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on sqlite. Args: database_engine: the database engine we are connected to get_first_callback: a callback which gets the next sequence ID. Used if we're on sqlite. sequence_name: the name of a postgres sequence to use. """ if isinstance(database_engine, PostgresEngine): return PostgresSequenceGenerator(sequence_name) else: return LocalSequenceGenerator(get_first_callback)