Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/new_push_rules

This commit is contained in:
Brendan Abolivier 2020-08-06 10:52:50 +01:00
commit 118a9eafb3
399 changed files with 2142 additions and 1732 deletions

View File

@ -4,12 +4,12 @@ about: Create a report to help us improve
--- ---
<!--
**THIS IS NOT A SUPPORT CHANNEL!** **THIS IS NOT A SUPPORT CHANNEL!**
**IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**, **IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**,
please ask in **#synapse:matrix.org** (using a matrix.org account if necessary) please ask in **#synapse:matrix.org** (using a matrix.org account if necessary)
<!--
If you want to report a security issue, please see https://matrix.org/security-disclosure-policy/ If you want to report a security issue, please see https://matrix.org/security-disclosure-policy/
This is a bug report template. By following the instructions below and This is a bug report template. By following the instructions below and

1
changelog.d/7314.misc Normal file
View File

@ -0,0 +1 @@
Allow guest access to the `GET /_matrix/client/r0/rooms/{room_id}/members` endpoint, according to MSC2689. Contributed by Awesome Technologies Innovationslabor GmbH.

1
changelog.d/7977.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in Synapse v1.7.2 which caused inaccurate membership counts in the room directory.

1
changelog.d/7987.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7989.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7996.bugfix Normal file
View File

@ -0,0 +1 @@
Fix various comments and minor discrepencies in server notices code.

1
changelog.d/7999.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long standing bug where HTTP HEAD requests resulted in a 400 error.

1
changelog.d/8000.doc Normal file
View File

@ -0,0 +1 @@
Improve workers docs.

1
changelog.d/8001.misc Normal file
View File

@ -0,0 +1 @@
Remove redundant and unreliable signature check for v1 Identity Service lookup responses.

1
changelog.d/8003.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8008.feature Normal file
View File

@ -0,0 +1 @@
Add rate limiting to users joining rooms.

1
changelog.d/8011.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger.

1
changelog.d/8012.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger.

1
changelog.d/8014.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8016.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8024.misc Normal file
View File

@ -0,0 +1 @@
Reduce less useful output in the newsfragment CI step. Add a link to the changelog section of the contributing guide on error.

1
changelog.d/8027.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8033.misc Normal file
View File

@ -0,0 +1 @@
Rename storage layer objects to be more sensible.

View File

@ -746,6 +746,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# - one for ratelimiting redactions by room admins. If this is not explicitly # - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful # set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly. # to allow room admins to deal with abuse quickly.
# - two for ratelimiting number of rooms a user can join, "local" for when
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
# #
# The defaults are as shown below. # The defaults are as shown below.
# #
@ -771,6 +775,14 @@ log_config: "CONFDIR/SERVERNAME.log.config"
#rc_admin_redaction: #rc_admin_redaction:
# per_second: 1 # per_second: 1
# burst_count: 50 # burst_count: 50
#
#rc_joins:
# local:
# per_second: 0.1
# burst_count: 3
# remote:
# per_second: 0.01
# burst_count: 3
# Ratelimiting settings for incoming federation # Ratelimiting settings for incoming federation

View File

@ -1,7 +1,7 @@
worker_app: synapse.app.federation_reader worker_app: synapse.app.federation_reader
worker_name: federation_reader1
worker_replication_host: 127.0.0.1 worker_replication_host: 127.0.0.1
worker_replication_port: 9092
worker_replication_http_port: 9093 worker_replication_http_port: 9093
worker_listeners: worker_listeners:

View File

@ -7,6 +7,6 @@ who are present in a publicly viewable room present on the server.
The directory info is stored in various tables, which can (typically after The directory info is stored in various tables, which can (typically after
DB corruption) get stale or out of sync. If this happens, for now the DB corruption) get stale or out of sync. If this happens, for now the
solution to fix it is to execute the SQL [here](../synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql) solution to fix it is to execute the SQL [here](../synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql)
and then restart synapse. This should then start a background task to and then restart synapse. This should then start a background task to
flush the current tables and regenerate the directory. flush the current tables and regenerate the directory.

View File

@ -70,19 +70,27 @@ the correct worker, or to the main synapse instance. See
[reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse [reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse
proxy. proxy.
To enable workers you should create a configuration file for each worker When using workers, each worker process has its own configuration file which
process. Each worker configuration file inherits the configuration of the shared contains settings specific to that worker, such as the HTTP listener that it
homeserver configuration file. You can then override configuration specific to provides (if any), logging configuration, etc.
that worker, e.g. the HTTP listener that it provides (if any); logging
configuration; etc. You should minimise the number of overrides though to Normally, the worker processes are configured to read from a shared
maintain a usable config. configuration file as well as the worker-specific configuration files. This
makes it easier to keep common configuration settings synchronised across all
the processes.
The main process is somewhat special in this respect: it does not normally
need its own configuration file and can take all of its configuration from the
shared configuration file.
### Shared Configuration ### Shared configuration
Normally, only a couple of changes are needed to make an existing configuration
file suitable for use with workers. First, you need to enable an "HTTP replication
listener" for the main process; and secondly, you need to enable redis-based
replication. For example:
Next you need to add both a HTTP replication listener, used for HTTP requests
between processes, and redis config to the shared Synapse configuration file
(`homeserver.yaml`). For example:
```yaml ```yaml
# extend the existing `listeners` section. This defines the ports that the # extend the existing `listeners` section. This defines the ports that the
@ -105,7 +113,7 @@ Under **no circumstances** should the replication listener be exposed to the
public internet; it has no authentication and is unencrypted. public internet; it has no authentication and is unencrypted.
### Worker Configuration ### Worker configuration
In the config file for each worker, you must specify the type of worker In the config file for each worker, you must specify the type of worker
application (`worker_app`), and you should specify a unqiue name for the worker application (`worker_app`), and you should specify a unqiue name for the worker
@ -145,6 +153,9 @@ plain HTTP endpoint on port 8083 separately serving various endpoints, e.g.
Obviously you should configure your reverse-proxy to route the relevant Obviously you should configure your reverse-proxy to route the relevant
endpoints to the worker (`localhost:8083` in the above example). endpoints to the worker (`localhost:8083` in the above example).
### Running Synapse with workers
Finally, you need to start your worker processes. This can be done with either Finally, you need to start your worker processes. This can be done with either
`synctl` or your distribution's preferred service manager such as `systemd`. We `synctl` or your distribution's preferred service manager such as `systemd`. We
recommend the use of `systemd` where available: for information on setting up recommend the use of `systemd` where available: for information on setting up
@ -407,6 +418,23 @@ all these to be folded into the `generic_worker` app and to use config to define
which processes handle the various proccessing such as push notifications. which processes handle the various proccessing such as push notifications.
## Migration from old config
There are two main independent changes that have been made: introducing Redis
support and merging apps into `synapse.app.generic_worker`. Both these changes
are backwards compatible and so no changes to the config are required, however
server admins are encouraged to plan to migrate to Redis as the old style direct
TCP replication config is deprecated.
To migrate to Redis add the `redis` config as above, and optionally remove the
TCP `replication` listener from master and `worker_replication_port` from worker
config.
To migrate apps to use `synapse.app.generic_worker` simply update the
`worker_app` option in the worker configs, and where worker are started (e.g.
in systemd service files, but not required for synctl).
## Architectural diagram ## Architectural diagram
The following shows an example setup using Redis and a reverse proxy: The following shows an example setup using Redis and a reverse proxy:

View File

@ -3,6 +3,8 @@
# A script which checks that an appropriate news file has been added on this # A script which checks that an appropriate news file has been added on this
# branch. # branch.
echo -e "+++ \033[32mChecking newsfragment\033[m"
set -e set -e
# make sure that origin/develop is up to date # make sure that origin/develop is up to date
@ -16,6 +18,8 @@ pr="$BUILDKITE_PULL_REQUEST"
if ! git diff --quiet FETCH_HEAD... -- debian; then if ! git diff --quiet FETCH_HEAD... -- debian; then
if git diff --quiet FETCH_HEAD... -- debian/changelog; then if git diff --quiet FETCH_HEAD... -- debian/changelog; then
echo "Updates to debian directory, but no update to the changelog." >&2 echo "Updates to debian directory, but no update to the changelog." >&2
echo "!! Please see the contributing guide for help writing your changelog entry:" >&2
echo "https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#debian-changelog" >&2
exit 1 exit 1
fi fi
fi fi
@ -26,7 +30,12 @@ if ! git diff --name-only FETCH_HEAD... | grep -qv '^debian/'; then
exit 0 exit 0
fi fi
tox -qe check-newsfragment # Print a link to the contributing guide if the user makes a mistake
CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing your changelog entry:
https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog"
# If check-newsfragment returns a non-zero exit code, print the contributing guide and exit
tox -qe check-newsfragment || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1)
echo echo
echo "--------------------------" echo "--------------------------"
@ -38,6 +47,7 @@ for f in `git diff --name-only FETCH_HEAD... -- changelog.d`; do
lastchar=`tr -d '\n' < $f | tail -c 1` lastchar=`tr -d '\n' < $f | tail -c 1`
if [ $lastchar != '.' -a $lastchar != '!' ]; then if [ $lastchar != '.' -a $lastchar != '!' ]; then
echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2 echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2
echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2
exit 1 exit 1
fi fi
@ -47,5 +57,6 @@ done
if [[ -n "$pr" && "$matched" -eq 0 ]]; then if [[ -n "$pr" && "$matched" -eq 0 ]]; then
echo -e "\e[31mERROR: Did not find a news fragment with the right number: expected changelog.d/$pr.*.\e[39m" >&2 echo -e "\e[31mERROR: Did not find a news fragment with the right number: expected changelog.d/$pr.*.\e[39m" >&2
echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2
exit 1 exit 1
fi fi

View File

@ -40,7 +40,7 @@ class MockHomeserver(HomeServer):
config.server_name, reactor=reactor, config=config, **kwargs config.server_name, reactor=reactor, config=config, **kwargs
) )
self.version_string = "Synapse/"+get_version_string(synapse) self.version_string = "Synapse/" + get_version_string(synapse)
if __name__ == "__main__": if __name__ == "__main__":
@ -86,7 +86,7 @@ if __name__ == "__main__":
store = hs.get_datastore() store = hs.get_datastore()
async def run_background_updates(): async def run_background_updates():
await store.db.updates.run_background_updates(sleep=False) await store.db_pool.updates.run_background_updates(sleep=False)
# Stop the reactor to exit the script once every background update is run. # Stop the reactor to exit the script once every background update is run.
reactor.stop() reactor.stop()

View File

@ -35,31 +35,29 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
run_in_background, run_in_background,
) )
from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.data_stores.main.deviceinbox import ( from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
DeviceInboxBackgroundUpdateStore, from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
) from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore from synapse.storage.databases.main.events_bg_updates import (
from synapse.storage.data_stores.main.events_bg_updates import (
EventsBackgroundUpdatesStore, EventsBackgroundUpdatesStore,
) )
from synapse.storage.data_stores.main.media_repository import ( from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore, MediaRepositoryBackgroundUpdateStore,
) )
from synapse.storage.data_stores.main.registration import ( from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore, RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart, find_max_generated_user_id_localpart,
) )
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore
from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.data_stores.main.user_directory import ( from synapse.storage.databases.main.user_directory import (
UserDirectoryBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore,
) )
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
from synapse.util import Clock from synapse.util import Clock
@ -175,14 +173,14 @@ class Store(
StatsStore, StatsStore,
): ):
def execute(self, f, *args, **kwargs): def execute(self, f, *args, **kwargs):
return self.db.runInteraction(f.__name__, f, *args, **kwargs) return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args): def execute_sql(self, sql, *args):
def r(txn): def r(txn):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
return self.db.runInteraction("execute_sql", r) return self.db_pool.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows): def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % ( sql = "INSERT INTO %s (%s) VALUES (%s)" % (
@ -227,7 +225,7 @@ class Porter(object):
async def setup_table(self, table): async def setup_table(self, table):
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting. # It's safe to just carry on inserting.
row = await self.postgres_store.db.simple_select_one( row = await self.postgres_store.db_pool.simple_select_one(
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
retcols=("forward_rowid", "backward_rowid"), retcols=("forward_rowid", "backward_rowid"),
@ -244,7 +242,7 @@ class Porter(object):
) = await self._setup_sent_transactions() ) = await self._setup_sent_transactions()
backward_chunk = 0 backward_chunk = 0
else: else:
await self.postgres_store.db.simple_insert( await self.postgres_store.db_pool.simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={
"table_name": table, "table_name": table,
@ -274,7 +272,7 @@ class Porter(object):
await self.postgres_store.execute(delete_all) await self.postgres_store.execute(delete_all)
await self.postgres_store.db.simple_insert( await self.postgres_store.db_pool.simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
) )
@ -318,7 +316,7 @@ class Porter(object):
if table == "user_directory_stream_pos": if table == "user_directory_stream_pos":
# We need to make sure there is a single row, `(X, null), as that is # We need to make sure there is a single row, `(X, null), as that is
# what synapse expects to be there. # what synapse expects to be there.
await self.postgres_store.db.simple_insert( await self.postgres_store.db_pool.simple_insert(
table=table, values={"stream_id": None} table=table, values={"stream_id": None}
) )
self.progress.update(table, table_size) # Mark table as done self.progress.update(table, table_size) # Mark table as done
@ -359,7 +357,7 @@ class Porter(object):
return headers, forward_rows, backward_rows return headers, forward_rows, backward_rows
headers, frows, brows = await self.sqlite_store.db.runInteraction( headers, frows, brows = await self.sqlite_store.db_pool.runInteraction(
"select", r "select", r
) )
@ -375,7 +373,7 @@ class Porter(object):
def insert(txn): def insert(txn):
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
self.postgres_store.db.simple_update_one_txn( self.postgres_store.db_pool.simple_update_one_txn(
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
@ -413,7 +411,7 @@ class Porter(object):
return headers, rows return headers, rows
headers, rows = await self.sqlite_store.db.runInteraction("select", r) headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r)
if rows: if rows:
forward_chunk = rows[-1][0] + 1 forward_chunk = rows[-1][0] + 1
@ -451,7 +449,7 @@ class Porter(object):
], ],
) )
self.postgres_store.db.simple_update_one_txn( self.postgres_store.db_pool.simple_update_one_txn(
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": "event_search"}, keyvalues={"table_name": "event_search"},
@ -494,7 +492,7 @@ class Porter(object):
db_conn, allow_outdated_version=allow_outdated_version db_conn, allow_outdated_version=allow_outdated_version
) )
prepare_database(db_conn, engine, config=self.hs_config) prepare_database(db_conn, engine, config=self.hs_config)
store = Store(Database(hs, db_config, engine), db_conn, hs) store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
db_conn.commit() db_conn.commit()
return store return store
@ -502,7 +500,7 @@ class Porter(object):
async def run_background_updates_on_postgres(self): async def run_background_updates_on_postgres(self):
# Manually apply all background updates on the PostgreSQL database. # Manually apply all background updates on the PostgreSQL database.
postgres_ready = ( postgres_ready = (
await self.postgres_store.db.updates.has_completed_background_updates() await self.postgres_store.db_pool.updates.has_completed_background_updates()
) )
if not postgres_ready: if not postgres_ready:
@ -511,9 +509,9 @@ class Porter(object):
self.progress.set_state("Running background updates on PostgreSQL") self.progress.set_state("Running background updates on PostgreSQL")
while not postgres_ready: while not postgres_ready:
await self.postgres_store.db.updates.do_next_background_update(100) await self.postgres_store.db_pool.updates.do_next_background_update(100)
postgres_ready = await ( postgres_ready = await (
self.postgres_store.db.updates.has_completed_background_updates() self.postgres_store.db_pool.updates.has_completed_background_updates()
) )
async def run(self): async def run(self):
@ -534,7 +532,7 @@ class Porter(object):
# Check if all background updates are done, abort if not. # Check if all background updates are done, abort if not.
updates_complete = ( updates_complete = (
await self.sqlite_store.db.updates.has_completed_background_updates() await self.sqlite_store.db_pool.updates.has_completed_background_updates()
) )
if not updates_complete: if not updates_complete:
end_error = ( end_error = (
@ -576,22 +574,24 @@ class Porter(object):
) )
try: try:
await self.postgres_store.db.runInteraction("alter_table", alter_table) await self.postgres_store.db_pool.runInteraction(
"alter_table", alter_table
)
except Exception: except Exception:
# On Error Resume Next # On Error Resume Next
pass pass
await self.postgres_store.db.runInteraction( await self.postgres_store.db_pool.runInteraction(
"create_port_table", create_port_table "create_port_table", create_port_table
) )
# Step 2. Get tables. # Step 2. Get tables.
self.progress.set_state("Fetching tables") self.progress.set_state("Fetching tables")
sqlite_tables = await self.sqlite_store.db.simple_select_onecol( sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol(
table="sqlite_master", keyvalues={"type": "table"}, retcol="name" table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
) )
postgres_tables = await self.postgres_store.db.simple_select_onecol( postgres_tables = await self.postgres_store.db_pool.simple_select_onecol(
table="information_schema.tables", table="information_schema.tables",
keyvalues={}, keyvalues={},
retcol="distinct table_name", retcol="distinct table_name",
@ -692,7 +692,7 @@ class Porter(object):
return headers, [r for r in rows if r[ts_ind] < yesterday] return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = await self.sqlite_store.db.runInteraction("select", r) headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r)
rows = self._convert_rows("sent_transactions", headers, rows) rows = self._convert_rows("sent_transactions", headers, rows)
@ -725,7 +725,7 @@ class Porter(object):
next_chunk = await self.sqlite_store.execute(get_start_id) next_chunk = await self.sqlite_store.execute(get_start_id)
next_chunk = max(max_inserted_rowid + 1, next_chunk) next_chunk = max(max_inserted_rowid + 1, next_chunk)
await self.postgres_store.db.simple_insert( await self.postgres_store.db_pool.simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={
"table_name": "sent_transactions", "table_name": "sent_transactions",
@ -794,14 +794,14 @@ class Porter(object):
next_id = curr_id + 1 next_id = curr_id + 1
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
def _setup_user_id_seq(self): def _setup_user_id_seq(self):
def r(txn): def r(txn):
next_id = find_max_generated_user_id_localpart(txn) + 1 next_id = find_max_generated_user_id_localpart(txn) + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.db.runInteraction("setup_user_id_seq", r) return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
############################################## ##############################################

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc import gc
import logging import logging
import os import os
@ -22,7 +21,6 @@ import sys
import traceback import traceback
from typing import Iterable from typing import Iterable
from daemonize import Daemonize
from typing_extensions import NoReturn from typing_extensions import NoReturn
from twisted.internet import defer, error, reactor from twisted.internet import defer, error, reactor
@ -34,6 +32,7 @@ from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import PreserveLoggingContext
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -129,16 +128,7 @@ def start_reactor(
if print_pidfile: if print_pidfile:
print(pid_file) print(pid_file)
daemon = Daemonize( daemonize_process(pid_file, logger)
app=appname,
pid=pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run() run()
@ -278,7 +268,7 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# It is now safe to start your Synapse. # It is now safe to start your Synapse.
hs.start_listening(listeners) hs.start_listening(listeners)
hs.get_datastore().db.start_profiling() hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start() hs.get_pusherpool().start()
setup_sentry(hs) setup_sentry(hs)

View File

@ -125,15 +125,15 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.data_stores.main.censor_events import CensorEventsStore from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore
from synapse.storage.data_stores.main.monthly_active_users import ( from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore, MonthlyActiveUsersWorkerStore,
) )
from synapse.storage.data_stores.main.presence import UserPresenceState from synapse.storage.databases.main.presence import UserPresenceState
from synapse.storage.data_stores.main.search import SearchWorkerStore from synapse.storage.databases.main.search import SearchWorkerStore
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree

View File

@ -380,13 +380,12 @@ def setup(config_options):
hs.setup_master() hs.setup_master()
@defer.inlineCallbacks async def do_acme() -> bool:
def do_acme():
""" """
Reprovision an ACME certificate, if it's required. Reprovision an ACME certificate, if it's required.
Returns: Returns:
Deferred[bool]: Whether the cert has been updated. Whether the cert has been updated.
""" """
acme = hs.get_acme_handler() acme = hs.get_acme_handler()
@ -405,7 +404,7 @@ def setup(config_options):
provision = True provision = True
if provision: if provision:
yield acme.provision_certificate() await acme.provision_certificate()
return provision return provision
@ -415,7 +414,7 @@ def setup(config_options):
Provision a certificate from ACME, if required, and reload the TLS Provision a certificate from ACME, if required, and reload the TLS
certificate if it's renewed. certificate if it's renewed.
""" """
reprovisioned = yield do_acme() reprovisioned = yield defer.ensureDeferred(do_acme())
if reprovisioned: if reprovisioned:
_base.refresh_certificate(hs) _base.refresh_certificate(hs)
@ -427,8 +426,8 @@ def setup(config_options):
acme = hs.get_acme_handler() acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME # Start up the webservices which we will respond to ACME
# challenges with, and then provision. # challenges with, and then provision.
yield acme.start_listening() yield defer.ensureDeferred(acme.start_listening())
yield do_acme() yield defer.ensureDeferred(do_acme())
# Check if it needs to be reprovisioned every day. # Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
@ -442,7 +441,7 @@ def setup(config_options):
_base.start(hs, config.listeners) _base.start(hs, config.listeners)
hs.get_datastore().db.updates.start_doing_background_updates() hs.get_datastore().db_pool.updates.start_doing_background_updates()
except Exception: except Exception:
# Print the exception and bail out. # Print the exception and bail out.
print("Error during startup:", file=sys.stderr) print("Error during startup:", file=sys.stderr)
@ -552,8 +551,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
# #
# This only reports info about the *main* database. # This only reports info about the *main* database.
stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
stats["database_server_version"] = hs.get_datastore().db.engine.server_version stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try: try:

View File

@ -175,7 +175,7 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol), urllib.parse.quote(protocol),
) )
try: try:
info = yield self.get_json(uri, {}) info = yield defer.ensureDeferred(self.get_json(uri, {}))
if not _is_valid_3pe_metadata(info): if not _is_valid_3pe_metadata(info):
logger.warning( logger.warning(

View File

@ -100,7 +100,10 @@ class DatabaseConnectionConfig:
self.name = name self.name = name
self.config = db_config self.config = db_config
self.data_stores = data_stores
# The `data_stores` config is actually talking about `databases` (we
# changed the name).
self.databases = data_stores
class DatabaseConfig(Config): class DatabaseConfig(Config):

View File

@ -93,6 +93,15 @@ class RatelimitConfig(Config):
if rc_admin_redaction: if rc_admin_redaction:
self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction) self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction)
self.rc_joins_local = RateLimitConfig(
config.get("rc_joins", {}).get("local", {}),
defaults={"per_second": 0.1, "burst_count": 3},
)
self.rc_joins_remote = RateLimitConfig(
config.get("rc_joins", {}).get("remote", {}),
defaults={"per_second": 0.01, "burst_count": 3},
)
def generate_config_section(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
## Ratelimiting ## ## Ratelimiting ##
@ -118,6 +127,10 @@ class RatelimitConfig(Config):
# - one for ratelimiting redactions by room admins. If this is not explicitly # - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful # set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly. # to allow room admins to deal with abuse quickly.
# - two for ratelimiting number of rooms a user can join, "local" for when
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
# #
# The defaults are as shown below. # The defaults are as shown below.
# #
@ -143,6 +156,14 @@ class RatelimitConfig(Config):
#rc_admin_redaction: #rc_admin_redaction:
# per_second: 1 # per_second: 1
# burst_count: 50 # burst_count: 50
#
#rc_joins:
# local:
# per_second: 0.1
# burst_count: 3
# remote:
# per_second: 0.01
# burst_count: 3
# Ratelimiting settings for incoming federation # Ratelimiting settings for incoming federation

View File

@ -223,8 +223,7 @@ class Keyring(object):
return results return results
@defer.inlineCallbacks async def _start_key_lookups(self, verify_requests):
def _start_key_lookups(self, verify_requests):
"""Sets off the key fetches for each verify request """Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved. Once each fetch completes, verify_request.key_ready will be resolved.
@ -245,7 +244,7 @@ class Keyring(object):
server_to_request_ids.setdefault(server_name, set()).add(request_id) server_to_request_ids.setdefault(server_name, set()).add(request_id)
# Wait for any previous lookups to complete before proceeding. # Wait for any previous lookups to complete before proceeding.
yield self.wait_for_previous_lookups(server_to_request_ids.keys()) await self.wait_for_previous_lookups(server_to_request_ids.keys())
# take out a lock on each of the servers by sticking a Deferred in # take out a lock on each of the servers by sticking a Deferred in
# key_downloads # key_downloads
@ -283,15 +282,14 @@ class Keyring(object):
except Exception: except Exception:
logger.exception("Error starting key lookups") logger.exception("Error starting key lookups")
@defer.inlineCallbacks async def wait_for_previous_lookups(self, server_names) -> None:
def wait_for_previous_lookups(self, server_names):
"""Waits for any previous key lookups for the given servers to finish. """Waits for any previous key lookups for the given servers to finish.
Args: Args:
server_names (Iterable[str]): list of servers which we want to look up server_names (Iterable[str]): list of servers which we want to look up
Returns: Returns:
Deferred[None]: resolves once all key lookups for the given servers have Resolves once all key lookups for the given servers have
completed. Follows the synapse rules of logcontext preservation. completed. Follows the synapse rules of logcontext preservation.
""" """
loop_count = 1 loop_count = 1
@ -309,7 +307,7 @@ class Keyring(object):
loop_count, loop_count,
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
yield defer.DeferredList((w[1] for w in wait_on)) await defer.DeferredList((w[1] for w in wait_on))
loop_count += 1 loop_count += 1
@ -326,13 +324,15 @@ class Keyring(object):
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@defer.inlineCallbacks async def do_iterations():
def do_iterations(): try:
with Measure(self.clock, "get_server_verify_keys"): with Measure(self.clock, "get_server_verify_keys"):
for f in self._key_fetchers: for f in self._key_fetchers:
if not remaining_requests: if not remaining_requests:
return return
yield self._attempt_key_fetches_with_fetcher(f, remaining_requests) await self._attempt_key_fetches_with_fetcher(
f, remaining_requests
)
# look for any requests which weren't satisfied # look for any requests which weren't satisfied
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -349,8 +349,7 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
) )
except Exception as err:
def on_err(err):
# we don't really expect to get here, because any errors should already # we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make # have been caught and logged. But if we do, let's log the error and make
# sure that all of the deferreds are resolved. # sure that all of the deferreds are resolved.
@ -360,10 +359,9 @@ class Keyring(object):
if not verify_request.key_ready.called: if not verify_request.key_ready.called:
verify_request.key_ready.errback(err) verify_request.key_ready.errback(err)
run_in_background(do_iterations).addErrback(on_err) run_in_background(do_iterations)
@defer.inlineCallbacks async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
"""Use a key fetcher to attempt to satisfy some key requests """Use a key fetcher to attempt to satisfy some key requests
Args: Args:
@ -390,7 +388,7 @@ class Keyring(object):
verify_request.minimum_valid_until_ts, verify_request.minimum_valid_until_ts,
) )
results = yield fetcher.get_keys(missing_keys) results = await fetcher.get_keys(missing_keys)
completed = [] completed = []
for verify_request in remaining_requests: for verify_request in remaining_requests:
@ -423,7 +421,7 @@ class Keyring(object):
class KeyFetcher(object): class KeyFetcher(object):
def get_keys(self, keys_to_fetch): async def get_keys(self, keys_to_fetch):
""" """
Args: Args:
keys_to_fetch (dict[str, dict[str, int]]): keys_to_fetch (dict[str, dict[str, int]]):
@ -442,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def get_keys(self, keys_to_fetch):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
keys_to_fetch = ( keys_to_fetch = (
@ -452,7 +449,7 @@ class StoreKeyFetcher(KeyFetcher):
for key_id in keys_for_server.keys() for key_id in keys_for_server.keys()
) )
res = yield self.store.get_server_verify_keys(keys_to_fetch) res = await self.store.get_server_verify_keys(keys_to_fetch)
keys = {} keys = {}
for (server_name, key_id), key in res.items(): for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key keys.setdefault(server_name, {})[key_id] = key
@ -464,8 +461,7 @@ class BaseV2KeyFetcher(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.config = hs.get_config() self.config = hs.get_config()
@defer.inlineCallbacks async def process_v2_response(self, from_server, response_json, time_added_ms):
def process_v2_response(self, from_server, response_json, time_added_ms):
"""Parse a 'Server Keys' structure from the result of a /key request """Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from This is used to parse either the entirety of the response from
@ -537,7 +533,7 @@ class BaseV2KeyFetcher(object):
key_json_bytes = encode_canonical_json(response_json) key_json_bytes = encode_canonical_json(response_json)
yield make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background( run_in_background(
@ -567,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_http_client() self.client = hs.get_http_client()
self.key_servers = self.config.key_servers self.key_servers = self.config.key_servers
@defer.inlineCallbacks async def get_keys(self, keys_to_fetch):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
@defer.inlineCallbacks async def get_key(key_server):
def get_key(key_server):
try: try:
result = yield self.get_server_verify_key_v2_indirect( result = await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server keys_to_fetch, key_server
) )
return result return result
@ -592,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return {} return {}
results = yield make_deferred_yieldable( results = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[run_in_background(get_key, server) for server in self.key_servers], [run_in_background(get_key, server) for server in self.key_servers],
consumeErrors=True, consumeErrors=True,
@ -606,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return union_of_keys return union_of_keys
@defer.inlineCallbacks async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
""" """
Args: Args:
keys_to_fetch (dict[str, dict[str, int]]): keys_to_fetch (dict[str, dict[str, int]]):
@ -617,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
the keys the keys
Returns: Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
from server_name -> key_id -> FetchKeyResult from server_name -> key_id -> FetchKeyResult
Raises: Raises:
@ -632,8 +625,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
) )
try: try:
query_response = yield defer.ensureDeferred( query_response = await self.client.post_json(
self.client.post_json(
destination=perspective_name, destination=perspective_name,
path="/_matrix/key/v2/query", path="/_matrix/key/v2/query",
data={ data={
@ -646,7 +638,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
} }
}, },
) )
)
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve upon # these both have str() representations which we can't really improve upon
raise KeyLookupError(str(e)) raise KeyLookupError(str(e))
@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
try: try:
self._validate_perspectives_response(key_server, response) self._validate_perspectives_response(key_server, response)
processed_response = yield self.process_v2_response( processed_response = await self.process_v2_response(
perspective_name, response, time_added_ms=time_now_ms perspective_name, response, time_added_ms=time_now_ms
) )
except KeyLookupError as e: except KeyLookupError as e:
@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
) )
keys.setdefault(server_name, {}).update(processed_response) keys.setdefault(server_name, {}).update(processed_response)
yield self.store.store_server_verify_keys( await self.store.store_server_verify_keys(
perspective_name, time_now_ms, added_keys perspective_name, time_now_ms, added_keys
) )
@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client() self.client = hs.get_http_client()
def get_keys(self, keys_to_fetch): async def get_keys(self, keys_to_fetch):
""" """
Args: Args:
keys_to_fetch (dict[str, iterable[str]]): keys_to_fetch (dict[str, iterable[str]]):
the keys to be fetched. server_name -> key_ids the keys to be fetched. server_name -> key_ids
Returns: Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
map from server_name -> key_id -> FetchKeyResult map from server_name -> key_id -> FetchKeyResult
""" """
results = {} results = {}
@defer.inlineCallbacks async def get_key(key_to_fetch_item):
def get_key(key_to_fetch_item):
server_name, key_ids = key_to_fetch_item server_name, key_ids = key_to_fetch_item
try: try:
keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids) keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys results[server_name] = keys
except KeyLookupError as e: except KeyLookupError as e:
logger.warning( logger.warning(
@ -767,12 +757,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception: except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name) logger.exception("Error getting keys %s from %s", key_ids, server_name)
return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback( return await yieldable_gather_results(
lambda _: results get_key, keys_to_fetch.items()
) ).addCallback(lambda _: results)
@defer.inlineCallbacks async def get_server_verify_key_v2_direct(self, server_name, key_ids):
def get_server_verify_key_v2_direct(self, server_name, key_ids):
""" """
Args: Args:
@ -794,8 +783,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
try: try:
response = yield defer.ensureDeferred( response = await self.client.get_json(
self.client.get_json(
destination=server_name, destination=server_name,
path="/_matrix/key/v2/server/" path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id), + urllib.parse.quote(requested_key_id),
@ -813,7 +801,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
# read the response). # read the response).
timeout=10000, timeout=10000,
) )
)
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve # these both have str() representations which we can't really improve
# upon # upon
@ -827,12 +814,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
% (server_name, response["server_name"]) % (server_name, response["server_name"])
) )
response_keys = yield self.process_v2_response( response_keys = await self.process_v2_response(
from_server=server_name, from_server=server_name,
response_json=response, response_json=response,
time_added_ms=time_now_ms, time_added_ms=time_now_ms,
) )
yield self.store.store_server_verify_keys( await self.store.store_server_verify_keys(
server_name, server_name,
time_now_ms, time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()), ((server_name, key_id, key) for key_id, key in response_keys.items()),
@ -842,22 +829,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys return keys
@defer.inlineCallbacks async def _handle_key_deferred(verify_request) -> None:
def _handle_key_deferred(verify_request):
"""Waits for the key to become available, and then performs a verification """Waits for the key to become available, and then performs a verification
Args: Args:
verify_request (VerifyJsonRequest): verify_request (VerifyJsonRequest):
Returns:
Deferred[None]
Raises: Raises:
SynapseError if there was a problem performing the verification SynapseError if there was a problem performing the verification
""" """
server_name = verify_request.server_name server_name = verify_request.server_name
with PreserveLoggingContext(): with PreserveLoggingContext():
_, key_id, verify_key = yield verify_request.key_ready _, key_id, verify_key = await verify_request.key_ready
json_object = verify_request.json_object json_object = verify_request.json_object

View File

@ -23,7 +23,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap from synapse.types import StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.storage.data_stores.main import DataStore from synapse.storage.databases.main import DataStore
@attr.s(slots=True) @attr.s(slots=True)

View File

@ -17,7 +17,6 @@ import logging
import twisted import twisted
import twisted.internet.error import twisted.internet.error
from twisted.internet import defer
from twisted.web import server, static from twisted.web import server, static
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -41,8 +40,7 @@ class AcmeHandler(object):
self.reactor = hs.get_reactor() self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain self._acme_domain = hs.config.acme_domain
@defer.inlineCallbacks async def start_listening(self):
def start_listening(self):
from synapse.handlers import acme_issuing_service from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug # Configure logging for txacme, if you need to debug
@ -82,18 +80,17 @@ class AcmeHandler(object):
self._issuer._registered = False self._issuer._registered = False
try: try:
yield self._issuer._ensure_registered() await self._issuer._ensure_registered()
except Exception: except Exception:
logger.error(ACME_REGISTER_FAIL_ERROR) logger.error(ACME_REGISTER_FAIL_ERROR)
raise raise
@defer.inlineCallbacks async def provision_certificate(self):
def provision_certificate(self):
logger.warning("Reprovisioning %s", self._acme_domain) logger.warning("Reprovisioning %s", self._acme_domain)
try: try:
yield self._issuer.issue_cert(self._acme_domain) await self._issuer.issue_cert(self._acme_domain)
except Exception: except Exception:
logger.exception("Fail!") logger.exception("Fail!")
raise raise

View File

@ -71,7 +71,7 @@ from synapse.replication.http.federation import (
) )
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room

View File

@ -22,14 +22,10 @@ import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from canonicaljson import json from canonicaljson import json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from twisted.internet.error import TimeoutError from twisted.internet.error import TimeoutError
from synapse.api.errors import ( from synapse.api.errors import (
AuthError,
CodeMessageException, CodeMessageException,
Codes, Codes,
HttpResponseException, HttpResponseException,
@ -628,9 +624,9 @@ class IdentityHandler(BaseHandler):
) )
if "mxid" in data: if "mxid" in data:
if "signatures" not in data: # note: we used to verify the identity server's signature here, but no longer
raise AuthError(401, "No signatures on 3pid binding") # require or validate it. See the following for context:
await self._verify_any_signature(data, id_server) # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950
return data["mxid"] return data["mxid"]
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
@ -751,30 +747,6 @@ class IdentityHandler(BaseHandler):
mxid = lookup_results["mappings"].get(lookup_value) mxid = lookup_results["mappings"].get(lookup_value)
return mxid return mxid
async def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
try:
key_data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name)
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
if "public_key" not in key_data:
raise AuthError(
401, "No public key named %s from %s" % (key_name, server_hostname)
)
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(
key_name, decode_base64(key_data["public_key"])
),
)
return
async def ask_id_server_for_third_party_invite( async def ask_id_server_for_third_party_invite(
self, self,
requester: Requester, requester: Requester,

View File

@ -109,7 +109,7 @@ class InitialSyncHandler(BaseHandler):
rooms_ret = [] rooms_ret = []
now_token = await self.hs.get_event_sources().get_current_token() now_token = self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"] presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token) pagination_config = PaginationConfig(from_token=now_token)
@ -360,7 +360,7 @@ class InitialSyncHandler(BaseHandler):
current_state.values(), time_now current_state.values(), time_now
) )
now_token = await self.hs.get_event_sources().get_current_token() now_token = self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None limit = pagin_config.limit if pagin_config else None
if limit is None: if limit is None:

View File

@ -45,7 +45,7 @@ from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
Collection, Collection,

View File

@ -309,7 +309,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
else: else:
pagin_config.from_token = ( pagin_config.from_token = (
await self.hs.get_event_sources().get_current_token_for_pagination() self.hs.get_event_sources().get_current_token_for_pagination()
) )
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key

View File

@ -38,7 +38,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage.data_stores.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -319,7 +319,7 @@ class PresenceHandler(BasePresenceHandler):
is some spurious presence changes that will self-correct. is some spurious presence changes that will self-correct.
""" """
# If the DB pool has already terminated, don't try updating # If the DB pool has already terminated, don't try updating
if not self.store.db.is_running(): if not self.store.db_pool.is_running():
return return
logger.info( logger.info(

View File

@ -548,7 +548,7 @@ class RegistrationHandler(BaseHandler):
address (str|None): the IP address used to perform the registration. address (str|None): the IP address used to perform the registration.
Returns: Returns:
Deferred Awaitable
""" """
if self.hs.config.worker_app: if self.hs.config.worker_app:
return self._register_client( return self._register_client(

View File

@ -22,7 +22,7 @@ import logging
import math import math
import string import string
from collections import OrderedDict from collections import OrderedDict
from typing import Optional, Tuple from typing import Awaitable, Optional, Tuple
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, EventTypes,
@ -1041,7 +1041,7 @@ class RoomEventSource(object):
): ):
# We just ignore the key for now. # We just ignore the key for now.
to_key = await self.get_current_key() to_key = self.get_current_key()
from_token = RoomStreamToken.parse(from_key) from_token = RoomStreamToken.parse(from_key)
if from_token.topological: if from_token.topological:
@ -1081,10 +1081,10 @@ class RoomEventSource(object):
return (events, end_key) return (events, end_key)
def get_current_key(self): def get_current_key(self) -> str:
return self.store.get_room_events_max_id() return "s%d" % (self.store.get_room_max_stream_ordering(),)
def get_current_key_for_room(self, room_id): def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id) return self.store.get_room_events_max_id(room_id)

View File

@ -22,7 +22,8 @@ from unpaddedbase64 import encode_base64
from synapse import types from synapse import types
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase from synapse.events import EventBase
@ -77,6 +78,17 @@ class RoomMemberHandler(object):
if self._is_on_event_persistence_instance: if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence self.persist_event_storage = hs.get_storage().persistence
self._join_rate_limiter_local = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
)
self._join_rate_limiter_remote = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
# This is only used to get at ratelimit function, and # This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as # maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state. # it doesn't store state.
@ -441,7 +453,28 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if not is_host_in_room: if is_host_in_room:
time_now_s = self.clock.time()
allowed, time_allowed = self._join_rate_limiter_local.can_do_action(
requester.user.to_string(),
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
else:
time_now_s = self.clock.time()
allowed, time_allowed = self._join_rate_limiter_remote.can_do_action(
requester.user.to_string(),
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
inviter = await self._get_inviter(target.to_string(), room_id) inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)

View File

@ -340,7 +340,7 @@ class SearchHandler(BaseHandler):
# If client has asked for "context" for each event (i.e. some surrounding # If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that # events and state), fetch that
if event_context is not None: if event_context is not None:
now_token = await self.hs.get_event_sources().get_current_token() now_token = self.hs.get_event_sources().get_current_token()
contexts = {} contexts = {}
for event in allowed_events: for event in allowed_events:

View File

@ -232,7 +232,7 @@ class StatsHandler:
if membership == prev_membership: if membership == prev_membership:
pass # noop pass # noop
if membership == Membership.JOIN: elif membership == Membership.JOIN:
room_stats_delta["joined_members"] += 1 room_stats_delta["joined_members"] += 1
elif membership == Membership.INVITE: elif membership == Membership.INVITE:
room_stats_delta["invited_members"] += 1 room_stats_delta["invited_members"] += 1

View File

@ -961,7 +961,7 @@ class SyncHandler(object):
# this is due to some of the underlying streams not supporting the ability # this is due to some of the underlying streams not supporting the ability
# to query up to a given point. # to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder` # Always use the `now_token` in `SyncResultBuilder`
now_token = await self.event_sources.get_current_token() now_token = self.event_sources.get_current_token()
logger.debug( logger.debug(
"Calculating sync response for %r between %s and %s", "Calculating sync response for %r between %s and %s",

View File

@ -284,8 +284,7 @@ class SimpleHttpClient(object):
ip_blacklist=self._ip_blacklist, ip_blacklist=self._ip_blacklist,
) )
@defer.inlineCallbacks async def request(self, method, uri, data=None, headers=None):
def request(self, method, uri, data=None, headers=None):
""" """
Args: Args:
method (str): HTTP method to use. method (str): HTTP method to use.
@ -330,7 +329,7 @@ class SimpleHttpClient(object):
self.hs.get_reactor(), self.hs.get_reactor(),
cancelled_to_request_timed_out_error, cancelled_to_request_timed_out_error,
) )
response = yield make_deferred_yieldable(request_deferred) response = await make_deferred_yieldable(request_deferred)
incoming_responses_counter.labels(method, response.code).inc() incoming_responses_counter.labels(method, response.code).inc()
logger.info( logger.info(
@ -353,8 +352,7 @@ class SimpleHttpClient(object):
set_tag("error_reason", e.args[0]) set_tag("error_reason", e.args[0])
raise raise
@defer.inlineCallbacks async def post_urlencoded_get_json(self, uri, args={}, headers=None):
def post_urlencoded_get_json(self, uri, args={}, headers=None):
""" """
Args: Args:
uri (str): uri (str):
@ -363,7 +361,7 @@ class SimpleHttpClient(object):
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred[object]: parsed json object: parsed json
Raises: Raises:
HttpResponseException: On a non-2xx HTTP response. HttpResponseException: On a non-2xx HTTP response.
@ -386,11 +384,11 @@ class SimpleHttpClient(object):
if headers: if headers:
actual_headers.update(headers) actual_headers.update(headers)
response = yield self.request( response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes "POST", uri, headers=Headers(actual_headers), data=query_bytes
) )
body = yield make_deferred_yieldable(readBody(response)) body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
return json.loads(body.decode("utf-8")) return json.loads(body.decode("utf-8"))
@ -399,8 +397,7 @@ class SimpleHttpClient(object):
response.code, response.phrase.decode("ascii", errors="replace"), body response.code, response.phrase.decode("ascii", errors="replace"), body
) )
@defer.inlineCallbacks async def post_json_get_json(self, uri, post_json, headers=None):
def post_json_get_json(self, uri, post_json, headers=None):
""" """
Args: Args:
@ -410,7 +407,7 @@ class SimpleHttpClient(object):
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred[object]: parsed json object: parsed json
Raises: Raises:
HttpResponseException: On a non-2xx HTTP response. HttpResponseException: On a non-2xx HTTP response.
@ -429,11 +426,11 @@ class SimpleHttpClient(object):
if headers: if headers:
actual_headers.update(headers) actual_headers.update(headers)
response = yield self.request( response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str "POST", uri, headers=Headers(actual_headers), data=json_str
) )
body = yield make_deferred_yieldable(readBody(response)) body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
return json.loads(body.decode("utf-8")) return json.loads(body.decode("utf-8"))
@ -442,8 +439,7 @@ class SimpleHttpClient(object):
response.code, response.phrase.decode("ascii", errors="replace"), body response.code, response.phrase.decode("ascii", errors="replace"), body
) )
@defer.inlineCallbacks async def get_json(self, uri, args={}, headers=None):
def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI. """ Gets some json from the given URI.
Args: Args:
@ -455,7 +451,7 @@ class SimpleHttpClient(object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON. HTTP body as JSON.
Raises: Raises:
HttpResponseException On a non-2xx HTTP response. HttpResponseException On a non-2xx HTTP response.
@ -466,11 +462,10 @@ class SimpleHttpClient(object):
if headers: if headers:
actual_headers.update(headers) actual_headers.update(headers)
body = yield self.get_raw(uri, args, headers=headers) body = await self.get_raw(uri, args, headers=headers)
return json.loads(body.decode("utf-8")) return json.loads(body.decode("utf-8"))
@defer.inlineCallbacks async def put_json(self, uri, json_body, args={}, headers=None):
def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI. """ Puts some json to the given URI.
Args: Args:
@ -483,7 +478,7 @@ class SimpleHttpClient(object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON. HTTP body as JSON.
Raises: Raises:
HttpResponseException On a non-2xx HTTP response. HttpResponseException On a non-2xx HTTP response.
@ -504,11 +499,11 @@ class SimpleHttpClient(object):
if headers: if headers:
actual_headers.update(headers) actual_headers.update(headers)
response = yield self.request( response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str "PUT", uri, headers=Headers(actual_headers), data=json_str
) )
body = yield make_deferred_yieldable(readBody(response)) body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
return json.loads(body.decode("utf-8")) return json.loads(body.decode("utf-8"))
@ -517,8 +512,7 @@ class SimpleHttpClient(object):
response.code, response.phrase.decode("ascii", errors="replace"), body response.code, response.phrase.decode("ascii", errors="replace"), body
) )
@defer.inlineCallbacks async def get_raw(self, uri, args={}, headers=None):
def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI. """ Gets raw text from the given URI.
Args: Args:
@ -530,7 +524,7 @@ class SimpleHttpClient(object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as bytes. HTTP body as bytes.
Raises: Raises:
HttpResponseException on a non-2xx HTTP response. HttpResponseException on a non-2xx HTTP response.
@ -543,9 +537,9 @@ class SimpleHttpClient(object):
if headers: if headers:
actual_headers.update(headers) actual_headers.update(headers)
response = yield self.request("GET", uri, headers=Headers(actual_headers)) response = await self.request("GET", uri, headers=Headers(actual_headers))
body = yield make_deferred_yieldable(readBody(response)) body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
return body return body
@ -557,8 +551,7 @@ class SimpleHttpClient(object):
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out. # The two should be factored out.
@defer.inlineCallbacks async def get_file(self, url, output_stream, max_size=None, headers=None):
def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL """GETs a file from a given URL
Args: Args:
url (str): The URL to GET url (str): The URL to GET
@ -574,7 +567,7 @@ class SimpleHttpClient(object):
if headers: if headers:
actual_headers.update(headers) actual_headers.update(headers)
response = yield self.request("GET", url, headers=Headers(actual_headers)) response = await self.request("GET", url, headers=Headers(actual_headers))
resp_headers = dict(response.headers.getAllRawHeaders()) resp_headers = dict(response.headers.getAllRawHeaders())
@ -598,7 +591,7 @@ class SimpleHttpClient(object):
# straight back in again # straight back in again
try: try:
length = yield make_deferred_yieldable( length = await make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size) _readBodyToFile(response, output_stream, max_size)
) )
except SynapseError: except SynapseError:

View File

@ -242,10 +242,12 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
no appropriate method exists. Can be overriden in sub classes for no appropriate method exists. Can be overriden in sub classes for
different routing. different routing.
""" """
# Treat HEAD requests as GET requests.
request_method = request.method.decode("ascii")
if request_method == "HEAD":
request_method = "GET"
method_handler = getattr( method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
self, "_async_render_%s" % (request.method.decode("ascii"),), None
)
if method_handler: if method_handler:
raw_callback_return = method_handler(request) raw_callback_return = method_handler(request)
@ -362,11 +364,15 @@ class JsonResource(DirectServeJsonResource):
A tuple of the callback to use, the name of the servlet, and the A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback key word arguments to pass to the callback
""" """
# Treat HEAD requests as GET requests.
request_path = request.path.decode("ascii") request_path = request.path.decode("ascii")
request_method = request.method
if request_method == b"HEAD":
request_method = b"GET"
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request_method, []):
m = path_entry.pattern.match(request_path) m = path_entry.pattern.match(request_path)
if m: if m:
# We found a match! # We found a match!
@ -579,7 +585,7 @@ def set_cors_headers(request: Request):
""" """
request.setHeader(b"Access-Control-Allow-Origin", b"*") request.setHeader(b"Access-Control-Allow-Origin", b"*")
request.setHeader( request.setHeader(
b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS" b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS"
) )
request.setHeader( request.setHeader(
b"Access-Control-Allow-Headers", b"Access-Control-Allow-Headers",

View File

@ -219,7 +219,7 @@ class ModuleApi(object):
Returns: Returns:
Deferred[object]: result of func Deferred[object]: result of func
""" """
return self._store.db.runInteraction(desc, func, *args, **kwargs) return self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
def complete_sso_login( def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str

View File

@ -320,7 +320,7 @@ class Notifier(object):
""" """
user_stream = self.user_to_user_stream.get(user_id) user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None: if user_stream is None:
current_token = await self.event_sources.get_current_token() current_token = self.event_sources.get_current_token()
if room_ids is None: if room_ids is None:
room_ids = await self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
@ -397,7 +397,7 @@ class Notifier(object):
""" """
from_token = pagination_config.from_token from_token = pagination_config.from_token
if not from_token: if not from_token:
from_token = await self.event_sources.get_current_token() from_token = self.event_sources.get_current_token()
limit = pagination_config.limit limit = pagination_config.limit

View File

@ -59,7 +59,6 @@ REQUIREMENTS = [
"pyyaml>=3.11", "pyyaml>=3.11",
"pyasn1>=0.1.9", "pyasn1>=0.1.9",
"pyasn1-modules>=0.0.7", "pyasn1-modules>=0.0.7",
"daemonize>=2.3.1",
"bcrypt>=3.1.0", "bcrypt>=3.1.0",
"pillow>=4.3.0", "pillow>=4.3.0",
"sortedcontainers>=1.4.4", "sortedcontainers>=1.4.4",

View File

@ -20,8 +20,6 @@ import urllib
from inspect import signature from inspect import signature
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
HttpResponseException, HttpResponseException,
@ -101,7 +99,7 @@ class ReplicationEndpoint(object):
assert self.METHOD in ("PUT", "POST", "GET") assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod @abc.abstractmethod
def _serialize_payload(**kwargs): async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request. """Static method that is called when creating a request.
Concrete implementations should have explicit parameters (rather than Concrete implementations should have explicit parameters (rather than
@ -110,9 +108,8 @@ class ReplicationEndpoint(object):
argument list. argument list.
Returns: Returns:
Deferred[dict]|dict: If POST/PUT request then dictionary must be dict: If POST/PUT request then dictionary must be JSON serialisable,
JSON serialisable, otherwise must be appropriate for adding as otherwise must be appropriate for adding as query args.
query args.
""" """
return {} return {}
@ -144,8 +141,7 @@ class ReplicationEndpoint(object):
instance_map = hs.config.worker.instance_map instance_map = hs.config.worker.instance_map
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
@defer.inlineCallbacks async def send_request(instance_name="master", **kwargs):
def send_request(instance_name="master", **kwargs):
if instance_name == local_instance_name: if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self") raise Exception("Trying to send HTTP request to self")
if instance_name == "master": if instance_name == "master":
@ -159,7 +155,7 @@ class ReplicationEndpoint(object):
"Instance %r not in 'instance_map' config" % (instance_name,) "Instance %r not in 'instance_map' config" % (instance_name,)
) )
data = yield cls._serialize_payload(**kwargs) data = await cls._serialize_payload(**kwargs)
url_args = [ url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
@ -197,7 +193,7 @@ class ReplicationEndpoint(object):
headers = {} # type: Dict[bytes, List[bytes]] headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False) inject_active_span_byte_dict(headers, None, check_destination=False)
try: try:
result = yield request_func(uri, data, headers=headers) result = await request_func(uri, data, headers=headers)
break break
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 504 or not cls.RETRY_ON_TIMEOUT: if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
@ -207,7 +203,7 @@ class ReplicationEndpoint(object):
# If we timed out we probably don't need to worry about backing # If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway. # off too much, but lets just wait a little anyway.
yield clock.sleep(1) await clock.sleep(1)
except HttpResponseException as e: except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError # We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And # on the master process that we should send to the client. (And

View File

@ -60,7 +60,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
def _serialize_payload(user_id): async def _serialize_payload(user_id):
return {} return {}
async def _handle_request(self, request, user_id): async def _handle_request(self, request, user_id):

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
@ -67,8 +65,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler self.federation_handler = hs.get_handlers().federation_handler
@staticmethod @staticmethod
@defer.inlineCallbacks async def _serialize_payload(store, event_and_contexts, backfilled):
def _serialize_payload(store, event_and_contexts, backfilled):
""" """
Args: Args:
store store
@ -78,9 +75,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
""" """
event_payloads = [] event_payloads = []
for event, context in event_and_contexts: for event, context in event_and_contexts:
serialized_context = yield defer.ensureDeferred( serialized_context = await context.serialize(event, store)
context.serialize(event, store)
)
event_payloads.append( event_payloads.append(
{ {
@ -156,7 +151,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry() self.registry = hs.get_federation_registry()
@staticmethod @staticmethod
def _serialize_payload(edu_type, origin, content): async def _serialize_payload(edu_type, origin, content):
return {"origin": origin, "content": content} return {"origin": origin, "content": content}
async def _handle_request(self, request, edu_type): async def _handle_request(self, request, edu_type):
@ -199,7 +194,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry() self.registry = hs.get_federation_registry()
@staticmethod @staticmethod
def _serialize_payload(query_type, args): async def _serialize_payload(query_type, args):
""" """
Args: Args:
query_type (str) query_type (str)
@ -240,7 +235,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@staticmethod @staticmethod
def _serialize_payload(room_id, args): async def _serialize_payload(room_id, args):
""" """
Args: Args:
room_id (str) room_id (str)
@ -275,7 +270,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@staticmethod @staticmethod
def _serialize_payload(room_id, room_version): async def _serialize_payload(room_id, room_version):
return {"room_version": room_version.identifier} return {"room_version": room_version.identifier}
async def _handle_request(self, request, room_id): async def _handle_request(self, request, room_id):

View File

@ -36,7 +36,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
def _serialize_payload(user_id, device_id, initial_display_name, is_guest): async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
""" """
Args: Args:
device_id (str|None): Device ID to use, if None a new one is device_id (str|None): Device ID to use, if None a new one is

View File

@ -52,7 +52,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): async def _serialize_payload(
requester, room_id, user_id, remote_room_hosts, content
):
""" """
Args: Args:
requester(Requester) requester(Requester)
@ -112,7 +114,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.member_handler = hs.get_room_member_handler() self.member_handler = hs.get_room_member_handler()
@staticmethod @staticmethod
def _serialize_payload( # type: ignore async def _serialize_payload( # type: ignore
invite_event_id: str, invite_event_id: str,
txn_id: Optional[str], txn_id: Optional[str],
requester: Requester, requester: Requester,
@ -174,7 +176,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
@staticmethod @staticmethod
def _serialize_payload(room_id, user_id, change): async def _serialize_payload(room_id, user_id, change):
""" """
Args: Args:
room_id (str) room_id (str)

View File

@ -50,7 +50,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
@staticmethod @staticmethod
def _serialize_payload(user_id): async def _serialize_payload(user_id):
return {} return {}
async def _handle_request(self, request, user_id): async def _handle_request(self, request, user_id):
@ -92,7 +92,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
@staticmethod @staticmethod
def _serialize_payload(user_id, state, ignore_status_msg=False): async def _serialize_payload(user_id, state, ignore_status_msg=False):
return { return {
"state": state, "state": state,
"ignore_status_msg": ignore_status_msg, "ignore_status_msg": ignore_status_msg,

View File

@ -34,7 +34,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
def _serialize_payload( async def _serialize_payload(
user_id, user_id,
password_hash, password_hash,
was_guest, was_guest,
@ -105,7 +105,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
def _serialize_payload(user_id, auth_result, access_token): async def _serialize_payload(user_id, auth_result, access_token):
""" """
Args: Args:
user_id (str): The user ID that consented user_id (str): The user ID that consented

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
@ -62,8 +60,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
@defer.inlineCallbacks async def _serialize_payload(
def _serialize_payload(
event_id, store, event, context, requester, ratelimit, extra_users event_id, store, event, context, requester, ratelimit, extra_users
): ):
""" """
@ -77,7 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
extra_users (list(UserID)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
""" """
serialized_context = yield defer.ensureDeferred(context.serialize(event, store)) serialized_context = await context.serialize(event, store)
payload = { payload = {
"event": event.get_pdu_json(), "event": event.get_pdu_json(),

View File

@ -54,7 +54,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
self.streams = hs.get_replication_streams() self.streams = hs.get_replication_streams()
@staticmethod @staticmethod
def _serialize_payload(stream_name, from_token, upto_token): async def _serialize_payload(stream_name, from_token, upto_token):
return {"from_token": from_token, "upto_token": upto_token} return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request(self, request, stream_name): async def _handle_request(self, request, stream_name):

View File

@ -16,8 +16,8 @@
import logging import logging
from typing import Optional from typing import Optional
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(CacheInvalidationWorkerStore): class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs) super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = MultiWriterIdGenerator( self._cache_id_gen = MultiWriterIdGenerator(

View File

@ -17,13 +17,13 @@
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.data_stores.main.tags import TagsWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.database import Database from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker( self._account_data_id_gen = SlavedIdTracker(
db_conn, db_conn,
"account_data", "account_data",

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.appservice import ( from synapse.storage.databases.main.appservice import (
ApplicationServiceTransactionWorkerStore, ApplicationServiceTransactionWorkerStore,
ApplicationServiceWorkerStore, ApplicationServiceWorkerStore,
) )

View File

@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
class SlavedClientIpStore(BaseSlavedStore): class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedClientIpStore, self).__init__(database, db_conn, hs) super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(

View File

@ -16,14 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker( self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id" db_conn, "device_inbox", "stream_id"

View File

@ -16,14 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.database import Database from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceStore, self).__init__(database, db_conn, hs) super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
self.hs = hs self.hs = hs

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.directory import DirectoryWorkerStore from synapse.storage.databases.main.directory import DirectoryWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore

View File

@ -15,18 +15,18 @@
# limitations under the License. # limitations under the License.
import logging import logging
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.data_stores.main.event_push_actions import ( from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore, EventPushActionsWorkerStore,
) )
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.relations import RelationsWorkerStore from synapse.storage.databases.main.relations import RelationsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.data_stores.main.stream import StreamWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
@ -55,11 +55,11 @@ class SlavedEventStore(
RelationsWorkerStore, RelationsWorkerStore,
BaseSlavedStore, BaseSlavedStore,
): ):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedEventStore, self).__init__(database, db_conn, hs) super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"current_state_delta_stream", "current_state_delta_stream",
entity_column="room_id", entity_column="room_id",

View File

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.filtering import FilteringStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.filtering import FilteringStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
class SlavedFilteringStore(BaseSlavedStore): class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedFilteringStore, self).__init__(database, db_conn, hs) super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired # Filters are immutable so this cache doesn't need to be expired

View File

@ -16,13 +16,13 @@
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
self.hs = hs self.hs = hs

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.keys import KeyStore from synapse.storage.databases.main.keys import KeyStore
# KeyStore isn't really safe to use from a worker, but for now we do so and hope that # KeyStore isn't really safe to use from a worker, but for now we do so and hope that
# the races it creates aren't too bad. # the races it creates aren't too bad.

View File

@ -15,8 +15,8 @@
from synapse.replication.tcp.streams import PresenceStream from synapse.replication.tcp.streams import PresenceStream
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.data_stores.main.presence import PresenceStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.presence import PresenceStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore): class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedPresenceStore, self).__init__(database, db_conn, hs) super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.storage.data_stores.main.profile import ProfileWorkerStore from synapse.storage.databases.main.profile import ProfileWorkerStore
class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
from synapse.replication.tcp.streams import PushRulesStream from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from .events import SlavedEventStore from .events import SlavedEventStore

View File

@ -15,15 +15,15 @@
# limitations under the License. # limitations under the License.
from synapse.replication.tcp.streams import PushersStream from synapse.replication.tcp.streams import PushersStream
from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.pusher import PusherWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedPusherStore, self).__init__(database, db_conn, hs) super(SlavedPusherStore, self).__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker( self._pushers_id_gen = SlavedIdTracker(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]

View File

@ -15,15 +15,15 @@
# limitations under the License. # limitations under the License.
from synapse.replication.tcp.streams import ReceiptsStream from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor # We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id # needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker( self._receipts_id_gen = SlavedIdTracker(

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.registration import RegistrationWorkerStore from synapse.storage.databases.main.registration import RegistrationWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore

View File

@ -14,15 +14,15 @@
# limitations under the License. # limitations under the License.
from synapse.replication.tcp.streams import PublicRoomsStream from synapse.replication.tcp.streams import PublicRoomsStream
from synapse.storage.data_stores.main.room import RoomWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.room import RoomWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(RoomWorkerStore, BaseSlavedStore): class RoomStore(RoomWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomStore, self).__init__(database, db_conn, hs) super(RoomStore, self).__init__(database, db_conn, hs)
self._public_room_id_gen = SlavedIdTracker( self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id" db_conn, "public_room_list_stream", "stream_id"

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.transactions import TransactionStore from synapse.storage.databases.main.transactions import TransactionStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore

View File

@ -31,7 +31,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin, assert_user_is_admin,
historical_admin_path_patterns, historical_admin_path_patterns,
) )
from synapse.storage.data_stores.main.room import RoomSortOrder from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.types import RoomAlias, RoomID, UserID, create_requester
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -444,7 +444,7 @@ class RoomMemberListRestServlet(RestServlet):
async def on_GET(self, request, room_id): async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler handler = self.message_handler
# request the state as of a given event, as identified by a stream token, # request the state as of a given event, as identified by a stream token,

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import inspect
import logging import logging
import os import os
import shutil import shutil
@ -30,7 +29,7 @@ from .filepath import MediaFilePaths
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from .storage_provider import StorageProvider from .storage_provider import StorageProviderWrapper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,7 +49,7 @@ class MediaStorage(object):
hs: "HomeServer", hs: "HomeServer",
local_media_directory: str, local_media_directory: str,
filepaths: MediaFilePaths, filepaths: MediaFilePaths,
storage_providers: Sequence["StorageProvider"], storage_providers: Sequence["StorageProviderWrapper"],
): ):
self.hs = hs self.hs = hs
self.local_media_directory = local_media_directory self.local_media_directory = local_media_directory
@ -115,11 +114,7 @@ class MediaStorage(object):
async def finish(): async def finish():
for provider in self.storage_providers: for provider in self.storage_providers:
# store_file is supposed to return an Awaitable, but guard await provider.store_file(path, file_info)
# against improper implementations.
result = provider.store_file(path, file_info)
if inspect.isawaitable(result):
await result
finished_called[0] = True finished_called[0] = True
@ -153,11 +148,7 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb")) return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers: for provider in self.storage_providers:
res = provider.fetch(path, file_info) # type: Any res = await provider.fetch(path, file_info) # type: Any
# Fetch is supposed to return an Awaitable[Responder], but guard
# against improper implementations.
if inspect.isawaitable(res):
res = await res
if res: if res:
logger.debug("Streaming %s from %s", path, provider) logger.debug("Streaming %s from %s", path, provider)
return res return res
@ -184,11 +175,7 @@ class MediaStorage(object):
os.makedirs(dirname) os.makedirs(dirname)
for provider in self.storage_providers: for provider in self.storage_providers:
res = provider.fetch(path, file_info) # type: Any res = await provider.fetch(path, file_info) # type: Any
# Fetch is supposed to return an Awaitable[Responder], but guard
# against improper implementations.
if inspect.isawaitable(res):
res = await res
if res: if res:
with res: with res:
consumer = BackgroundFileConsumer( consumer = BackgroundFileConsumer(

View File

@ -586,7 +586,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("Running url preview cache expiry") logger.debug("Running url preview cache expiry")
if not (await self.store.db.updates.has_completed_background_updates()): if not (await self.store.db_pool.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry") logger.info("Still running DB updates; skipping expiry")
return return

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
import os import os
import shutil import shutil
@ -88,12 +89,18 @@ class StorageProviderWrapper(StorageProvider):
return None return None
if self.store_synchronous: if self.store_synchronous:
return await self.backend.store_file(path, file_info) # store_file is supposed to return an Awaitable, but guard
# against improper implementations.
result = self.backend.store_file(path, file_info)
if inspect.isawaitable(result):
return await result
else: else:
# TODO: Handle errors. # TODO: Handle errors.
def store(): async def store():
try: try:
return self.backend.store_file(path, file_info) result = self.backend.store_file(path, file_info)
if inspect.isawaitable(result):
return await result
except Exception: except Exception:
logger.exception("Error storing file") logger.exception("Error storing file")
@ -101,7 +108,11 @@ class StorageProviderWrapper(StorageProvider):
return None return None
async def fetch(self, path, file_info): async def fetch(self, path, file_info):
return await self.backend.fetch(path, file_info) # store_file is supposed to return an Awaitable, but guard
# against improper implementations.
result = self.backend.fetch(path, file_info)
if inspect.isawaitable(result):
return await result
class FileStorageProviderBackend(StorageProvider): class FileStorageProviderBackend(StorageProvider):

View File

@ -105,7 +105,7 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender, WorkerServerNoticesSender,
) )
from synapse.state import StateHandler, StateResolutionHandler from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore, DataStores, Storage from synapse.storage import Databases, DataStore, Storage
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.util import Clock from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
@ -280,7 +280,7 @@ class HomeServer(object):
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.start_time = int(self.get_clock().time()) self.start_time = int(self.get_clock().time())
self.datastores = DataStores(self.DATASTORE_CLASS, self) self.datastores = Databases(self.DATASTORE_CLASS, self)
logger.info("Finished setting up.") logger.info("Finished setting up.")
def setup_master(self): def setup_master(self):

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
@ -55,14 +56,11 @@ class ConsentServerNotices(object):
self._consent_uri_builder = ConsentURIBuilder(hs.config) self._consent_uri_builder = ConsentURIBuilder(hs.config)
async def maybe_send_server_notice_to_user(self, user_id): async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
"""Check if we need to send a notice to this user, and does so if so """Check if we need to send a notice to this user, and does so if so
Args: Args:
user_id (str): user to check user_id: user to check
Returns:
Deferred
""" """
if self._server_notice_content is None: if self._server_notice_content is None:
# not enabled # not enabled
@ -105,7 +103,7 @@ class ConsentServerNotices(object):
self._users_in_progress.remove(user_id) self._users_in_progress.remove(user_id)
def copy_with_str_subst(x, substitutions): def copy_with_str_subst(x: Any, substitutions: Any) -> Any:
"""Deep-copy a structure, carrying out string substitions on any strings """Deep-copy a structure, carrying out string substitions on any strings
Args: Args:
@ -121,7 +119,7 @@ def copy_with_str_subst(x, substitutions):
if isinstance(x, dict): if isinstance(x, dict):
return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()} return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()}
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
return [copy_with_str_subst(y) for y in x] return [copy_with_str_subst(y, substitutions) for y in x]
# assume it's uninterested and can be shallow-copied. # assume it's uninterested and can be shallow-copied.
return x return x

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Tuple
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, EventTypes,
@ -52,7 +53,7 @@ class ResourceLimitsServerNotices(object):
and not hs.config.hs_disabled and not hs.config.hs_disabled
) )
async def maybe_send_server_notice_to_user(self, user_id): async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
"""Check if we need to send a notice to this user, this will be true in """Check if we need to send a notice to this user, this will be true in
two cases. two cases.
1. The server has reached its limit does not reflect this 1. The server has reached its limit does not reflect this
@ -60,10 +61,7 @@ class ResourceLimitsServerNotices(object):
actually the server is fine actually the server is fine
Args: Args:
user_id (str): user to check user_id: user to check
Returns:
Deferred
""" """
if not self._enabled: if not self._enabled:
return return
@ -115,18 +113,20 @@ class ResourceLimitsServerNotices(object):
elif not currently_blocked and limit_msg: elif not currently_blocked and limit_msg:
# Room is not notifying of a block, when it ought to be. # Room is not notifying of a block, when it ought to be.
await self._apply_limit_block_notification( await self._apply_limit_block_notification(
user_id, limit_msg, limit_type user_id, limit_msg, limit_type # type: ignore
) )
except SynapseError as e: except SynapseError as e:
logger.error("Error sending resource limits server notice: %s", e) logger.error("Error sending resource limits server notice: %s", e)
async def _remove_limit_block_notification(self, user_id, ref_events): async def _remove_limit_block_notification(
self, user_id: str, ref_events: List[str]
) -> None:
"""Utility method to remove limit block notifications from the server """Utility method to remove limit block notifications from the server
notices room. notices room.
Args: Args:
user_id (str): user to notify user_id: user to notify
ref_events (list[str]): The event_ids of pinned events that are unrelated to ref_events: The event_ids of pinned events that are unrelated to
limit blocking and need to be preserved. limit blocking and need to be preserved.
""" """
content = {"pinned": ref_events} content = {"pinned": ref_events}
@ -135,15 +135,15 @@ class ResourceLimitsServerNotices(object):
) )
async def _apply_limit_block_notification( async def _apply_limit_block_notification(
self, user_id, event_body, event_limit_type self, user_id: str, event_body: str, event_limit_type: str
): ) -> None:
"""Utility method to apply limit block notifications in the server """Utility method to apply limit block notifications in the server
notices room. notices room.
Args: Args:
user_id (str): user to notify user_id: user to notify
event_body(str): The human readable text that describes the block. event_body: The human readable text that describes the block.
event_limit_type(str): Specifies the type of block e.g. monthly active user event_limit_type: Specifies the type of block e.g. monthly active user
limit has been exceeded. limit has been exceeded.
""" """
content = { content = {
@ -162,7 +162,7 @@ class ResourceLimitsServerNotices(object):
user_id, content, EventTypes.Pinned, "" user_id, content, EventTypes.Pinned, ""
) )
async def _check_and_set_tags(self, user_id, room_id): async def _check_and_set_tags(self, user_id: str, room_id: str) -> None:
""" """
Since server notices rooms were originally not with tags, Since server notices rooms were originally not with tags,
important to check that tags have been set correctly important to check that tags have been set correctly
@ -182,17 +182,16 @@ class ResourceLimitsServerNotices(object):
) )
self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
async def _is_room_currently_blocked(self, room_id): async def _is_room_currently_blocked(self, room_id: str) -> Tuple[bool, List[str]]:
""" """
Determines if the room is currently blocked Determines if the room is currently blocked
Args: Args:
room_id(str): The room id of the server notices room room_id: The room id of the server notices room
Returns: Returns:
Deferred[Tuple[bool, List]]:
bool: Is the room currently blocked bool: Is the room currently blocked
list: The list of pinned events that are unrelated to limit blocking list: The list of pinned event IDs that are unrelated to limit blocking
This list can be used as a convenience in the case where the block This list can be used as a convenience in the case where the block
is to be lifted and the remaining pinned event references need to be is to be lifted and the remaining pinned event references need to be
preserved preserved
@ -207,7 +206,7 @@ class ResourceLimitsServerNotices(object):
# The user has yet to join the server notices room # The user has yet to join the server notices room
pass pass
referenced_events = [] referenced_events = [] # type: List[str]
if pinned_state_event is not None: if pinned_state_event is not None:
referenced_events = list(pinned_state_event.content.get("pinned", [])) referenced_events = list(pinned_state_event.content.get("pinned", []))

View File

@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.events import EventBase
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -50,20 +52,21 @@ class ServerNoticesManager(object):
return self._config.server_notices_mxid is not None return self._config.server_notices_mxid is not None
async def send_notice( async def send_notice(
self, user_id, event_content, type=EventTypes.Message, state_key=None self,
): user_id: str,
event_content: dict,
type: str = EventTypes.Message,
state_key: Optional[bool] = None,
) -> EventBase:
"""Send a notice to the given user """Send a notice to the given user
Creates the server notices room, if none exists. Creates the server notices room, if none exists.
Args: Args:
user_id (str): mxid of user to send event to. user_id: mxid of user to send event to.
event_content (dict): content of event to send event_content: content of event to send
type(EventTypes): type of event type: type of event
is_state_event(bool): Is the event a state event is_state_event: Is the event a state event
Returns:
Deferred[FrozenEvent]
""" """
room_id = await self.get_or_create_notice_room_for_user(user_id) room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id) await self.maybe_invite_user_to_room(user_id, room_id)
@ -89,17 +92,17 @@ class ServerNoticesManager(object):
return event return event
@cached() @cached()
async def get_or_create_notice_room_for_user(self, user_id): async def get_or_create_notice_room_for_user(self, user_id: str) -> str:
"""Get the room for notices for a given user """Get the room for notices for a given user
If we have not yet created a notice room for this user, create it, but don't If we have not yet created a notice room for this user, create it, but don't
invite the user to it. invite the user to it.
Args: Args:
user_id (str): complete user id for the user we want a room for user_id: complete user id for the user we want a room for
Returns: Returns:
str: room id of notice room. room id of notice room.
""" """
if not self.is_enabled(): if not self.is_enabled():
raise Exception("Server notices not enabled") raise Exception("Server notices not enabled")
@ -163,7 +166,7 @@ class ServerNoticesManager(object):
logger.info("Created server notices room %s for %s", room_id, user_id) logger.info("Created server notices room %s for %s", room_id, user_id)
return room_id return room_id
async def maybe_invite_user_to_room(self, user_id: str, room_id: str): async def maybe_invite_user_to_room(self, user_id: str, room_id: str) -> None:
"""Invite the given user to the given server room, unless the user has already """Invite the given user to the given server room, unless the user has already
joined or been invited to it. joined or been invited to it.

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, Union
from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.consent_server_notices import ConsentServerNotices
from synapse.server_notices.resource_limits_server_notices import ( from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices, ResourceLimitsServerNotices,
@ -32,22 +34,22 @@ class ServerNoticesSender(object):
self._server_notices = ( self._server_notices = (
ConsentServerNotices(hs), ConsentServerNotices(hs),
ResourceLimitsServerNotices(hs), ResourceLimitsServerNotices(hs),
) ) # type: Iterable[Union[ConsentServerNotices, ResourceLimitsServerNotices]]
async def on_user_syncing(self, user_id): async def on_user_syncing(self, user_id: str) -> None:
"""Called when the user performs a sync operation. """Called when the user performs a sync operation.
Args: Args:
user_id (str): mxid of user who synced user_id: mxid of user who synced
""" """
for sn in self._server_notices: for sn in self._server_notices:
await sn.maybe_send_server_notice_to_user(user_id) await sn.maybe_send_server_notice_to_user(user_id)
async def on_user_ip(self, user_id): async def on_user_ip(self, user_id: str) -> None:
"""Called on the master when a worker process saw a client request. """Called on the master when a worker process saw a client request.
Args: Args:
user_id (str): mxid user_id: mxid
""" """
# The synchrotrons use a stubbed version of ServerNoticesSender, so # The synchrotrons use a stubbed version of ServerNoticesSender, so
# we check for notices to send to the user in on_user_ip as well as # we check for notices to send to the user in on_user_ip as well as

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
class WorkerServerNoticesSender(object): class WorkerServerNoticesSender(object):
@ -24,24 +23,18 @@ class WorkerServerNoticesSender(object):
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
def on_user_syncing(self, user_id): async def on_user_syncing(self, user_id: str) -> None:
"""Called when the user performs a sync operation. """Called when the user performs a sync operation.
Args: Args:
user_id (str): mxid of user who synced user_id: mxid of user who synced
Returns:
Deferred
""" """
return defer.succeed(None) return None
def on_user_ip(self, user_id): async def on_user_ip(self, user_id: str) -> None:
"""Called on the master when a worker process saw a client request. """Called on the master when a worker process saw a client request.
Args: Args:
user_id (str): mxid user_id: mxid
Returns:
Deferred
""" """
raise AssertionError("on_user_ip unexpectedly called on worker") raise AssertionError("on_user_ip unexpectedly called on worker")

View File

@ -28,7 +28,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap from synapse.types import StateMap
from synapse.util import Clock from synapse.util import Clock

View File

@ -17,18 +17,19 @@
""" """
The storage layer is split up into multiple parts to allow Synapse to run The storage layer is split up into multiple parts to allow Synapse to run
against different configurations of databases (e.g. single or multiple against different configurations of databases (e.g. single or multiple
databases). The `Database` class represents a single physical database. The databases). The `DatabasePool` class represents connections to a single physical
`data_stores` are classes that talk directly to a `Database` instance and have database. The `databases` are classes that talk directly to a `DatabasePool`
associated schemas, background updates, etc. On top of those there are classes instance and have associated schemas, background updates, etc. On top of those
that provide high level interfaces that combine calls to multiple `data_stores`. there are classes that provide high level interfaces that combine calls to
multiple `databases`.
There are also schemas that get applied to every database, regardless of the There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`. stored in `synapse.storage.schema`.
""" """
from synapse.storage.data_stores import DataStores from synapse.storage.databases import Databases
from synapse.storage.data_stores.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.storage.persist_events import EventsPersistenceStorage from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage from synapse.storage.state import StateGroupStorage
@ -40,7 +41,7 @@ class Storage(object):
"""The high level interfaces for talking to various storage layers. """The high level interfaces for talking to various storage layers.
""" """
def __init__(self, hs, stores: DataStores): def __init__(self, hs, stores: Databases):
# We include the main data store here mainly so that we don't have to # We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level # rewrite all the existing code to split it into high vs low level
# interfaces. # interfaces.

View File

@ -23,7 +23,7 @@ from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id from synapse.types import Collection, get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,11 +37,11 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database). per data store (and not one per physical database).
""" """
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = database.engine self.database_engine = database.engine
self.db = database self.db_pool = database
self.rand = random.SystemRandom() self.rand = random.SystemRandom()
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):

View File

@ -88,7 +88,7 @@ class BackgroundUpdater(object):
def __init__(self, hs, database): def __init__(self, hs, database):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.db = database self.db_pool = database
# if a background update is currently running, its name. # if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str] self._current_background_update = None # type: Optional[str]
@ -139,7 +139,7 @@ class BackgroundUpdater(object):
# otherwise, check if there are updates to be run. This is important, # otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates # as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen. # itself, but still wants to wait for them to happen.
updates = await self.db.simple_select_onecol( updates = await self.db_pool.simple_select_onecol(
"background_updates", "background_updates",
keyvalues=None, keyvalues=None,
retcol="1", retcol="1",
@ -160,7 +160,7 @@ class BackgroundUpdater(object):
if update_name == self._current_background_update: if update_name == self._current_background_update:
return False return False
update_exists = await self.db.simple_select_one_onecol( update_exists = await self.db_pool.simple_select_one_onecol(
"background_updates", "background_updates",
keyvalues={"update_name": update_name}, keyvalues={"update_name": update_name},
retcol="1", retcol="1",
@ -189,10 +189,10 @@ class BackgroundUpdater(object):
ORDER BY ordering, update_name ORDER BY ordering, update_name
""" """
) )
return self.db.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
if not self._current_background_update: if not self._current_background_update:
all_pending_updates = await self.db.runInteraction( all_pending_updates = await self.db_pool.runInteraction(
"background_updates", get_background_updates_txn, "background_updates", get_background_updates_txn,
) )
if not all_pending_updates: if not all_pending_updates:
@ -243,7 +243,7 @@ class BackgroundUpdater(object):
else: else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
progress_json = await self.db.simple_select_one_onecol( progress_json = await self.db_pool.simple_select_one_onecol(
"background_updates", "background_updates",
keyvalues={"update_name": update_name}, keyvalues={"update_name": update_name},
retcol="progress_json", retcol="progress_json",
@ -402,7 +402,7 @@ class BackgroundUpdater(object):
logger.debug("[SQL] %s", sql) logger.debug("[SQL] %s", sql)
c.execute(sql) c.execute(sql)
if isinstance(self.db.engine, engines.PostgresEngine): if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner = create_index_psql runner = create_index_psql
elif psql_only: elif psql_only:
runner = None runner = None
@ -413,7 +413,7 @@ class BackgroundUpdater(object):
def updater(progress, batch_size): def updater(progress, batch_size):
if runner is not None: if runner is not None:
logger.info("Adding index %s to %s", index_name, table) logger.info("Adding index %s to %s", index_name, table)
yield self.db.runWithConnection(runner) yield self.db_pool.runWithConnection(runner)
yield self._end_background_update(update_name) yield self._end_background_update(update_name)
return 1 return 1
@ -433,7 +433,7 @@ class BackgroundUpdater(object):
% update_name % update_name
) )
self._current_background_update = None self._current_background_update = None
return self.db.simple_delete_one( return self.db_pool.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name} "background_updates", keyvalues={"update_name": update_name}
) )
@ -445,7 +445,7 @@ class BackgroundUpdater(object):
progress: The progress of the update. progress: The progress of the update.
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"background_update_progress", "background_update_progress",
self._background_update_progress_txn, self._background_update_progress_txn,
update_name, update_name,
@ -463,7 +463,7 @@ class BackgroundUpdater(object):
progress_json = json.dumps(progress) progress_json = json.dumps(progress)
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
"background_updates", "background_updates",
keyvalues={"update_name": update_name}, keyvalues={"update_name": update_name},

View File

@ -279,7 +279,7 @@ class PerformanceCounters(object):
return top_n_counters return top_n_counters
class Database(object): class DatabasePool(object):
"""Wraps a single physical database and connection pool. """Wraps a single physical database and connection pool.
A single database may be used by multiple data stores. A single database may be used by multiple data stores.

View File

@ -15,17 +15,17 @@
import logging import logging
from synapse.storage.data_stores.main.events import PersistEventsStore from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.data_stores.state import StateGroupDataStore from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.database import Database, make_conn from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DataStores(object): class Databases(object):
"""The various data stores. """The various databases.
These are low level interfaces to physical databases. These are low level interfaces to physical databases.
@ -51,12 +51,12 @@ class DataStores(object):
engine.check_database(db_conn) engine.check_database(db_conn)
prepare_database( prepare_database(
db_conn, engine, hs.config, data_stores=database_config.data_stores, db_conn, engine, hs.config, databases=database_config.databases,
) )
database = Database(hs, database_config, engine) database = DatabasePool(hs, database_config, engine)
if "main" in database_config.data_stores: if "main" in database_config.databases:
logger.info("Starting 'main' data store") logger.info("Starting 'main' data store")
# Sanity check we don't try and configure the main store on # Sanity check we don't try and configure the main store on
@ -73,7 +73,7 @@ class DataStores(object):
hs, database, self.main hs, database, self.main
) )
if "state" in database_config.data_stores: if "state" in database_config.databases:
logger.info("Starting 'state' data store") logger.info("Starting 'state' data store")
# Sanity check we don't try and configure the state store on # Sanity check we don't try and configure the state store on

View File

@ -21,7 +21,7 @@ import time
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
IdGenerator, IdGenerator,
@ -119,7 +119,7 @@ class DataStore(
CacheInvalidationWorkerStore, CacheInvalidationWorkerStore,
ServerMetricsStore, ServerMetricsStore,
): ):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = database.engine self.database_engine = database.engine
@ -174,7 +174,7 @@ class DataStore(
self._presence_on_startup = self._get_active_presence(db_conn) self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self.db.get_cache_dict( presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
db_conn, db_conn,
"presence_stream", "presence_stream",
entity_column="user_id", entity_column="user_id",
@ -188,7 +188,7 @@ class DataStore(
) )
max_device_inbox_id = self._device_inbox_id_gen.get_current_token() max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"device_inbox", "device_inbox",
entity_column="user_id", entity_column="user_id",
@ -203,7 +203,7 @@ class DataStore(
) )
# The federation outbox and the local device inbox uses the same # The federation outbox and the local device inbox uses the same
# stream_id generator. # stream_id generator.
device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"device_federation_outbox", "device_federation_outbox",
entity_column="destination", entity_column="destination",
@ -229,7 +229,7 @@ class DataStore(
) )
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"current_state_delta_stream", "current_state_delta_stream",
entity_column="room_id", entity_column="room_id",
@ -243,7 +243,7 @@ class DataStore(
prefilled_cache=curr_state_delta_prefill, prefilled_cache=curr_state_delta_prefill,
) )
_group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"local_group_updates", "local_group_updates",
entity_column="user_id", entity_column="user_id",
@ -282,7 +282,7 @@ class DataStore(
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,)) txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
txn.close() txn.close()
for row in rows: for row in rows:
@ -295,7 +295,9 @@ class DataStore(
Counts the number of users who used this homeserver in the last 24 hours. Counts the number of users who used this homeserver in the last 24 hours.
""" """
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
return self.db.runInteraction("count_daily_users", self._count_users, yesterday) return self.db_pool.runInteraction(
"count_daily_users", self._count_users, yesterday
)
def count_monthly_users(self): def count_monthly_users(self):
""" """
@ -305,7 +307,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts. amongst other things, includes a 3 day grace period before a user counts.
""" """
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
return self.db.runInteraction( return self.db_pool.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago "count_monthly_users", self._count_users, thirty_days_ago
) )
@ -405,7 +407,7 @@ class DataStore(
return results return results
return self.db.runInteraction("count_r30_users", _count_r30_users) return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self): def _get_start_of_day(self):
""" """
@ -470,7 +472,7 @@ class DataStore(
# frequently # frequently
self._last_user_visit_update = now self._last_user_visit_update = now
return self.db.runInteraction( return self.db_pool.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits "generate_user_daily_visits", _generate_user_daily_visits
) )
@ -481,7 +483,7 @@ class DataStore(
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
return self.db.simple_select_list( return self.db_pool.simple_select_list(
table="users", table="users",
keyvalues={}, keyvalues={},
retcols=[ retcols=[
@ -543,10 +545,12 @@ class DataStore(
where_clause where_clause
) )
txn.execute(sql, args) txn.execute(sql, args)
users = self.db.cursor_to_dict(txn) users = self.db_pool.cursor_to_dict(txn)
return users, count return users, count
return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn) return self.db_pool.runInteraction(
"get_users_paginate_txn", get_users_paginate_txn
)
def search_users(self, term): def search_users(self, term):
"""Function to search users list for one or more users with """Function to search users list for one or more users with
@ -558,7 +562,7 @@ class DataStore(
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
return self.db.simple_search_list( return self.db_pool.simple_search_list(
table="users", table="users",
term=term, term=term,
col="name", col="name",

View File

@ -23,7 +23,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -40,7 +40,7 @@ class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented. # the abstract methods being implemented.
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
account_max = self.get_max_account_data_stream_id() account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max "AccountDataAndTagsChangeCache", account_max
@ -69,7 +69,7 @@ class AccountDataWorkerStore(SQLBaseStore):
""" """
def get_account_data_for_user_txn(txn): def get_account_data_for_user_txn(txn):
rows = self.db.simple_select_list_txn( rows = self.db_pool.simple_select_list_txn(
txn, txn,
"account_data", "account_data",
{"user_id": user_id}, {"user_id": user_id},
@ -80,7 +80,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows row["account_data_type"]: db_to_json(row["content"]) for row in rows
} }
rows = self.db.simple_select_list_txn( rows = self.db_pool.simple_select_list_txn(
txn, txn,
"room_account_data", "room_account_data",
{"user_id": user_id}, {"user_id": user_id},
@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room return global_account_data, by_room
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn "get_account_data_for_user", get_account_data_for_user_txn
) )
@ -104,7 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred: A dict Deferred: A dict
""" """
result = yield self.db.simple_select_one_onecol( result = yield self.db_pool.simple_select_one_onecol(
table="account_data", table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type}, keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content", retcol="content",
@ -129,7 +129,7 @@ class AccountDataWorkerStore(SQLBaseStore):
""" """
def get_account_data_for_room_txn(txn): def get_account_data_for_room_txn(txn):
rows = self.db.simple_select_list_txn( rows = self.db_pool.simple_select_list_txn(
txn, txn,
"room_account_data", "room_account_data",
{"user_id": user_id, "room_id": room_id}, {"user_id": user_id, "room_id": room_id},
@ -140,7 +140,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows row["account_data_type"]: db_to_json(row["content"]) for row in rows
} }
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn "get_account_data_for_room", get_account_data_for_room_txn
) )
@ -158,7 +158,7 @@ class AccountDataWorkerStore(SQLBaseStore):
""" """
def get_account_data_for_room_and_type_txn(txn): def get_account_data_for_room_and_type_txn(txn):
content_json = self.db.simple_select_one_onecol_txn( content_json = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="room_account_data", table="room_account_data",
keyvalues={ keyvalues={
@ -172,7 +172,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None return db_to_json(content_json) if content_json else None
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
) )
@ -202,7 +202,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() return txn.fetchall()
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn "get_updated_global_account_data", get_updated_global_account_data_txn
) )
@ -232,7 +232,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() return txn.fetchall()
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_updated_room_account_data", get_updated_room_account_data_txn "get_updated_room_account_data", get_updated_room_account_data_txn
) )
@ -277,7 +277,7 @@ class AccountDataWorkerStore(SQLBaseStore):
if not changed: if not changed:
return defer.succeed(({}, {})) return defer.succeed(({}, {}))
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
) )
@ -295,7 +295,7 @@ class AccountDataWorkerStore(SQLBaseStore):
class AccountDataStore(AccountDataWorkerStore): class AccountDataStore(AccountDataWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator( self._account_data_id_gen = StreamIdGenerator(
db_conn, db_conn,
"account_data_max_stream_id", "account_data_max_stream_id",
@ -333,7 +333,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as room_account_data has a unique constraint # no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will # on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict. # retry if there is a conflict.
yield self.db.simple_upsert( yield self.db_pool.simple_upsert(
desc="add_room_account_data", desc="add_room_account_data",
table="room_account_data", table="room_account_data",
keyvalues={ keyvalues={
@ -379,7 +379,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as account_data has a unique constraint on # no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if # (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict. # there is a conflict.
yield self.db.simple_upsert( yield self.db_pool.simple_upsert(
desc="add_user_account_data", desc="add_user_account_data",
table="account_data", table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type}, keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@ -427,4 +427,4 @@ class AccountDataStore(AccountDataWorkerStore):
) )
txn.execute(update_max_id_sql, (next_id, next_id)) txn.execute(update_max_id_sql, (next_id, next_id))
return self.db.runInteraction("update_account_data_max_stream_id", _update) return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)

View File

@ -23,8 +23,8 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.events_worker import EventsWorkerStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,7 +49,7 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore): class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.services_cache = load_appservices( self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files hs.hostname, hs.config.app_service_config_files
) )
@ -134,7 +134,7 @@ class ApplicationServiceTransactionWorkerStore(
A Deferred which resolves to a list of ApplicationServices, which A Deferred which resolves to a list of ApplicationServices, which
may be empty. may be empty.
""" """
results = yield self.db.simple_select_list( results = yield self.db_pool.simple_select_list(
"application_services_state", {"state": state}, ["as_id"] "application_services_state", {"state": state}, ["as_id"]
) )
# NB: This assumes this class is linked with ApplicationServiceStore # NB: This assumes this class is linked with ApplicationServiceStore
@ -156,7 +156,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns: Returns:
A Deferred which resolves to ApplicationServiceState. A Deferred which resolves to ApplicationServiceState.
""" """
result = yield self.db.simple_select_one( result = yield self.db_pool.simple_select_one(
"application_services_state", "application_services_state",
{"as_id": service.id}, {"as_id": service.id},
["state"], ["state"],
@ -176,7 +176,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns: Returns:
A Deferred which resolves when the state was set successfully. A Deferred which resolves when the state was set successfully.
""" """
return self.db.simple_upsert( return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state} "application_services_state", {"as_id": service.id}, {"state": state}
) )
@ -217,7 +217,9 @@ class ApplicationServiceTransactionWorkerStore(
) )
return AppServiceTransaction(service=service, id=new_txn_id, events=events) return AppServiceTransaction(service=service, id=new_txn_id, events=events)
return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) return self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn
)
def complete_appservice_txn(self, txn_id, service): def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction. """Completes an application service transaction.
@ -250,7 +252,7 @@ class ApplicationServiceTransactionWorkerStore(
) )
# Set current txn_id for AS to 'txn_id' # Set current txn_id for AS to 'txn_id'
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
"application_services_state", "application_services_state",
{"as_id": service.id}, {"as_id": service.id},
@ -258,13 +260,13 @@ class ApplicationServiceTransactionWorkerStore(
) )
# Delete txn # Delete txn
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
"application_services_txns", "application_services_txns",
{"txn_id": txn_id, "as_id": service.id}, {"txn_id": txn_id, "as_id": service.id},
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"complete_appservice_txn", _complete_appservice_txn "complete_appservice_txn", _complete_appservice_txn
) )
@ -288,7 +290,7 @@ class ApplicationServiceTransactionWorkerStore(
" ORDER BY txn_id ASC LIMIT 1", " ORDER BY txn_id ASC LIMIT 1",
(service.id,), (service.id,),
) )
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if not rows: if not rows:
return None return None
@ -296,7 +298,7 @@ class ApplicationServiceTransactionWorkerStore(
return entry return entry
entry = yield self.db.runInteraction( entry = yield self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
) )
@ -326,7 +328,7 @@ class ApplicationServiceTransactionWorkerStore(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn "set_appservice_last_pos", set_appservice_last_pos_txn
) )
@ -355,7 +357,7 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, [row[1] for row in rows] return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.db.runInteraction( upper_bound, event_ids = yield self.db_pool.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn "get_new_events_for_appservice", get_new_events_for_appservice_txn
) )

View File

@ -26,7 +26,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow, EventsStreamEventRow,
) )
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -39,7 +39,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore): class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
@ -92,7 +92,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return updates, upto_token, limited return updates, upto_token, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn "get_all_updated_caches", get_all_updated_caches_txn
) )
@ -203,7 +203,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return return
cache_func.invalidate(keys) cache_func.invalidate(keys)
await self.db.runInteraction( await self.db_pool.runInteraction(
"invalidate_cache_and_stream", "invalidate_cache_and_stream",
self._send_invalidation_to_replication, self._send_invalidation_to_replication,
cache_func.__name__, cache_func.__name__,
@ -288,7 +288,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if keys is not None: if keys is not None:
keys = list(keys) keys = list(keys)
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="cache_invalidation_stream_by_instance", table="cache_invalidation_stream_by_instance",
values={ values={

Some files were not shown because too many files have changed in this diff Show More