mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-12 02:24:20 -05:00
Use async with
for ID gens (#8383)
This will allow us to hit the DB after we've finished using the generated stream ID.
This commit is contained in:
parent
916bb9d0d1
commit
cbabb312e0
1
changelog.d/8383.misc
Normal file
1
changelog.d/8383.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Refactor ID generators to use `async with` syntax.
|
@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||||||
"""
|
"""
|
||||||
content_json = json_encoder.encode(content)
|
content_json = json_encoder.encode(content)
|
||||||
|
|
||||||
with await self._account_data_id_gen.get_next() as next_id:
|
async with self._account_data_id_gen.get_next() as next_id:
|
||||||
# no need to lock here as room_account_data has a unique constraint
|
# no need to lock here as room_account_data has a unique constraint
|
||||||
# on (user_id, room_id, account_data_type) so simple_upsert will
|
# on (user_id, room_id, account_data_type) so simple_upsert will
|
||||||
# retry if there is a conflict.
|
# retry if there is a conflict.
|
||||||
@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||||||
"""
|
"""
|
||||||
content_json = json_encoder.encode(content)
|
content_json = json_encoder.encode(content)
|
||||||
|
|
||||||
with await self._account_data_id_gen.get_next() as next_id:
|
async with self._account_data_id_gen.get_next() as next_id:
|
||||||
# no need to lock here as account_data has a unique constraint on
|
# no need to lock here as account_data has a unique constraint on
|
||||||
# (user_id, account_data_type) so simple_upsert will retry if
|
# (user_id, account_data_type) so simple_upsert will retry if
|
||||||
# there is a conflict.
|
# there is a conflict.
|
||||||
|
@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
|||||||
rows.append((destination, stream_id, now_ms, edu_json))
|
rows.append((destination, stream_id, now_ms, edu_json))
|
||||||
txn.executemany(sql, rows)
|
txn.executemany(sql, rows)
|
||||||
|
|
||||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||||
now_ms = self.clock.time_msec()
|
now_ms = self.clock.time_msec()
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||||
@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
|||||||
txn, stream_id, local_messages_by_user_then_device
|
txn, stream_id, local_messages_by_user_then_device
|
||||||
)
|
)
|
||||||
|
|
||||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||||
now_ms = self.clock.time_msec()
|
now_ms = self.clock.time_msec()
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_messages_from_remote_to_device_inbox",
|
"add_messages_from_remote_to_device_inbox",
|
||||||
|
@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
THe new stream ID.
|
THe new stream ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with await self._device_list_id_gen.get_next() as stream_id:
|
async with self._device_list_id_gen.get_next() as stream_id:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_user_sig_change_to_streams",
|
"add_user_sig_change_to_streams",
|
||||||
self._add_user_signature_change_txn,
|
self._add_user_signature_change_txn,
|
||||||
@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
if not device_ids:
|
if not device_ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
with await self._device_list_id_gen.get_next_mult(
|
async with self._device_list_id_gen.get_next_mult(
|
||||||
len(device_ids)
|
len(device_ids)
|
||||||
) as stream_ids:
|
) as stream_ids:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
return stream_ids[-1]
|
return stream_ids[-1]
|
||||||
|
|
||||||
context = get_active_span_text_map()
|
context = get_active_span_text_map()
|
||||||
with await self._device_list_id_gen.get_next_mult(
|
async with self._device_list_id_gen.get_next_mult(
|
||||||
len(hosts) * len(device_ids)
|
len(hosts) * len(device_ids)
|
||||||
) as stream_ids:
|
) as stream_ids:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||||||
key (dict): the key data
|
key (dict): the key data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with await self._cross_signing_id_gen.get_next() as stream_id:
|
async with self._cross_signing_id_gen.get_next() as stream_id:
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"add_e2e_cross_signing_key",
|
"add_e2e_cross_signing_key",
|
||||||
self._set_e2e_cross_signing_key_txn,
|
self._set_e2e_cross_signing_key_txn,
|
||||||
|
@ -156,15 +156,15 @@ class PersistEventsStore:
|
|||||||
# Note: Multiple instances of this function cannot be in flight at
|
# Note: Multiple instances of this function cannot be in flight at
|
||||||
# the same time for the same room.
|
# the same time for the same room.
|
||||||
if backfilled:
|
if backfilled:
|
||||||
stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
|
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
|
||||||
len(events_and_contexts)
|
len(events_and_contexts)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
stream_ordering_manager = await self._stream_id_gen.get_next_mult(
|
stream_ordering_manager = self._stream_id_gen.get_next_mult(
|
||||||
len(events_and_contexts)
|
len(events_and_contexts)
|
||||||
)
|
)
|
||||||
|
|
||||||
with stream_ordering_manager as stream_orderings:
|
async with stream_ordering_manager as stream_orderings:
|
||||||
for (event, context), stream in zip(events_and_contexts, stream_orderings):
|
for (event, context), stream in zip(events_and_contexts, stream_orderings):
|
||||||
event.internal_metadata.stream_ordering = stream
|
event.internal_metadata.stream_ordering = stream
|
||||||
|
|
||||||
|
@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||||||
|
|
||||||
return next_id
|
return next_id
|
||||||
|
|
||||||
with await self._group_updates_id_gen.get_next() as next_id:
|
async with self._group_updates_id_gen.get_next() as next_id:
|
||||||
res = await self.db_pool.runInteraction(
|
res = await self.db_pool.runInteraction(
|
||||||
"register_user_group_membership",
|
"register_user_group_membership",
|
||||||
_register_user_group_membership_txn,
|
_register_user_group_membership_txn,
|
||||||
|
@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
|
|||||||
|
|
||||||
class PresenceStore(SQLBaseStore):
|
class PresenceStore(SQLBaseStore):
|
||||||
async def update_presence(self, presence_states):
|
async def update_presence(self, presence_states):
|
||||||
stream_ordering_manager = await self._presence_id_gen.get_next_mult(
|
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
||||||
len(presence_states)
|
len(presence_states)
|
||||||
)
|
)
|
||||||
|
|
||||||
with stream_ordering_manager as stream_orderings:
|
async with stream_ordering_manager as stream_orderings:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"update_presence",
|
"update_presence",
|
||||||
self._update_presence_txn,
|
self._update_presence_txn,
|
||||||
|
@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||||||
) -> None:
|
) -> None:
|
||||||
conditions_json = json_encoder.encode(conditions)
|
conditions_json = json_encoder.encode(conditions)
|
||||||
actions_json = json_encoder.encode(actions)
|
actions_json = json_encoder.encode(actions)
|
||||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
if before or after:
|
if before or after:
|
||||||
@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||||||
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
||||||
)
|
)
|
||||||
|
|
||||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||||||
Raises:
|
Raises:
|
||||||
NotFoundError if the rule does not exist.
|
NotFoundError if the rule does not exist.
|
||||||
"""
|
"""
|
||||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"_set_push_rule_enabled_txn",
|
"_set_push_rule_enabled_txn",
|
||||||
@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||||||
data={"actions": actions_json},
|
data={"actions": actions_json},
|
||||||
)
|
)
|
||||||
|
|
||||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
|
|||||||
last_stream_ordering,
|
last_stream_ordering,
|
||||||
profile_tag="",
|
profile_tag="",
|
||||||
) -> None:
|
) -> None:
|
||||||
with await self._pushers_id_gen.get_next() as stream_id:
|
async with self._pushers_id_gen.get_next() as stream_id:
|
||||||
# no need to lock because `pushers` has a unique key on
|
# no need to lock because `pushers` has a unique key on
|
||||||
# (app_id, pushkey, user_name) so simple_upsert will retry
|
# (app_id, pushkey, user_name) so simple_upsert will retry
|
||||||
await self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with await self._pushers_id_gen.get_next() as stream_id:
|
async with self._pushers_id_gen.get_next() as stream_id:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"delete_pusher", delete_pusher_txn, stream_id
|
"delete_pusher", delete_pusher_txn, stream_id
|
||||||
)
|
)
|
||||||
|
@ -524,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
|||||||
"insert_receipt_conv", graph_to_linear
|
"insert_receipt_conv", graph_to_linear
|
||||||
)
|
)
|
||||||
|
|
||||||
with await self._receipts_id_gen.get_next() as stream_id:
|
async with self._receipts_id_gen.get_next() as stream_id:
|
||||||
event_ts = await self.db_pool.runInteraction(
|
event_ts = await self.db_pool.runInteraction(
|
||||||
"insert_linearized_receipt",
|
"insert_linearized_receipt",
|
||||||
self.insert_linearized_receipt_txn,
|
self.insert_linearized_receipt_txn,
|
||||||
|
@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with await self._public_room_id_gen.get_next() as next_id:
|
async with self._public_room_id_gen.get_next() as next_id:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"store_room_txn", store_room_txn, next_id
|
"store_room_txn", store_room_txn, next_id
|
||||||
)
|
)
|
||||||
@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with await self._public_room_id_gen.get_next() as next_id:
|
async with self._public_room_id_gen.get_next() as next_id:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"set_room_is_public", set_room_is_public_txn, next_id
|
"set_room_is_public", set_room_is_public_txn, next_id
|
||||||
)
|
)
|
||||||
@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with await self._public_room_id_gen.get_next() as next_id:
|
async with self._public_room_id_gen.get_next() as next_id:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"set_room_is_public_appservice",
|
"set_room_is_public_appservice",
|
||||||
set_room_is_public_appservice_txn,
|
set_room_is_public_appservice_txn,
|
||||||
|
@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
|
|||||||
)
|
)
|
||||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||||
|
|
||||||
with await self._account_data_id_gen.get_next() as next_id:
|
async with self._account_data_id_gen.get_next() as next_id:
|
||||||
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
||||||
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
self.get_tags_for_user.invalidate((user_id,))
|
||||||
@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
|
|||||||
txn.execute(sql, (user_id, room_id, tag))
|
txn.execute(sql, (user_id, room_id, tag))
|
||||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||||
|
|
||||||
with await self._account_data_id_gen.get_next() as next_id:
|
async with self._account_data_id_gen.get_next() as next_id:
|
||||||
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
||||||
|
|
||||||
self.get_tags_for_user.invalidate((user_id,))
|
self.get_tags_for_user.invalidate((user_id,))
|
||||||
|
@ -12,14 +12,14 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import heapq
|
import heapq
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Dict, List, Set
|
from contextlib import contextmanager
|
||||||
|
from typing import Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
|
import attr
|
||||||
from typing_extensions import Deque
|
from typing_extensions import Deque
|
||||||
|
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
@ -86,7 +86,7 @@ class StreamIdGenerator:
|
|||||||
upwards, -1 to grow downwards.
|
upwards, -1 to grow downwards.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
with await stream_id_gen.get_next() as stream_id:
|
async with stream_id_gen.get_next() as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -101,10 +101,10 @@ class StreamIdGenerator:
|
|||||||
)
|
)
|
||||||
self._unfinished_ids = deque() # type: Deque[int]
|
self._unfinished_ids = deque() # type: Deque[int]
|
||||||
|
|
||||||
async def get_next(self):
|
def get_next(self):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with await stream_id_gen.get_next() as stream_id:
|
async with stream_id_gen.get_next() as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@ -113,7 +113,7 @@ class StreamIdGenerator:
|
|||||||
|
|
||||||
self._unfinished_ids.append(next_id)
|
self._unfinished_ids.append(next_id)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextmanager
|
||||||
def manager():
|
def manager():
|
||||||
try:
|
try:
|
||||||
yield next_id
|
yield next_id
|
||||||
@ -121,12 +121,12 @@ class StreamIdGenerator:
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._unfinished_ids.remove(next_id)
|
self._unfinished_ids.remove(next_id)
|
||||||
|
|
||||||
return manager()
|
return _AsyncCtxManagerWrapper(manager())
|
||||||
|
|
||||||
async def get_next_mult(self, n):
|
def get_next_mult(self, n):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with await stream_id_gen.get_next(n) as stream_ids:
|
async with stream_id_gen.get_next(n) as stream_ids:
|
||||||
# ... persist events ...
|
# ... persist events ...
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@ -140,7 +140,7 @@ class StreamIdGenerator:
|
|||||||
for next_id in next_ids:
|
for next_id in next_ids:
|
||||||
self._unfinished_ids.append(next_id)
|
self._unfinished_ids.append(next_id)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextmanager
|
||||||
def manager():
|
def manager():
|
||||||
try:
|
try:
|
||||||
yield next_ids
|
yield next_ids
|
||||||
@ -149,7 +149,7 @@ class StreamIdGenerator:
|
|||||||
for next_id in next_ids:
|
for next_id in next_ids:
|
||||||
self._unfinished_ids.remove(next_id)
|
self._unfinished_ids.remove(next_id)
|
||||||
|
|
||||||
return manager()
|
return _AsyncCtxManagerWrapper(manager())
|
||||||
|
|
||||||
def get_current_token(self):
|
def get_current_token(self):
|
||||||
"""Returns the maximum stream id such that all stream ids less than or
|
"""Returns the maximum stream id such that all stream ids less than or
|
||||||
@ -282,59 +282,23 @@ class MultiWriterIdGenerator:
|
|||||||
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
||||||
return self._sequence_gen.get_next_mult_txn(txn, n)
|
return self._sequence_gen.get_next_mult_txn(txn, n)
|
||||||
|
|
||||||
async def get_next(self):
|
def get_next(self):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with await stream_id_gen.get_next() as stream_id:
|
async with stream_id_gen.get_next() as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
|
|
||||||
|
|
||||||
# Assert the fetched ID is actually greater than what we currently
|
return _MultiWriterCtxManager(self)
|
||||||
# believe the ID to be. If not, then the sequence and table have got
|
|
||||||
# out of sync somehow.
|
|
||||||
with self._lock:
|
|
||||||
assert self._current_positions.get(self._instance_name, 0) < next_id
|
|
||||||
|
|
||||||
self._unfinished_ids.add(next_id)
|
def get_next_mult(self, n: int):
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def manager():
|
|
||||||
try:
|
|
||||||
# Multiply by the return factor so that the ID has correct sign.
|
|
||||||
yield self._return_factor * next_id
|
|
||||||
finally:
|
|
||||||
self._mark_id_as_finished(next_id)
|
|
||||||
|
|
||||||
return manager()
|
|
||||||
|
|
||||||
async def get_next_mult(self, n: int):
|
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
with await stream_id_gen.get_next_mult(5) as stream_ids:
|
async with stream_id_gen.get_next_mult(5) as stream_ids:
|
||||||
# ... persist events ...
|
# ... persist events ...
|
||||||
"""
|
"""
|
||||||
next_ids = await self._db.runInteraction(
|
|
||||||
"_load_next_mult_id", self._load_next_mult_id_txn, n
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert the fetched ID is actually greater than any ID we've already
|
return _MultiWriterCtxManager(self, n)
|
||||||
# seen. If not, then the sequence and table have got out of sync
|
|
||||||
# somehow.
|
|
||||||
with self._lock:
|
|
||||||
assert max(self._current_positions.values(), default=0) < min(next_ids)
|
|
||||||
|
|
||||||
self._unfinished_ids.update(next_ids)
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def manager():
|
|
||||||
try:
|
|
||||||
yield [self._return_factor * i for i in next_ids]
|
|
||||||
finally:
|
|
||||||
for i in next_ids:
|
|
||||||
self._mark_id_as_finished(i)
|
|
||||||
|
|
||||||
return manager()
|
|
||||||
|
|
||||||
def get_next_txn(self, txn: LoggingTransaction):
|
def get_next_txn(self, txn: LoggingTransaction):
|
||||||
"""
|
"""
|
||||||
@ -482,3 +446,61 @@ class MultiWriterIdGenerator:
|
|||||||
# There was a gap in seen positions, so there is nothing more to
|
# There was a gap in seen positions, so there is nothing more to
|
||||||
# do.
|
# do.
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class _AsyncCtxManagerWrapper:
|
||||||
|
"""Helper class to convert a plain context manager to an async one.
|
||||||
|
|
||||||
|
This is mainly useful if you have a plain context manager but the interface
|
||||||
|
requires an async one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
inner = attr.ib()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self.inner.__enter__()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return self.inner.__exit__(exc_type, exc, tb)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class _MultiWriterCtxManager:
|
||||||
|
"""Async context manager returned by MultiWriterIdGenerator
|
||||||
|
"""
|
||||||
|
|
||||||
|
id_gen = attr.ib(type=MultiWriterIdGenerator)
|
||||||
|
multiple_ids = attr.ib(type=Optional[int], default=None)
|
||||||
|
stream_ids = attr.ib(type=List[int], factory=list)
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Union[int, List[int]]:
|
||||||
|
self.stream_ids = await self.id_gen._db.runInteraction(
|
||||||
|
"_load_next_mult_id",
|
||||||
|
self.id_gen._load_next_mult_id_txn,
|
||||||
|
self.multiple_ids or 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the fetched ID is actually greater than any ID we've already
|
||||||
|
# seen. If not, then the sequence and table have got out of sync
|
||||||
|
# somehow.
|
||||||
|
with self.id_gen._lock:
|
||||||
|
assert max(self.id_gen._current_positions.values(), default=0) < min(
|
||||||
|
self.stream_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
self.id_gen._unfinished_ids.update(self.stream_ids)
|
||||||
|
|
||||||
|
if self.multiple_ids is None:
|
||||||
|
return self.stream_ids[0] * self.id_gen._return_factor
|
||||||
|
else:
|
||||||
|
return [i * self.id_gen._return_factor for i in self.stream_ids]
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
for i in self.stream_ids:
|
||||||
|
self.id_gen._mark_id_as_finished(i)
|
||||||
|
|
||||||
|
if exc_type is not None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
||||||
|
@ -111,7 +111,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||||||
# advanced after we leave the context manager.
|
# advanced after we leave the context manager.
|
||||||
|
|
||||||
async def _get_next_async():
|
async def _get_next_async():
|
||||||
with await id_gen.get_next() as stream_id:
|
async with id_gen.get_next() as stream_id:
|
||||||
self.assertEqual(stream_id, 8)
|
self.assertEqual(stream_id, 8)
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||||
@ -139,10 +139,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||||||
ctx3 = self.get_success(id_gen.get_next())
|
ctx3 = self.get_success(id_gen.get_next())
|
||||||
ctx4 = self.get_success(id_gen.get_next())
|
ctx4 = self.get_success(id_gen.get_next())
|
||||||
|
|
||||||
s1 = ctx1.__enter__()
|
s1 = self.get_success(ctx1.__aenter__())
|
||||||
s2 = ctx2.__enter__()
|
s2 = self.get_success(ctx2.__aenter__())
|
||||||
s3 = ctx3.__enter__()
|
s3 = self.get_success(ctx3.__aenter__())
|
||||||
s4 = ctx4.__enter__()
|
s4 = self.get_success(ctx4.__aenter__())
|
||||||
|
|
||||||
self.assertEqual(s1, 8)
|
self.assertEqual(s1, 8)
|
||||||
self.assertEqual(s2, 9)
|
self.assertEqual(s2, 9)
|
||||||
@ -152,22 +152,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||||
|
|
||||||
ctx2.__exit__(None, None, None)
|
self.get_success(ctx2.__aexit__(None, None, None))
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||||
|
|
||||||
ctx1.__exit__(None, None, None)
|
self.get_success(ctx1.__aexit__(None, None, None))
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
|
||||||
|
|
||||||
ctx4.__exit__(None, None, None)
|
self.get_success(ctx4.__aexit__(None, None, None))
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
|
||||||
|
|
||||||
ctx3.__exit__(None, None, None)
|
self.get_success(ctx3.__aexit__(None, None, None))
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 11})
|
self.assertEqual(id_gen.get_positions(), {"master": 11})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
|
||||||
@ -190,7 +190,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||||||
# advanced after we leave the context manager.
|
# advanced after we leave the context manager.
|
||||||
|
|
||||||
async def _get_next_async():
|
async def _get_next_async():
|
||||||
with await first_id_gen.get_next() as stream_id:
|
async with first_id_gen.get_next() as stream_id:
|
||||||
self.assertEqual(stream_id, 8)
|
self.assertEqual(stream_id, 8)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -208,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||||||
# stream ID
|
# stream ID
|
||||||
|
|
||||||
async def _get_next_async():
|
async def _get_next_async():
|
||||||
with await second_id_gen.get_next() as stream_id:
|
async with second_id_gen.get_next() as stream_id:
|
||||||
self.assertEqual(stream_id, 9)
|
self.assertEqual(stream_id, 9)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -305,10 +305,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||||||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||||
with self.get_success(id_gen.get_next()) as stream_id:
|
|
||||||
|
async def _get_next_async():
|
||||||
|
async with id_gen.get_next() as stream_id:
|
||||||
self.assertEqual(stream_id, 6)
|
self.assertEqual(stream_id, 6)
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||||
|
|
||||||
|
self.get_success(_get_next_async())
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||||
|
|
||||||
# We assume that so long as `get_next` does correctly advance the
|
# We assume that so long as `get_next` does correctly advance the
|
||||||
@ -373,17 +377,23 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
id_gen = self._create_id_generator()
|
id_gen = self._create_id_generator()
|
||||||
|
|
||||||
with self.get_success(id_gen.get_next()) as stream_id:
|
async def _get_next_async():
|
||||||
|
async with id_gen.get_next() as stream_id:
|
||||||
self._insert_row("master", stream_id)
|
self._insert_row("master", stream_id)
|
||||||
|
|
||||||
|
self.get_success(_get_next_async())
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": -1})
|
self.assertEqual(id_gen.get_positions(), {"master": -1})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
|
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
|
||||||
|
|
||||||
with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
|
async def _get_next_async2():
|
||||||
|
async with id_gen.get_next_mult(3) as stream_ids:
|
||||||
for stream_id in stream_ids:
|
for stream_id in stream_ids:
|
||||||
self._insert_row("master", stream_id)
|
self._insert_row("master", stream_id)
|
||||||
|
|
||||||
|
self.get_success(_get_next_async2())
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": -4})
|
self.assertEqual(id_gen.get_positions(), {"master": -4})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), -4)
|
self.assertEqual(id_gen.get_persisted_upto_position(), -4)
|
||||||
@ -402,19 +412,25 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||||||
id_gen_1 = self._create_id_generator("first")
|
id_gen_1 = self._create_id_generator("first")
|
||||||
id_gen_2 = self._create_id_generator("second")
|
id_gen_2 = self._create_id_generator("second")
|
||||||
|
|
||||||
with self.get_success(id_gen_1.get_next()) as stream_id:
|
async def _get_next_async():
|
||||||
|
async with id_gen_1.get_next() as stream_id:
|
||||||
self._insert_row("first", stream_id)
|
self._insert_row("first", stream_id)
|
||||||
id_gen_2.advance("first", stream_id)
|
id_gen_2.advance("first", stream_id)
|
||||||
|
|
||||||
|
self.get_success(_get_next_async())
|
||||||
|
|
||||||
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
|
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
|
||||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
|
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
|
||||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
|
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
|
||||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
|
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
|
||||||
|
|
||||||
with self.get_success(id_gen_2.get_next()) as stream_id:
|
async def _get_next_async2():
|
||||||
|
async with id_gen_2.get_next() as stream_id:
|
||||||
self._insert_row("second", stream_id)
|
self._insert_row("second", stream_id)
|
||||||
id_gen_1.advance("second", stream_id)
|
id_gen_1.advance("second", stream_id)
|
||||||
|
|
||||||
|
self.get_success(_get_next_async2())
|
||||||
|
|
||||||
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
|
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
|
||||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
|
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
|
||||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
|
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
|
||||||
|
Loading…
Reference in New Issue
Block a user