Prevent local quarantined media from being claimed by media retention (#12972)

This commit is contained in:
Andrew Morgan 2022-06-07 11:53:47 +01:00 committed by GitHub
parent f7baffd8ec
commit a47636c570
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 185 additions and 29 deletions

View file

@ -251,12 +251,36 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn
)
async def get_local_media_before(
async def get_local_media_ids(
self,
before_ts: int,
size_gt: int,
keep_profiles: bool,
include_quarantined_media: bool,
include_protected_media: bool,
) -> List[str]:
"""
Retrieve a list of media IDs from the local media store.
Args:
before_ts: Only retrieve IDs from media that was either last accessed
(or if never accessed, created) before the given UNIX timestamp in ms.
size_gt: Only retrieve IDs from media that has a size (in bytes) greater than
the given integer.
keep_profiles: If True, exclude media IDs from the results that are used in the
following situations:
* global profile user avatar
* per-room profile user avatar
* room avatar
* a user's avatar in the user directory
include_quarantined_media: If False, exclude media IDs from the results that have
been marked as quarantined.
include_protected_media: If False, exclude media IDs from the results that have
been marked as protected from quarantine.
Returns:
A list of local media IDs.
"""
# to find files that have never been accessed (last_access_ts IS NULL)
# compare with `created_ts`
@ -294,12 +318,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
sql += sql_keep
def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
if include_quarantined_media is False:
# Do not include media that has been quarantined
sql += """
AND quarantined_by IS NULL
"""
if include_protected_media is False:
# Do not include media that has been protected from quarantine
sql += """
AND safe_from_quarantine = false
"""
def _get_local_media_ids_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts, before_ts, size_gt))
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
"get_local_media_before", _get_local_media_before_txn
"get_local_media_ids", _get_local_media_ids_txn
)
async def store_local_media(
@ -599,15 +635,37 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail",
)
async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
) -> List[Dict[str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.
Args:
before_ts: Only retrieve IDs from media that was either last accessed
(or if never accessed, created) before the given UNIX timestamp in ms.
include_quarantined_media: If False, exclude media IDs from the results that have
been marked as quarantined.
Returns:
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
"""
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
if include_quarantined_media is False:
# Only include media that has not been quarantined
sql += """
AND quarantined_by IS NULL
"""
return await self.db_pool.execute(
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None: