Convert all namedtuples to attrs. (#11665)

To improve type hints throughout the code.
This commit is contained in:
Patrick Cloke 2021-12-30 13:47:12 -05:00 committed by GitHub
parent 07a3b5daba
commit cbd82d0b2d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 231 additions and 206 deletions

View file

@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from typing import Iterable, List, Optional, Tuple
import attr
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomAliasMapping:
room_id: str
room_alias: str
servers: List[str]
class DirectoryWorkerStore(CacheInvalidationWorkerStore):

View file

@ -1976,14 +1976,17 @@ class PersistEventsStore:
txn, self.store.get_retention_policy_for_room, (event.room_id,)
)
def store_event_search_txn(self, txn, event, key, value):
def store_event_search_txn(
self, txn: LoggingTransaction, event: EventBase, key: str, value: str
) -> None:
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
txn: The database transaction.
event: The event being added to the search table.
key: A key describing the search value (one of "content.name",
"content.topic", or "content.body")
value: The value from the event's content.
"""
self.store.store_search_entries_txn(
txn,

View file

@ -13,11 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
from abc import abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
import attr
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
@ -43,9 +54,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
RatelimitOverride = collections.namedtuple(
"RatelimitOverride", ("messages_per_second", "burst_count")
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RatelimitOverride:
messages_per_second: int
burst_count: int
class RoomSortOrder(Enum):
@ -207,6 +219,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
assert network_tuple.network_id is not None
query_args.append(network_tuple.network_id)
else:
published_sql = """
@ -284,7 +297,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"""
where_clauses = []
query_args = []
query_args: List[Union[str, int]] = []
if network_tuple:
if network_tuple.appservice_id:
@ -293,6 +306,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
assert network_tuple.network_id is not None
query_args.append(network_tuple.network_id)
else:
published_sql = """

View file

@ -14,9 +14,10 @@
import logging
import re
from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
import attr
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@ -33,10 +34,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
SearchEntry = namedtuple(
"SearchEntry",
["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"],
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class SearchEntry:
key: str
value: str
event_id: str
room_id: str
stream_ordering: Optional[int]
origin_server_ts: int
def _clean_value_for_search(value: str) -> str:

View file

@ -14,7 +14,6 @@
# limitations under the License.
import collections.abc
import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
@ -43,19 +42,6 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
class _GetStateGroupDelta(
namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
):
"""Return type of get_state_group_delta that implements __len__, which lets
us use the itrable flag when caching
"""
__slots__ = []
def __len__(self):
return len(self.delta_ids) if self.delta_ids else 0
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers."""

View file

@ -36,9 +36,9 @@ what sort order was used:
"""
import abc
import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
import attr
from frozendict import frozendict
from twisted.internet import defer
@ -74,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs
_EventDictReturn = namedtuple(
"_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
event_id: str
topological_ordering: Optional[int]
stream_ordering: int
def generate_pagination_where_clause(
@ -825,7 +827,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
for event, row in zip(events, rows):
stream = row.stream_ordering
if topo_order and row.topological_ordering:
topo = row.topological_ordering
topo: Optional[int] = row.topological_ordering
else:
topo = None
internal = event.internal_metadata