mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-06 16:34:09 -04:00
Use async with
for ID gens (#8383)
This will allow us to hit the DB after we've finished using the generated stream ID.
This commit is contained in:
parent
916bb9d0d1
commit
cbabb312e0
15 changed files with 144 additions and 105 deletions
|
@ -12,14 +12,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Dict, List, Set
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
import attr
|
||||
from typing_extensions import Deque
|
||||
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
|
@ -86,7 +86,7 @@ class StreamIdGenerator:
|
|||
upwards, -1 to grow downwards.
|
||||
|
||||
Usage:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
|
@ -101,10 +101,10 @@ class StreamIdGenerator:
|
|||
)
|
||||
self._unfinished_ids = deque() # type: Deque[int]
|
||||
|
||||
async def get_next(self):
|
||||
def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
with self._lock:
|
||||
|
@ -113,7 +113,7 @@ class StreamIdGenerator:
|
|||
|
||||
self._unfinished_ids.append(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
@contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_id
|
||||
|
@ -121,12 +121,12 @@ class StreamIdGenerator:
|
|||
with self._lock:
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
return manager()
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
async def get_next_mult(self, n):
|
||||
def get_next_mult(self, n):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next(n) as stream_ids:
|
||||
async with stream_id_gen.get_next(n) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
with self._lock:
|
||||
|
@ -140,7 +140,7 @@ class StreamIdGenerator:
|
|||
for next_id in next_ids:
|
||||
self._unfinished_ids.append(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
@contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
|
@ -149,7 +149,7 @@ class StreamIdGenerator:
|
|||
for next_id in next_ids:
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
return manager()
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_current_token(self):
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
|
@ -282,59 +282,23 @@ class MultiWriterIdGenerator:
|
|||
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
||||
return self._sequence_gen.get_next_mult_txn(txn, n)
|
||||
|
||||
async def get_next(self):
|
||||
def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
|
||||
|
||||
# Assert the fetched ID is actually greater than what we currently
|
||||
# believe the ID to be. If not, then the sequence and table have got
|
||||
# out of sync somehow.
|
||||
with self._lock:
|
||||
assert self._current_positions.get(self._instance_name, 0) < next_id
|
||||
return _MultiWriterCtxManager(self)
|
||||
|
||||
self._unfinished_ids.add(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
# Multiply by the return factor so that the ID has correct sign.
|
||||
yield self._return_factor * next_id
|
||||
finally:
|
||||
self._mark_id_as_finished(next_id)
|
||||
|
||||
return manager()
|
||||
|
||||
async def get_next_mult(self, n: int):
|
||||
def get_next_mult(self, n: int):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
async with stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
next_ids = await self._db.runInteraction(
|
||||
"_load_next_mult_id", self._load_next_mult_id_txn, n
|
||||
)
|
||||
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
with self._lock:
|
||||
assert max(self._current_positions.values(), default=0) < min(next_ids)
|
||||
|
||||
self._unfinished_ids.update(next_ids)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield [self._return_factor * i for i in next_ids]
|
||||
finally:
|
||||
for i in next_ids:
|
||||
self._mark_id_as_finished(i)
|
||||
|
||||
return manager()
|
||||
return _MultiWriterCtxManager(self, n)
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction):
|
||||
"""
|
||||
|
@ -482,3 +446,61 @@ class MultiWriterIdGenerator:
|
|||
# There was a gap in seen positions, so there is nothing more to
|
||||
# do.
|
||||
break
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _AsyncCtxManagerWrapper:
|
||||
"""Helper class to convert a plain context manager to an async one.
|
||||
|
||||
This is mainly useful if you have a plain context manager but the interface
|
||||
requires an async one.
|
||||
"""
|
||||
|
||||
inner = attr.ib()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.inner.__enter__()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return self.inner.__exit__(exc_type, exc, tb)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _MultiWriterCtxManager:
|
||||
"""Async context manager returned by MultiWriterIdGenerator
|
||||
"""
|
||||
|
||||
id_gen = attr.ib(type=MultiWriterIdGenerator)
|
||||
multiple_ids = attr.ib(type=Optional[int], default=None)
|
||||
stream_ids = attr.ib(type=List[int], factory=list)
|
||||
|
||||
async def __aenter__(self) -> Union[int, List[int]]:
|
||||
self.stream_ids = await self.id_gen._db.runInteraction(
|
||||
"_load_next_mult_id",
|
||||
self.id_gen._load_next_mult_id_txn,
|
||||
self.multiple_ids or 1,
|
||||
)
|
||||
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
with self.id_gen._lock:
|
||||
assert max(self.id_gen._current_positions.values(), default=0) < min(
|
||||
self.stream_ids
|
||||
)
|
||||
|
||||
self.id_gen._unfinished_ids.update(self.stream_ids)
|
||||
|
||||
if self.multiple_ids is None:
|
||||
return self.stream_ids[0] * self.id_gen._return_factor
|
||||
else:
|
||||
return [i * self.id_gen._return_factor for i in self.stream_ids]
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
for i in self.stream_ids:
|
||||
self.id_gen._mark_id_as_finished(i)
|
||||
|
||||
if exc_type is not None:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue