diff --git a/CHANGES.rst b/CHANGES.rst index 6c85241ea..82247fa52 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,96 @@ +Changes in synapse v0.21.0 (2017-05-18) +======================================= + +No changes since v0.21.0-rc3 + + +Changes in synapse v0.21.0-rc3 (2017-05-17) +=========================================== + +Features: + +* Add per user rate-limiting overrides (PR #2208) +* Add config option to limit maximum number of events requested by ``/sync`` + and ``/messages`` (PR #2221) Thanks to @psaavedra! + + +Changes: + +* Various small performance fixes (PR #2201, #2202, #2224, #2226, #2227, #2228, + #2229) +* Update username availability checker API (PR #2209, #2213) +* When purging, don't de-delta state groups we're about to delete (PR #2214) +* Documentation to check synapse version (PR #2215) Thanks to @hamber-dick! +* Add an index to event_search to speed up purge history API (PR #2218) + + +Bug fixes: + +* Fix API to allow clients to upload one-time-keys with new sigs (PR #2206) + + +Changes in synapse v0.21.0-rc2 (2017-05-08) +=========================================== + +Changes: + +* Always mark remotes as up if we receive a signed request from them (PR #2190) + + +Bug fixes: + +* Fix bug where users got pushed for rooms they had muted (PR #2200) + + +Changes in synapse v0.21.0-rc1 (2017-05-08) +=========================================== + +Features: + +* Add username availability checker API (PR #2183) +* Add read marker API (PR #2120) + + +Changes: + +* Enable guest access for the 3pl/3pid APIs (PR #1986) +* Add setting to support TURN for guests (PR #2011) +* Various performance improvements (PR #2075, #2076, #2080, #2083, #2108, + #2158, #2176, #2185) +* Make synctl a bit more user friendly (PR #2078, #2127) Thanks @APwhitehat! +* Replace HTTP replication with TCP replication (PR #2082, #2097, #2098, + #2099, #2103, #2014, #2016, #2115, #2116, #2117) +* Support authenticated SMTP (PR #2102) Thanks @DanielDent! +* Add a counter metric for successfully-sent transactions (PR #2121) +* Propagate errors sensibly from proxied IS requests (PR #2147) +* Add more granular event send metrics (PR #2178) + + + +Bug fixes: + +* Fix nuke-room script to work with current schema (PR #1927) Thanks + @zuckschwerdt! +* Fix db port script to not assume postgres tables are in the public schema + (PR #2024) Thanks @jerrykan! +* Fix getting latest device IP for user with no devices (PR #2118) +* Fix rejection of invites to unreachable servers (PR #2145) +* Fix code for reporting old verify keys in synapse (PR #2156) +* Fix invite state to always include all events (PR #2163) +* Fix bug where synapse would always fetch state for any missing event (PR #2170) +* Fix a leak with timed out HTTP connections (PR #2180) +* Fix bug where we didn't time out HTTP requests to ASes (PR #2192) + + +Docs: + +* Clarify doc for SQLite to PostgreSQL port (PR #1961) Thanks @benhylau! +* Fix typo in synctl help (PR #2107) Thanks @HarHarLinks! +* ``web_client_location`` documentation fix (PR #2131) Thanks @matthewjwolff! +* Update README.rst with FreeBSD changes (PR #2132) Thanks @feld! +* Clarify setting up metrics (PR #2149) Thanks @encks! + + Changes in synapse v0.20.0 (2017-04-11) ======================================= diff --git a/README.rst b/README.rst index c92546124..35141ac71 100644 --- a/README.rst +++ b/README.rst @@ -109,10 +109,10 @@ Installing prerequisites on ArchLinux:: sudo pacman -S base-devel python2 python-pip \ python-setuptools python-virtualenv sqlite3 -Installing prerequisites on CentOS 7:: +Installing prerequisites on CentOS 7 or Fedora 25:: sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ - lcms2-devel libwebp-devel tcl-devel tk-devel \ + lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \ python-virtualenv libffi-devel openssl-devel sudo yum groupinstall "Development Tools" @@ -246,6 +246,25 @@ Setting up a TURN server For reliable VoIP calls to be routed via this homeserver, you MUST configure a TURN server. See ``_ for details. +IPv6 +---- + +As of Synapse 0.19 we finally support IPv6, many thanks to @kyrias and @glyph +for providing PR #1696. + +However, for federation to work on hosts with IPv6 DNS servers you **must** +be running Twisted 17.1.0 or later - see https://github.com/matrix-org/synapse/issues/1002 +for details. We can't make Synapse depend on Twisted 17.1 by default +yet as it will break most older distributions (see https://github.com/matrix-org/synapse/pull/1909) +so if you are using operating system dependencies you'll have to install your +own Twisted 17.1 package via pip or backports etc. + +If you're running in a virtualenv then pip should have installed the newest +Twisted automatically, but if your virtualenv is old you will need to manually +upgrade to a newer Twisted dependency via: + + pip install Twisted>=17.1.0 + Running Synapse =============== @@ -336,8 +355,11 @@ ArchLinux --------- The quickest way to get up and running with ArchLinux is probably with the community package -https://www.archlinux.org/packages/community/any/matrix-synapse/, which should pull in all -the necessary dependencies. +https://www.archlinux.org/packages/community/any/matrix-synapse/, which should pull in most of +the necessary dependencies. If the default web client is to be served (enabled by default in +the generated config), +https://www.archlinux.org/packages/community/any/python2-matrix-angular-sdk/ will also need to +be installed. Alternatively, to install using pip a few changes may be needed as ArchLinux defaults to python 3, but synapse currently assumes python 2.7 by default: @@ -374,7 +396,7 @@ FreeBSD Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from: - - Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean`` + - Ports: ``cd /usr/ports/net-im/py-matrix-synapse && make install clean`` - Packages: ``pkg install py27-matrix-synapse`` diff --git a/UPGRADE.rst b/UPGRADE.rst index 9f044719a..6164df883 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -28,6 +28,15 @@ running: git pull # Update the versions of synapse's python dependencies. python synapse/python_dependencies.py | xargs -n1 pip install --upgrade + +To check whether your update was sucessfull, run: + +.. code:: bash + + # replace your.server.domain with ther domain of your synaspe homeserver + curl https:///_matrix/federation/v1/version + +So for the Matrix.org HS server the URL would be: https://matrix.org/_matrix/federation/v1/version. Upgrading to v0.15.0 diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 418689731..c833f3f31 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -36,15 +36,13 @@ class HttpClient(object): the request body. This will be encoded as JSON. Returns: - Deferred: Succeeds when we get *any* HTTP response. - - The result of the deferred is a tuple of `(code, response)`, - where `response` is a dict representing the decoded JSON body. + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. """ pass def get_json(self, url, args=None): - """ Get's some json from the given host homeserver and path + """ Gets some json from the given host homeserver and path Args: url (str): The URL to GET data from. @@ -54,10 +52,8 @@ class HttpClient(object): and *not* a string. Returns: - Deferred: Succeeds when we get *any* HTTP response. - - The result of the deferred is a tuple of `(code, response)`, - where `response` is a dict representing the decoded JSON body. + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. """ pass @@ -214,4 +210,4 @@ class _JsonProducer(object): pass def stopProducing(self): - pass \ No newline at end of file + pass diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst new file mode 100644 index 000000000..1c9c5a6bd --- /dev/null +++ b/docs/admin_api/user_admin_api.rst @@ -0,0 +1,73 @@ +Query Account +============= + +This API returns information about a specific user account. + +The api is:: + + GET /_matrix/client/r0/admin/whois/ + +including an ``access_token`` of a server admin. + +It returns a JSON body like the following: + +.. code:: json + + { + "user_id": "", + "devices": { + "": { + "sessions": [ + { + "connections": [ + { + "ip": "1.2.3.4", + "last_seen": 1417222374433, + "user_agent": "Mozilla/5.0 ..." + }, + { + "ip": "1.2.3.10", + "last_seen": 1417222374500, + "user_agent": "Dalvik/2.1.0 ..." + } + ] + } + ] + } + } + } + +``last_seen`` is measured in milliseconds since the Unix epoch. + +Deactivate Account +================== + +This API deactivates an account. It removes active access tokens, resets the +password, and deletes third-party IDs (to prevent the user requesting a +password reset). + +The api is:: + + POST /_matrix/client/r0/admin/deactivate/ + +including an ``access_token`` of a server admin, and an empty request body. + + +Reset password +============== + +Changes the password of another user. + +The api is:: + + POST /_matrix/client/r0/admin/reset_password/ + +with a body of: + +.. code:: json + + { + "new_password": "" + } + +including an ``access_token`` of a server admin. diff --git a/docs/metrics-howto.rst b/docs/metrics-howto.rst index 7390ab85c..143cd0f42 100644 --- a/docs/metrics-howto.rst +++ b/docs/metrics-howto.rst @@ -21,13 +21,12 @@ How to monitor Synapse metrics using Prometheus 3. Add a prometheus target for synapse. - It needs to set the ``metrics_path`` to a non-default value:: + It needs to set the ``metrics_path`` to a non-default value (under ``scrape_configs``):: - job_name: "synapse" metrics_path: "/_synapse/metrics" static_configs: - - targets: - "my.server.here:9092" + - targets: ["my.server.here:9092"] If your prometheus is older than 1.5.2, you will need to replace ``static_configs`` in the above with ``target_groups``. diff --git a/docs/postgres.rst b/docs/postgres.rst index 402ff9a4d..b592801e9 100644 --- a/docs/postgres.rst +++ b/docs/postgres.rst @@ -112,9 +112,9 @@ script one last time, e.g. if the SQLite database is at ``homeserver.db`` run:: synapse_port_db --sqlite-database homeserver.db \ - --postgres-config database_config.yaml + --postgres-config homeserver-postgres.yaml Once that has completed, change the synapse config to point at the PostgreSQL -database configuration file using the ``database_config`` parameter (see -`Synapse Config`_) and restart synapse. Synapse should now be running against +database configuration file ``homeserver-postgres.yaml`` (i.e. rename it to +``homeserver.yaml``) and restart synapse. Synapse should now be running against PostgreSQL. diff --git a/docs/replication.rst b/docs/replication.rst index 7e37e7198..310abb348 100644 --- a/docs/replication.rst +++ b/docs/replication.rst @@ -26,28 +26,10 @@ expose the append-only log to the readers should be fairly minimal. Architecture ------------ -The Replication API -~~~~~~~~~~~~~~~~~~~ +The Replication Protocol +~~~~~~~~~~~~~~~~~~~~~~~~ -Synapse will optionally expose a long poll HTTP API for extracting updates. The -API will have a similar shape to /sync in that clients provide tokens -indicating where in the log they have reached and a timeout. The synapse server -then either responds with updates immediately if it already has updates or it -waits until the timeout for more updates. If the timeout expires and nothing -happened then the server returns an empty response. - -However unlike the /sync API this replication API is returning synapse specific -data rather than trying to implement a matrix specification. The replication -results are returned as arrays of rows where the rows are mostly lifted -directly from the database. This avoids unnecessary JSON parsing on the server -and hopefully avoids an impedance mismatch between the data returned and the -required updates to the datastore. - -This does not replicate all the database tables as many of the database tables -are indexes that can be recovered from the contents of other tables. - -The format and parameters for the api are documented in -``synapse/replication/resource.py``. +See ``tcp_replication.rst`` The Slaved DataStore diff --git a/docs/tcp_replication.rst b/docs/tcp_replication.rst new file mode 100644 index 000000000..62225ba6f --- /dev/null +++ b/docs/tcp_replication.rst @@ -0,0 +1,223 @@ +TCP Replication +=============== + +Motivation +---------- + +Previously the workers used an HTTP long poll mechanism to get updates from the +master, which had the problem of causing a lot of duplicate work on the server. +This TCP protocol replaces those APIs with the aim of increased efficiency. + + + +Overview +-------- + +The protocol is based on fire and forget, line based commands. An example flow +would be (where '>' indicates master to worker and '<' worker to master flows):: + + > SERVER example.com + < REPLICATE events 53 + > RDATA events 54 ["$foo1:bar.com", ...] + > RDATA events 55 ["$foo4:bar.com", ...] + +The example shows the server accepting a new connection and sending its identity +with the ``SERVER`` command, followed by the client asking to subscribe to the +``events`` stream from the token ``53``. The server then periodically sends ``RDATA`` +commands which have the format ``RDATA ``, where the +format of ```` is defined by the individual streams. + +Error reporting happens by either the client or server sending an `ERROR` +command, and usually the connection will be closed. + + +Since the protocol is a simple line based, its possible to manually connect to +the server using a tool like netcat. A few things should be noted when manually +using the protocol: + +* When subscribing to a stream using ``REPLICATE``, the special token ``NOW`` can + be used to get all future updates. The special stream name ``ALL`` can be used + with ``NOW`` to subscribe to all available streams. +* The federation stream is only available if federation sending has been + disabled on the main process. +* The server will only time connections out that have sent a ``PING`` command. + If a ping is sent then the connection will be closed if no further commands + are receieved within 15s. Both the client and server protocol implementations + will send an initial PING on connection and ensure at least one command every + 5s is sent (not necessarily ``PING``). +* ``RDATA`` commands *usually* include a numeric token, however if the stream + has multiple rows to replicate per token the server will send multiple + ``RDATA`` commands, with all but the last having a token of ``batch``. See + the documentation on ``commands.RdataCommand`` for further details. + + +Architecture +------------ + +The basic structure of the protocol is line based, where the initial word of +each line specifies the command. The rest of the line is parsed based on the +command. For example, the `RDATA` command is defined as:: + + RDATA + +(Note that `` may contains spaces, but cannot contain newlines.) + +Blank lines are ignored. + + +Keep alives +~~~~~~~~~~~ + +Both sides are expected to send at least one command every 5s or so, and +should send a ``PING`` command if necessary. If either side do not receive a +command within e.g. 15s then the connection should be closed. + +Because the server may be connected to manually using e.g. netcat, the timeouts +aren't enabled until an initial ``PING`` command is seen. Both the client and +server implementations below send a ``PING`` command immediately on connection to +ensure the timeouts are enabled. + +This ensures that both sides can quickly realize if the tcp connection has gone +and handle the situation appropriately. + + +Start up +~~~~~~~~ + +When a new connection is made, the server: + +* Sends a ``SERVER`` command, which includes the identity of the server, allowing + the client to detect if its connected to the expected server +* Sends a ``PING`` command as above, to enable the client to time out connections + promptly. + +The client: + +* Sends a ``NAME`` command, allowing the server to associate a human friendly + name with the connection. This is optional. +* Sends a ``PING`` as above +* For each stream the client wishes to subscribe to it sends a ``REPLICATE`` + with the stream_name and token it wants to subscribe from. +* On receipt of a ``SERVER`` command, checks that the server name matches the + expected server name. + + +Error handling +~~~~~~~~~~~~~~ + +If either side detects an error it can send an ``ERROR`` command and close the +connection. + +If the client side loses the connection to the server it should reconnect, +following the steps above. + + +Congestion +~~~~~~~~~~ + +If the server sends messages faster than the client can consume them the server +will first buffer a (fairly large) number of commands and then disconnect the +client. This ensures that we don't queue up an unbounded number of commands in +memory and gives us a potential oppurtunity to squawk loudly. When/if the client +recovers it can reconnect to the server and ask for missed messages. + + +Reliability +~~~~~~~~~~~ + +In general the replication stream should be considered an unreliable transport +since e.g. commands are not resent if the connection disappears. + +The exception to that are the replication streams, i.e. RDATA commands, since +these include tokens which can be used to restart the stream on connection +errors. + +The client should keep track of the token in the last RDATA command received +for each stream so that on reconneciton it can start streaming from the correct +place. Note: not all RDATA have valid tokens due to batching. See +``RdataCommand`` for more details. + + +Example +~~~~~~~ + +An example iteraction is shown below. Each line is prefixed with '>' or '<' to +indicate which side is sending, these are *not* included on the wire:: + + * connection established * + > SERVER localhost:8823 + > PING 1490197665618 + < NAME synapse.app.appservice + < PING 1490197665618 + < REPLICATE events 1 + < REPLICATE backfill 1 + < REPLICATE caches 1 + > POSITION events 1 + > POSITION backfill 1 + > POSITION caches 1 + > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] + > RDATA events 14 ["$149019767112vOHxz:localhost:8823", + "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] + < PING 1490197675618 + > ERROR server stopping + * connection closed by server * + +The ``POSITION`` command sent by the server is used to set the clients position +without needing to send data with the ``RDATA`` command. + + +An example of a batched set of ``RDATA`` is:: + + > RDATA caches batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] + > RDATA caches batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] + > RDATA caches batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] + > RDATA caches 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] + +In this case the client shouldn't advance their caches token until it sees the +the last ``RDATA``. + + +List of commands +~~~~~~~~~~~~~~~~ + +The list of valid commands, with which side can send it: server (S) or client (C): + +SERVER (S) + Sent at the start to identify which server the client is talking to + +RDATA (S) + A single update in a stream + +POSITION (S) + The position of the stream has been updated + +ERROR (S, C) + There was an error + +PING (S, C) + Sent periodically to ensure the connection is still alive + +NAME (C) + Sent at the start by client to inform the server who they are + +REPLICATE (C) + Asks the server to replicate a given stream + +USER_SYNC (C) + A user has started or stopped syncing + +FEDERATION_ACK (C) + Acknowledge receipt of some federation data + +REMOVE_PUSHER (C) + Inform the server a pusher should be removed + +INVALIDATE_CACHE (C) + Inform the server a cache should be invalidated + +SYNC (S, C) + Used exclusively in tests + + +See ``synapse/replication/tcp/commands.py`` for a detailed description and the +format of each command. diff --git a/docs/turn-howto.rst b/docs/turn-howto.rst index 04c010071..e48628ce6 100644 --- a/docs/turn-howto.rst +++ b/docs/turn-howto.rst @@ -50,14 +50,37 @@ You may be able to setup coturn via your package manager, or set it up manually pwgen -s 64 1 - 5. Ensure youe firewall allows traffic into the TURN server on - the ports you've configured it to listen on (remember to allow - both TCP and UDP if you've enabled both). + 5. Consider your security settings. TURN lets users request a relay + which will connect to arbitrary IP addresses and ports. At the least + we recommend: - 6. If you've configured coturn to support TLS/DTLS, generate or + # VoIP traffic is all UDP. There is no reason to let users connect to arbitrary TCP endpoints via the relay. + no-tcp-relay + + # don't let the relay ever try to connect to private IP address ranges within your network (if any) + # given the turn server is likely behind your firewall, remember to include any privileged public IPs too. + denied-peer-ip=10.0.0.0-10.255.255.255 + denied-peer-ip=192.168.0.0-192.168.255.255 + denied-peer-ip=172.16.0.0-172.31.255.255 + + # special case the turn server itself so that client->TURN->TURN->client flows work + allowed-peer-ip=10.0.0.1 + + # consider whether you want to limit the quota of relayed streams per user (or total) to avoid risk of DoS. + user-quota=12 # 4 streams per video call, so 12 streams = 3 simultaneous relayed calls per user. + total-quota=1200 + + Ideally coturn should refuse to relay traffic which isn't SRTP; + see https://github.com/matrix-org/synapse/issues/2009 + + 6. Ensure your firewall allows traffic into the TURN server on + the ports you've configured it to listen on (remember to allow + both TCP and UDP TURN traffic) + + 7. If you've configured coturn to support TLS/DTLS, generate or import your private key and certificate. - 7. Start the turn server:: + 8. Start the turn server:: bin/turnserver -o @@ -83,12 +106,19 @@ Your home server configuration file needs the following extra keys: to refresh credentials. The TURN REST API specification recommends one day (86400000). + 4. "turn_allow_guests": Whether to allow guest users to use the TURN + server. This is enabled by default, as otherwise VoIP will not + work reliably for guests. However, it does introduce a security risk + as it lets guests connect to arbitrary endpoints without having gone + through a CAPTCHA or similar to register a real account. + As an example, here is the relevant section of the config file for matrix.org:: turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ] turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons turn_user_lifetime: 86400000 + turn_allow_guests: True Now, restart synapse:: diff --git a/docs/workers.rst b/docs/workers.rst index 65b6e690f..2d3df9159 100644 --- a/docs/workers.rst +++ b/docs/workers.rst @@ -12,7 +12,7 @@ across multiple processes is a recipe for disaster, plus you should be using postgres anyway if you care about scalability). The workers communicate with the master synapse process via a synapse-specific -HTTP protocol called 'replication' - analogous to MySQL or Postgres style +TCP protocol called 'replication' - analogous to MySQL or Postgres style database replication; feeding a stream of relevant data to the workers so they can be kept in sync with the main synapse process and database state. @@ -21,16 +21,11 @@ To enable workers, you need to add a replication listener to the master synapse, listeners: - port: 9092 bind_address: '127.0.0.1' - type: http - tls: false - x_forwarded: false - resources: - - names: [replication] - compress: false + type: replication Under **no circumstances** should this replication API listener be exposed to the public internet; it currently implements no authentication whatsoever and is -unencrypted HTTP. +unencrypted. You then create a set of configs for the various worker processes. These should be worker configuration files should be stored in a dedicated subdirectory, to allow @@ -50,14 +45,16 @@ e.g. the HTTP listener that it provides (if any); logging configuration; etc. You should minimise the number of overrides though to maintain a usable config. You must specify the type of worker application (worker_app) and the replication -endpoint that it's talking to on the main synapse process (worker_replication_url). +endpoint that it's talking to on the main synapse process (worker_replication_host +and worker_replication_port). For instance:: worker_app: synapse.app.synchrotron # The replication listener on the synapse to talk to. - worker_replication_url: http://127.0.0.1:9092/_synapse/replication + worker_replication_host: 127.0.0.1 + worker_replication_port: 9092 worker_listeners: - type: http @@ -95,4 +92,3 @@ To manipulate a specific worker, you pass the -w option to synctl:: All of the above is highly experimental and subject to change as Synapse evolves, but documenting it here to help folks needing highly scalable Synapses similar to the one running matrix.org! - diff --git a/scripts-dev/nuke-room-from-db.sh b/scripts-dev/nuke-room-from-db.sh index 58c036c89..1201d176c 100755 --- a/scripts-dev/nuke-room-from-db.sh +++ b/scripts-dev/nuke-room-from-db.sh @@ -9,16 +9,39 @@ ROOMID="$1" sqlite3 homeserver.db < last_sync_ms. Lists the users that have stopped syncing + # but we haven't notified the master of that yet + self.users_going_offline = {} + + self._send_stop_syncing_loop = self.clock.looping_call( + self.send_stop_syncing, 10 * 1000 + ) + self.process_id = random_string(16) logger.info("Presence process_id is %r", self.process_id) - self._sending_sync = False - self._need_to_send_sync = False - self.clock.looping_call( - self._send_syncing_users_regularly, - UPDATE_SYNCING_USERS_MS, - ) + def send_user_sync(self, user_id, is_syncing, last_sync_ms): + self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms) - reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) + def mark_as_coming_online(self, user_id): + """A user has started syncing. Send a UserSync to the master, unless they + had recently stopped syncing. + + Args: + user_id (str) + """ + going_offline = self.users_going_offline.pop(user_id, None) + if not going_offline: + # Safe to skip because we haven't yet told the master they were offline + self.send_user_sync(user_id, True, self.clock.time_msec()) + + def mark_as_going_offline(self, user_id): + """A user has stopped syncing. We wait before notifying the master as + its likely they'll come back soon. This allows us to avoid sending + a stopped syncing immediately followed by a started syncing notification + to the master + + Args: + user_id (str) + """ + self.users_going_offline[user_id] = self.clock.time_msec() + + def send_stop_syncing(self): + """Check if there are any users who have stopped syncing a while ago + and haven't come back yet. If there are poke the master about them. + """ + now = self.clock.time_msec() + for user_id, last_sync_ms in self.users_going_offline.items(): + if now - last_sync_ms > 10 * 1000: + self.users_going_offline.pop(user_id, None) + self.send_user_sync(user_id, False, last_sync_ms) def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? @@ -139,18 +162,16 @@ class SynchrotronPresence(object): get_states = PresenceHandler.get_states.__func__ get_state = PresenceHandler.get_state.__func__ - _get_interested_parties = PresenceHandler._get_interested_parties.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__ - @defer.inlineCallbacks def user_syncing(self, user_id, affect_presence): if affect_presence: curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 - prev_states = yield self.current_state_for_users([user_id]) - if prev_states[user_id].state == PresenceState.OFFLINE: - # TODO: Don't block the sync request on this HTTP hit. - yield self._send_syncing_users_now() + + # If we went from no in flight sync to some, notify replication + if self.user_to_num_current_syncs[user_id] == 1: + self.mark_as_coming_online(user_id) def _end(): # We check that the user_id is in user_to_num_current_syncs because @@ -159,6 +180,10 @@ class SynchrotronPresence(object): if affect_presence and user_id in self.user_to_num_current_syncs: self.user_to_num_current_syncs[user_id] -= 1 + # If we went from one in flight sync to non, notify replication + if self.user_to_num_current_syncs[user_id] == 0: + self.mark_as_going_offline(user_id) + @contextlib.contextmanager def _user_syncing(): try: @@ -166,56 +191,12 @@ class SynchrotronPresence(object): finally: _end() - defer.returnValue(_user_syncing()) - - @defer.inlineCallbacks - def _on_shutdown(self): - # When the synchrotron is shutdown tell the master to clear the in - # progress syncs for this process - self.user_to_num_current_syncs.clear() - yield self._send_syncing_users_now() - - def _send_syncing_users_regularly(self): - # Only send an update if we aren't in the middle of sending one. - if not self._sending_sync: - preserve_fn(self._send_syncing_users_now)() - - @defer.inlineCallbacks - def _send_syncing_users_now(self): - if self._sending_sync: - # We don't want to race with sending another update. - # Instead we wait for that update to finish and send another - # update afterwards. - self._need_to_send_sync = True - return - - # Flag that we are sending an update. - self._sending_sync = True - - yield self.http_client.post_json_get_json(self.syncing_users_url, { - "process_id": self.process_id, - "syncing_users": [ - user_id for user_id, count in self.user_to_num_current_syncs.items() - if count > 0 - ], - }) - - # Unset the flag as we are no longer sending an update. - self._sending_sync = False - if self._need_to_send_sync: - # If something happened while we were sending the update then - # we might need to send another update. - # TODO: Check if the update that was sent matches the current state - # as we only need to send an update if they are different. - self._need_to_send_sync = False - yield self._send_syncing_users_now() + return defer.succeed(_user_syncing()) @defer.inlineCallbacks def notify_from_replication(self, states, stream_id): - parties = yield self._get_interested_parties( - states, calculate_remote_hosts=False - ) - room_ids_to_states, users_to_states, _ = parties + parties = yield get_interested_parties(self.store, states) + room_ids_to_states, users_to_states = parties self.notifier.on_new_event( "presence_key", stream_id, rooms=room_ids_to_states.keys(), @@ -223,26 +204,24 @@ class SynchrotronPresence(object): ) @defer.inlineCallbacks - def process_replication(self, result): - stream = result.get("presence", {"rows": []}) - states = [] - for row in stream["rows"]: - ( - position, user_id, state, last_active_ts, - last_federation_update_ts, last_user_sync_ts, status_msg, - currently_active - ) = row - state = UserPresenceState( - user_id, state, last_active_ts, - last_federation_update_ts, last_user_sync_ts, status_msg, - currently_active - ) - self.user_to_current_state[user_id] = state - states.append(state) + def process_replication_rows(self, token, rows): + states = [UserPresenceState( + row.user_id, row.state, row.last_active_ts, + row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg, + row.currently_active + ) for row in rows] - if states and "position" in stream: - stream_id = int(stream["position"]) - yield self.notify_from_replication(states, stream_id) + for state in states: + self.user_to_current_state[row.user_id] = state + + stream_id = token + yield self.notify_from_replication(states, stream_id) + + def get_currently_syncing_users(self): + return [ + user_id for user_id, count in self.user_to_num_current_syncs.iteritems() + if count > 0 + ] class SynchrotronTyping(object): @@ -257,16 +236,12 @@ class SynchrotronTyping(object): # value which we *must* use for the next replication request. return {"typing": self._latest_room_serial} - def process_replication(self, result): - stream = result.get("typing") - if stream: - self._latest_room_serial = int(stream["position"]) + def process_replication_rows(self, token, rows): + self._latest_room_serial = token - for row in stream["rows"]: - position, room_id, typing_json = row - typing = json.loads(typing_json) - self._room_serials[room_id] = position - self._room_typing[room_id] = typing + for row in rows: + self._room_serials[row.room_id] = token + self._room_typing[row.room_id] = row.user_ids class SynchrotronApplicationService(object): @@ -351,118 +326,10 @@ class SynchrotronServer(HomeServer): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - notifier = self.get_notifier() - presence_handler = self.get_presence_handler() - typing_handler = self.get_typing_handler() + self.get_tcp_replication().start_replication(self) - def notify_from_stream( - result, stream_name, stream_key, room=None, user=None - ): - stream = result.get(stream_name) - if stream: - position_index = stream["field_names"].index("position") - if room: - room_index = stream["field_names"].index(room) - if user: - user_index = stream["field_names"].index(user) - - users = () - rooms = () - for row in stream["rows"]: - position = row[position_index] - - if user: - users = (row[user_index],) - - if room: - rooms = (row[room_index],) - - notifier.on_new_event( - stream_key, position, users=users, rooms=rooms - ) - - @defer.inlineCallbacks - def notify_device_list_update(result): - stream = result.get("device_lists") - if not stream: - return - - position_index = stream["field_names"].index("position") - user_index = stream["field_names"].index("user_id") - - for row in stream["rows"]: - position = row[position_index] - user_id = row[user_index] - - room_ids = yield store.get_rooms_for_user(user_id) - - notifier.on_new_event( - "device_list_key", position, rooms=room_ids, - ) - - @defer.inlineCallbacks - def notify(result): - stream = result.get("events") - if stream: - max_position = stream["position"] - - event_map = yield store.get_events([row[1] for row in stream["rows"]]) - - for row in stream["rows"]: - position = row[0] - event_id = row[1] - event = event_map.get(event_id, None) - if not event: - continue - - extra_users = () - if event.type == EventTypes.Member: - extra_users = (event.state_key,) - notifier.on_new_room_event( - event, position, max_position, extra_users - ) - - notify_from_stream( - result, "push_rules", "push_rules_key", user="user_id" - ) - notify_from_stream( - result, "user_account_data", "account_data_key", user="user_id" - ) - notify_from_stream( - result, "room_account_data", "account_data_key", user="user_id" - ) - notify_from_stream( - result, "tag_account_data", "account_data_key", user="user_id" - ) - notify_from_stream( - result, "receipts", "receipt_key", room="room_id" - ) - notify_from_stream( - result, "typing", "typing_key", room="room_id" - ) - notify_from_stream( - result, "to_device", "to_device_key", user="user_id" - ) - yield notify_device_list_update(result) - - while True: - try: - args = store.stream_positions() - args.update(typing_handler.stream_positions()) - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - typing_handler.process_replication(result) - yield presence_handler.process_replication(result) - yield notify(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def build_tcp_replication(self): + return SyncReplicationHandler(self) def build_presence_handler(self): return SynchrotronPresence(self) @@ -471,6 +338,79 @@ class SynchrotronServer(HomeServer): return SynchrotronTyping(self) +class SyncReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(SyncReplicationHandler, self).__init__(hs.get_datastore()) + + self.store = hs.get_datastore() + self.typing_handler = hs.get_typing_handler() + self.presence_handler = hs.get_presence_handler() + self.notifier = hs.get_notifier() + + self.presence_handler.sync_callback = self.send_user_sync + + def on_rdata(self, stream_name, token, rows): + super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) + + preserve_fn(self.process_and_notify)(stream_name, token, rows) + + def get_streams_to_replicate(self): + args = super(SyncReplicationHandler, self).get_streams_to_replicate() + args.update(self.typing_handler.stream_positions()) + return args + + def get_currently_syncing_users(self): + return self.presence_handler.get_currently_syncing_users() + + @defer.inlineCallbacks + def process_and_notify(self, stream_name, token, rows): + if stream_name == "events": + # We shouldn't get multiple rows per token for events stream, so + # we don't need to optimise this for multiple rows. + for row in rows: + event = yield self.store.get_event(row.event_id) + extra_users = () + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + max_token = self.store.get_room_max_stream_ordering() + self.notifier.on_new_room_event( + event, token, max_token, extra_users + ) + elif stream_name == "push_rules": + self.notifier.on_new_event( + "push_rules_key", token, users=[row.user_id for row in rows], + ) + elif stream_name in ("account_data", "tag_account_data",): + self.notifier.on_new_event( + "account_data_key", token, users=[row.user_id for row in rows], + ) + elif stream_name == "receipts": + self.notifier.on_new_event( + "receipt_key", token, rooms=[row.room_id for row in rows], + ) + elif stream_name == "typing": + self.typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows], + ) + elif stream_name == "to_device": + entities = [row.entity for row in rows if row.entity.startswith("@")] + if entities: + self.notifier.on_new_event( + "to_device_key", token, users=entities, + ) + elif stream_name == "device_lists": + all_room_ids = set() + for row in rows: + room_ids = yield self.store.get_rooms_for_user(row.user_id) + all_room_ids.update(room_ids) + self.notifier.on_new_event( + "device_list_key", token, rooms=all_room_ids, + ) + elif stream_name == "presence": + yield self.presence_handler.process_replication_rows(token, rows) + + def start(config_options): try: config = HomeServerConfig.load_config( @@ -514,7 +454,6 @@ def start(config_options): def start(): ss.get_datastore().start_profiling() - ss.replicate() ss.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 23eb6a1ec..3bd7ef7bb 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -125,7 +125,7 @@ def main(): "configfile", nargs="?", default="homeserver.yaml", - help="the homeserver config file, defaults to homserver.yaml", + help="the homeserver config file, defaults to homeserver.yaml", ) parser.add_argument( "-w", "--worker", @@ -202,7 +202,8 @@ def main(): worker_app = worker_config["worker_app"] worker_pidfile = worker_config["worker_pid_file"] worker_daemonize = worker_config["worker_daemonize"] - assert worker_daemonize # TODO print something more user friendly + assert worker_daemonize, "In config %r: expected '%s' to be True" % ( + worker_configfile, "worker_daemonize") worker_cache_factor = worker_config.get("synctl_cache_factor") workers.append(Worker( worker_app, worker_configfile, worker_pidfile, worker_cache_factor, @@ -233,6 +234,9 @@ def main(): if action == "start" or action == "restart": if start_stop_synapse: + # Check if synapse is already running + if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())): + abort("synapse.app.homeserver already running") start(configfile) for worker in workers: diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index b0106a359..7346206bb 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.api.constants import EventTypes +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer @@ -124,29 +125,23 @@ class ApplicationService(object): raise ValueError( "Expected bool for 'exclusive' in ns '%s'" % ns ) - if not isinstance(regex_obj.get("regex"), basestring): + regex = regex_obj.get("regex") + if isinstance(regex, basestring): + regex_obj["regex"] = re.compile(regex) # Pre-compile regex + else: raise ValueError( "Expected string for 'regex' in ns '%s'" % ns ) return namespaces - def _matches_regex(self, test_string, namespace_key, return_obj=False): - if not isinstance(test_string, basestring): - logger.error( - "Expected a string to test regex against, but got %s", - test_string - ) - return False - + def _matches_regex(self, test_string, namespace_key): for regex_obj in self.namespaces[namespace_key]: - if re.match(regex_obj["regex"], test_string): - if return_obj: - return regex_obj - return True - return False + if regex_obj["regex"].match(test_string): + return regex_obj + return None def _is_exclusive(self, ns_key, test_string): - regex_obj = self._matches_regex(test_string, ns_key, return_obj=True) + regex_obj = self._matches_regex(test_string, ns_key) if regex_obj: return regex_obj["exclusive"] return False @@ -166,7 +161,14 @@ class ApplicationService(object): if not store: defer.returnValue(False) - member_list = yield store.get_users_in_room(event.room_id) + does_match = yield self._matches_user_in_member_list(event.room_id, store) + defer.returnValue(does_match) + + @cachedInlineCallbacks(num_args=1, cache_context=True) + def _matches_user_in_member_list(self, room_id, store, cache_context): + member_list = yield store.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) # check joined member events for user_id in member_list: @@ -219,10 +221,10 @@ class ApplicationService(object): ) def is_interested_in_alias(self, alias): - return self._matches_regex(alias, ApplicationService.NS_ALIASES) + return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) def is_interested_in_room(self, room_id): - return self._matches_regex(room_id, ApplicationService.NS_ROOMS) + return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) def is_exclusive_user(self, user_id): return ( diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 0030b5db1..fe156b693 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -71,6 +71,15 @@ class EmailConfig(Config): self.email_riot_base_url = email_config.get( "riot_base_url", None ) + self.email_smtp_user = email_config.get( + "smtp_user", None + ) + self.email_smtp_pass = email_config.get( + "smtp_pass", None + ) + self.require_transport_security = email_config.get( + "require_transport_security", False + ) if "app_name" in email_config: self.email_app_name = email_config["app_name"] else: @@ -91,10 +100,17 @@ class EmailConfig(Config): # Defining a custom URL for Riot is only needed if email notifications # should contain links to a self-hosted installation of Riot; when set # the "app_name" setting is ignored. + # + # If your SMTP server requires authentication, the optional smtp_user & + # smtp_pass variables should be used + # #email: # enable_notifs: false # smtp_host: "localhost" # smtp_port: 25 + # smtp_user: "exampleusername" + # smtp_pass: "examplepassword" + # require_transport_security: False # notif_from: "Your Friendly %(app)s Home Server " # app_name: Matrix # template_dir: res/templates diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 87e500c97..f7e03c4cd 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -69,6 +69,7 @@ class RegistrationConfig(Config): trusted_third_party_id_servers: - matrix.org - vector.im + - riot.im """ % locals() def add_arguments(self, parser): diff --git a/synapse/config/server.py b/synapse/config/server.py index 1f9999d57..3910b9dc3 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -35,6 +35,8 @@ class ServerConfig(Config): # "disable" federation self.send_federation = config.get("send_federation", True) + self.filter_timeline_limit = config.get("filter_timeline_limit", -1) + if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': self.public_baseurl += '/' @@ -144,6 +146,12 @@ class ServerConfig(Config): # Whether to serve a web client from the HTTP/HTTPS root resource. web_client: True + # The root directory to server for the above web client. + # If left undefined, synapse will serve the matrix-angular-sdk web client. + # Make sure matrix-angular-sdk is installed with pip if web_client is True + # and web_client_location is undefined + # web_client_location: "/path/to/web/root" + # The public-facing base URL for the client API (not including _matrix/...) # public_baseurl: https://example.com:8448/ @@ -155,6 +163,10 @@ class ServerConfig(Config): # The GC threshold parameters to pass to `gc.set_threshold`, if defined # gc_thresholds: [700, 10, 10] + # Set the limit on the returned events in the timeline in the get + # and sync operations. The default value is -1, means no upper limit. + # filter_timeline_limit: 5000 + # List of ports that Synapse should listen on, their purpose and their # configuration. listeners: diff --git a/synapse/config/voip.py b/synapse/config/voip.py index eeb693027..3a4e16fa9 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -23,6 +23,7 @@ class VoipConfig(Config): self.turn_username = config.get("turn_username") self.turn_password = config.get("turn_password") self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) + self.turn_allow_guests = config.get("turn_allow_guests", True) def default_config(self, **kwargs): return """\ @@ -41,4 +42,11 @@ class VoipConfig(Config): # How long generated TURN credentials last turn_user_lifetime: "1h" + + # Whether guests should be allowed to use the TURN server. + # This defaults to True, otherwise VoIP will be unreliable for guests. + # However, it does introduce a slight security risk as it allows users to + # connect to arbitrary endpoints without having first signed up for a + # valid account (e.g. by passing a CAPTCHA). + turn_allow_guests: True """ diff --git a/synapse/config/workers.py b/synapse/config/workers.py index b165c67ee..ea48d931a 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -28,7 +28,9 @@ class WorkerConfig(Config): self.worker_pid_file = config.get("worker_pid_file") self.worker_log_file = config.get("worker_log_file") self.worker_log_config = config.get("worker_log_config") - self.worker_replication_url = config.get("worker_replication_url") + self.worker_replication_host = config.get("worker_replication_host", None) + self.worker_replication_port = config.get("worker_replication_port", None) + self.worker_name = config.get("worker_name", self.worker_app) if self.worker_listeners: for listener in self.worker_listeners: diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6be18880b..e9a732ff0 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -50,6 +50,7 @@ class EventContext(object): "prev_group", "delta_ids", "prev_state_events", + "app_service", ] def __init__(self): @@ -68,3 +69,5 @@ class EventContext(object): self.delta_ids = None self.prev_state_events = None + + self.app_service = None diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 5bbaef818..824f4a42e 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -225,7 +225,22 @@ def format_event_for_client_v2_without_room_id(d): def serialize_event(e, time_now_ms, as_client_event=True, event_format=format_event_for_client_v1, - token_id=None, only_event_fields=None): + token_id=None, only_event_fields=None, is_invite=False): + """Serialize event for clients + + Args: + e (EventBase) + time_now_ms (int) + as_client_event (bool) + event_format + token_id + only_event_fields + is_invite (bool): Whether this is an invite that is being sent to the + invitee + + Returns: + dict + """ # FIXME(erikj): To handle the case of presence events and the like if not isinstance(e, EventBase): return e @@ -251,6 +266,12 @@ def serialize_event(e, time_now_ms, as_client_event=True, if txn_id is not None: d["unsigned"]["transaction_id"] = txn_id + # If this is an invite for somebody else, then we don't care about the + # invite_room_state as that's meant solely for the invitee. Other clients + # will already have the state since they're in the room. + if not is_invite: + d["unsigned"].pop("invite_room_state", None) + if as_client_event: d = event_format(d) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index deee0f490..861441708 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -474,8 +474,13 @@ class FederationClient(FederationBase): content (object): Any additional data to put into the content field of the event. Return: - A tuple of (origin (str), event (object)) where origin is the remote - homeserver which generated the event. + Deferred: resolves to a tuple of (origin (str), event (object)) + where origin is the remote homeserver which generated the event. + + Fails with a ``CodeMessageException`` if the chosen remote server + returns a 300/400 code. + + Fails with a ``RuntimeError`` if no servers were reachable. """ valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: @@ -528,6 +533,27 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_join(self, destinations, pdu): + """Sends a join event to one of a list of homeservers. + + Doing so will cause the remote server to add the event to the graph, + and send the event out to the rest of the federation. + + Args: + destinations (str): Candidate homeservers which are probably + participating in the room. + pdu (BaseEvent): event to be sent + + Return: + Deferred: resolves to a dict with members ``origin`` (a string + giving the serer the event was sent to, ``state`` (?) and + ``auth_chain``. + + Fails with a ``CodeMessageException`` if the chosen remote server + returns a 300/400 code. + + Fails with a ``RuntimeError`` if no servers were reachable. + """ + for destination in destinations: if destination == self.server_name: continue @@ -635,6 +661,26 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_leave(self, destinations, pdu): + """Sends a leave event to one of a list of homeservers. + + Doing so will cause the remote server to add the event to the graph, + and send the event out to the rest of the federation. + + This is mostly useful to reject received invites. + + Args: + destinations (str): Candidate homeservers which are probably + participating in the room. + pdu (BaseEvent): event to be sent + + Return: + Deferred: resolves to None. + + Fails with a ``CodeMessageException`` if the chosen remote server + returns a non-200 code. + + Fails with a ``RuntimeError`` if no servers were reachable. + """ for destination in destinations: if destination == self.server_name: continue diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index bc20b9c20..51e3fdea0 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -440,6 +440,16 @@ class FederationServer(FederationBase): key_id: json.loads(json_bytes) } + logger.info( + "Claimed one-time-keys: %s", + ",".join(( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in json_result.iteritems() + for device_id, device_keys in user_keys.iteritems() + for key_id, _ in device_keys.iteritems() + )), + ) + defer.returnValue({"one_time_keys": json_result}) @defer.inlineCallbacks diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index bbb019522..93e5acebc 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -31,23 +31,21 @@ Events are replicated via a separate events stream. from .units import Edu +from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure import synapse.metrics from blist import sorteddict -import ujson +from collections import namedtuple + +import logging + +logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) -PRESENCE_TYPE = "p" -KEYED_EDU_TYPE = "k" -EDU_TYPE = "e" -FAILURE_TYPE = "f" -DEVICE_MESSAGE_TYPE = "d" - - class FederationRemoteSendQueue(object): """A drop in replacement for TransactionQueue""" @@ -55,18 +53,19 @@ class FederationRemoteSendQueue(object): self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id - self.presence_map = {} - self.presence_changed = sorteddict() + self.presence_map = {} # Pending presence map user_id -> UserPresenceState + self.presence_changed = sorteddict() # Stream position -> user_id - self.keyed_edu = {} - self.keyed_edu_changed = sorteddict() + self.keyed_edu = {} # (destination, key) -> EDU + self.keyed_edu_changed = sorteddict() # stream position -> (destination, key) - self.edus = sorteddict() + self.edus = sorteddict() # stream position -> Edu - self.failures = sorteddict() + self.failures = sorteddict() # stream position -> (destination, Failure) - self.device_messages = sorteddict() + self.device_messages = sorteddict() # stream position -> destination self.pos = 1 self.pos_time = sorteddict() @@ -122,7 +121,9 @@ class FederationRemoteSendQueue(object): del self.presence_changed[key] user_ids = set( - user_id for uids in self.presence_changed.values() for _, user_id in uids + user_id + for uids in self.presence_changed.itervalues() + for user_id in uids ) to_del = [ @@ -189,18 +190,20 @@ class FederationRemoteSendQueue(object): self.notifier.on_new_replication_data() - def send_presence(self, destination, states): - """As per TransactionQueue""" + def send_presence(self, states): + """As per TransactionQueue + + Args: + states (list(UserPresenceState)) + """ pos = self._next_pos() - self.presence_map.update({ - state.user_id: state - for state in states - }) + # We only want to send presence for our own users, so lets always just + # filter here just in case. + local_states = filter(lambda s: self.is_mine_id(s.user_id), states) - self.presence_changed[pos] = [ - (destination, state.user_id) for state in states - ] + self.presence_map.update({state.user_id: state for state in local_states}) + self.presence_changed[pos] = [state.user_id for state in local_states] self.notifier.on_new_replication_data() @@ -220,10 +223,15 @@ class FederationRemoteSendQueue(object): def get_current_token(self): return self.pos - 1 - def get_replication_rows(self, token, limit, federation_ack=None): - """ + def federation_ack(self, token): + self._clear_queue_before_pos(token) + + def get_replication_rows(self, from_token, to_token, limit, federation_ack=None): + """Get rows to be sent over federation between the two tokens + Args: - token (int) + from_token (int) + to_token(int) limit (int) federation_ack (int): Optional. The position where the worker is explicitly acknowledged it has handled. Allows us to drop @@ -232,9 +240,11 @@ class FederationRemoteSendQueue(object): # TODO: Handle limit. # To handle restarts where we wrap around - if token > self.pos: - token = -1 + if from_token > self.pos: + from_token = -1 + # list of tuple(int, BaseFederationRow), where the first is the position + # of the federation stream. rows = [] # There should be only one reader, so lets delete everything its @@ -244,62 +254,295 @@ class FederationRemoteSendQueue(object): # Fetch changed presence keys = self.presence_changed.keys() - i = keys.bisect_right(token) - dest_user_ids = set( - (pos, dest_user_id) - for pos in keys[i:] - for dest_user_id in self.presence_changed[pos] - ) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + dest_user_ids = [ + (pos, user_id) + for pos in keys[i:j] + for user_id in self.presence_changed[pos] + ] - for (key, (dest, user_id)) in dest_user_ids: - rows.append((key, PRESENCE_TYPE, ujson.dumps({ - "destination": dest, - "state": self.presence_map[user_id].as_dict(), - }))) + for (key, user_id) in dest_user_ids: + rows.append((key, PresenceRow( + state=self.presence_map[user_id], + ))) # Fetch changes keyed edus keys = self.keyed_edu_changed.keys() - i = keys.bisect_right(token) - keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + # We purposefully clobber based on the key here, python dict comprehensions + # always use the last value, so this will correctly point to the last + # stream position. + keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]} - for (pos, (destination, edu_key)) in keyed_edus: - rows.append( - (pos, KEYED_EDU_TYPE, ujson.dumps({ - "key": edu_key, - "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(), - })) - ) + for ((destination, edu_key), pos) in keyed_edus.iteritems(): + rows.append((pos, KeyedEduRow( + key=edu_key, + edu=self.keyed_edu[(destination, edu_key)], + ))) # Fetch changed edus keys = self.edus.keys() - i = keys.bisect_right(token) - edus = set((k, self.edus[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + edus = ((k, self.edus[k]) for k in keys[i:j]) for (pos, edu) in edus: - rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict()))) + rows.append((pos, EduRow(edu))) # Fetch changed failures keys = self.failures.keys() - i = keys.bisect_right(token) - failures = set((k, self.failures[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + failures = ((k, self.failures[k]) for k in keys[i:j]) for (pos, (destination, failure)) in failures: - rows.append((pos, FAILURE_TYPE, ujson.dumps({ - "destination": destination, - "failure": failure, - }))) + rows.append((pos, FailureRow( + destination=destination, + failure=failure, + ))) # Fetch changed device messages keys = self.device_messages.keys() - i = keys.bisect_right(token) - device_messages = set((k, self.device_messages[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + device_messages = {self.device_messages[k]: k for k in keys[i:j]} - for (pos, destination) in device_messages: - rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({ - "destination": destination, - }))) + for (destination, pos) in device_messages.iteritems(): + rows.append((pos, DeviceRow( + destination=destination, + ))) # Sort rows based on pos rows.sort() - return rows + return [(pos, row.TypeId, row.to_data()) for pos, row in rows] + + +class BaseFederationRow(object): + """Base class for rows to be sent in the federation stream. + + Specifies how to identify, serialize and deserialize the different types. + """ + + TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + + @staticmethod + def from_data(data): + """Parse the data from the federation stream into a row. + + Args: + data: The value of ``data`` from FederationStreamRow.data, type + depends on the type of stream + """ + raise NotImplementedError() + + def to_data(self): + """Serialize this row to be sent over the federation stream. + + Returns: + The value to be sent in FederationStreamRow.data. The type depends + on the type of stream. + """ + raise NotImplementedError() + + def add_to_buffer(self, buff): + """Add this row to the appropriate field in the buffer ready for this + to be sent over federation. + + We use a buffer so that we can batch up events that have come in at + the same time and send them all at once. + + Args: + buff (BufferedToSend) + """ + raise NotImplementedError() + + +class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( + "state", # UserPresenceState +))): + TypeId = "p" + + @staticmethod + def from_data(data): + return PresenceRow( + state=UserPresenceState.from_dict(data) + ) + + def to_data(self): + return self.state.as_dict() + + def add_to_buffer(self, buff): + buff.presence.append(self.state) + + +class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( + "key", # tuple(str) - the edu key passed to send_edu + "edu", # Edu +))): + """Streams EDUs that have an associated key that is ued to clobber. For example, + typing EDUs clobber based on room_id. + """ + + TypeId = "k" + + @staticmethod + def from_data(data): + return KeyedEduRow( + key=tuple(data["key"]), + edu=Edu(**data["edu"]), + ) + + def to_data(self): + return { + "key": self.key, + "edu": self.edu.get_internal_dict(), + } + + def add_to_buffer(self, buff): + buff.keyed_edus.setdefault( + self.edu.destination, {} + )[self.key] = self.edu + + +class EduRow(BaseFederationRow, namedtuple("EduRow", ( + "edu", # Edu +))): + """Streams EDUs that don't have keys. See KeyedEduRow + """ + TypeId = "e" + + @staticmethod + def from_data(data): + return EduRow(Edu(**data)) + + def to_data(self): + return self.edu.get_internal_dict() + + def add_to_buffer(self, buff): + buff.edus.setdefault(self.edu.destination, []).append(self.edu) + + +class FailureRow(BaseFederationRow, namedtuple("FailureRow", ( + "destination", # str + "failure", +))): + """Streams failures to a remote server. Failures are issued when there was + something wrong with a transaction the remote sent us, e.g. it included + an event that was invalid. + """ + + TypeId = "f" + + @staticmethod + def from_data(data): + return FailureRow( + destination=data["destination"], + failure=data["failure"], + ) + + def to_data(self): + return { + "destination": self.destination, + "failure": self.failure, + } + + def add_to_buffer(self, buff): + buff.failures.setdefault(self.destination, []).append(self.failure) + + +class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( + "destination", # str +))): + """Streams the fact that either a) there is pending to device messages for + users on the remote, or b) a local users device has changed and needs to + be sent to the remote. + """ + TypeId = "d" + + @staticmethod + def from_data(data): + return DeviceRow(destination=data["destination"]) + + def to_data(self): + return {"destination": self.destination} + + def add_to_buffer(self, buff): + buff.device_destinations.add(self.destination) + + +TypeToRow = { + Row.TypeId: Row + for Row in ( + PresenceRow, + KeyedEduRow, + EduRow, + FailureRow, + DeviceRow, + ) +} + + +ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( + "presence", # list(UserPresenceState) + "keyed_edus", # dict of destination -> { key -> Edu } + "edus", # dict of destination -> [Edu] + "failures", # dict of destination -> [failures] + "device_destinations", # set of destinations +)) + + +def process_rows_for_federation(transaction_queue, rows): + """Parse a list of rows from the federation stream and put them in the + transaction queue ready for sending to the relevant homeservers. + + Args: + transaction_queue (TransactionQueue) + rows (list(synapse.replication.tcp.streams.FederationStreamRow)) + """ + + # The federation stream contains a bunch of different types of + # rows that need to be handled differently. We parse the rows, put + # them into the appropriate collection and then send them off. + + buff = ParsedFederationStreamData( + presence=[], + keyed_edus={}, + edus={}, + failures={}, + device_destinations=set(), + ) + + # Parse the rows in the stream and add to the buffer + for row in rows: + if row.type not in TypeToRow: + logger.error("Unrecognized federation row type %r", row.type) + continue + + RowType = TypeToRow[row.type] + parsed_row = RowType.from_data(row.data) + parsed_row.add_to_buffer(buff) + + if buff.presence: + transaction_queue.send_presence(buff.presence) + + for destination, edu_map in buff.keyed_edus.iteritems(): + for key, edu in edu_map.items(): + transaction_queue.send_edu( + edu.destination, edu.edu_type, edu.content, key=key, + ) + + for destination, edu_list in buff.edus.iteritems(): + for edu in edu_list: + transaction_queue.send_edu( + edu.destination, edu.edu_type, edu.content, key=None, + ) + + for destination, failure_list in buff.failures.iteritems(): + for failure in failure_list: + transaction_queue.send_failure(destination, failure) + + for destination in buff.device_destinations: + transaction_queue.send_device_messages(destination) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index c27ce7c5f..a15198e05 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -21,11 +21,10 @@ from .units import Transaction, Edu from synapse.api.errors import HttpResponseException from synapse.util.async import run_on_reactor -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logcontext import preserve_context_over_fn, preserve_fn from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.metrics import measure_func -from synapse.types import get_domain_from_id -from synapse.handlers.presence import format_user_presence_state +from synapse.handlers.presence import format_user_presence_state, get_interested_remotes import synapse.metrics import logging @@ -41,6 +40,8 @@ sent_pdus_destination_dist = client_metrics.register_distribution( ) sent_edus_counter = client_metrics.register_counter("sent_edus") +sent_transactions_counter = client_metrics.register_counter("sent_transactions") + class TransactionQueue(object): """This class makes sure we only have one transaction in flight at @@ -77,8 +78,18 @@ class TransactionQueue(object): # destination -> list of tuple(edu, deferred) self.pending_edus_by_dest = edus = {} - # Presence needs to be separate as we send single aggragate EDUs + # Map of user_id -> UserPresenceState for all the pending presence + # to be sent out by user_id. Entries here get processed and put in + # pending_presence_by_dest + self.pending_presence = {} + + # Map of destination -> user_id -> UserPresenceState of pending presence + # to be sent to each destinations self.pending_presence_by_dest = presence = {} + + # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered + # based on their key (e.g. typing events by room_id) + # Map of destination -> (edu_type, key) -> Edu self.pending_edus_keyed_by_dest = edus_keyed = {} metrics.register_callback( @@ -113,6 +124,8 @@ class TransactionQueue(object): self._is_processing = False self._last_poked_id = -1 + self._processing_pending_presence = False + def can_send_to(self, destination): """Can we send messages to the given server? @@ -169,15 +182,12 @@ class TransactionQueue(object): # Otherwise if the last member on a server in a room is # banned then it won't receive the event because it won't # be in the room after the ban. - users_in_room = yield self.state.get_current_user_in_room( + destinations = yield self.state.get_current_hosts_in_room( event.room_id, latest_event_ids=[ prev_id for prev_id, _ in event.prev_events ], ) - destinations = set( - get_domain_from_id(user_id) for user_id in users_in_room - ) if send_on_behalf_of is not None: # If we are sending the event on behalf of another server # then it already has the event and there is no reason to @@ -224,17 +234,71 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def send_presence(self, destination, states): - if not self.can_send_to(destination): - return + @preserve_fn # the caller should not yield on this + @defer.inlineCallbacks + def send_presence(self, states): + """Send the new presence states to the appropriate destinations. - self.pending_presence_by_dest.setdefault(destination, {}).update({ + This actually queues up the presence states ready for sending and + triggers a background task to process them and send out the transactions. + + Args: + states (list(UserPresenceState)) + """ + + # First we queue up the new presence by user ID, so multiple presence + # updates in quick successtion are correctly handled + # We only want to send presence for our own users, so lets always just + # filter here just in case. + self.pending_presence.update({ state.user_id: state for state in states + if self.is_mine_id(state.user_id) }) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + # We then handle the new pending presence in batches, first figuring + # out the destinations we need to send each state to and then poking it + # to attempt a new transaction. We linearize this so that we don't + # accidentally mess up the ordering and send multiple presence updates + # in the wrong order + if self._processing_pending_presence: + return + + self._processing_pending_presence = True + try: + while True: + states_map = self.pending_presence + self.pending_presence = {} + + if not states_map: + break + + yield self._process_presence_inner(states_map.values()) + finally: + self._processing_pending_presence = False + + @measure_func("txnqueue._process_presence") + @defer.inlineCallbacks + def _process_presence_inner(self, states): + """Given a list of states populate self.pending_presence_by_dest and + poke to send a new transaction to each destination + + Args: + states (list(UserPresenceState)) + """ + hosts_and_states = yield get_interested_remotes(self.store, states, self.state) + + for destinations, states in hosts_and_states: + for destination in destinations: + if not self.can_send_to(destination): + continue + + self.pending_presence_by_dest.setdefault( + destination, {} + ).update({ + state.user_id: state for state in states + }) + + preserve_fn(self._attempt_new_transaction)(destination) def send_edu(self, destination, edu_type, content, key=None): edu = Edu( @@ -374,6 +438,7 @@ class TransactionQueue(object): destination, pending_pdus, pending_edus, pending_failures, ) if success: + sent_transactions_counter.inc() # Remove the acknowledged device messages from the database # Only bother if we actually sent some device messages if device_message_edus: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 15a03378f..52b2a717d 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -193,6 +193,26 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def make_membership_event(self, destination, room_id, user_id, membership): + """Asks a remote server to build and sign us a membership event + + Note that this does not append any events to any graphs. + + Args: + destination (str): address of remote homeserver + room_id (str): room to join/leave + user_id (str): user to be joined/left + membership (str): one of join/leave + + Returns: + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body (ie, the new event). + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. + + Fails with ``NotRetryingDestination`` if we are not yet ready + to retry this server. + """ valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: raise RuntimeError( @@ -201,11 +221,23 @@ class TransportLayerClient(object): ) path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id) + ignore_backoff = False + retry_on_dns_fail = False + + if membership == Membership.LEAVE: + # we particularly want to do our best to send leave events. The + # problem is that if it fails, we won't retry it later, so if the + # remote server was just having a momentary blip, the room will be + # out of sync. + ignore_backoff = True + retry_on_dns_fail = True + content = yield self.client.get_json( destination=destination, path=path, - retry_on_dns_fail=False, + retry_on_dns_fail=retry_on_dns_fail, timeout=20000, + ignore_backoff=ignore_backoff, ) defer.returnValue(content) @@ -232,6 +264,12 @@ class TransportLayerClient(object): destination=destination, path=path, data=content, + + # we want to do our best to send this through. The problem is + # that if it fails, we won't retry it later, so if the remote + # server was just having a momentary blip, the room will be out of + # sync. + ignore_backoff=True, ) defer.returnValue(response) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index c840da834..3d676e7d8 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -24,6 +24,7 @@ from synapse.http.servlet import ( ) from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string +from synapse.util.logcontext import preserve_fn from synapse.types import ThirdPartyInstanceID import functools @@ -79,6 +80,7 @@ class Authenticator(object): def __init__(self, hs): self.keyring = hs.get_keyring() self.server_name = hs.hostname + self.store = hs.get_datastore() # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks @@ -138,6 +140,13 @@ class Authenticator(object): logger.info("Request from %s", origin) request.authenticated_entity = origin + # If we get a valid signed request from the other side, its probably + # alive + retry_timings = yield self.store.get_destination_retry_timings(origin) + if retry_timings and retry_timings["retry_last_ts"]: + logger.info("Marking origin %r as up", origin) + preserve_fn(self.store.set_destination_retry_timings)(origin, 0, 0) + defer.returnValue(origin) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index e83adc833..faa5609c0 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -53,7 +53,20 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() - def ratelimit(self, requester): + @defer.inlineCallbacks + def ratelimit(self, requester, update=True): + """Ratelimits requests. + + Args: + requester (Requester) + update (bool): Whether to record that a request is being processed. + Set to False when doing multiple checks for one request (e.g. + to check up front if we would reject the request), and set to + True for the last call for a given request. + + Raises: + LimitExceededError if the request should be ratelimited + """ time_now = self.clock.time() user_id = requester.user.to_string() @@ -67,10 +80,25 @@ class BaseHandler(object): if requester.app_service and not requester.app_service.is_rate_limited(): return + # Check if there is a per user override in the DB. + override = yield self.store.get_ratelimit_for_user(user_id) + if override: + # If overriden with a null Hz then ratelimiting has been entirely + # disabled for the user + if not override.messages_per_second: + return + + messages_per_second = override.messages_per_second + burst_count = override.burst_count + else: + messages_per_second = self.hs.config.rc_messages_per_second + burst_count = self.hs.config.rc_message_burst_count + allowed, time_allowed = self.ratelimiter.send_message( user_id, time_now, - msg_rate_hz=self.hs.config.rc_messages_per_second, - burst_count=self.hs.config.rc_message_burst_count, + msg_rate_hz=messages_per_second, + burst_count=burst_count, + update=update, ) if not allowed: raise LimitExceededError( diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c22f65ce5..982cda3ed 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -17,6 +17,7 @@ from synapse.api.constants import EventTypes from synapse.util import stringutils from synapse.util.async import Linearizer from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.retryutils import NotRetryingDestination from synapse.util.metrics import measure_func from synapse.types import get_domain_from_id, RoomStreamToken from twisted.internet import defer @@ -425,12 +426,38 @@ class DeviceListEduUpdater(object): # This can happen since we batch updates return + # Given a list of updates we check if we need to resync. This + # happens if we've missed updates. resync = yield self._need_to_do_resync(user_id, pending_updates) if resync: # Fetch all devices for the user. origin = get_domain_from_id(user_id) - result = yield self.federation.query_user_devices(origin, user_id) + try: + result = yield self.federation.query_user_devices(origin, user_id) + except NotRetryingDestination: + # TODO: Remember that we are now out of sync and try again + # later + logger.warn( + "Failed to handle device list update for %s," + " we're not retrying the remote", + user_id, + ) + # We abort on exceptions rather than accepting the update + # as otherwise synapse will 'forget' that its device list + # is out of date. If we bail then we will retry the resync + # next time we get a device list update for this user_id. + # This makes it more likely that the device lists will + # eventually become consistent. + return + except Exception: + # TODO: Remember that we are now out of sync and try again + # later + logger.exception( + "Failed to handle device list update for %s", user_id + ) + return + stream_id = result["stream_id"] devices = result["devices"] yield self.store.update_remote_device_list_cache( diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index c2b38d72a..668a90e49 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, CodeMessageException from synapse.types import get_domain_from_id -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -145,7 +145,7 @@ class E2eKeysHandler(object): "status": 503, "message": e.message } - yield preserve_context_over_deferred(defer.gatherResults([ + yield make_deferred_yieldable(defer.gatherResults([ preserve_fn(do_remote_query)(destination) for destination in remote_queries_not_in_cache ])) @@ -257,11 +257,21 @@ class E2eKeysHandler(object): "status": 503, "message": e.message } - yield preserve_context_over_deferred(defer.gatherResults([ + yield make_deferred_yieldable(defer.gatherResults([ preserve_fn(claim_client_keys)(destination) for destination in remote_queries ])) + logger.info( + "Claimed one-time-keys: %s", + ",".join(( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in json_result.iteritems() + for device_id, device_keys in user_keys.iteritems() + for key_id, _ in device_keys.iteritems() + )), + ) + defer.returnValue({ "one_time_keys": json_result, "failures": failures @@ -288,19 +298,8 @@ class E2eKeysHandler(object): one_time_keys = keys.get("one_time_keys", None) if one_time_keys: - logger.info( - "Adding %d one_time_keys for device %r for user %r at %d", - len(one_time_keys), device_id, user_id, time_now - ) - key_list = [] - for key_id, key_json in one_time_keys.items(): - algorithm, key_id = key_id.split(":") - key_list.append(( - algorithm, key_id, encode_canonical_json(key_json) - )) - - yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, key_list + yield self._upload_one_time_keys_for_user( + user_id, device_id, time_now, one_time_keys, ) # the device should have been registered already, but it may have been @@ -313,3 +312,58 @@ class E2eKeysHandler(object): result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue({"one_time_key_counts": result}) + + @defer.inlineCallbacks + def _upload_one_time_keys_for_user(self, user_id, device_id, time_now, + one_time_keys): + logger.info( + "Adding one_time_keys %r for device %r for user %r at %d", + one_time_keys.keys(), device_id, user_id, time_now, + ) + + # make a list of (alg, id, key) tuples + key_list = [] + for key_id, key_obj in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append(( + algorithm, key_id, key_obj + )) + + # First we check if we have already persisted any of the keys. + existing_key_map = yield self.store.get_e2e_one_time_keys( + user_id, device_id, [k_id for _, k_id, _ in key_list] + ) + + new_keys = [] # Keys that we need to insert. (alg, id, json) tuples. + for algorithm, key_id, key in key_list: + ex_json = existing_key_map.get((algorithm, key_id), None) + if ex_json: + if not _one_time_keys_match(ex_json, key): + raise SynapseError( + 400, + ("One time key %s:%s already exists. " + "Old key: %s; new key: %r") % + (algorithm, key_id, ex_json, key) + ) + else: + new_keys.append((algorithm, key_id, encode_canonical_json(key))) + + yield self.store.add_e2e_one_time_keys( + user_id, device_id, time_now, new_keys + ) + + +def _one_time_keys_match(old_key_json, new_key): + old_key = json.loads(old_key_json) + + # if either is a string rather than an object, they must match exactly + if not isinstance(old_key, dict) or not isinstance(new_key, dict): + return old_key == new_key + + # otherwise, we strip off the 'signatures' if any, because it's legitimate + # for different upload attempts to have different signatures. + old_key.pop("signatures", None) + new_key_copy = dict(new_key) + new_key_copy.pop("signatures", None) + + return old_key == new_key_copy diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 53f929639..52d97dfbf 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -28,7 +28,7 @@ from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_fn, preserve_context_over_deferred + preserve_fn, preserve_context_over_deferred ) from synapse.util.metrics import measure_func from synapse.util.logutils import log_function @@ -172,8 +172,22 @@ class FederationHandler(BaseHandler): origin, pdu, prevs, min_depth ) - prevs = {e_id for e_id, _ in pdu.prev_events} - seen = set(have_seen.keys()) + # Update the set of things we've seen after trying to + # fetch the missing stuff + have_seen = yield self.store.have_events(prevs) + seen = set(have_seen.iterkeys()) + + if not prevs - seen: + logger.info( + "Found all missing prev events for %s", pdu.event_id + ) + elif prevs - seen: + logger.info( + "Not fetching %d missing events for room %r,event %s: %r...", + len(prevs - seen), pdu.room_id, pdu.event_id, + list(prevs - seen)[:5], + ) + if prevs - seen: logger.info( "Still missing %d events for room %r: %r...", @@ -208,19 +222,15 @@ class FederationHandler(BaseHandler): Args: origin (str): Origin of the pdu. Will be called to get the missing events pdu: received pdu - prevs (str[]): List of event ids which we are missing + prevs (set(str)): List of event ids which we are missing min_depth (int): Minimum depth of events to return. - - Returns: - Deferred: updated have_seen dictionary """ # We recalculate seen, since it may have changed. have_seen = yield self.store.have_events(prevs) seen = set(have_seen.keys()) if not prevs - seen: - # nothing left to do - defer.returnValue(have_seen) + return latest = yield self.store.get_latest_event_ids_in_room( pdu.room_id @@ -232,8 +242,8 @@ class FederationHandler(BaseHandler): latest |= seen logger.info( - "Missing %d events for room %r: %r...", - len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] + "Missing %d events for room %r pdu %s: %r...", + len(prevs - seen), pdu.room_id, pdu.event_id, list(prevs - seen)[:5] ) # XXX: we set timeout to 10s to help workaround @@ -265,22 +275,23 @@ class FederationHandler(BaseHandler): timeout=10000, ) + logger.info( + "Got %d events: %r...", + len(missing_events), [e.event_id for e in missing_events[:5]] + ) + # We want to sort these by depth so we process them and # tell clients about them in order. missing_events.sort(key=lambda x: x.depth) for e in missing_events: + logger.info("Handling found event %s", e.event_id) yield self.on_receive_pdu( origin, e, get_missing=False ) - have_seen = yield self.store.have_events( - [ev for ev, _ in pdu.prev_events] - ) - defer.returnValue(have_seen) - @log_function @defer.inlineCallbacks def _process_received_pdu(self, origin, pdu, state, auth_chain): @@ -369,13 +380,6 @@ class FederationHandler(BaseHandler): affected=event.event_id, ) - # if we're receiving valid events from an origin, - # it's probably a good idea to mark it as not in retry-state - # for sending (although this is a bit of a leap) - retry_timings = yield self.store.get_destination_retry_timings(origin) - if retry_timings and retry_timings["retry_last_ts"]: - self.store.set_destination_retry_timings(origin, 0, 0) - room = yield self.store.get_room(event.room_id) if not room: @@ -394,11 +398,10 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) if event.type == EventTypes.Member: if event.membership == Membership.JOIN: @@ -916,11 +919,10 @@ class FederationHandler(BaseHandler): origin, auth_chain, state, event ) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[joinee] - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=[joinee] + ) logger.debug("Finished joining %s to %s", joinee, room_id) finally: @@ -1035,10 +1037,9 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) if event.type == EventTypes.Member: if event.content["membership"] == Membership.JOIN: @@ -1084,29 +1085,22 @@ class FederationHandler(BaseHandler): ) target_user = UserID.from_string(event.state_key) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[target_user], - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=[target_user], + ) defer.returnValue(event) @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): - try: - origin, event = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" - ) - event = self._sign_event(event) - except SynapseError: - raise - except CodeMessageException as e: - logger.warn("Failed to reject invite: %s", e) - raise SynapseError(500, "Failed to reject invite") + origin, event = yield self._make_and_verify_event( + target_hosts, + room_id, + user_id, + "leave" + ) + event = self._sign_event(event) # Try the host that we succesfully called /make_leave/ on first for # the /send_leave/ request. @@ -1116,16 +1110,10 @@ class FederationHandler(BaseHandler): except ValueError: pass - try: - yield self.replication_layer.send_leave( - target_hosts, - event - ) - except SynapseError: - raise - except CodeMessageException as e: - logger.warn("Failed to reject invite: %s", e) - raise SynapseError(500, "Failed to reject invite") + yield self.replication_layer.send_leave( + target_hosts, + event + ) context = yield self.state_handler.compute_event_context(event) @@ -1246,10 +1234,9 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) defer.returnValue(None) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 6a53c5eb4..9efcdff1d 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import ( - CodeMessageException + MatrixCodeMessageException, CodeMessageException ) from ._base import BaseHandler from synapse.util.async import run_on_reactor @@ -90,6 +90,9 @@ class IdentityHandler(BaseHandler): ), {'sid': creds['sid'], 'client_secret': client_secret} ) + except MatrixCodeMessageException as e: + logger.info("getValidated3pid failed with Matrix error: %r", e) + raise SynapseError(e.code, e.msg, e.errcode) except CodeMessageException as e: data = json.loads(e.msg) @@ -159,6 +162,9 @@ class IdentityHandler(BaseHandler): params ) defer.returnValue(data) + except MatrixCodeMessageException as e: + logger.info("Proxied requestToken failed with Matrix error: %r", e) + raise SynapseError(e.code, e.msg, e.errcode) except CodeMessageException as e: logger.info("Proxied requestToken failed: %r", e) raise e @@ -193,6 +199,9 @@ class IdentityHandler(BaseHandler): params ) defer.returnValue(data) + except MatrixCodeMessageException as e: + logger.info("Proxied requestToken failed with Matrix error: %r", e) + raise SynapseError(e.code, e.msg, e.errcode) except CodeMessageException as e: logger.info("Proxied requestToken failed: %r", e) raise e diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7a498af5a..196925eda 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -175,7 +175,8 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None): + def create_event(self, requester, event_dict, token_id=None, txn_id=None, + prev_event_ids=None): """ Given a dict from a client, create a new event. @@ -185,6 +186,7 @@ class MessageHandler(BaseHandler): Adds display names to Join membership events. Args: + requester event_dict (dict): An entire event token_id (str) txn_id (str) @@ -226,6 +228,7 @@ class MessageHandler(BaseHandler): event, context = yield self._create_new_client_event( builder=builder, + requester=requester, prev_event_ids=prev_event_ids, ) @@ -251,17 +254,7 @@ class MessageHandler(BaseHandler): # We check here if we are currently being rate limited, so that we # don't do unnecessary work. We check again just before we actually # send the event. - time_now = self.clock.time() - allowed, time_allowed = self.ratelimiter.send_message( - event.sender, time_now, - msg_rate_hz=self.hs.config.rc_messages_per_second, - burst_count=self.hs.config.rc_message_burst_count, - update=False, - ) - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), - ) + yield self.ratelimit(requester, update=False) user = UserID.from_string(event.sender) @@ -319,6 +312,7 @@ class MessageHandler(BaseHandler): See self.create_event and self.send_nonmember_event. """ event, context = yield self.create_event( + requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id @@ -416,7 +410,7 @@ class MessageHandler(BaseHandler): @measure_func("_create_new_client_event") @defer.inlineCallbacks - def _create_new_client_event(self, builder, prev_event_ids=None): + def _create_new_client_event(self, builder, requester=None, prev_event_ids=None): if prev_event_ids: prev_events = yield self.store.add_event_hashes(prev_event_ids) prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) @@ -456,6 +450,8 @@ class MessageHandler(BaseHandler): state_handler = self.state_handler context = yield state_handler.compute_event_context(builder) + if requester: + context.app_service = requester.app_service if builder.is_state(): builder.prev_state = yield self.store.add_event_hashes( @@ -493,7 +489,7 @@ class MessageHandler(BaseHandler): # We now need to go and hit out to wherever we need to hit out to. if ratelimit: - self.ratelimit(requester) + yield self.ratelimit(requester) try: yield self.auth.check_from_context(event, context) @@ -531,9 +527,9 @@ class MessageHandler(BaseHandler): state_to_include_ids = [ e_id - for k, e_id in context.current_state_ids.items() + for k, e_id in context.current_state_ids.iteritems() if k[0] in self.hs.config.room_invite_state_types - or k[0] == EventTypes.Member and k[1] == event.sender + or k == (EventTypes.Member, event.sender) ] state_to_include = yield self.store.get_events(state_to_include_ids) @@ -545,7 +541,7 @@ class MessageHandler(BaseHandler): "content": e.content, "sender": e.sender, } - for e in state_to_include.values() + for e in state_to_include.itervalues() ] invitee = UserID.from_string(event.state_key) @@ -612,12 +608,9 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def _notify(): yield run_on_reactor() - yield self.notifier.on_new_room_event( + self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=extra_users ) preserve_fn(_notify)() - - # If invite, remove room_state from unsigned before sending. - event.unsigned.pop("invite_room_state", None) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1ede117c7..c7c0b0a1e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -30,6 +30,7 @@ from synapse.api.constants import PresenceState from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.async import Linearizer from synapse.util.logcontext import preserve_fn from synapse.util.logutils import log_function from synapse.util.metrics import Measure @@ -187,6 +188,7 @@ class PresenceHandler(object): # process_id to millisecond timestamp last updated. self.external_process_to_current_syncs = {} self.external_process_last_updated_ms = {} + self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to @@ -316,11 +318,7 @@ class PresenceHandler(object): if to_federation_ping: federation_presence_out_counter.inc_by(len(to_federation_ping)) - _, _, hosts_to_states = yield self._get_interested_parties( - to_federation_ping.values() - ) - - self._push_to_remotes(hosts_to_states) + self._push_to_remotes(to_federation_ping.values()) def _handle_timeouts(self): """Checks the presence of users that have timed out and updates as @@ -508,6 +506,73 @@ class PresenceHandler(object): self.external_process_last_updated_ms[process_id] = self.clock.time_msec() self.external_process_to_current_syncs[process_id] = syncing_user_ids + @defer.inlineCallbacks + def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): + """Update the syncing users for an external process as a delta. + + Args: + process_id (str): An identifier for the process the users are + syncing against. This allows synapse to process updates + as user start and stop syncing against a given process. + user_id (str): The user who has started or stopped syncing + is_syncing (bool): Whether or not the user is now syncing + sync_time_msec(int): Time in ms when the user was last syncing + """ + with (yield self.external_sync_linearizer.queue(process_id)): + prev_state = yield self.current_state_for_user(user_id) + + process_presence = self.external_process_to_current_syncs.setdefault( + process_id, set() + ) + + updates = [] + if is_syncing and user_id not in process_presence: + if prev_state.state == PresenceState.OFFLINE: + updates.append(prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=sync_time_msec, + last_user_sync_ts=sync_time_msec, + )) + else: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=sync_time_msec, + )) + process_presence.add(user_id) + elif user_id in process_presence: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=sync_time_msec, + )) + + if not is_syncing: + process_presence.discard(user_id) + + if updates: + yield self._update_states(updates) + + self.external_process_last_updated_ms[process_id] = self.clock.time_msec() + + @defer.inlineCallbacks + def update_external_syncs_clear(self, process_id): + """Marks all users that had been marked as syncing by a given process + as offline. + + Used when the process has stopped/disappeared. + """ + with (yield self.external_sync_linearizer.queue(process_id)): + process_presence = self.external_process_to_current_syncs.pop( + process_id, set() + ) + prev_states = yield self.current_state_for_users(process_presence) + time_now_ms = self.clock.time_msec() + + yield self._update_states([ + prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + ) + for prev_state in prev_states.itervalues() + ]) + self.external_process_last_updated_ms.pop(process_id, None) + @defer.inlineCallbacks def current_state_for_user(self, user_id): """Get the current presence state for a user. @@ -527,14 +592,14 @@ class PresenceHandler(object): for user_id in user_ids } - missing = [user_id for user_id, state in states.items() if not state] + missing = [user_id for user_id, state in states.iteritems() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = yield self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in states.items() if not state] + missing = [user_id for user_id, state in states.iteritems() if not state] if missing: new = { user_id: UserPresenceState.default(user_id) @@ -545,54 +610,6 @@ class PresenceHandler(object): defer.returnValue(states) - @defer.inlineCallbacks - def _get_interested_parties(self, states, calculate_remote_hosts=True): - """Given a list of states return which entities (rooms, users, servers) - are interested in the given states. - - Returns: - 3-tuple: `(room_ids_to_states, users_to_states, hosts_to_states)`, - with each item being a dict of `entity_name` -> `[UserPresenceState]` - """ - room_ids_to_states = {} - users_to_states = {} - for state in states: - room_ids = yield self.store.get_rooms_for_user(state.user_id) - for room_id in room_ids: - room_ids_to_states.setdefault(room_id, []).append(state) - - plist = yield self.store.get_presence_list_observers_accepted(state.user_id) - for u in plist: - users_to_states.setdefault(u, []).append(state) - - # Always notify self - users_to_states.setdefault(state.user_id, []).append(state) - - hosts_to_states = {} - if calculate_remote_hosts: - for room_id, states in room_ids_to_states.items(): - local_states = filter(lambda s: self.is_mine_id(s.user_id), states) - if not local_states: - continue - - hosts = yield self.store.get_hosts_in_room(room_id) - - for host in hosts: - hosts_to_states.setdefault(host, []).extend(local_states) - - for user_id, states in users_to_states.items(): - local_states = filter(lambda s: self.is_mine_id(s.user_id), states) - if not local_states: - continue - - host = get_domain_from_id(user_id) - hosts_to_states.setdefault(host, []).extend(local_states) - - # TODO: de-dup hosts_to_states, as a single host might have multiple - # of same presence - - defer.returnValue((room_ids_to_states, users_to_states, hosts_to_states)) - @defer.inlineCallbacks def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to @@ -600,34 +617,33 @@ class PresenceHandler(object): """ stream_id, max_token = yield self.store.update_presence(states) - parties = yield self._get_interested_parties(states) - room_ids_to_states, users_to_states, hosts_to_states = parties + parties = yield get_interested_parties(self.store, states) + room_ids_to_states, users_to_states = parties self.notifier.on_new_event( "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states.keys()] + users=[UserID.from_string(u) for u in users_to_states] ) - self._push_to_remotes(hosts_to_states) + self._push_to_remotes(states) @defer.inlineCallbacks def notify_for_states(self, state, stream_id): - parties = yield self._get_interested_parties([state]) - room_ids_to_states, users_to_states, hosts_to_states = parties + parties = yield get_interested_parties(self.store, [state]) + room_ids_to_states, users_to_states = parties self.notifier.on_new_event( "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states.keys()] + users=[UserID.from_string(u) for u in users_to_states] ) - def _push_to_remotes(self, hosts_to_states): + def _push_to_remotes(self, states): """Sends state updates to remote servers. Args: - hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` + states (list(UserPresenceState)) """ - for host, states in hosts_to_states.items(): - self.federation.send_presence(host, states) + self.federation.send_presence(states) @defer.inlineCallbacks def incoming_presence(self, origin, content): @@ -764,18 +780,17 @@ class PresenceHandler(object): # don't need to send to local clients here, as that is done as part # of the event stream/sync. # TODO: Only send to servers not already in the room. - user_ids = yield self.store.get_users_in_room(room_id) if self.is_mine(user): state = yield self.current_state_for_user(user.to_string()) - hosts = set(get_domain_from_id(u) for u in user_ids) - self._push_to_remotes({host: (state,) for host in hosts}) + self._push_to_remotes([state]) else: + user_ids = yield self.store.get_users_in_room(room_id) user_ids = filter(self.is_mine_id, user_ids) states = yield self.current_state_for_users(user_ids) - self._push_to_remotes({user.domain: states.values()}) + self._push_to_remotes(states.values()) @defer.inlineCallbacks def get_presence_list(self, observer_user, accepted=None): @@ -1275,3 +1290,66 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): persist_and_notify = True return new_state, persist_and_notify, federation_ping + + +@defer.inlineCallbacks +def get_interested_parties(store, states): + """Given a list of states return which entities (rooms, users) + are interested in the given states. + + Args: + states (list(UserPresenceState)) + + Returns: + 2-tuple: `(room_ids_to_states, users_to_states)`, + with each item being a dict of `entity_name` -> `[UserPresenceState]` + """ + room_ids_to_states = {} + users_to_states = {} + for state in states: + room_ids = yield store.get_rooms_for_user(state.user_id) + for room_id in room_ids: + room_ids_to_states.setdefault(room_id, []).append(state) + + plist = yield store.get_presence_list_observers_accepted(state.user_id) + for u in plist: + users_to_states.setdefault(u, []).append(state) + + # Always notify self + users_to_states.setdefault(state.user_id, []).append(state) + + defer.returnValue((room_ids_to_states, users_to_states)) + + +@defer.inlineCallbacks +def get_interested_remotes(store, states, state_handler): + """Given a list of presence states figure out which remote servers + should be sent which. + + All the presence states should be for local users only. + + Args: + store (DataStore) + states (list(UserPresenceState)) + + Returns: + Deferred list of ([destinations], [UserPresenceState]), where for + each row the list of UserPresenceState should be sent to each + destination + """ + hosts_and_states = [] + + # First we look up the rooms each user is in (as well as any explicit + # subscriptions), then for each distinct room we look up the remote + # hosts in those rooms. + room_ids_to_states, users_to_states = yield get_interested_parties(store, states) + + for room_id, states in room_ids_to_states.iteritems(): + hosts = yield state_handler.get_current_hosts_in_room(room_id) + hosts_and_states.append((hosts, states)) + + for user_id, states in users_to_states.iteritems(): + host = get_domain_from_id(user_id) + hosts_and_states.append(([host], states)) + + defer.returnValue(hosts_and_states) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 9bf638f81..7abee98de 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler): if not self.hs.is_mine(user): return - self.ratelimit(requester) + yield self.ratelimit(requester) room_ids = yield self.store.get_rooms_for_user( user.to_string(), diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py new file mode 100644 index 000000000..b5b0303d5 --- /dev/null +++ b/synapse/handlers/read_marker.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseHandler + +from twisted.internet import defer + +from synapse.util.async import Linearizer + +import logging +logger = logging.getLogger(__name__) + + +class ReadMarkerHandler(BaseHandler): + def __init__(self, hs): + super(ReadMarkerHandler, self).__init__(hs) + self.server_name = hs.config.server_name + self.store = hs.get_datastore() + self.read_marker_linearizer = Linearizer(name="read_marker") + self.notifier = hs.get_notifier() + + @defer.inlineCallbacks + def received_client_read_marker(self, room_id, user_id, event_id): + """Updates the read marker for a given user in a given room if the event ID given + is ahead in the stream relative to the current read marker. + + This uses a notifier to indicate that account data should be sent down /sync if + the read marker has changed. + """ + + with (yield self.read_marker_linearizer.queue((room_id, user_id))): + account_data = yield self.store.get_account_data_for_room(user_id, room_id) + + existing_read_marker = account_data.get("m.fully_read", None) + + should_update = True + + if existing_read_marker: + # Only update if the new marker is ahead in the stream + should_update = yield self.store.is_event_after( + event_id, + existing_read_marker['event_id'] + ) + + if should_update: + content = { + "event_id": event_id + } + max_id = yield self.store.add_account_data_to_room( + user_id, room_id, "m.fully_read", content + ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 03c6a85fc..ee3a2269a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -54,6 +54,13 @@ class RegistrationHandler(BaseHandler): Codes.INVALID_USERNAME ) + if not localpart: + raise SynapseError( + 400, + "User ID cannot be empty", + Codes.INVALID_USERNAME + ) + if localpart[0] == '_': raise SynapseError( 400, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 99cb7db0d..d2a0d6520 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -75,7 +75,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - self.ratelimit(requester) + yield self.ratelimit(requester) if "room_alias_name" in config: for wchar in string.whitespace: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 2052d6d05..1ca88517a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -70,6 +70,7 @@ class RoomMemberHandler(BaseHandler): content["kind"] = "guest" event, context = yield msg_handler.create_event( + requester, { "type": EventTypes.Member, "content": content, @@ -139,13 +140,6 @@ class RoomMemberHandler(BaseHandler): ) yield user_joined_room(self.distributor, user, room_id) - def reject_remote_invite(self, user_id, room_id, remote_room_hosts): - return self.hs.get_handlers().federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - user_id - ) - @defer.inlineCallbacks def update_membership( self, @@ -286,13 +280,21 @@ class RoomMemberHandler(BaseHandler): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - + fed_handler = self.hs.get_handlers().federation_handler try: - ret = yield self.reject_remote_invite( - target.to_string(), room_id, remote_room_hosts + ret = yield fed_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + target.to_string(), ) defer.returnValue(ret) - except SynapseError as e: + except Exception as e: + # if we were unable to reject the exception, just mark + # it as rejected on our end and plough ahead. + # + # The 'except' clause is very broad, but we need to + # capture everything from DNS failures upwards + # logger.warn("Failed to reject invite: %s", e) yield self.store.locally_reject_invite( @@ -737,10 +739,11 @@ class RoomMemberHandler(BaseHandler): if len(current_state_ids) == 1 and create_event_id: defer.returnValue(self.hs.is_mine_id(create_event_id)) - for (etype, state_key), event_id in current_state_ids.items(): + for etype, state_key in current_state_ids: if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): continue + event_id = current_state_ids[(etype, state_key)] event = yield self.store.get_event(event_id, allow_none=True) if not event: continue diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0eea7f8f9..3b7818af5 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -24,7 +24,6 @@ from synapse.types import UserID, get_domain_from_id import logging from collections import namedtuple -import ujson as json logger = logging.getLogger(__name__) @@ -288,11 +287,13 @@ class TypingHandler(object): for room_id, serial in self._room_serials.items(): if last_id < serial and serial <= current_id: typing = self._room_typing[room_id] - typing_bytes = json.dumps(list(typing), ensure_ascii=False) - rows.append((serial, room_id, typing_bytes)) + rows.append((serial, room_id, list(typing))) rows.sort() return rows + def get_current_token(self): + return self._latest_room_serial + class TypingNotificationEventSource(object): def __init__(self, hs): diff --git a/synapse/http/client.py b/synapse/http/client.py index ca2f770f5..9eba046bb 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -16,9 +16,10 @@ from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE from synapse.api.errors import ( - CodeMessageException, SynapseError, Codes, + CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, ) from synapse.util.logcontext import preserve_context_over_fn +from synapse.util import logcontext import synapse.metrics from synapse.http.endpoint import SpiderEndpoint @@ -72,39 +73,45 @@ class SimpleHttpClient(object): contextFactory=hs.get_http_client_context_factory() ) self.user_agent = hs.version_string + self.clock = hs.get_clock() if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) + @defer.inlineCallbacks def request(self, method, uri, *args, **kwargs): # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.inc(method) - d = preserve_context_over_fn( - self.agent.request, - method, uri, *args, **kwargs - ) + + def send_request(): + request_deferred = self.agent.request( + method, uri, *args, **kwargs + ) + + return self.clock.time_bound_deferred( + request_deferred, + time_out=60, + ) logger.info("Sending request %s %s", method, uri) - def _cb(response): + try: + with logcontext.PreserveLoggingContext(): + response = yield send_request() + incoming_responses_counter.inc(method, response.code) logger.info( "Received response to %s %s: %s", method, uri, response.code ) - return response - - def _eb(failure): + defer.returnValue(response) + except Exception as e: incoming_responses_counter.inc(method, "ERR") logger.info( "Error sending request to %s %s: %s %s", - method, uri, failure.type, failure.getErrorMessage() + method, uri, type(e).__name__, e.message ) - return failure - - d.addCallbacks(_cb, _eb) - - return d + raise e @defer.inlineCallbacks def post_urlencoded_get_json(self, uri, args={}): @@ -145,6 +152,11 @@ class SimpleHttpClient(object): body = yield preserve_context_over_fn(readBody, response) + if 200 <= response.code < 300: + defer.returnValue(json.loads(body)) + else: + raise self._exceptionFromFailedRequest(response, body) + defer.returnValue(json.loads(body)) @defer.inlineCallbacks @@ -164,8 +176,11 @@ class SimpleHttpClient(object): On a non-2xx HTTP response. The response body will be used as the error message. """ - body = yield self.get_raw(uri, args) - defer.returnValue(json.loads(body)) + try: + body = yield self.get_raw(uri, args) + defer.returnValue(json.loads(body)) + except CodeMessageException as e: + raise self._exceptionFromFailedRequest(e.code, e.msg) @defer.inlineCallbacks def put_json(self, uri, json_body, args={}): @@ -246,6 +261,15 @@ class SimpleHttpClient(object): else: raise CodeMessageException(response.code, body) + def _exceptionFromFailedRequest(self, response, body): + try: + jsonBody = json.loads(body) + errcode = jsonBody['errcode'] + error = jsonBody['error'] + return MatrixCodeMessageException(response.code, error, errcode) + except (ValueError, KeyError): + return CodeMessageException(response.code, body) + # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 62b4d7e93..747a791f8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -125,6 +125,8 @@ class MatrixFederationHttpClient(object): code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. + (May also fail with plenty of other Exceptions for things like DNS + failures, connection failures, SSL failures.) """ limiter = yield synapse.util.retryutils.get_retry_limiter( destination, @@ -302,8 +304,10 @@ class MatrixFederationHttpClient(object): Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. On a 4xx or 5xx error response a - CodeMessageException is raised. + will be the decoded JSON body. + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. @@ -360,8 +364,10 @@ class MatrixFederationHttpClient(object): try the request anyway. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result - will be the decoded JSON body. On a 4xx or 5xx error response a - CodeMessageException is raised. + will be the decoded JSON body. + + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. @@ -410,10 +416,11 @@ class MatrixFederationHttpClient(object): ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. Returns: - Deferred: Succeeds when we get *any* HTTP response. + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. - The result of the deferred is a tuple of `(code, response)`, - where `response` is a dict representing the decoded JSON body. + Fails with ``HTTPRequestException`` if we get an HTTP response + code >= 300. Fails with ``NotRetryingDestination`` if we are not yet ready to retry this server. diff --git a/synapse/notifier.py b/synapse/notifier.py index 7eeba6d28..48566187a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -163,6 +163,8 @@ class Notifier(object): self.store = hs.get_datastore() self.pending_new_room_events = [] + self.replication_callbacks = [] + self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() @@ -202,7 +204,12 @@ class Notifier(object): lambda: len(self.user_to_user_stream), ) - @preserve_fn + def add_replication_callback(self, cb): + """Add a callback that will be called when some new data is available. + Callback is not given any arguments. + """ + self.replication_callbacks.append(cb) + def on_new_room_event(self, event, room_stream_id, max_room_stream_id, extra_users=[]): """ Used by handlers to inform the notifier something has happened @@ -216,15 +223,13 @@ class Notifier(object): until all previous events have been persisted before notifying the client streams. """ - with PreserveLoggingContext(): - self.pending_new_room_events.append(( - room_stream_id, event, extra_users - )) - self._notify_pending_new_room_events(max_room_stream_id) + self.pending_new_room_events.append(( + room_stream_id, event, extra_users + )) + self._notify_pending_new_room_events(max_room_stream_id) - self.notify_replication() + self.notify_replication() - @preserve_fn def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous event to be persisted. @@ -242,14 +247,16 @@ class Notifier(object): else: self._on_new_room_event(event, room_stream_id, extra_users) - @preserve_fn def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - self.appservice_handler.notify_interested_services(room_stream_id) + preserve_fn(self.appservice_handler.notify_interested_services)( + room_stream_id) if self.federation_sender: - self.federation_sender.notify_new_events(room_stream_id) + preserve_fn(self.federation_sender.notify_new_events)( + room_stream_id + ) if event.type == EventTypes.Member and event.membership == Membership.JOIN: self._user_joined_room(event.state_key, event.room_id) @@ -260,7 +267,6 @@ class Notifier(object): rooms=[event.room_id], ) - @preserve_fn def on_new_event(self, stream_key, new_token, users=[], rooms=[]): """ Used to inform listeners that something has happend event wise. @@ -287,7 +293,6 @@ class Notifier(object): self.notify_replication() - @preserve_fn def on_new_replication_data(self): """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" @@ -510,6 +515,9 @@ class Notifier(object): self.replication_deferred = ObservableDeferred(defer.Deferred()) deferred.callback(None) + for cb in self.replication_callbacks: + preserve_fn(cb)() + @defer.inlineCallbacks def wait_for_replication(self, callback, timeout): """Wait for an event to happen. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 78b095c90..f943ff640 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -87,7 +87,11 @@ class BulkPushRuleEvaluator: condition_cache = {} for uid, rules in self.rules_by_user.items(): - display_name = room_members.get(uid, {}).get("display_name", None) + display_name = None + profile_info = room_members.get(uid) + if profile_info: + display_name = profile_info.display_name + if not display_name: # Handle the case where we are pushing a membership event to # that user, as they might not be already joined. diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 3a50c72e0..f83aa7625 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -200,7 +200,11 @@ class Mailer(object): yield sendmail( self.hs.config.email_smtp_host, raw_from, raw_to, multipart_msg.as_string(), - port=self.hs.config.email_smtp_port + port=self.hs.config.email_smtp_port, + requireAuthentication=self.hs.config.email_smtp_user is not None, + username=self.hs.config.email_smtp_user, + password=self.hs.config.email_smtp_pass, + requireTransportSecurity=self.hs.config.require_transport_security ) @defer.inlineCallbacks diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 287df94b4..6835f54e9 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -17,15 +17,12 @@ from twisted.internet import defer from synapse.push.presentable_names import ( calculate_room_name, name_from_member_event ) -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred @defer.inlineCallbacks def get_badge_count(store, user_id): - invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(store.get_invited_rooms_for_user)(user_id), - preserve_fn(store.get_rooms_for_user)(user_id), - ], consumeErrors=True)) + invites = yield store.get_invited_rooms_for_user(user_id) + joins = yield store.get_rooms_for_user(user_id) my_receipts_by_room = yield store.get_receipts_for_user( user_id, "m.read", diff --git a/synapse/replication/expire_cache.py b/synapse/replication/expire_cache.py deleted file mode 100644 index c05a50d7a..000000000 --- a/synapse/replication/expire_cache.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.http.server import respond_with_json_bytes, request_handler -from synapse.http.servlet import parse_json_object_from_request - -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET - - -class ExpireCacheResource(Resource): - """ - HTTP endpoint for expiring storage caches. - - POST /_synapse/replication/expire_cache HTTP/1.1 - Content-Type: application/json - - { - "invalidate": [ - { - "name": "func_name", - "keys": ["key1", "key2"] - } - ] - } - """ - - def __init__(self, hs): - Resource.__init__(self) # Resource is old-style, so no super() - - self.store = hs.get_datastore() - self.version_string = hs.version_string - self.clock = hs.get_clock() - - def render_POST(self, request): - self._async_render_POST(request) - return NOT_DONE_YET - - @request_handler() - def _async_render_POST(self, request): - content = parse_json_object_from_request(request) - - for row in content["invalidate"]: - name = row["name"] - keys = tuple(row["keys"]) - - getattr(self.store, name).invalidate(keys) - - respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/presence_resource.py b/synapse/replication/presence_resource.py deleted file mode 100644 index fc18130ab..000000000 --- a/synapse/replication/presence_resource.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.http.server import respond_with_json_bytes, request_handler -from synapse.http.servlet import parse_json_object_from_request - -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET -from twisted.internet import defer - - -class PresenceResource(Resource): - """ - HTTP endpoint for marking users as syncing. - - POST /_synapse/replication/presence HTTP/1.1 - Content-Type: application/json - - { - "process_id": "", - "syncing_users": [""] - } - """ - - def __init__(self, hs): - Resource.__init__(self) # Resource is old-style, so no super() - - self.version_string = hs.version_string - self.clock = hs.get_clock() - self.presence_handler = hs.get_presence_handler() - - def render_POST(self, request): - self._async_render_POST(request) - return NOT_DONE_YET - - @request_handler() - @defer.inlineCallbacks - def _async_render_POST(self, request): - content = parse_json_object_from_request(request) - - process_id = content["process_id"] - syncing_user_ids = content["syncing_users"] - - yield self.presence_handler.update_external_syncs( - process_id, set(syncing_user_ids) - ) - - respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/pusher_resource.py b/synapse/replication/pusher_resource.py deleted file mode 100644 index 9b01ab3c1..000000000 --- a/synapse/replication/pusher_resource.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.http.server import respond_with_json_bytes, request_handler -from synapse.http.servlet import parse_json_object_from_request - -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET -from twisted.internet import defer - - -class PusherResource(Resource): - """ - HTTP endpoint for deleting rejected pushers - """ - - def __init__(self, hs): - Resource.__init__(self) # Resource is old-style, so no super() - - self.version_string = hs.version_string - self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self.clock = hs.get_clock() - - def render_POST(self, request): - self._async_render_POST(request) - return NOT_DONE_YET - - @request_handler() - @defer.inlineCallbacks - def _async_render_POST(self, request): - content = parse_json_object_from_request(request) - - for remove in content["remove"]: - yield self.store.delete_pusher_by_app_id_pushkey_user_id( - remove["app_id"], - remove["push_key"], - remove["user_id"], - ) - - self.notifier.on_new_replication_data() - - respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py deleted file mode 100644 index 03930fe95..000000000 --- a/synapse/replication/resource.py +++ /dev/null @@ -1,576 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.http.servlet import parse_integer, parse_string -from synapse.http.server import request_handler, finish_request -from synapse.replication.pusher_resource import PusherResource -from synapse.replication.presence_resource import PresenceResource -from synapse.replication.expire_cache import ExpireCacheResource -from synapse.api.errors import SynapseError - -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET -from twisted.internet import defer - -import ujson as json - -import collections -import logging - -logger = logging.getLogger(__name__) - -REPLICATION_PREFIX = "/_synapse/replication" - -STREAM_NAMES = ( - ("events",), - ("presence",), - ("typing",), - ("receipts",), - ("user_account_data", "room_account_data", "tag_account_data",), - ("backfill",), - ("push_rules",), - ("pushers",), - ("caches",), - ("to_device",), - ("public_rooms",), - ("federation",), - ("device_lists",), -) - - -class ReplicationResource(Resource): - """ - HTTP endpoint for extracting data from synapse. - - The streams of data returned by the endpoint are controlled by the - parameters given to the API. To return a given stream pass a query - parameter with a position in the stream to return data from or the - special value "-1" to return data from the start of the stream. - - If there is no data for any of the supplied streams after the given - position then the request will block until there is data for one - of the streams. This allows clients to long-poll this API. - - The possible streams are: - - * "streams": A special stream returing the positions of other streams. - * "events": The new events seen on the server. - * "presence": Presence updates. - * "typing": Typing updates. - * "receipts": Receipt updates. - * "user_account_data": Top-level per user account data. - * "room_account_data: Per room per user account data. - * "tag_account_data": Per room per user tags. - * "backfill": Old events that have been backfilled from other servers. - * "push_rules": Per user changes to push rules. - * "pushers": Per user changes to their pushers. - * "caches": Cache invalidations. - - The API takes two additional query parameters: - - * "timeout": How long to wait before returning an empty response. - * "limit": The maximum number of rows to return for the selected streams. - - The response is a JSON object with keys for each stream with updates. Under - each key is a JSON object with: - - * "position": The current position of the stream. - * "field_names": The names of the fields in each row. - * "rows": The updates as an array of arrays. - - There are a number of ways this API could be used: - - 1) To replicate the contents of the backing database to another database. - 2) To be notified when the contents of a shared backing database changes. - 3) To "tail" the activity happening on a server for debugging. - - In the first case the client would track all of the streams and store it's - own copy of the data. - - In the second case the client might theoretically just be able to follow - the "streams" stream to track where the other streams are. However in - practise it will probably need to get the contents of the streams in - order to expire the any in-memory caches. Whether it gets the contents - of the streams from this replication API or directly from the backing - store is a matter of taste. - - In the third case the client would use the "streams" stream to find what - streams are available and their current positions. Then it can start - long-polling this replication API for new data on those streams. - """ - - def __init__(self, hs): - Resource.__init__(self) # Resource is old-style, so no super() - - self.version_string = hs.version_string - self.store = hs.get_datastore() - self.sources = hs.get_event_sources() - self.presence_handler = hs.get_presence_handler() - self.typing_handler = hs.get_typing_handler() - self.federation_sender = hs.get_federation_sender() - self.notifier = hs.notifier - self.clock = hs.get_clock() - self.config = hs.get_config() - - self.putChild("remove_pushers", PusherResource(hs)) - self.putChild("syncing_users", PresenceResource(hs)) - self.putChild("expire_cache", ExpireCacheResource(hs)) - - def render_GET(self, request): - self._async_render_GET(request) - return NOT_DONE_YET - - @defer.inlineCallbacks - def current_replication_token(self): - stream_token = yield self.sources.get_current_token() - backfill_token = yield self.store.get_current_backfill_token() - push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() - pushers_token = self.store.get_pushers_stream_token() - caches_token = self.store.get_cache_stream_token() - public_rooms_token = self.store.get_current_public_room_stream_id() - federation_token = self.federation_sender.get_current_token() - device_list_token = self.store.get_device_stream_token() - - defer.returnValue(_ReplicationToken( - room_stream_token, - int(stream_token.presence_key), - int(stream_token.typing_key), - int(stream_token.receipt_key), - int(stream_token.account_data_key), - backfill_token, - push_rules_token, - pushers_token, - 0, # State stream is no longer a thing - caches_token, - int(stream_token.to_device_key), - int(public_rooms_token), - int(federation_token), - int(device_list_token), - )) - - @request_handler() - @defer.inlineCallbacks - def _async_render_GET(self, request): - limit = parse_integer(request, "limit", 100) - timeout = parse_integer(request, "timeout", 10 * 1000) - - request.setHeader(b"Content-Type", b"application/json") - - request_streams = { - name: parse_integer(request, name) - for names in STREAM_NAMES for name in names - } - request_streams["streams"] = parse_string(request, "streams") - - federation_ack = parse_integer(request, "federation_ack", None) - - def replicate(): - return self.replicate( - request_streams, limit, - federation_ack=federation_ack - ) - - writer = yield self.notifier.wait_for_replication(replicate, timeout) - result = writer.finish() - - for stream_name, stream_content in result.items(): - logger.info( - "Replicating %d rows of %s from %s -> %s", - len(stream_content["rows"]), - stream_name, - request_streams.get(stream_name), - stream_content["position"], - ) - - request.write(json.dumps(result, ensure_ascii=False)) - finish_request(request) - - @defer.inlineCallbacks - def replicate(self, request_streams, limit, federation_ack=None): - writer = _Writer() - current_token = yield self.current_replication_token() - logger.debug("Replicating up to %r", current_token) - - if limit == 0: - raise SynapseError(400, "Limit cannot be 0") - - yield self.account_data(writer, current_token, limit, request_streams) - yield self.events(writer, current_token, limit, request_streams) - # TODO: implement limit - yield self.presence(writer, current_token, request_streams) - yield self.typing(writer, current_token, request_streams) - yield self.receipts(writer, current_token, limit, request_streams) - yield self.push_rules(writer, current_token, limit, request_streams) - yield self.pushers(writer, current_token, limit, request_streams) - yield self.caches(writer, current_token, limit, request_streams) - yield self.to_device(writer, current_token, limit, request_streams) - yield self.public_rooms(writer, current_token, limit, request_streams) - yield self.device_lists(writer, current_token, limit, request_streams) - self.federation(writer, current_token, limit, request_streams, federation_ack) - self.streams(writer, current_token, request_streams) - - logger.debug("Replicated %d rows", writer.total) - defer.returnValue(writer) - - def streams(self, writer, current_token, request_streams): - request_token = request_streams.get("streams") - - streams = [] - - if request_token is not None: - if request_token == "-1": - for names, position in zip(STREAM_NAMES, current_token): - streams.extend((name, position) for name in names) - else: - items = zip( - STREAM_NAMES, - current_token, - _ReplicationToken(request_token) - ) - for names, current_id, last_id in items: - if last_id < current_id: - streams.extend((name, current_id) for name in names) - - if streams: - writer.write_header_and_rows( - "streams", streams, ("name", "position"), - position=str(current_token) - ) - - @defer.inlineCallbacks - def events(self, writer, current_token, limit, request_streams): - request_events = request_streams.get("events") - request_backfill = request_streams.get("backfill") - - if request_events is not None or request_backfill is not None: - if request_events is None: - request_events = current_token.events - if request_backfill is None: - request_backfill = current_token.backfill - - no_new_tokens = ( - request_events == current_token.events - and request_backfill == current_token.backfill - ) - if no_new_tokens: - return - - res = yield self.store.get_all_new_events( - request_backfill, request_events, - current_token.backfill, current_token.events, - limit - ) - - upto_events_token = _position_from_rows( - res.new_forward_events, current_token.events - ) - - upto_backfill_token = _position_from_rows( - res.new_backfill_events, current_token.backfill - ) - - if request_events != upto_events_token: - writer.write_header_and_rows("events", res.new_forward_events, ( - "position", "event_id", "room_id", "type", "state_key", - ), position=upto_events_token) - - if request_backfill != upto_backfill_token: - writer.write_header_and_rows("backfill", res.new_backfill_events, ( - "position", "event_id", "room_id", "type", "state_key", "redacts", - ), position=upto_backfill_token) - - writer.write_header_and_rows( - "forward_ex_outliers", res.forward_ex_outliers, - ("position", "event_id", "state_group"), - ) - writer.write_header_and_rows( - "backward_ex_outliers", res.backward_ex_outliers, - ("position", "event_id", "state_group"), - ) - - @defer.inlineCallbacks - def presence(self, writer, current_token, request_streams): - current_position = current_token.presence - - request_presence = request_streams.get("presence") - - if request_presence is not None and request_presence != current_position: - presence_rows = yield self.presence_handler.get_all_presence_updates( - request_presence, current_position - ) - upto_token = _position_from_rows(presence_rows, current_position) - writer.write_header_and_rows("presence", presence_rows, ( - "position", "user_id", "state", "last_active_ts", - "last_federation_update_ts", "last_user_sync_ts", - "status_msg", "currently_active", - ), position=upto_token) - - @defer.inlineCallbacks - def typing(self, writer, current_token, request_streams): - current_position = current_token.typing - - request_typing = request_streams.get("typing") - - if request_typing is not None and request_typing != current_position: - # If they have a higher token than current max, we can assume that - # they had been talking to a previous instance of the master. Since - # we reset the token on restart, the best (but hacky) thing we can - # do is to simply resend down all the typing notifications. - if request_typing > current_position: - request_typing = 0 - - typing_rows = yield self.typing_handler.get_all_typing_updates( - request_typing, current_position - ) - upto_token = _position_from_rows(typing_rows, current_position) - writer.write_header_and_rows("typing", typing_rows, ( - "position", "room_id", "typing" - ), position=upto_token) - - @defer.inlineCallbacks - def receipts(self, writer, current_token, limit, request_streams): - current_position = current_token.receipts - - request_receipts = request_streams.get("receipts") - - if request_receipts is not None and request_receipts != current_position: - receipts_rows = yield self.store.get_all_updated_receipts( - request_receipts, current_position, limit - ) - upto_token = _position_from_rows(receipts_rows, current_position) - writer.write_header_and_rows("receipts", receipts_rows, ( - "position", "room_id", "receipt_type", "user_id", "event_id", "data" - ), position=upto_token) - - @defer.inlineCallbacks - def account_data(self, writer, current_token, limit, request_streams): - current_position = current_token.account_data - - user_account_data = request_streams.get("user_account_data") - room_account_data = request_streams.get("room_account_data") - tag_account_data = request_streams.get("tag_account_data") - - if user_account_data is not None or room_account_data is not None: - if user_account_data is None: - user_account_data = current_position - if room_account_data is None: - room_account_data = current_position - - no_new_tokens = ( - user_account_data == current_position - and room_account_data == current_position - ) - if no_new_tokens: - return - - user_rows, room_rows = yield self.store.get_all_updated_account_data( - user_account_data, room_account_data, current_position, limit - ) - - upto_users_token = _position_from_rows(user_rows, current_position) - upto_rooms_token = _position_from_rows(room_rows, current_position) - - writer.write_header_and_rows("user_account_data", user_rows, ( - "position", "user_id", "type", "content" - ), position=upto_users_token) - writer.write_header_and_rows("room_account_data", room_rows, ( - "position", "user_id", "room_id", "type", "content" - ), position=upto_rooms_token) - - if tag_account_data is not None: - tag_rows = yield self.store.get_all_updated_tags( - tag_account_data, current_position, limit - ) - upto_tag_token = _position_from_rows(tag_rows, current_position) - writer.write_header_and_rows("tag_account_data", tag_rows, ( - "position", "user_id", "room_id", "tags" - ), position=upto_tag_token) - - @defer.inlineCallbacks - def push_rules(self, writer, current_token, limit, request_streams): - current_position = current_token.push_rules - - push_rules = request_streams.get("push_rules") - - if push_rules is not None and push_rules != current_position: - rows = yield self.store.get_all_push_rule_updates( - push_rules, current_position, limit - ) - upto_token = _position_from_rows(rows, current_position) - writer.write_header_and_rows("push_rules", rows, ( - "position", "event_stream_ordering", "user_id", "rule_id", "op", - "priority_class", "priority", "conditions", "actions" - ), position=upto_token) - - @defer.inlineCallbacks - def pushers(self, writer, current_token, limit, request_streams): - current_position = current_token.pushers - - pushers = request_streams.get("pushers") - - if pushers is not None and pushers != current_position: - updated, deleted = yield self.store.get_all_updated_pushers( - pushers, current_position, limit - ) - upto_token = _position_from_rows(updated, current_position) - writer.write_header_and_rows("pushers", updated, ( - "position", "user_id", "access_token", "profile_tag", "kind", - "app_id", "app_display_name", "device_display_name", "pushkey", - "ts", "lang", "data" - ), position=upto_token) - writer.write_header_and_rows("deleted_pushers", deleted, ( - "position", "user_id", "app_id", "pushkey" - ), position=upto_token) - - @defer.inlineCallbacks - def caches(self, writer, current_token, limit, request_streams): - current_position = current_token.caches - - caches = request_streams.get("caches") - - if caches is not None and caches != current_position: - updated_caches = yield self.store.get_all_updated_caches( - caches, current_position, limit - ) - upto_token = _position_from_rows(updated_caches, current_position) - writer.write_header_and_rows("caches", updated_caches, ( - "position", "cache_func", "keys", "invalidation_ts" - ), position=upto_token) - - @defer.inlineCallbacks - def to_device(self, writer, current_token, limit, request_streams): - current_position = current_token.to_device - - to_device = request_streams.get("to_device") - - if to_device is not None and to_device != current_position: - to_device_rows = yield self.store.get_all_new_device_messages( - to_device, current_position, limit - ) - upto_token = _position_from_rows(to_device_rows, current_position) - writer.write_header_and_rows("to_device", to_device_rows, ( - "position", "user_id", "device_id", "message_json" - ), position=upto_token) - - @defer.inlineCallbacks - def public_rooms(self, writer, current_token, limit, request_streams): - current_position = current_token.public_rooms - - public_rooms = request_streams.get("public_rooms") - - if public_rooms is not None and public_rooms != current_position: - public_rooms_rows = yield self.store.get_all_new_public_rooms( - public_rooms, current_position, limit - ) - upto_token = _position_from_rows(public_rooms_rows, current_position) - writer.write_header_and_rows("public_rooms", public_rooms_rows, ( - "position", "room_id", "visibility", "appservice_id", "network_id", - ), position=upto_token) - - def federation(self, writer, current_token, limit, request_streams, federation_ack): - if self.config.send_federation: - return - - current_position = current_token.federation - - federation = request_streams.get("federation") - - if federation is not None and federation != current_position: - federation_rows = self.federation_sender.get_replication_rows( - federation, limit, federation_ack=federation_ack, - ) - upto_token = _position_from_rows(federation_rows, current_position) - writer.write_header_and_rows("federation", federation_rows, ( - "position", "type", "content", - ), position=upto_token) - - @defer.inlineCallbacks - def device_lists(self, writer, current_token, limit, request_streams): - current_position = current_token.device_lists - - device_lists = request_streams.get("device_lists") - - if device_lists is not None and device_lists != current_position: - changes = yield self.store.get_all_device_list_changes_for_remotes( - device_lists, - ) - writer.write_header_and_rows("device_lists", changes, ( - "position", "user_id", "destination", - ), position=current_position) - - -class _Writer(object): - """Writes the streams as a JSON object as the response to the request""" - def __init__(self): - self.streams = {} - self.total = 0 - - def write_header_and_rows(self, name, rows, fields, position=None): - if position is None: - if rows: - position = rows[-1][0] - else: - return - - self.streams[name] = { - "position": position if type(position) is int else str(position), - "field_names": fields, - "rows": rows, - } - - self.total += len(rows) - - def __nonzero__(self): - return bool(self.total) - - def finish(self): - return self.streams - - -class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( - "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers", "state", "caches", "to_device", "public_rooms", - "federation", "device_lists", -))): - __slots__ = [] - - def __new__(cls, *args): - if len(args) == 1: - streams = [int(value) for value in args[0].split("_")] - if len(streams) < len(cls._fields): - streams.extend([0] * (len(cls._fields) - len(streams))) - return cls(*streams) - else: - return super(_ReplicationToken, cls).__new__(cls, *args) - - def __str__(self): - return "_".join(str(value) for value in self) - - -def _position_from_rows(rows, current_position): - """Calculates a position to return for a stream. Ideally we want to return the - position of the last row, as that will be the most correct. However, if there - are no rows we fall back to using the current position to stop us from - repeatedly hitting the storage layer unncessarily thinking there are updates. - (Not all advances of the token correspond to an actual update) - - We can't just always return the current position, as we often limit the - number of rows we replicate, and so the stream may lag. The assumption is - that if the storage layer returns no new rows then we are not lagging and - we are at the `current_position`. - """ - if rows: - return rows[-1][0] - return current_position diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index ab133db87..b96264116 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -15,7 +15,6 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine -from twisted.internet import defer from ._slaved_id_tracker import SlavedIdTracker @@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore): else: self._cache_id_gen = None - self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache" - self.http_client = hs.get_simple_http_client() + self.hs = hs def stream_positions(self): pos = {} @@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore): pos["caches"] = self._cache_id_gen.get_current_token() return pos - def process_replication(self, result): - stream = result.get("caches") - if stream: - for row in stream["rows"]: - ( - position, cache_func, keys, invalidation_ts, - ) = row - + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "caches": + self._cache_id_gen.advance(token) + for row in rows: try: - getattr(self, cache_func).invalidate(tuple(keys)) + getattr(self, row.cache_func).invalidate(tuple(row.keys)) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. pass - self._cache_id_gen.advance(int(stream["position"])) - return defer.succeed(None) def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) txn.call_after(self._send_invalidation_poke, cache_func, keys) - @defer.inlineCallbacks def _send_invalidation_poke(self, cache_func, keys): - try: - yield self.http_client.post_json_get_json(self.expire_cache_url, { - "invalidate": [{ - "name": cache_func.__name__, - "keys": list(keys), - }] - }) - except: - logger.exception("Failed to poke on expire_cache") + self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 77c64722c..efbd87918 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore): result["tag_account_data"] = position return result - def process_replication(self, result): - stream = result.get("user_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id, data_type = row[:3] - self.get_global_account_data_by_type_for_user.invalidate( - (data_type, user_id,) - ) - self.get_account_data_for_user.invalidate((user_id,)) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "tag_account_data": + self._account_data_id_gen.advance(token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - stream = result.get("room_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] - self.get_account_data_for_user.invalidate((user_id,)) + elif stream_name == "account_data": + self._account_data_id_gen.advance(token) + for row in rows: + if not row.room_id: + self.get_global_account_data_by_type_for_user.invalidate( + (row.data_type, row.user_id,) + ) + self.get_account_data_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - stream = result.get("tag_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] - self.get_tags_for_user.invalidate((user_id,)) - self._account_data_stream_cache.entity_has_changed( - user_id, position - ) - - return super(SlavedAccountDataStore, self).process_replication(result) + return super(SlavedAccountDataStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index f9102e0d8..6f3fb6477 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore): result["to_device"] = self._device_inbox_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("to_device") - if stream: - self._device_inbox_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - stream_id = row[0] - entity = row[1] - - if entity.startswith("@"): + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "to_device": + self._device_inbox_id_gen.advance(token) + for row in rows: + if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( - entity, stream_id + row.entity, token ) else: self._device_federation_outbox_stream_cache.entity_has_changed( - entity, stream_id + row.entity, token ) - - return super(SlavedDeviceInboxStore, self).process_replication(result) + return super(SlavedDeviceInboxStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index ca46aa17b..4d4a43547 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore): result["device_lists"] = self._device_list_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("device_lists") - if stream: - self._device_list_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - stream_id = row[0] - user_id = row[1] - destination = row[2] - + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "device_lists": + self._device_list_id_gen.advance(token) + for row in rows: self._device_list_stream_cache.entity_has_changed( - user_id, stream_id + row.user_id, token ) - if destination: + if row.destination: self._device_list_federation_stream_cache.entity_has_changed( - destination, stream_id + row.destination, token ) - - return super(SlavedDeviceStore, self).process_replication(result) + return super(SlavedDeviceStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d4db1e452..fcaf58b93 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -71,6 +71,7 @@ class SlavedEventStore(BaseSlavedStore): # to reach inside the __dict__ to extract them. get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] + get_hosts_in_room = RoomMemberStore.__dict__["get_hosts_in_room"] get_users_who_share_room_with_user = ( RoomMemberStore.__dict__["get_users_who_share_room_with_user"] ) @@ -101,9 +102,6 @@ class SlavedEventStore(BaseSlavedStore): _get_state_groups_from_groups_txn = ( DataStore._get_state_groups_from_groups_txn.__func__ ) - _get_state_group_from_group = ( - StateStore.__dict__["_get_state_group_from_group"] - ) get_recent_event_ids_for_room = ( StreamStore.__dict__["get_recent_event_ids_for_room"] ) @@ -146,6 +144,9 @@ class SlavedEventStore(BaseSlavedStore): RoomMemberStore.__dict__["_get_joined_users_from_context"] ) + get_joined_hosts = DataStore.get_joined_hosts.__func__ + _get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"] + get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ get_room_events_stream_for_rooms = ( DataStore.get_room_events_stream_for_rooms.__func__ @@ -201,48 +202,25 @@ class SlavedEventStore(BaseSlavedStore): result["backfill"] = -self._backfill_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("events") - if stream: - self._stream_id_gen.advance(int(stream["position"])) - - if stream["rows"]: - logger.info("Got %d event rows", len(stream["rows"])) - - for row in stream["rows"]: - self._process_replication_row( - row, backfilled=False, + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "events": + self._stream_id_gen.advance(token) + for row in rows: + self.invalidate_caches_for_event( + token, row.event_id, row.room_id, row.type, row.state_key, + row.redacts, + backfilled=False, ) - - stream = result.get("backfill") - if stream: - self._backfill_id_gen.advance(-int(stream["position"])) - for row in stream["rows"]: - self._process_replication_row( - row, backfilled=True, + elif stream_name == "backfill": + self._backfill_id_gen.advance(-token) + for row in rows: + self.invalidate_caches_for_event( + -token, row.event_id, row.room_id, row.type, row.state_key, + row.redacts, + backfilled=True, ) - - stream = result.get("forward_ex_outliers") - if stream: - self._stream_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - event_id = row[1] - self._invalidate_get_event_cache(event_id) - - stream = result.get("backward_ex_outliers") - if stream: - self._backfill_id_gen.advance(-int(stream["position"])) - for row in stream["rows"]: - event_id = row[1] - self._invalidate_get_event_cache(event_id) - - return super(SlavedEventStore, self).process_replication(result) - - def _process_replication_row(self, row, backfilled): - stream_ordering = row[0] if not backfilled else -row[0] - self.invalidate_caches_for_event( - stream_ordering, row[1], row[2], row[3], row[4], row[5], - backfilled=backfilled, + return super(SlavedEventStore, self).process_replication_rows( + stream_name, token, rows ) def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index e4a2414d7..cfb928018 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -39,6 +39,16 @@ class SlavedPresenceStore(BaseSlavedStore): _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] + # XXX: This is a bit broken because we don't persist the accepted list in a + # way that can be replicated. This means that we don't have a way to + # invalidate the cache correctly. + get_presence_list_accepted = PresenceStore.__dict__[ + "get_presence_list_accepted" + ] + get_presence_list_observers_accepted = PresenceStore.__dict__[ + "get_presence_list_observers_accepted" + ] + def get_current_presence_token(self): return self._presence_id_gen.get_current_token() @@ -48,15 +58,14 @@ class SlavedPresenceStore(BaseSlavedStore): result["presence"] = position return result - def process_replication(self, result): - stream = result.get("presence") - if stream: - self._presence_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "presence": + self._presence_id_gen.advance(token) + for row in rows: self.presence_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - self._get_presence_for_user.invalidate((user_id,)) - - return super(SlavedPresenceStore, self).process_replication(result) + self._get_presence_for_user.invalidate((row.user_id,)) + return super(SlavedPresenceStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 21ceb0213..83e880fdd 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore): result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("push_rules") - if stream: - for row in stream["rows"]: - position = row[0] - user_id = row[2] - self.get_push_rules_for_user.invalidate((user_id,)) - self.get_push_rules_enabled_for_user.invalidate((user_id,)) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "push_rules": + self._push_rules_stream_id_gen.advance(token) + for row in rows: + self.get_push_rules_for_user.invalidate((row.user_id,)) + self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - self._push_rules_stream_id_gen.advance(int(stream["position"])) - - return super(SlavedPushRuleStore, self).process_replication(result) + return super(SlavedPushRuleStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index d88206b3b..4e8d68ece 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore): result["pushers"] = self._pushers_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("pushers") - if stream: - self._pushers_id_gen.advance(int(stream["position"])) - - stream = result.get("deleted_pushers") - if stream: - self._pushers_id_gen.advance(int(stream["position"])) - - return super(SlavedPusherStore, self).process_replication(result) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "pushers": + self._pushers_id_gen.advance(token) + return super(SlavedPusherStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index ac9662d39..b371574ec 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore): result["receipts"] = self._receipts_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("receipts") - if stream: - self._receipts_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, room_id, receipt_type, user_id = row[:4] - self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) - self._receipts_stream_cache.entity_has_changed(room_id, position) - - return super(SlavedReceiptsStore, self).process_replication(result) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_linearized_receipts_for_room.invalidate_many((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) + + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "receipts": + self._receipts_id_gen.advance(token) + for row in rows: + self.invalidate_caches_for_receipt( + row.room_id, row.receipt_type, row.user_id + ) + self._receipts_stream_cache.entity_has_changed(row.room_id, token) + + return super(SlavedReceiptsStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 6df9a25ef..f51038403 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore): result["public_rooms"] = self._public_room_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("public_rooms") - if stream: - self._public_room_id_gen.advance(int(stream["position"])) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "public_rooms": + self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication(result) + return super(RoomStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py new file mode 100644 index 000000000..81c2ea7ee --- /dev/null +++ b/synapse/replication/tcp/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module implements the TCP replication protocol used by synapse to +communicate between the master process and its workers (when they're enabled). + +Further details can be found in docs/tcp_replication.rst + + +Structure of the module: + * client.py - the client classes used for workers to connect to master + * command.py - the definitions of all the valid commands + * protocol.py - contains bot the client and server protocol implementations, + these should not be used directly + * resource.py - the server classes that accepts and handle client connections + * streams.py - the definitons of all the valid streams + +""" diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py new file mode 100644 index 000000000..90fb6c133 --- /dev/null +++ b/synapse/replication/tcp/client.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A replication client for use by synapse workers. +""" + +from twisted.internet import reactor, defer +from twisted.internet.protocol import ReconnectingClientFactory + +from .commands import ( + FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand, +) +from .protocol import ClientReplicationStreamProtocol + +import logging + +logger = logging.getLogger(__name__) + + +class ReplicationClientFactory(ReconnectingClientFactory): + """Factory for building connections to the master. Will reconnect if the + connection is lost. + + Accepts a handler that will be called when new data is available or data + is required. + """ + maxDelay = 5 # Try at least once every N seconds + + def __init__(self, hs, client_name, handler): + self.client_name = client_name + self.handler = handler + self.server_name = hs.config.server_name + self._clock = hs.get_clock() # As self.clock is defined in super class + + reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying) + + def startedConnecting(self, connector): + logger.info("Connecting to replication: %r", connector.getDestination()) + + def buildProtocol(self, addr): + logger.info("Connected to replication: %r", addr) + self.resetDelay() + return ClientReplicationStreamProtocol( + self.client_name, self.server_name, self._clock, self.handler + ) + + def clientConnectionLost(self, connector, reason): + logger.error("Lost replication conn: %r", reason) + ReconnectingClientFactory.clientConnectionLost(self, connector, reason) + + def clientConnectionFailed(self, connector, reason): + logger.error("Failed to connect to replication: %r", reason) + ReconnectingClientFactory.clientConnectionFailed( + self, connector, reason + ) + + +class ReplicationClientHandler(object): + """A base handler that can be passed to the ReplicationClientFactory. + + By default proxies incoming replication data to the SlaveStore. + """ + def __init__(self, store): + self.store = store + + # The current connection. None if we are currently (re)connecting + self.connection = None + + # Any pending commands to be sent once a new connection has been + # established + self.pending_commands = [] + + # Map from string -> deferred, to wake up when receiveing a SYNC with + # the given string. + # Used for tests. + self.awaiting_syncs = {} + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + client_name = hs.config.worker_name + factory = ReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + reactor.connectTCP(host, port, factory) + + def on_rdata(self, stream_name, token, rows): + """Called when we get new replication data. By default this just pokes + the slave store. + + Can be overriden in subclasses to handle more. + """ + logger.info("Received rdata %s -> %s", stream_name, token) + self.store.process_replication_rows(stream_name, token, rows) + + def on_position(self, stream_name, token): + """Called when we get new position data. By default this just pokes + the slave store. + + Can be overriden in subclasses to handle more. + """ + self.store.process_replication_rows(stream_name, token, []) + + def on_sync(self, data): + """When we received a SYNC we wake up any deferreds that were waiting + for the sync with the given data. + + Used by tests. + """ + d = self.awaiting_syncs.pop(data, None) + if d: + d.callback(data) + + def get_streams_to_replicate(self): + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns a dictionary of stream name to token. + """ + args = self.store.stream_positions() + user_account_data = args.pop("user_account_data", None) + room_account_data = args.pop("room_account_data", None) + if user_account_data: + args["account_data"] = user_account_data + elif room_account_data: + args["account_data"] = room_account_data + return args + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. (Overriden by the synchrotron's only) + """ + return [] + + def send_command(self, cmd): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + if self.connection: + self.connection.send_command(cmd) + else: + logger.warn("Queuing command as not connected: %r", cmd.NAME) + self.pending_commands.append(cmd) + + def send_federation_ack(self, token): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync(self, user_id, is_syncing, last_sync_ms): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms)) + + def send_remove_pusher(self, app_id, push_key, user_id): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func, keys): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func.__name__, keys) + self.send_command(cmd) + + def await_sync(self, data): + """Returns a deferred that is resolved when we receive a SYNC command + with given data. + + Used by tests. + """ + return self.awaiting_syncs.setdefault(data, defer.Deferred()) + + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + self.connection = connection + if connection: + for cmd in self.pending_commands: + connection.send_command(cmd) + self.pending_commands = [] diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py new file mode 100644 index 000000000..84d2a2272 --- /dev/null +++ b/synapse/replication/tcp/commands.py @@ -0,0 +1,346 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the various valid commands + +The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are +allowed to be sent by which side. +""" + +import logging +import ujson as json + + +logger = logging.getLogger(__name__) + + +class Command(object): + """The base command class. + + All subclasses must set the NAME variable which equates to the name of the + command on the wire. + + A full command line on the wire is constructed from `NAME + " " + to_line()` + + The default implementation creates a command of form ` ` + """ + NAME = None + + def __init__(self, data): + self.data = data + + @classmethod + def from_line(cls, line): + """Deserialises a line from the wire into this command. `line` does not + include the command. + """ + return cls(line) + + def to_line(self): + """Serialises the comamnd for the wire. Does not include the command + prefix. + """ + return self.data + + +class ServerCommand(Command): + """Sent by the server on new connection and includes the server_name. + + Format:: + + SERVER + """ + NAME = "SERVER" + + +class RdataCommand(Command): + """Sent by server when a subscribed stream has an update. + + Format:: + + RDATA + + The `` may either be a numeric stream id OR "batch". The latter case + is used to support sending multiple updates with the same stream ID. This + is done by sending an RDATA for each row, with all but the last RDATA having + a token of "batch" and the last having the final stream ID. + + The client should batch all incoming RDATA with a token of "batch" (per + stream_name) until it sees an RDATA with a numeric stream ID. + + `` of "batch" maps to the instance variable `token` being None. + + An example of a batched series of RDATA:: + + RDATA presence batch ["@foo:example.com", "online", ...] + RDATA presence batch ["@bar:example.com", "online", ...] + RDATA presence 59 ["@baz:example.com", "online", ...] + """ + NAME = "RDATA" + + def __init__(self, stream_name, token, row): + self.stream_name = stream_name + self.token = token + self.row = row + + @classmethod + def from_line(cls, line): + stream_name, token, row_json = line.split(" ", 2) + return cls( + stream_name, + None if token == "batch" else int(token), + json.loads(row_json) + ) + + def to_line(self): + return " ".join(( + self.stream_name, + str(self.token) if self.token is not None else "batch", + json.dumps(self.row), + )) + + +class PositionCommand(Command): + """Sent by the client to tell the client the stream postition without + needing to send an RDATA. + """ + NAME = "POSITION" + + def __init__(self, stream_name, token): + self.stream_name = stream_name + self.token = token + + @classmethod + def from_line(cls, line): + stream_name, token = line.split(" ", 1) + return cls(stream_name, int(token)) + + def to_line(self): + return " ".join((self.stream_name, str(self.token),)) + + +class ErrorCommand(Command): + """Sent by either side if there was an ERROR. The data is a string describing + the error. + """ + NAME = "ERROR" + + +class PingCommand(Command): + """Sent by either side as a keep alive. The data is arbitary (often timestamp) + """ + NAME = "PING" + + +class NameCommand(Command): + """Sent by client to inform the server of the client's identity. The data + is the name + """ + NAME = "NAME" + + +class ReplicateCommand(Command): + """Sent by the client to subscribe to the stream. + + Format:: + + REPLICATE + + Where may be either: + * a numeric stream_id to stream updates from + * "NOW" to stream all subsequent updates. + + The can be "ALL" to subscribe to all known streams, in which + case the must be set to "NOW", i.e.:: + + REPLICATE ALL NOW + """ + NAME = "REPLICATE" + + def __init__(self, stream_name, token): + self.stream_name = stream_name + self.token = token + + @classmethod + def from_line(cls, line): + stream_name, token = line.split(" ", 1) + if token in ("NOW", "now"): + token = "NOW" + else: + token = int(token) + return cls(stream_name, token) + + def to_line(self): + return " ".join((self.stream_name, str(self.token),)) + + +class UserSyncCommand(Command): + """Sent by the client to inform the server that a user has started or + stopped syncing. Used to calculate presence on the master. + + Includes a timestamp of when the last user sync was. + + Format:: + + USER_SYNC + + Where is either "start" or "stop" + """ + NAME = "USER_SYNC" + + def __init__(self, user_id, is_syncing, last_sync_ms): + self.user_id = user_id + self.is_syncing = is_syncing + self.last_sync_ms = last_sync_ms + + @classmethod + def from_line(cls, line): + user_id, state, last_sync_ms = line.split(" ", 2) + + if state not in ("start", "end"): + raise Exception("Invalid USER_SYNC state %r" % (state,)) + + return cls(user_id, state == "start", int(last_sync_ms)) + + def to_line(self): + return " ".join(( + self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), + )) + + +class FederationAckCommand(Command): + """Sent by the client when it has processed up to a given point in the + federation stream. This allows the master to drop in-memory caches of the + federation stream. + + This must only be sent from one worker (i.e. the one sending federation) + + Format:: + + FEDERATION_ACK + """ + NAME = "FEDERATION_ACK" + + def __init__(self, token): + self.token = token + + @classmethod + def from_line(cls, line): + return cls(int(line)) + + def to_line(self): + return str(self.token) + + +class SyncCommand(Command): + """Used for testing. The client protocol implementation allows waiting + on a SYNC command with a specified data. + """ + NAME = "SYNC" + + +class RemovePusherCommand(Command): + """Sent by the client to request the master remove the given pusher. + + Format:: + + REMOVE_PUSHER + """ + NAME = "REMOVE_PUSHER" + + def __init__(self, app_id, push_key, user_id): + self.user_id = user_id + self.app_id = app_id + self.push_key = push_key + + @classmethod + def from_line(cls, line): + app_id, push_key, user_id = line.split(" ", 2) + + return cls(app_id, push_key, user_id) + + def to_line(self): + return " ".join((self.app_id, self.push_key, self.user_id)) + + +class InvalidateCacheCommand(Command): + """Sent by the client to invalidate an upstream cache. + + THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE + NOT DISASTROUS IF WE DROP ON THE FLOOR. + + Mainly used to invalidate destination retry timing caches. + + Format:: + + INVALIDATE_CACHE + + Where is a json list. + """ + NAME = "INVALIDATE_CACHE" + + def __init__(self, cache_func, keys): + self.cache_func = cache_func + self.keys = keys + + @classmethod + def from_line(cls, line): + cache_func, keys_json = line.split(" ", 1) + + return cls(cache_func, json.loads(keys_json)) + + def to_line(self): + return " ".join((self.cache_func, json.dumps(self.keys))) + + +# Map of command name to command type. +COMMAND_MAP = { + cmd.NAME: cmd + for cmd in ( + ServerCommand, + RdataCommand, + PositionCommand, + ErrorCommand, + PingCommand, + NameCommand, + ReplicateCommand, + UserSyncCommand, + FederationAckCommand, + SyncCommand, + RemovePusherCommand, + InvalidateCacheCommand, + ) +} + +# The commands the server is allowed to send +VALID_SERVER_COMMANDS = ( + ServerCommand.NAME, + RdataCommand.NAME, + PositionCommand.NAME, + ErrorCommand.NAME, + PingCommand.NAME, + SyncCommand.NAME, +) + +# The commands the client is allowed to send +VALID_CLIENT_COMMANDS = ( + NameCommand.NAME, + ReplicateCommand.NAME, + PingCommand.NAME, + UserSyncCommand.NAME, + FederationAckCommand.NAME, + RemovePusherCommand.NAME, + InvalidateCacheCommand.NAME, + ErrorCommand.NAME, +) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py new file mode 100644 index 000000000..9fee2a484 --- /dev/null +++ b/synapse/replication/tcp/protocol.py @@ -0,0 +1,640 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains the implementation of both the client and server +protocols. + +The basic structure of the protocol is line based, where the initial word of +each line specifies the command. The rest of the line is parsed based on the +command. For example, the `RDATA` command is defined as:: + + RDATA + +(Note that `` may contains spaces, but cannot contain newlines.) + +Blank lines are ignored. + +# Example + +An example iteraction is shown below. Each line is prefixed with '>' or '<' to +indicate which side is sending, these are *not* included on the wire:: + + * connection established * + > SERVER localhost:8823 + > PING 1490197665618 + < NAME synapse.app.appservice + < PING 1490197665618 + < REPLICATE events 1 + < REPLICATE backfill 1 + < REPLICATE caches 1 + > POSITION events 1 + > POSITION backfill 1 + > POSITION caches 1 + > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] + > RDATA events 14 ["$149019767112vOHxz:localhost:8823", + "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] + < PING 1490197675618 + > ERROR server stopping + * connection closed by server * +""" + +from twisted.internet import defer +from twisted.protocols.basic import LineOnlyReceiver +from twisted.python.failure import Failure + +from commands import ( + COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, + ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand, + NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand, +) +from streams import STREAMS_MAP + +from synapse.util.stringutils import random_string +from synapse.metrics.metric import CounterMetric + +import logging +import synapse.metrics +import struct +import fcntl + + +metrics = synapse.metrics.get_metrics_for(__name__) + +connection_close_counter = metrics.register_counter( + "close_reason", labels=["reason_type"], +) + + +# A list of all connected protocols. This allows us to send metrics about the +# connections. +connected_connections = [] + + +logger = logging.getLogger(__name__) + + +PING_TIME = 5000 +PING_TIMEOUT_MULTIPLIER = 5 +PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER + + +class ConnectionStates(object): + CONNECTING = "connecting" + ESTABLISHED = "established" + PAUSED = "paused" + CLOSED = "closed" + + +class BaseReplicationStreamProtocol(LineOnlyReceiver): + """Base replication protocol shared between client and server. + + Reads lines (ignoring blank ones) and parses them into command classes, + asserting that they are valid for the given direction, i.e. server commands + are only sent by the server. + + On receiving a new command it calls `on_` with the parsed + command. + + It also sends `PING` periodically, and correctly times out remote connections + (if they send a `PING` command) + """ + delimiter = b'\n' + + VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive + VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send + + max_line_buffer = 10000 + + def __init__(self, clock): + self.clock = clock + + self.last_received_command = self.clock.time_msec() + self.last_sent_command = 0 + self.time_we_closed = None # When we requested the connection be closed + + self.received_ping = False # Have we reecived a ping from the other side + + self.state = ConnectionStates.CONNECTING + + self.name = "anon" # The name sent by a client. + self.conn_id = random_string(5) # To dedupe in case of name clashes. + + # List of pending commands to send once we've established the connection + self.pending_commands = [] + + # The LoopingCall for sending pings. + self._send_ping_loop = None + + self.inbound_commands_counter = CounterMetric( + "inbound_commands", labels=["command"], + ) + self.outbound_commands_counter = CounterMetric( + "outbound_commands", labels=["command"], + ) + + def connectionMade(self): + logger.info("[%s] Connection established", self.id()) + + self.state = ConnectionStates.ESTABLISHED + + connected_connections.append(self) # Register connection for metrics + + self.transport.registerProducer(self, True) # For the *Producing callbacks + + self._send_pending_commands() + + # Starts sending pings + self._send_ping_loop = self.clock.looping_call(self.send_ping, 5000) + + # Always send the initial PING so that the other side knows that they + # can time us out. + self.send_command(PingCommand(self.clock.time_msec())) + + def send_ping(self): + """Periodically sends a ping and checks if we should close the connection + due to the other side timing out. + """ + now = self.clock.time_msec() + + if self.time_we_closed: + if now - self.time_we_closed > PING_TIMEOUT_MS: + logger.info( + "[%s] Failed to close connection gracefully, aborting", self.id() + ) + self.transport.abortConnection() + else: + if now - self.last_sent_command >= PING_TIME: + self.send_command(PingCommand(now)) + + if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS: + logger.info( + "[%s] Connection hasn't received command in %r ms. Closing.", + self.id(), now - self.last_received_command + ) + self.send_error("ping timeout") + + def lineReceived(self, line): + """Called when we've received a line + """ + if line.strip() == "": + # Ignore blank lines + return + + line = line.decode("utf-8") + cmd_name, rest_of_line = line.split(" ", 1) + + if cmd_name not in self.VALID_INBOUND_COMMANDS: + logger.error("[%s] invalid command %s", self.id(), cmd_name) + self.send_error("invalid command: %s", cmd_name) + return + + self.last_received_command = self.clock.time_msec() + + self.inbound_commands_counter.inc(cmd_name) + + cmd_cls = COMMAND_MAP[cmd_name] + try: + cmd = cmd_cls.from_line(rest_of_line) + except Exception as e: + logger.exception( + "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line + ) + self.send_error( + "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) + ) + return + + # Now lets try and call on_ function + try: + getattr(self, "on_%s" % (cmd_name,))(cmd) + except Exception: + logger.exception("[%s] Failed to handle line: %r", self.id(), line) + + def close(self): + logger.warn("[%s] Closing connection", self.id()) + self.time_we_closed = self.clock.time_msec() + self.transport.loseConnection() + self.on_connection_closed() + + def send_error(self, error_string, *args): + """Send an error to remote and close the connection. + """ + self.send_command(ErrorCommand(error_string % args)) + self.close() + + def send_command(self, cmd, do_buffer=True): + """Send a command if connection has been established. + + Args: + cmd (Command) + do_buffer (bool): Whether to buffer the message or always attempt + to send the command. This is mostly used to send an error + message if we're about to close the connection due our buffers + becoming full. + """ + if self.state == ConnectionStates.CLOSED: + logger.info("[%s] Not sending, connection closed", self.id()) + return + + if do_buffer and self.state != ConnectionStates.ESTABLISHED: + self._queue_command(cmd) + return + + self.outbound_commands_counter.inc(cmd.NAME) + + string = "%s %s" % (cmd.NAME, cmd.to_line(),) + if "\n" in string: + raise Exception("Unexpected newline in command: %r", string) + + self.sendLine(string.encode("utf-8")) + + self.last_sent_command = self.clock.time_msec() + + def _queue_command(self, cmd): + """Queue the command until the connection is ready to write to again. + """ + logger.info("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + self.pending_commands.append(cmd) + + if len(self.pending_commands) > self.max_line_buffer: + # The other side is failing to keep up and out buffers are becoming + # full, so lets close the connection. + # XXX: should we squawk more loudly? + logger.error("[%s] Remote failed to keep up", self.id()) + self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False) + self.close() + + def _send_pending_commands(self): + """Send any queued commandes + """ + pending = self.pending_commands + self.pending_commands = [] + for cmd in pending: + self.send_command(cmd) + + def on_PING(self, line): + self.received_ping = True + + def on_ERROR(self, cmd): + logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) + + def pauseProducing(self): + """This is called when both the kernel send buffer and the twisted + tcp connection send buffers have become full. + + We don't actually have any control over those sizes, so we buffer some + commands ourselves before knifing the connection due to the remote + failing to keep up. + """ + logger.info("[%s] Pause producing", self.id()) + self.state = ConnectionStates.PAUSED + + def resumeProducing(self): + """The remote has caught up after we started buffering! + """ + logger.info("[%s] Resume producing", self.id()) + self.state = ConnectionStates.ESTABLISHED + self._send_pending_commands() + + def stopProducing(self): + """We're never going to send any more data (normally because either + we or the remote has closed the connection) + """ + logger.info("[%s] Stop producing", self.id()) + self.on_connection_closed() + + def connectionLost(self, reason): + logger.info("[%s] Replication connection closed: %r", self.id(), reason) + if isinstance(reason, Failure): + connection_close_counter.inc(reason.type.__name__) + else: + connection_close_counter.inc(reason.__class__.__name__) + + try: + # Remove us from list of connections to be monitored + connected_connections.remove(self) + except ValueError: + pass + + # Stop the looping call sending pings. + if self._send_ping_loop and self._send_ping_loop.running: + self._send_ping_loop.stop() + + self.on_connection_closed() + + def on_connection_closed(self): + logger.info("[%s] Connection was closed", self.id()) + + self.state = ConnectionStates.CLOSED + self.pending_commands = [] + + if self.transport: + self.transport.unregisterProducer() + + def __str__(self): + return "ReplicationConnection" % ( + self.name, self.conn_id, self.addr, + ) + + def id(self): + return "%s-%s" % (self.name, self.conn_id) + + +class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): + VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS + VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS + + def __init__(self, server_name, clock, streamer, addr): + BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + + self.server_name = server_name + self.streamer = streamer + self.addr = addr + + # The streams the client has subscribed to and is up to date with + self.replication_streams = set() + + # The streams the client is currently subscribing to. + self.connecting_streams = set() + + # Map from stream name to list of updates to send once we've finished + # subscribing the client to the stream. + self.pending_rdata = {} + + def connectionMade(self): + self.send_command(ServerCommand(self.server_name)) + BaseReplicationStreamProtocol.connectionMade(self) + self.streamer.new_connection(self) + + def on_NAME(self, cmd): + logger.info("[%s] Renamed to %r", self.id(), cmd.data) + self.name = cmd.data + + def on_USER_SYNC(self, cmd): + self.streamer.on_user_sync( + self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, + ) + + def on_REPLICATE(self, cmd): + stream_name = cmd.stream_name + token = cmd.token + + if stream_name == "ALL": + # Subscribe to all streams we're publishing to. + for stream in self.streamer.streams_by_name.iterkeys(): + self.subscribe_to_stream(stream, token) + else: + self.subscribe_to_stream(stream_name, token) + + def on_FEDERATION_ACK(self, cmd): + self.streamer.federation_ack(cmd.token) + + def on_REMOVE_PUSHER(self, cmd): + self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + + def on_INVALIDATE_CACHE(self, cmd): + self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + + @defer.inlineCallbacks + def subscribe_to_stream(self, stream_name, token): + """Subscribe the remote to a streams. + + This invloves checking if they've missed anything and sending those + updates down if they have. During that time new updates for the stream + are queued and sent once we've sent down any missed updates. + """ + self.replication_streams.discard(stream_name) + self.connecting_streams.add(stream_name) + + try: + # Get missing updates + updates, current_token = yield self.streamer.get_stream_updates( + stream_name, token, + ) + + # Send all the missing updates + for update in updates: + token, row = update[0], update[1] + self.send_command(RdataCommand(stream_name, token, row)) + + # We send a POSITION command to ensure that they have an up to + # date token (especially useful if we didn't send any updates + # above) + self.send_command(PositionCommand(stream_name, current_token)) + + # Now we can send any updates that came in while we were subscribing + pending_rdata = self.pending_rdata.pop(stream_name, []) + for token, update in pending_rdata: + # Only send updates newer than the current token + if token > current_token: + self.send_command(RdataCommand(stream_name, token, update)) + + # They're now fully subscribed + self.replication_streams.add(stream_name) + except Exception as e: + logger.exception("[%s] Failed to handle REPLICATE command", self.id()) + self.send_error("failed to handle replicate: %r", e) + finally: + self.connecting_streams.discard(stream_name) + + def stream_update(self, stream_name, token, data): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + if stream_name in self.replication_streams: + # The client is subscribed to the stream + self.send_command(RdataCommand(stream_name, token, data)) + elif stream_name in self.connecting_streams: + # The client is being subscribed to the stream + logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) + self.pending_rdata.setdefault(stream_name, []).append((token, data)) + else: + # The client isn't subscribed + logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) + + def send_sync(self, data): + self.send_command(SyncCommand(data)) + + def on_connection_closed(self): + BaseReplicationStreamProtocol.on_connection_closed(self) + self.streamer.lost_connection(self) + + +class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): + VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS + VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS + + def __init__(self, client_name, server_name, clock, handler): + BaseReplicationStreamProtocol.__init__(self, clock) + + self.client_name = client_name + self.server_name = server_name + self.handler = handler + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self.pending_batches = {} + + def connectionMade(self): + self.send_command(NameCommand(self.client_name)) + BaseReplicationStreamProtocol.connectionMade(self) + + # Once we've connected subscribe to the necessary streams + for stream_name, token in self.handler.get_streams_to_replicate().iteritems(): + self.replicate(stream_name, token) + + # Tell the server if we have any users currently syncing (should only + # happen on synchrotrons) + currently_syncing = self.handler.get_currently_syncing_users() + now = self.clock.time_msec() + for user_id in currently_syncing: + self.send_command(UserSyncCommand(user_id, True, now)) + + # We've now finished connecting to so inform the client handler + self.handler.update_connection(self) + + def on_SERVER(self, cmd): + if cmd.data != self.server_name: + logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) + self.send_error("Wrong remote") + + def on_RDATA(self, cmd): + try: + row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row) + except Exception: + logger.exception( + "[%s] Failed to parse RDATA: %r %r", + self.id(), cmd.stream_name, cmd.row + ) + raise + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream. Batch + # until we get an update for the stream with a non None token + self.pending_batches.setdefault(cmd.stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self.pending_batches.pop(cmd.stream_name, []) + rows.append(row) + + self.handler.on_rdata(cmd.stream_name, cmd.token, rows) + + def on_POSITION(self, cmd): + self.handler.on_position(cmd.stream_name, cmd.token) + + def on_SYNC(self, cmd): + self.handler.on_sync(cmd.data) + + def replicate(self, stream_name, token): + """Send the subscription request to the server + """ + if stream_name not in STREAMS_MAP: + raise Exception("Invalid stream name %r" % (stream_name,)) + + logger.info( + "[%s] Subscribing to replication stream: %r from %r", + self.id(), stream_name, token + ) + + self.send_command(ReplicateCommand(stream_name, token)) + + def on_connection_closed(self): + BaseReplicationStreamProtocol.on_connection_closed(self) + self.handler.update_connection(None) + + +# The following simply registers metrics for the replication connections + +metrics.register_callback( + "pending_commands", + lambda: { + (p.name, p.conn_id): len(p.pending_commands) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +def transport_buffer_size(protocol): + if protocol.transport: + size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen + return size + return 0 + + +metrics.register_callback( + "transport_send_buffer", + lambda: { + (p.name, p.conn_id): transport_buffer_size(p) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +def transport_kernel_read_buffer_size(protocol, read=True): + SIOCINQ = 0x541B + SIOCOUTQ = 0x5411 + + if protocol.transport: + fileno = protocol.transport.getHandle().fileno() + if read: + op = SIOCINQ + else: + op = SIOCOUTQ + size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0] + return size + return 0 + + +metrics.register_callback( + "transport_kernel_send_buffer", + lambda: { + (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +metrics.register_callback( + "transport_kernel_read_buffer", + lambda: { + (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +metrics.register_callback( + "inbound_commands", + lambda: { + (k[0], p.name, p.conn_id): count + for p in connected_connections + for k, count in p.inbound_commands_counter.counts.iteritems() + }, + labels=["command", "name", "conn_id"], +) + +metrics.register_callback( + "outbound_commands", + lambda: { + (k[0], p.name, p.conn_id): count + for p in connected_connections + for k, count in p.outbound_commands_counter.counts.iteritems() + }, + labels=["command", "name", "conn_id"], +) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py new file mode 100644 index 000000000..8b2c4c304 --- /dev/null +++ b/synapse/replication/tcp/resource.py @@ -0,0 +1,290 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The server side of the replication stream. +""" + +from twisted.internet import defer, reactor +from twisted.internet.protocol import Factory + +from streams import STREAMS_MAP, FederationStream +from protocol import ServerReplicationStreamProtocol + +from synapse.util.metrics import Measure, measure_func + +import logging +import synapse.metrics + + +metrics = synapse.metrics.get_metrics_for(__name__) +stream_updates_counter = metrics.register_counter( + "stream_updates", labels=["stream_name"] +) +user_sync_counter = metrics.register_counter("user_sync") +federation_ack_counter = metrics.register_counter("federation_ack") +remove_pusher_counter = metrics.register_counter("remove_pusher") +invalidate_cache_counter = metrics.register_counter("invalidate_cache") + +logger = logging.getLogger(__name__) + + +class ReplicationStreamProtocolFactory(Factory): + """Factory for new replication connections. + """ + def __init__(self, hs): + self.streamer = ReplicationStreamer(hs) + self.clock = hs.get_clock() + self.server_name = hs.config.server_name + + def buildProtocol(self, addr): + return ServerReplicationStreamProtocol( + self.server_name, + self.clock, + self.streamer, + addr + ) + + +class ReplicationStreamer(object): + """Handles replication connections. + + This needs to be poked when new replication data may be available. When new + data is available it will propagate to all connected clients. + """ + + def __init__(self, hs): + self.store = hs.get_datastore() + self.presence_handler = hs.get_presence_handler() + self.clock = hs.get_clock() + + # Current connections. + self.connections = [] + + metrics.register_callback("total_connections", lambda: len(self.connections)) + + # List of streams that clients can subscribe to. + # We only support federation stream if federation sending hase been + # disabled on the master. + self.streams = [ + stream(hs) for stream in STREAMS_MAP.itervalues() + if stream != FederationStream or not hs.config.send_federation + ] + + self.streams_by_name = {stream.NAME: stream for stream in self.streams} + + metrics.register_callback( + "connections_per_stream", + lambda: { + (stream_name,): len([ + conn for conn in self.connections + if stream_name in conn.replication_streams + ]) + for stream_name in self.streams_by_name + }, + labels=["stream_name"], + ) + + self.federation_sender = None + if not hs.config.send_federation: + self.federation_sender = hs.get_federation_sender() + + hs.get_notifier().add_replication_callback(self.on_notifier_poke) + + # Keeps track of whether we are currently checking for updates + self.is_looping = False + self.pending_updates = False + + reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown) + + def on_shutdown(self): + # close all connections on shutdown + for conn in self.connections: + conn.send_error("server shutting down") + + @defer.inlineCallbacks + def on_notifier_poke(self): + """Checks if there is actually any new data and sends it to the + connections if there are. + + This should get called each time new data is available, even if it + is currently being executed, so that nothing gets missed + """ + if not self.connections: + # Don't bother if nothing is listening. We still need to advance + # the stream tokens otherwise they'll fall beihind forever + for stream in self.streams: + stream.discard_updates_and_advance() + return + + # If we're in the process of checking for new updates, mark that fact + # and return + if self.is_looping: + logger.debug("Noitifier poke loop already running") + self.pending_updates = True + return + + self.pending_updates = True + self.is_looping = True + + try: + # Keep looping while there have been pokes about potential updates. + # This protects against the race where a stream we already checked + # gets an update while we're handling other streams. + while self.pending_updates: + self.pending_updates = False + + with Measure(self.clock, "repl.stream.get_updates"): + # First we tell the streams that they should update their + # current tokens. + for stream in self.streams: + stream.advance_current_token() + + for stream in self.streams: + if stream.last_token == stream.upto_token: + continue + + logger.debug( + "Getting stream: %s: %s -> %s", + stream.NAME, stream.last_token, stream.upto_token + ) + updates, current_token = yield stream.get_updates() + + logger.debug( + "Sending %d updates to %d connections", + len(updates), len(self.connections), + ) + + if updates: + logger.info( + "Streaming: %s -> %s", stream.NAME, updates[-1][0] + ) + stream_updates_counter.inc_by(len(updates), stream.NAME) + + # Some streams return multiple rows with the same stream IDs, + # we need to make sure they get sent out in batches. We do + # this by setting the current token to all but the last of + # a series of updates with the same token to have a None + # token. See RdataCommand for more details. + batched_updates = _batch_updates(updates) + + for conn in self.connections: + for token, row in batched_updates: + try: + conn.stream_update(stream.NAME, token, row) + except Exception: + logger.exception("Failed to replicate") + + logger.debug("No more pending updates, breaking poke loop") + finally: + self.pending_updates = False + self.is_looping = False + + @measure_func("repl.get_stream_updates") + def get_stream_updates(self, stream_name, token): + """For a given stream get all updates since token. This is called when + a client first subscribes to a stream. + """ + stream = self.streams_by_name.get(stream_name, None) + if not stream: + raise Exception("unknown stream %s", stream_name) + + return stream.get_updates_since(token) + + @measure_func("repl.federation_ack") + def federation_ack(self, token): + """We've received an ack for federation stream from a client. + """ + federation_ack_counter.inc() + if self.federation_sender: + self.federation_sender.federation_ack(token) + + @measure_func("repl.on_user_sync") + def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + """A client has started/stopped syncing on a worker. + """ + user_sync_counter.inc() + self.presence_handler.update_external_syncs_row( + conn_id, user_id, is_syncing, last_sync_ms, + ) + + @measure_func("repl.on_remove_pusher") + @defer.inlineCallbacks + def on_remove_pusher(self, app_id, push_key, user_id): + """A client has asked us to remove a pusher + """ + remove_pusher_counter.inc() + yield self.store.delete_pusher_by_app_id_pushkey_user_id( + app_id=app_id, pushkey=push_key, user_id=user_id + ) + + self.notifier.on_new_replication_data() + + @measure_func("repl.on_invalidate_cache") + def on_invalidate_cache(self, cache_func, keys): + """The client has asked us to invalidate a cache + """ + invalidate_cache_counter.inc() + getattr(self.store, cache_func).invalidate(tuple(keys)) + + def send_sync_to_all_connections(self, data): + """Sends a SYNC command to all clients. + + Used in tests. + """ + for conn in self.connections: + conn.send_sync(data) + + def new_connection(self, connection): + """A new client connection has been established + """ + self.connections.append(connection) + + def lost_connection(self, connection): + """A client connection has been lost + """ + try: + self.connections.remove(connection) + except ValueError: + pass + + # We need to tell the presence handler that the connection has been + # lost so that it can handle any ongoing syncs on that connection. + self.presence_handler.update_external_syncs_clear(connection.conn_id) + + +def _batch_updates(updates): + """Takes a list of updates of form [(token, row)] and sets the token to + None for all rows where the next row has the same token. This is used to + implement batching. + + For example: + + [(1, _), (1, _), (2, _), (3, _), (3, _)] + + becomes: + + [(None, _), (1, _), (2, _), (None, _), (3, _)] + """ + if not updates: + return [] + + new_updates = [] + for i, update in enumerate(updates[:-1]): + if update[0] == updates[i + 1][0]: + new_updates.append((None, update[1])) + else: + new_updates.append(update) + + new_updates.append(updates[-1]) + return new_updates diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py new file mode 100644 index 000000000..369d5f242 --- /dev/null +++ b/synapse/replication/tcp/streams.py @@ -0,0 +1,464 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines all the valid streams that clients can subscribe to, and the format +of the rows returned by each stream. + +Each stream is defined by the following information: + + stream name: The name of the stream + row type: The type that is used to serialise/deserialse the row + current_token: The function that returns the current token for the stream + update_function: The function that returns a list of updates between two tokens +""" + +from twisted.internet import defer +from collections import namedtuple + +import logging + + +logger = logging.getLogger(__name__) + + +MAX_EVENTS_BEHIND = 10000 + + +EventStreamRow = namedtuple("EventStreamRow", ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional +)) +BackfillStreamRow = namedtuple("BackfillStreamRow", ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional +)) +PresenceStreamRow = namedtuple("PresenceStreamRow", ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool +)) +TypingStreamRow = namedtuple("TypingStreamRow", ( + "room_id", # str + "user_ids", # list(str) +)) +ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict +)) +PushRulesStreamRow = namedtuple("PushRulesStreamRow", ( + "user_id", # str +)) +PushersStreamRow = namedtuple("PushersStreamRow", ( + "user_id", # str + "app_id", # str + "pushkey", # str + "deleted", # bool +)) +CachesStreamRow = namedtuple("CachesStreamRow", ( + "cache_func", # str + "keys", # list(str) + "invalidation_ts", # int +)) +PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional +)) +DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( + "user_id", # str + "destination", # str +)) +ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( + "entity", # str +)) +FederationStreamRow = namedtuple("FederationStreamRow", ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow +)) +TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( + "user_id", # str + "room_id", # str + "data", # dict +)) +AccountDataStreamRow = namedtuple("AccountDataStream", ( + "user_id", # str + "room_id", # str + "data_type", # str + "data", # dict +)) + + +class Stream(object): + """Base class for the streams. + + Provides a `get_updates()` function that returns new updates since the last + time it was called up until the point `advance_current_token` was called. + """ + NAME = None # The name of the stream + ROW_TYPE = None # The type of the row + _LIMITED = True # Whether the update function takes a limit + + def __init__(self, hs): + # The token from which we last asked for updates + self.last_token = self.current_token() + + # The token that we will get updates up to + self.upto_token = self.current_token() + + def advance_current_token(self): + """Updates `upto_token` to "now", which updates up until which point + get_updates[_since] will fetch rows till. + """ + self.upto_token = self.current_token() + + def discard_updates_and_advance(self): + """Called when the stream should advance but the updates would be discarded, + e.g. when there are no currently connected workers. + """ + self.upto_token = self.current_token() + self.last_token = self.upto_token + + @defer.inlineCallbacks + def get_updates(self): + """Gets all updates since the last time this function was called (or + since the stream was constructed if it hadn't been called before), + until the `upto_token` + + Returns: + (list(ROW_TYPE), int): list of updates plus the token used as an + upper bound of the updates (i.e. the "current token") + """ + updates, current_token = yield self.get_updates_since(self.last_token) + self.last_token = current_token + + defer.returnValue((updates, current_token)) + + @defer.inlineCallbacks + def get_updates_since(self, from_token): + """Like get_updates except allows specifying from when we should + stream updates + + Returns: + (list(ROW_TYPE), int): list of updates plus the token used as an + upper bound of the updates (i.e. the "current token") + """ + if from_token in ("NOW", "now"): + defer.returnValue(([], self.upto_token)) + + current_token = self.upto_token + + from_token = int(from_token) + + if from_token == current_token: + defer.returnValue(([], current_token)) + + if self._LIMITED: + rows = yield self.update_function( + from_token, current_token, + limit=MAX_EVENTS_BEHIND + 1, + ) + + if len(rows) >= MAX_EVENTS_BEHIND: + raise Exception("stream %s has fallen behined" % (self.NAME)) + else: + rows = yield self.update_function( + from_token, current_token, + ) + + updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows] + + defer.returnValue((updates, current_token)) + + def current_token(self): + """Gets the current token of the underlying streams. Should be provided + by the sub classes + + Returns: + int + """ + raise NotImplementedError() + + def update_function(self, from_token, current_token, limit=None): + """Get updates between from_token and to_token. If Stream._LIMITED is + True then limit is provided, otherwise it's not. + + Returns: + Deferred(list(tuple)): the first entry in the tuple is the token for + that update, and the rest of the tuple gets used to construct + a ``ROW_TYPE`` instance + """ + raise NotImplementedError() + + +class EventsStream(Stream): + """We received a new event, or an event went from being an outlier to not + """ + NAME = "events" + ROW_TYPE = EventStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + self.current_token = store.get_current_events_token + self.update_function = store.get_all_new_forward_event_rows + + super(EventsStream, self).__init__(hs) + + +class BackfillStream(Stream): + """We fetched some old events and either we had never seen that event before + or it went from being an outlier to not. + """ + NAME = "backfill" + ROW_TYPE = BackfillStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + self.current_token = store.get_current_backfill_token + self.update_function = store.get_all_new_backfill_event_rows + + super(BackfillStream, self).__init__(hs) + + +class PresenceStream(Stream): + NAME = "presence" + _LIMITED = False + ROW_TYPE = PresenceStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + presence_handler = hs.get_presence_handler() + + self.current_token = store.get_current_presence_token + self.update_function = presence_handler.get_all_presence_updates + + super(PresenceStream, self).__init__(hs) + + +class TypingStream(Stream): + NAME = "typing" + _LIMITED = False + ROW_TYPE = TypingStreamRow + + def __init__(self, hs): + typing_handler = hs.get_typing_handler() + + self.current_token = typing_handler.get_current_token + self.update_function = typing_handler.get_all_typing_updates + + super(TypingStream, self).__init__(hs) + + +class ReceiptsStream(Stream): + NAME = "receipts" + ROW_TYPE = ReceiptsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_max_receipt_stream_id + self.update_function = store.get_all_updated_receipts + + super(ReceiptsStream, self).__init__(hs) + + +class PushRulesStream(Stream): + """A user has changed their push rules + """ + NAME = "push_rules" + ROW_TYPE = PushRulesStreamRow + + def __init__(self, hs): + self.store = hs.get_datastore() + super(PushRulesStream, self).__init__(hs) + + def current_token(self): + push_rules_token, _ = self.store.get_push_rules_stream_token() + return push_rules_token + + @defer.inlineCallbacks + def update_function(self, from_token, to_token, limit): + rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) + defer.returnValue([(row[0], row[2]) for row in rows]) + + +class PushersStream(Stream): + """A user has added/changed/removed a pusher + """ + NAME = "pushers" + ROW_TYPE = PushersStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_pushers_stream_token + self.update_function = store.get_all_updated_pushers_rows + + super(PushersStream, self).__init__(hs) + + +class CachesStream(Stream): + """A cache was invalidated on the master and no other stream would invalidate + the cache on the workers + """ + NAME = "caches" + ROW_TYPE = CachesStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_cache_stream_token + self.update_function = store.get_all_updated_caches + + super(CachesStream, self).__init__(hs) + + +class PublicRoomsStream(Stream): + """The public rooms list changed + """ + NAME = "public_rooms" + ROW_TYPE = PublicRoomsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_current_public_room_stream_id + self.update_function = store.get_all_new_public_rooms + + super(PublicRoomsStream, self).__init__(hs) + + +class DeviceListsStream(Stream): + """Someone added/changed/removed a device + """ + NAME = "device_lists" + _LIMITED = False + ROW_TYPE = DeviceListsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_device_stream_token + self.update_function = store.get_all_device_list_changes_for_remotes + + super(DeviceListsStream, self).__init__(hs) + + +class ToDeviceStream(Stream): + """New to_device messages for a client + """ + NAME = "to_device" + ROW_TYPE = ToDeviceStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_to_device_stream_token + self.update_function = store.get_all_new_device_messages + + super(ToDeviceStream, self).__init__(hs) + + +class FederationStream(Stream): + """Data to be sent over federation. Only available when master has federation + sending disabled. + """ + NAME = "federation" + ROW_TYPE = FederationStreamRow + + def __init__(self, hs): + federation_sender = hs.get_federation_sender() + + self.current_token = federation_sender.get_current_token + self.update_function = federation_sender.get_replication_rows + + super(FederationStream, self).__init__(hs) + + +class TagAccountDataStream(Stream): + """Someone added/removed a tag for a room + """ + NAME = "tag_account_data" + ROW_TYPE = TagAccountDataStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_max_account_data_stream_id + self.update_function = store.get_all_updated_tags + + super(TagAccountDataStream, self).__init__(hs) + + +class AccountDataStream(Stream): + """Global or per room account data was changed + """ + NAME = "account_data" + ROW_TYPE = AccountDataStreamRow + + def __init__(self, hs): + self.store = hs.get_datastore() + + self.current_token = self.store.get_max_account_data_stream_id + + super(AccountDataStream, self).__init__(hs) + + @defer.inlineCallbacks + def update_function(self, from_token, to_token, limit): + global_results, room_results = yield self.store.get_all_updated_account_data( + from_token, from_token, to_token, limit + ) + + results = list(room_results) + results.extend( + (stream_id, user_id, None, account_data_type, content,) + for stream_id, user_id, account_data_type, content in global_results + ) + + defer.returnValue(results) + + +STREAMS_MAP = { + stream.NAME: stream + for stream in ( + EventsStream, + BackfillStream, + PresenceStream, + TypingStream, + ReceiptsStream, + PushRulesStream, + PushersStream, + CachesStream, + PublicRoomsStream, + DeviceListsStream, + ToDeviceStream, + FederationStream, + TagAccountDataStream, + AccountDataStream, + ) +} diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index f9f5a3e07..aa8d874f9 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -40,6 +40,7 @@ from synapse.rest.client.v2_alpha import ( register, auth, receipts, + read_marker, keys, tokenrefresh, tags, @@ -88,6 +89,7 @@ class ClientRestResource(JsonResource): register.register_servlets(hs, client_resource) auth.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource) + read_marker.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 8930f1826..f15aa5c13 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -39,6 +39,7 @@ class ClientDirectoryServer(ClientV1RestServlet): def __init__(self, hs): super(ClientDirectoryServer, self).__init__(hs) + self.store = hs.get_datastore() self.handlers = hs.get_handlers() @defer.inlineCallbacks @@ -70,7 +71,10 @@ class ClientDirectoryServer(ClientV1RestServlet): logger.debug("Got servers: %s", servers) # TODO(erikj): Check types. - # TODO(erikj): Check that room exists + + room = yield self.store.get_room(room_id) + if room is None: + raise SynapseError(400, "Room does not exist") dir_handler = self.handlers.directory_handler diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 0bdd6b5b3..cd388770c 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -164,6 +164,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): else: msg_handler = self.handlers.message_handler event, context = yield msg_handler.create_event( + requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id, @@ -406,7 +407,13 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet): users_with_profile = yield self.state.get_current_user_in_room(room_id) defer.returnValue((200, { - "joined": users_with_profile + "joined": { + user_id: { + "avatar_url": profile.avatar_url, + "display_name": profile.display_name, + } + for user_id, profile in users_with_profile.iteritems() + } })) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 03141c623..c43b30b73 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,10 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req( + request, + self.hs.config.turn_allow_guests + ) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 20e765f48..1f5bc24cc 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -47,3 +47,13 @@ def client_v2_patterns(path_regex, releases=(0,), new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) patterns.append(re.compile("^" + new_prefix + path_regex)) return patterns + + +def set_timeline_upper_limit(filter_json, filter_timeline_limit): + if filter_timeline_limit < 0: + return # no upper limits + timeline = filter_json.get('room', {}).get('timeline', {}) + if 'limit' in timeline: + filter_json['room']['timeline']["limit"] = min( + filter_json['room']['timeline']['limit'], + filter_timeline_limit) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index b16079cec..0e0a187ef 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -16,7 +16,7 @@ from ._base import client_v2_patterns from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from twisted.internet import defer @@ -82,6 +82,13 @@ class RoomAccountDataServlet(RestServlet): body = parse_json_object_from_request(request) + if account_data_type == "m.fully_read": + raise SynapseError( + 405, + "Cannot set m.fully_read through this API." + " Use /rooms/!roomId:server.name/read_markers" + ) + max_id = yield self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index b4084fec6..d2b2fd66e 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -20,6 +20,7 @@ from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID from ._base import client_v2_patterns +from ._base import set_timeline_upper_limit import logging @@ -85,6 +86,11 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) + set_timeline_upper_limit( + content, + self.hs.config.filter_timeline_limit + ) + filter_id = yield self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content, diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py new file mode 100644 index 000000000..2f8784fe0 --- /dev/null +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from ._base import client_v2_patterns + +import logging + + +logger = logging.getLogger(__name__) + + +class ReadMarkerRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/rooms/(?P[^/]*)/read_markers$") + + def __init__(self, hs): + super(ReadMarkerRestServlet, self).__init__() + self.auth = hs.get_auth() + self.receipts_handler = hs.get_receipts_handler() + self.read_marker_handler = hs.get_read_marker_handler() + self.presence_handler = hs.get_presence_handler() + + @defer.inlineCallbacks + def on_POST(self, request, room_id): + requester = yield self.auth.get_user_by_req(request) + + yield self.presence_handler.bump_presence_active_time(requester.user) + + body = parse_json_object_from_request(request) + + read_event_id = body.get("m.read", None) + if read_event_id: + yield self.receipts_handler.received_client_receipt( + room_id, + "m.read", + user_id=requester.user.to_string(), + event_id=read_event_id + ) + + read_marker_event_id = body.get("m.fully_read", None) + if read_marker_event_id: + yield self.read_marker_handler.received_client_read_marker( + room_id, + user_id=requester.user.to_string(), + event_id=read_marker_event_id + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + ReadMarkerRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 3acf4eacd..1421c1815 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -21,7 +21,7 @@ from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import ( - RestServlet, parse_json_object_from_request, assert_params_in_request + RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string ) from synapse.util.msisdn import phone_number_to_msisdn @@ -31,6 +31,7 @@ import logging import hmac from hashlib import sha1 from synapse.util.async import run_on_reactor +from synapse.util.ratelimitutils import FederationRateLimiter # We ought to be using hmac.compare_digest() but on older pythons it doesn't @@ -115,6 +116,44 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): defer.returnValue((200, ret)) +class UsernameAvailabilityRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/register/available") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(UsernameAvailabilityRestServlet, self).__init__() + self.hs = hs + self.registration_handler = hs.get_handlers().registration_handler + self.ratelimiter = FederationRateLimiter( + hs.get_clock(), + # Time window of 2s + window_size=2000, + # Artificially delay requests if rate > sleep_limit/window_size + sleep_limit=1, + # Amount of artificial delay to apply + sleep_msec=1000, + # Error with 429 if more than reject_limit requests are queued + reject_limit=1, + # Allow 1 request at a time + concurrent_requests=1, + ) + + @defer.inlineCallbacks + def on_GET(self, request): + ip = self.hs.get_ip_from_request(request) + with self.ratelimiter.ratelimit(ip) as wait_deferred: + yield wait_deferred + + username = parse_string(request, "username", required=True) + + yield self.registration_handler.check_username(username) + + defer.returnValue((200, {"available": True})) + + class RegisterRestServlet(RestServlet): PATTERNS = client_v2_patterns("/register$") @@ -555,4 +594,5 @@ class RegisterRestServlet(RestServlet): def register_servlets(hs, http_server): EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) + UsernameAvailabilityRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a7a9e0a79..771e127ab 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -28,6 +28,7 @@ from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION from synapse.api.errors import SynapseError from synapse.api.constants import PresenceState from ._base import client_v2_patterns +from ._base import set_timeline_upper_limit import itertools import logging @@ -78,6 +79,7 @@ class SyncRestServlet(RestServlet): def __init__(self, hs): super(SyncRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() @@ -121,6 +123,8 @@ class SyncRestServlet(RestServlet): if filter_id.startswith('{'): try: filter_object = json.loads(filter_id) + set_timeline_upper_limit(filter_object, + self.hs.config.filter_timeline_limit) except: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) @@ -253,6 +257,7 @@ class SyncRestServlet(RestServlet): invite = serialize_event( room.invite, time_now, token_id=token_id, event_format=format_event_for_client_v2_without_room_id, + is_invite=True, ) unsigned = dict(invite.get("unsigned", {})) invite["unsigned"] = unsigned diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 31f94bc6e..6fceb23e2 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -36,7 +36,7 @@ class ThirdPartyProtocolsServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols() defer.returnValue((200, protocols)) @@ -54,7 +54,7 @@ class ThirdPartyProtocolServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols( only_protocol=protocol, @@ -77,7 +77,7 @@ class ThirdPartyUserServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop("access_token", None) @@ -101,7 +101,7 @@ class ThirdPartyLocationServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop("access_token", None) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index ff95269ba..be68d9a09 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -84,12 +84,11 @@ class LocalKey(Resource): } old_verify_keys = {} - for key in self.config.old_signing_keys: - key_id = "%s:%s" % (key.alg, key.version) + for key_id, key in self.config.old_signing_keys.items(): verify_key_bytes = key.encode() old_verify_keys[key_id] = { u"key": encode_base64(verify_key_bytes), - u"expired_ts": key.expired, + u"expired_ts": key.expired_ts, } tls_fingerprints = self.config.tls_fingerprints diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index c43b185e0..caca96c22 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -34,6 +34,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \ from synapse.util.async import Linearizer from synapse.util.stringutils import is_ascii from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.retryutils import NotRetryingDestination import os import errno @@ -181,7 +182,8 @@ class MediaRepository(object): logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise - + except NotRetryingDestination: + logger.warn("Not retrying destination %r", server_name) except Exception: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) diff --git a/synapse/server.py b/synapse/server.py index c57703204..12754c89a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -48,6 +48,7 @@ from synapse.handlers.typing import TypingHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.receipts import ReceiptsHandler +from synapse.handlers.read_marker import ReadMarkerHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier @@ -132,6 +133,8 @@ class HomeServer(object): 'federation_sender', 'receipts_handler', 'macaroon_generator', + 'tcp_replication', + 'read_marker_handler', ] def __init__(self, hostname, **kwargs): @@ -290,6 +293,12 @@ class HomeServer(object): def build_receipts_handler(self): return ReceiptsHandler(self) + def build_read_marker_handler(self): + return ReadMarkerHandler(self) + + def build_tcp_replication(self): + raise NotImplementedError() + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/state.py b/synapse/state.py index f6b83d888..02fee47f3 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -175,6 +175,17 @@ class StateHandler(object): ) defer.returnValue(joined_users) + @defer.inlineCallbacks + def get_current_hosts_in_room(self, room_id, latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + logger.debug("calling resolve_state_groups from get_current_hosts_in_room") + entry = yield self.resolve_state_groups(room_id, latest_event_ids) + joined_hosts = yield self.store.get_joined_hosts( + room_id, entry.state_id, entry.state + ) + defer.returnValue(joined_hosts) + @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): """Build an EventContext structure for the event. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index c659004e8..58b73af7d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -60,12 +60,12 @@ class LoggingTransaction(object): object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "after_callbacks", after_callbacks) - def call_after(self, callback, *args): + def call_after(self, callback, *args, **kwargs): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. """ - self.after_callbacks.append((callback, args)) + self.after_callbacks.append((callback, args, kwargs)) def __getattr__(self, name): return getattr(self.txn, name) @@ -319,8 +319,8 @@ class SQLBaseStore(object): inner_func, *args, **kwargs ) finally: - for after_callback, after_args in after_callbacks: - after_callback(*after_args) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) defer.returnValue(result) @defer.inlineCallbacks diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 813ad59e5..7157fb1df 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -210,7 +210,9 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_handlers[update_name] = update_handler def register_background_index_update(self, update_name, index_name, - table, columns, where_clause=None): + table, columns, where_clause=None, + unique=False, + psql_only=False): """Helper for store classes to do a background index addition To use: @@ -226,48 +228,80 @@ class BackgroundUpdateStore(SQLBaseStore): index_name (str): name of index to add table (str): table to add index to columns (list[str]): columns/expressions to include in index + unique (bool): true to make a UNIQUE index + psql_only: true to only create this index on psql databases (useful + for virtual sqlite tables) """ - # if this is postgres, we add the indexes concurrently. Otherwise - # we fall back to doing it inline - if isinstance(self.database_engine, engines.PostgresEngine): - conc = True - else: - conc = False - # We don't use partial indices on SQLite as it wasn't introduced - # until 3.8, and wheezy has 3.7 - where_clause = None - - sql = ( - "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" - " %(where_clause)s" - ) % { - "conc": "CONCURRENTLY" if conc else "", - "name": index_name, - "table": table, - "columns": ", ".join(columns), - "where_clause": "WHERE " + where_clause if where_clause else "" - } - - def create_index_concurrently(conn): + def create_index_psql(conn): conn.rollback() # postgres insists on autocommit for the index conn.set_session(autocommit=True) - c = conn.cursor() - c.execute(sql) - conn.set_session(autocommit=False) - def create_index(conn): + try: + c = conn.cursor() + + # If a previous attempt to create the index was interrupted, + # we may already have a half-built index. Let's just drop it + # before trying to create it again. + + sql = "DROP INDEX IF EXISTS %s" % (index_name,) + logger.debug("[SQL] %s", sql) + c.execute(sql) + + sql = ( + "CREATE %(unique)s INDEX CONCURRENTLY %(name)s" + " ON %(table)s" + " (%(columns)s) %(where_clause)s" + ) % { + "unique": "UNIQUE" if unique else "", + "name": index_name, + "table": table, + "columns": ", ".join(columns), + "where_clause": "WHERE " + where_clause if where_clause else "" + } + logger.debug("[SQL] %s", sql) + c.execute(sql) + finally: + conn.set_session(autocommit=False) + + def create_index_sqlite(conn): + # Sqlite doesn't support concurrent creation of indexes. + # + # We don't use partial indices on SQLite as it wasn't introduced + # until 3.8, and wheezy has 3.7 + # + # We assume that sqlite doesn't give us invalid indices; however + # we may still end up with the index existing but the + # background_updates not having been recorded if synapse got shut + # down at the wrong moment - hance we use IF NOT EXISTS. (SQLite + # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.) + sql = ( + "CREATE %(unique)s INDEX IF NOT EXISTS %(name)s ON %(table)s" + " (%(columns)s)" + ) % { + "unique": "UNIQUE" if unique else "", + "name": index_name, + "table": table, + "columns": ", ".join(columns), + } + c = conn.cursor() + logger.debug("[SQL] %s", sql) c.execute(sql) + if isinstance(self.database_engine, engines.PostgresEngine): + runner = create_index_psql + elif psql_only: + runner = None + else: + runner = create_index_sqlite + @defer.inlineCallbacks def updater(progress, batch_size): - logger.info("Adding index %s to %s", index_name, table) - if conc: - yield self.runWithConnection(create_index_concurrently) - else: - yield self.runWithConnection(create_index) + if runner is not None: + logger.info("Adding index %s to %s", index_name, table) + yield self.runWithConnection(runner) yield self._end_background_update(update_name) defer.returnValue(1) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 71e5ea112..747d2df62 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -33,6 +33,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, + max_entries=5000, ) super(ClientIpStore, self).__init__(hs) @@ -120,6 +121,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): where_clauses.append("(user_id = ? AND device_id = ?)") bindings.extend((user_id, device_id)) + if not where_clauses: + return [] + inner_select = ( "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " "WHERE %(where)s " diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 2714519d2..0b62b493d 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -325,23 +325,26 @@ class DeviceInboxStore(BackgroundUpdateStore): # we return. upper_pos = min(current_pos, last_pos + limit) sql = ( - "SELECT stream_id, user_id" + "SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" + " GROUP BY user_id" ) txn.execute(sql, (last_pos, upper_pos)) rows = txn.fetchall() sql = ( - "SELECT stream_id, destination" + "SELECT max(stream_id), destination" " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" + " GROUP BY destination" ) txn.execute(sql, (last_pos, upper_pos)) rows.extend(txn) + # Order by ascending stream ordering + rows.sort() + return rows return self.runInteraction( diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 53e36791d..d9936c88b 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -18,7 +18,7 @@ import ujson as json from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore +from ._base import SQLBaseStore, Cache from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks @@ -29,6 +29,14 @@ class DeviceStore(SQLBaseStore): def __init__(self, hs): super(DeviceStore, self).__init__(hs) + # Map of (user_id, device_id) -> bool. If there is an entry that implies + # the device exists. + self.device_id_exists_cache = Cache( + name="device_id_exists", + keylen=2, + max_entries=10000, + ) + self._clock.looping_call( self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) @@ -54,6 +62,10 @@ class DeviceStore(SQLBaseStore): defer.Deferred: boolean whether the device was inserted or an existing device existed with that ID. """ + key = (user_id, device_id) + if self.device_id_exists_cache.get(key, None): + defer.returnValue(False) + try: inserted = yield self._simple_insert( "devices", @@ -65,6 +77,7 @@ class DeviceStore(SQLBaseStore): desc="store_device", or_ignore=True, ) + self.device_id_exists_cache.prefill(key, True) defer.returnValue(inserted) except Exception as e: logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" @@ -93,6 +106,7 @@ class DeviceStore(SQLBaseStore): desc="get_device", ) + @defer.inlineCallbacks def delete_device(self, user_id, device_id): """Delete a device. @@ -102,12 +116,15 @@ class DeviceStore(SQLBaseStore): Returns: defer.Deferred """ - return self._simple_delete_one( + yield self._simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, desc="delete_device", ) + self.device_id_exists_cache.invalidate((user_id, device_id)) + + @defer.inlineCallbacks def delete_devices(self, user_id, device_ids): """Deletes several devices. @@ -117,13 +134,15 @@ class DeviceStore(SQLBaseStore): Returns: defer.Deferred """ - return self._simple_delete_many( + yield self._simple_delete_many( table="devices", column="device_id", iterable=device_ids, keyvalues={"user_id": user_id}, desc="delete_devices", ) + for device_id in device_ids: + self.device_id_exists_cache.invalidate((user_id, device_id)) def update_device(self, user_id, device_id, new_display_name=None): """Update a device. @@ -533,7 +552,7 @@ class DeviceStore(SQLBaseStore): rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) defer.returnValue(set(row[0] for row in rows)) - def get_all_device_list_changes_for_remotes(self, from_key): + def get_all_device_list_changes_for_remotes(self, from_key, to_key): """Return a list of `(stream_id, user_id, destination)` which is the combined list of changes to devices, and which destinations need to be poked. `destination` may be None if no destinations need to be poked. @@ -541,11 +560,11 @@ class DeviceStore(SQLBaseStore): sql = """ SELECT stream_id, user_id, destination FROM device_lists_stream LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) - WHERE stream_id > ? + WHERE ? < stream_id AND stream_id <= ? """ return self._execute( "get_all_device_list_changes_for_remotes", None, - sql, from_key, + sql, from_key, to_key ) @defer.inlineCallbacks diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 7cbc1470f..e00f31da2 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -14,7 +14,7 @@ # limitations under the License. from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.util.caches.descriptors import cached from canonicaljson import encode_canonical_json import ujson as json @@ -123,18 +123,24 @@ class EndToEndKeyStore(SQLBaseStore): return result @defer.inlineCallbacks - def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): - """Insert some new one time keys for a device. + def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + """Retrieve a number of one-time keys for a user - Checks if any of the keys are already inserted, if they are then check - if they match. If they don't then we raise an error. + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + key_ids(list[str]): list of key ids (excluding algorithm) to + retrieve + + Returns: + deferred resolving to Dict[(str, str), str]: map from (algorithm, + key_id) to json string for key """ - # First we check if we have already persisted any of the keys. rows = yield self._simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", - iterable=[key_id for _, key_id, _ in key_list], + iterable=key_ids, retcols=("algorithm", "key_id", "key_json",), keyvalues={ "user_id": user_id, @@ -143,20 +149,22 @@ class EndToEndKeyStore(SQLBaseStore): desc="add_e2e_one_time_keys_check", ) - existing_key_map = { + defer.returnValue({ (row["algorithm"], row["key_id"]): row["key_json"] for row in rows - } + }) - new_keys = [] # Keys that we need to insert - for algorithm, key_id, json_bytes in key_list: - ex_bytes = existing_key_map.get((algorithm, key_id), None) - if ex_bytes: - if json_bytes != ex_bytes: - raise SynapseError( - 400, "One time key with key_id %r already exists" % (key_id,) - ) - else: - new_keys.append((algorithm, key_id, json_bytes)) + @defer.inlineCallbacks + def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + """Insert some new one time keys for a device. Errors if any of the + keys already exist. + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + time_now(long): insertion time to record (ms since epoch) + new_keys(iterable[(str, str, str)]: keys to add - each a tuple of + (algorithm, key_id, key json) + """ def _add_e2e_one_time_keys(txn): # We are protected from race between lookup and insertion due to @@ -177,10 +185,14 @@ class EndToEndKeyStore(SQLBaseStore): for algorithm, key_id, json_bytes in new_keys ], ) + txn.call_after( + self.count_e2e_one_time_keys.invalidate, (user_id, device_id,) + ) yield self.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) + @cached(max_entries=10000) def count_e2e_one_time_keys(self, user_id, device_id): """ Count the number of one time keys the server has for a device Returns: @@ -225,6 +237,9 @@ class EndToEndKeyStore(SQLBaseStore): ) for user_id, device_id, algorithm, key_id in delete: txn.execute(sql, (user_id, device_id, algorithm, key_id)) + txn.call_after( + self.count_e2e_one_time_keys.invalidate, (user_id, device_id,) + ) return result return self.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys @@ -242,3 +257,4 @@ class EndToEndKeyStore(SQLBaseStore): keyvalues={"user_id": user_id, "device_id": device_id}, desc="delete_e2e_one_time_keys_by_device" ) + self.count_e2e_one_time_keys.invalidate((user_id, device_id,)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 3f6833fad..c4aeb4880 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.state import resolve_events from synapse.util.caches.descriptors import cached +from synapse.types import get_domain_from_id from canonicaljson import encode_canonical_json from collections import deque, namedtuple, OrderedDict @@ -49,6 +50,9 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) persist_event_counter = metrics.register_counter("persisted_events") +event_counter = metrics.register_counter( + "persisted_events_sep", labels=["type", "origin_type", "origin_entity"] +) def encode_json(json_object): @@ -203,6 +207,18 @@ class EventsStore(SQLBaseStore): where_clause="contains_url = true AND outlier = false", ) + # an event_id index on event_search is useful for the purge_history + # api. Plus it means we get to enforce some integrity with a UNIQUE + # clause + self.register_background_index_update( + "event_search_event_id_idx", + index_name="event_search_event_id_idx", + table="event_search", + columns=["event_id"], + unique=True, + psql_only=True, + ) + self._event_persist_queue = _EventPeristenceQueue() def persist_events(self, events_and_contexts, backfilled=False): @@ -370,6 +386,23 @@ class EventsStore(SQLBaseStore): new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc_by(len(chunk)) + for event, context in chunk: + if context.app_service: + origin_type = "local" + origin_entity = context.app_service.id + elif self.hs.is_mine_id(event.sender): + origin_type = "local" + origin_entity = "*client*" + else: + origin_type = "remote" + origin_entity = get_domain_from_id(event.sender) + + event_counter.inc(event.type, origin_type, origin_entity) + + for room_id, (_, _, new_state) in current_state_for_room.iteritems(): + self.get_current_state_ids.prefill( + (room_id, ), new_state + ) @defer.inlineCallbacks def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): @@ -419,10 +452,10 @@ class EventsStore(SQLBaseStore): Assumes that we are only persisting events for one room at a time. Returns: - 2-tuple (to_delete, to_insert) where both are state dicts, i.e. - (type, state_key) -> event_id. `to_delete` are the entries to + 3-tuple (to_delete, to_insert, new_state) where both are state dicts, + i.e. (type, state_key) -> event_id. `to_delete` are the entries to first be deleted from current_state_events, `to_insert` are entries - to insert. + to insert. `new_state` is the full set of state. May return None if there are no changes to be applied. """ # Now we need to work out the different state sets for @@ -529,7 +562,7 @@ class EventsStore(SQLBaseStore): if ev_id in events_to_insert } - defer.returnValue((to_delete, to_insert)) + defer.returnValue((to_delete, to_insert, current_state)) @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, @@ -682,7 +715,7 @@ class EventsStore(SQLBaseStore): def _update_current_state_txn(self, txn, state_delta_by_room): for room_id, current_state_tuple in state_delta_by_room.iteritems(): - to_delete, to_insert = current_state_tuple + to_delete, to_insert, _ = current_state_tuple txn.executemany( "DELETE FROM current_state_events WHERE event_id = ?", [(ev_id,) for ev_id in to_delete.itervalues()], @@ -1327,11 +1360,26 @@ class EventsStore(SQLBaseStore): def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) - def _get_events_from_cache(self, events, allow_rejected): + def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): + """Fetch events from the caches + + Args: + events (list(str)): list of event_ids to fetch + allow_rejected (bool): Whether to teturn events that were rejected + update_metrics (bool): Whether to update the cache hit ratio metrics + + Returns: + dict of event_id -> _EventCacheEntry for each event_id in cache. If + allow_rejected is `False` then there will still be an entry but it + will be `None` + """ event_map = {} for event_id in events: - ret = self._get_event_cache.get((event_id,), None) + ret = self._get_event_cache.get( + (event_id,), None, + update_metrics=update_metrics, + ) if not ret: continue @@ -1771,6 +1819,94 @@ class EventsStore(SQLBaseStore): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn) + + return new_event_updates + return self.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + return self.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + @cached(num_args=5, max_entries=10) def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): @@ -1903,6 +2039,8 @@ class EventsStore(SQLBaseStore): 400, "topological_ordering is greater than forward extremeties" ) + logger.debug("[purge] looking for events to delete") + txn.execute( "SELECT event_id, state_key FROM events" " LEFT JOIN state_events USING (room_id, event_id)" @@ -1911,9 +2049,19 @@ class EventsStore(SQLBaseStore): ) event_rows = txn.fetchall() + to_delete = [ + (event_id,) for event_id, state_key in event_rows + if state_key is None and not self.hs.is_mine_id(event_id) + ] + logger.info( + "[purge] found %i events before cutoff, of which %i are remote" + " non-state events to delete", len(event_rows), len(to_delete)) + for event_id, state_key in event_rows: txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + logger.debug("[purge] Finding new backward extremities") + # We calculate the new entries for the backward extremeties by finding # all events that point to events that are to be purged txn.execute( @@ -1926,6 +2074,8 @@ class EventsStore(SQLBaseStore): ) new_backwards_extrems = txn.fetchall() + logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems) + txn.execute( "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) @@ -1940,6 +2090,8 @@ class EventsStore(SQLBaseStore): ] ) + logger.debug("[purge] finding redundant state groups") + # Get all state groups that are only referenced by events that are # to be deleted. txn.execute( @@ -1955,15 +2107,20 @@ class EventsStore(SQLBaseStore): ) state_rows = txn.fetchall() - state_groups_to_delete = [sg for sg, in state_rows] + logger.debug("[purge] found %i redundant state groups", len(state_rows)) + + # make a set of the redundant state groups, so that we can look them up + # efficiently + state_groups_to_delete = set([sg for sg, in state_rows]) # Now we get all the state groups that rely on these state groups - new_state_edges = [] - chunks = [ - state_groups_to_delete[i:i + 100] - for i in xrange(0, len(state_groups_to_delete), 100) - ] - for chunk in chunks: + logger.debug("[purge] finding state groups which depend on redundant" + " state groups") + remaining_state_groups = [] + for i in xrange(0, len(state_rows), 100): + chunk = [sg for sg, in state_rows[i:i + 100]] + # look for state groups whose prev_state_group is one we are about + # to delete rows = self._simple_select_many_txn( txn, table="state_group_edges", @@ -1972,21 +2129,28 @@ class EventsStore(SQLBaseStore): retcols=["state_group"], keyvalues={}, ) - new_state_edges.extend(row["state_group"] for row in rows) + remaining_state_groups.extend( + row["state_group"] for row in rows - # Now we turn the state groups that reference to-be-deleted state groups - # to non delta versions. - for new_state_edge in new_state_edges: - curr_state = self._get_state_groups_from_groups_txn( - txn, [new_state_edge], types=None + # exclude state groups we are about to delete: no point in + # updating them + if row["state_group"] not in state_groups_to_delete ) - curr_state = curr_state[new_state_edge] + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.debug("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn( + txn, [sg], types=None + ) + curr_state = curr_state[sg] self._simple_delete_txn( txn, table="state_groups_state", keyvalues={ - "state_group": new_state_edge, + "state_group": sg, } ) @@ -1994,7 +2158,7 @@ class EventsStore(SQLBaseStore): txn, table="state_group_edges", keyvalues={ - "state_group": new_state_edge, + "state_group": sg, } ) @@ -2003,7 +2167,7 @@ class EventsStore(SQLBaseStore): table="state_groups_state", values=[ { - "state_group": new_state_edge, + "state_group": sg, "room_id": room_id, "type": key[0], "state_key": key[1], @@ -2013,6 +2177,7 @@ class EventsStore(SQLBaseStore): ], ) + logger.debug("[purge] removing redundant state groups") txn.executemany( "DELETE FROM state_groups_state WHERE state_group = ?", state_rows @@ -2021,22 +2186,21 @@ class EventsStore(SQLBaseStore): "DELETE FROM state_groups WHERE id = ?", state_rows ) + # Delete all non-state + logger.debug("[purge] removing events from event_to_state_groups") txn.executemany( "DELETE FROM event_to_state_groups WHERE event_id = ?", [(event_id,) for event_id, _ in event_rows] ) + logger.debug("[purge] updating room_depth") txn.execute( "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", (topological_ordering, room_id,) ) # Delete all remote non-state events - to_delete = [ - (event_id,) for event_id, state_key in event_rows - if state_key is None and not self.hs.is_mine_id(event_id) - ] for table in ( "events", "event_json", @@ -2052,16 +2216,15 @@ class EventsStore(SQLBaseStore): "event_signatures", "rejections", ): + logger.debug("[purge] removing remote non-state events from %s", table) + txn.executemany( "DELETE FROM %s WHERE event_id = ?" % (table,), to_delete ) - txn.executemany( - "DELETE FROM events WHERE event_id = ?", - to_delete - ) # Mark all state and own events as outliers + logger.debug("[purge] marking remaining events as outliers") txn.executemany( "UPDATE events SET outlier = ?" " WHERE event_id = ?", @@ -2071,6 +2234,30 @@ class EventsStore(SQLBaseStore): ] ) + logger.info("[purge] done") + + @defer.inlineCallbacks + def is_event_after(self, event_id1, event_id2): + """Returns True if event_id1 is after event_id2 in the stream + """ + to_1, so_1 = yield self._get_event_ordering(event_id1) + to_2, so_2 = yield self._get_event_ordering(event_id2) + defer.returnValue((to_1, so_1) > (to_2, so_2)) + + @defer.inlineCallbacks + def _get_event_ordering(self, event_id): + res = yield self._simple_select_one( + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True + ) + + if not res: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"]))) + AllNewEventsResult = namedtuple("AllNewEventsResult", [ "new_forward_events", "new_backfill_events", diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index cbec25596..0a819d32c 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -16,6 +16,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.push.baserules import list_with_base_rules +from synapse.api.constants import EventTypes from twisted.internet import defer import logging @@ -184,6 +185,18 @@ class PushRuleStore(SQLBaseStore): if uid in local_users_in_room: user_ids.add(uid) + forgotten = yield self.who_forgot_in_room( + event.room_id, on_invalidate=cache_context.invalidate, + ) + + for row in forgotten: + user_id = row["user_id"] + event_id = row["event_id"] + + mem_id = current_state_ids.get((EventTypes.Member, user_id), None) + if event_id == mem_id: + user_ids.discard(user_id) + rules_by_user = yield self.bulk_get_push_rules( user_ids, on_invalidate=cache_context.invalidate, ) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8cc9f0353..34d2f82b7 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -135,6 +135,48 @@ class PusherStore(SQLBaseStore): "get_all_updated_pushers", get_all_updated_pushers_txn ) + def get_all_updated_pushers_rows(self, last_id, current_id, limit): + """Get all the pushers that have changed between the given tokens. + + Returns: + Deferred(list(tuple)): each tuple consists of: + stream_id (str) + user_id (str) + app_id (str) + pushkey (str) + was_deleted (bool): whether the pusher was added/updated (False) + or deleted (True) + """ + + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_pushers_rows_txn(txn): + sql = ( + "SELECT id, user_name, app_id, pushkey" + " FROM pushers" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + results = [list(row) + [False] for row in txn] + + sql = ( + "SELECT stream_id, user_id, app_id, pushkey" + " FROM deleted_pushers" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + results.extend(list(row) + [True] for row in txn) + results.sort() # Sort so that they're ordered by stream id + + return results + return self.runInteraction( + "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn + ) + @cachedInlineCallbacks(num_args=1, max_entries=15000) def get_if_user_has_pusher(self, user_id): # This only exists for the cachedList decorator diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 6b0f8c278..efb90c3c9 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -47,10 +47,13 @@ class ReceiptsStore(SQLBaseStore): # Returns an ObservableDeferred res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None) - if res and res.called and user_id in res.result: - # We'd only be adding to the set, so no point invalidating if the - # user is already there - return + if res: + if isinstance(res, defer.Deferred) and res.called: + res = res.result + if user_id in res: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return self.get_users_with_read_receipts_in_room.invalidate((room_id,)) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index e4c56cc17..5d543652b 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from ._base import SQLBaseStore from .engines import PostgresEngine, Sqlite3Engine @@ -33,6 +33,11 @@ OpsLevel = collections.namedtuple( ("ban_level", "kick_level", "redact_level",) ) +RatelimitOverride = collections.namedtuple( + "RatelimitOverride", + ("messages_per_second", "burst_count",) +) + class RoomStore(SQLBaseStore): @@ -473,3 +478,32 @@ class RoomStore(SQLBaseStore): return self.runInteraction( "get_all_new_public_rooms", get_all_new_public_rooms ) + + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user + + Args: + user_id (str) + + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self._simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) + + if row: + defer.returnValue(RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + )) + else: + defer.returnValue(None) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 367dbbbcf..0829ae5be 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -18,7 +18,9 @@ from twisted.internet import defer from collections import namedtuple from ._base import SQLBaseStore +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.stringutils import to_ascii from synapse.api.constants import Membership, EventTypes from synapse.types import get_domain_from_id @@ -35,6 +37,13 @@ RoomsForUser = namedtuple( ) +# We store this using a namedtuple so that we save about 3x space over using a +# dict. +ProfileInfo = namedtuple( + "ProfileInfo", ("avatar_url", "display_name") +) + + _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" @@ -139,7 +148,7 @@ class RoomMemberStore(SQLBaseStore): hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) defer.returnValue(hosts) - @cached(max_entries=500000, iterable=True) + @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): def f(txn): sql = ( @@ -152,7 +161,7 @@ class RoomMemberStore(SQLBaseStore): ) txn.execute(sql, (room_id, Membership.JOIN,)) - return [r[0] for r in txn] + return [to_ascii(r[0]) for r in txn] return self.runInteraction("get_users_in_room", f) @cached() @@ -378,7 +387,9 @@ class RoomMemberStore(SQLBaseStore): state_group = object() return self._get_joined_users_from_context( - event.room_id, state_group, context.current_state_ids, event=event, + event.room_id, state_group, context.current_state_ids, + event=event, + context=context, ) def get_joined_users_from_state(self, room_id, state_group, state_ids): @@ -396,46 +407,95 @@ class RoomMemberStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, max_entries=100000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, - cache_context, event=None): + cache_context, event=None, context=None): # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None + users_in_room = {} member_event_ids = [ e_id for key, e_id in current_state_ids.iteritems() if key[0] == EventTypes.Member ] - rows = yield self._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids, - retcols=['user_id', 'display_name', 'avatar_url'], - keyvalues={ - "membership": Membership.JOIN, - }, - batch_size=500, - desc="_get_joined_users_from_context", + if context is not None: + # If we have a context with a delta from a previous state group, + # check if we also have the result from the previous group in cache. + # If we do then we can reuse that result and simply update it with + # any membership changes in `delta_ids` + if context.prev_group and context.delta_ids: + prev_res = self._get_joined_users_from_context.cache.get( + (room_id, context.prev_group), None + ) + if prev_res and isinstance(prev_res, dict): + users_in_room = dict(prev_res) + member_event_ids = [ + e_id + for key, e_id in context.delta_ids.iteritems() + if key[0] == EventTypes.Member + ] + for etype, state_key in context.delta_ids: + users_in_room.pop(state_key, None) + + # We check if we have any of the member event ids in the event cache + # before we ask the DB + + # We don't update the event cache hit ratio as it completely throws off + # the hit ratio counts. After all, we don't populate the cache if we + # miss it here + event_map = self._get_events_from_cache( + member_event_ids, + allow_rejected=False, + update_metrics=False, ) - users_in_room = { - row["user_id"]: { - "display_name": row["display_name"], - "avatar_url": row["avatar_url"], - } - for row in rows - } + missing_member_event_ids = [] + for event_id in member_event_ids: + ev_entry = event_map.get(event_id) + if ev_entry: + if ev_entry.event.membership == Membership.JOIN: + users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( + display_name=to_ascii( + ev_entry.event.content.get("displayname", None) + ), + avatar_url=to_ascii( + ev_entry.event.content.get("avatar_url", None) + ), + ) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=missing_member_event_ids, + retcols=('user_id', 'display_name', 'avatar_url',), + keyvalues={ + "membership": Membership.JOIN, + }, + batch_size=500, + desc="_get_joined_users_from_context", + ) + + users_in_room.update({ + to_ascii(row["user_id"]): ProfileInfo( + avatar_url=to_ascii(row["avatar_url"]), + display_name=to_ascii(row["display_name"]), + ) + for row in rows + }) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room[event.state_key] = { - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } + users_in_room[to_ascii(event.state_key)] = ProfileInfo( + display_name=to_ascii(event.content.get("displayname", None)), + avatar_url=to_ascii(event.content.get("avatar_url", None)), + ) defer.returnValue(users_in_room) @@ -474,6 +534,45 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(False) + def get_joined_hosts(self, room_id, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_hosts( + room_id, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) + def _get_joined_hosts(self, room_id, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + joined_hosts = set() + for etype, state_key in current_state_ids: + if etype == EventTypes.Member: + try: + host = get_domain_from_id(state_key) + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + if host in joined_hosts: + continue + + event_id = current_state_ids[(etype, state_key)] + event = yield self.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + joined_hosts.add(intern_string(host)) + + defer.returnValue(joined_hosts) + @defer.inlineCallbacks def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py index 784f3b348..20ad8bd5a 100644 --- a/synapse/storage/schema/delta/37/remove_auth_idx.py +++ b/synapse/storage/schema/delta/37/remove_auth_idx.py @@ -36,6 +36,10 @@ DROP INDEX IF EXISTS transactions_have_ref; -- and is used incredibly rarely. DROP INDEX IF EXISTS events_order_topo_stream_room; +-- an equivalent index to this actually gets re-created in delta 41, because it +-- turned out that deleting it wasn't a great plan :/. In any case, let's +-- delete it here, and delta 41 will create a new one with an added UNIQUE +-- constraint DROP INDEX IF EXISTS event_search_ev_idx; """ diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql new file mode 100644 index 000000000..5d9cfecf3 --- /dev/null +++ b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_event_id_idx', '{}'); diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/schema/delta/41/ratelimit.sql new file mode 100644 index 000000000..a194bf023 --- /dev/null +++ b/synapse/storage/schema/delta/41/ratelimit.sql @@ -0,0 +1,22 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE ratelimit_override ( + user_id TEXT NOT NULL, + messages_per_second BIGINT, + burst_count BIGINT +); + +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index fb23f6f46..85acf2ad1 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,8 +14,9 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches import intern_string +from synapse.util.stringutils import to_ascii from synapse.storage.engines import PostgresEngine from twisted.internet import defer @@ -69,17 +70,33 @@ class StateStore(SQLBaseStore): where_clause="type='m.room.member'", ) - @cachedInlineCallbacks(max_entries=100000, iterable=True) + @cached(max_entries=100000, iterable=True) def get_current_state_ids(self, room_id): - rows = yield self._simple_select_list( - table="current_state_events", - keyvalues={"room_id": room_id}, - retcols=["event_id", "type", "state_key"], - desc="_calculate_state_delta", + """Get the current state event ids for a room based on the + current_state_events table. + + Args: + room_id (str) + + Returns: + deferred: dict of (type, state_key) -> event_id + """ + def _get_current_state_ids_txn(txn): + txn.execute( + """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """, + (room_id,) + ) + + return { + (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn + } + + return self.runInteraction( + "get_current_state_ids", + _get_current_state_ids_txn, ) - defer.returnValue({ - (r["type"], r["state_key"]): r["event_id"] for r in rows - }) @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): @@ -210,6 +227,18 @@ class StateStore(SQLBaseStore): ], ) + # Prefill the state group cache with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=context.state_group, + value=dict(context.current_state_ids), + full=True, + ) + self._simple_insert_many_txn( txn, table="event_to_state_groups", @@ -263,12 +292,7 @@ class StateStore(SQLBaseStore): return count - @cached(num_args=2, max_entries=100000, iterable=True) - def _get_state_group_from_group(self, group, types): - raise NotImplementedError() - - @cachedList(cached_method_name="_get_state_group_from_group", - list_name="groups", num_args=2, inlineCallbacks=True) + @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ @@ -496,7 +520,7 @@ class StateStore(SQLBaseStore): state_map = yield self.get_state_ids_for_events([event_id], types) defer.returnValue(state_map[event_id]) - @cached(num_args=2, max_entries=100000) + @cached(num_args=2, max_entries=50000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", @@ -644,7 +668,7 @@ class StateStore(SQLBaseStore): state_dict = results[group] state_dict.update( - ((intern_string(k[0]), intern_string(k[1])), v) + ((intern_string(k[0]), intern_string(k[1])), to_ascii(v)) for k, v in group_state_dict.iteritems() ) diff --git a/synapse/types.py b/synapse/types.py index 9666f9d73..445bdcb4d 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -56,10 +56,10 @@ def create_requester(user_id, access_token_id=None, is_guest=False, def get_domain_from_id(string): - try: - return string.split(":", 1)[1] - except IndexError: + idx = string.find(":") + if idx == -1: raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[idx + 1:] class DomainSpecificString( @@ -216,9 +216,7 @@ class StreamToken( return self def copy_and_replace(self, key, new_value): - d = self._asdict() - d[key] = new_value - return StreamToken(**d) + return self._replace(**{key: new_value}) StreamToken.START = StreamToken( diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 98a5a26ac..2a2360ab5 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class DeferredTimedOutError(SynapseError): def __init__(self): - super(SynapseError, self).__init__(504, "Timed out") + super(DeferredTimedOutError, self).__init__(504, "Timed out") def unwrapFirstError(failure): diff --git a/synapse/util/async.py b/synapse/util/async.py index 35380bf8e..1453faf0e 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -89,6 +89,11 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) def observe(self): + """Observe the underlying deferred. + + Can return either a deferred if the underlying deferred is still pending + (or has failed), or the actual value. Callers may need to use maybeDeferred. + """ if not self._result: d = defer.Deferred() @@ -101,7 +106,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return defer.succeed(res) if success else defer.fail(res) + return res if success else defer.fail(res) def observers(self): return self._observers diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 8a7774a88..4a83c46d9 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -14,13 +14,10 @@ # limitations under the License. import synapse.metrics -from lrucache import LruCache import os CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -DEBUG_CACHES = False - metrics = synapse.metrics.get_metrics_for("synapse.util.caches") caches_by_name = {} @@ -40,10 +37,6 @@ def register_cache(name, cache): ) -_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR)) -_stirng_cache_metrics = register_cache("string_cache", _string_cache) - - KNOWN_KEYS = { key: key for key in ( @@ -67,14 +60,16 @@ KNOWN_KEYS = { def intern_string(string): - """Takes a (potentially) unicode string and interns using custom cache + """Takes a (potentially) unicode string and interns it if it's ascii """ - new_str = _string_cache.setdefault(string, string) - if new_str is string: - _stirng_cache_metrics.inc_hits() - else: - _stirng_cache_metrics.inc_misses() - return new_str + if string is None: + return None + + try: + string = string.encode("ascii") + return intern(string) + except UnicodeEncodeError: + return string def intern_dict(dictionary): @@ -87,13 +82,9 @@ def intern_dict(dictionary): def _intern_known_values(key, value): - intern_str_keys = ("event_id", "room_id") - intern_unicode_keys = ("sender", "user_id", "type", "state_key") + intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",) - if key in intern_str_keys: - return intern(value.encode('ascii')) - - if key in intern_unicode_keys: + if key in intern_keys: return intern_string(value) return value diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5c30ed235..48dcbafee 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,8 +18,9 @@ from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError, logcontext from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry +from synapse.util.stringutils import to_ascii -from . import DEBUG_CACHES, register_cache +from . import register_cache from twisted.internet import defer from collections import namedtuple @@ -76,7 +77,7 @@ class Cache(object): self.cache = LruCache( max_size=max_entries, keylen=keylen, cache_type=cache_type, - size_callback=(lambda d: len(d.result)) if iterable else None, + size_callback=(lambda d: len(d)) if iterable else None, ) self.name = name @@ -95,13 +96,26 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, key, default=_CacheSentinel, callback=None): + def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True): + """Looks the key up in the caches. + + Args: + key(tuple) + default: What is returned if key is not in the caches. If not + specified then function throws KeyError instead + callback(fn): Gets called when the entry in the cache is invalidated + update_metrics (bool): whether to update the cache hit rate metrics + + Returns: + Either a Deferred or the raw result + """ callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _CacheSentinel) if val is not _CacheSentinel: if val.sequence == self.sequence: val.callbacks.update(callbacks) - self.metrics.inc_hits() + if update_metrics: + self.metrics.inc_hits() return val.deferred val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) @@ -109,7 +123,8 @@ class Cache(object): self.metrics.inc_hits() return val - self.metrics.inc_misses() + if update_metrics: + self.metrics.inc_misses() if default is _CacheSentinel: raise KeyError() @@ -137,7 +152,7 @@ class Cache(object): if self.sequence == entry.sequence: existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry is entry: - self.cache.set(key, entry.deferred, entry.callbacks) + self.cache.set(key, result, entry.callbacks) else: entry.invalidate() else: @@ -152,10 +167,6 @@ class Cache(object): def invalidate(self, key): self.check_thread() - if not isinstance(key, tuple): - raise TypeError( - "The cache key must be a tuple not %r" % (type(key),) - ) # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) @@ -224,8 +235,20 @@ class _CacheDescriptorBase(object): ) self.num_args = num_args + + # list of the names of the args used as the cache key self.arg_names = all_args[1:num_args + 1] + # self.arg_defaults is a map of arg name to its default value for each + # argument that has a default value + if arg_spec.defaults: + self.arg_defaults = dict(zip( + all_args[-len(arg_spec.defaults):], + arg_spec.defaults + )) + else: + self.arg_defaults = {} + if "cache_context" in self.arg_names: raise Exception( "cache_context arg cannot be included among the cache keys" @@ -289,18 +312,47 @@ class CacheDescriptor(_CacheDescriptorBase): iterable=self.iterable, ) + def get_cache_key_gen(args, kwargs): + """Given some args/kwargs return a generator that resolves into + the cache_key. + + We loop through each arg name, looking up if its in the `kwargs`, + otherwise using the next argument in `args`. If there are no more + args then we try looking the arg name up in the defaults + """ + pos = 0 + for nm in self.arg_names: + if nm in kwargs: + yield kwargs[nm] + elif pos < len(args): + yield args[pos] + pos += 1 + else: + yield self.arg_defaults[nm] + + # By default our cache key is a tuple, but if there is only one item + # then don't bother wrapping in a tuple. This is to save memory. + if self.num_args == 1: + nm = self.arg_names[0] + + def get_cache_key(args, kwargs): + if nm in kwargs: + return kwargs[nm] + elif len(args): + return args[0] + else: + return self.arg_defaults[nm] + else: + def get_cache_key(args, kwargs): + return tuple(get_cache_key_gen(args, kwargs)) + @functools.wraps(self.orig) def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) - # Add temp cache_context so inspect.getcallargs doesn't explode - if self.add_cache_context: - kwargs["cache_context"] = None - - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + cache_key = get_cache_key(args, kwargs) # Add our own `cache_context` to argument list if the wrapped function # has asked for one @@ -310,20 +362,10 @@ class CacheDescriptor(_CacheDescriptorBase): try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) - observer = cached_result_d.observe() - if DEBUG_CACHES: - @defer.inlineCallbacks - def check_result(cached_result): - actual_result = yield self.function_to_call(obj, *args, **kwargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, cache_key, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) - observer.addCallback(check_result) + if isinstance(cached_result_d, ObservableDeferred): + observer = cached_result_d.observe() + else: + observer = cached_result_d except KeyError: ret = defer.maybeDeferred( @@ -337,16 +379,30 @@ class CacheDescriptor(_CacheDescriptorBase): ret.addErrback(onErr) + # If our cache_key is a string, try to convert to ascii to save + # a bit of space in large caches + if isinstance(cache_key, basestring): + cache_key = to_ascii(cache_key) + result_d = ObservableDeferred(ret, consumeErrors=True) cache.set(cache_key, result_d, callback=invalidate_callback) observer = result_d.observe() - return logcontext.make_deferred_yieldable(observer) + if isinstance(observer, defer.Deferred): + return logcontext.make_deferred_yieldable(observer) + else: + return observer + + if self.num_args == 1: + wrapped.invalidate = lambda key: cache.invalidate(key[0]) + wrapped.prefill = lambda key, val: cache.prefill(key[0], val) + else: + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill - wrapped.invalidate = cache.invalidate wrapped.invalidate_all = cache.invalidate_all - wrapped.invalidate_many = cache.invalidate_many - wrapped.prefill = cache.prefill wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -419,7 +475,9 @@ class CacheListDescriptor(_CacheDescriptorBase): try: res = cache.get(tuple(key), callback=invalidate_callback) - if not res.has_succeeded(): + if not isinstance(res, ObservableDeferred): + results[arg] = res + elif not res.has_succeeded(): res = res.observe() res.addCallback(lambda r, arg: (arg, r), arg) cached_defers[arg] = res diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 857afee7c..990216145 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -334,12 +334,8 @@ def preserve_fn(f): LoggingContext.set_current_context(LoggingContext.sentinel) return result - # XXX: why is this here rather than inside g? surely we want to preserve - # the context from the time the function was called, not when it was - # wrapped? - current = LoggingContext.current_context() - def g(*args, **kwargs): + current = LoggingContext.current_context() res = f(*args, **kwargs) if isinstance(res, defer.Deferred) and not res.called: # The function will have reset the context before returning, so diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index a100f151d..95a6168e1 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -40,3 +40,17 @@ def is_ascii(s): return False else: return True + + +def to_ascii(s): + """Converts a string to ascii if it is ascii, otherwise leave it alone. + + If given None then will return None. + """ + if s is None: + return None + + try: + return s.encode("ascii") + except UnicodeEncodeError: + return s diff --git a/synapse/visibility.py b/synapse/visibility.py index 31659156a..c4dd9ae2c 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): events ([synapse.events.EventBase]): list of events to filter """ forgotten = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(store.who_forgot_in_room)( + defer.maybeDeferred( + preserve_fn(store.who_forgot_in_room), room_id, ) for room_id in frozenset(e.room_id for e in events) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index aa8cc5055..7586ea905 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -19,10 +19,12 @@ from twisted.internet import defer from mock import Mock from tests import unittest +import re + def _regex(regex, exclusive=True): return { - "regex": regex, + "regex": re.compile(regex), "exclusive": exclusive } diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 878a54dc3..19f5ed6bc 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -14,6 +14,7 @@ # limitations under the License. import mock +from synapse.api import errors from twisted.internet import defer import synapse.api.errors @@ -44,3 +45,134 @@ class E2eKeysHandlerTestCase(unittest.TestCase): local_user = "@boris:" + self.hs.hostname res = yield self.handler.query_local_devices({local_user: None}) self.assertDictEqual(res, {local_user: {}}) + + @defer.inlineCallbacks + def test_reupload_one_time_keys(self): + """we should be able to re-upload the same keys""" + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + keys = { + "alg1:k1": "key1", + "alg2:k2": { + "key": "key2", + "signatures": {"k1": "sig1"} + }, + "alg2:k3": { + "key": "key3", + }, + } + + res = yield self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys}, + ) + self.assertDictEqual(res, { + "one_time_key_counts": {"alg1": 1, "alg2": 2} + }) + + # we should be able to change the signature without a problem + keys["alg2:k2"]["signatures"]["k1"] = "sig2" + res = yield self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys}, + ) + self.assertDictEqual(res, { + "one_time_key_counts": {"alg1": 1, "alg2": 2} + }) + + @defer.inlineCallbacks + def test_change_one_time_keys(self): + """attempts to change one-time-keys should be rejected""" + + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + keys = { + "alg1:k1": "key1", + "alg2:k2": { + "key": "key2", + "signatures": {"k1": "sig1"} + }, + "alg2:k3": { + "key": "key3", + }, + } + + res = yield self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys}, + ) + self.assertDictEqual(res, { + "one_time_key_counts": {"alg1": 1, "alg2": 2} + }) + + try: + yield self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}, + ) + self.fail("No error when changing string key") + except errors.SynapseError: + pass + + try: + yield self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}, + ) + self.fail("No error when replacing dict key with string") + except errors.SynapseError: + pass + + try: + yield self.handler.upload_keys_for_user( + local_user, device_id, { + "one_time_keys": {"alg1:k1": {"key": "key"}} + }, + ) + self.fail("No error when replacing string key with dict") + except errors.SynapseError: + pass + + try: + yield self.handler.upload_keys_for_user( + local_user, device_id, { + "one_time_keys": { + "alg2:k2": { + "key": "key3", + "signatures": {"k1": "sig1"}, + } + }, + }, + ) + self.fail("No error when replacing dict key") + except errors.SynapseError: + pass + + @unittest.DEBUG + @defer.inlineCallbacks + def test_claim_one_time_key(self): + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + keys = { + "alg1:k1": "key1", + } + + res = yield self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys}, + ) + self.assertDictEqual(res, { + "one_time_key_counts": {"alg1": 1} + }) + + res2 = yield self.handler.claim_one_time_keys({ + "one_time_keys": { + local_user: { + device_id: "alg1" + } + } + }, timeout=None) + self.assertEqual(res2, { + "failures": {}, + "one_time_keys": { + local_user: { + device_id: { + "alg1:k1": "key1" + } + } + } + }) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index b82868054..81063f19a 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from twisted.internet import defer, reactor from tests import unittest from mock import Mock, NonCallableMock from tests.utils import setup_test_homeserver -from synapse.replication.resource import ReplicationResource +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.replication.tcp.client import ( + ReplicationClientHandler, ReplicationClientFactory, +) class BaseSlavedStoreTestCase(unittest.TestCase): @@ -33,18 +36,29 @@ class BaseSlavedStoreTestCase(unittest.TestCase): ) self.hs.get_ratelimiter().send_message.return_value = (True, 0) - self.replication = ReplicationResource(self.hs) - self.master_store = self.hs.get_datastore() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 + server_factory = ReplicationStreamProtocolFactory(self.hs) + listener = reactor.listenUNIX("\0xxx", server_factory) + self.addCleanup(listener.stopListening) + self.streamer = server_factory.streamer + + self.replication_handler = ReplicationClientHandler(self.slaved_store) + client_factory = ReplicationClientFactory( + self.hs, "client_name", self.replication_handler + ) + client_connector = reactor.connectUNIX("\0xxx", client_factory) + self.addCleanup(client_factory.stopTrying) + self.addCleanup(client_connector.disconnect) + @defer.inlineCallbacks def replicate(self): - streams = self.slaved_store.stream_positions() - writer = yield self.replication.replicate(streams, 100) - result = writer.finish() - yield self.slaved_store.process_replication(result) + yield self.streamer.on_notifier_poke() + d = self.replication_handler.await_sync("replication_test") + self.streamer.send_sync_to_all_connections("replication_test") + yield d @defer.inlineCallbacks def check(self, method, args, expected_result=None): diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py deleted file mode 100644 index 429b37e36..000000000 --- a/tests/replication/test_resource.py +++ /dev/null @@ -1,204 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import json - -from mock import Mock, NonCallableMock -from twisted.internet import defer - -import synapse.types -from synapse.replication.resource import ReplicationResource -from synapse.types import UserID -from tests import unittest -from tests.utils import setup_test_homeserver - - -class ReplicationResourceCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver( - "red", - http_client=None, - replication_layer=Mock(), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), - ) - self.user_id = "@seeing:red" - self.user = UserID.from_string(self.user_id) - - self.hs.get_ratelimiter().send_message.return_value = (True, 0) - - self.resource = ReplicationResource(self.hs) - - @defer.inlineCallbacks - def test_streams(self): - # Passing "-1" returns the current stream positions - code, body = yield self.get(streams="-1") - self.assertEquals(code, 200) - self.assertEquals(body["streams"]["field_names"], ["name", "position"]) - position = body["streams"]["position"] - # Passing the current position returns an empty response after the - # timeout - get = self.get(streams=str(position), timeout="0") - self.hs.clock.advance_time_msec(1) - code, body = yield get - self.assertEquals(code, 200) - self.assertEquals(body, {}) - - @defer.inlineCallbacks - def test_events(self): - get = self.get(events="-1", timeout="0") - yield self.hs.get_handlers().room_creation_handler.create_room( - synapse.types.create_requester(self.user), {} - ) - code, body = yield get - self.assertEquals(code, 200) - self.assertEquals(body["events"]["field_names"], [ - "position", "event_id", "room_id", "type", "state_key", - ]) - - @defer.inlineCallbacks - def test_presence(self): - get = self.get(presence="-1") - yield self.hs.get_presence_handler().set_state( - self.user, {"presence": "online"} - ) - code, body = yield get - self.assertEquals(code, 200) - self.assertEquals(body["presence"]["field_names"], [ - "position", "user_id", "state", "last_active_ts", - "last_federation_update_ts", "last_user_sync_ts", - "status_msg", "currently_active", - ]) - - @defer.inlineCallbacks - def test_typing(self): - room_id = yield self.create_room() - get = self.get(typing="-1") - yield self.hs.get_typing_handler().started_typing( - self.user, self.user, room_id, timeout=2 - ) - code, body = yield get - self.assertEquals(code, 200) - self.assertEquals(body["typing"]["field_names"], [ - "position", "room_id", "typing" - ]) - - @defer.inlineCallbacks - def test_receipts(self): - room_id = yield self.create_room() - event_id = yield self.send_text_message(room_id, "Hello, World") - get = self.get(receipts="-1") - yield self.hs.get_receipts_handler().received_client_receipt( - room_id, "m.read", self.user_id, event_id - ) - code, body = yield get - self.assertEquals(code, 200) - self.assertEquals(body["receipts"]["field_names"], [ - "position", "room_id", "receipt_type", "user_id", "event_id", "data" - ]) - - def _test_timeout(stream): - """Check that a request for the given stream timesout""" - @defer.inlineCallbacks - def test_timeout(self): - get = self.get(**{stream: "-1", "timeout": "0"}) - self.hs.clock.advance_time_msec(1) - code, body = yield get - self.assertEquals(code, 200) - self.assertEquals(body.get("rows", []), []) - test_timeout.__name__ = "test_timeout_%s" % (stream) - return test_timeout - - test_timeout_events = _test_timeout("events") - test_timeout_presence = _test_timeout("presence") - test_timeout_typing = _test_timeout("typing") - test_timeout_receipts = _test_timeout("receipts") - test_timeout_user_account_data = _test_timeout("user_account_data") - test_timeout_room_account_data = _test_timeout("room_account_data") - test_timeout_tag_account_data = _test_timeout("tag_account_data") - test_timeout_backfill = _test_timeout("backfill") - test_timeout_push_rules = _test_timeout("push_rules") - test_timeout_pushers = _test_timeout("pushers") - test_timeout_state = _test_timeout("state") - - @defer.inlineCallbacks - def send_text_message(self, room_id, message): - handler = self.hs.get_handlers().message_handler - event = yield handler.create_and_send_nonmember_event( - synapse.types.create_requester(self.user), - { - "type": "m.room.message", - "content": {"body": "message", "msgtype": "m.text"}, - "room_id": room_id, - "sender": self.user.to_string(), - } - ) - defer.returnValue(event.event_id) - - @defer.inlineCallbacks - def create_room(self): - result = yield self.hs.get_handlers().room_creation_handler.create_room( - synapse.types.create_requester(self.user), {} - ) - defer.returnValue(result["room_id"]) - - @defer.inlineCallbacks - def get(self, **params): - request = NonCallableMock(spec_set=[ - "write", "finish", "setResponseCode", "setHeader", "args", - "method", "processing" - ]) - - request.method = "GET" - request.args = {k: [v] for k, v in params.items()} - - @contextlib.contextmanager - def processing(): - yield - request.processing = processing - - yield self.resource._async_render_GET(request) - self.assertTrue(request.finish.called) - - if request.setResponseCode.called: - response_code = request.setResponseCode.call_args[0][0] - else: - response_code = 200 - - response_json = "".join( - call[0][0] for call in request.write.call_args_list - ) - response_body = json.loads(response_json) - - if response_code == 200: - self.check_response(response_body) - - defer.returnValue((response_code, response_body)) - - def check_response(self, response_body): - for name, stream in response_body.items(): - self.assertIn("field_names", stream) - field_names = stream["field_names"] - self.assertIn("rows", stream) - for row in stream["rows"]: - self.assertEquals( - len(row), len(field_names), - "%s: len(row = %r) == len(field_names = %r)" % ( - name, row, field_names - ) - ) diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py index 38556da9a..024ac1506 100644 --- a/tests/storage/event_injector.py +++ b/tests/storage/event_injector.py @@ -27,10 +27,10 @@ class EventInjector: self.event_builder_factory = hs.get_event_builder_factory() @defer.inlineCallbacks - def create_room(self, room): + def create_room(self, room, user): builder = self.event_builder_factory.new({ "type": EventTypes.Create, - "sender": "", + "sender": user.to_string(), "room_id": room.to_string(), "content": {}, }) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8361dd8ce..281eb1625 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a.func.prefill(("foo",), ObservableDeferred(d)) - self.assertEquals(a.func("foo").result, d.result) + self.assertEquals(a.func("foo"), d.result) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 3762b38e3..14443b53b 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -50,7 +50,7 @@ class EventsStoreTestCase(unittest.TestCase): # Create something to report room = RoomID.from_string("!abc123:test") user = UserID.from_string("@raccoonlover:test") - yield self.event_injector.create_room(room) + yield self.event_injector.create_room(room, user) self.base_event = yield self._get_last_stream_token() diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 4414e8677..3f14ab503 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -175,3 +175,41 @@ class DescriptorTestCase(unittest.TestCase): logcontext.LoggingContext.sentinel) return d1 + + @defer.inlineCallbacks + def test_cache_default_args(self): + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2=2, arg3=3): + return self.mock(arg1, arg2, arg3) + + obj = Cls() + + obj.mock.return_value = 'fish' + r = yield obj.fn(1, 2, 3) + self.assertEqual(r, 'fish') + obj.mock.assert_called_once_with(1, 2, 3) + obj.mock.reset_mock() + + # a call with same params shouldn't call the mock again + r = yield obj.fn(1, 2) + self.assertEqual(r, 'fish') + obj.mock.assert_not_called() + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = 'chips' + r = yield obj.fn(2, 3) + self.assertEqual(r, 'chips') + obj.mock.assert_called_once_with(2, 3, 3) + obj.mock.reset_mock() + + # the two values should now be cached + r = yield obj.fn(1, 2) + self.assertEqual(r, 'fish') + r = yield obj.fn(2, 3) + self.assertEqual(r, 'chips') + obj.mock.assert_not_called() diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py index 7e289715b..d3a8630c2 100644 --- a/tests/util/test_snapshot_cache.py +++ b/tests/util/test_snapshot_cache.py @@ -53,7 +53,9 @@ class SnapshotCacheTestCase(unittest.TestCase): # before the cache expires returns a resolved deferred. get_result_at_11 = self.cache.get(11, "key") self.assertIsNotNone(get_result_at_11) - self.assertTrue(get_result_at_11.called) + if isinstance(get_result_at_11, Deferred): + # The cache may return the actual result rather than a deferred + self.assertTrue(get_result_at_11.called) # Check that getting the key after the deferred has resolved # after the cache expires returns None