Merge branch 'develop' of github.com:matrix-org/synapse into erikj/e2e_one_time_upsert

This commit is contained in:
Erik Johnston 2017-03-29 10:57:19 +01:00
commit 4ad613f6be
45 changed files with 831 additions and 596 deletions

View File

@ -146,6 +146,7 @@ To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse
source ~/.synapse/bin/activate
pip install --upgrade pip
pip install --upgrade setuptools
pip install https://github.com/matrix-org/synapse/tarball/master
@ -228,6 +229,7 @@ To get started, it is easiest to use the command line to register new users::
New user localpart: erikj
Password:
Confirm password:
Make admin [no]:
Success!
This process uses a setting ``registration_shared_secret`` in

View File

@ -15,10 +15,172 @@
from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, RoomID
from twisted.internet import defer
import ujson as json
import jsonschema
from jsonschema import FormatChecker
FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"limit": {
"type": "number"
},
"senders": {
"$ref": "#/definitions/user_id_array"
},
"not_senders": {
"$ref": "#/definitions/user_id_array"
},
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
"types": {
"type": "array",
"items": {
"type": "string"
}
},
"not_types": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
ROOM_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"not_rooms": {
"$ref": "#/definitions/room_id_array"
},
"rooms": {
"$ref": "#/definitions/room_id_array"
},
"ephemeral": {
"$ref": "#/definitions/room_event_filter"
},
"include_leave": {
"type": "boolean"
},
"state": {
"$ref": "#/definitions/room_event_filter"
},
"timeline": {
"$ref": "#/definitions/room_event_filter"
},
"account_data": {
"$ref": "#/definitions/room_event_filter"
},
}
}
ROOM_EVENT_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"limit": {
"type": "number"
},
"senders": {
"$ref": "#/definitions/user_id_array"
},
"not_senders": {
"$ref": "#/definitions/user_id_array"
},
"types": {
"type": "array",
"items": {
"type": "string"
}
},
"not_types": {
"type": "array",
"items": {
"type": "string"
}
},
"rooms": {
"$ref": "#/definitions/room_id_array"
},
"not_rooms": {
"$ref": "#/definitions/room_id_array"
},
"contains_url": {
"type": "boolean"
}
}
}
USER_ID_ARRAY_SCHEMA = {
"type": "array",
"items": {
"type": "string",
"format": "matrix_user_id"
}
}
ROOM_ID_ARRAY_SCHEMA = {
"type": "array",
"items": {
"type": "string",
"format": "matrix_room_id"
}
}
USER_FILTER_SCHEMA = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "schema for a Sync filter",
"type": "object",
"definitions": {
"room_id_array": ROOM_ID_ARRAY_SCHEMA,
"user_id_array": USER_ID_ARRAY_SCHEMA,
"filter": FILTER_SCHEMA,
"room_filter": ROOM_FILTER_SCHEMA,
"room_event_filter": ROOM_EVENT_FILTER_SCHEMA
},
"properties": {
"presence": {
"$ref": "#/definitions/filter"
},
"account_data": {
"$ref": "#/definitions/filter"
},
"room": {
"$ref": "#/definitions/room_filter"
},
"event_format": {
"type": "string",
"enum": ["client", "federation"]
},
"event_fields": {
"type": "array",
"items": {
"type": "string",
# Don't allow '\\' in event field filters. This makes matching
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
"pattern": "^((?!\\\).)*$"
}
}
},
"additionalProperties": False
}
@FormatChecker.cls_checks('matrix_room_id')
def matrix_room_id_validator(room_id_str):
return RoomID.from_string(room_id_str)
@FormatChecker.cls_checks('matrix_user_id')
def matrix_user_id_validator(user_id_str):
return UserID.from_string(user_id_str)
class Filtering(object):
@ -53,98 +215,11 @@ class Filtering(object):
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
top_level_definitions = [
"presence", "account_data"
]
room_level_definitions = [
"state", "timeline", "ephemeral", "account_data"
]
for key in top_level_definitions:
if key in user_filter_json:
self._check_definition(user_filter_json[key])
if "room" in user_filter_json:
self._check_definition_room_lists(user_filter_json["room"])
for key in room_level_definitions:
if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key])
if "event_fields" in user_filter_json:
if type(user_filter_json["event_fields"]) != list:
raise SynapseError(400, "event_fields must be a list of strings")
for field in user_filter_json["event_fields"]:
if not isinstance(field, basestring):
raise SynapseError(400, "Event field must be a string")
# Don't allow '\\' in event field filters. This makes matching
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
if r'\\' in field:
raise SynapseError(
400, r'The escape character \ cannot itself be escaped'
)
def _check_definition_room_lists(self, definition):
"""Check that "rooms" and "not_rooms" are lists of room ids if they
are present
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
self._check_definition_room_lists(definition)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
try:
jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
format_checker=FormatChecker())
except jsonschema.ValidationError as e:
raise SynapseError(400, e.message)
class FilterCollection(object):

View File

@ -29,6 +29,7 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
@ -63,6 +64,7 @@ class ClientReaderSlavedStore(
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
TransactionStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):

View File

@ -24,6 +24,7 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
@ -59,6 +60,7 @@ logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
TransactionStore,
BaseSlavedStore,
MediaRepositoryStore,
ClientIpStore,

View File

@ -15,7 +15,6 @@
from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
@ -382,30 +381,24 @@ class Keyring(object):
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
with limiter:
keys = None
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
logger.info(
"Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
keys = None
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
logger.info(
"Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
if not keys:
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
)
if not keys:
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
)
keys = {server_name: keys}
keys = {server_name: keys}
defer.returnValue(keys)

View File

@ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent, builder
import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
from synapse.util.retryutils import NotRetryingDestination
import copy
import itertools
@ -88,7 +88,7 @@ class FederationClient(FederationBase):
@log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=False):
retry_on_dns_fail=False, ignore_backoff=False):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
@ -98,6 +98,8 @@ class FederationClient(FederationBase):
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
of the query request.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns:
a Deferred which will eventually yield a JSON object from the
@ -106,7 +108,8 @@ class FederationClient(FederationBase):
sent_queries_counter.inc(query_type)
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)
@log_function
@ -234,31 +237,24 @@ class FederationClient(FederationBase):
continue
try:
limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
transaction_data = yield self.transport_layer.get_event(
destination, event_id, timeout=timeout,
)
with limiter:
transaction_data = yield self.transport_layer.get_event(
destination, event_id, timeout=timeout,
)
logger.debug("transaction_data %r", transaction_data)
logger.debug("transaction_data %r", transaction_data)
pdu_list = [
self.event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"]
]
pdu_list = [
self.event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"]
]
if pdu_list and pdu_list[0]:
pdu = pdu_list[0]
if pdu_list and pdu_list[0]:
pdu = pdu_list[0]
# Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
# Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
break
pdu_attempts[destination] = now

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from twisted.internet import defer
@ -22,9 +22,7 @@ from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state
@ -311,14 +309,8 @@ class TransactionQueue(object):
# XXX: what's this for?
yield run_on_reactor()
pending_pdus = []
while True:
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
backoff_on_404=True, # If we get a 404 the other side has gone
)
device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination)
)
@ -374,7 +366,6 @@ class TransactionQueue(object):
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
limiter=limiter,
)
if success:
# Remove the acknowledged device messages from the database
@ -392,12 +383,24 @@ class TransactionQueue(object):
self.last_device_list_stream_id_by_dest[destination] = dev_list_id
else:
break
except NotRetryingDestination:
except NotRetryingDestination as e:
logger.debug(
"TX [%s] not ready for retry yet - "
"TX [%s] not ready for retry yet (next retry at %s) - "
"dropping transaction for now",
destination,
datetime.datetime.fromtimestamp(
(e.retry_last_ts + e.retry_interval) / 1000.0
),
)
except Exception as e:
logger.warn(
"TX [%s] Failed to send transaction: %s",
destination,
e,
)
for p in pending_pdus:
logger.info("Failed to send event %s to %s", p.event_id,
destination)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
@ -437,7 +440,7 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, limiter):
pending_failures):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
@ -447,132 +450,104 @@ class TransactionQueue(object):
success = True
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
logger.debug(
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pdus),
len(edus),
len(failures)
)
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
len(failures),
)
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
logger.debug(
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pdus),
len(edus),
len(failures)
response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
len(failures),
)
with limiter:
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
if response:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
if response:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
raise e
except HttpResponseException as e:
code = e.code
response = e.response
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
raise e
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
logger.debug("TX [%s] Marked as delivered", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
if code != 200:
for p in pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, destination
)
success = False
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
success = False
logger.debug("TX [%s] Marked as delivered", destination)
if code != 200:
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
logger.info(
"Failed to send event %s to %s", p.event_id, destination
)
success = False
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
defer.returnValue(success)

View File

@ -163,6 +163,7 @@ class TransportLayerClient(object):
data=json_data,
json_data_callback=json_data_callback,
long_retries=True,
backoff_on_404=True, # If we get a 404 the other side has gone
)
logger.debug(
@ -174,7 +175,8 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail):
def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
path = PREFIX + "/query/%s" % query_type
content = yield self.client.get_json(
@ -183,6 +185,7 @@ class TransportLayerClient(object):
args=args,
retry_on_dns_fail=retry_on_dns_fail,
timeout=10000,
ignore_backoff=ignore_backoff,
)
defer.returnValue(content)
@ -242,6 +245,7 @@ class TransportLayerClient(object):
destination=destination,
path=path,
data=content,
ignore_backoff=True,
)
defer.returnValue(response)
@ -269,6 +273,7 @@ class TransportLayerClient(object):
destination=remote_server,
path=path,
args=args,
ignore_backoff=True,
)
defer.returnValue(response)

View File

@ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler):
"room_alias": room_alias.to_string(),
},
retry_on_dns_fail=False,
ignore_backoff=True,
)
except CodeMessageException as e:
logging.warn("Error retrieving alias")

View File

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@ -121,15 +121,11 @@ class E2eKeysHandler(object):
def do_remote_query(destination):
destination_query = remote_queries_not_in_cache[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
remote_result = yield self.federation.query_client_keys(
destination,
{"device_keys": destination_query},
timeout=timeout
)
with limiter:
remote_result = yield self.federation.query_client_keys(
destination,
{"device_keys": destination_query},
timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
@ -239,18 +235,14 @@ class E2eKeysHandler(object):
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
with limiter:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message

View File

@ -575,8 +575,7 @@ class PresenceHandler(object):
if not local_states:
continue
users = yield self.store.get_users_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users)
hosts = yield self.store.get_hosts_in_room(room_id)
for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states)

View File

@ -52,7 +52,8 @@ class ProfileHandler(BaseHandler):
args={
"user_id": target_user.to_string(),
"field": "displayname",
}
},
ignore_backoff=True,
)
except CodeMessageException as e:
if e.code != 404:
@ -99,7 +100,8 @@ class ProfileHandler(BaseHandler):
args={
"user_id": target_user.to_string(),
"field": "avatar_url",
}
},
ignore_backoff=True,
)
except CodeMessageException as e:
if e.code != 404:

View File

@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.util.retryutils
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, HTTPConnectionPool, Agent
@ -22,7 +21,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util import logcontext
import synapse.metrics
from canonicaljson import encode_canonical_json
@ -94,6 +93,7 @@ class MatrixFederationHttpClient(object):
reactor, MatrixFederationEndpointFactory(hs), pool=pool
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
self.version_string = hs.version_string
self._next_id = 1
@ -103,129 +103,152 @@ class MatrixFederationHttpClient(object):
)
@defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True,
timeout=None, long_retries=False):
""" Creates and sends a request to the given url
def _request(self, destination, method, path,
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True,
timeout=None, long_retries=False,
ignore_backoff=False,
backoff_on_404=False):
""" Creates and sends a request to the given server
Args:
destination (str): The remote server to send the HTTP request to.
method (str): HTTP method
path (str): The HTTP path
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): Back off if we get a 404
Returns:
Deferred: resolves with the http response object on success.
Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300.
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
"""
headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination]
url_bytes = self._create_url(
destination, path_bytes, param_bytes, query_bytes
limiter = yield synapse.util.retryutils.get_retry_limiter(
destination,
self.clock,
self._store,
backoff_on_404=backoff_on_404,
ignore_backoff=ignore_backoff,
)
txn_id = "%s-O-%s" % (method, self._next_id)
self._next_id = (self._next_id + 1) % (sys.maxint - 1)
destination = destination.encode("ascii")
path_bytes = path.encode("ascii")
with limiter:
headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination]
outbound_logger.info(
"{%s} [%s] Sending request: %s %s",
txn_id, destination, method, url_bytes
)
url_bytes = self._create_url(
destination, path_bytes, param_bytes, query_bytes
)
# XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place)
if long_retries:
retries_left = MAX_LONG_RETRIES
else:
retries_left = MAX_SHORT_RETRIES
txn_id = "%s-O-%s" % (method, self._next_id)
self._next_id = (self._next_id + 1) % (sys.maxint - 1)
http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "")
)
outbound_logger.info(
"{%s} [%s] Sending request: %s %s",
txn_id, destination, method, url_bytes
)
log_result = None
try:
while True:
producer = None
if body_callback:
producer = body_callback(method, http_url_bytes, headers_dict)
# XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place)
if long_retries:
retries_left = MAX_LONG_RETRIES
else:
retries_left = MAX_SHORT_RETRIES
try:
def send_request():
request_deferred = preserve_context_over_fn(
self.agent.request,
http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "")
)
log_result = None
try:
while True:
producer = None
if body_callback:
producer = body_callback(method, http_url_bytes, headers_dict)
try:
def send_request():
request_deferred = self.agent.request(
method,
url_bytes,
Headers(headers_dict),
producer
)
return self.clock.time_bound_deferred(
request_deferred,
time_out=timeout / 1000. if timeout else 60,
)
with logcontext.PreserveLoggingContext():
response = yield send_request()
log_result = "%d %s" % (response.code, response.phrase,)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
"DNS Lookup failed to %s with %s",
destination,
e
)
log_result = "DNS Lookup failed to %s with %s" % (
destination, e
)
raise
logger.warn(
"{%s} Sending request failed to %s: %s %s: %s - %s",
txn_id,
destination,
method,
url_bytes,
Headers(headers_dict),
producer
type(e).__name__,
_flatten_response_never_received(e),
)
return self.clock.time_bound_deferred(
request_deferred,
time_out=timeout / 1000. if timeout else 60,
log_result = "%s - %s" % (
type(e).__name__, _flatten_response_never_received(e),
)
response = yield preserve_context_over_fn(send_request)
if retries_left and not timeout:
if long_retries:
delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
delay = min(delay, 60)
delay *= random.uniform(0.8, 1.4)
else:
delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
delay = min(delay, 2)
delay *= random.uniform(0.8, 1.4)
log_result = "%d %s" % (response.code, response.phrase,)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
"DNS Lookup failed to %s with %s",
destination,
e
)
log_result = "DNS Lookup failed to %s with %s" % (
destination, e
)
raise
logger.warn(
"{%s} Sending request failed to %s: %s %s: %s - %s",
txn_id,
destination,
method,
url_bytes,
type(e).__name__,
_flatten_response_never_received(e),
)
log_result = "%s - %s" % (
type(e).__name__, _flatten_response_never_received(e),
)
if retries_left and not timeout:
if long_retries:
delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
delay = min(delay, 60)
delay *= random.uniform(0.8, 1.4)
yield sleep(delay)
retries_left -= 1
else:
delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
delay = min(delay, 2)
delay *= random.uniform(0.8, 1.4)
raise
finally:
outbound_logger.info(
"{%s} [%s] Result: %s",
txn_id,
destination,
log_result,
)
yield sleep(delay)
retries_left -= 1
else:
raise
finally:
outbound_logger.info(
"{%s} [%s] Result: %s",
txn_id,
destination,
log_result,
)
if 200 <= response.code < 300:
pass
else:
# :'(
# Update transactions table?
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
raise HttpResponseException(
response.code, response.phrase, body
)
if 200 <= response.code < 300:
pass
else:
# :'(
# Update transactions table?
body = yield preserve_context_over_fn(readBody, response)
raise HttpResponseException(
response.code, response.phrase, body
)
defer.returnValue(response)
defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict,
content=None):
@ -254,7 +277,9 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None,
long_retries=False, timeout=None):
long_retries=False, timeout=None,
ignore_backoff=False,
backoff_on_404=False):
""" Sends the specifed json data using PUT
Args:
@ -269,11 +294,19 @@ class MatrixFederationHttpClient(object):
retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
a failure of the server (and should therefore back off future
requests)
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
"""
if not json_data_callback:
@ -288,26 +321,29 @@ class MatrixFederationHttpClient(object):
producer = _JsonProducer(json_data)
return producer
response = yield self._create_request(
destination.encode("ascii"),
response = yield self._request(
destination,
"PUT",
path.encode("ascii"),
path,
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
backoff_on_404=backoff_on_404,
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False,
timeout=None):
timeout=None, ignore_backoff=False):
""" Sends the specifed json data using POST
Args:
@ -320,11 +356,15 @@ class MatrixFederationHttpClient(object):
retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
"""
def body_callback(method, url_bytes, headers_dict):
@ -333,27 +373,29 @@ class MatrixFederationHttpClient(object):
)
return _JsonProducer(data)
response = yield self._create_request(
destination.encode("ascii"),
response = yield self._request(
destination,
"POST",
path.encode("ascii"),
path,
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
timeout=None):
timeout=None, ignore_backoff=False):
""" GETs some json from the given host homeserver and path
Args:
@ -365,11 +407,16 @@ class MatrixFederationHttpClient(object):
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will
be retried.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns:
Deferred: Succeeds when we get *any* HTTP response.
The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
"""
logger.debug("get_json args: %s", args)
@ -386,39 +433,47 @@ class MatrixFederationHttpClient(object):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._create_request(
destination.encode("ascii"),
response = yield self._request(
destination,
"GET",
path.encode("ascii"),
path,
query_bytes=query_bytes,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
ignore_backoff=ignore_backoff,
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
retry_on_dns_fail=True, max_size=None):
retry_on_dns_fail=True, max_size=None,
ignore_backoff=False):
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
path (str): The HTTP path to GET.
output_stream (file): File to write the response body to.
args (dict): Optional dictionary used to create the query string.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns:
Deferred: resolves with an (int,dict) tuple of the file length and
a dict of the response headers.
Fails with ``HTTPRequestException`` if we get an HTTP response code
>= 300
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
"""
encoded_args = {}
@ -434,22 +489,23 @@ class MatrixFederationHttpClient(object):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._create_request(
destination.encode("ascii"),
response = yield self._request(
destination,
"GET",
path.encode("ascii"),
path,
query_bytes=query_bytes,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)
headers = dict(response.headers.getAllRawHeaders())
try:
length = yield preserve_context_over_fn(
_readBodyToFile,
response, output_stream, max_size
)
with logcontext.PreserveLoggingContext():
length = yield _readBodyToFile(
response, output_stream, max_size
)
except:
logger.exception("Failed to download body")
raise

View File

@ -19,6 +19,7 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
REQUIREMENTS = {
"jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
"frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],

View File

@ -167,7 +167,6 @@ class SlavedEventStore(BaseSlavedStore):
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
_get_state_for_groups = DataStore._get_state_for_groups.__func__
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
_get_events_around_txn = DataStore._get_events_around_txn.__func__

View File

@ -57,5 +57,6 @@ class SlavedPresenceStore(BaseSlavedStore):
self.presence_stream_cache.entity_has_changed(
user_id, position
)
self._get_presence_for_user.invalidate((user_id,))
return super(SlavedPresenceStore, self).process_replication(result)

View File

@ -268,7 +268,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if existingUid is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
ret = yield self.identity_handler.requestMsisdnToken(**body)
defer.returnValue((200, ret))

View File

@ -537,7 +537,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
self.device_handler.check_device_registered(
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)

View File

@ -73,6 +73,9 @@ class LoggingTransaction(object):
def __setattr__(self, name, value):
setattr(self.txn, name, value)
def __iter__(self):
return self.txn.__iter__()
def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args)
@ -132,7 +135,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3):
counters = []
for name, (count, cum_time) in self.current_counters.items():
for name, (count, cum_time) in self.current_counters.iteritems():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
@ -357,7 +360,7 @@ class SQLBaseStore(object):
"""
col_headers = list(intern(column[0]) for column in cursor.description)
results = list(
dict(zip(col_headers, row)) for row in cursor.fetchall()
dict(zip(col_headers, row)) for row in cursor
)
return results
@ -565,7 +568,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
@ -579,7 +582,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values())
return [r[0] for r in txn.fetchall()]
return [r[0] for r in txn]
def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"):
@ -712,7 +715,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
for key, value in keyvalues.items():
for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,))
values.append(value)
@ -753,7 +756,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
@ -870,7 +873,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
for key, value in keyvalues.items():
for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,))
values.append(value)
@ -901,16 +904,16 @@ class SQLBaseStore(object):
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
for row in rows
for row in txn
}
txn.close()
if cache:
min_val = min(cache.values())
min_val = min(cache.itervalues())
else:
min_val = max_value

View File

@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
global_account_data = {
row[0]: json.loads(row[1]) for row in txn.fetchall()
row[0]: json.loads(row[1]) for row in txn
}
sql = (
@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
account_data_by_room = {}
for row in txn.fetchall():
for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2])

View File

@ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
)
txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"])
for row in txn.fetchall():
for row in txn:
# Add the message for all devices for this user on this
# server.
device = row[0]
@ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
txn.execute(sql, [user_id] + devices)
for row in txn.fetchall():
for row in txn:
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
@ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn.fetchall():
for row in txn:
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
@ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" ORDER BY stream_id ASC"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn.fetchall())
rows.extend(txn)
return rows
@ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
destination, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn.fetchall():
for row in txn:
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:

View File

@ -329,17 +329,20 @@ class DeviceStore(SQLBaseStore):
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
rows = txn.fetchall()
if not rows:
return (now_stream_id, [])
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in rows}
query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in query_map.itervalues())
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
)

View File

@ -193,7 +193,7 @@ class EndToEndKeyStore(SQLBaseStore):
)
txn.execute(sql, (user_id, device_id))
result = {}
for algorithm, key_count in txn.fetchall():
for algorithm, key_count in txn:
result[algorithm] = key_count
return result
return self.runInteraction(
@ -214,7 +214,7 @@ class EndToEndKeyStore(SQLBaseStore):
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn.fetchall():
for key_id, key_json in txn:
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
sql = (

View File

@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
base_sql % (",".join(["?"] * len(chunk)),),
chunk
)
new_front.update([r[0] for r in txn.fetchall()])
new_front.update([r[0] for r in txn])
new_front -= results
@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, False,))
return dict(txn.fetchall())
return dict(txn)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
@ -152,7 +152,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, ))
results = []
for event_id, depth in txn.fetchall():
for event_id, depth in txn:
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
@ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id))
rows = txn.fetchall()
return [event_id for event_id, in rows]
return [event_id for event_id, in txn]
return self.runInteraction(
"get_forward_extremeties_for_room",
@ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results))
)
for row in txn.fetchall():
for row in txn:
if row[1] not in event_results:
queue.put((-row[0], row[1]))
@ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results))
)
for e_id, in txn.fetchall():
for e_id, in txn:
new_front.add(e_id)
new_front -= earliest_events

View File

@ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore):
" stream_ordering >= ? AND stream_ordering <= ?"
)
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn.fetchall()]
return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret)

View File

@ -217,14 +217,14 @@ class EventsStore(SQLBaseStore):
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
for room_id, evs_ctxs in partitioned.items():
for room_id, evs_ctxs in partitioned.iteritems():
d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs,
backfilled=backfilled,
)
deferreds.append(d)
for room_id in partitioned.keys():
for room_id in partitioned:
self._maybe_start_persisting(room_id)
return preserve_context_over_deferred(
@ -323,7 +323,7 @@ class EventsStore(SQLBaseStore):
(event, context)
)
for room_id, ev_ctx_rm in events_by_room.items():
for room_id, ev_ctx_rm in events_by_room.iteritems():
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
@ -428,6 +428,7 @@ class EventsStore(SQLBaseStore):
# Now we need to work out the different state sets for
# each state extremities
state_sets = []
state_groups = set()
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
@ -437,9 +438,17 @@ class EventsStore(SQLBaseStore):
if event_id == ev.event_id:
if ctx.current_state_ids is None:
raise Exception("Unknown current state")
state_sets.append(ctx.current_state_ids)
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
# If we've already seen the state group don't bother adding
# it to the state sets again
if ctx.state_group not in state_groups:
state_sets.append(ctx.current_state_ids)
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
if ctx.state_group:
# Add this as a seen state group (if it has a state
# group)
state_groups.add(ctx.state_group)
break
else:
# If we couldn't find it, then we'll need to pull
@ -453,45 +462,51 @@ class EventsStore(SQLBaseStore):
missing_event_ids,
)
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups)
groups = set(event_to_groups.itervalues()) - state_groups
state_sets.extend(group_to_state.values())
if groups:
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.itervalues())
if not new_latest_event_ids:
current_state = {}
elif was_updated:
# We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including
# the events we have yet to persist, so we need a slightly more
# complicated event lookup function than simply looking the events
# up in the db.
events_map = {ev.event_id: ev for ev, _ in events_context}
if len(state_sets) == 1:
# If there is only one state set, then we know what the current
# state is.
current_state = state_sets[0]
else:
# We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including
# the events we have yet to persist, so we need a slightly more
# complicated event lookup function than simply looking the events
# up in the db.
events_map = {ev.event_id: ev for ev, _ in events_context}
@defer.inlineCallbacks
def get_events(ev_ids):
# We get the events by first looking at the list of events we
# are trying to persist, and then fetching the rest from the DB.
db = []
to_return = {}
for ev_id in ev_ids:
ev = events_map.get(ev_id, None)
if ev:
to_return[ev_id] = ev
else:
db.append(ev_id)
@defer.inlineCallbacks
def get_events(ev_ids):
# We get the events by first looking at the list of events we
# are trying to persist, and then fetching the rest from the DB.
db = []
to_return = {}
for ev_id in ev_ids:
ev = events_map.get(ev_id, None)
if ev:
to_return[ev_id] = ev
else:
db.append(ev_id)
if db:
evs = yield self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
to_return.update(evs)
defer.returnValue(to_return)
if db:
evs = yield self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
to_return.update(evs)
defer.returnValue(to_return)
current_state = yield resolve_events(
state_sets,
state_map_factory=get_events,
)
current_state = yield resolve_events(
state_sets,
state_map_factory=get_events,
)
else:
return
@ -718,7 +733,7 @@ class EventsStore(SQLBaseStore):
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
max_stream_order):
for room_id, new_extrem in new_forward_extremities.items():
for room_id, new_extrem in new_forward_extremities.iteritems():
self._simple_delete_txn(
txn,
table="event_forward_extremities",
@ -736,7 +751,7 @@ class EventsStore(SQLBaseStore):
"event_id": ev_id,
"room_id": room_id,
}
for room_id, new_extrem in new_forward_extremities.items()
for room_id, new_extrem in new_forward_extremities.iteritems()
for ev_id in new_extrem
],
)
@ -753,7 +768,7 @@ class EventsStore(SQLBaseStore):
"event_id": event_id,
"stream_ordering": max_stream_order,
}
for room_id, new_extrem in new_forward_extremities.items()
for room_id, new_extrem in new_forward_extremities.iteritems()
for event_id in new_extrem
]
)
@ -807,7 +822,7 @@ class EventsStore(SQLBaseStore):
event.depth, depth_updates.get(event.room_id, event.depth)
)
for room_id, depth in depth_updates.items():
for room_id, depth in depth_updates.iteritems():
self._update_min_depth_for_room_txn(txn, room_id, depth)
def _update_outliers_txn(self, txn, events_and_contexts):
@ -834,7 +849,7 @@ class EventsStore(SQLBaseStore):
have_persisted = {
event_id: outlier
for event_id, outlier in txn.fetchall()
for event_id, outlier in txn
}
to_remove = set()
@ -958,14 +973,10 @@ class EventsStore(SQLBaseStore):
return
def event_dict(event):
return {
k: v
for k, v in event.get_dict().items()
if k not in [
"redacted",
"redacted_because",
]
}
d = event.get_dict()
d.pop("redacted", None)
d.pop("redacted_because", None)
return d
self._simple_insert_many_txn(
txn,
@ -1998,7 +2009,7 @@ class EventsStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in curr_state.items()
for key, state_id in curr_state.iteritems()
],
)

View File

@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
),
(current_version,)
)
applied_deltas = [d for d, in txn.fetchall()]
applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded
return None

View File

@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id,
)
self._invalidate_cache_and_stream(
txn, self._get_presence_for_user, (state.user_id,)
txn.call_after(
self._get_presence_for_user.invalidate, (state.user_id,)
)
# Actually insert new rows

View File

@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
)
txn.execute(sql, (room_id, receipt_type, user_id))
results = txn.fetchall()
if results and topological_ordering:
for to, so, _ in results:
if topological_ordering:
for to, so, _ in txn:
if int(to) > topological_ordering:
return False
elif int(to) == topological_ordering and int(so) >= stream_ordering:

View File

@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn.fetchall())
return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)

View File

@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
sql % ("AND appservice_id IS NULL",),
(stream_id,)
)
return dict(txn.fetchall())
return dict(txn)
else:
# We want to get from all lists, so we need to aggregate the results
@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
results = {}
# A room is visible if its visible on any list.
for room_id, visibility in txn.fetchall():
for room_id, visibility in txn:
results[room_id] = bool(visibility) or results.get(room_id, False)
return results

View File

@ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore):
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
def get_hosts_in_room(self, room_id, cache_context):
"""Returns the set of all hosts currently in the room
"""
user_ids = yield self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate,
)
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
defer.returnValue(hosts)
@cached(max_entries=500000, iterable=True)
def get_users_in_room(self, room_id):
def f(txn):
rows = self._get_members_rows_txn(
txn,
room_id=room_id,
membership=Membership.JOIN,
sql = (
"SELECT m.user_id FROM room_memberships as m"
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id "
" AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
)
return [r["user_id"] for r in rows]
txn.execute(sql, (room_id, Membership.JOIN,))
return [r[0] for r in txn]
return self.runInteraction("get_users_in_room", f)
@cached()
@ -246,34 +259,6 @@ class RoomMemberStore(SQLBaseStore):
return results
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?"
where_values = [room_id]
if membership:
where_clause += " AND m.membership = ?"
where_values.append(membership)
if user_id:
where_clause += " AND m.user_id = ?"
where_values.append(user_id)
sql = (
"SELECT m.* FROM room_memberships as m"
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id "
" AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND %(where)s"
) % {
"where": where_clause,
}
txn.execute(sql, where_values)
rows = self.cursor_to_dict(txn)
return rows
@cachedInlineCallbacks(max_entries=500000, iterable=True)
def get_rooms_for_user(self, user_id):
"""Returns a set of room_ids the user is currently joined to

View File

@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()}
return {k: v for k, v in txn}
def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU

View File

@ -90,7 +90,7 @@ class StateStore(SQLBaseStore):
event_ids,
)
groups = set(event_to_groups.values())
groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state)
@ -108,17 +108,18 @@ class StateStore(SQLBaseStore):
state_event_map = yield self.get_events(
[
ev_id for group_ids in group_to_ids.values()
for ev_id in group_ids.values()
ev_id for group_ids in group_to_ids.itervalues()
for ev_id in group_ids.itervalues()
],
get_prev_content=False
)
defer.returnValue({
group: [
state_event_map[v] for v in event_id_map.values() if v in state_event_map
state_event_map[v] for v in event_id_map.itervalues()
if v in state_event_map
]
for group, event_id_map in group_to_ids.items()
for group, event_id_map in group_to_ids.iteritems()
})
def _have_persisted_state_group_txn(self, txn, state_group):
@ -190,7 +191,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.items()
for key, state_id in context.delta_ids.iteritems()
],
)
else:
@ -205,7 +206,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.items()
for key, state_id in context.current_state_ids.iteritems()
],
)
@ -217,7 +218,7 @@ class StateStore(SQLBaseStore):
"state_group": state_group_id,
"event_id": event_id,
}
for event_id, state_group_id in state_groups.items()
for event_id, state_group_id in state_groups.iteritems()
],
)
@ -341,10 +342,10 @@ class StateStore(SQLBaseStore):
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
rows = self.cursor_to_dict(txn)
for row in rows:
key = (row["type"], row["state_key"])
results[group][key] = row["event_id"]
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
results[group][key] = event_id
else:
if types is not None:
where_clause = "AND (%s)" % (
@ -373,12 +374,11 @@ class StateStore(SQLBaseStore):
" WHERE state_group = ? %s" % (where_clause,),
args
)
rows = txn.fetchall()
results[group].update({
(typ, state_key): event_id
for typ, state_key, event_id in rows
results[group].update(
((typ, state_key), event_id)
for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
})
)
# If the lengths match then we must have all the types,
# so no need to go walk further down the tree.
@ -415,21 +415,21 @@ class StateStore(SQLBaseStore):
event_ids,
)
groups = set(event_to_groups.values())
groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
[ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
get_prev_content=False
)
event_to_state = {
event_id: {
k: state_event_map[v]
for k, v in group_to_state[group].items()
for k, v in group_to_state[group].iteritems()
if v in state_event_map
}
for event_id, group in event_to_groups.items()
for event_id, group in event_to_groups.iteritems()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@ -452,12 +452,12 @@ class StateStore(SQLBaseStore):
event_ids,
)
groups = set(event_to_groups.values())
groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = {
event_id: group_to_state[group]
for event_id, group in event_to_groups.items()
for event_id, group in event_to_groups.iteritems()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@ -569,7 +569,7 @@ class StateStore(SQLBaseStore):
got_all = not (missing_types or types is None)
return {
k: v for k, v in state_dict_ids.items()
k: v for k, v in state_dict_ids.iteritems()
if include(k[0], k[1])
}, missing_types, got_all
@ -628,7 +628,7 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched
# from the database.
for group, group_state_dict in group_to_state_dict.items():
for group, group_state_dict in group_to_state_dict.iteritems():
if types:
# We delibrately put key -> None mappings into the cache to
# cache absence of the key, on the assumption that if we've
@ -643,10 +643,10 @@ class StateStore(SQLBaseStore):
else:
state_dict = results[group]
state_dict.update({
(intern_string(k[0]), intern_string(k[1])): v
for k, v in group_state_dict.items()
})
state_dict.update(
((intern_string(k[0]), intern_string(k[1])), v)
for k, v in group_state_dict.iteritems()
)
self._state_group_cache.update(
cache_seq_num,
@ -657,10 +657,10 @@ class StateStore(SQLBaseStore):
# Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
for group, state_dict in results.items():
for group, state_dict in results.iteritems():
results[group] = {
key: event_id
for key, event_id in state_dict.items()
for key, event_id in state_dict.iteritems()
if event_id
}
@ -749,7 +749,7 @@ class StateStore(SQLBaseStore):
# of keys
delta_state = {
key: value for key, value in curr_state.items()
key: value for key, value in curr_state.iteritems()
if prev_state.get(key, None) != value
}
@ -789,7 +789,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in delta_state.items()
for key, state_id in delta_state.iteritems()
],
)

View File

@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
tags = []
for tag, content in txn.fetchall():
for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json))
@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
room_ids = [row[0] for row in txn.fetchall()]
room_ids = [row[0] for row in txn]
return room_ids
changed = self._account_data_stream_cache.has_entity_changed(

View File

@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class DeferredTimedOutError(SynapseError):
def __init__(self):
super(SynapseError).__init__(504, "Timed out")
super(SynapseError, self).__init__(504, "Timed out")
def unwrapFirstError(failure):
@ -93,8 +93,10 @@ class Clock(object):
ret_deferred = defer.Deferred()
def timed_out_fn():
e = DeferredTimedOutError()
try:
ret_deferred.errback(DeferredTimedOutError())
ret_deferred.errback(e)
except:
pass
@ -114,7 +116,7 @@ class Clock(object):
ret_deferred.addBoth(cancel)
def sucess(res):
def success(res):
try:
ret_deferred.callback(res)
except:
@ -128,7 +130,7 @@ class Clock(object):
except:
pass
given_deferred.addCallbacks(callback=sucess, errback=err)
given_deferred.addCallbacks(callback=success, errback=err)
timer = self.call_later(time_out, timed_out_fn)

View File

@ -35,7 +35,8 @@ class NotRetryingDestination(Exception):
@defer.inlineCallbacks
def get_retry_limiter(destination, clock, store, **kwargs):
def get_retry_limiter(destination, clock, store, ignore_backoff=False,
**kwargs):
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
@ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
that will mark the destination as down if an exception is thrown (excluding
CodeMessageException with code < 500)
Args:
destination (str): name of homeserver
clock (synapse.util.clock): timing source
store (synapse.storage.transactions.TransactionStore): datastore
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway. We will still update the next
retry_interval on success/failure.
Example usage:
try:
@ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
now = int(clock.time_msec())
if retry_last_ts + retry_interval > now:
if not ignore_backoff and retry_last_ts + retry_interval > now:
raise NotRetryingDestination(
retry_last_ts=retry_last_ts,
retry_interval=retry_interval,
@ -124,7 +133,13 @@ class RetryDestinationLimiter(object):
def __exit__(self, exc_type, exc_val, exc_tb):
valid_err_code = False
if exc_type is not None and issubclass(exc_type, CodeMessageException):
if exc_type is None:
valid_err_code = True
elif not issubclass(exc_type, Exception):
# avoid treating exceptions which don't derive from Exception as
# failures; this is mostly so as not to catch defer._DefGen.
valid_err_code = True
elif issubclass(exc_type, CodeMessageException):
# Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS
@ -142,11 +157,13 @@ class RetryDestinationLimiter(object):
else:
valid_err_code = False
if exc_type is None or valid_err_code:
if valid_err_code:
# We connected successfully.
if not self.retry_interval:
return
logger.debug("Connection to %s was successful; clearing backoff",
self.destination)
retry_last_ts = 0
self.retry_interval = 0
else:
@ -160,6 +177,10 @@ class RetryDestinationLimiter(object):
else:
self.retry_interval = self.min_retry_interval
logger.debug(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
self.destination, exc_type, exc_val, self.retry_interval
)
retry_last_ts = int(self.clock.time_msec())
@defer.inlineCallbacks

View File

@ -134,6 +134,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave"
# Always allow the user to see their own leave events, otherwise
# they won't see the room disappear if they reject the invite
if membership == "leave" and (
prev_membership == "join" or prev_membership == "invite"
):
return True
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority:

View File

@ -23,6 +23,9 @@ from tests.utils import (
from synapse.api.filtering import Filter
from synapse.events import FrozenEvent
from synapse.api.errors import SynapseError
import jsonschema
user_localpart = "test_user"
@ -54,6 +57,70 @@ class FilteringTestCase(unittest.TestCase):
self.datastore = hs.get_datastore()
def test_errors_on_invalid_filters(self):
invalid_filters = [
{"boom": {}},
{"account_data": "Hello World"},
{"event_fields": ["\\foo"]},
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
{"event_format": "other"},
{"room": {"not_rooms": ["#foo:pik-test"]}},
{"presence": {"senders": ["@bar;pik.test.com"]}}
]
for filter in invalid_filters:
with self.assertRaises(SynapseError) as check_filter_error:
self.filtering.check_valid_filter(filter)
self.assertIsInstance(check_filter_error.exception, SynapseError)
def test_valid_filters(self):
valid_filters = [
{
"room": {
"timeline": {"limit": 20},
"state": {"not_types": ["m.room.member"]},
"ephemeral": {"limit": 0, "not_types": ["*"]},
"include_leave": False,
"rooms": ["!dee:pik-test"],
"not_rooms": ["!gee:pik-test"],
"account_data": {"limit": 0, "types": ["*"]}
}
},
{
"room": {
"state": {
"types": ["m.room.*"],
"not_rooms": ["!726s6s6q:example.com"]
},
"timeline": {
"limit": 10,
"types": ["m.room.message"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"]
},
"ephemeral": {
"types": ["m.receipt", "m.typing"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"]
}
},
"presence": {
"types": ["m.presence"],
"not_senders": ["@alice:example.com"]
},
"event_format": "client",
"event_fields": ["type", "content", "sender"]
}
]
for filter in valid_filters:
try:
self.filtering.check_valid_filter(filter)
except jsonschema.ValidationError as e:
self.fail(e)
def test_limits_are_applied(self):
# TODO
pass
def test_definition_types_works_with_literals(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]

View File

@ -93,6 +93,7 @@ class DirectoryTestCase(unittest.TestCase):
"room_alias": "#another:remote",
},
retry_on_dns_fail=False,
ignore_backoff=True,
)
@defer.inlineCallbacks

View File

@ -119,7 +119,8 @@ class ProfileTestCase(unittest.TestCase):
self.mock_federation.make_query.assert_called_with(
destination="remote",
query_type="profile",
args={"user_id": "@alice:remote", "field": "displayname"}
args={"user_id": "@alice:remote", "field": "displayname"},
ignore_backoff=True,
)
@defer.inlineCallbacks

View File

@ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
),
json_data_callback=ANY,
long_retries=True,
backoff_on_404=True,
),
defer.succeed((200, "OK"))
)
@ -263,6 +264,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
),
json_data_callback=ANY,
long_retries=True,
backoff_on_404=True,
),
defer.succeed((200, "OK"))
)

View File

@ -33,8 +33,8 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase):
USER_ID = "@apple:test"
EXAMPLE_FILTER = {"type": ["m.*"]}
EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter]
@defer.inlineCallbacks

View File

@ -89,7 +89,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_select_one_1col(self):
self.mock_txn.rowcount = 1
self.mock_txn.fetchall.return_value = [("Value",)]
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore._simple_select_one_onecol(
table="tablename",
@ -136,7 +136,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_select_list(self):
self.mock_txn.rowcount = 3
self.mock_txn.fetchall.return_value = ((1,), (2,), (3,))
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (
("colA", None, None, None, None, None, None),
)

33
tests/util/test_clock.py Normal file
View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse import util
from twisted.internet import defer
from tests import unittest
class ClockTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_time_bound_deferred(self):
# just a deferred which never resolves
slow_deferred = defer.Deferred()
clock = util.Clock()
time_bound = clock.time_bound_deferred(slow_deferred, 0.001)
try:
yield time_bound
self.fail("Expected timedout error, but got nothing")
except util.DeferredTimedOutError:
pass