Convert stats and related calls to async/await (#8192)

This commit is contained in:
Patrick Cloke 2020-08-27 17:24:37 -04:00 committed by GitHub
parent b71d4a094c
commit b49a5b9307
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 78 additions and 77 deletions

1
changelog.d/8192.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List from typing import Dict, List
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.database import DatabasePool, make_in_list_sql_clause
@ -33,11 +33,11 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self.hs = hs self.hs = hs
@cached(num_args=0) @cached(num_args=0)
def get_monthly_active_count(self): async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users """Generates current count of monthly active users
Returns: Returns:
Defered[int]: Number of current monthly active users Number of current monthly active users
""" """
def _count_users(txn): def _count_users(txn):
@ -46,10 +46,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
return self.db_pool.runInteraction("count_users", _count_users) return await self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0) @cached(num_args=0)
def get_monthly_active_count_by_service(self): async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
"""Generates current count of monthly active users broken down by service. """Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users. A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table Since the `monthly_active_users` table is populated from the `user_ips` table
@ -57,8 +57,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
method to return anything other than native matrix users. method to return anything other than native matrix users.
Returns: Returns:
Deferred[dict]: dict that includes a mapping between app_service_id A mapping between app_service_id and the number of occurrences.
and the number of occurrences.
""" """
@ -74,7 +73,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
result = txn.fetchall() result = txn.fetchall()
return dict(result) return dict(result)
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"count_users_by_service", _count_users_by_service "count_users_by_service", _count_users_by_service
) )

View File

@ -15,8 +15,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import Counter
from itertools import chain from itertools import chain
from typing import Any, Dict, Tuple from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.defer import DeferredLock from twisted.internet.defer import DeferredLock
@ -251,21 +252,23 @@ class StatsStore(StateDeltasStore):
desc="update_room_state", desc="update_room_state",
) )
def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): async def get_statistics_for_subject(
self, stats_type: str, stats_id: str, start: str, size: int = 100
) -> List[dict]:
""" """
Get statistics for a given subject. Get statistics for a given subject.
Args: Args:
stats_type (str): The type of subject stats_type: The type of subject
stats_id (str): The ID of the subject (e.g. room_id or user_id) stats_id: The ID of the subject (e.g. room_id or user_id)
start (int): Pagination start. Number of entries, not timestamp. start: Pagination start. Number of entries, not timestamp.
size (int): How many entries to return. size: How many entries to return.
Returns: Returns:
Deferred[list[dict]], where the dict has the keys of A list of dicts, where the dict has the keys of
ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_statistics_for_subject", "get_statistics_for_subject",
self._get_statistics_for_subject_txn, self._get_statistics_for_subject_txn,
stats_type, stats_type,
@ -319,18 +322,17 @@ class StatsStore(StateDeltasStore):
allow_none=True, allow_none=True,
) )
def bulk_update_stats_delta(self, ts, updates, stream_id): async def bulk_update_stats_delta(
self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
) -> None:
"""Bulk update stats tables for a given stream_id and updates the stats """Bulk update stats tables for a given stream_id and updates the stats
incremental position. incremental position.
Args: Args:
ts (int): Current timestamp in ms ts: Current timestamp in ms
updates(dict[str, dict[str, dict[str, Counter]]]): The updates to updates: The updates to commit as a mapping of
commit as a mapping stats_type -> stats_id -> field -> delta. stats_type -> stats_id -> field -> delta.
stream_id (int): Current position. stream_id: Current position.
Returns:
Deferred
""" """
def _bulk_update_stats_delta_txn(txn): def _bulk_update_stats_delta_txn(txn):
@ -355,38 +357,37 @@ class StatsStore(StateDeltasStore):
updatevalues={"stream_id": stream_id}, updatevalues={"stream_id": stream_id},
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"bulk_update_stats_delta", _bulk_update_stats_delta_txn "bulk_update_stats_delta", _bulk_update_stats_delta_txn
) )
def update_stats_delta( async def update_stats_delta(
self, self,
ts, ts: int,
stats_type, stats_type: str,
stats_id, stats_id: str,
fields, fields: Dict[str, int],
complete_with_stream_id, complete_with_stream_id: Optional[int],
absolute_field_overrides=None, absolute_field_overrides: Optional[Dict[str, int]] = None,
): ) -> None:
""" """
Updates the statistics for a subject, with a delta (difference/relative Updates the statistics for a subject, with a delta (difference/relative
change). change).
Args: Args:
ts (int): timestamp of the change ts: timestamp of the change
stats_type (str): "room" or "user" the kind of subject stats_type: "room" or "user" the kind of subject
stats_id (str): the subject's ID (room ID or user ID) stats_id: the subject's ID (room ID or user ID)
fields (dict[str, int]): Deltas of stats values. fields: Deltas of stats values.
complete_with_stream_id (int, optional): complete_with_stream_id:
If supplied, converts an incomplete row into a complete row, If supplied, converts an incomplete row into a complete row,
with the supplied stream_id marked as the stream_id where the with the supplied stream_id marked as the stream_id where the
row was completed. row was completed.
absolute_field_overrides (dict[str, int]): Current stats values absolute_field_overrides: Current stats values (i.e. not deltas) of
(i.e. not deltas) of absolute fields. absolute fields. Does not work with per-slice fields.
Does not work with per-slice fields.
""" """
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"update_stats_delta", "update_stats_delta",
self._update_stats_delta_txn, self._update_stats_delta_txn,
ts, ts,
@ -646,19 +647,20 @@ class StatsStore(StateDeltasStore):
txn, into_table, all_dest_keyvalues, src_row txn, into_table, all_dest_keyvalues, src_row
) )
def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): async def get_changes_room_total_events_and_bytes(
self, min_pos: int, max_pos: int
) -> Dict[str, Dict[str, int]]:
"""Fetches the counts of events in the given range of stream IDs. """Fetches the counts of events in the given range of stream IDs.
Args: Args:
min_pos (int) min_pos
max_pos (int) max_pos
Returns: Returns:
Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field Mapping of room ID to field changes.
changes.
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"stats_incremental_total_events_and_bytes", "stats_incremental_total_events_and_bytes",
self.get_changes_room_total_events_and_bytes_txn, self.get_changes_room_total_events_and_bytes_txn,
min_pos, min_pos,

View File

@ -24,6 +24,7 @@ from synapse.api.errors import ResourceLimitError
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -142,7 +143,7 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_exceeded_large(self): def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) side_effect=lambda: make_awaitable(self.large_number_of_users)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
@ -153,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) side_effect=lambda: make_awaitable(self.large_number_of_users)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -168,7 +169,7 @@ class AuthTestCase(unittest.TestCase):
# If not in monthly active cohort # If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.auth_blocking._max_mau_value) side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -178,7 +179,7 @@ class AuthTestCase(unittest.TestCase):
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.auth_blocking._max_mau_value) side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -188,10 +189,10 @@ class AuthTestCase(unittest.TestCase):
) )
# If in monthly active cohort # If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock( self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec()) side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.auth_blocking._max_mau_value) side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
@ -199,10 +200,10 @@ class AuthTestCase(unittest.TestCase):
) )
) )
self.hs.get_datastore().user_last_seen_monthly_active = Mock( self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec()) side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.auth_blocking._max_mau_value) side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(
@ -215,7 +216,7 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users) side_effect=lambda: make_awaitable(self.small_number_of_users)
) )
# Ensure does not raise exception # Ensure does not raise exception
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -225,7 +226,7 @@ class AuthTestCase(unittest.TestCase):
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users) side_effect=lambda: make_awaitable(self.small_number_of_users)
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(

View File

@ -15,8 +15,6 @@
from mock import Mock from mock import Mock
from twisted.internet import defer
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.api.errors import Codes, ResourceLimitError, SynapseError
@ -102,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_not_blocked(self): def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock( self.store.count_monthly_users = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value - 1) side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1)
) )
# Ensure does not throw exception # Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User")) self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@ -110,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_blocked(self): def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) side_effect=lambda: make_awaitable(self.lots_of_users)
) )
self.get_failure( self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"), self.get_or_create_user(self.requester, "b", "display_name"),
@ -118,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
) )
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
) )
self.get_failure( self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"), self.get_or_create_user(self.requester, "b", "display_name"),
@ -128,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_mau_blocked(self): def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) side_effect=lambda: make_awaitable(self.lots_of_users)
) )
self.get_failure( self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError self.handler.register_user(localpart="local_part"), ResourceLimitError
) )
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
) )
self.get_failure( self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError self.handler.register_user(localpart="local_part"), ResourceLimitError

View File

@ -20,8 +20,6 @@ import urllib.parse
from mock import Mock from mock import Mock
from twisted.internet import defer
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import HttpResponseException, ResourceLimitError from synapse.api.errors import HttpResponseException, ResourceLimitError
@ -29,6 +27,7 @@ from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
@ -338,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit # Set monthly active users to the limit
store.get_monthly_active_count = Mock( store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit
@ -592,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit # Set monthly active users to the limit
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit
@ -632,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit # Set monthly active users to the limit
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit

View File

@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
raise Exception("Failed to find reference to ResourceLimitsServerNotices") raise Exception("Failed to find reference to ResourceLimitsServerNotices")
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000) side_effect=lambda user_id: make_awaitable(1000)
) )
self._rlsn._server_notices_manager.send_notice = Mock( self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock()) return_value=defer.succeed(Mock())
@ -158,7 +158,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None) side_effect=lambda user_id: make_awaitable(None)
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -261,10 +261,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test" self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self): def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000)) self.store.get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(1000)
)
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000) side_effect=lambda user_id: make_awaitable(1000)
) )
# Call the function multiple times to ensure we only send the notice once # Call the function multiple times to ensure we only send the notice once

View File

@ -16,13 +16,12 @@
from mock import Mock from mock import Mock
from twisted.internet import defer
import synapse.rest.admin import synapse.rest.admin
from synapse.http.site import XForwardedForRequest from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
@ -155,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
user_id = "@user:server" user_id = "@user:server"
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users) side_effect=lambda: make_awaitable(lots_of_users)
) )
self.get_success( self.get_success(
self.store.insert_client_ip( self.store.insert_client_ip(