mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Merge branch 'release-v0.9.2'
This commit is contained in:
commit
405f8c4796
23
CHANGES.rst
23
CHANGES.rst
@ -1,3 +1,26 @@
|
|||||||
|
Changes in synapse v0.9.2 (2015-06-12)
|
||||||
|
======================================
|
||||||
|
|
||||||
|
General:
|
||||||
|
|
||||||
|
* Use ultrajson for json (de)serialisation when a canonical encoding is not
|
||||||
|
required. Ultrajson is significantly faster than simplejson in certain
|
||||||
|
circumstances.
|
||||||
|
* Use connection pools for outgoing HTTP connections.
|
||||||
|
* Process thumbnails on separate threads.
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
|
||||||
|
* Add option, ``gzip_responses``, to disable HTTP response compression.
|
||||||
|
|
||||||
|
Federation:
|
||||||
|
|
||||||
|
* Improve resilience of backfill by ensuring we fetch any missing auth events.
|
||||||
|
* Improve performance of backfill and joining remote rooms by removing
|
||||||
|
unnecessary computations. This included handling events we'd previously
|
||||||
|
handled as well as attempting to compute the current state for outliers.
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.9.1 (2015-05-26)
|
Changes in synapse v0.9.1 (2015-05-26)
|
||||||
======================================
|
======================================
|
||||||
|
|
||||||
|
@ -21,3 +21,5 @@ handlers:
|
|||||||
root:
|
root:
|
||||||
level: INFO
|
level: INFO
|
||||||
handlers: [journal]
|
handlers: [journal]
|
||||||
|
|
||||||
|
disable_existing_loggers: False
|
||||||
|
@ -16,4 +16,4 @@
|
|||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.9.1"
|
__version__ = "0.9.2"
|
||||||
|
@ -54,6 +54,8 @@ from synapse.rest.client.v1 import ClientV1RestResource
|
|||||||
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
|
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
|
||||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
|
|
||||||
|
from synapse import events
|
||||||
|
|
||||||
from daemonize import Daemonize
|
from daemonize import Daemonize
|
||||||
import twisted.manhole.telnet
|
import twisted.manhole.telnet
|
||||||
|
|
||||||
@ -85,10 +87,16 @@ class SynapseHomeServer(HomeServer):
|
|||||||
return MatrixFederationHttpClient(self)
|
return MatrixFederationHttpClient(self)
|
||||||
|
|
||||||
def build_resource_for_client(self):
|
def build_resource_for_client(self):
|
||||||
return gz_wrap(ClientV1RestResource(self))
|
res = ClientV1RestResource(self)
|
||||||
|
if self.config.gzip_responses:
|
||||||
|
res = gz_wrap(res)
|
||||||
|
return res
|
||||||
|
|
||||||
def build_resource_for_client_v2_alpha(self):
|
def build_resource_for_client_v2_alpha(self):
|
||||||
return gz_wrap(ClientV2AlphaRestResource(self))
|
res = ClientV2AlphaRestResource(self)
|
||||||
|
if self.config.gzip_responses:
|
||||||
|
res = gz_wrap(res)
|
||||||
|
return res
|
||||||
|
|
||||||
def build_resource_for_federation(self):
|
def build_resource_for_federation(self):
|
||||||
return JsonResource(self)
|
return JsonResource(self)
|
||||||
@ -415,6 +423,8 @@ def setup(config_options):
|
|||||||
logger.info("Server hostname: %s", config.server_name)
|
logger.info("Server hostname: %s", config.server_name)
|
||||||
logger.info("Server version: %s", version_string)
|
logger.info("Server version: %s", version_string)
|
||||||
|
|
||||||
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
if re.search(":[0-9]+$", config.server_name):
|
if re.search(":[0-9]+$", config.server_name):
|
||||||
domain_with_port = config.server_name
|
domain_with_port = config.server_name
|
||||||
else:
|
else:
|
||||||
|
@ -26,6 +26,7 @@ class CaptchaConfig(Config):
|
|||||||
config["captcha_ip_origin_is_x_forwarded"]
|
config["captcha_ip_origin_is_x_forwarded"]
|
||||||
)
|
)
|
||||||
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
|
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
|
||||||
|
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name):
|
def default_config(self, config_dir_path, server_name):
|
||||||
return """\
|
return """\
|
||||||
@ -48,4 +49,7 @@ class CaptchaConfig(Config):
|
|||||||
|
|
||||||
# A secret key used to bypass the captcha test entirely.
|
# A secret key used to bypass the captcha test entirely.
|
||||||
#captcha_bypass_secret: "YOUR_SECRET_HERE"
|
#captcha_bypass_secret: "YOUR_SECRET_HERE"
|
||||||
|
|
||||||
|
# The API endpoint to use for verifying m.login.recaptcha responses.
|
||||||
|
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
|
||||||
"""
|
"""
|
||||||
|
@ -39,7 +39,7 @@ class RegistrationConfig(Config):
|
|||||||
## Registration ##
|
## Registration ##
|
||||||
|
|
||||||
# Enable registration for new users.
|
# Enable registration for new users.
|
||||||
enable_registration: True
|
enable_registration: False
|
||||||
|
|
||||||
# If set, allows registration by anyone who also has the shared
|
# If set, allows registration by anyone who also has the shared
|
||||||
# secret, even if registration is otherwise disabled.
|
# secret, even if registration is otherwise disabled.
|
||||||
|
@ -28,6 +28,8 @@ class ServerConfig(Config):
|
|||||||
self.web_client = config["web_client"]
|
self.web_client = config["web_client"]
|
||||||
self.soft_file_limit = config["soft_file_limit"]
|
self.soft_file_limit = config["soft_file_limit"]
|
||||||
self.daemonize = config.get("daemonize")
|
self.daemonize = config.get("daemonize")
|
||||||
|
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
||||||
|
self.gzip_responses = config["gzip_responses"]
|
||||||
|
|
||||||
# Attempt to guess the content_addr for the v0 content repostitory
|
# Attempt to guess the content_addr for the v0 content repostitory
|
||||||
content_addr = config.get("content_addr")
|
content_addr = config.get("content_addr")
|
||||||
@ -85,6 +87,11 @@ class ServerConfig(Config):
|
|||||||
# Turn on the twisted telnet manhole service on localhost on the given
|
# Turn on the twisted telnet manhole service on localhost on the given
|
||||||
# port.
|
# port.
|
||||||
#manhole: 9000
|
#manhole: 9000
|
||||||
|
|
||||||
|
# Should synapse compress HTTP responses to clients that support it?
|
||||||
|
# This should be disabled if running synapse behind a load balancer
|
||||||
|
# that can do automatic compression.
|
||||||
|
gzip_responses: True
|
||||||
""" % locals()
|
""" % locals()
|
||||||
|
|
||||||
def read_arguments(self, args):
|
def read_arguments(self, args):
|
||||||
|
@ -16,6 +16,12 @@
|
|||||||
from synapse.util.frozenutils import freeze
|
from synapse.util.frozenutils import freeze
|
||||||
|
|
||||||
|
|
||||||
|
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
|
||||||
|
# bugs where we accidentally share e.g. signature dicts. However, converting
|
||||||
|
# a dict to frozen_dicts is expensive.
|
||||||
|
USE_FROZEN_DICTS = True
|
||||||
|
|
||||||
|
|
||||||
class _EventInternalMetadata(object):
|
class _EventInternalMetadata(object):
|
||||||
def __init__(self, internal_metadata_dict):
|
def __init__(self, internal_metadata_dict):
|
||||||
self.__dict__ = dict(internal_metadata_dict)
|
self.__dict__ = dict(internal_metadata_dict)
|
||||||
@ -122,7 +128,10 @@ class FrozenEvent(EventBase):
|
|||||||
|
|
||||||
unsigned = dict(event_dict.pop("unsigned", {}))
|
unsigned = dict(event_dict.pop("unsigned", {}))
|
||||||
|
|
||||||
frozen_dict = freeze(event_dict)
|
if USE_FROZEN_DICTS:
|
||||||
|
frozen_dict = freeze(event_dict)
|
||||||
|
else:
|
||||||
|
frozen_dict = event_dict
|
||||||
|
|
||||||
super(FrozenEvent, self).__init__(
|
super(FrozenEvent, self).__init__(
|
||||||
frozen_dict,
|
frozen_dict,
|
||||||
|
@ -18,8 +18,6 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
from syutil.jsonutil import encode_canonical_json
|
|
||||||
|
|
||||||
from synapse.crypto.event_signing import check_event_content_hash
|
from synapse.crypto.event_signing import check_event_content_hash
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
@ -120,16 +118,15 @@ class FederationBase(object):
|
|||||||
)
|
)
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Signature check failed for %s redacted to %s",
|
"Signature check failed for %s",
|
||||||
encode_canonical_json(pdu.get_pdu_json()),
|
pdu.event_id,
|
||||||
encode_canonical_json(redacted_pdu_json),
|
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if not check_event_content_hash(pdu):
|
if not check_event_content_hash(pdu):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Event content has been tampered, redacting %s, %s",
|
"Event content has been tampered, redacting.",
|
||||||
pdu.event_id, encode_canonical_json(pdu.get_dict())
|
pdu.event_id,
|
||||||
)
|
)
|
||||||
defer.returnValue(redacted_event)
|
defer.returnValue(redacted_event)
|
||||||
|
|
||||||
|
@ -93,6 +93,8 @@ class TransportLayerServer(object):
|
|||||||
|
|
||||||
yield self.keyring.verify_json_for_server(origin, json_request)
|
yield self.keyring.verify_json_for_server(origin, json_request)
|
||||||
|
|
||||||
|
logger.info("Request from %s", origin)
|
||||||
|
|
||||||
defer.returnValue((origin, content))
|
defer.returnValue((origin, content))
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -78,7 +78,9 @@ class BaseHandler(object):
|
|||||||
context = yield state_handler.compute_event_context(builder)
|
context = yield state_handler.compute_event_context(builder)
|
||||||
|
|
||||||
if builder.is_state():
|
if builder.is_state():
|
||||||
builder.prev_state = context.prev_state_events
|
builder.prev_state = yield self.store.add_event_hashes(
|
||||||
|
context.prev_state_events
|
||||||
|
)
|
||||||
|
|
||||||
yield self.auth.add_auth_events(builder, context)
|
yield self.auth.add_auth_events(builder, context)
|
||||||
|
|
||||||
|
@ -187,8 +187,8 @@ class AuthHandler(BaseHandler):
|
|||||||
# each request
|
# each request
|
||||||
try:
|
try:
|
||||||
client = SimpleHttpClient(self.hs)
|
client = SimpleHttpClient(self.hs)
|
||||||
data = yield client.post_urlencoded_get_json(
|
resp_body = yield client.post_urlencoded_get_json(
|
||||||
"https://www.google.com/recaptcha/api/siteverify",
|
self.hs.config.recaptcha_siteverify_api,
|
||||||
args={
|
args={
|
||||||
'secret': self.hs.config.recaptcha_private_key,
|
'secret': self.hs.config.recaptcha_private_key,
|
||||||
'response': user_response,
|
'response': user_response,
|
||||||
@ -198,7 +198,8 @@ class AuthHandler(BaseHandler):
|
|||||||
except PartialDownloadError as pde:
|
except PartialDownloadError as pde:
|
||||||
# Twisted is silly
|
# Twisted is silly
|
||||||
data = pde.response
|
data = pde.response
|
||||||
resp_body = simplejson.loads(data)
|
resp_body = simplejson.loads(data)
|
||||||
|
|
||||||
if 'success' in resp_body and resp_body['success']:
|
if 'success' in resp_body and resp_body['success']:
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||||
|
@ -247,9 +247,15 @@ class FederationHandler(BaseHandler):
|
|||||||
if set(e_id for e_id, _ in ev.prev_events) - event_ids
|
if set(e_id for e_id, _ in ev.prev_events) - event_ids
|
||||||
]
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"backfill: Got %d events with %d edges",
|
||||||
|
len(events), len(edges),
|
||||||
|
)
|
||||||
|
|
||||||
# For each edge get the current state.
|
# For each edge get the current state.
|
||||||
|
|
||||||
auth_events = {}
|
auth_events = {}
|
||||||
|
state_events = {}
|
||||||
events_to_state = {}
|
events_to_state = {}
|
||||||
for e_id in edges:
|
for e_id in edges:
|
||||||
state, auth = yield self.replication_layer.get_state_for_room(
|
state, auth = yield self.replication_layer.get_state_for_room(
|
||||||
@ -258,12 +264,46 @@ class FederationHandler(BaseHandler):
|
|||||||
event_id=e_id
|
event_id=e_id
|
||||||
)
|
)
|
||||||
auth_events.update({a.event_id: a for a in auth})
|
auth_events.update({a.event_id: a for a in auth})
|
||||||
|
auth_events.update({s.event_id: s for s in state})
|
||||||
|
state_events.update({s.event_id: s for s in state})
|
||||||
events_to_state[e_id] = state
|
events_to_state[e_id] = state
|
||||||
|
|
||||||
|
seen_events = yield self.store.have_events(
|
||||||
|
set(auth_events.keys()) | set(state_events.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
all_events = events + state_events.values() + auth_events.values()
|
||||||
|
required_auth = set(
|
||||||
|
a_id for event in all_events for a_id, _ in event.auth_events
|
||||||
|
)
|
||||||
|
|
||||||
|
missing_auth = required_auth - set(auth_events)
|
||||||
|
results = yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
self.replication_layer.get_pdu(
|
||||||
|
[dest],
|
||||||
|
event_id,
|
||||||
|
outlier=True,
|
||||||
|
timeout=10000,
|
||||||
|
)
|
||||||
|
for event_id in missing_auth
|
||||||
|
],
|
||||||
|
consumeErrors=True
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
auth_events.update({a.event_id: a for a in results})
|
||||||
|
|
||||||
yield defer.gatherResults(
|
yield defer.gatherResults(
|
||||||
[
|
[
|
||||||
self._handle_new_event(dest, a)
|
self._handle_new_event(
|
||||||
|
dest, a,
|
||||||
|
auth_events={
|
||||||
|
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||||
|
auth_events[a_id]
|
||||||
|
for a_id, _ in a.auth_events
|
||||||
|
},
|
||||||
|
)
|
||||||
for a in auth_events.values()
|
for a in auth_events.values()
|
||||||
|
if a.event_id not in seen_events
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
@ -274,6 +314,11 @@ class FederationHandler(BaseHandler):
|
|||||||
dest, event_map[e_id],
|
dest, event_map[e_id],
|
||||||
state=events_to_state[e_id],
|
state=events_to_state[e_id],
|
||||||
backfilled=True,
|
backfilled=True,
|
||||||
|
auth_events={
|
||||||
|
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||||
|
auth_events[a_id]
|
||||||
|
for a_id, _ in event_map[e_id].auth_events
|
||||||
|
},
|
||||||
)
|
)
|
||||||
for e_id in events_to_state
|
for e_id in events_to_state
|
||||||
],
|
],
|
||||||
@ -900,8 +945,10 @@ class FederationHandler(BaseHandler):
|
|||||||
event.event_id, event.signatures,
|
event.event_id, event.signatures,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
outlier = event.internal_metadata.is_outlier()
|
||||||
|
|
||||||
context = yield self.state_handler.compute_event_context(
|
context = yield self.state_handler.compute_event_context(
|
||||||
event, old_state=state
|
event, old_state=state, outlier=outlier,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not auth_events:
|
if not auth_events:
|
||||||
@ -912,7 +959,7 @@ class FederationHandler(BaseHandler):
|
|||||||
event.event_id, auth_events,
|
event.event_id, auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_new_state = not event.internal_metadata.is_outlier()
|
is_new_state = not outlier
|
||||||
|
|
||||||
# This is a hack to fix some old rooms where the initial join event
|
# This is a hack to fix some old rooms where the initial join event
|
||||||
# didn't reference the create event in its auth events.
|
# didn't reference the create event in its auth events.
|
||||||
|
@ -20,7 +20,8 @@ import synapse.metrics
|
|||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.web.client import (
|
from twisted.web.client import (
|
||||||
Agent, readBody, FileBodyProducer, PartialDownloadError
|
Agent, readBody, FileBodyProducer, PartialDownloadError,
|
||||||
|
HTTPConnectionPool,
|
||||||
)
|
)
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
@ -55,7 +56,9 @@ class SimpleHttpClient(object):
|
|||||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||||
# 'like a browser'
|
# 'like a browser'
|
||||||
self.agent = Agent(reactor)
|
pool = HTTPConnectionPool(reactor)
|
||||||
|
pool.maxPersistentPerHost = 10
|
||||||
|
self.agent = Agent(reactor, pool=pool)
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
|
|
||||||
def request(self, method, *args, **kwargs):
|
def request(self, method, *args, **kwargs):
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
from twisted.internet import defer, reactor, protocol
|
from twisted.internet import defer, reactor, protocol
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.web.client import readBody, _AgentBase, _URI
|
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web._newclient import ResponseDone
|
from twisted.web._newclient import ResponseDone
|
||||||
|
|
||||||
@ -103,7 +103,9 @@ class MatrixFederationHttpClient(object):
|
|||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.signing_key = hs.config.signing_key[0]
|
self.signing_key = hs.config.signing_key[0]
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.agent = MatrixFederationHttpAgent(reactor)
|
pool = HTTPConnectionPool(reactor)
|
||||||
|
pool.maxPersistentPerHost = 10
|
||||||
|
self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
|
|
||||||
|
@ -19,9 +19,10 @@ from synapse.api.errors import (
|
|||||||
)
|
)
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
import synapse.events
|
||||||
|
|
||||||
from syutil.jsonutil import (
|
from syutil.jsonutil import (
|
||||||
encode_canonical_json, encode_pretty_printed_json
|
encode_canonical_json, encode_pretty_printed_json, encode_json
|
||||||
)
|
)
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
@ -168,9 +169,10 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
|
|
||||||
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
|
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs, canonical_json=True):
|
||||||
resource.Resource.__init__(self)
|
resource.Resource.__init__(self)
|
||||||
|
|
||||||
|
self.canonical_json = canonical_json
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.path_regexs = {}
|
self.path_regexs = {}
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
@ -256,6 +258,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
response_code_message=response_code_message,
|
response_code_message=response_code_message,
|
||||||
pretty_print=_request_user_agent_is_curl(request),
|
pretty_print=_request_user_agent_is_curl(request),
|
||||||
version_string=self.version_string,
|
version_string=self.version_string,
|
||||||
|
canonical_json=self.canonical_json,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -277,11 +280,16 @@ class RootRedirect(resource.Resource):
|
|||||||
|
|
||||||
def respond_with_json(request, code, json_object, send_cors=False,
|
def respond_with_json(request, code, json_object, send_cors=False,
|
||||||
response_code_message=None, pretty_print=False,
|
response_code_message=None, pretty_print=False,
|
||||||
version_string=""):
|
version_string="", canonical_json=True):
|
||||||
if pretty_print:
|
if pretty_print:
|
||||||
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
||||||
else:
|
else:
|
||||||
json_bytes = encode_canonical_json(json_object)
|
if canonical_json:
|
||||||
|
json_bytes = encode_canonical_json(json_object)
|
||||||
|
else:
|
||||||
|
json_bytes = encode_json(
|
||||||
|
json_object, using_frozen_dicts=synapse.events.USE_FROZEN_DICTS
|
||||||
|
)
|
||||||
|
|
||||||
return respond_with_json_bytes(
|
return respond_with_json_bytes(
|
||||||
request, code, json_bytes,
|
request, code, json_bytes,
|
||||||
|
@ -24,6 +24,7 @@ import baserules
|
|||||||
import logging
|
import logging
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import re
|
import re
|
||||||
|
import random
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -256,134 +257,154 @@ class Pusher(object):
|
|||||||
logger.info("Pusher %s for user %s starting from token %s",
|
logger.info("Pusher %s for user %s starting from token %s",
|
||||||
self.pushkey, self.user_name, self.last_token)
|
self.pushkey, self.user_name, self.last_token)
|
||||||
|
|
||||||
|
wait = 0
|
||||||
while self.alive:
|
while self.alive:
|
||||||
from_tok = StreamToken.from_string(self.last_token)
|
try:
|
||||||
config = PaginationConfig(from_token=from_tok, limit='1')
|
if wait > 0:
|
||||||
chunk = yield self.evStreamHandler.get_stream(
|
yield synapse.util.async.sleep(wait)
|
||||||
self.user_name, config,
|
yield self.get_and_dispatch()
|
||||||
timeout=100*365*24*60*60*1000, affect_presence=False
|
wait = 0
|
||||||
)
|
except:
|
||||||
|
if wait == 0:
|
||||||
# limiting to 1 may get 1 event plus 1 presence event, so
|
wait = 1
|
||||||
# pick out the actual event
|
else:
|
||||||
single_event = None
|
wait = min(wait * 2, 1800)
|
||||||
for c in chunk['chunk']:
|
logger.exception(
|
||||||
if 'event_id' in c: # Hmmm...
|
"Exception in pusher loop for pushkey %s. Pausing for %ds",
|
||||||
single_event = c
|
self.pushkey, wait
|
||||||
break
|
|
||||||
if not single_event:
|
|
||||||
self.last_token = chunk['end']
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not self.alive:
|
|
||||||
continue
|
|
||||||
|
|
||||||
processed = False
|
|
||||||
actions = yield self._actions_for_event(single_event)
|
|
||||||
tweaks = _tweaks_for_actions(actions)
|
|
||||||
|
|
||||||
if len(actions) == 0:
|
|
||||||
logger.warn("Empty actions! Using default action.")
|
|
||||||
actions = Pusher.DEFAULT_ACTIONS
|
|
||||||
|
|
||||||
if 'notify' not in actions and 'dont_notify' not in actions:
|
|
||||||
logger.warn("Neither notify nor dont_notify in actions: adding default")
|
|
||||||
actions.extend(Pusher.DEFAULT_ACTIONS)
|
|
||||||
|
|
||||||
if 'dont_notify' in actions:
|
|
||||||
logger.debug(
|
|
||||||
"%s for %s: dont_notify",
|
|
||||||
single_event['event_id'], self.user_name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_and_dispatch(self):
|
||||||
|
from_tok = StreamToken.from_string(self.last_token)
|
||||||
|
config = PaginationConfig(from_token=from_tok, limit='1')
|
||||||
|
timeout = (300 + random.randint(-60, 60)) * 1000
|
||||||
|
chunk = yield self.evStreamHandler.get_stream(
|
||||||
|
self.user_name, config,
|
||||||
|
timeout=timeout, affect_presence=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# limiting to 1 may get 1 event plus 1 presence event, so
|
||||||
|
# pick out the actual event
|
||||||
|
single_event = None
|
||||||
|
for c in chunk['chunk']:
|
||||||
|
if 'event_id' in c: # Hmmm...
|
||||||
|
single_event = c
|
||||||
|
break
|
||||||
|
if not single_event:
|
||||||
|
self.last_token = chunk['end']
|
||||||
|
logger.debug("Event stream timeout for pushkey %s", self.pushkey)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.alive:
|
||||||
|
return
|
||||||
|
|
||||||
|
processed = False
|
||||||
|
actions = yield self._actions_for_event(single_event)
|
||||||
|
tweaks = _tweaks_for_actions(actions)
|
||||||
|
|
||||||
|
if len(actions) == 0:
|
||||||
|
logger.warn("Empty actions! Using default action.")
|
||||||
|
actions = Pusher.DEFAULT_ACTIONS
|
||||||
|
|
||||||
|
if 'notify' not in actions and 'dont_notify' not in actions:
|
||||||
|
logger.warn("Neither notify nor dont_notify in actions: adding default")
|
||||||
|
actions.extend(Pusher.DEFAULT_ACTIONS)
|
||||||
|
|
||||||
|
if 'dont_notify' in actions:
|
||||||
|
logger.debug(
|
||||||
|
"%s for %s: dont_notify",
|
||||||
|
single_event['event_id'], self.user_name
|
||||||
|
)
|
||||||
|
processed = True
|
||||||
|
else:
|
||||||
|
rejected = yield self.dispatch_push(single_event, tweaks)
|
||||||
|
self.has_unread = True
|
||||||
|
if isinstance(rejected, list) or isinstance(rejected, tuple):
|
||||||
processed = True
|
processed = True
|
||||||
else:
|
for pk in rejected:
|
||||||
rejected = yield self.dispatch_push(single_event, tweaks)
|
if pk != self.pushkey:
|
||||||
self.has_unread = True
|
# for sanity, we only remove the pushkey if it
|
||||||
if isinstance(rejected, list) or isinstance(rejected, tuple):
|
# was the one we actually sent...
|
||||||
processed = True
|
logger.warn(
|
||||||
for pk in rejected:
|
("Ignoring rejected pushkey %s because we"
|
||||||
if pk != self.pushkey:
|
" didn't send it"), pk
|
||||||
# for sanity, we only remove the pushkey if it
|
)
|
||||||
# was the one we actually sent...
|
else:
|
||||||
logger.warn(
|
logger.info(
|
||||||
("Ignoring rejected pushkey %s because we"
|
"Pushkey %s was rejected: removing",
|
||||||
" didn't send it"), pk
|
pk
|
||||||
)
|
)
|
||||||
else:
|
yield self.hs.get_pusherpool().remove_pusher(
|
||||||
logger.info(
|
self.app_id, pk, self.user_name
|
||||||
"Pushkey %s was rejected: removing",
|
)
|
||||||
pk
|
|
||||||
)
|
|
||||||
yield self.hs.get_pusherpool().remove_pusher(
|
|
||||||
self.app_id, pk, self.user_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.alive:
|
if not self.alive:
|
||||||
continue
|
return
|
||||||
|
|
||||||
if processed:
|
if processed:
|
||||||
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
||||||
self.last_token = chunk['end']
|
self.last_token = chunk['end']
|
||||||
self.store.update_pusher_last_token_and_success(
|
self.store.update_pusher_last_token_and_success(
|
||||||
|
self.app_id,
|
||||||
|
self.pushkey,
|
||||||
|
self.user_name,
|
||||||
|
self.last_token,
|
||||||
|
self.clock.time_msec()
|
||||||
|
)
|
||||||
|
if self.failing_since:
|
||||||
|
self.failing_since = None
|
||||||
|
self.store.update_pusher_failing_since(
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.pushkey,
|
self.pushkey,
|
||||||
self.user_name,
|
self.user_name,
|
||||||
self.last_token,
|
self.failing_since)
|
||||||
self.clock.time_msec()
|
else:
|
||||||
|
if not self.failing_since:
|
||||||
|
self.failing_since = self.clock.time_msec()
|
||||||
|
self.store.update_pusher_failing_since(
|
||||||
|
self.app_id,
|
||||||
|
self.pushkey,
|
||||||
|
self.user_name,
|
||||||
|
self.failing_since
|
||||||
|
)
|
||||||
|
|
||||||
|
if (self.failing_since and
|
||||||
|
self.failing_since <
|
||||||
|
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
|
||||||
|
# we really only give up so that if the URL gets
|
||||||
|
# fixed, we don't suddenly deliver a load
|
||||||
|
# of old notifications.
|
||||||
|
logger.warn("Giving up on a notification to user %s, "
|
||||||
|
"pushkey %s",
|
||||||
|
self.user_name, self.pushkey)
|
||||||
|
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
||||||
|
self.last_token = chunk['end']
|
||||||
|
self.store.update_pusher_last_token(
|
||||||
|
self.app_id,
|
||||||
|
self.pushkey,
|
||||||
|
self.user_name,
|
||||||
|
self.last_token
|
||||||
|
)
|
||||||
|
|
||||||
|
self.failing_since = None
|
||||||
|
self.store.update_pusher_failing_since(
|
||||||
|
self.app_id,
|
||||||
|
self.pushkey,
|
||||||
|
self.user_name,
|
||||||
|
self.failing_since
|
||||||
)
|
)
|
||||||
if self.failing_since:
|
|
||||||
self.failing_since = None
|
|
||||||
self.store.update_pusher_failing_since(
|
|
||||||
self.app_id,
|
|
||||||
self.pushkey,
|
|
||||||
self.user_name,
|
|
||||||
self.failing_since)
|
|
||||||
else:
|
else:
|
||||||
if not self.failing_since:
|
logger.warn("Failed to dispatch push for user %s "
|
||||||
self.failing_since = self.clock.time_msec()
|
"(failing for %dms)."
|
||||||
self.store.update_pusher_failing_since(
|
"Trying again in %dms",
|
||||||
self.app_id,
|
self.user_name,
|
||||||
self.pushkey,
|
self.clock.time_msec() - self.failing_since,
|
||||||
self.user_name,
|
self.backoff_delay)
|
||||||
self.failing_since
|
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
|
||||||
)
|
self.backoff_delay *= 2
|
||||||
|
if self.backoff_delay > Pusher.MAX_BACKOFF:
|
||||||
if (self.failing_since and
|
self.backoff_delay = Pusher.MAX_BACKOFF
|
||||||
self.failing_since <
|
|
||||||
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
|
|
||||||
# we really only give up so that if the URL gets
|
|
||||||
# fixed, we don't suddenly deliver a load
|
|
||||||
# of old notifications.
|
|
||||||
logger.warn("Giving up on a notification to user %s, "
|
|
||||||
"pushkey %s",
|
|
||||||
self.user_name, self.pushkey)
|
|
||||||
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
|
||||||
self.last_token = chunk['end']
|
|
||||||
self.store.update_pusher_last_token(
|
|
||||||
self.app_id,
|
|
||||||
self.pushkey,
|
|
||||||
self.user_name,
|
|
||||||
self.last_token
|
|
||||||
)
|
|
||||||
|
|
||||||
self.failing_since = None
|
|
||||||
self.store.update_pusher_failing_since(
|
|
||||||
self.app_id,
|
|
||||||
self.pushkey,
|
|
||||||
self.user_name,
|
|
||||||
self.failing_since
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warn("Failed to dispatch push for user %s "
|
|
||||||
"(failing for %dms)."
|
|
||||||
"Trying again in %dms",
|
|
||||||
self.user_name,
|
|
||||||
self.clock.time_msec() - self.failing_since,
|
|
||||||
self.backoff_delay)
|
|
||||||
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
|
|
||||||
self.backoff_delay *= 2
|
|
||||||
if self.backoff_delay > Pusher.MAX_BACKOFF:
|
|
||||||
self.backoff_delay = Pusher.MAX_BACKOFF
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.alive = False
|
self.alive = False
|
||||||
|
@ -18,7 +18,7 @@ from distutils.version import LooseVersion
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
REQUIREMENTS = {
|
REQUIREMENTS = {
|
||||||
"syutil>=0.0.6": ["syutil>=0.0.6"],
|
"syutil>=0.0.7": ["syutil>=0.0.7"],
|
||||||
"Twisted==14.0.2": ["twisted==14.0.2"],
|
"Twisted==14.0.2": ["twisted==14.0.2"],
|
||||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||||
@ -30,6 +30,7 @@ REQUIREMENTS = {
|
|||||||
"frozendict>=0.4": ["frozendict"],
|
"frozendict>=0.4": ["frozendict"],
|
||||||
"pillow": ["PIL"],
|
"pillow": ["PIL"],
|
||||||
"pydenticon": ["pydenticon"],
|
"pydenticon": ["pydenticon"],
|
||||||
|
"ujson": ["ujson"],
|
||||||
}
|
}
|
||||||
CONDITIONAL_REQUIREMENTS = {
|
CONDITIONAL_REQUIREMENTS = {
|
||||||
"web_client": {
|
"web_client": {
|
||||||
@ -52,8 +53,8 @@ def github_link(project, version, egg):
|
|||||||
DEPENDENCY_LINKS = [
|
DEPENDENCY_LINKS = [
|
||||||
github_link(
|
github_link(
|
||||||
project="matrix-org/syutil",
|
project="matrix-org/syutil",
|
||||||
version="v0.0.6",
|
version="v0.0.7",
|
||||||
egg="syutil-0.0.6",
|
egg="syutil-0.0.7",
|
||||||
),
|
),
|
||||||
github_link(
|
github_link(
|
||||||
project="matrix-org/matrix-angular-sdk",
|
project="matrix-org/matrix-angular-sdk",
|
||||||
|
@ -25,7 +25,7 @@ class ClientV1RestResource(JsonResource):
|
|||||||
"""A resource for version 1 of the matrix client API."""
|
"""A resource for version 1 of the matrix client API."""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
JsonResource.__init__(self, hs)
|
JsonResource.__init__(self, hs, canonical_json=False)
|
||||||
self.register_servlets(self, hs)
|
self.register_servlets(self, hs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -28,7 +28,7 @@ class ClientV2AlphaRestResource(JsonResource):
|
|||||||
"""A resource for version 2 alpha of the matrix client API."""
|
"""A resource for version 2 alpha of the matrix client API."""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
JsonResource.__init__(self, hs)
|
JsonResource.__init__(self, hs, canonical_json=False)
|
||||||
self.register_servlets(self, hs)
|
self.register_servlets(self, hs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -15,13 +15,14 @@
|
|||||||
|
|
||||||
from .thumbnailer import Thumbnailer
|
from .thumbnailer import Thumbnailer
|
||||||
|
|
||||||
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.http.server import respond_with_json
|
from synapse.http.server import respond_with_json
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
cs_error, Codes, SynapseError
|
cs_error, Codes, SynapseError
|
||||||
)
|
)
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, threads
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
|
|
||||||
@ -52,7 +53,7 @@ class BaseMediaResource(Resource):
|
|||||||
def __init__(self, hs, filepaths):
|
def __init__(self, hs, filepaths):
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.client = hs.get_http_client()
|
self.client = MatrixFederationHttpClient(hs)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
@ -273,57 +274,65 @@ class BaseMediaResource(Resource):
|
|||||||
if not requirements:
|
if not requirements:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
remote_thumbnails = []
|
||||||
|
|
||||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||||
thumbnailer = Thumbnailer(input_path)
|
thumbnailer = Thumbnailer(input_path)
|
||||||
m_width = thumbnailer.width
|
m_width = thumbnailer.width
|
||||||
m_height = thumbnailer.height
|
m_height = thumbnailer.height
|
||||||
|
|
||||||
if m_width * m_height >= self.max_image_pixels:
|
def generate_thumbnails():
|
||||||
logger.info(
|
if m_width * m_height >= self.max_image_pixels:
|
||||||
"Image too large to thumbnail %r x %r > %r",
|
logger.info(
|
||||||
m_width, m_height, self.max_image_pixels
|
"Image too large to thumbnail %r x %r > %r",
|
||||||
)
|
m_width, m_height, self.max_image_pixels
|
||||||
return
|
)
|
||||||
|
return
|
||||||
|
|
||||||
scales = set()
|
scales = set()
|
||||||
crops = set()
|
crops = set()
|
||||||
for r_width, r_height, r_method, r_type in requirements:
|
for r_width, r_height, r_method, r_type in requirements:
|
||||||
if r_method == "scale":
|
if r_method == "scale":
|
||||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||||
scales.add((
|
scales.add((
|
||||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||||
))
|
))
|
||||||
elif r_method == "crop":
|
elif r_method == "crop":
|
||||||
crops.add((r_width, r_height, r_type))
|
crops.add((r_width, r_height, r_type))
|
||||||
|
|
||||||
for t_width, t_height, t_type in scales:
|
for t_width, t_height, t_type in scales:
|
||||||
t_method = "scale"
|
t_method = "scale"
|
||||||
t_path = self.filepaths.remote_media_thumbnail(
|
t_path = self.filepaths.remote_media_thumbnail(
|
||||||
server_name, file_id, t_width, t_height, t_type, t_method
|
server_name, file_id, t_width, t_height, t_type, t_method
|
||||||
)
|
)
|
||||||
self._makedirs(t_path)
|
self._makedirs(t_path)
|
||||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||||
yield self.store.store_remote_media_thumbnail(
|
remote_thumbnails.append([
|
||||||
server_name, media_id, file_id,
|
server_name, media_id, file_id,
|
||||||
t_width, t_height, t_type, t_method, t_len
|
t_width, t_height, t_type, t_method, t_len
|
||||||
)
|
])
|
||||||
|
|
||||||
for t_width, t_height, t_type in crops:
|
for t_width, t_height, t_type in crops:
|
||||||
if (t_width, t_height, t_type) in scales:
|
if (t_width, t_height, t_type) in scales:
|
||||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||||
# scaled one then there is no point in calculating a separate
|
# scaled one then there is no point in calculating a separate
|
||||||
# thumbnail.
|
# thumbnail.
|
||||||
continue
|
continue
|
||||||
t_method = "crop"
|
t_method = "crop"
|
||||||
t_path = self.filepaths.remote_media_thumbnail(
|
t_path = self.filepaths.remote_media_thumbnail(
|
||||||
server_name, file_id, t_width, t_height, t_type, t_method
|
server_name, file_id, t_width, t_height, t_type, t_method
|
||||||
)
|
)
|
||||||
self._makedirs(t_path)
|
self._makedirs(t_path)
|
||||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||||
yield self.store.store_remote_media_thumbnail(
|
remote_thumbnails.append([
|
||||||
server_name, media_id, file_id,
|
server_name, media_id, file_id,
|
||||||
t_width, t_height, t_type, t_method, t_len
|
t_width, t_height, t_type, t_method, t_len
|
||||||
)
|
])
|
||||||
|
|
||||||
|
yield threads.deferToThread(generate_thumbnails)
|
||||||
|
|
||||||
|
for r in remote_thumbnails:
|
||||||
|
yield self.store.store_remote_media_thumbnail(*r)
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"width": m_width,
|
"width": m_width,
|
||||||
|
@ -106,7 +106,7 @@ class StateHandler(object):
|
|||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def compute_event_context(self, event, old_state=None):
|
def compute_event_context(self, event, old_state=None, outlier=False):
|
||||||
""" Fills out the context with the `current state` of the graph. The
|
""" Fills out the context with the `current state` of the graph. The
|
||||||
`current state` here is defined to be the state of the event graph
|
`current state` here is defined to be the state of the event graph
|
||||||
just before the event - i.e. it never includes `event`
|
just before the event - i.e. it never includes `event`
|
||||||
@ -119,9 +119,23 @@ class StateHandler(object):
|
|||||||
Returns:
|
Returns:
|
||||||
an EventContext
|
an EventContext
|
||||||
"""
|
"""
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
context = EventContext()
|
context = EventContext()
|
||||||
|
|
||||||
yield run_on_reactor()
|
if outlier:
|
||||||
|
# If this is an outlier, then we know it shouldn't have any current
|
||||||
|
# state. Certainly store.get_current_state won't return any, and
|
||||||
|
# persisting the event won't store the state group.
|
||||||
|
if old_state:
|
||||||
|
context.current_state = {
|
||||||
|
(s.type, s.state_key): s for s in old_state
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
context.current_state = {}
|
||||||
|
context.prev_state_events = []
|
||||||
|
context.state_group = None
|
||||||
|
defer.returnValue(context)
|
||||||
|
|
||||||
if old_state:
|
if old_state:
|
||||||
context.current_state = {
|
context.current_state = {
|
||||||
@ -155,10 +169,6 @@ class StateHandler(object):
|
|||||||
context.current_state = curr_state
|
context.current_state = curr_state
|
||||||
context.state_group = group if not event.is_state() else None
|
context.state_group = group if not event.is_state() else None
|
||||||
|
|
||||||
prev_state = yield self.store.add_event_hashes(
|
|
||||||
prev_state
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.current_state:
|
if key in context.current_state:
|
||||||
|
@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 19
|
SCHEMA_VERSION = 20
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
@ -348,7 +348,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
|||||||
module_name, absolute_path, python_file
|
module_name, absolute_path, python_file
|
||||||
)
|
)
|
||||||
logger.debug("Running script %s", relative_path)
|
logger.debug("Running script %s", relative_path)
|
||||||
module.run_upgrade(cur)
|
module.run_upgrade(cur, database_engine)
|
||||||
elif ext == ".sql":
|
elif ext == ".sql":
|
||||||
# A plain old .sql file, just read and execute it
|
# A plain old .sql file, just read and execute it
|
||||||
logger.debug("Applying schema %s", relative_path)
|
logger.debug("Applying schema %s", relative_path)
|
||||||
|
@ -127,7 +127,7 @@ class Cache(object):
|
|||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
def cached(max_entries=1000, num_args=1, lru=False):
|
class CacheDescriptor(object):
|
||||||
""" A method decorator that applies a memoizing cache around the function.
|
""" A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
The function is presumed to take zero or more arguments, which are used in
|
The function is presumed to take zero or more arguments, which are used in
|
||||||
@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False):
|
|||||||
which can be used to insert values into the cache specifically, without
|
which can be used to insert values into the cache specifically, without
|
||||||
calling the calculation function.
|
calling the calculation function.
|
||||||
"""
|
"""
|
||||||
def wrap(orig):
|
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
|
||||||
|
self.orig = orig
|
||||||
|
|
||||||
|
self.max_entries = max_entries
|
||||||
|
self.num_args = num_args
|
||||||
|
self.lru = lru
|
||||||
|
|
||||||
|
def __get__(self, obj, objtype=None):
|
||||||
cache = Cache(
|
cache = Cache(
|
||||||
name=orig.__name__,
|
name=self.orig.__name__,
|
||||||
max_entries=max_entries,
|
max_entries=self.max_entries,
|
||||||
keylen=num_args,
|
keylen=self.num_args,
|
||||||
lru=lru,
|
lru=self.lru,
|
||||||
)
|
)
|
||||||
|
|
||||||
@functools.wraps(orig)
|
@functools.wraps(self.orig)
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wrapped(self, *keyargs):
|
def wrapped(*keyargs):
|
||||||
try:
|
try:
|
||||||
cached_result = cache.get(*keyargs)
|
cached_result = cache.get(*keyargs[:self.num_args])
|
||||||
if DEBUG_CACHES:
|
if DEBUG_CACHES:
|
||||||
actual_result = yield orig(self, *keyargs)
|
actual_result = yield self.orig(obj, *keyargs)
|
||||||
if actual_result != cached_result:
|
if actual_result != cached_result:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Stale cache entry %s%r: cached: %r, actual %r",
|
"Stale cache entry %s%r: cached: %r, actual %r",
|
||||||
orig.__name__, keyargs,
|
self.orig.__name__, keyargs,
|
||||||
cached_result, actual_result,
|
cached_result, actual_result,
|
||||||
)
|
)
|
||||||
raise ValueError("Stale cache entry")
|
raise ValueError("Stale cache entry")
|
||||||
@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False):
|
|||||||
# while the SELECT is executing (SYN-369)
|
# while the SELECT is executing (SYN-369)
|
||||||
sequence = cache.sequence
|
sequence = cache.sequence
|
||||||
|
|
||||||
ret = yield orig(self, *keyargs)
|
ret = yield self.orig(obj, *keyargs)
|
||||||
|
|
||||||
cache.update(sequence, *keyargs + (ret,))
|
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
wrapped.invalidate = cache.invalidate
|
wrapped.invalidate = cache.invalidate
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
wrapped.invalidate_all = cache.invalidate_all
|
||||||
wrapped.prefill = cache.prefill
|
wrapped.prefill = cache.prefill
|
||||||
|
|
||||||
|
obj.__dict__[self.orig.__name__] = wrapped
|
||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
return wrap
|
|
||||||
|
def cached(max_entries=1000, num_args=1, lru=False):
|
||||||
|
return lambda orig: CacheDescriptor(
|
||||||
|
orig,
|
||||||
|
max_entries=max_entries,
|
||||||
|
num_args=num_args,
|
||||||
|
lru=lru
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoggingTransaction(object):
|
class LoggingTransaction(object):
|
||||||
|
@ -17,7 +17,7 @@ from _base import SQLBaseStore, _RollbackButIsFineException
|
|||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
from synapse.util.logcontext import preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_context_over_deferred
|
||||||
@ -26,11 +26,11 @@ from synapse.api.constants import EventTypes
|
|||||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||||
|
|
||||||
from syutil.base64util import decode_base64
|
from syutil.base64util import decode_base64
|
||||||
from syutil.jsonutil import encode_canonical_json
|
from syutil.jsonutil import encode_json
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import simplejson as json
|
import ujson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -166,8 +166,9 @@ class EventsStore(SQLBaseStore):
|
|||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_json = encode_canonical_json(
|
metadata_json = encode_json(
|
||||||
event.internal_metadata.get_dict()
|
event.internal_metadata.get_dict(),
|
||||||
|
using_frozen_dicts=USE_FROZEN_DICTS
|
||||||
).decode("UTF-8")
|
).decode("UTF-8")
|
||||||
|
|
||||||
# If we have already persisted this event, we don't need to do any
|
# If we have already persisted this event, we don't need to do any
|
||||||
@ -235,12 +236,14 @@ class EventsStore(SQLBaseStore):
|
|||||||
"event_id": event.event_id,
|
"event_id": event.event_id,
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"internal_metadata": metadata_json,
|
"internal_metadata": metadata_json,
|
||||||
"json": encode_canonical_json(event_dict).decode("UTF-8"),
|
"json": encode_json(
|
||||||
|
event_dict, using_frozen_dicts=USE_FROZEN_DICTS
|
||||||
|
).decode("UTF-8"),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
content = encode_canonical_json(
|
content = encode_json(
|
||||||
event.content
|
event.content, using_frozen_dicts=USE_FROZEN_DICTS
|
||||||
).decode("UTF-8")
|
).decode("UTF-8")
|
||||||
|
|
||||||
vals = {
|
vals = {
|
||||||
@ -266,8 +269,8 @@ class EventsStore(SQLBaseStore):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
vals["unrecognized_keys"] = encode_canonical_json(
|
vals["unrecognized_keys"] = encode_json(
|
||||||
unrec
|
unrec, using_frozen_dicts=USE_FROZEN_DICTS
|
||||||
).decode("UTF-8")
|
).decode("UTF-8")
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
@ -733,7 +736,8 @@ class EventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
because = yield self.get_event(
|
because = yield self.get_event(
|
||||||
redaction_id,
|
redaction_id,
|
||||||
check_redacted=False
|
check_redacted=False,
|
||||||
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if because:
|
if because:
|
||||||
@ -743,6 +747,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
prev = yield self.get_event(
|
prev = yield self.get_event(
|
||||||
ev.unsigned["replaces_state"],
|
ev.unsigned["replaces_state"],
|
||||||
get_prev_content=False,
|
get_prev_content=False,
|
||||||
|
allow_none=True,
|
||||||
)
|
)
|
||||||
if prev:
|
if prev:
|
||||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||||
|
@ -18,7 +18,7 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_upgrade(cur):
|
def run_upgrade(cur, *args, **kwargs):
|
||||||
cur.execute("SELECT id, regex FROM application_services_regex")
|
cur.execute("SELECT id, regex FROM application_services_regex")
|
||||||
for row in cur.fetchall():
|
for row in cur.fetchall():
|
||||||
try:
|
try:
|
||||||
|
76
synapse/storage/schema/delta/20/pushers.py
Normal file
76
synapse/storage/schema/delta/20/pushers.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Main purpose of this upgrade is to change the unique key on the
|
||||||
|
pushers table again (it was missed when the v16 full schema was
|
||||||
|
made) but this also changes the pushkey and data columns to text.
|
||||||
|
When selecting a bytea column into a text column, postgres inserts
|
||||||
|
the hex encoded data, and there's no portable way of getting the
|
||||||
|
UTF-8 bytes, so we have to do it in Python.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||||
|
logger.info("Porting pushers table...")
|
||||||
|
cur.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS pushers2 (
|
||||||
|
id BIGINT PRIMARY KEY,
|
||||||
|
user_name TEXT NOT NULL,
|
||||||
|
access_token BIGINT DEFAULT NULL,
|
||||||
|
profile_tag VARCHAR(32) NOT NULL,
|
||||||
|
kind VARCHAR(8) NOT NULL,
|
||||||
|
app_id VARCHAR(64) NOT NULL,
|
||||||
|
app_display_name VARCHAR(64) NOT NULL,
|
||||||
|
device_display_name VARCHAR(128) NOT NULL,
|
||||||
|
pushkey TEXT NOT NULL,
|
||||||
|
ts BIGINT NOT NULL,
|
||||||
|
lang VARCHAR(8),
|
||||||
|
data TEXT,
|
||||||
|
last_token TEXT,
|
||||||
|
last_success BIGINT,
|
||||||
|
failing_since BIGINT,
|
||||||
|
UNIQUE (app_id, pushkey, user_name)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
cur.execute("""SELECT
|
||||||
|
id, user_name, access_token, profile_tag, kind,
|
||||||
|
app_id, app_display_name, device_display_name,
|
||||||
|
pushkey, ts, lang, data, last_token, last_success,
|
||||||
|
failing_since
|
||||||
|
FROM pushers
|
||||||
|
""")
|
||||||
|
count = 0
|
||||||
|
for row in cur.fetchall():
|
||||||
|
row = list(row)
|
||||||
|
row[8] = bytes(row[8]).decode("utf-8")
|
||||||
|
row[11] = bytes(row[11]).decode("utf-8")
|
||||||
|
cur.execute(database_engine.convert_param_style("""
|
||||||
|
INSERT into pushers2 (
|
||||||
|
id, user_name, access_token, profile_tag, kind,
|
||||||
|
app_id, app_display_name, device_display_name,
|
||||||
|
pushkey, ts, lang, data, last_token, last_success,
|
||||||
|
failing_since
|
||||||
|
) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
|
||||||
|
row
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
cur.execute("DROP TABLE pushers")
|
||||||
|
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
|
||||||
|
logger.info("Moved %d pushers to new table", count)
|
@ -81,19 +81,23 @@ class StateStore(SQLBaseStore):
|
|||||||
f,
|
f,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
state_list = yield defer.gatherResults(
|
||||||
def c(vals):
|
|
||||||
vals[:] = yield self._get_events(vals, get_prev_content=False)
|
|
||||||
|
|
||||||
yield defer.gatherResults(
|
|
||||||
[
|
[
|
||||||
c(vals)
|
self._fetch_events_for_group(group, vals)
|
||||||
for vals in states.values()
|
for group, vals in states.items()
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(states)
|
defer.returnValue(dict(state_list))
|
||||||
|
|
||||||
|
@cached(num_args=1)
|
||||||
|
def _fetch_events_for_group(self, state_group, events):
|
||||||
|
return self._get_events(
|
||||||
|
events, get_prev_content=False
|
||||||
|
).addCallback(
|
||||||
|
lambda evs: (state_group, evs)
|
||||||
|
)
|
||||||
|
|
||||||
def _store_state_groups_txn(self, txn, event, context):
|
def _store_state_groups_txn(self, txn, event, context):
|
||||||
if context.current_state is None:
|
if context.current_state is None:
|
||||||
|
@ -13,8 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
|
||||||
|
|
||||||
|
|
||||||
class JsonEncodedObject(object):
|
class JsonEncodedObject(object):
|
||||||
""" A common base class for defining protocol units that are represented
|
""" A common base class for defining protocol units that are represented
|
||||||
@ -76,15 +74,7 @@ class JsonEncodedObject(object):
|
|||||||
if k in self.valid_keys and k not in self.internal_keys
|
if k in self.valid_keys and k not in self.internal_keys
|
||||||
}
|
}
|
||||||
d.update(self.unrecognized_keys)
|
d.update(self.unrecognized_keys)
|
||||||
return copy.deepcopy(d)
|
return d
|
||||||
|
|
||||||
def get_full_dict(self):
|
|
||||||
d = {
|
|
||||||
k: _encode(v) for (k, v) in self.__dict__.items()
|
|
||||||
if k in self.valid_keys or k in self.internal_keys
|
|
||||||
}
|
|
||||||
d.update(self.unrecognized_keys)
|
|
||||||
return copy.deepcopy(d)
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
|
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
|
||||||
|
@ -100,7 +100,7 @@ class FederationTestCase(unittest.TestCase):
|
|||||||
return defer.succeed({})
|
return defer.succeed({})
|
||||||
self.datastore.have_events.side_effect = have_events
|
self.datastore.have_events.side_effect = have_events
|
||||||
|
|
||||||
def annotate(ev, old_state=None):
|
def annotate(ev, old_state=None, outlier=False):
|
||||||
context = Mock()
|
context = Mock()
|
||||||
context.current_state = {}
|
context.current_state = {}
|
||||||
context.auth_events = {}
|
context.auth_events = {}
|
||||||
@ -120,7 +120,7 @@ class FederationTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.state_handler.compute_event_context.assert_called_once_with(
|
self.state_handler.compute_event_context.assert_called_once_with(
|
||||||
ANY, old_state=None,
|
ANY, old_state=None, outlier=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
||||||
|
@ -42,6 +42,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||||||
"get_room",
|
"get_room",
|
||||||
"store_room",
|
"store_room",
|
||||||
"get_latest_events_in_room",
|
"get_latest_events_in_room",
|
||||||
|
"add_event_hashes",
|
||||||
]),
|
]),
|
||||||
resource_for_federation=NonCallableMock(),
|
resource_for_federation=NonCallableMock(),
|
||||||
http_client=NonCallableMock(spec_set=[]),
|
http_client=NonCallableMock(spec_set=[]),
|
||||||
@ -88,6 +89,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
|
||||||
self.datastore.persist_event.return_value = (1,1)
|
self.datastore.persist_event.return_value = (1,1)
|
||||||
|
self.datastore.add_event_hashes.return_value = []
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_invite(self):
|
def test_invite(self):
|
||||||
|
@ -96,73 +96,84 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_passthrough(self):
|
def test_passthrough(self):
|
||||||
@cached()
|
class A(object):
|
||||||
def func(self, key):
|
@cached()
|
||||||
return key
|
def func(self, key):
|
||||||
|
return key
|
||||||
|
|
||||||
self.assertEquals((yield func(self, "foo")), "foo")
|
a = A()
|
||||||
self.assertEquals((yield func(self, "bar")), "bar")
|
|
||||||
|
self.assertEquals((yield a.func("foo")), "foo")
|
||||||
|
self.assertEquals((yield a.func("bar")), "bar")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_hit(self):
|
def test_hit(self):
|
||||||
callcount = [0]
|
callcount = [0]
|
||||||
|
|
||||||
@cached()
|
class A(object):
|
||||||
def func(self, key):
|
@cached()
|
||||||
callcount[0] += 1
|
def func(self, key):
|
||||||
return key
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
yield func(self, "foo")
|
a = A()
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
self.assertEquals((yield func(self, "foo")), "foo")
|
self.assertEquals((yield a.func("foo")), "foo")
|
||||||
self.assertEquals(callcount[0], 1)
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_invalidate(self):
|
def test_invalidate(self):
|
||||||
callcount = [0]
|
callcount = [0]
|
||||||
|
|
||||||
@cached()
|
class A(object):
|
||||||
def func(self, key):
|
@cached()
|
||||||
callcount[0] += 1
|
def func(self, key):
|
||||||
return key
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
yield func(self, "foo")
|
a = A()
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
func.invalidate("foo")
|
a.func.invalidate("foo")
|
||||||
|
|
||||||
yield func(self, "foo")
|
yield a.func("foo")
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
self.assertEquals(callcount[0], 2)
|
||||||
|
|
||||||
def test_invalidate_missing(self):
|
def test_invalidate_missing(self):
|
||||||
@cached()
|
class A(object):
|
||||||
def func(self, key):
|
@cached()
|
||||||
return key
|
def func(self, key):
|
||||||
|
return key
|
||||||
|
|
||||||
func.invalidate("what")
|
A().func.invalidate("what")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_max_entries(self):
|
def test_max_entries(self):
|
||||||
callcount = [0]
|
callcount = [0]
|
||||||
|
|
||||||
@cached(max_entries=10)
|
class A(object):
|
||||||
def func(self, key):
|
@cached(max_entries=10)
|
||||||
callcount[0] += 1
|
def func(self, key):
|
||||||
return key
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
for k in range(0,12):
|
a = A()
|
||||||
yield func(self, k)
|
|
||||||
|
for k in range(0, 12):
|
||||||
|
yield a.func(k)
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 12)
|
self.assertEquals(callcount[0], 12)
|
||||||
|
|
||||||
# There must have been at least 2 evictions, meaning if we calculate
|
# There must have been at least 2 evictions, meaning if we calculate
|
||||||
# all 12 values again, we must get called at least 2 more times
|
# all 12 values again, we must get called at least 2 more times
|
||||||
for k in range(0,12):
|
for k in range(0,12):
|
||||||
yield func(self, k)
|
yield a.func(k)
|
||||||
|
|
||||||
self.assertTrue(callcount[0] >= 14,
|
self.assertTrue(callcount[0] >= 14,
|
||||||
msg="Expected callcount >= 14, got %d" % (callcount[0]))
|
msg="Expected callcount >= 14, got %d" % (callcount[0]))
|
||||||
@ -171,12 +182,15 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
|||||||
def test_prefill(self):
|
def test_prefill(self):
|
||||||
callcount = [0]
|
callcount = [0]
|
||||||
|
|
||||||
@cached()
|
class A(object):
|
||||||
def func(self, key):
|
@cached()
|
||||||
callcount[0] += 1
|
def func(self, key):
|
||||||
return key
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
func.prefill("foo", 123)
|
a = A()
|
||||||
|
|
||||||
self.assertEquals((yield func(self, "foo")), 123)
|
a.func.prefill("foo", 123)
|
||||||
|
|
||||||
|
self.assertEquals((yield a.func("foo")), 123)
|
||||||
self.assertEquals(callcount[0], 0)
|
self.assertEquals(callcount[0], 0)
|
||||||
|
@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||||||
(yield self.store.get_user_by_id(self.user_id))
|
(yield self.store.get_user_by_id(self.user_id))
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield self.store.get_user_by_token(self.tokens[1])
|
result = yield self.store.get_user_by_token(self.tokens[0])
|
||||||
|
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user