mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
parent
2a1470cd05
commit
864f144543
1
.gitignore
vendored
1
.gitignore
vendored
@ -10,6 +10,7 @@
|
|||||||
*.tac
|
*.tac
|
||||||
_trial_temp/
|
_trial_temp/
|
||||||
_trial_temp*/
|
_trial_temp*/
|
||||||
|
/out
|
||||||
|
|
||||||
# stuff that is likely to exist when you run a server locally
|
# stuff that is likely to exist when you run a server locally
|
||||||
/*.db
|
/*.db
|
||||||
|
1
changelog.d/6150.misc
Normal file
1
changelog.d/6150.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Expand type-checking on modules imported by synapse.config.
|
@ -17,6 +17,7 @@
|
|||||||
"""Contains exceptions and error codes."""
|
"""Contains exceptions and error codes."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
from six.moves import http_client
|
from six.moves import http_client
|
||||||
@ -111,7 +112,7 @@ class ProxiedRequestError(SynapseError):
|
|||||||
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
|
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
|
||||||
super(ProxiedRequestError, self).__init__(code, msg, errcode)
|
super(ProxiedRequestError, self).__init__(code, msg, errcode)
|
||||||
if additional_fields is None:
|
if additional_fields is None:
|
||||||
self._additional_fields = {}
|
self._additional_fields = {} # type: Dict
|
||||||
else:
|
else:
|
||||||
self._additional_fields = dict(additional_fields)
|
self._additional_fields = dict(additional_fields)
|
||||||
|
|
||||||
|
@ -12,6 +12,9 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
|
||||||
@ -102,4 +105,4 @@ KNOWN_ROOM_VERSIONS = {
|
|||||||
RoomVersions.V4,
|
RoomVersions.V4,
|
||||||
RoomVersions.V5,
|
RoomVersions.V5,
|
||||||
)
|
)
|
||||||
} # type: dict[str, RoomVersion]
|
} # type: Dict[str, RoomVersion]
|
||||||
|
@ -263,7 +263,9 @@ def start(hs, listeners=None):
|
|||||||
refresh_certificate(hs)
|
refresh_certificate(hs)
|
||||||
|
|
||||||
# Start the tracer
|
# Start the tracer
|
||||||
synapse.logging.opentracing.init_tracer(hs.config)
|
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
|
||||||
|
hs.config
|
||||||
|
)
|
||||||
|
|
||||||
# It is now safe to start your Synapse.
|
# It is now safe to start your Synapse.
|
||||||
hs.start_listening(listeners)
|
hs.start_listening(listeners)
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from six import string_types
|
from six import string_types
|
||||||
from six.moves.urllib import parse as urlparse
|
from six.moves.urllib import parse as urlparse
|
||||||
@ -56,8 +57,8 @@ def load_appservices(hostname, config_files):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# Dicts of value -> filename
|
# Dicts of value -> filename
|
||||||
seen_as_tokens = {}
|
seen_as_tokens = {} # type: Dict[str, str]
|
||||||
seen_ids = {}
|
seen_ids = {} # type: Dict[str, str]
|
||||||
|
|
||||||
appservices = []
|
appservices = []
|
||||||
|
|
||||||
|
@ -73,8 +73,8 @@ DEFAULT_CONFIG = """\
|
|||||||
|
|
||||||
|
|
||||||
class ConsentConfig(Config):
|
class ConsentConfig(Config):
|
||||||
def __init__(self):
|
def __init__(self, *args):
|
||||||
super(ConsentConfig, self).__init__()
|
super(ConsentConfig, self).__init__(*args)
|
||||||
|
|
||||||
self.user_consent_version = None
|
self.user_consent_version = None
|
||||||
self.user_consent_template_dir = None
|
self.user_consent_template_dir = None
|
||||||
|
@ -13,6 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
from synapse.util.module_loader import load_module
|
from synapse.util.module_loader import load_module
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config
|
||||||
@ -22,7 +24,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
|
|||||||
|
|
||||||
class PasswordAuthProviderConfig(Config):
|
class PasswordAuthProviderConfig(Config):
|
||||||
def read_config(self, config, **kwargs):
|
def read_config(self, config, **kwargs):
|
||||||
self.password_providers = []
|
self.password_providers = [] # type: List[Any]
|
||||||
providers = []
|
providers = []
|
||||||
|
|
||||||
# We want to be backwards compatible with the old `ldap_config`
|
# We want to be backwards compatible with the old `ldap_config`
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
from synapse.python_dependencies import DependencyException, check_requirements
|
from synapse.python_dependencies import DependencyException, check_requirements
|
||||||
from synapse.util.module_loader import load_module
|
from synapse.util.module_loader import load_module
|
||||||
@ -61,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
|
|||||||
Dictionary mapping from media type string to list of
|
Dictionary mapping from media type string to list of
|
||||||
ThumbnailRequirement tuples.
|
ThumbnailRequirement tuples.
|
||||||
"""
|
"""
|
||||||
requirements = {}
|
requirements = {} # type: Dict[str, List]
|
||||||
for size in thumbnail_sizes:
|
for size in thumbnail_sizes:
|
||||||
width = size["width"]
|
width = size["width"]
|
||||||
height = size["height"]
|
height = size["height"]
|
||||||
@ -130,7 +131,7 @@ class ContentRepositoryConfig(Config):
|
|||||||
#
|
#
|
||||||
# We don't create the storage providers here as not all workers need
|
# We don't create the storage providers here as not all workers need
|
||||||
# them to be started.
|
# them to be started.
|
||||||
self.media_storage_providers = []
|
self.media_storage_providers = [] # type: List[tuple]
|
||||||
|
|
||||||
for provider_config in storage_providers:
|
for provider_config in storage_providers:
|
||||||
# We special case the module "file_system" so as not to need to
|
# We special case the module "file_system" so as not to need to
|
||||||
|
@ -19,6 +19,7 @@ import logging
|
|||||||
import os.path
|
import os.path
|
||||||
import re
|
import re
|
||||||
from textwrap import indent
|
from textwrap import indent
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import yaml
|
import yaml
|
||||||
@ -243,7 +244,7 @@ class ServerConfig(Config):
|
|||||||
# events with profile information that differ from the target's global profile.
|
# events with profile information that differ from the target's global profile.
|
||||||
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
|
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
|
||||||
|
|
||||||
self.listeners = []
|
self.listeners = [] # type: List[dict]
|
||||||
for listener in config.get("listeners", []):
|
for listener in config.get("listeners", []):
|
||||||
if not isinstance(listener.get("port", None), int):
|
if not isinstance(listener.get("port", None), int):
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
@ -287,7 +288,10 @@ class ServerConfig(Config):
|
|||||||
validator=attr.validators.instance_of(bool), default=False
|
validator=attr.validators.instance_of(bool), default=False
|
||||||
)
|
)
|
||||||
complexity = attr.ib(
|
complexity = attr.ib(
|
||||||
validator=attr.validators.instance_of((int, float)), default=1.0
|
validator=attr.validators.instance_of(
|
||||||
|
(float, int) # type: ignore[arg-type] # noqa
|
||||||
|
),
|
||||||
|
default=1.0,
|
||||||
)
|
)
|
||||||
complexity_error = attr.ib(
|
complexity_error = attr.ib(
|
||||||
validator=attr.validators.instance_of(str),
|
validator=attr.validators.instance_of(str),
|
||||||
@ -366,7 +370,7 @@ class ServerConfig(Config):
|
|||||||
"cleanup_extremities_with_dummy_events", True
|
"cleanup_extremities_with_dummy_events", True
|
||||||
)
|
)
|
||||||
|
|
||||||
def has_tls_listener(self):
|
def has_tls_listener(self) -> bool:
|
||||||
return any(l["tls"] for l in self.listeners)
|
return any(l["tls"] for l in self.listeners)
|
||||||
|
|
||||||
def generate_config_section(
|
def generate_config_section(
|
||||||
|
@ -59,8 +59,8 @@ class ServerNoticesConfig(Config):
|
|||||||
None if server notices are not enabled.
|
None if server notices are not enabled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, *args):
|
||||||
super(ServerNoticesConfig, self).__init__()
|
super(ServerNoticesConfig, self).__init__(*args)
|
||||||
self.server_notices_mxid = None
|
self.server_notices_mxid = None
|
||||||
self.server_notices_mxid_display_name = None
|
self.server_notices_mxid_display_name = None
|
||||||
self.server_notices_mxid_avatar_url = None
|
self.server_notices_mxid_avatar_url = None
|
||||||
|
@ -170,6 +170,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
@ -547,7 +548,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
|
|||||||
return
|
return
|
||||||
|
|
||||||
span = opentracing.tracer.active_span
|
span = opentracing.tracer.active_span
|
||||||
carrier = {}
|
carrier = {} # type: Dict[str, str]
|
||||||
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
|
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
|
||||||
|
|
||||||
for key, value in carrier.items():
|
for key, value in carrier.items():
|
||||||
@ -584,7 +585,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
|
|||||||
|
|
||||||
span = opentracing.tracer.active_span
|
span = opentracing.tracer.active_span
|
||||||
|
|
||||||
carrier = {}
|
carrier = {} # type: Dict[str, str]
|
||||||
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
|
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
|
||||||
|
|
||||||
for key, value in carrier.items():
|
for key, value in carrier.items():
|
||||||
@ -639,7 +640,7 @@ def get_active_span_text_map(destination=None):
|
|||||||
if destination and not whitelisted_homeserver(destination):
|
if destination and not whitelisted_homeserver(destination):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
carrier = {}
|
carrier = {} # type: Dict[str, str]
|
||||||
opentracing.tracer.inject(
|
opentracing.tracer.inject(
|
||||||
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
|
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
|
||||||
)
|
)
|
||||||
@ -653,7 +654,7 @@ def active_span_context_as_string():
|
|||||||
Returns:
|
Returns:
|
||||||
The active span context encoded as a string.
|
The active span context encoded as a string.
|
||||||
"""
|
"""
|
||||||
carrier = {}
|
carrier = {} # type: Dict[str, str]
|
||||||
if opentracing:
|
if opentracing:
|
||||||
opentracing.tracer.inject(
|
opentracing.tracer.inject(
|
||||||
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
|
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
|
||||||
|
@ -119,7 +119,11 @@ def trace_function(f):
|
|||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
|
|
||||||
s = inspect.currentframe().f_back
|
frame = inspect.currentframe()
|
||||||
|
if frame is None:
|
||||||
|
raise Exception("Can't get current frame!")
|
||||||
|
|
||||||
|
s = frame.f_back
|
||||||
|
|
||||||
to_print = [
|
to_print = [
|
||||||
"\t%s:%s %s. Args: args=%s, kwargs=%s"
|
"\t%s:%s %s. Args: args=%s, kwargs=%s"
|
||||||
@ -144,7 +148,7 @@ def trace_function(f):
|
|||||||
pathname=pathname,
|
pathname=pathname,
|
||||||
lineno=lineno,
|
lineno=lineno,
|
||||||
msg=msg,
|
msg=msg,
|
||||||
args=None,
|
args=tuple(),
|
||||||
exc_info=None,
|
exc_info=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -157,7 +161,12 @@ def trace_function(f):
|
|||||||
|
|
||||||
|
|
||||||
def get_previous_frames():
|
def get_previous_frames():
|
||||||
s = inspect.currentframe().f_back.f_back
|
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
if frame is None:
|
||||||
|
raise Exception("Can't get current frame!")
|
||||||
|
|
||||||
|
s = frame.f_back.f_back
|
||||||
to_return = []
|
to_return = []
|
||||||
while s:
|
while s:
|
||||||
if s.f_globals["__name__"].startswith("synapse"):
|
if s.f_globals["__name__"].startswith("synapse"):
|
||||||
@ -174,7 +183,10 @@ def get_previous_frames():
|
|||||||
|
|
||||||
|
|
||||||
def get_previous_frame(ignore=[]):
|
def get_previous_frame(ignore=[]):
|
||||||
s = inspect.currentframe().f_back.f_back
|
frame = inspect.currentframe()
|
||||||
|
if frame is None:
|
||||||
|
raise Exception("Can't get current frame!")
|
||||||
|
s = frame.f_back.f_back
|
||||||
|
|
||||||
while s:
|
while s:
|
||||||
if s.f_globals["__name__"].startswith("synapse"):
|
if s.f_globals["__name__"].startswith("synapse"):
|
||||||
|
@ -125,7 +125,7 @@ class InFlightGauge(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Counts number of in flight blocks for a given set of label values
|
# Counts number of in flight blocks for a given set of label values
|
||||||
self._registrations = {}
|
self._registrations = {} # type: Dict
|
||||||
|
|
||||||
# Protects access to _registrations
|
# Protects access to _registrations
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
@ -226,7 +226,7 @@ class BucketCollector(object):
|
|||||||
# Fetch the data -- this must be synchronous!
|
# Fetch the data -- this must be synchronous!
|
||||||
data = self.data_collector()
|
data = self.data_collector()
|
||||||
|
|
||||||
buckets = {}
|
buckets = {} # type: Dict[float, int]
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for x in data.keys():
|
for x in data.keys():
|
||||||
|
@ -36,9 +36,9 @@ from twisted.web.resource import Resource
|
|||||||
try:
|
try:
|
||||||
from prometheus_client.samples import Sample
|
from prometheus_client.samples import Sample
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Sample = namedtuple(
|
Sample = namedtuple( # type: ignore[no-redef] # noqa
|
||||||
"Sample", ["name", "labels", "value", "timestamp", "exemplar"]
|
"Sample", ["name", "labels", "value", "timestamp", "exemplar"]
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
|
|
||||||
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
|
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Set
|
from typing import List, Set
|
||||||
|
|
||||||
from pkg_resources import (
|
from pkg_resources import (
|
||||||
DistributionNotFound,
|
DistributionNotFound,
|
||||||
@ -73,6 +73,7 @@ REQUIREMENTS = [
|
|||||||
"netaddr>=0.7.18",
|
"netaddr>=0.7.18",
|
||||||
"Jinja2>=2.9",
|
"Jinja2>=2.9",
|
||||||
"bleach>=1.4.3",
|
"bleach>=1.4.3",
|
||||||
|
"typing-extensions>=3.7.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
CONDITIONAL_REQUIREMENTS = {
|
CONDITIONAL_REQUIREMENTS = {
|
||||||
@ -144,7 +145,11 @@ def check_requirements(for_feature=None):
|
|||||||
deps_needed.append(dependency)
|
deps_needed.append(dependency)
|
||||||
errors.append(
|
errors.append(
|
||||||
"Needed %s, got %s==%s"
|
"Needed %s, got %s==%s"
|
||||||
% (dependency, e.dist.project_name, e.dist.version)
|
% (
|
||||||
|
dependency,
|
||||||
|
e.dist.project_name, # type: ignore[attr-defined] # noqa
|
||||||
|
e.dist.version, # type: ignore[attr-defined] # noqa
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except DistributionNotFound:
|
except DistributionNotFound:
|
||||||
deps_needed.append(dependency)
|
deps_needed.append(dependency)
|
||||||
@ -159,7 +164,7 @@ def check_requirements(for_feature=None):
|
|||||||
if not for_feature:
|
if not for_feature:
|
||||||
# Check the optional dependencies are up to date. We allow them to not be
|
# Check the optional dependencies are up to date. We allow them to not be
|
||||||
# installed.
|
# installed.
|
||||||
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), [])
|
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]
|
||||||
|
|
||||||
for dependency in OPTS:
|
for dependency in OPTS:
|
||||||
try:
|
try:
|
||||||
@ -168,7 +173,11 @@ def check_requirements(for_feature=None):
|
|||||||
deps_needed.append(dependency)
|
deps_needed.append(dependency)
|
||||||
errors.append(
|
errors.append(
|
||||||
"Needed optional %s, got %s==%s"
|
"Needed optional %s, got %s==%s"
|
||||||
% (dependency, e.dist.project_name, e.dist.version)
|
% (
|
||||||
|
dependency,
|
||||||
|
e.dist.project_name, # type: ignore[attr-defined] # noqa
|
||||||
|
e.dist.version, # type: ignore[attr-defined] # noqa
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except DistributionNotFound:
|
except DistributionNotFound:
|
||||||
# If it's not found, we don't care
|
# If it's not found, we don't care
|
||||||
|
@ -318,6 +318,7 @@ class StreamToken(
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
_SEPARATOR = "_"
|
_SEPARATOR = "_"
|
||||||
|
START = None # type: StreamToken
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, string):
|
def from_string(cls, string):
|
||||||
@ -402,7 +403,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||||||
followed by the "stream_ordering" id of the event it comes after.
|
followed by the "stream_ordering" id of the event it comes after.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = []
|
__slots__ = [] # type: list
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse(cls, string):
|
def parse(cls, string):
|
||||||
|
@ -13,9 +13,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Dict, Sequence, Set, Union
|
||||||
|
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
@ -213,7 +215,9 @@ class Linearizer(object):
|
|||||||
# the first element is the number of things executing, and
|
# the first element is the number of things executing, and
|
||||||
# the second element is an OrderedDict, where the keys are deferreds for the
|
# the second element is an OrderedDict, where the keys are deferreds for the
|
||||||
# things blocked from executing.
|
# things blocked from executing.
|
||||||
self.key_to_defer = {}
|
self.key_to_defer = (
|
||||||
|
{}
|
||||||
|
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
|
||||||
|
|
||||||
def queue(self, key):
|
def queue(self, key):
|
||||||
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
|
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
|
||||||
@ -340,10 +344,10 @@ class ReadWriteLock(object):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Latest readers queued
|
# Latest readers queued
|
||||||
self.key_to_current_readers = {}
|
self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
|
||||||
|
|
||||||
# Latest writer queued
|
# Latest writer queued
|
||||||
self.key_to_current_writer = {}
|
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def read(self, key):
|
def read(self, key):
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from six.moves import intern
|
from six.moves import intern
|
||||||
@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):
|
|||||||
|
|
||||||
|
|
||||||
caches_by_name = {}
|
caches_by_name = {}
|
||||||
collectors_by_name = {}
|
collectors_by_name = {} # type: Dict
|
||||||
|
|
||||||
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
|
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
|
||||||
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
|
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
|
||||||
|
@ -18,10 +18,12 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from six import itervalues
|
from six import itervalues
|
||||||
|
|
||||||
from prometheus_client import Gauge
|
from prometheus_client import Gauge
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -37,6 +39,18 @@ from . import register_cache
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _CachedFunction(Protocol):
|
||||||
|
invalidate = None # type: Any
|
||||||
|
invalidate_all = None # type: Any
|
||||||
|
invalidate_many = None # type: Any
|
||||||
|
prefill = None # type: Any
|
||||||
|
cache = None # type: Any
|
||||||
|
num_args = None # type: Any
|
||||||
|
|
||||||
|
def __name__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
cache_pending_metric = Gauge(
|
cache_pending_metric = Gauge(
|
||||||
"synapse_util_caches_cache_pending",
|
"synapse_util_caches_cache_pending",
|
||||||
"Number of lookups currently pending for this cache",
|
"Number of lookups currently pending for this cache",
|
||||||
@ -245,7 +259,9 @@ class Cache(object):
|
|||||||
|
|
||||||
|
|
||||||
class _CacheDescriptorBase(object):
|
class _CacheDescriptorBase(object):
|
||||||
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
|
def __init__(
|
||||||
|
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
|
||||||
|
):
|
||||||
self.orig = orig
|
self.orig = orig
|
||||||
|
|
||||||
if inlineCallbacks:
|
if inlineCallbacks:
|
||||||
@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||||||
return tuple(get_cache_key_gen(args, kwargs))
|
return tuple(get_cache_key_gen(args, kwargs))
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def _wrapped(*args, **kwargs):
|
||||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||||
# whenever we are invalidated
|
# whenever we are invalidated
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||||||
|
|
||||||
return make_deferred_yieldable(observer)
|
return make_deferred_yieldable(observer)
|
||||||
|
|
||||||
|
wrapped = cast(_CachedFunction, _wrapped)
|
||||||
|
|
||||||
if self.num_args == 1:
|
if self.num_args == 1:
|
||||||
wrapped.invalidate = lambda key: cache.invalidate(key[0])
|
wrapped.invalidate = lambda key: cache.invalidate(key[0])
|
||||||
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
|
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
from six import itervalues
|
from six import itervalues
|
||||||
|
|
||||||
SENTINEL = object()
|
SENTINEL = object()
|
||||||
@ -12,7 +14,7 @@ class TreeCache(object):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.size = 0
|
self.size = 0
|
||||||
self.root = {}
|
self.root = {} # type: Dict
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
return self.set(key, value)
|
return self.set(key, value)
|
||||||
|
@ -54,5 +54,5 @@ def load_python_module(location: str):
|
|||||||
if spec is None:
|
if spec is None:
|
||||||
raise Exception("Unable to load module at %s" % (location,))
|
raise Exception("Unable to load module at %s" % (location,))
|
||||||
mod = importlib.util.module_from_spec(spec)
|
mod = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(mod)
|
spec.loader.exec_module(mod) # type: ignore
|
||||||
return mod
|
return mod
|
||||||
|
Loading…
Reference in New Issue
Block a user