Add type hints to media repository storage module (#11311)

This commit is contained in:
Patrick Cloke 2021-11-12 11:05:26 -05:00 committed by GitHub
parent 6f8f3d4bc5
commit 9b90b9454b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 62 deletions

1
changelog.d/11311.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to storage classes.

View File

@ -36,7 +36,6 @@ exclude = (?x)
|synapse/storage/databases/main/events_bg_updates.py |synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/events_worker.py |synapse/storage/databases/main/events_worker.py
|synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/media_repository.py
|synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py |synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/presence.py |synapse/storage/databases/main/presence.py

View File

@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.types import JsonDict from synapse.types import JsonDict, UserID
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -231,7 +231,7 @@ class PreviewUrlResource(DirectServeJsonResource):
og = await make_deferred_yieldable(observable.observe()) og = await make_deferred_yieldable(observable.observe())
respond_with_json_bytes(request, 200, og, send_cors=True) respond_with_json_bytes(request, 200, og, send_cors=True)
async def _do_preview(self, url: str, user: str, ts: int) -> bytes: async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview """Check the db, and download the URL and build a preview
Args: Args:
@ -360,7 +360,7 @@ class PreviewUrlResource(DirectServeJsonResource):
return jsonog.encode("utf8") return jsonog.encode("utf8")
async def _download_url(self, url: str, user: str) -> MediaInfo: async def _download_url(self, url: str, user: UserID) -> MediaInfo:
# TODO: we should probably honour robots.txt... except in practice # TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a # we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot? # bot, so are we really a robot?
@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
) )
async def _precache_image_url( async def _precache_image_url(
self, user: str, media_info: MediaInfo, og: JsonDict self, user: UserID, media_info: MediaInfo, og: JsonDict
) -> None: ) -> None:
""" """
Pre-cache the image (if one exists) for posterity Pre-cache the image (if one exists) for posterity

View File

@ -13,10 +13,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -46,7 +61,12 @@ class MediaSortOrder(Enum):
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
@ -102,13 +122,15 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
self._drop_media_index_without_method, self._drop_media_index_without_method,
) )
async def _drop_media_index_without_method(self, progress, batch_size): async def _drop_media_index_without_method(
self, progress: JsonDict, batch_size: int
) -> int:
"""background update handler which removes the old constraints. """background update handler which removes the old constraints.
Note that this is only run on postgres. Note that this is only run on postgres.
""" """
def f(txn): def f(txn: LoggingTransaction) -> None:
txn.execute( txn.execute(
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key" "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
) )
@ -126,7 +148,12 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars""" """Persistence for attachments and avatars"""
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname
@ -174,7 +201,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
plus the total count of all the user's media plus the total count of all the user's media
""" """
def get_local_media_by_user_paginate_txn(txn): def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:
# Set ordering # Set ordering
order_by_column = MediaSortOrder(order_by).value order_by_column = MediaSortOrder(order_by).value
@ -184,14 +213,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
else: else:
order = "ASC" order = "ASC"
args = [user_id] args: List[Union[str, int]] = [user_id]
sql = """ sql = """
SELECT COUNT(*) as total_media SELECT COUNT(*) as total_media
FROM local_media_repository FROM local_media_repository
WHERE user_id = ? WHERE user_id = ?
""" """
txn.execute(sql, args) txn.execute(sql, args)
count = txn.fetchone()[0] count = txn.fetchone()[0] # type: ignore[index]
sql = """ sql = """
SELECT SELECT
@ -268,7 +297,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
sql += sql_keep sql += sql_keep
def _get_local_media_before_txn(txn): def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts, before_ts, size_gt)) txn.execute(sql, (before_ts, before_ts, size_gt))
return [row[0] for row in txn] return [row[0] for row in txn]
@ -278,13 +307,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_local_media( async def store_local_media(
self, self,
media_id, media_id: str,
media_type, media_type: str,
time_now_ms, time_now_ms: int,
upload_name, upload_name: Optional[str],
media_length, media_length: int,
user_id, user_id: UserID,
url_cache=None, url_cache: Optional[str] = None,
) -> None: ) -> None:
await self.db_pool.simple_insert( await self.db_pool.simple_insert(
"local_media_repository", "local_media_repository",
@ -315,7 +344,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
None if the URL isn't cached. None if the URL isn't cached.
""" """
def get_url_cache_txn(txn): def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# get the most recently cached result (relative to the given ts) # get the most recently cached result (relative to the given ts)
sql = ( sql = (
"SELECT response_code, etag, expires_ts, og, media_id, download_ts" "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
@ -359,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_url_cache( async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts self, url, response_code, etag, expires_ts, og, media_id, download_ts
): ) -> None:
await self.db_pool.simple_insert( await self.db_pool.simple_insert(
"local_media_repository_url_cache", "local_media_repository_url_cache",
{ {
@ -390,13 +419,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_local_thumbnail( async def store_local_thumbnail(
self, self,
media_id, media_id: str,
thumbnail_width, thumbnail_width: int,
thumbnail_height, thumbnail_height: int,
thumbnail_type, thumbnail_type: str,
thumbnail_method, thumbnail_method: str,
thumbnail_length, thumbnail_length: int,
): ) -> None:
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="local_media_repository_thumbnails", table="local_media_repository_thumbnails",
keyvalues={ keyvalues={
@ -430,14 +459,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_cached_remote_media( async def store_cached_remote_media(
self, self,
origin, origin: str,
media_id, media_id: str,
media_type, media_type: str,
media_length, media_length: int,
time_now_ms, time_now_ms: int,
upload_name, upload_name: Optional[str],
filesystem_id, filesystem_id: str,
): ) -> None:
await self.db_pool.simple_insert( await self.db_pool.simple_insert(
"remote_media_cache", "remote_media_cache",
{ {
@ -458,7 +487,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
local_media: Iterable[str], local_media: Iterable[str],
remote_media: Iterable[Tuple[str, str]], remote_media: Iterable[Tuple[str, str]],
time_ms: int, time_ms: int,
): ) -> None:
"""Updates the last access time of the given media """Updates the last access time of the given media
Args: Args:
@ -467,7 +496,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
time_ms: Current time in milliseconds time_ms: Current time in milliseconds
""" """
def update_cache_txn(txn): def update_cache_txn(txn: LoggingTransaction) -> None:
sql = ( sql = (
"UPDATE remote_media_cache SET last_access_ts = ?" "UPDATE remote_media_cache SET last_access_ts = ?"
" WHERE media_origin = ? AND media_id = ?" " WHERE media_origin = ? AND media_id = ?"
@ -488,7 +517,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn "update_cached_last_access_time", update_cache_txn
) )
@ -542,15 +571,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def store_remote_media_thumbnail( async def store_remote_media_thumbnail(
self, self,
origin, origin: str,
media_id, media_id: str,
filesystem_id, filesystem_id: str,
thumbnail_width, thumbnail_width: int,
thumbnail_height, thumbnail_height: int,
thumbnail_type, thumbnail_type: str,
thumbnail_method, thumbnail_method: str,
thumbnail_length, thumbnail_length: int,
): ) -> None:
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="remote_media_cache_thumbnails", table="remote_media_cache_thumbnails",
keyvalues={ keyvalues={
@ -566,7 +595,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail", desc="store_remote_media_thumbnail",
) )
async def get_remote_media_before(self, before_ts): async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
sql = ( sql = (
"SELECT media_origin, media_id, filesystem_id" "SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache" " FROM remote_media_cache"
@ -602,7 +631,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" LIMIT 500" " LIMIT 500"
) )
def _get_expired_url_cache_txn(txn): def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (now_ts,)) txn.execute(sql, (now_ts,))
return [row[0] for row in txn] return [row[0] for row in txn]
@ -610,18 +639,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_expired_url_cache", _get_expired_url_cache_txn "get_expired_url_cache", _get_expired_url_cache_txn
) )
async def delete_url_cache(self, media_ids): async def delete_url_cache(self, media_ids: Collection[str]) -> None:
if len(media_ids) == 0: if len(media_ids) == 0:
return return
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn): def _delete_url_cache_txn(txn: LoggingTransaction) -> None:
txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction( await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn)
"delete_url_cache", _delete_url_cache_txn
)
async def get_url_cache_media_before(self, before_ts: int) -> List[str]: async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = ( sql = (
@ -631,7 +658,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" LIMIT 500" " LIMIT 500"
) )
def _get_url_cache_media_before_txn(txn): def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts,)) txn.execute(sql, (before_ts,))
return [row[0] for row in txn] return [row[0] for row in txn]
@ -639,11 +666,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_url_cache_media_before", _get_url_cache_media_before_txn "get_url_cache_media_before", _get_url_cache_media_before_txn
) )
async def delete_url_cache_media(self, media_ids): async def delete_url_cache_media(self, media_ids: Collection[str]) -> None:
if len(media_ids) == 0: if len(media_ids) == 0:
return return
def _delete_url_cache_media_txn(txn): def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None:
sql = "DELETE FROM local_media_repository WHERE media_id = ?" sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
@ -652,6 +679,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn "delete_url_cache_media", _delete_url_cache_media_txn
) )