mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
N + 3
: Read from column full_user_id
rather than user_id
of tables profiles
and user_filters
(#15649)
This commit is contained in:
parent
e0f2429d13
commit
d0c4257f14
1
changelog.d/15649.misc
Normal file
1
changelog.d/15649.misc
Normal file
@ -0,0 +1 @@
|
||||
Read from column `full_user_id` rather than `user_id` of tables `profiles` and `user_filters`.
|
@ -152,9 +152,9 @@ class Filtering:
|
||||
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
|
||||
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
self, user_id: UserID, filter_id: Union[int, str]
|
||||
) -> "FilterCollection":
|
||||
result = await self.store.get_user_filter(user_localpart, filter_id)
|
||||
result = await self.store.get_user_filter(user_id, filter_id)
|
||||
return FilterCollection(self._hs, result)
|
||||
|
||||
def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> Awaitable[int]:
|
||||
|
@ -164,7 +164,7 @@ class AccountValidityHandler:
|
||||
|
||||
try:
|
||||
user_display_name = await self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
UserID.from_string(user_id)
|
||||
)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
|
@ -89,7 +89,7 @@ class AdminHandler:
|
||||
}
|
||||
|
||||
# Add additional user metadata
|
||||
profile = await self._store.get_profileinfo(user.localpart)
|
||||
profile = await self._store.get_profileinfo(user)
|
||||
threepids = await self._store.user_get_threepids(user.to_string())
|
||||
external_ids = [
|
||||
({"auth_provider": auth_provider, "external_id": external_id})
|
||||
|
@ -1759,7 +1759,7 @@ class AuthHandler:
|
||||
return
|
||||
|
||||
user_profile_data = await self.store.get_profileinfo(
|
||||
UserID.from_string(registered_user_id).localpart
|
||||
UserID.from_string(registered_user_id)
|
||||
)
|
||||
|
||||
# Store any extra attributes which will be passed in the login response.
|
||||
|
@ -297,5 +297,5 @@ class DeactivateAccountHandler:
|
||||
# Add the user to the directory, if necessary. Note that
|
||||
# this must be done after the user is re-activated, because
|
||||
# deactivated users are excluded from the user directory.
|
||||
profile = await self.store.get_profileinfo(user.localpart)
|
||||
profile = await self.store.get_profileinfo(user)
|
||||
await self.user_directory_handler.handle_local_profile_change(user_id, profile)
|
||||
|
@ -67,7 +67,7 @@ class ProfileHandler:
|
||||
target_user = UserID.from_string(user_id)
|
||||
|
||||
if self.hs.is_mine(target_user):
|
||||
profileinfo = await self.store.get_profileinfo(target_user.localpart)
|
||||
profileinfo = await self.store.get_profileinfo(target_user)
|
||||
if profileinfo.display_name is None:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
|
||||
@ -99,9 +99,7 @@ class ProfileHandler:
|
||||
async def get_displayname(self, target_user: UserID) -> Optional[str]:
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
displayname = await self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
)
|
||||
displayname = await self.store.get_profile_displayname(target_user)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
@ -147,7 +145,7 @@ class ProfileHandler:
|
||||
raise AuthError(400, "Cannot set another user's displayname")
|
||||
|
||||
if not by_admin and not self.hs.config.registration.enable_set_displayname:
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
if profile.display_name:
|
||||
raise SynapseError(
|
||||
400,
|
||||
@ -180,7 +178,7 @@ class ProfileHandler:
|
||||
|
||||
await self.store.set_profile_displayname(target_user, displayname_to_set)
|
||||
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
@ -194,9 +192,7 @@ class ProfileHandler:
|
||||
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
avatar_url = await self.store.get_profile_avatar_url(
|
||||
target_user.localpart
|
||||
)
|
||||
avatar_url = await self.store.get_profile_avatar_url(target_user)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
@ -241,7 +237,7 @@ class ProfileHandler:
|
||||
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||
|
||||
if not by_admin and not self.hs.config.registration.enable_set_avatar_url:
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
if profile.avatar_url:
|
||||
raise SynapseError(
|
||||
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
||||
@ -272,7 +268,7 @@ class ProfileHandler:
|
||||
|
||||
await self.store.set_profile_avatar_url(target_user, avatar_url_to_set)
|
||||
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
@ -369,14 +365,10 @@ class ProfileHandler:
|
||||
response = {}
|
||||
try:
|
||||
if just_field is None or just_field == "displayname":
|
||||
response["displayname"] = await self.store.get_profile_displayname(
|
||||
user.localpart
|
||||
)
|
||||
response["displayname"] = await self.store.get_profile_displayname(user)
|
||||
|
||||
if just_field is None or just_field == "avatar_url":
|
||||
response["avatar_url"] = await self.store.get_profile_avatar_url(
|
||||
user.localpart
|
||||
)
|
||||
response["avatar_url"] = await self.store.get_profile_avatar_url(user)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
|
@ -315,7 +315,7 @@ class RegistrationHandler:
|
||||
approved=approved,
|
||||
)
|
||||
|
||||
profile = await self.store.get_profileinfo(localpart)
|
||||
profile = await self.store.get_profileinfo(user)
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
user_id, profile
|
||||
)
|
||||
|
@ -655,7 +655,9 @@ class ModuleApi:
|
||||
Returns:
|
||||
The profile information (i.e. display name and avatar URL).
|
||||
"""
|
||||
return await self._store.get_profileinfo(localpart)
|
||||
server_name = self._hs.hostname
|
||||
user_id = UserID.from_string(f"@{localpart}:{server_name}")
|
||||
return await self._store.get_profileinfo(user_id)
|
||||
|
||||
async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
|
||||
"""Look up the threepids (email addresses and phone numbers) associated with the
|
||||
|
@ -247,7 +247,7 @@ class Mailer:
|
||||
|
||||
try:
|
||||
user_display_name = await self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
UserID.from_string(user_id)
|
||||
)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
|
@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
user_localpart=target_user.localpart, filter_id=filter_id_int
|
||||
user_id=target_user, filter_id=filter_id_int
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code != 404:
|
||||
|
@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet):
|
||||
else:
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
user.localpart, filter_id
|
||||
user, filter_id
|
||||
)
|
||||
except StoreError as err:
|
||||
if err.code != 404:
|
||||
|
@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore):
|
||||
|
||||
@cached(num_args=2)
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
self, user_id: UserID, filter_id: Union[int, str]
|
||||
) -> JsonDict:
|
||||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||
@ -156,7 +156,7 @@ class FilteringWorkerStore(SQLBaseStore):
|
||||
|
||||
def_json = await self.db_pool.simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
|
||||
keyvalues={"full_user_id": user_id.to_string(), "filter_id": filter_id},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
@ -172,15 +172,15 @@ class FilteringWorkerStore(SQLBaseStore):
|
||||
def _do_txn(txn: LoggingTransaction) -> int:
|
||||
sql = (
|
||||
"SELECT filter_id FROM user_filters "
|
||||
"WHERE user_id = ? AND filter_json = ?"
|
||||
"WHERE full_user_id = ? AND filter_json = ?"
|
||||
)
|
||||
txn.execute(sql, (user_id.localpart, bytearray(def_json)))
|
||||
txn.execute(sql, (user_id.to_string(), bytearray(def_json)))
|
||||
filter_id_response = txn.fetchone()
|
||||
if filter_id_response is not None:
|
||||
return filter_id_response[0]
|
||||
|
||||
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
|
||||
txn.execute(sql, (user_id.localpart,))
|
||||
sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?"
|
||||
txn.execute(sql, (user_id.to_string(),))
|
||||
max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
|
||||
if max_id is None:
|
||||
filter_id = 0
|
||||
|
@ -137,11 +137,11 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
|
||||
return 50
|
||||
|
||||
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
|
||||
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
|
||||
try:
|
||||
profile = await self.db_pool.simple_select_one(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
desc="get_profileinfo",
|
||||
)
|
||||
@ -156,18 +156,18 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||
)
|
||||
|
||||
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
|
||||
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
|
||||
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
|
||||
async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcol="avatar_url",
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCHEMA_VERSION = 77 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 78 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
@ -103,6 +103,9 @@ Changes in SCHEMA_VERSION = 76:
|
||||
|
||||
Changes in SCHEMA_VERSION = 77
|
||||
- (Postgres) Add NOT VALID CHECK (full_user_id IS NOT NULL) to tables profiles and user_filters
|
||||
|
||||
Changes in SCHEMA_VERSION = 78
|
||||
- Validate check (full_user_id IS NOT NULL) on tables profiles and user_filters
|
||||
"""
|
||||
|
||||
|
||||
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
|
||||
|
||||
def run_upgrade(
|
||||
cur: LoggingTransaction,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
config: HomeServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Part 3 of a multi-step migration to drop the column `user_id` and replace it with
|
||||
`full_user_id`. See the database schema docs for more information on the full
|
||||
migration steps.
|
||||
"""
|
||||
hostname = config.server.server_name
|
||||
|
||||
if isinstance(database_engine, PostgresEngine):
|
||||
# check if the constraint can be validated
|
||||
check_sql = """
|
||||
SELECT user_id from profiles WHERE full_user_id IS NULL
|
||||
"""
|
||||
cur.execute(check_sql)
|
||||
res = cur.fetchall()
|
||||
|
||||
if res:
|
||||
# there are rows the background job missed, finish them here before we validate the constraint
|
||||
process_rows_sql = """
|
||||
UPDATE profiles
|
||||
SET full_user_id = '@' || user_id || ?
|
||||
WHERE user_id IN (
|
||||
SELECT user_id FROM profiles WHERE full_user_id IS NULL
|
||||
)
|
||||
"""
|
||||
cur.execute(process_rows_sql, (f":{hostname}",))
|
||||
|
||||
# Now we can validate
|
||||
validate_sql = """
|
||||
ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null
|
||||
"""
|
||||
cur.execute(validate_sql)
|
||||
|
||||
else:
|
||||
# in SQLite we need to rewrite the table to add the constraint.
|
||||
# First drop any temporary table that might be here from a previous failed migration.
|
||||
cur.execute("DROP TABLE IF EXISTS temp_profiles")
|
||||
|
||||
create_sql = """
|
||||
CREATE TABLE temp_profiles (
|
||||
full_user_id text NOT NULL,
|
||||
user_id text,
|
||||
displayname text,
|
||||
avatar_url text,
|
||||
UNIQUE (full_user_id),
|
||||
UNIQUE (user_id)
|
||||
)
|
||||
"""
|
||||
cur.execute(create_sql)
|
||||
|
||||
copy_sql = """
|
||||
INSERT INTO temp_profiles (
|
||||
user_id,
|
||||
displayname,
|
||||
avatar_url,
|
||||
full_user_id)
|
||||
SELECT user_id, displayname, avatar_url, '@' || user_id || ':' || ? FROM profiles
|
||||
"""
|
||||
cur.execute(copy_sql, (f"{hostname}",))
|
||||
|
||||
drop_sql = """
|
||||
DROP TABLE profiles
|
||||
"""
|
||||
cur.execute(drop_sql)
|
||||
|
||||
rename_sql = """
|
||||
ALTER TABLE temp_profiles RENAME to profiles
|
||||
"""
|
||||
cur.execute(rename_sql)
|
@ -0,0 +1,95 @@
|
||||
# Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
|
||||
|
||||
def run_upgrade(
|
||||
cur: LoggingTransaction,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
config: HomeServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Part 3 of a multi-step migration to drop the column `user_id` and replace it with
|
||||
`full_user_id`. See the database schema docs for more information on the full
|
||||
migration steps.
|
||||
"""
|
||||
hostname = config.server.server_name
|
||||
|
||||
if isinstance(database_engine, PostgresEngine):
|
||||
# check if the constraint can be validated
|
||||
check_sql = """
|
||||
SELECT user_id from user_filters WHERE full_user_id IS NULL
|
||||
"""
|
||||
cur.execute(check_sql)
|
||||
res = cur.fetchall()
|
||||
|
||||
if res:
|
||||
# there are rows the background job missed, finish them here before we validate constraint
|
||||
process_rows_sql = """
|
||||
UPDATE user_filters
|
||||
SET full_user_id = '@' || user_id || ?
|
||||
WHERE user_id IN (
|
||||
SELECT user_id FROM user_filters WHERE full_user_id IS NULL
|
||||
)
|
||||
"""
|
||||
cur.execute(process_rows_sql, (f":{hostname}",))
|
||||
|
||||
# Now we can validate
|
||||
validate_sql = """
|
||||
ALTER TABLE user_filters VALIDATE CONSTRAINT full_user_id_not_null
|
||||
"""
|
||||
cur.execute(validate_sql)
|
||||
|
||||
else:
|
||||
cur.execute("DROP TABLE IF EXISTS temp_user_filters")
|
||||
create_sql = """
|
||||
CREATE TABLE temp_user_filters (
|
||||
full_user_id text NOT NULL,
|
||||
user_id text NOT NULL,
|
||||
filter_id bigint NOT NULL,
|
||||
filter_json bytea NOT NULL,
|
||||
UNIQUE (full_user_id),
|
||||
UNIQUE (user_id)
|
||||
)
|
||||
"""
|
||||
cur.execute(create_sql)
|
||||
|
||||
index_sql = """
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS user_filters_unique ON
|
||||
temp_user_filters (user_id, filter_id)
|
||||
"""
|
||||
cur.execute(index_sql)
|
||||
|
||||
copy_sql = """
|
||||
INSERT INTO temp_user_filters (
|
||||
user_id,
|
||||
filter_id,
|
||||
filter_json,
|
||||
full_user_id)
|
||||
SELECT user_id, filter_id, filter_json, '@' || user_id || ':' || ? FROM user_filters
|
||||
"""
|
||||
cur.execute(copy_sql, (f"{hostname}",))
|
||||
|
||||
drop_sql = """
|
||||
DROP TABLE user_filters
|
||||
"""
|
||||
cur.execute(drop_sql)
|
||||
|
||||
rename_sql = """
|
||||
ALTER TABLE temp_user_filters RENAME to user_filters
|
||||
"""
|
||||
cur.execute(rename_sql)
|
@ -35,7 +35,6 @@ from tests.events.test_utils import MockEvent
|
||||
|
||||
user_id = UserID.from_string("@test_user:test")
|
||||
user2_id = UserID.from_string("@test_user2:test")
|
||||
user_localpart = "test_user"
|
||||
|
||||
|
||||
class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
@ -449,9 +448,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||
@ -479,9 +476,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart + "2", filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||
@ -498,9 +493,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
events = [event]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_room_state(events=events))
|
||||
@ -519,9 +512,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
events = [event]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_room_state(events))
|
||||
@ -603,9 +594,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
user_filter_json,
|
||||
(
|
||||
self.get_success(
|
||||
self.datastore.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=0
|
||||
)
|
||||
self.datastore.get_user_filter(user_id=user_id, filter_id=0)
|
||||
)
|
||||
),
|
||||
)
|
||||
@ -620,9 +609,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
self.assertEqual(filter.get_filter_json(), user_filter_json)
|
||||
|
@ -80,11 +80,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
(self.get_success(self.store.get_profile_displayname(self.frank))),
|
||||
"Frank Jr.",
|
||||
)
|
||||
|
||||
@ -96,11 +92,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
(self.get_success(self.store.get_profile_displayname(self.frank))),
|
||||
"Frank",
|
||||
)
|
||||
|
||||
@ -112,7 +104,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertIsNone(
|
||||
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
||||
self.get_success(self.store.get_profile_displayname(self.frank))
|
||||
)
|
||||
|
||||
def test_set_my_name_if_disabled(self) -> None:
|
||||
@ -122,11 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.get_success(self.store.set_profile_displayname(self.frank, "Frank"))
|
||||
|
||||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
(self.get_success(self.store.get_profile_displayname(self.frank))),
|
||||
"Frank",
|
||||
)
|
||||
|
||||
@ -201,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||
"http://my.server/pic.gif",
|
||||
)
|
||||
|
||||
@ -215,7 +203,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
|
||||
@ -229,7 +217,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertIsNone(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||
)
|
||||
|
||||
def test_set_my_avatar_if_disabled(self) -> None:
|
||||
@ -241,7 +229,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
|
||||
|
@ -28,7 +28,7 @@ from synapse.module_api import ModuleApi
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, notifications, presence, profile, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, create_requester
|
||||
from synapse.types import JsonDict, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.events.test_presence_router import send_presence_update, sync_presence
|
||||
@ -103,7 +103,9 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||
self.assertEqual(email["added_at"], 0)
|
||||
|
||||
# Check that the displayname was assigned
|
||||
displayname = self.get_success(self.store.get_profile_displayname("bob"))
|
||||
displayname = self.get_success(
|
||||
self.store.get_profile_displayname(UserID.from_string("@bob:test"))
|
||||
)
|
||||
self.assertEqual(displayname, "Bobberino")
|
||||
|
||||
def test_can_register_admin_user(self) -> None:
|
||||
|
@ -46,7 +46,9 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||
filter = self.get_success(
|
||||
self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
||||
self.store.get_user_filter(
|
||||
user_id=UserID.from_string(FilterTestCase.user_id), filter_id=0
|
||||
)
|
||||
)
|
||||
self.pump()
|
||||
self.assertEqual(filter, self.EXAMPLE_FILTER)
|
||||
|
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.server import HomeServer
|
||||
@ -35,18 +36,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(
|
||||
"Frank",
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.u_frank.localpart)
|
||||
)
|
||||
),
|
||||
(self.get_success(self.store.get_profile_displayname(self.u_frank))),
|
||||
)
|
||||
|
||||
# test set to None
|
||||
self.get_success(self.store.set_profile_displayname(self.u_frank, None))
|
||||
|
||||
self.assertIsNone(
|
||||
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
|
||||
self.get_success(self.store.get_profile_displayname(self.u_frank))
|
||||
)
|
||||
|
||||
def test_avatar_url(self) -> None:
|
||||
@ -58,18 +55,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(
|
||||
"http://my.site/here",
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_avatar_url(self.u_frank.localpart)
|
||||
)
|
||||
),
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.u_frank))),
|
||||
)
|
||||
|
||||
# test set to None
|
||||
self.get_success(self.store.set_profile_avatar_url(self.u_frank, None))
|
||||
|
||||
self.assertIsNone(
|
||||
self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
|
||||
self.get_success(self.store.get_profile_avatar_url(self.u_frank))
|
||||
)
|
||||
|
||||
def test_profiles_bg_migration(self) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user