Add additional type hints to the storage module. (#8980)

This commit is contained in:
Patrick Cloke 2020-12-30 08:09:53 -05:00 committed by GitHub
parent b8591899ab
commit 637282bb50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 224 additions and 148 deletions

View file

@ -12,9 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
from typing import (
TYPE_CHECKING,
Awaitable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
)
import attr
@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.storage.databases import Databases
logger = logging.getLogger(__name__)
# Used for generic functions below
@ -330,10 +343,12 @@ class StateGroupStorage:
"""High level interface to fetching state for event.
"""
def __init__(self, hs, stores):
def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
async def get_state_group_delta(self, state_group: int):
async def get_state_group_delta(
self, state_group: int
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between
the old and the new.
@ -341,8 +356,8 @@ class StateGroupStorage:
state_group: The state group used to retrieve state deltas.
Returns:
Tuple[Optional[int], Optional[StateMap[str]]]:
(prev_group, delta_ids)
A tuple of the previous group and a state map of the event IDs which
make up the delta between the old and new state groups.
"""
return await self.stores.state.get_state_group_delta(state_group)
@ -436,7 +451,7 @@ class StateGroupStorage:
async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@ -472,7 +487,7 @@ class StateGroupStorage:
async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
@ -500,7 +515,7 @@ class StateGroupStorage:
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
@ -516,7 +531,7 @@ class StateGroupStorage:
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event