mirror of
				https://git.anonymousland.org/anonymousland/synapse.git
				synced 2025-11-04 01:04:02 -05:00 
			
		
		
		
	Remove concept of a non-limited stream. (#7011)
This commit is contained in:
		
							parent
							
								
									caec7d4fa0
								
							
						
					
					
						commit
						fdb1344716
					
				
					 8 changed files with 72 additions and 68 deletions
				
			
		
							
								
								
									
										1
									
								
								changelog.d/7011.misc
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/7011.misc
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Remove concept of a non-limited stream.
 | 
			
		||||
| 
						 | 
				
			
			@ -747,7 +747,7 @@ class PresenceHandler(object):
 | 
			
		|||
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    async def get_all_presence_updates(self, last_id, current_id):
 | 
			
		||||
    async def get_all_presence_updates(self, last_id, current_id, limit):
 | 
			
		||||
        """
 | 
			
		||||
        Gets a list of presence update rows from between the given stream ids.
 | 
			
		||||
        Each row has:
 | 
			
		||||
| 
						 | 
				
			
			@ -762,7 +762,7 @@ class PresenceHandler(object):
 | 
			
		|||
        """
 | 
			
		||||
        # TODO(markjh): replicate the unpersisted changes.
 | 
			
		||||
        # This could use the in-memory stores for recent changes.
 | 
			
		||||
        rows = await self.store.get_all_presence_updates(last_id, current_id)
 | 
			
		||||
        rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
 | 
			
		||||
        return rows
 | 
			
		||||
 | 
			
		||||
    def notify_new_event(self):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,6 +15,7 @@
 | 
			
		|||
 | 
			
		||||
import logging
 | 
			
		||||
from collections import namedtuple
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -257,7 +258,13 @@ class TypingHandler(object):
 | 
			
		|||
            "typing_key", self._latest_room_serial, rooms=[member.room_id]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def get_all_typing_updates(self, last_id, current_id):
 | 
			
		||||
    async def get_all_typing_updates(
 | 
			
		||||
        self, last_id: int, current_id: int, limit: int
 | 
			
		||||
    ) -> List[dict]:
 | 
			
		||||
        """Get up to `limit` typing updates between the given tokens, earliest
 | 
			
		||||
        updates first.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        if last_id == current_id:
 | 
			
		||||
            return []
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -275,7 +282,7 @@ class TypingHandler(object):
 | 
			
		|||
                typing = self._room_typing[room_id]
 | 
			
		||||
                rows.append((serial, room_id, list(typing)))
 | 
			
		||||
        rows.sort()
 | 
			
		||||
        return rows
 | 
			
		||||
        return rows[:limit]
 | 
			
		||||
 | 
			
		||||
    def get_current_token(self):
 | 
			
		||||
        return self._latest_room_serial
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -166,11 +166,6 @@ class ReplicationStreamer(object):
 | 
			
		|||
                self.pending_updates = False
 | 
			
		||||
 | 
			
		||||
                with Measure(self.clock, "repl.stream.get_updates"):
 | 
			
		||||
                    # First we tell the streams that they should update their
 | 
			
		||||
                    # current tokens.
 | 
			
		||||
                    for stream in self.streams:
 | 
			
		||||
                        stream.advance_current_token()
 | 
			
		||||
 | 
			
		||||
                    all_streams = self.streams
 | 
			
		||||
 | 
			
		||||
                    if self._replication_torture_level is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -180,7 +175,7 @@ class ReplicationStreamer(object):
 | 
			
		|||
                        random.shuffle(all_streams)
 | 
			
		||||
 | 
			
		||||
                    for stream in all_streams:
 | 
			
		||||
                        if stream.last_token == stream.upto_token:
 | 
			
		||||
                        if stream.last_token == stream.current_token():
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        if self._replication_torture_level:
 | 
			
		||||
| 
						 | 
				
			
			@ -192,7 +187,7 @@ class ReplicationStreamer(object):
 | 
			
		|||
                            "Getting stream: %s: %s -> %s",
 | 
			
		||||
                            stream.NAME,
 | 
			
		||||
                            stream.last_token,
 | 
			
		||||
                            stream.upto_token,
 | 
			
		||||
                            stream.current_token(),
 | 
			
		||||
                        )
 | 
			
		||||
                        try:
 | 
			
		||||
                            updates, current_token = await stream.get_updates()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,10 +17,12 @@
 | 
			
		|||
import itertools
 | 
			
		||||
import logging
 | 
			
		||||
from collections import namedtuple
 | 
			
		||||
from typing import Any, List, Optional
 | 
			
		||||
from typing import Any, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import attr
 | 
			
		||||
 | 
			
		||||
from synapse.types import JsonDict
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -119,13 +121,12 @@ class Stream(object):
 | 
			
		|||
    """Base class for the streams.
 | 
			
		||||
 | 
			
		||||
    Provides a `get_updates()` function that returns new updates since the last
 | 
			
		||||
    time it was called up until the point `advance_current_token` was called.
 | 
			
		||||
    time it was called.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    NAME = None  # type: str  # The name of the stream
 | 
			
		||||
    # The type of the row. Used by the default impl of parse_row.
 | 
			
		||||
    ROW_TYPE = None  # type: Any
 | 
			
		||||
    _LIMITED = True  # Whether the update function takes a limit
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def parse_row(cls, row):
 | 
			
		||||
| 
						 | 
				
			
			@ -146,26 +147,15 @@ class Stream(object):
 | 
			
		|||
        # The token from which we last asked for updates
 | 
			
		||||
        self.last_token = self.current_token()
 | 
			
		||||
 | 
			
		||||
        # The token that we will get updates up to
 | 
			
		||||
        self.upto_token = self.current_token()
 | 
			
		||||
 | 
			
		||||
    def advance_current_token(self):
 | 
			
		||||
        """Updates `upto_token` to "now", which updates up until which point
 | 
			
		||||
        get_updates[_since] will fetch rows till.
 | 
			
		||||
        """
 | 
			
		||||
        self.upto_token = self.current_token()
 | 
			
		||||
 | 
			
		||||
    def discard_updates_and_advance(self):
 | 
			
		||||
        """Called when the stream should advance but the updates would be discarded,
 | 
			
		||||
        e.g. when there are no currently connected workers.
 | 
			
		||||
        """
 | 
			
		||||
        self.upto_token = self.current_token()
 | 
			
		||||
        self.last_token = self.upto_token
 | 
			
		||||
        self.last_token = self.current_token()
 | 
			
		||||
 | 
			
		||||
    async def get_updates(self):
 | 
			
		||||
        """Gets all updates since the last time this function was called (or
 | 
			
		||||
        since the stream was constructed if it hadn't been called before),
 | 
			
		||||
        until the `upto_token`
 | 
			
		||||
        since the stream was constructed if it hadn't been called before).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Deferred[Tuple[List[Tuple[int, Any]], int]:
 | 
			
		||||
| 
						 | 
				
			
			@ -178,44 +168,45 @@ class Stream(object):
 | 
			
		|||
 | 
			
		||||
        return updates, current_token
 | 
			
		||||
 | 
			
		||||
    async def get_updates_since(self, from_token):
 | 
			
		||||
    async def get_updates_since(
 | 
			
		||||
        self, from_token: int
 | 
			
		||||
    ) -> Tuple[List[Tuple[int, JsonDict]], int]:
 | 
			
		||||
        """Like get_updates except allows specifying from when we should
 | 
			
		||||
        stream updates
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Deferred[Tuple[List[Tuple[int, Any]], int]:
 | 
			
		||||
                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
 | 
			
		||||
                list of ``(token, row)`` entries. ``row`` will be json-serialised and
 | 
			
		||||
                sent over the replication steam.
 | 
			
		||||
            Resolves to a pair `(updates, new_last_token)`, where `updates` is
 | 
			
		||||
            a list of `(token, row)` entries and `new_last_token` is the new
 | 
			
		||||
            position in stream.
 | 
			
		||||
        """
 | 
			
		||||
        if from_token in ("NOW", "now"):
 | 
			
		||||
            return [], self.upto_token
 | 
			
		||||
 | 
			
		||||
        current_token = self.upto_token
 | 
			
		||||
        if from_token in ("NOW", "now"):
 | 
			
		||||
            return [], self.current_token()
 | 
			
		||||
 | 
			
		||||
        current_token = self.current_token()
 | 
			
		||||
 | 
			
		||||
        from_token = int(from_token)
 | 
			
		||||
 | 
			
		||||
        if from_token == current_token:
 | 
			
		||||
            return [], current_token
 | 
			
		||||
 | 
			
		||||
        logger.info("get_updates_since: %s", self.__class__)
 | 
			
		||||
        if self._LIMITED:
 | 
			
		||||
            rows = await self.update_function(
 | 
			
		||||
                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
 | 
			
		||||
            )
 | 
			
		||||
        rows = await self.update_function(
 | 
			
		||||
            from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
            # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
 | 
			
		||||
            rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
 | 
			
		||||
        else:
 | 
			
		||||
            rows = await self.update_function(from_token, current_token)
 | 
			
		||||
        # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
 | 
			
		||||
        rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
 | 
			
		||||
 | 
			
		||||
        updates = [(row[0], row[1:]) for row in rows]
 | 
			
		||||
 | 
			
		||||
        # check we didn't get more rows than the limit.
 | 
			
		||||
        # doing it like this allows the update_function to be a generator.
 | 
			
		||||
        if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
 | 
			
		||||
        if len(updates) >= MAX_EVENTS_BEHIND:
 | 
			
		||||
            raise Exception("stream %s has fallen behind" % (self.NAME))
 | 
			
		||||
 | 
			
		||||
        # The update function didn't hit the limit, so we must have got all
 | 
			
		||||
        # the updates to `current_token`, and can return that as our new
 | 
			
		||||
        # stream position.
 | 
			
		||||
        return updates, current_token
 | 
			
		||||
 | 
			
		||||
    def current_token(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -227,9 +218,8 @@ class Stream(object):
 | 
			
		|||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def update_function(self, from_token, current_token, limit=None):
 | 
			
		||||
        """Get updates between from_token and to_token. If Stream._LIMITED is
 | 
			
		||||
        True then limit is provided, otherwise it's not.
 | 
			
		||||
    def update_function(self, from_token, current_token, limit):
 | 
			
		||||
        """Get updates between from_token and to_token.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Deferred(list(tuple)): the first entry in the tuple is the token for
 | 
			
		||||
| 
						 | 
				
			
			@ -257,7 +247,6 @@ class BackfillStream(Stream):
 | 
			
		|||
 | 
			
		||||
class PresenceStream(Stream):
 | 
			
		||||
    NAME = "presence"
 | 
			
		||||
    _LIMITED = False
 | 
			
		||||
    ROW_TYPE = PresenceStreamRow
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs):
 | 
			
		||||
| 
						 | 
				
			
			@ -272,7 +261,6 @@ class PresenceStream(Stream):
 | 
			
		|||
 | 
			
		||||
class TypingStream(Stream):
 | 
			
		||||
    NAME = "typing"
 | 
			
		||||
    _LIMITED = False
 | 
			
		||||
    ROW_TYPE = TypingStreamRow
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs):
 | 
			
		||||
| 
						 | 
				
			
			@ -372,7 +360,6 @@ class DeviceListsStream(Stream):
 | 
			
		|||
    """
 | 
			
		||||
 | 
			
		||||
    NAME = "device_lists"
 | 
			
		||||
    _LIMITED = False
 | 
			
		||||
    ROW_TYPE = DeviceListsStreamRow
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs):
 | 
			
		||||
| 
						 | 
				
			
			@ -462,7 +449,6 @@ class UserSignatureStream(Stream):
 | 
			
		|||
    """
 | 
			
		||||
 | 
			
		||||
    NAME = "user_signature"
 | 
			
		||||
    _LIMITED = False
 | 
			
		||||
    ROW_TYPE = UserSignatureStreamRow
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -576,7 +576,7 @@ class DeviceWorkerStore(SQLBaseStore):
 | 
			
		|||
            return set()
 | 
			
		||||
 | 
			
		||||
    async def get_all_device_list_changes_for_remotes(
 | 
			
		||||
        self, from_key: int, to_key: int
 | 
			
		||||
        self, from_key: int, to_key: int, limit: int,
 | 
			
		||||
    ) -> List[Tuple[int, str]]:
 | 
			
		||||
        """Return a list of `(stream_id, entity)` which is the combined list of
 | 
			
		||||
        changes to devices and which destinations need to be poked. Entity is
 | 
			
		||||
| 
						 | 
				
			
			@ -592,10 +592,16 @@ class DeviceWorkerStore(SQLBaseStore):
 | 
			
		|||
                SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
 | 
			
		||||
            ) AS e
 | 
			
		||||
            WHERE ? < stream_id AND stream_id <= ?
 | 
			
		||||
            LIMIT ?
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        return await self.db.execute(
 | 
			
		||||
            "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
 | 
			
		||||
            "get_all_device_list_changes_for_remotes",
 | 
			
		||||
            None,
 | 
			
		||||
            sql,
 | 
			
		||||
            from_key,
 | 
			
		||||
            to_key,
 | 
			
		||||
            limit,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @cached(max_entries=10000)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 | 
			
		|||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
 | 
			
		||||
    def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
 | 
			
		||||
        """Return a list of changes from the user signature stream to notify remotes.
 | 
			
		||||
        Note that the user signature stream represents when a user signs their
 | 
			
		||||
        device with their user-signing key, which is not published to other
 | 
			
		||||
| 
						 | 
				
			
			@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 | 
			
		|||
            Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
 | 
			
		||||
        """
 | 
			
		||||
        sql = """
 | 
			
		||||
            SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
 | 
			
		||||
            SELECT stream_id, from_user_id AS user_id
 | 
			
		||||
            FROM user_signature_stream
 | 
			
		||||
            WHERE ? < stream_id AND stream_id <= ?
 | 
			
		||||
            GROUP BY user_id
 | 
			
		||||
            ORDER BY stream_id ASC
 | 
			
		||||
            LIMIT ?
 | 
			
		||||
        """
 | 
			
		||||
        return self.db.execute(
 | 
			
		||||
            "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
 | 
			
		||||
            "get_all_user_signature_changes_for_remotes",
 | 
			
		||||
            None,
 | 
			
		||||
            sql,
 | 
			
		||||
            from_key,
 | 
			
		||||
            to_key,
 | 
			
		||||
            limit,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore):
 | 
			
		|||
                    "status_msg": state.status_msg,
 | 
			
		||||
                    "currently_active": state.currently_active,
 | 
			
		||||
                }
 | 
			
		||||
                for state in presence_states
 | 
			
		||||
                for stream_id, state in zip(stream_orderings, presence_states)
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore):
 | 
			
		|||
            )
 | 
			
		||||
            txn.execute(sql + clause, [stream_id] + list(args))
 | 
			
		||||
 | 
			
		||||
    def get_all_presence_updates(self, last_id, current_id):
 | 
			
		||||
    def get_all_presence_updates(self, last_id, current_id, limit):
 | 
			
		||||
        if last_id == current_id:
 | 
			
		||||
            return defer.succeed([])
 | 
			
		||||
 | 
			
		||||
        def get_all_presence_updates_txn(txn):
 | 
			
		||||
            sql = (
 | 
			
		||||
                "SELECT stream_id, user_id, state, last_active_ts,"
 | 
			
		||||
                " last_federation_update_ts, last_user_sync_ts, status_msg,"
 | 
			
		||||
                " currently_active"
 | 
			
		||||
                " FROM presence_stream"
 | 
			
		||||
                " WHERE ? < stream_id AND stream_id <= ?"
 | 
			
		||||
            )
 | 
			
		||||
            txn.execute(sql, (last_id, current_id))
 | 
			
		||||
            sql = """
 | 
			
		||||
                SELECT stream_id, user_id, state, last_active_ts,
 | 
			
		||||
                    last_federation_update_ts, last_user_sync_ts,
 | 
			
		||||
                    status_msg,
 | 
			
		||||
                currently_active
 | 
			
		||||
                FROM presence_stream
 | 
			
		||||
                WHERE ? < stream_id AND stream_id <= ?
 | 
			
		||||
                ORDER BY stream_id ASC
 | 
			
		||||
                LIMIT ?
 | 
			
		||||
            """
 | 
			
		||||
            txn.execute(sql, (last_id, current_id, limit))
 | 
			
		||||
            return txn.fetchall()
 | 
			
		||||
 | 
			
		||||
        return self.db.runInteraction(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue