mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-12-15 19:18:46 -05:00
Bugbear: Add Mutable Parameter fixes (#9682)
Part of #9366 Adds in fixes for B006 and B008, both relating to mutable parameter lint errors. Signed-off-by: Jonathan de Jong <jonathan@automatia.nl>
This commit is contained in:
parent
64f4f506c5
commit
2ca4e349e9
38 changed files with 224 additions and 113 deletions
|
|
@ -900,7 +900,7 @@ class DatabasePool:
|
|||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
desc: str = "simple_upsert",
|
||||
lock: bool = True,
|
||||
) -> Optional[bool]:
|
||||
|
|
@ -927,6 +927,8 @@ class DatabasePool:
|
|||
Native upserts always return None. Emulated upserts return True if a
|
||||
new entry was created, False if an existing one was updated.
|
||||
"""
|
||||
insertion_values = insertion_values or {}
|
||||
|
||||
attempts = 0
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -964,7 +966,7 @@ class DatabasePool:
|
|||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
lock: bool = True,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
|
|
@ -982,6 +984,8 @@ class DatabasePool:
|
|||
Native upserts always return None. Emulated upserts return True if a
|
||||
new entry was created, False if an existing one was updated.
|
||||
"""
|
||||
insertion_values = insertion_values or {}
|
||||
|
||||
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
|
||||
self.simple_upsert_txn_native_upsert(
|
||||
txn, table, keyvalues, values, insertion_values=insertion_values
|
||||
|
|
@ -1003,7 +1007,7 @@ class DatabasePool:
|
|||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
lock: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
|
|
@ -1017,6 +1021,8 @@ class DatabasePool:
|
|||
Returns True if a new entry was created, False if an existing
|
||||
one was updated.
|
||||
"""
|
||||
insertion_values = insertion_values or {}
|
||||
|
||||
# We need to lock the table :(, unless we're *really* careful
|
||||
if lock:
|
||||
self.engine.lock_table(txn, table)
|
||||
|
|
@ -1077,7 +1083,7 @@ class DatabasePool:
|
|||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
insertion_values: Dict[str, Any] = {},
|
||||
insertion_values: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Use the native UPSERT functionality in recent PostgreSQL versions.
|
||||
|
|
@ -1090,7 +1096,7 @@ class DatabasePool:
|
|||
"""
|
||||
allvalues = {} # type: Dict[str, Any]
|
||||
allvalues.update(keyvalues)
|
||||
allvalues.update(insertion_values)
|
||||
allvalues.update(insertion_values or {})
|
||||
|
||||
if not values:
|
||||
latter = "NOTHING"
|
||||
|
|
@ -1513,7 +1519,7 @@ class DatabasePool:
|
|||
column: str,
|
||||
iterable: Iterable[Any],
|
||||
retcols: Iterable[str],
|
||||
keyvalues: Dict[str, Any] = {},
|
||||
keyvalues: Optional[Dict[str, Any]] = None,
|
||||
desc: str = "simple_select_many_batch",
|
||||
batch_size: int = 100,
|
||||
) -> List[Any]:
|
||||
|
|
@ -1531,6 +1537,8 @@ class DatabasePool:
|
|||
desc: description of the transaction, for logging and metrics
|
||||
batch_size: the number of rows for each select query
|
||||
"""
|
||||
keyvalues = keyvalues or {}
|
||||
|
||||
results = [] # type: List[Dict[str, Any]]
|
||||
|
||||
if not iterable:
|
||||
|
|
|
|||
|
|
@ -320,8 +320,8 @@ class PersistEventsStore:
|
|||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool,
|
||||
state_delta_for_room: Dict[str, DeltaState] = {},
|
||||
new_forward_extremeties: Dict[str, List[str]] = {},
|
||||
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
|
||||
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
|
||||
):
|
||||
"""Insert some number of room events into the necessary database tables.
|
||||
|
||||
|
|
@ -342,6 +342,9 @@ class PersistEventsStore:
|
|||
extremities.
|
||||
|
||||
"""
|
||||
state_delta_for_room = state_delta_for_room or {}
|
||||
new_forward_extremeties = new_forward_extremeties or {}
|
||||
|
||||
all_events_and_contexts = events_and_contexts
|
||||
|
||||
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
|
||||
|
|
|
|||
|
|
@ -1171,7 +1171,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
user_id: str,
|
||||
membership: str,
|
||||
is_admin: bool = False,
|
||||
content: JsonDict = {},
|
||||
content: Optional[JsonDict] = None,
|
||||
local_attestation: Optional[dict] = None,
|
||||
remote_attestation: Optional[dict] = None,
|
||||
is_publicised: bool = False,
|
||||
|
|
@ -1192,6 +1192,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
is_publicised: Whether this should be publicised.
|
||||
"""
|
||||
|
||||
content = content or {}
|
||||
|
||||
def _register_user_group_membership_txn(txn, next_id):
|
||||
# TODO: Upsert?
|
||||
self.db_pool.simple_delete_txn(
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
# FIXME: how should this be cached?
|
||||
async def get_filtered_current_state_ids(
|
||||
self, room_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state event of a given type for a room based on the
|
||||
current_state_events table. This may not be as up-to-date as the result
|
||||
|
|
@ -205,7 +205,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
Map from type/state_key to event ID.
|
||||
"""
|
||||
|
||||
where_clause, where_args = state_filter.make_sql_filter_clause()
|
||||
where_clause, where_args = (
|
||||
state_filter or StateFilter.all()
|
||||
).make_sql_filter_clause()
|
||||
|
||||
if not where_clause:
|
||||
# We delegate to the cached version
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
|
@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
|
|||
return count
|
||||
|
||||
def _get_state_groups_from_groups_txn(
|
||||
self, txn, groups, state_filter=StateFilter.all()
|
||||
self, txn, groups, state_filter: Optional[StateFilter] = None
|
||||
):
|
||||
state_filter = state_filter or StateFilter.all()
|
||||
|
||||
results = {group: {} for group in groups}
|
||||
|
||||
where_clause, where_args = state_filter.make_sql_filter_clause()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Iterable, List, Set, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
|
@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
return state_filter.filter_state(state_dict_ids), not missing_types
|
||||
|
||||
async def _get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[int, MutableStateMap[str]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
|
@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
Returns:
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
state_filter = state_filter or StateFilter.all()
|
||||
|
||||
member_filter, non_member_filter = state_filter.get_member_split()
|
||||
|
||||
|
|
|
|||
|
|
@ -449,7 +449,7 @@ class StateGroupStorage:
|
|||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||
|
||||
async def get_state_for_events(
|
||||
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
|
||||
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[str, StateMap[EventBase]]:
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
|
|
@ -465,7 +465,7 @@ class StateGroupStorage:
|
|||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
state_event_map = await self.stores.main.get_events(
|
||||
|
|
@ -485,7 +485,7 @@ class StateGroupStorage:
|
|||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_ids_for_events(
|
||||
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
|
||||
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[str, StateMap[str]]:
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
|
|
@ -502,7 +502,7 @@ class StateGroupStorage:
|
|||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
|
|
@ -513,7 +513,7 @@ class StateGroupStorage:
|
|||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_for_event(
|
||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[EventBase]:
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
|
@ -525,11 +525,13 @@ class StateGroupStorage:
|
|||
Returns:
|
||||
A dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = await self.get_state_for_events([event_id], state_filter)
|
||||
state_map = await self.get_state_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
)
|
||||
return state_map[event_id]
|
||||
|
||||
async def get_state_ids_for_event(
|
||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
|
@ -541,11 +543,13 @@ class StateGroupStorage:
|
|||
Returns:
|
||||
A dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = await self.get_state_ids_for_events([event_id], state_filter)
|
||||
state_map = await self.get_state_ids_for_events(
|
||||
[event_id], state_filter or StateFilter.all()
|
||||
)
|
||||
return state_map[event_id]
|
||||
|
||||
def _get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
|
@ -558,7 +562,9 @@ class StateGroupStorage:
|
|||
Returns:
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
return self.stores.state._get_state_for_groups(groups, state_filter)
|
||||
return self.stores.state._get_state_for_groups(
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
async def store_state_group(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import logging
|
|||
import threading
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import attr
|
||||
|
||||
|
|
@ -91,7 +91,14 @@ class StreamIdGenerator:
|
|||
# ... persist event ...
|
||||
"""
|
||||
|
||||
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
|
||||
def __init__(
|
||||
self,
|
||||
db_conn,
|
||||
table,
|
||||
column,
|
||||
extra_tables: Iterable[Tuple[str, str]] = (),
|
||||
step=1,
|
||||
):
|
||||
assert step != 0
|
||||
self._lock = threading.Lock()
|
||||
self._step = step
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue