diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py new file mode 100755 index 000000000..f0d0cd8a3 --- /dev/null +++ b/scripts-dev/definitions.py @@ -0,0 +1,142 @@ +#! /usr/bin/python + +import ast +import yaml + +class DefinitionVisitor(ast.NodeVisitor): + def __init__(self): + super(DefinitionVisitor, self).__init__() + self.functions = {} + self.classes = {} + self.names = {} + self.attrs = set() + self.definitions = { + 'def': self.functions, + 'class': self.classes, + 'names': self.names, + 'attrs': self.attrs, + } + + def visit_Name(self, node): + self.names.setdefault(type(node.ctx).__name__, set()).add(node.id) + + def visit_Attribute(self, node): + self.attrs.add(node.attr) + for child in ast.iter_child_nodes(node): + self.visit(child) + + def visit_ClassDef(self, node): + visitor = DefinitionVisitor() + self.classes[node.name] = visitor.definitions + for child in ast.iter_child_nodes(node): + visitor.visit(child) + + def visit_FunctionDef(self, node): + visitor = DefinitionVisitor() + self.functions[node.name] = visitor.definitions + for child in ast.iter_child_nodes(node): + visitor.visit(child) + + +def non_empty(defs): + functions = {name: non_empty(f) for name, f in defs['def'].items()} + classes = {name: non_empty(f) for name, f in defs['class'].items()} + result = {} + if functions: result['def'] = functions + if classes: result['class'] = classes + names = defs['names'] + uses = [] + for name in names.get('Load', ()): + if name not in names.get('Param', ()) and name not in names.get('Store', ()): + uses.append(name) + uses.extend(defs['attrs']) + if uses: result['uses'] = uses + result['names'] = names + result['attrs'] = defs['attrs'] + return result + + +def definitions_in_code(input_code): + input_ast = ast.parse(input_code) + visitor = DefinitionVisitor() + visitor.visit(input_ast) + definitions = non_empty(visitor.definitions) + return definitions + + +def definitions_in_file(filepath): + with open(filepath) as f: + return definitions_in_code(f.read()) + + +def defined_names(prefix, defs, names): + for name, funcs in defs.get('def', {}).items(): + names.setdefault(name, {'defined': []})['defined'].append(prefix + name) + defined_names(prefix + name + ".", funcs, names) + + for name, funcs in defs.get('class', {}).items(): + names.setdefault(name, {'defined': []})['defined'].append(prefix + name) + defined_names(prefix + name + ".", funcs, names) + + +def used_names(prefix, defs, names): + for name, funcs in defs.get('def', {}).items(): + used_names(prefix + name + ".", funcs, names) + + for name, funcs in defs.get('class', {}).items(): + used_names(prefix + name + ".", funcs, names) + + for used in defs.get('uses', ()): + if used in names: + names[used].setdefault('used', []).append(prefix.rstrip('.')) + + +if __name__ == '__main__': + import sys, os, argparse, re + + parser = argparse.ArgumentParser(description='Find definitions.') + parser.add_argument( + "--unused", action="store_true", help="Only list unused definitions" + ) + parser.add_argument( + "--ignore", action="append", metavar="REGEXP", help="Ignore a pattern" + ) + parser.add_argument( + "--pattern", action="append", metavar="REGEXP", + help="Search for a pattern" + ) + parser.add_argument( + "directories", nargs='+', metavar="DIR", + help="Directories to search for definitions" + ) + args = parser.parse_args() + + definitions = {} + for directory in args.directories: + for root, dirs, files in os.walk(directory): + for filename in files: + if filename.endswith(".py"): + filepath = os.path.join(root, filename) + definitions[filepath] = definitions_in_file(filepath) + + names = {} + for filepath, defs in definitions.items(): + defined_names(filepath + ":", defs, names) + + for filepath, defs in definitions.items(): + used_names(filepath + ":", defs, names) + + patterns = [re.compile(pattern) for pattern in args.pattern or ()] + ignore = [re.compile(pattern) for pattern in args.ignore or ()] + + result = {} + for name, definition in names.items(): + if patterns and not any(pattern.match(name) for pattern in patterns): + continue + if ignore and any(pattern.match(name) for pattern in ignore): + continue + if args.unused and definition.get('used'): + continue + result[name] = definition + + yaml.dump(result, sys.stdout, default_flow_style=False) diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 6aba72e45..62515997b 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -95,8 +95,6 @@ class Store(object): _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] - _execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"] - def runInteraction(self, desc, func, *args, **kwargs): def r(conn): try: diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c3b4d971a..ee3045268 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -77,11 +77,6 @@ class SynapseError(CodeMessageException): ) -class RoomError(SynapseError): - """An error raised when a room event fails.""" - pass - - class RegistrationError(SynapseError): """An error raised when a registration event fails.""" pass diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 21840e4a2..190b03e2f 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -85,12 +85,6 @@ import time logger = logging.getLogger("synapse.app.homeserver") -class GzipFile(File): - def getChild(self, path, request): - child = File.getChild(self, path, request) - return EncodingResourceWrapper(child, [GzipEncoderFactory()]) - - def gz_wrap(r): return EncodingResourceWrapper(r, [GzipEncoderFactory()]) @@ -134,6 +128,7 @@ class SynapseHomeServer(HomeServer): # (It can stay enabled for the API resources: they call # write() with the whole body and then finish() straight # after and so do not trigger the bug. + # GzipFile was removed in commit 184ba09 # return GzipFile(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable? diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4ff20599d..f4dce712f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1456,52 +1456,3 @@ class FederationHandler(BaseHandler): }, "missing": [e.event_id for e in missing_locals], }) - - @defer.inlineCallbacks - def _handle_auth_events(self, origin, auth_events): - auth_ids_to_deferred = {} - - def process_auth_ev(ev): - auth_ids = [e_id for e_id, _ in ev.auth_events] - - prev_ds = [ - auth_ids_to_deferred[i] - for i in auth_ids - if i in auth_ids_to_deferred - ] - - d = defer.Deferred() - - auth_ids_to_deferred[ev.event_id] = d - - @defer.inlineCallbacks - def f(*_): - ev.internal_metadata.outlier = True - - try: - auth = { - (e.type, e.state_key): e for e in auth_events - if e.event_id in auth_ids - } - - yield self._handle_new_event( - origin, ev, auth_events=auth - ) - except: - logger.exception( - "Failed to handle auth event %s", - ev.event_id, - ) - - d.callback(None) - - if prev_ds: - dx = defer.DeferredList(prev_ds) - dx.addBoth(f) - else: - f() - - for e in auth_events: - process_auth_ev(e) - - yield defer.DeferredList(auth_ids_to_deferred.values()) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index bb5eef6bb..773f0a2e9 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -492,32 +492,6 @@ class RoomMemberHandler(BaseHandler): "user_joined_room", user=user, room_id=room_id ) - @defer.inlineCallbacks - def _should_invite_join(self, room_id, prev_state, do_auth): - logger.debug("_should_invite_join: room_id: %s", room_id) - - # XXX: We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - - # Only do an invite join dance if a) we were invited, - # b) the person inviting was from a differnt HS and c) we are - # not currently in the room - room_host = None - if prev_state and prev_state.membership == Membership.INVITE: - room = yield self.store.get_room(room_id) - inviter = UserID.from_string( - prev_state.sender - ) - - is_remote_invite_join = not self.hs.is_mine(inviter) and not room - room_host = inviter.domain - else: - is_remote_invite_join = False - - defer.returnValue((is_remote_invite_join, room_host)) - @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): """Returns a list of roomids that the user has any of the given diff --git a/synapse/state.py b/synapse/state.py index ed36f844c..bb225c39c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -31,10 +31,6 @@ import hashlib logger = logging.getLogger(__name__) -def _get_state_key_from_event(event): - return event.state_key - - KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 495ef087c..693784ad3 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -25,8 +25,6 @@ from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple - import sys import time import threading @@ -376,9 +374,6 @@ class SQLBaseStore(object): return self.runInteraction(desc, interaction) - def _execute_and_decode(self, desc, query, *args): - return self._execute(desc, self.cursor_to_dict, query, *args) - # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. @@ -691,37 +686,6 @@ class SQLBaseStore(object): return dict(zip(retcols, row)) - def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None, - retcols=None, allow_none=False, - desc="_simple_selectupdate_one"): - """ Combined SELECT then UPDATE.""" - def func(txn): - ret = None - if retcols: - ret = self._simple_select_one_txn( - txn, - table=table, - keyvalues=keyvalues, - retcols=retcols, - allow_none=allow_none, - ) - - if updatevalues: - self._simple_update_one_txn( - txn, - table=table, - keyvalues=keyvalues, - updatevalues=updatevalues, - ) - - # if txn.rowcount == 0: - # raise StoreError(404, "No row found") - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched") - - return ret - return self.runInteraction(desc, func) - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -743,16 +707,6 @@ class SQLBaseStore(object): raise StoreError(500, "more than one row matched") return self.runInteraction(desc, func) - def _simple_delete(self, table, keyvalues, desc="_simple_delete"): - """Executes a DELETE query on the named table. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - - return self.runInteraction(desc, self._simple_delete_txn) - def _simple_delete_txn(self, txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, @@ -761,24 +715,6 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) - def _simple_max_id(self, table): - """Executes a SELECT query on the named table, expecting to return the - max value for the column "id". - - Args: - table : string giving the table name - """ - sql = "SELECT MAX(id) AS id FROM %s" % table - - def func(txn): - txn.execute(sql) - max_id = self.cursor_to_dict(txn)[0]["id"] - if max_id is None: - return 0 - return max_id - - return self.runInteraction("_simple_max_id", func) - def get_next_stream_id(self): with self._next_stream_id_lock: i = self._next_stream_id @@ -791,129 +727,3 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass - - -class Table(object): - """ A base class used to store information about a particular table. - """ - - table_name = None - """ str: The name of the table """ - - fields = None - """ list: The field names """ - - EntryType = None - """ Type: A tuple type used to decode the results """ - - _select_where_clause = "SELECT %s FROM %s WHERE %s" - _select_clause = "SELECT %s FROM %s" - _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)" - - @classmethod - def select_statement(cls, where_clause=None): - """ - Args: - where_clause (str): The WHERE clause to use. - - Returns: - str: An SQL statement to select rows from the table with the given - WHERE clause. - """ - if where_clause: - return cls._select_where_clause % ( - ", ".join(cls.fields), - cls.table_name, - where_clause - ) - else: - return cls._select_clause % ( - ", ".join(cls.fields), - cls.table_name, - ) - - @classmethod - def insert_statement(cls): - return cls._insert_clause % ( - cls.table_name, - ", ".join(cls.fields), - ", ".join(["?"] * len(cls.fields)), - ) - - @classmethod - def decode_single_result(cls, results): - """ Given an iterable of tuples, return a single instance of - `EntryType` or None if the iterable is empty - Args: - results (list): The results list to convert to `EntryType` - Returns: - EntryType: An instance of `EntryType` - """ - results = list(results) - if results: - return cls.EntryType(*results[0]) - else: - return None - - @classmethod - def decode_results(cls, results): - """ Given an iterable of tuples, return a list of `EntryType` - Args: - results (list): The results list to convert to `EntryType` - - Returns: - list: A list of `EntryType` - """ - return [cls.EntryType(*row) for row in results] - - @classmethod - def get_fields_string(cls, prefix=None): - if prefix: - to_join = ("%s.%s" % (prefix, f) for f in cls.fields) - else: - to_join = cls.fields - - return ", ".join(to_join) - - -class JoinHelper(object): - """ Used to help do joins on tables by looking at the tables' fields and - creating a list of unique fields to use with SELECTs and a namedtuple - to dump the results into. - - Attributes: - tables (list): List of `Table` classes - EntryType (type) - """ - - def __init__(self, *tables): - self.tables = tables - - res = [] - for table in self.tables: - res += [f for f in table.fields if f not in res] - - self.EntryType = namedtuple("JoinHelperEntry", res) - - def get_fields(self, **prefixes): - """Get a string representing a list of fields for use in SELECT - statements with the given prefixes applied to each. - - For example:: - - JoinHelper(PdusTable, StateTable).get_fields( - PdusTable="pdus", - StateTable="state" - ) - """ - res = [] - for field in self.EntryType._fields: - for table in self.tables: - if field in table.fields: - res.append("%s.%s" % (prefixes[table.__name__], field)) - break - - return ", ".join(res) - - def decode_results(self, rows): - return [self.EntryType(*row) for row in rows] diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index c1cabbaa6..6d4421dd8 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -154,98 +154,6 @@ class EventFederationStore(SQLBaseStore): return results - def _get_latest_state_in_room(self, txn, room_id, type, state_key): - event_ids = self._simple_select_onecol_txn( - txn, - table="state_forward_extremities", - keyvalues={ - "room_id": room_id, - "type": type, - "state_key": state_key, - }, - retcol="event_id", - ) - - results = [] - for event_id in event_ids: - hashes = self._get_event_reference_hashes_txn(txn, event_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" - } - results.append((event_id, prev_hashes)) - - return results - - def _get_prev_events(self, txn, event_id): - results = self._get_prev_events_and_state( - txn, - event_id, - is_state=0, - ) - - return [(e_id, h, ) for e_id, h, _ in results] - - def _get_prev_state(self, txn, event_id): - results = self._get_prev_events_and_state( - txn, - event_id, - is_state=True, - ) - - return [(e_id, h, ) for e_id, h, _ in results] - - def _get_prev_events_and_state(self, txn, event_id, is_state=None): - keyvalues = { - "event_id": event_id, - } - - if is_state is not None: - keyvalues["is_state"] = bool(is_state) - - res = self._simple_select_list_txn( - txn, - table="event_edges", - keyvalues=keyvalues, - retcols=["prev_event_id", "is_state"], - ) - - hashes = self._get_prev_event_hashes_txn(txn, event_id) - - results = [] - for d in res: - edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"]) - edge_hash.update(hashes.get(d["prev_event_id"], {})) - prev_hashes = { - k: encode_base64(v) - for k, v in edge_hash.items() - if k == "sha256" - } - results.append((d["prev_event_id"], prev_hashes, d["is_state"])) - - return results - - def _get_auth_events(self, txn, event_id): - auth_ids = self._simple_select_onecol_txn( - txn, - table="event_auth", - keyvalues={ - "event_id": event_id, - }, - retcol="auth_id", - ) - - results = [] - for auth_id in auth_ids: - hashes = self._get_event_reference_hashes_txn(txn, auth_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" - } - results.append((auth_id, prev_hashes)) - - return results - def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 46df6b4d6..416ef6af9 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -890,22 +890,11 @@ class EventsStore(SQLBaseStore): return ev - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) - def _parse_events_txn(self, txn, rows): event_ids = [r["event_id"] for r in rows] return self._get_events_txn(txn, event_ids) - def _has_been_redacted_txn(self, txn, event): - sql = "SELECT event_id FROM redactions WHERE redacts = ?" - txn.execute(sql, (event.event_id,)) - result = txn.fetchone() - return result[0] if result else None - @defer.inlineCallbacks def count_daily_messages(self): """ diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 00b748f13..345c4e110 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, Table +from ._base import SQLBaseStore from twisted.internet import defer from synapse.api.errors import StoreError @@ -149,5 +149,5 @@ class PusherStore(SQLBaseStore): ) -class PushersTable(Table): +class PushersTable(object): table_name = "pushers" diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 41c939efb..8c40d9a8a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -178,12 +178,6 @@ class RoomMemberStore(SQLBaseStore): return joined_domains - def _get_members_query(self, where_clause, where_values): - return self.runInteraction( - "get_members_query", self._get_members_events_txn, - where_clause, where_values - ).addCallbacks(self._get_events) - def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): rows = self._get_members_rows_txn( txn, diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index ab57b9217..b070be504 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -24,41 +24,6 @@ from synapse.crypto.event_signing import compute_event_reference_hash class SignatureStore(SQLBaseStore): """Persistence for event signatures and hashes""" - def _get_event_content_hashes_txn(self, txn, event_id): - """Get all the hashes for a given Event. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict of algorithm -> hash. - """ - query = ( - "SELECT algorithm, hash" - " FROM event_content_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id, )) - return dict(txn.fetchall()) - - def _store_event_content_hash_txn(self, txn, event_id, algorithm, - hash_bytes): - """Store a hash for a Event - Args: - txn (cursor): - event_id (str): Id for the Event. - algorithm (str): Hashing algorithm. - hash_bytes (bytes): Hash function output bytes. - """ - self._simple_insert_txn( - txn, - "event_content_hashes", - { - "event_id": event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, - ) - def get_event_reference_hashes(self, event_ids): def f(txn): return [ @@ -123,80 +88,3 @@ class SignatureStore(SQLBaseStore): table="event_reference_hashes", values=vals, ) - - def _get_event_signatures_txn(self, txn, event_id): - """Get all the signatures for a given PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict of sig name -> dict(key_id -> signature_bytes) - """ - query = ( - "SELECT signature_name, key_id, signature" - " FROM event_signatures" - " WHERE event_id = ? " - ) - txn.execute(query, (event_id, )) - rows = txn.fetchall() - - res = {} - - for name, key, sig in rows: - res.setdefault(name, {})[key] = sig - - return res - - def _store_event_signature_txn(self, txn, event_id, signature_name, key_id, - signature_bytes): - """Store a signature from the origin server for a PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - origin (str): origin of the Event. - key_id (str): Id for the signing key. - signature (bytes): The signature. - """ - self._simple_insert_txn( - txn, - "event_signatures", - { - "event_id": event_id, - "signature_name": signature_name, - "key_id": key_id, - "signature": buffer(signature_bytes), - }, - ) - - def _get_prev_event_hashes_txn(self, txn, event_id): - """Get all the hashes for previous PDUs of a PDU - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. - """ - query = ( - "SELECT prev_event_id, algorithm, hash" - " FROM event_edge_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id, )) - results = {} - for prev_event_id, algorithm, hash_bytes in txn.fetchall(): - hashes = results.setdefault(prev_event_id, {}) - hashes[algorithm] = hash_bytes - return results - - def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, - algorithm, hash_bytes): - self._simple_insert_txn( - txn, - "event_edge_hashes", - { - "event_id": event_id, - "prev_event_id": prev_event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, - ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 9630efcfc..e935b9443 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,8 +20,6 @@ from synapse.util.caches.descriptors import ( from twisted.internet import defer -from synapse.util.stringutils import random_string - import logging logger = logging.getLogger(__name__) @@ -428,7 +426,3 @@ class StateStore(SQLBaseStore): } defer.returnValue(results) - - -def _make_group_id(clock): - return str(int(clock.time_msec())) + random_string(5) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index aaa3609aa..699083ae1 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -23,22 +23,6 @@ from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.receipts import ReceiptEventSource -class NullSource(object): - """This event source never yields any events and its token remains at - zero. It may be useful for unit-testing.""" - def __init__(self, hs): - pass - - def get_new_events_for_user(self, user, from_key, limit): - return defer.succeed(([], from_key)) - - def get_current_key(self, direction='f'): - return defer.succeed(0) - - def get_pagination_rows(self, user, pagination_config, key): - return defer.succeed(([], pagination_config.from_key)) - - class EventSources(object): SOURCE_TYPES = { "room": RoomEventSource, @@ -70,15 +54,3 @@ class EventSources(object): ), ) defer.returnValue(token) - - -class StreamSource(object): - def get_new_events_for_user(self, user, from_key, limit): - """from_key is the key within this event source.""" - raise NotImplementedError("get_new_events_for_user") - - def get_current_key(self): - raise NotImplementedError("get_current_key") - - def get_pagination_rows(self, user, pagination_config, key): - raise NotImplementedError("get_rows") diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 07ff25cef..1d123ccef 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -29,34 +29,6 @@ def unwrapFirstError(failure): return failure.value.subFailure -def unwrap_deferred(d): - """Given a deferred that we know has completed, return its value or raise - the failure as an exception - """ - if not d.called: - raise RuntimeError("deferred has not finished") - - res = [] - - def f(r): - res.append(r) - return r - d.addCallback(f) - - if res: - return res[0] - - def f(r): - res.append(r) - return r - d.addErrback(f) - - if res: - res[0].raiseException() - else: - raise RuntimeError("deferred did not call callbacks") - - class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests. diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 2ee3da0b3..29d9bbaad 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -41,6 +41,22 @@ myid = "@apple:test" PATH_PREFIX = "/_matrix/client/api/v1" +class NullSource(object): + """This event source never yields any events and its token remains at + zero. It may be useful for unit-testing.""" + def __init__(self, hs): + pass + + def get_new_events_for_user(self, user, from_key, limit): + return defer.succeed(([], from_key)) + + def get_current_key(self, direction='f'): + return defer.succeed(0) + + def get_pagination_rows(self, user, pagination_config, key): + return defer.succeed(([], pagination_config.from_key)) + + class JustPresenceHandlers(object): def __init__(self, hs): self.presence_handler = PresenceHandler(hs) @@ -243,7 +259,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): # HIDEOUS HACKERY # TODO(paul): This should be injected in via the HomeServer DI system from synapse.streams.events import ( - PresenceEventSource, NullSource, EventSources + PresenceEventSource, EventSources ) old_SOURCE_TYPES = EventSources.SOURCE_TYPES diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 8573f18b5..1ddca1da4 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -185,26 +185,6 @@ class SQLBaseStoreTestCase(unittest.TestCase): [3, 4, 1, 2] ) - @defer.inlineCallbacks - def test_update_one_with_return(self): - self.mock_txn.rowcount = 1 - self.mock_txn.fetchone.return_value = ("Old Value",) - - ret = yield self.datastore._simple_selectupdate_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columname": "New Value"}, - retcols=["columname"] - ) - - self.assertEquals({"columname": "Old Value"}, ret) - self.mock_txn.execute.assert_has_calls([ - call('SELECT columname FROM tablename WHERE keycol = ?', - ['TheKey']), - call("UPDATE tablename SET columname = ? WHERE keycol = ?", - ["New Value", "TheKey"]) - ]) - @defer.inlineCallbacks def test_delete_one(self): self.mock_txn.rowcount = 1