Simplify the way the HomeServer object caches its internal attributes. (#8565)

Changes `@cache_in_self` to use underscore-prefixed attributes.
This commit is contained in:
Jonathan de Jong 2020-11-30 19:28:44 +01:00 committed by GitHub
parent a090b86209
commit ca60822b34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 40 additions and 36 deletions

1
changelog.d/8565.misc Normal file
View File

@ -0,0 +1 @@
Simplify the way the `HomeServer` object caches its internal attributes.

View File

@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "An error was encountered when sending the email") raise SynapseError(500, "An error was encountered when sending the email")
token_expires = ( token_expires = (
self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime self.hs.get_clock().time_msec()
+ self.hs.config.email_validation_token_lifetime
) )
await self.store.start_or_continue_validation_session( await self.store.start_or_continue_validation_session(

View File

@ -115,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors. # comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it # Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something. # look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10) await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)} return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@ -387,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors. # comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it # Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something. # look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10) await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)} return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@ -466,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors. # comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it # Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something. # look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10) await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)} return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)

View File

@ -135,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors. # comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it # Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something. # look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10) await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)} return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@ -214,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors. # comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it # Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something. # look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10) await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)} return 200, {"sid": random_string(16)}
raise SynapseError( raise SynapseError(

View File

@ -66,7 +66,7 @@ class LocalKey(Resource):
def __init__(self, hs): def __init__(self, hs):
self.config = hs.config self.config = hs.config
self.clock = hs.clock self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec()) self.update_response_body(self.clock.time_msec())
Resource.__init__(self) Resource.__init__(self)

View File

@ -147,7 +147,8 @@ def cache_in_self(builder: T) -> T:
"@cache_in_self can only be used on functions starting with `get_`" "@cache_in_self can only be used on functions starting with `get_`"
) )
depname = builder.__name__[len("get_") :] # get_attr -> _attr
depname = builder.__name__[len("get") :]
building = [False] building = [False]
@ -235,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta):
self._instance_id = random_string(5) self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master" self._instance_name = config.worker_name or "master"
self.clock = Clock(reactor)
self.distributor = Distributor()
self.registration_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=config.rc_registration.per_second,
burst_count=config.rc_registration.burst_count,
)
self.version_string = version_string self.version_string = version_string
self.datastores = None # type: Optional[Databases] self.datastores = None # type: Optional[Databases]
@ -301,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta):
def is_mine_id(self, string: str) -> bool: def is_mine_id(self, string: str) -> bool:
return string.split(":", 1)[1] == self.hostname return string.split(":", 1)[1] == self.hostname
@cache_in_self
def get_clock(self) -> Clock: def get_clock(self) -> Clock:
return self.clock return Clock(self._reactor)
def get_datastore(self) -> DataStore: def get_datastore(self) -> DataStore:
if not self.datastores: if not self.datastores:
@ -319,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_config(self) -> HomeServerConfig: def get_config(self) -> HomeServerConfig:
return self.config return self.config
@cache_in_self
def get_distributor(self) -> Distributor: def get_distributor(self) -> Distributor:
return self.distributor return Distributor()
@cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter: def get_registration_ratelimiter(self) -> Ratelimiter:
return self.registration_ratelimiter return Ratelimiter(
clock=self.get_clock(),
rate_hz=self.config.rc_registration.per_second,
burst_count=self.config.rc_registration.burst_count,
)
@cache_in_self @cache_in_self
def get_federation_client(self) -> FederationClient: def get_federation_client(self) -> FederationClient:
@ -687,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self @cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter: def get_federation_ratelimiter(self) -> FederationRateLimiter:
return FederationRateLimiter(self.clock, config=self.config.rc_federation) return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation)
@cache_in_self @cache_in_self
def get_module_api(self) -> ModuleApi: def get_module_api(self) -> ModuleApi:

View File

@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect()) self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self): def test_macaroon_caveats(self):
self.hs.clock.now = 5000 self.hs.get_clock().now = 5000
token = self.macaroon_generator.generate_access_token("a_user") token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000 self.hs.get_clock().now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield defer.ensureDeferred( user_id = yield defer.ensureDeferred(
@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("a_user", user_id) self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected # when we advance the clock, the token should be rejected
self.hs.clock.now = 6000 self.hs.get_clock().now = 6000
with self.assertRaises(synapse.api.errors.AuthError): with self.assertRaises(synapse.api.errors.AuthError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token) self.auth_handler.validate_short_term_login_token_and_get_user_id(token)

View File

@ -78,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler() self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler self.worker_hs._replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs) repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol( self.client = ClientReplicationStreamProtocol(

View File

@ -33,12 +33,15 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver( presence_handler = Mock()
"red", http_client=None, federation_client=Mock() presence_handler.set_state.return_value = defer.succeed(None)
)
hs.presence_handler = Mock() hs = self.setup_test_homeserver(
hs.presence_handler.set_state.return_value = defer.succeed(None) "red",
http_client=None,
federation_client=Mock(),
presence_handler=presence_handler,
)
return hs return hs
@ -55,7 +58,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(self.hs.presence_handler.set_state.call_count, 1) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
def test_put_presence_disabled(self): def test_put_presence_disabled(self):
""" """
@ -70,4 +73,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(self.hs.presence_handler.set_state.call_count, 0) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)

View File

@ -569,7 +569,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do # We need to manually add an email address otherwise the handler will do
# nothing. # nothing.
now = self.hs.clock.time_msec() now = self.hs.get_clock().time_msec()
self.get_success( self.get_success(
self.store.user_add_threepid( self.store.user_add_threepid(
user_id=user_id, user_id=user_id,
@ -587,7 +587,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# We need to manually add an email address otherwise the handler will do # We need to manually add an email address otherwise the handler will do
# nothing. # nothing.
now = self.hs.clock.time_msec() now = self.hs.get_clock().time_msec()
self.get_success( self.get_success(
self.store.user_add_threepid( self.store.user_add_threepid(
user_id=user_id, user_id=user_id,
@ -646,7 +646,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.hs.config.account_validity.startup_job_max_delta = self.max_delta self.hs.config.account_validity.startup_job_max_delta = self.max_delta
now_ms = self.hs.clock.time_msec() now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing()) self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) res = self.get_success(self.store.get_expiration_ts_for_user(user_id))

View File

@ -271,7 +271,7 @@ def setup_test_homeserver(
# Install @cache_in_self attributes # Install @cache_in_self attributes
for key, val in kwargs.items(): for key, val in kwargs.items():
setattr(hs, key, val) setattr(hs, "_" + key, val)
# Mock TLS # Mock TLS
hs.tls_server_context_factory = Mock() hs.tls_server_context_factory = Mock()