Run Black. (#5482)

This commit is contained in:
Amber Brown 2019-06-20 19:32:02 +10:00 committed by GitHub
parent 7dcf984075
commit 32e7c9e7f2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
376 changed files with 9142 additions and 10388 deletions

View file

@ -40,6 +40,7 @@ class Clock(object):
Args:
reactor: The Twisted reactor to use.
"""
_reactor = attr.ib()
@defer.inlineCallbacks
@ -70,9 +71,7 @@ class Clock(object):
call = task.LoopingCall(f)
call.clock = self._reactor
d = call.start(msec / 1000.0, now=False)
d.addErrback(
log_failure, "Looping call died", consumeErrors=False,
)
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call
def call_later(self, delay, callback, *args, **kwargs):
@ -84,6 +83,7 @@ class Clock(object):
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
def wrapped_callback(*args, **kwargs):
with PreserveLoggingContext():
callback(*args, **kwargs)
@ -129,12 +129,7 @@ def log_failure(failure, msg, consumeErrors=True):
"""
logger.error(
msg,
exc_info=(
failure.type,
failure.value,
failure.getTracebackObject()
)
msg, exc_info=(failure.type, failure.value, failure.getTracebackObject())
)
if not consumeErrors:
@ -152,12 +147,12 @@ def glob_to_regex(glob):
Returns:
re.RegexObject
"""
res = ''
res = ""
for c in glob:
if c == '*':
res = res + '.*'
elif c == '?':
res = res + '.'
if c == "*":
res = res + ".*"
elif c == "?":
res = res + "."
else:
res = res + re.escape(c)

View file

@ -95,6 +95,7 @@ class ObservableDeferred(object):
def remove(r):
self._observers.discard(d)
return r
d.addBoth(remove)
self._observers.add(d)
@ -123,7 +124,9 @@ class ObservableDeferred(object):
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
id(self),
self._result,
self._deferred,
)
@ -150,10 +153,12 @@ def concurrently_execute(func, args, limit):
except StopIteration:
pass
return logcontext.make_deferred_yieldable(defer.gatherResults([
run_in_background(_concurrently_execute_inner)
for _ in range(limit)
], consumeErrors=True)).addErrback(unwrapFirstError)
return logcontext.make_deferred_yieldable(
defer.gatherResults(
[run_in_background(_concurrently_execute_inner) for _ in range(limit)],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
def yieldable_gather_results(func, iter, *args, **kwargs):
@ -169,10 +174,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
Deferred[list]: Resolved when all functions have been invoked, or errors if
one of the function calls fails.
"""
return logcontext.make_deferred_yieldable(defer.gatherResults([
run_in_background(func, item, *args, **kwargs)
for item in iter
], consumeErrors=True)).addErrback(unwrapFirstError)
return logcontext.make_deferred_yieldable(
defer.gatherResults(
[run_in_background(func, item, *args, **kwargs) for item in iter],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
class Linearizer(object):
@ -185,6 +192,7 @@ class Linearizer(object):
# do some work.
"""
def __init__(self, name=None, max_count=1, clock=None):
"""
Args:
@ -197,6 +205,7 @@ class Linearizer(object):
if not clock:
from twisted.internet import reactor
clock = Clock(reactor)
self._clock = clock
self.max_count = max_count
@ -221,7 +230,7 @@ class Linearizer(object):
res = self._await_lock(key)
else:
logger.debug(
"Acquired uncontended linearizer lock %r for key %r", self.name, key,
"Acquired uncontended linearizer lock %r for key %r", self.name, key
)
entry[0] += 1
res = defer.succeed(None)
@ -266,9 +275,7 @@ class Linearizer(object):
"""
entry = self.key_to_defer[key]
logger.debug(
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
)
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer = make_deferred_yieldable(defer.Deferred())
entry[1][new_defer] = 1
@ -293,14 +300,14 @@ class Linearizer(object):
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.debug(
"Cancelling wait for linearizer lock %r for key %r",
self.name, key,
"Cancelling wait for linearizer lock %r for key %r", self.name, key
)
else:
logger.warn(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name, key,
self.name,
key,
)
# we just have to take ourselves back out of the queue.
@ -438,7 +445,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
try:
deferred.cancel()
except: # noqa: E722, if we throw any exception it'll break time outs
except: # noqa: E722, if we throw any exception it'll break time outs
logger.exception("Canceller failed during timeout")
if not new_d.called:

View file

@ -104,8 +104,8 @@ def register_cache(cache_type, cache_name, cache):
KNOWN_KEYS = {
key: key for key in
(
key: key
for key in (
"auth_events",
"content",
"depth",
@ -150,7 +150,7 @@ def intern_dict(dictionary):
def _intern_known_values(key, value):
intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",)
intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
if key in intern_keys:
return intern_string(value)

View file

@ -40,9 +40,7 @@ _CacheSentinel = object()
class CacheEntry(object):
__slots__ = [
"deferred", "callbacks", "invalidated"
]
__slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(self, deferred, callbacks):
self.deferred = deferred
@ -73,7 +71,9 @@ class Cache(object):
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type,
max_size=max_entries,
keylen=keylen,
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
)
@ -133,10 +133,7 @@ class Cache(object):
def set(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.check_thread()
entry = CacheEntry(
deferred=value,
callbacks=callbacks,
)
entry = CacheEntry(deferred=value, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
@ -191,9 +188,7 @@ class Cache(object):
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
@ -244,29 +239,25 @@ class _CacheDescriptorBase(object):
raise Exception(
"Not enough explicit positional arguments to key off for %r: "
"got %i args, but wanted %i. (@cached cannot key off *args or "
"**kwargs)"
% (orig.__name__, len(all_args), num_args)
"**kwargs)" % (orig.__name__, len(all_args), num_args)
)
self.num_args = num_args
# list of the names of the args used as the cache key
self.arg_names = all_args[1:num_args + 1]
self.arg_names = all_args[1 : num_args + 1]
# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
if arg_spec.defaults:
self.arg_defaults = dict(zip(
all_args[-len(arg_spec.defaults):],
arg_spec.defaults
))
self.arg_defaults = dict(
zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults)
)
else:
self.arg_defaults = {}
if "cache_context" in self.arg_names:
raise Exception(
"cache_context arg cannot be included among the cache keys"
)
raise Exception("cache_context arg cannot be included among the cache keys")
self.add_cache_context = cache_context
@ -304,12 +295,24 @@ class CacheDescriptor(_CacheDescriptorBase):
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
"""
def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
inlineCallbacks=False, cache_context=False, iterable=False):
def __init__(
self,
orig,
max_entries=1000,
num_args=None,
tree=False,
inlineCallbacks=False,
cache_context=False,
iterable=False,
):
super(CacheDescriptor, self).__init__(
orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
cache_context=cache_context)
orig,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
cache_context=cache_context,
)
max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
@ -356,7 +359,9 @@ class CacheDescriptor(_CacheDescriptorBase):
return args[0]
else:
return self.arg_defaults[nm]
else:
def get_cache_key(args, kwargs):
return tuple(get_cache_key_gen(args, kwargs))
@ -383,8 +388,7 @@ class CacheDescriptor(_CacheDescriptorBase):
except KeyError:
ret = defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call),
obj, *args, **kwargs
logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs
)
def onErr(f):
@ -437,8 +441,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
results.
"""
def __init__(self, orig, cached_method_name, list_name, num_args=None,
inlineCallbacks=False):
def __init__(
self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
):
"""
Args:
orig (function)
@ -451,7 +456,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
be wrapped by defer.inlineCallbacks
"""
super(CacheListDescriptor, self).__init__(
orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
orig, num_args=num_args, inlineCallbacks=inlineCallbacks
)
self.list_name = list_name
@ -463,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
% (self.list_name, cached_method_name,)
% (self.list_name, cached_method_name)
)
def __get__(self, obj, objtype=None):
@ -494,8 +500,10 @@ class CacheListDescriptor(_CacheDescriptorBase):
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
def arg_to_cache_key(arg):
return arg
else:
keylist = list(keyargs)
@ -505,8 +513,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args:
try:
res = cache.get(arg_to_cache_key(arg),
callback=invalidate_callback)
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
@ -554,18 +561,15 @@ class CacheListDescriptor(_CacheDescriptorBase):
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = list(missing)
cached_defers.append(defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call),
**args_to_call
).addCallbacks(complete_all, errback))
cached_defers.append(
defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call), **args_to_call
).addCallbacks(complete_all, errback)
)
if cached_defers:
d = defer.gatherResults(
cached_defers,
consumeErrors=True,
).addCallbacks(
lambda _: results,
unwrapFirstError
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
lambda _: results, unwrapFirstError
)
return logcontext.make_deferred_yieldable(d)
else:
@ -586,8 +590,9 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
iterable=False):
def cached(
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@ -598,8 +603,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
)
def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
cache_context=False, iterable=False):
def cachedInlineCallbacks(
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,

View file

@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
there.
value (dict): The full or partial dict value
"""
def __len__(self):
return len(self.value)
@ -84,13 +85,15 @@ class DictionaryCache(object):
self.metrics.inc_hits()
if dict_keys is None:
return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
return DictionaryEntry(
entry.full, entry.known_absent, dict(entry.value)
)
else:
return DictionaryEntry(entry.full, entry.known_absent, {
k: entry.value[k]
for k in dict_keys
if k in entry.value
})
return DictionaryEntry(
entry.full,
entry.known_absent,
{k: entry.value[k] for k in dict_keys if k in entry.value},
)
self.metrics.inc_misses()
return DictionaryEntry(False, set(), {})

View file

@ -28,8 +28,15 @@ SENTINEL = object()
class ExpiringCache(object):
def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
reset_expiry_on_get=False, iterable=False):
def __init__(
self,
cache_name,
clock,
max_len=0,
expiry_ms=0,
reset_expiry_on_get=False,
iterable=False,
):
"""
Args:
cache_name (str): Name of this cache, used for logging.
@ -67,8 +74,7 @@ class ExpiringCache(object):
def f():
return run_as_background_process(
"prune_cache_%s" % self._cache_name,
self._prune_cache,
"prune_cache_%s" % self._cache_name, self._prune_cache
)
self._clock.looping_call(f, self._expiry_ms / 2)
@ -153,7 +159,9 @@ class ExpiringCache(object):
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
self._cache_name, begin_length, len(self)
self._cache_name,
begin_length,
len(self),
)
def __len__(self):

View file

@ -49,8 +49,15 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
evicted_callback=None):
def __init__(
self,
max_size,
keylen=1,
cache_type=dict,
size_callback=None,
evicted_callback=None,
):
"""
Args:
max_size (int):
@ -93,9 +100,12 @@ class LruCache(object):
cached_cache_len = [0]
if size_callback is not None:
def cache_len():
return cached_cache_len[0]
else:
def cache_len():
return len(cache)

View file

@ -35,12 +35,10 @@ class ResponseCache(object):
self.pending_result_cache = {} # Requests that haven't finished yet.
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.
self.timeout_sec = timeout_ms / 1000.0
self._name = name
self._metrics = register_cache(
"response_cache", name, self
)
self._metrics = register_cache("response_cache", name, self)
def size(self):
return len(self.pending_result_cache)
@ -100,8 +98,7 @@ class ResponseCache(object):
def remove(r):
if self.timeout_sec:
self.clock.call_later(
self.timeout_sec,
self.pending_result_cache.pop, key, None,
self.timeout_sec, self.pending_result_cache.pop, key, None
)
else:
self.pending_result_cache.pop(key, None)
@ -147,14 +144,15 @@ class ResponseCache(object):
"""
result = self.get(key)
if not result:
logger.info("[%s]: no cached result for [%s], calculating new one",
self._name, key)
logger.info(
"[%s]: no cached result for [%s], calculating new one", self._name, key
)
d = run_in_background(callback, *args, **kwargs)
result = self.set(key, d)
elif not isinstance(result, defer.Deferred) or result.called:
logger.info("[%s]: using completed cached result for [%s]",
self._name, key)
logger.info("[%s]: using completed cached result for [%s]", self._name, key)
else:
logger.info("[%s]: using incomplete cached result for [%s]",
self._name, key)
logger.info(
"[%s]: using incomplete cached result for [%s]", self._name, key
)
return make_deferred_yieldable(result)

View file

@ -77,9 +77,8 @@ class StreamChangeCache(object):
if stream_pos >= self._earliest_known_stream_pos:
changed_entities = {
self._cache[k] for k in self._cache.islice(
start=self._cache.bisect_right(stream_pos),
)
self._cache[k]
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
}
result = changed_entities.intersection(entities)
@ -114,8 +113,10 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
return [self._cache[k] for k in self._cache.islice(
start=self._cache.bisect_right(stream_pos))]
return [
self._cache[k]
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
]
else:
return None
@ -136,7 +137,7 @@ class StreamChangeCache(object):
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(
k, self._earliest_known_stream_pos,
k, self._earliest_known_stream_pos
)
self._entity_to_key.pop(r, None)

View file

@ -9,6 +9,7 @@ class TreeCache(object):
efficiently.
Keys must be tuples.
"""
def __init__(self):
self.size = 0
self.root = {}

View file

@ -155,6 +155,7 @@ class TTLCache(object):
@attr.s(frozen=True, slots=True)
class _CacheEntry(object):
"""TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time = attr.ib()
key = attr.ib()

View file

@ -51,9 +51,7 @@ class Distributor(object):
if name in self.signals:
raise KeyError("%r already has a signal named %s" % (self, name))
self.signals[name] = Signal(
name,
)
self.signals[name] = Signal(name)
if name in self.pre_registration:
signal = self.signals[name]
@ -78,11 +76,7 @@ class Distributor(object):
if name not in self.signals:
raise KeyError("%r does not have a signal named %s" % (self, name))
run_as_background_process(
name,
self.signals[name].fire,
*args, **kwargs
)
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
class Signal(object):
@ -118,22 +112,23 @@ class Signal(object):
def eb(failure):
logger.warning(
"%s signal observer %s failed: %r",
self.name, observer, failure,
self.name,
observer,
failure,
exc_info=(
failure.type,
failure.value,
failure.getTracebackObject()))
failure.getTracebackObject(),
),
)
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
deferreds = [
run_in_background(do, o)
for o in self.observers
]
deferreds = [run_in_background(do, o) for o in self.observers]
return make_deferred_yieldable(defer.gatherResults(
deferreds, consumeErrors=True,
))
return make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
def __repr__(self):
return "<Signal name=%r>" % (self.name,)

View file

@ -60,11 +60,10 @@ def _handle_frozendict(obj):
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
return obj._dict
raise TypeError('Object of type %s is not JSON serializable' %
obj.__class__.__name__)
raise TypeError(
"Object of type %s is not JSON serializable" % obj.__class__.__name__
)
# A JSONEncoder which is capable of encoding frozendics without barfing
frozendict_json_encoder = json.JSONEncoder(
default=_handle_frozendict,
)
frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict)

View file

@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource):
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
for path_seg in full_path.split(b'/')[1:-1]:
for path_seg in full_path.split(b"/")[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = NoResource()
@ -60,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
# ===========================
# now attach the actual desired resource
last_path_seg = full_path.split(b'/')[-1]
last_path_seg = full_path.split(b"/")[-1]
# if there is already a resource here, thieve its children and
# replace it
@ -70,9 +70,7 @@ def create_resource_tree(desired_tree, root_resource):
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = _resource_id(
existing_dummy_resource, child_name
)
child_res_id = _resource_id(existing_dummy_resource, child_name)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)

View file

@ -70,7 +70,8 @@ class JsonEncodedObject(object):
dict
"""
d = {
k: _encode(v) for (k, v) in self.__dict__.items()
k: _encode(v)
for (k, v) in self.__dict__.items()
if k in self.valid_keys and k not in self.internal_keys
}
d.update(self.unrecognized_keys)
@ -78,7 +79,8 @@ class JsonEncodedObject(object):
def get_internal_dict(self):
d = {
k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
k: _encode(v, internal=True)
for (k, v) in self.__dict__.items()
if k in self.valid_keys
}
d.update(self.unrecognized_keys)

View file

@ -42,6 +42,8 @@ try:
def get_thread_resource_usage():
return resource.getrusage(RUSAGE_THREAD)
except Exception:
# If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
# won't track resource usage by returning None.
@ -64,8 +66,11 @@ class ContextResourceUsage(object):
"""
__slots__ = [
"ru_stime", "ru_utime",
"db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
"ru_stime",
"ru_utime",
"db_txn_count",
"db_txn_duration_sec",
"db_sched_duration_sec",
"evt_db_fetch_count",
]
@ -91,8 +96,8 @@ class ContextResourceUsage(object):
return ContextResourceUsage(copy_from=self)
def reset(self):
self.ru_stime = 0.
self.ru_utime = 0.
self.ru_stime = 0.0
self.ru_utime = 0.0
self.db_txn_count = 0
self.db_txn_duration_sec = 0
@ -100,15 +105,18 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count = 0
def __repr__(self):
return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
"db_txn_count='%r', db_txn_duration_sec='%r', "
"db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % (
self.ru_stime,
self.ru_utime,
self.db_txn_count,
self.db_txn_duration_sec,
self.db_sched_duration_sec,
self.evt_db_fetch_count,)
return (
"<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
"db_txn_count='%r', db_txn_duration_sec='%r', "
"db_sched_duration_sec='%r', evt_db_fetch_count='%r'>"
) % (
self.ru_stime,
self.ru_utime,
self.db_txn_count,
self.db_txn_duration_sec,
self.db_sched_duration_sec,
self.evt_db_fetch_count,
)
def __iadd__(self, other):
"""Add another ContextResourceUsage's stats to this one's.
@ -159,11 +167,15 @@ class LoggingContext(object):
"""
__slots__ = [
"previous_context", "name", "parent_context",
"previous_context",
"name",
"parent_context",
"_resource_usage",
"usage_start",
"main_thread", "alive",
"request", "tag",
"main_thread",
"alive",
"request",
"tag",
]
thread_local = threading.local()
@ -196,6 +208,7 @@ class LoggingContext(object):
def __nonzero__(self):
return False
__bool__ = __nonzero__ # python3
sentinel = Sentinel()
@ -261,7 +274,8 @@ class LoggingContext(object):
if self.previous_context != old_context:
logger.warn(
"Expected previous context %r, found %r",
self.previous_context, old_context
self.previous_context,
old_context,
)
self.alive = True
@ -285,9 +299,8 @@ class LoggingContext(object):
self.alive = False
# if we have a parent, pass our CPU usage stats on
if (
self.parent_context is not None
and hasattr(self.parent_context, '_resource_usage')
if self.parent_context is not None and hasattr(
self.parent_context, "_resource_usage"
):
self.parent_context._resource_usage += self._resource_usage
@ -320,9 +333,7 @@ class LoggingContext(object):
# When we stop, let's record the cpu used since we started
if not self.usage_start:
logger.warning(
"Called stop on logcontext %s without calling start", self,
)
logger.warning("Called stop on logcontext %s without calling start", self)
return
usage_end = get_thread_resource_usage()
@ -381,6 +392,7 @@ class LoggingContextFilter(logging.Filter):
**defaults: Default values to avoid formatters complaining about
missing fields
"""
def __init__(self, **defaults):
self.defaults = defaults
@ -416,17 +428,12 @@ class PreserveLoggingContext(object):
def __enter__(self):
"""Captures the current logging context"""
self.current_context = LoggingContext.set_current_context(
self.new_context
)
self.current_context = LoggingContext.set_current_context(self.new_context)
if self.current_context:
self.has_parent = self.current_context.previous_context is not None
if not self.current_context.alive:
logger.debug(
"Entering dead context: %s",
self.current_context,
)
logger.debug("Entering dead context: %s", self.current_context)
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
@ -444,10 +451,7 @@ class PreserveLoggingContext(object):
if self.current_context is not LoggingContext.sentinel:
if not self.current_context.alive:
logger.debug(
"Restoring dead context: %s",
self.current_context,
)
logger.debug("Restoring dead context: %s", self.current_context)
def nested_logging_context(suffix, parent_context=None):
@ -474,15 +478,16 @@ def nested_logging_context(suffix, parent_context=None):
if parent_context is None:
parent_context = LoggingContext.current_context()
return LoggingContext(
parent_context=parent_context,
request=parent_context.request + "-" + suffix,
parent_context=parent_context, request=parent_context.request + "-" + suffix
)
def preserve_fn(f):
"""Function decorator which wraps the function with run_in_background"""
def g(*args, **kwargs):
return run_in_background(f, *args, **kwargs)
return g
@ -502,7 +507,7 @@ def run_in_background(f, *args, **kwargs):
current = LoggingContext.current_context()
try:
res = f(*args, **kwargs)
except: # noqa: E722
except: # noqa: E722
# the assumption here is that the caller doesn't want to be disturbed
# by synchronous exceptions, so let's turn them into Failures.
return defer.fail()
@ -639,6 +644,4 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
with LoggingContext(parent_context=logcontext):
return f(*args, **kwargs)
return make_deferred_yieldable(
threads.deferToThreadPool(reactor, threadpool, g)
)
return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))

View file

@ -29,6 +29,7 @@ class LogFormatter(logging.Formatter):
(Normally only stack frames between the point the exception was raised and
where it was caught are logged).
"""
def __init__(self, *args, **kwargs):
super(LogFormatter, self).__init__(*args, **kwargs)
@ -40,7 +41,7 @@ class LogFormatter(logging.Formatter):
# check that we actually have an f_back attribute to work around
# https://twistedmatrix.com/trac/ticket/9305
if tb and hasattr(tb.tb_frame, 'f_back'):
if tb and hasattr(tb.tb_frame, "f_back"):
sio.write("Capture point (most recent call last):\n")
traceback.print_stack(tb.tb_frame.f_back, None, sio)

View file

@ -44,7 +44,7 @@ def _log_debug_as_f(f, msg, msg_args):
lineno=lineno,
msg=msg,
args=msg_args,
exc_info=None
exc_info=None,
)
logger.handle(record)
@ -70,20 +70,11 @@ def log_function(f):
r = r[:50] + "..."
return r
func_args = [
"%s=%s" % (k, format(v)) for k, v in bound_args.items()
]
func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()]
msg_args = {
"func_name": func_name,
"args": ", ".join(func_args)
}
msg_args = {"func_name": func_name, "args": ", ".join(func_args)}
_log_debug_as_f(
f,
"Invoked '%(func_name)s' with args: %(args)s",
msg_args
)
_log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args)
return f(*args, **kwargs)
@ -103,19 +94,13 @@ def time_function(f):
start = time.clock()
try:
_log_debug_as_f(
f,
"[FUNC START] {%s-%d}",
(func_name, id),
)
_log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id))
r = f(*args, **kwargs)
finally:
end = time.clock()
_log_debug_as_f(
f,
"[FUNC END] {%s-%d} %.3f sec",
(func_name, id, end - start,),
f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start)
)
return r
@ -137,9 +122,8 @@ def trace_function(f):
s = inspect.currentframe().f_back
to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s" % (
pathname, linenum, func_name, args, kwargs
)
"\t%s:%s %s. Args: args=%s, kwargs=%s"
% (pathname, linenum, func_name, args, kwargs)
]
while s:
if True or s.f_globals["__name__"].startswith("synapse"):
@ -147,9 +131,7 @@ def trace_function(f):
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
to_print.append(
"\t%s:%d %s. Args: %s" % (
filename, lineno, function, args_string
)
"\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string)
)
s = s.f_back
@ -163,7 +145,7 @@ def trace_function(f):
lineno=lineno,
msg=msg,
args=None,
exc_info=None
exc_info=None,
)
logger.handle(record)
@ -182,13 +164,13 @@ def get_previous_frames():
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
to_return.append("{{ %s:%d %s - Args: %s }}" % (
filename, lineno, function, args_string
))
to_return.append(
"{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string)
)
s = s.f_back
return ", ". join(to_return)
return ", ".join(to_return)
def get_previous_frame(ignore=[]):
@ -201,7 +183,10 @@ def get_previous_frame(ignore=[]):
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
return "{{ %s:%d %s - Args: %s }}" % (
filename, lineno, function, args_string
filename,
lineno,
function,
args_string,
)
s = s.f_back

View file

@ -74,27 +74,25 @@ def manhole(username, password, globals):
twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
"""
if not isinstance(password, bytes):
password = password.encode('ascii')
password = password.encode("ascii")
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
**{username: password}
)
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password})
rlm = manhole_ssh.TerminalRealm()
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
SynapseManhole,
dict(globals, __name__="__console__")
SynapseManhole, dict(globals, __name__="__console__")
)
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY)
factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY)
factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY)
factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY)
return factory
class SynapseManhole(ColoredManhole):
"""Overrides connectionMade to create our own ManholeInterpreter"""
def connectionMade(self):
super(SynapseManhole, self).connectionMade()
@ -127,7 +125,7 @@ class SynapseManholeInterpreter(ManholeInterpreter):
value = SyntaxError(msg, (filename, lineno, offset, line))
sys.last_value = value
lines = traceback.format_exception_only(type, value)
self.write(''.join(lines))
self.write("".join(lines))
def showtraceback(self):
"""Display the exception that just occurred.
@ -140,6 +138,6 @@ class SynapseManholeInterpreter(ManholeInterpreter):
try:
# We remove the first stack item because it is our own code.
lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
self.write(''.join(lines))
self.write("".join(lines))
finally:
last_tb = ei = None

View file

@ -30,25 +30,31 @@ block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"])
block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"])
block_ru_utime = Counter(
"synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"])
"synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]
)
block_ru_stime = Counter(
"synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"])
"synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]
)
block_db_txn_count = Counter(
"synapse_util_metrics_block_db_txn_count", "", ["block_name"])
"synapse_util_metrics_block_db_txn_count", "", ["block_name"]
)
# seconds spent waiting for db txns, excluding scheduling time, in this block
block_db_txn_duration = Counter(
"synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"])
"synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]
)
# seconds spent waiting for a db connection, in this block
block_db_sched_duration = Counter(
"synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
"synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
)
# Tracks the number of blocks currently active
in_flight = InFlightGauge(
"synapse_util_metrics_block_in_flight", "",
"synapse_util_metrics_block_in_flight",
"",
labels=["block_name"],
sub_metrics=["real_time_max", "real_time_sum"],
)
@ -62,13 +68,18 @@ def measure_func(name):
with Measure(self.clock, name):
r = yield func(self, *args, **kwargs)
defer.returnValue(r)
return measured_func
return wrapper
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start",
"clock",
"name",
"start_context",
"start",
"created_context",
"start_usage",
]
@ -108,7 +119,9 @@ class Measure(object):
if context != self.start_context:
logger.warn(
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
self.start_context, context, self.name
self.start_context,
context,
self.name,
)
return
@ -126,8 +139,7 @@ class Measure(object):
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
except ValueError:
logger.warn(
"Failed to save metrics! OLD: %r, NEW: %r",
self.start_usage, current
"Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current
)
if self.created_context:

View file

@ -28,15 +28,13 @@ def load_module(provider):
"""
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
module, clz = provider['module'].rsplit(".", 1)
module, clz = provider["module"].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
try:
provider_config = provider_class.parse_config(provider["config"])
except Exception as e:
raise ConfigError(
"Failed to parse config for %r: %r" % (provider['module'], e)
)
raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
return provider_class, provider_config

View file

@ -36,6 +36,6 @@ def phone_number_to_msisdn(country, number):
phoneNumber = phonenumbers.parse(number, country)
except phonenumbers.NumberParseException:
raise SynapseError(400, "Unable to parse phone number")
return phonenumbers.format_number(
phoneNumber, phonenumbers.PhoneNumberFormat.E164
)[1:]
return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[
1:
]

View file

@ -56,11 +56,7 @@ class FederationRateLimiter(object):
_PerHostRatelimiter
"""
return self.ratelimiters.setdefault(
host,
_PerHostRatelimiter(
clock=self.clock,
config=self._config,
)
host, _PerHostRatelimiter(clock=self.clock, config=self._config)
).ratelimit()
@ -112,8 +108,7 @@ class _PerHostRatelimiter(object):
# remove any entries from request_times which aren't within the window
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
r for r in self.request_times if time_now - r < self.window_size
]
# reject the request if we already have too many queued up (either
@ -121,9 +116,7 @@ class _PerHostRatelimiter(object):
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
retry_after_ms=int(
self.window_size / self.sleep_limit
),
retry_after_ms=int(self.window_size / self.sleep_limit)
)
self.request_times.append(time_now)
@ -143,22 +136,18 @@ class _PerHostRatelimiter(object):
logger.debug(
"Ratelimit [%s]: len(self.request_times)=%d",
id(request_id), len(self.request_times),
id(request_id),
len(self.request_times),
)
if len(self.request_times) > self.sleep_limit:
logger.debug(
"Ratelimiter: sleeping request for %f sec", self.sleep_sec,
)
logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
logger.debug(
"Ratelimit [%s]: Finished sleeping",
id(request_id),
)
logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
@ -168,10 +157,7 @@ class _PerHostRatelimiter(object):
ret_defer = queue_request()
def on_start(r):
logger.debug(
"Ratelimit [%s]: Processing req",
id(request_id),
)
logger.debug("Ratelimit [%s]: Processing req", id(request_id))
self.current_processing.add(request_id)
return r
@ -193,10 +179,7 @@ class _PerHostRatelimiter(object):
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id):
logger.debug(
"Ratelimit [%s]: Processed req",
id(request_id),
)
logger.debug("Ratelimit [%s]: Processed req", id(request_id))
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.

View file

@ -20,9 +20,7 @@ import six
from six import PY2, PY3
from six.moves import range
_string_with_symbols = (
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
)
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# random_string and random_string_with_symbols are used for a range of things,
# some cryptographically important, some less so. We use SystemRandom to make sure
@ -31,13 +29,11 @@ rand = random.SystemRandom()
def random_string(length):
return ''.join(rand.choice(string.ascii_letters) for _ in range(length))
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
return ''.join(
rand.choice(_string_with_symbols) for _ in range(length)
)
return "".join(rand.choice(_string_with_symbols) for _ in range(length))
def is_ascii(s):
@ -45,7 +41,7 @@ def is_ascii(s):
if PY3:
if isinstance(s, bytes):
try:
s.decode('ascii').encode('ascii')
s.decode("ascii").encode("ascii")
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
@ -104,12 +100,12 @@ def exception_to_unicode(e):
# and instead look at what is in the args member.
if len(e.args) == 0:
return u""
return ""
elif len(e.args) > 1:
return six.text_type(repr(e.args))
msg = e.args[0]
if isinstance(msg, bytes):
return msg.decode('utf-8', errors='replace')
return msg.decode("utf-8", errors="replace")
else:
return msg

View file

@ -35,11 +35,13 @@ def check_3pid_allowed(hs, medium, address):
for constraint in hs.config.allowed_local_3pids:
logger.debug(
"Checking 3PID %s (%s) against %s (%s)",
address, medium, constraint['pattern'], constraint['medium'],
address,
medium,
constraint["pattern"],
constraint["medium"],
)
if (
medium == constraint['medium'] and
re.match(constraint['pattern'], address)
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
return True
else:

View file

@ -23,44 +23,53 @@ logger = logging.getLogger(__name__)
def get_version_string(module):
try:
null = open(os.devnull, 'w')
null = open(os.devnull, "w")
cwd = os.path.dirname(os.path.abspath(module.__file__))
try:
git_branch = subprocess.check_output(
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null,
cwd=cwd,
).strip().decode('ascii')
git_branch = (
subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd
)
.strip()
.decode("ascii")
)
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
try:
git_tag = subprocess.check_output(
['git', 'describe', '--exact-match'],
stderr=null,
cwd=cwd,
).strip().decode('ascii')
git_tag = (
subprocess.check_output(
["git", "describe", "--exact-match"], stderr=null, cwd=cwd
)
.strip()
.decode("ascii")
)
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
try:
git_commit = subprocess.check_output(
['git', 'rev-parse', '--short', 'HEAD'],
stderr=null,
cwd=cwd,
).strip().decode('ascii')
git_commit = (
subprocess.check_output(
["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd
)
.strip()
.decode("ascii")
)
except subprocess.CalledProcessError:
git_commit = ""
try:
dirty_string = "-this_is_a_dirty_checkout"
is_dirty = subprocess.check_output(
['git', 'describe', '--dirty=' + dirty_string],
stderr=null,
cwd=cwd,
).strip().decode('ascii').endswith(dirty_string)
is_dirty = (
subprocess.check_output(
["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd
)
.strip()
.decode("ascii")
.endswith(dirty_string)
)
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
@ -68,16 +77,10 @@ def get_version_string(module):
if git_branch or git_tag or git_commit or git_dirty:
git_version = ",".join(
s for s in
(git_branch, git_tag, git_commit, git_dirty,)
if s
s for s in (git_branch, git_tag, git_commit, git_dirty) if s
)
return (
"%s (%s)" % (
module.__version__, git_version,
)
)
return "%s (%s)" % (module.__version__, git_version)
except Exception as e:
logger.info("Failed to check for git repository: %s", e)

View file

@ -69,9 +69,7 @@ class WheelTimer(object):
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
self.entries.extend(
_Entry(key) for key in range(last_key, then_key + 1)
)
self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1))
self.entries[-1].queue.append(obj)