mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
23740eaa3d
During the migration the automated script to update the copyright headers accidentally got rid of some of the existing copyright lines. Reinstate them.
228 lines
7.5 KiB
Python
228 lines
7.5 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright 2014-2016 OpenMarket Ltd
|
|
# Copyright (C) 2023 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
# Originally licensed under the Apache License, Version 2.0:
|
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
|
#
|
|
# [This file includes modifications made by New Vector Limited]
|
|
#
|
|
#
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
from synapse.storage._base import SQLBaseStore
|
|
from synapse.storage.database import (
|
|
DatabasePool,
|
|
LoggingDatabaseConnection,
|
|
LoggingTransaction,
|
|
)
|
|
from synapse.storage.databases.main.roommember import ProfileInfo
|
|
from synapse.storage.engines import PostgresEngine
|
|
from synapse.types import JsonDict, UserID
|
|
|
|
if TYPE_CHECKING:
|
|
from synapse.server import HomeServer
|
|
|
|
|
|
class ProfileWorkerStore(SQLBaseStore):
|
|
def __init__(
|
|
self,
|
|
database: DatabasePool,
|
|
db_conn: LoggingDatabaseConnection,
|
|
hs: "HomeServer",
|
|
):
|
|
super().__init__(database, db_conn, hs)
|
|
self.server_name: str = hs.hostname
|
|
self.database_engine = database.engine
|
|
self.db_pool.updates.register_background_index_update(
|
|
"profiles_full_user_id_key_idx",
|
|
index_name="profiles_full_user_id_key",
|
|
table="profiles",
|
|
columns=["full_user_id"],
|
|
unique=True,
|
|
)
|
|
|
|
self.db_pool.updates.register_background_update_handler(
|
|
"populate_full_user_id_profiles", self.populate_full_user_id_profiles
|
|
)
|
|
|
|
async def populate_full_user_id_profiles(
|
|
self, progress: JsonDict, batch_size: int
|
|
) -> int:
|
|
"""
|
|
Background update to populate the column `full_user_id` of the table
|
|
profiles from entries in the column `user_local_part` of the same table
|
|
"""
|
|
|
|
lower_bound_id = progress.get("lower_bound_id", "")
|
|
|
|
def _get_last_id(txn: LoggingTransaction) -> Optional[str]:
|
|
sql = """
|
|
SELECT user_id FROM profiles
|
|
WHERE user_id > ?
|
|
ORDER BY user_id
|
|
LIMIT 1 OFFSET 1000
|
|
"""
|
|
txn.execute(sql, (lower_bound_id,))
|
|
res = txn.fetchone()
|
|
if res:
|
|
upper_bound_id = res[0]
|
|
return upper_bound_id
|
|
else:
|
|
return None
|
|
|
|
def _process_batch(
|
|
txn: LoggingTransaction, lower_bound_id: str, upper_bound_id: str
|
|
) -> None:
|
|
sql = """
|
|
UPDATE profiles
|
|
SET full_user_id = '@' || user_id || ?
|
|
WHERE ? < user_id AND user_id <= ? AND full_user_id IS NULL
|
|
"""
|
|
txn.execute(sql, (f":{self.server_name}", lower_bound_id, upper_bound_id))
|
|
|
|
def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
|
|
sql = """
|
|
UPDATE profiles
|
|
SET full_user_id = '@' || user_id || ?
|
|
WHERE ? < user_id AND full_user_id IS NULL
|
|
"""
|
|
txn.execute(
|
|
sql,
|
|
(
|
|
f":{self.server_name}",
|
|
lower_bound_id,
|
|
),
|
|
)
|
|
|
|
if isinstance(self.database_engine, PostgresEngine):
|
|
sql = """
|
|
ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null
|
|
"""
|
|
txn.execute(sql)
|
|
|
|
upper_bound_id = await self.db_pool.runInteraction(
|
|
"populate_full_user_id_profiles", _get_last_id
|
|
)
|
|
|
|
if upper_bound_id is None:
|
|
await self.db_pool.runInteraction(
|
|
"populate_full_user_id_profiles", _final_batch, lower_bound_id
|
|
)
|
|
|
|
await self.db_pool.updates._end_background_update(
|
|
"populate_full_user_id_profiles"
|
|
)
|
|
return 1
|
|
|
|
await self.db_pool.runInteraction(
|
|
"populate_full_user_id_profiles",
|
|
_process_batch,
|
|
lower_bound_id,
|
|
upper_bound_id,
|
|
)
|
|
|
|
progress["lower_bound_id"] = upper_bound_id
|
|
|
|
await self.db_pool.runInteraction(
|
|
"populate_full_user_id_profiles",
|
|
self.db_pool.updates._background_update_progress_txn,
|
|
"populate_full_user_id_profiles",
|
|
progress,
|
|
)
|
|
|
|
return 50
|
|
|
|
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
|
|
profile = await self.db_pool.simple_select_one(
|
|
table="profiles",
|
|
keyvalues={"full_user_id": user_id.to_string()},
|
|
retcols=("displayname", "avatar_url"),
|
|
desc="get_profileinfo",
|
|
allow_none=True,
|
|
)
|
|
if profile is None:
|
|
# no match
|
|
return ProfileInfo(None, None)
|
|
|
|
return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
|
|
|
|
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
|
|
return await self.db_pool.simple_select_one_onecol(
|
|
table="profiles",
|
|
keyvalues={"full_user_id": user_id.to_string()},
|
|
retcol="displayname",
|
|
desc="get_profile_displayname",
|
|
)
|
|
|
|
async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]:
|
|
return await self.db_pool.simple_select_one_onecol(
|
|
table="profiles",
|
|
keyvalues={"full_user_id": user_id.to_string()},
|
|
retcol="avatar_url",
|
|
desc="get_profile_avatar_url",
|
|
)
|
|
|
|
async def create_profile(self, user_id: UserID) -> None:
|
|
user_localpart = user_id.localpart
|
|
await self.db_pool.simple_insert(
|
|
table="profiles",
|
|
values={"user_id": user_localpart, "full_user_id": user_id.to_string()},
|
|
desc="create_profile",
|
|
)
|
|
|
|
async def set_profile_displayname(
|
|
self, user_id: UserID, new_displayname: Optional[str]
|
|
) -> None:
|
|
"""
|
|
Set the display name of a user.
|
|
|
|
Args:
|
|
user_id: The user's ID.
|
|
new_displayname: The new display name. If this is None, the user's display
|
|
name is removed.
|
|
"""
|
|
user_localpart = user_id.localpart
|
|
await self.db_pool.simple_upsert(
|
|
table="profiles",
|
|
keyvalues={"user_id": user_localpart},
|
|
values={
|
|
"displayname": new_displayname,
|
|
"full_user_id": user_id.to_string(),
|
|
},
|
|
desc="set_profile_displayname",
|
|
)
|
|
|
|
async def set_profile_avatar_url(
|
|
self, user_id: UserID, new_avatar_url: Optional[str]
|
|
) -> None:
|
|
"""
|
|
Set the avatar of a user.
|
|
|
|
Args:
|
|
user_id: The user's ID.
|
|
new_avatar_url: The new avatar URL. If this is None, the user's avatar is
|
|
removed.
|
|
"""
|
|
user_localpart = user_id.localpart
|
|
await self.db_pool.simple_upsert(
|
|
table="profiles",
|
|
keyvalues={"user_id": user_localpart},
|
|
values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
|
|
desc="set_profile_avatar_url",
|
|
)
|
|
|
|
|
|
class ProfileStore(ProfileWorkerStore):
|
|
pass
|