mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
N + 3
: Read from column full_user_id
rather than user_id
of tables profiles
and user_filters
(#15649)
This commit is contained in:
parent
e0f2429d13
commit
d0c4257f14
1
changelog.d/15649.misc
Normal file
1
changelog.d/15649.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Read from column `full_user_id` rather than `user_id` of tables `profiles` and `user_filters`.
|
@ -152,9 +152,9 @@ class Filtering:
|
|||||||
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
|
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
|
||||||
|
|
||||||
async def get_user_filter(
|
async def get_user_filter(
|
||||||
self, user_localpart: str, filter_id: Union[int, str]
|
self, user_id: UserID, filter_id: Union[int, str]
|
||||||
) -> "FilterCollection":
|
) -> "FilterCollection":
|
||||||
result = await self.store.get_user_filter(user_localpart, filter_id)
|
result = await self.store.get_user_filter(user_id, filter_id)
|
||||||
return FilterCollection(self._hs, result)
|
return FilterCollection(self._hs, result)
|
||||||
|
|
||||||
def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> Awaitable[int]:
|
def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> Awaitable[int]:
|
||||||
|
@ -164,7 +164,7 @@ class AccountValidityHandler:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
user_display_name = await self.store.get_profile_displayname(
|
user_display_name = await self.store.get_profile_displayname(
|
||||||
UserID.from_string(user_id).localpart
|
UserID.from_string(user_id)
|
||||||
)
|
)
|
||||||
if user_display_name is None:
|
if user_display_name is None:
|
||||||
user_display_name = user_id
|
user_display_name = user_id
|
||||||
|
@ -89,7 +89,7 @@ class AdminHandler:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Add additional user metadata
|
# Add additional user metadata
|
||||||
profile = await self._store.get_profileinfo(user.localpart)
|
profile = await self._store.get_profileinfo(user)
|
||||||
threepids = await self._store.user_get_threepids(user.to_string())
|
threepids = await self._store.user_get_threepids(user.to_string())
|
||||||
external_ids = [
|
external_ids = [
|
||||||
({"auth_provider": auth_provider, "external_id": external_id})
|
({"auth_provider": auth_provider, "external_id": external_id})
|
||||||
|
@ -1759,7 +1759,7 @@ class AuthHandler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
user_profile_data = await self.store.get_profileinfo(
|
user_profile_data = await self.store.get_profileinfo(
|
||||||
UserID.from_string(registered_user_id).localpart
|
UserID.from_string(registered_user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store any extra attributes which will be passed in the login response.
|
# Store any extra attributes which will be passed in the login response.
|
||||||
|
@ -297,5 +297,5 @@ class DeactivateAccountHandler:
|
|||||||
# Add the user to the directory, if necessary. Note that
|
# Add the user to the directory, if necessary. Note that
|
||||||
# this must be done after the user is re-activated, because
|
# this must be done after the user is re-activated, because
|
||||||
# deactivated users are excluded from the user directory.
|
# deactivated users are excluded from the user directory.
|
||||||
profile = await self.store.get_profileinfo(user.localpart)
|
profile = await self.store.get_profileinfo(user)
|
||||||
await self.user_directory_handler.handle_local_profile_change(user_id, profile)
|
await self.user_directory_handler.handle_local_profile_change(user_id, profile)
|
||||||
|
@ -67,7 +67,7 @@ class ProfileHandler:
|
|||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
|
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
profileinfo = await self.store.get_profileinfo(target_user.localpart)
|
profileinfo = await self.store.get_profileinfo(target_user)
|
||||||
if profileinfo.display_name is None:
|
if profileinfo.display_name is None:
|
||||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||||
|
|
||||||
@ -99,9 +99,7 @@ class ProfileHandler:
|
|||||||
async def get_displayname(self, target_user: UserID) -> Optional[str]:
|
async def get_displayname(self, target_user: UserID) -> Optional[str]:
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
displayname = await self.store.get_profile_displayname(
|
displayname = await self.store.get_profile_displayname(target_user)
|
||||||
target_user.localpart
|
|
||||||
)
|
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||||
@ -147,7 +145,7 @@ class ProfileHandler:
|
|||||||
raise AuthError(400, "Cannot set another user's displayname")
|
raise AuthError(400, "Cannot set another user's displayname")
|
||||||
|
|
||||||
if not by_admin and not self.hs.config.registration.enable_set_displayname:
|
if not by_admin and not self.hs.config.registration.enable_set_displayname:
|
||||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user)
|
||||||
if profile.display_name:
|
if profile.display_name:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
@ -180,7 +178,7 @@ class ProfileHandler:
|
|||||||
|
|
||||||
await self.store.set_profile_displayname(target_user, displayname_to_set)
|
await self.store.set_profile_displayname(target_user, displayname_to_set)
|
||||||
|
|
||||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user)
|
||||||
await self.user_directory_handler.handle_local_profile_change(
|
await self.user_directory_handler.handle_local_profile_change(
|
||||||
target_user.to_string(), profile
|
target_user.to_string(), profile
|
||||||
)
|
)
|
||||||
@ -194,9 +192,7 @@ class ProfileHandler:
|
|||||||
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
|
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
avatar_url = await self.store.get_profile_avatar_url(
|
avatar_url = await self.store.get_profile_avatar_url(target_user)
|
||||||
target_user.localpart
|
|
||||||
)
|
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||||
@ -241,7 +237,7 @@ class ProfileHandler:
|
|||||||
raise AuthError(400, "Cannot set another user's avatar_url")
|
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||||
|
|
||||||
if not by_admin and not self.hs.config.registration.enable_set_avatar_url:
|
if not by_admin and not self.hs.config.registration.enable_set_avatar_url:
|
||||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user)
|
||||||
if profile.avatar_url:
|
if profile.avatar_url:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
||||||
@ -272,7 +268,7 @@ class ProfileHandler:
|
|||||||
|
|
||||||
await self.store.set_profile_avatar_url(target_user, avatar_url_to_set)
|
await self.store.set_profile_avatar_url(target_user, avatar_url_to_set)
|
||||||
|
|
||||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user)
|
||||||
await self.user_directory_handler.handle_local_profile_change(
|
await self.user_directory_handler.handle_local_profile_change(
|
||||||
target_user.to_string(), profile
|
target_user.to_string(), profile
|
||||||
)
|
)
|
||||||
@ -369,14 +365,10 @@ class ProfileHandler:
|
|||||||
response = {}
|
response = {}
|
||||||
try:
|
try:
|
||||||
if just_field is None or just_field == "displayname":
|
if just_field is None or just_field == "displayname":
|
||||||
response["displayname"] = await self.store.get_profile_displayname(
|
response["displayname"] = await self.store.get_profile_displayname(user)
|
||||||
user.localpart
|
|
||||||
)
|
|
||||||
|
|
||||||
if just_field is None or just_field == "avatar_url":
|
if just_field is None or just_field == "avatar_url":
|
||||||
response["avatar_url"] = await self.store.get_profile_avatar_url(
|
response["avatar_url"] = await self.store.get_profile_avatar_url(user)
|
||||||
user.localpart
|
|
||||||
)
|
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||||
|
@ -315,7 +315,7 @@ class RegistrationHandler:
|
|||||||
approved=approved,
|
approved=approved,
|
||||||
)
|
)
|
||||||
|
|
||||||
profile = await self.store.get_profileinfo(localpart)
|
profile = await self.store.get_profileinfo(user)
|
||||||
await self.user_directory_handler.handle_local_profile_change(
|
await self.user_directory_handler.handle_local_profile_change(
|
||||||
user_id, profile
|
user_id, profile
|
||||||
)
|
)
|
||||||
|
@ -655,7 +655,9 @@ class ModuleApi:
|
|||||||
Returns:
|
Returns:
|
||||||
The profile information (i.e. display name and avatar URL).
|
The profile information (i.e. display name and avatar URL).
|
||||||
"""
|
"""
|
||||||
return await self._store.get_profileinfo(localpart)
|
server_name = self._hs.hostname
|
||||||
|
user_id = UserID.from_string(f"@{localpart}:{server_name}")
|
||||||
|
return await self._store.get_profileinfo(user_id)
|
||||||
|
|
||||||
async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
|
async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
|
||||||
"""Look up the threepids (email addresses and phone numbers) associated with the
|
"""Look up the threepids (email addresses and phone numbers) associated with the
|
||||||
|
@ -247,7 +247,7 @@ class Mailer:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
user_display_name = await self.store.get_profile_displayname(
|
user_display_name = await self.store.get_profile_displayname(
|
||||||
UserID.from_string(user_id).localpart
|
UserID.from_string(user_id)
|
||||||
)
|
)
|
||||||
if user_display_name is None:
|
if user_display_name is None:
|
||||||
user_display_name = user_id
|
user_display_name = user_id
|
||||||
|
@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
filter_collection = await self.filtering.get_user_filter(
|
filter_collection = await self.filtering.get_user_filter(
|
||||||
user_localpart=target_user.localpart, filter_id=filter_id_int
|
user_id=target_user, filter_id=filter_id_int
|
||||||
)
|
)
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
if e.code != 404:
|
if e.code != 404:
|
||||||
|
@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
filter_collection = await self.filtering.get_user_filter(
|
filter_collection = await self.filtering.get_user_filter(
|
||||||
user.localpart, filter_id
|
user, filter_id
|
||||||
)
|
)
|
||||||
except StoreError as err:
|
except StoreError as err:
|
||||||
if err.code != 404:
|
if err.code != 404:
|
||||||
|
@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
async def get_user_filter(
|
async def get_user_filter(
|
||||||
self, user_localpart: str, filter_id: Union[int, str]
|
self, user_id: UserID, filter_id: Union[int, str]
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||||
# with a coherent error message rather than 500 M_UNKNOWN.
|
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||||
@ -156,7 +156,7 @@ class FilteringWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
def_json = await self.db_pool.simple_select_one_onecol(
|
def_json = await self.db_pool.simple_select_one_onecol(
|
||||||
table="user_filters",
|
table="user_filters",
|
||||||
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
|
keyvalues={"full_user_id": user_id.to_string(), "filter_id": filter_id},
|
||||||
retcol="filter_json",
|
retcol="filter_json",
|
||||||
allow_none=False,
|
allow_none=False,
|
||||||
desc="get_user_filter",
|
desc="get_user_filter",
|
||||||
@ -172,15 +172,15 @@ class FilteringWorkerStore(SQLBaseStore):
|
|||||||
def _do_txn(txn: LoggingTransaction) -> int:
|
def _do_txn(txn: LoggingTransaction) -> int:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT filter_id FROM user_filters "
|
"SELECT filter_id FROM user_filters "
|
||||||
"WHERE user_id = ? AND filter_json = ?"
|
"WHERE full_user_id = ? AND filter_json = ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id.localpart, bytearray(def_json)))
|
txn.execute(sql, (user_id.to_string(), bytearray(def_json)))
|
||||||
filter_id_response = txn.fetchone()
|
filter_id_response = txn.fetchone()
|
||||||
if filter_id_response is not None:
|
if filter_id_response is not None:
|
||||||
return filter_id_response[0]
|
return filter_id_response[0]
|
||||||
|
|
||||||
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
|
sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?"
|
||||||
txn.execute(sql, (user_id.localpart,))
|
txn.execute(sql, (user_id.to_string(),))
|
||||||
max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
|
max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
|
||||||
if max_id is None:
|
if max_id is None:
|
||||||
filter_id = 0
|
filter_id = 0
|
||||||
|
@ -137,11 +137,11 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return 50
|
return 50
|
||||||
|
|
||||||
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
|
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
|
||||||
try:
|
try:
|
||||||
profile = await self.db_pool.simple_select_one(
|
profile = await self.db_pool.simple_select_one(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"full_user_id": user_id.to_string()},
|
||||||
retcols=("displayname", "avatar_url"),
|
retcols=("displayname", "avatar_url"),
|
||||||
desc="get_profileinfo",
|
desc="get_profileinfo",
|
||||||
)
|
)
|
||||||
@ -156,18 +156,18 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
|
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
|
||||||
return await self.db_pool.simple_select_one_onecol(
|
return await self.db_pool.simple_select_one_onecol(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"full_user_id": user_id.to_string()},
|
||||||
retcol="displayname",
|
retcol="displayname",
|
||||||
desc="get_profile_displayname",
|
desc="get_profile_displayname",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
|
async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]:
|
||||||
return await self.db_pool.simple_select_one_onecol(
|
return await self.db_pool.simple_select_one_onecol(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"full_user_id": user_id.to_string()},
|
||||||
retcol="avatar_url",
|
retcol="avatar_url",
|
||||||
desc="get_profile_avatar_url",
|
desc="get_profile_avatar_url",
|
||||||
)
|
)
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
SCHEMA_VERSION = 77 # remember to update the list below when updating
|
SCHEMA_VERSION = 78 # remember to update the list below when updating
|
||||||
"""Represents the expectations made by the codebase about the database schema
|
"""Represents the expectations made by the codebase about the database schema
|
||||||
|
|
||||||
This should be incremented whenever the codebase changes its requirements on the
|
This should be incremented whenever the codebase changes its requirements on the
|
||||||
@ -103,6 +103,9 @@ Changes in SCHEMA_VERSION = 76:
|
|||||||
|
|
||||||
Changes in SCHEMA_VERSION = 77
|
Changes in SCHEMA_VERSION = 77
|
||||||
- (Postgres) Add NOT VALID CHECK (full_user_id IS NOT NULL) to tables profiles and user_filters
|
- (Postgres) Add NOT VALID CHECK (full_user_id IS NOT NULL) to tables profiles and user_filters
|
||||||
|
|
||||||
|
Changes in SCHEMA_VERSION = 78
|
||||||
|
- Validate check (full_user_id IS NOT NULL) on tables profiles and user_filters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,92 @@
|
|||||||
|
# Copyright 2023 The Matrix.org Foundation C.I.C
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||||
|
|
||||||
|
|
||||||
|
def run_upgrade(
|
||||||
|
cur: LoggingTransaction,
|
||||||
|
database_engine: BaseDatabaseEngine,
|
||||||
|
config: HomeServerConfig,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Part 3 of a multi-step migration to drop the column `user_id` and replace it with
|
||||||
|
`full_user_id`. See the database schema docs for more information on the full
|
||||||
|
migration steps.
|
||||||
|
"""
|
||||||
|
hostname = config.server.server_name
|
||||||
|
|
||||||
|
if isinstance(database_engine, PostgresEngine):
|
||||||
|
# check if the constraint can be validated
|
||||||
|
check_sql = """
|
||||||
|
SELECT user_id from profiles WHERE full_user_id IS NULL
|
||||||
|
"""
|
||||||
|
cur.execute(check_sql)
|
||||||
|
res = cur.fetchall()
|
||||||
|
|
||||||
|
if res:
|
||||||
|
# there are rows the background job missed, finish them here before we validate the constraint
|
||||||
|
process_rows_sql = """
|
||||||
|
UPDATE profiles
|
||||||
|
SET full_user_id = '@' || user_id || ?
|
||||||
|
WHERE user_id IN (
|
||||||
|
SELECT user_id FROM profiles WHERE full_user_id IS NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
cur.execute(process_rows_sql, (f":{hostname}",))
|
||||||
|
|
||||||
|
# Now we can validate
|
||||||
|
validate_sql = """
|
||||||
|
ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null
|
||||||
|
"""
|
||||||
|
cur.execute(validate_sql)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# in SQLite we need to rewrite the table to add the constraint.
|
||||||
|
# First drop any temporary table that might be here from a previous failed migration.
|
||||||
|
cur.execute("DROP TABLE IF EXISTS temp_profiles")
|
||||||
|
|
||||||
|
create_sql = """
|
||||||
|
CREATE TABLE temp_profiles (
|
||||||
|
full_user_id text NOT NULL,
|
||||||
|
user_id text,
|
||||||
|
displayname text,
|
||||||
|
avatar_url text,
|
||||||
|
UNIQUE (full_user_id),
|
||||||
|
UNIQUE (user_id)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
cur.execute(create_sql)
|
||||||
|
|
||||||
|
copy_sql = """
|
||||||
|
INSERT INTO temp_profiles (
|
||||||
|
user_id,
|
||||||
|
displayname,
|
||||||
|
avatar_url,
|
||||||
|
full_user_id)
|
||||||
|
SELECT user_id, displayname, avatar_url, '@' || user_id || ':' || ? FROM profiles
|
||||||
|
"""
|
||||||
|
cur.execute(copy_sql, (f"{hostname}",))
|
||||||
|
|
||||||
|
drop_sql = """
|
||||||
|
DROP TABLE profiles
|
||||||
|
"""
|
||||||
|
cur.execute(drop_sql)
|
||||||
|
|
||||||
|
rename_sql = """
|
||||||
|
ALTER TABLE temp_profiles RENAME to profiles
|
||||||
|
"""
|
||||||
|
cur.execute(rename_sql)
|
@ -0,0 +1,95 @@
|
|||||||
|
# Copyright 2023 The Matrix.org Foundation C.I.C
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||||
|
|
||||||
|
|
||||||
|
def run_upgrade(
|
||||||
|
cur: LoggingTransaction,
|
||||||
|
database_engine: BaseDatabaseEngine,
|
||||||
|
config: HomeServerConfig,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Part 3 of a multi-step migration to drop the column `user_id` and replace it with
|
||||||
|
`full_user_id`. See the database schema docs for more information on the full
|
||||||
|
migration steps.
|
||||||
|
"""
|
||||||
|
hostname = config.server.server_name
|
||||||
|
|
||||||
|
if isinstance(database_engine, PostgresEngine):
|
||||||
|
# check if the constraint can be validated
|
||||||
|
check_sql = """
|
||||||
|
SELECT user_id from user_filters WHERE full_user_id IS NULL
|
||||||
|
"""
|
||||||
|
cur.execute(check_sql)
|
||||||
|
res = cur.fetchall()
|
||||||
|
|
||||||
|
if res:
|
||||||
|
# there are rows the background job missed, finish them here before we validate constraint
|
||||||
|
process_rows_sql = """
|
||||||
|
UPDATE user_filters
|
||||||
|
SET full_user_id = '@' || user_id || ?
|
||||||
|
WHERE user_id IN (
|
||||||
|
SELECT user_id FROM user_filters WHERE full_user_id IS NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
cur.execute(process_rows_sql, (f":{hostname}",))
|
||||||
|
|
||||||
|
# Now we can validate
|
||||||
|
validate_sql = """
|
||||||
|
ALTER TABLE user_filters VALIDATE CONSTRAINT full_user_id_not_null
|
||||||
|
"""
|
||||||
|
cur.execute(validate_sql)
|
||||||
|
|
||||||
|
else:
|
||||||
|
cur.execute("DROP TABLE IF EXISTS temp_user_filters")
|
||||||
|
create_sql = """
|
||||||
|
CREATE TABLE temp_user_filters (
|
||||||
|
full_user_id text NOT NULL,
|
||||||
|
user_id text NOT NULL,
|
||||||
|
filter_id bigint NOT NULL,
|
||||||
|
filter_json bytea NOT NULL,
|
||||||
|
UNIQUE (full_user_id),
|
||||||
|
UNIQUE (user_id)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
cur.execute(create_sql)
|
||||||
|
|
||||||
|
index_sql = """
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS user_filters_unique ON
|
||||||
|
temp_user_filters (user_id, filter_id)
|
||||||
|
"""
|
||||||
|
cur.execute(index_sql)
|
||||||
|
|
||||||
|
copy_sql = """
|
||||||
|
INSERT INTO temp_user_filters (
|
||||||
|
user_id,
|
||||||
|
filter_id,
|
||||||
|
filter_json,
|
||||||
|
full_user_id)
|
||||||
|
SELECT user_id, filter_id, filter_json, '@' || user_id || ':' || ? FROM user_filters
|
||||||
|
"""
|
||||||
|
cur.execute(copy_sql, (f"{hostname}",))
|
||||||
|
|
||||||
|
drop_sql = """
|
||||||
|
DROP TABLE user_filters
|
||||||
|
"""
|
||||||
|
cur.execute(drop_sql)
|
||||||
|
|
||||||
|
rename_sql = """
|
||||||
|
ALTER TABLE temp_user_filters RENAME to user_filters
|
||||||
|
"""
|
||||||
|
cur.execute(rename_sql)
|
@ -35,7 +35,6 @@ from tests.events.test_utils import MockEvent
|
|||||||
|
|
||||||
user_id = UserID.from_string("@test_user:test")
|
user_id = UserID.from_string("@test_user:test")
|
||||||
user2_id = UserID.from_string("@test_user2:test")
|
user2_id = UserID.from_string("@test_user2:test")
|
||||||
user_localpart = "test_user"
|
|
||||||
|
|
||||||
|
|
||||||
class FilteringTestCase(unittest.HomeserverTestCase):
|
class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
@ -449,9 +448,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
user_filter = self.get_success(
|
user_filter = self.get_success(
|
||||||
self.filtering.get_user_filter(
|
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||||
user_localpart=user_localpart, filter_id=filter_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.get_success(user_filter.filter_presence(presence_states))
|
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||||
@ -479,9 +476,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
user_filter = self.get_success(
|
user_filter = self.get_success(
|
||||||
self.filtering.get_user_filter(
|
self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
|
||||||
user_localpart=user_localpart + "2", filter_id=filter_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.get_success(user_filter.filter_presence(presence_states))
|
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||||
@ -498,9 +493,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = self.get_success(
|
user_filter = self.get_success(
|
||||||
self.filtering.get_user_filter(
|
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||||
user_localpart=user_localpart, filter_id=filter_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.get_success(user_filter.filter_room_state(events=events))
|
results = self.get_success(user_filter.filter_room_state(events=events))
|
||||||
@ -519,9 +512,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = self.get_success(
|
user_filter = self.get_success(
|
||||||
self.filtering.get_user_filter(
|
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||||
user_localpart=user_localpart, filter_id=filter_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.get_success(user_filter.filter_room_state(events))
|
results = self.get_success(user_filter.filter_room_state(events))
|
||||||
@ -603,9 +594,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||||||
user_filter_json,
|
user_filter_json,
|
||||||
(
|
(
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.datastore.get_user_filter(
|
self.datastore.get_user_filter(user_id=user_id, filter_id=0)
|
||||||
user_localpart=user_localpart, filter_id=0
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -620,9 +609,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
filter = self.get_success(
|
filter = self.get_success(
|
||||||
self.filtering.get_user_filter(
|
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||||
user_localpart=user_localpart, filter_id=filter_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(filter.get_filter_json(), user_filter_json)
|
self.assertEqual(filter.get_filter_json(), user_filter_json)
|
||||||
|
@ -80,11 +80,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
(
|
(self.get_success(self.store.get_profile_displayname(self.frank))),
|
||||||
self.get_success(
|
|
||||||
self.store.get_profile_displayname(self.frank.localpart)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
"Frank Jr.",
|
"Frank Jr.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -96,11 +92,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
(
|
(self.get_success(self.store.get_profile_displayname(self.frank))),
|
||||||
self.get_success(
|
|
||||||
self.store.get_profile_displayname(self.frank.localpart)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
"Frank",
|
"Frank",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -112,7 +104,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(
|
self.assertIsNone(
|
||||||
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
self.get_success(self.store.get_profile_displayname(self.frank))
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_my_name_if_disabled(self) -> None:
|
def test_set_my_name_if_disabled(self) -> None:
|
||||||
@ -122,11 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||||||
self.get_success(self.store.set_profile_displayname(self.frank, "Frank"))
|
self.get_success(self.store.set_profile_displayname(self.frank, "Frank"))
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
(
|
(self.get_success(self.store.get_profile_displayname(self.frank))),
|
||||||
self.get_success(
|
|
||||||
self.store.get_profile_displayname(self.frank.localpart)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
"Frank",
|
"Frank",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -201,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||||
"http://my.server/pic.gif",
|
"http://my.server/pic.gif",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -215,7 +203,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||||
"http://my.server/me.png",
|
"http://my.server/me.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -229,7 +217,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(
|
self.assertIsNone(
|
||||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_my_avatar_if_disabled(self) -> None:
|
def test_set_my_avatar_if_disabled(self) -> None:
|
||||||
@ -241,7 +229,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||||
"http://my.server/me.png",
|
"http://my.server/me.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from synapse.module_api import ModuleApi
|
|||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, notifications, presence, profile, room
|
from synapse.rest.client import login, notifications, presence, profile, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, create_requester
|
from synapse.types import JsonDict, UserID, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.events.test_presence_router import send_presence_update, sync_presence
|
from tests.events.test_presence_router import send_presence_update, sync_presence
|
||||||
@ -103,7 +103,9 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
|||||||
self.assertEqual(email["added_at"], 0)
|
self.assertEqual(email["added_at"], 0)
|
||||||
|
|
||||||
# Check that the displayname was assigned
|
# Check that the displayname was assigned
|
||||||
displayname = self.get_success(self.store.get_profile_displayname("bob"))
|
displayname = self.get_success(
|
||||||
|
self.store.get_profile_displayname(UserID.from_string("@bob:test"))
|
||||||
|
)
|
||||||
self.assertEqual(displayname, "Bobberino")
|
self.assertEqual(displayname, "Bobberino")
|
||||||
|
|
||||||
def test_can_register_admin_user(self) -> None:
|
def test_can_register_admin_user(self) -> None:
|
||||||
|
@ -46,7 +46,9 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||||
filter = self.get_success(
|
filter = self.get_success(
|
||||||
self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
self.store.get_user_filter(
|
||||||
|
user_id=UserID.from_string(FilterTestCase.user_id), filter_id=0
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.pump()
|
self.pump()
|
||||||
self.assertEqual(filter, self.EXAMPLE_FILTER)
|
self.assertEqual(filter, self.EXAMPLE_FILTER)
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -35,18 +36,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"Frank",
|
"Frank",
|
||||||
(
|
(self.get_success(self.store.get_profile_displayname(self.u_frank))),
|
||||||
self.get_success(
|
|
||||||
self.store.get_profile_displayname(self.u_frank.localpart)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# test set to None
|
# test set to None
|
||||||
self.get_success(self.store.set_profile_displayname(self.u_frank, None))
|
self.get_success(self.store.set_profile_displayname(self.u_frank, None))
|
||||||
|
|
||||||
self.assertIsNone(
|
self.assertIsNone(
|
||||||
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
|
self.get_success(self.store.get_profile_displayname(self.u_frank))
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_avatar_url(self) -> None:
|
def test_avatar_url(self) -> None:
|
||||||
@ -58,18 +55,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"http://my.site/here",
|
"http://my.site/here",
|
||||||
(
|
(self.get_success(self.store.get_profile_avatar_url(self.u_frank))),
|
||||||
self.get_success(
|
|
||||||
self.store.get_profile_avatar_url(self.u_frank.localpart)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# test set to None
|
# test set to None
|
||||||
self.get_success(self.store.set_profile_avatar_url(self.u_frank, None))
|
self.get_success(self.store.set_profile_avatar_url(self.u_frank, None))
|
||||||
|
|
||||||
self.assertIsNone(
|
self.assertIsNone(
|
||||||
self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
|
self.get_success(self.store.get_profile_avatar_url(self.u_frank))
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_profiles_bg_migration(self) -> None:
|
def test_profiles_bg_migration(self) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user