Use async with for ID gens (#8383)

This will allow us to hit the DB after we've finished using the generated stream ID.
This commit is contained in:
Erik Johnston 2020-09-23 16:11:18 +01:00 committed by GitHub
parent 916bb9d0d1
commit cbabb312e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 144 additions and 105 deletions

View file

@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import heapq
import logging
import threading
from collections import deque
from typing import Dict, List, Set
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Union
import attr
from typing_extensions import Deque
from synapse.storage.database import DatabasePool, LoggingTransaction
@ -86,7 +86,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards.
Usage:
with await stream_id_gen.get_next() as stream_id:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@ -101,10 +101,10 @@ class StreamIdGenerator:
)
self._unfinished_ids = deque() # type: Deque[int]
async def get_next(self):
def get_next(self):
"""
Usage:
with await stream_id_gen.get_next() as stream_id:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@ -113,7 +113,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id)
@contextlib.contextmanager
@contextmanager
def manager():
try:
yield next_id
@ -121,12 +121,12 @@ class StreamIdGenerator:
with self._lock:
self._unfinished_ids.remove(next_id)
return manager()
return _AsyncCtxManagerWrapper(manager())
async def get_next_mult(self, n):
def get_next_mult(self, n):
"""
Usage:
with await stream_id_gen.get_next(n) as stream_ids:
async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@ -140,7 +140,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.append(next_id)
@contextlib.contextmanager
@contextmanager
def manager():
try:
yield next_ids
@ -149,7 +149,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
return manager()
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
@ -282,59 +282,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
async def get_next(self):
def get_next(self):
"""
Usage:
with await stream_id_gen.get_next() as stream_id:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
with self._lock:
assert self._current_positions.get(self._instance_name, 0) < next_id
return _MultiWriterCtxManager(self)
self._unfinished_ids.add(next_id)
@contextlib.contextmanager
def manager():
try:
# Multiply by the return factor so that the ID has correct sign.
yield self._return_factor * next_id
finally:
self._mark_id_as_finished(next_id)
return manager()
async def get_next_mult(self, n: int):
def get_next_mult(self, n: int):
"""
Usage:
with await stream_id_gen.get_next_mult(5) as stream_ids:
async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
next_ids = await self._db.runInteraction(
"_load_next_mult_id", self._load_next_mult_id_txn, n
)
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
with self._lock:
assert max(self._current_positions.values(), default=0) < min(next_ids)
self._unfinished_ids.update(next_ids)
@contextlib.contextmanager
def manager():
try:
yield [self._return_factor * i for i in next_ids]
finally:
for i in next_ids:
self._mark_id_as_finished(i)
return manager()
return _MultiWriterCtxManager(self, n)
def get_next_txn(self, txn: LoggingTransaction):
"""
@ -482,3 +446,61 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to
# do.
break
@attr.s(slots=True)
class _AsyncCtxManagerWrapper:
"""Helper class to convert a plain context manager to an async one.
This is mainly useful if you have a plain context manager but the interface
requires an async one.
"""
inner = attr.ib()
async def __aenter__(self):
return self.inner.__enter__()
async def __aexit__(self, exc_type, exc, tb):
return self.inner.__exit__(exc_type, exc, tb)
@attr.s(slots=True)
class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator
"""
id_gen = attr.ib(type=MultiWriterIdGenerator)
multiple_ids = attr.ib(type=Optional[int], default=None)
stream_ids = attr.ib(type=List[int], factory=list)
async def __aenter__(self) -> Union[int, List[int]]:
self.stream_ids = await self.id_gen._db.runInteraction(
"_load_next_mult_id",
self.id_gen._load_next_mult_id_txn,
self.multiple_ids or 1,
)
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
with self.id_gen._lock:
assert max(self.id_gen._current_positions.values(), default=0) < min(
self.stream_ids
)
self.id_gen._unfinished_ids.update(self.stream_ids)
if self.multiple_ids is None:
return self.stream_ids[0] * self.id_gen._return_factor
else:
return [i * self.id_gen._return_factor for i in self.stream_ids]
async def __aexit__(self, exc_type, exc, tb):
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
if exc_type is not None:
return False
return False