mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-01-26 18:36:59 -05:00
Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/new_push_rules
This commit is contained in:
commit
118a9eafb3
4
.github/ISSUE_TEMPLATE/BUG_REPORT.md
vendored
4
.github/ISSUE_TEMPLATE/BUG_REPORT.md
vendored
@ -4,12 +4,12 @@ about: Create a report to help us improve
|
||||
|
||||
---
|
||||
|
||||
<!--
|
||||
|
||||
**THIS IS NOT A SUPPORT CHANNEL!**
|
||||
**IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**,
|
||||
please ask in **#synapse:matrix.org** (using a matrix.org account if necessary)
|
||||
|
||||
<!--
|
||||
|
||||
If you want to report a security issue, please see https://matrix.org/security-disclosure-policy/
|
||||
|
||||
This is a bug report template. By following the instructions below and
|
||||
|
1
changelog.d/7314.misc
Normal file
1
changelog.d/7314.misc
Normal file
@ -0,0 +1 @@
|
||||
Allow guest access to the `GET /_matrix/client/r0/rooms/{room_id}/members` endpoint, according to MSC2689. Contributed by Awesome Technologies Innovationslabor GmbH.
|
1
changelog.d/7977.bugfix
Normal file
1
changelog.d/7977.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix a bug introduced in Synapse v1.7.2 which caused inaccurate membership counts in the room directory.
|
1
changelog.d/7987.misc
Normal file
1
changelog.d/7987.misc
Normal file
@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7989.misc
Normal file
1
changelog.d/7989.misc
Normal file
@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7996.bugfix
Normal file
1
changelog.d/7996.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix various comments and minor discrepencies in server notices code.
|
1
changelog.d/7999.bugfix
Normal file
1
changelog.d/7999.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix a long standing bug where HTTP HEAD requests resulted in a 400 error.
|
1
changelog.d/8000.doc
Normal file
1
changelog.d/8000.doc
Normal file
@ -0,0 +1 @@
|
||||
Improve workers docs.
|
1
changelog.d/8001.misc
Normal file
1
changelog.d/8001.misc
Normal file
@ -0,0 +1 @@
|
||||
Remove redundant and unreliable signature check for v1 Identity Service lookup responses.
|
1
changelog.d/8003.misc
Normal file
1
changelog.d/8003.misc
Normal file
@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
1
changelog.d/8008.feature
Normal file
1
changelog.d/8008.feature
Normal file
@ -0,0 +1 @@
|
||||
Add rate limiting to users joining rooms.
|
1
changelog.d/8011.bugfix
Normal file
1
changelog.d/8011.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger.
|
1
changelog.d/8012.bugfix
Normal file
1
changelog.d/8012.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger.
|
1
changelog.d/8014.misc
Normal file
1
changelog.d/8014.misc
Normal file
@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
1
changelog.d/8016.misc
Normal file
1
changelog.d/8016.misc
Normal file
@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
1
changelog.d/8024.misc
Normal file
1
changelog.d/8024.misc
Normal file
@ -0,0 +1 @@
|
||||
Reduce less useful output in the newsfragment CI step. Add a link to the changelog section of the contributing guide on error.
|
1
changelog.d/8027.misc
Normal file
1
changelog.d/8027.misc
Normal file
@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
1
changelog.d/8033.misc
Normal file
1
changelog.d/8033.misc
Normal file
@ -0,0 +1 @@
|
||||
Rename storage layer objects to be more sensible.
|
@ -746,6 +746,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
|
||||
# - one for ratelimiting redactions by room admins. If this is not explicitly
|
||||
# set then it uses the same ratelimiting as per rc_message. This is useful
|
||||
# to allow room admins to deal with abuse quickly.
|
||||
# - two for ratelimiting number of rooms a user can join, "local" for when
|
||||
# users are joining rooms the server is already in (this is cheap) vs
|
||||
# "remote" for when users are trying to join rooms not on the server (which
|
||||
# can be more expensive)
|
||||
#
|
||||
# The defaults are as shown below.
|
||||
#
|
||||
@ -771,6 +775,14 @@ log_config: "CONFDIR/SERVERNAME.log.config"
|
||||
#rc_admin_redaction:
|
||||
# per_second: 1
|
||||
# burst_count: 50
|
||||
#
|
||||
#rc_joins:
|
||||
# local:
|
||||
# per_second: 0.1
|
||||
# burst_count: 3
|
||||
# remote:
|
||||
# per_second: 0.01
|
||||
# burst_count: 3
|
||||
|
||||
|
||||
# Ratelimiting settings for incoming federation
|
||||
|
@ -1,7 +1,7 @@
|
||||
worker_app: synapse.app.federation_reader
|
||||
worker_name: federation_reader1
|
||||
|
||||
worker_replication_host: 127.0.0.1
|
||||
worker_replication_port: 9092
|
||||
worker_replication_http_port: 9093
|
||||
|
||||
worker_listeners:
|
||||
|
@ -7,6 +7,6 @@ who are present in a publicly viewable room present on the server.
|
||||
|
||||
The directory info is stored in various tables, which can (typically after
|
||||
DB corruption) get stale or out of sync. If this happens, for now the
|
||||
solution to fix it is to execute the SQL [here](../synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql)
|
||||
solution to fix it is to execute the SQL [here](../synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql)
|
||||
and then restart synapse. This should then start a background task to
|
||||
flush the current tables and regenerate the directory.
|
||||
|
@ -23,7 +23,7 @@ The processes communicate with each other via a Synapse-specific protocol called
|
||||
feeds streams of newly written data between processes so they can be kept in
|
||||
sync with the database state.
|
||||
|
||||
When configured to do so, Synapse uses a
|
||||
When configured to do so, Synapse uses a
|
||||
[Redis pub/sub channel](https://redis.io/topics/pubsub) to send the replication
|
||||
stream between all configured Synapse processes. Additionally, processes may
|
||||
make HTTP requests to each other, primarily for operations which need to wait
|
||||
@ -66,23 +66,31 @@ https://hub.docker.com/r/matrixdotorg/synapse/.
|
||||
|
||||
To make effective use of the workers, you will need to configure an HTTP
|
||||
reverse-proxy such as nginx or haproxy, which will direct incoming requests to
|
||||
the correct worker, or to the main synapse instance. See
|
||||
the correct worker, or to the main synapse instance. See
|
||||
[reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse
|
||||
proxy.
|
||||
|
||||
To enable workers you should create a configuration file for each worker
|
||||
process. Each worker configuration file inherits the configuration of the shared
|
||||
homeserver configuration file. You can then override configuration specific to
|
||||
that worker, e.g. the HTTP listener that it provides (if any); logging
|
||||
configuration; etc. You should minimise the number of overrides though to
|
||||
maintain a usable config.
|
||||
When using workers, each worker process has its own configuration file which
|
||||
contains settings specific to that worker, such as the HTTP listener that it
|
||||
provides (if any), logging configuration, etc.
|
||||
|
||||
Normally, the worker processes are configured to read from a shared
|
||||
configuration file as well as the worker-specific configuration files. This
|
||||
makes it easier to keep common configuration settings synchronised across all
|
||||
the processes.
|
||||
|
||||
The main process is somewhat special in this respect: it does not normally
|
||||
need its own configuration file and can take all of its configuration from the
|
||||
shared configuration file.
|
||||
|
||||
|
||||
### Shared Configuration
|
||||
### Shared configuration
|
||||
|
||||
Normally, only a couple of changes are needed to make an existing configuration
|
||||
file suitable for use with workers. First, you need to enable an "HTTP replication
|
||||
listener" for the main process; and secondly, you need to enable redis-based
|
||||
replication. For example:
|
||||
|
||||
Next you need to add both a HTTP replication listener, used for HTTP requests
|
||||
between processes, and redis config to the shared Synapse configuration file
|
||||
(`homeserver.yaml`). For example:
|
||||
|
||||
```yaml
|
||||
# extend the existing `listeners` section. This defines the ports that the
|
||||
@ -105,7 +113,7 @@ Under **no circumstances** should the replication listener be exposed to the
|
||||
public internet; it has no authentication and is unencrypted.
|
||||
|
||||
|
||||
### Worker Configuration
|
||||
### Worker configuration
|
||||
|
||||
In the config file for each worker, you must specify the type of worker
|
||||
application (`worker_app`), and you should specify a unqiue name for the worker
|
||||
@ -145,6 +153,9 @@ plain HTTP endpoint on port 8083 separately serving various endpoints, e.g.
|
||||
Obviously you should configure your reverse-proxy to route the relevant
|
||||
endpoints to the worker (`localhost:8083` in the above example).
|
||||
|
||||
|
||||
### Running Synapse with workers
|
||||
|
||||
Finally, you need to start your worker processes. This can be done with either
|
||||
`synctl` or your distribution's preferred service manager such as `systemd`. We
|
||||
recommend the use of `systemd` where available: for information on setting up
|
||||
@ -407,6 +418,23 @@ all these to be folded into the `generic_worker` app and to use config to define
|
||||
which processes handle the various proccessing such as push notifications.
|
||||
|
||||
|
||||
## Migration from old config
|
||||
|
||||
There are two main independent changes that have been made: introducing Redis
|
||||
support and merging apps into `synapse.app.generic_worker`. Both these changes
|
||||
are backwards compatible and so no changes to the config are required, however
|
||||
server admins are encouraged to plan to migrate to Redis as the old style direct
|
||||
TCP replication config is deprecated.
|
||||
|
||||
To migrate to Redis add the `redis` config as above, and optionally remove the
|
||||
TCP `replication` listener from master and `worker_replication_port` from worker
|
||||
config.
|
||||
|
||||
To migrate apps to use `synapse.app.generic_worker` simply update the
|
||||
`worker_app` option in the worker configs, and where worker are started (e.g.
|
||||
in systemd service files, but not required for synctl).
|
||||
|
||||
|
||||
## Architectural diagram
|
||||
|
||||
The following shows an example setup using Redis and a reverse proxy:
|
||||
|
@ -3,6 +3,8 @@
|
||||
# A script which checks that an appropriate news file has been added on this
|
||||
# branch.
|
||||
|
||||
echo -e "+++ \033[32mChecking newsfragment\033[m"
|
||||
|
||||
set -e
|
||||
|
||||
# make sure that origin/develop is up to date
|
||||
@ -16,6 +18,8 @@ pr="$BUILDKITE_PULL_REQUEST"
|
||||
if ! git diff --quiet FETCH_HEAD... -- debian; then
|
||||
if git diff --quiet FETCH_HEAD... -- debian/changelog; then
|
||||
echo "Updates to debian directory, but no update to the changelog." >&2
|
||||
echo "!! Please see the contributing guide for help writing your changelog entry:" >&2
|
||||
echo "https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#debian-changelog" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
@ -26,7 +30,12 @@ if ! git diff --name-only FETCH_HEAD... | grep -qv '^debian/'; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
tox -qe check-newsfragment
|
||||
# Print a link to the contributing guide if the user makes a mistake
|
||||
CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing your changelog entry:
|
||||
https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog"
|
||||
|
||||
# If check-newsfragment returns a non-zero exit code, print the contributing guide and exit
|
||||
tox -qe check-newsfragment || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1)
|
||||
|
||||
echo
|
||||
echo "--------------------------"
|
||||
@ -38,6 +47,7 @@ for f in `git diff --name-only FETCH_HEAD... -- changelog.d`; do
|
||||
lastchar=`tr -d '\n' < $f | tail -c 1`
|
||||
if [ $lastchar != '.' -a $lastchar != '!' ]; then
|
||||
echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2
|
||||
echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -47,5 +57,6 @@ done
|
||||
|
||||
if [[ -n "$pr" && "$matched" -eq 0 ]]; then
|
||||
echo -e "\e[31mERROR: Did not find a news fragment with the right number: expected changelog.d/$pr.*.\e[39m" >&2
|
||||
echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
@ -40,7 +40,7 @@ class MockHomeserver(HomeServer):
|
||||
config.server_name, reactor=reactor, config=config, **kwargs
|
||||
)
|
||||
|
||||
self.version_string = "Synapse/"+get_version_string(synapse)
|
||||
self.version_string = "Synapse/" + get_version_string(synapse)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -86,7 +86,7 @@ if __name__ == "__main__":
|
||||
store = hs.get_datastore()
|
||||
|
||||
async def run_background_updates():
|
||||
await store.db.updates.run_background_updates(sleep=False)
|
||||
await store.db_pool.updates.run_background_updates(sleep=False)
|
||||
# Stop the reactor to exit the script once every background update is run.
|
||||
reactor.stop()
|
||||
|
||||
|
@ -35,31 +35,29 @@ from synapse.logging.context import (
|
||||
make_deferred_yieldable,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore
|
||||
from synapse.storage.data_stores.main.deviceinbox import (
|
||||
DeviceInboxBackgroundUpdateStore,
|
||||
)
|
||||
from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore
|
||||
from synapse.storage.data_stores.main.events_bg_updates import (
|
||||
from synapse.storage.database import DatabasePool, make_conn
|
||||
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.events_bg_updates import (
|
||||
EventsBackgroundUpdatesStore,
|
||||
)
|
||||
from synapse.storage.data_stores.main.media_repository import (
|
||||
from synapse.storage.databases.main.media_repository import (
|
||||
MediaRepositoryBackgroundUpdateStore,
|
||||
)
|
||||
from synapse.storage.data_stores.main.registration import (
|
||||
from synapse.storage.databases.main.registration import (
|
||||
RegistrationBackgroundUpdateStore,
|
||||
find_max_generated_user_id_localpart,
|
||||
)
|
||||
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore
|
||||
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
|
||||
from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore
|
||||
from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore
|
||||
from synapse.storage.data_stores.main.stats import StatsStore
|
||||
from synapse.storage.data_stores.main.user_directory import (
|
||||
from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.stats import StatsStore
|
||||
from synapse.storage.databases.main.user_directory import (
|
||||
UserDirectoryBackgroundUpdateStore,
|
||||
)
|
||||
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
|
||||
from synapse.storage.database import Database, make_conn
|
||||
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
from synapse.util import Clock
|
||||
@ -175,14 +173,14 @@ class Store(
|
||||
StatsStore,
|
||||
):
|
||||
def execute(self, f, *args, **kwargs):
|
||||
return self.db.runInteraction(f.__name__, f, *args, **kwargs)
|
||||
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
||||
|
||||
def execute_sql(self, sql, *args):
|
||||
def r(txn):
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db.runInteraction("execute_sql", r)
|
||||
return self.db_pool.runInteraction("execute_sql", r)
|
||||
|
||||
def insert_many_txn(self, txn, table, headers, rows):
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||
@ -227,7 +225,7 @@ class Porter(object):
|
||||
async def setup_table(self, table):
|
||||
if table in APPEND_ONLY_TABLES:
|
||||
# It's safe to just carry on inserting.
|
||||
row = await self.postgres_store.db.simple_select_one(
|
||||
row = await self.postgres_store.db_pool.simple_select_one(
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
retcols=("forward_rowid", "backward_rowid"),
|
||||
@ -244,7 +242,7 @@ class Porter(object):
|
||||
) = await self._setup_sent_transactions()
|
||||
backward_chunk = 0
|
||||
else:
|
||||
await self.postgres_store.db.simple_insert(
|
||||
await self.postgres_store.db_pool.simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={
|
||||
"table_name": table,
|
||||
@ -274,7 +272,7 @@ class Porter(object):
|
||||
|
||||
await self.postgres_store.execute(delete_all)
|
||||
|
||||
await self.postgres_store.db.simple_insert(
|
||||
await self.postgres_store.db_pool.simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
|
||||
)
|
||||
@ -318,7 +316,7 @@ class Porter(object):
|
||||
if table == "user_directory_stream_pos":
|
||||
# We need to make sure there is a single row, `(X, null), as that is
|
||||
# what synapse expects to be there.
|
||||
await self.postgres_store.db.simple_insert(
|
||||
await self.postgres_store.db_pool.simple_insert(
|
||||
table=table, values={"stream_id": None}
|
||||
)
|
||||
self.progress.update(table, table_size) # Mark table as done
|
||||
@ -359,7 +357,7 @@ class Porter(object):
|
||||
|
||||
return headers, forward_rows, backward_rows
|
||||
|
||||
headers, frows, brows = await self.sqlite_store.db.runInteraction(
|
||||
headers, frows, brows = await self.sqlite_store.db_pool.runInteraction(
|
||||
"select", r
|
||||
)
|
||||
|
||||
@ -375,7 +373,7 @@ class Porter(object):
|
||||
def insert(txn):
|
||||
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
|
||||
|
||||
self.postgres_store.db.simple_update_one_txn(
|
||||
self.postgres_store.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
@ -413,7 +411,7 @@ class Porter(object):
|
||||
|
||||
return headers, rows
|
||||
|
||||
headers, rows = await self.sqlite_store.db.runInteraction("select", r)
|
||||
headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r)
|
||||
|
||||
if rows:
|
||||
forward_chunk = rows[-1][0] + 1
|
||||
@ -451,7 +449,7 @@ class Porter(object):
|
||||
],
|
||||
)
|
||||
|
||||
self.postgres_store.db.simple_update_one_txn(
|
||||
self.postgres_store.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": "event_search"},
|
||||
@ -494,7 +492,7 @@ class Porter(object):
|
||||
db_conn, allow_outdated_version=allow_outdated_version
|
||||
)
|
||||
prepare_database(db_conn, engine, config=self.hs_config)
|
||||
store = Store(Database(hs, db_config, engine), db_conn, hs)
|
||||
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
|
||||
db_conn.commit()
|
||||
|
||||
return store
|
||||
@ -502,7 +500,7 @@ class Porter(object):
|
||||
async def run_background_updates_on_postgres(self):
|
||||
# Manually apply all background updates on the PostgreSQL database.
|
||||
postgres_ready = (
|
||||
await self.postgres_store.db.updates.has_completed_background_updates()
|
||||
await self.postgres_store.db_pool.updates.has_completed_background_updates()
|
||||
)
|
||||
|
||||
if not postgres_ready:
|
||||
@ -511,9 +509,9 @@ class Porter(object):
|
||||
self.progress.set_state("Running background updates on PostgreSQL")
|
||||
|
||||
while not postgres_ready:
|
||||
await self.postgres_store.db.updates.do_next_background_update(100)
|
||||
await self.postgres_store.db_pool.updates.do_next_background_update(100)
|
||||
postgres_ready = await (
|
||||
self.postgres_store.db.updates.has_completed_background_updates()
|
||||
self.postgres_store.db_pool.updates.has_completed_background_updates()
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
@ -534,7 +532,7 @@ class Porter(object):
|
||||
|
||||
# Check if all background updates are done, abort if not.
|
||||
updates_complete = (
|
||||
await self.sqlite_store.db.updates.has_completed_background_updates()
|
||||
await self.sqlite_store.db_pool.updates.has_completed_background_updates()
|
||||
)
|
||||
if not updates_complete:
|
||||
end_error = (
|
||||
@ -576,22 +574,24 @@ class Porter(object):
|
||||
)
|
||||
|
||||
try:
|
||||
await self.postgres_store.db.runInteraction("alter_table", alter_table)
|
||||
await self.postgres_store.db_pool.runInteraction(
|
||||
"alter_table", alter_table
|
||||
)
|
||||
except Exception:
|
||||
# On Error Resume Next
|
||||
pass
|
||||
|
||||
await self.postgres_store.db.runInteraction(
|
||||
await self.postgres_store.db_pool.runInteraction(
|
||||
"create_port_table", create_port_table
|
||||
)
|
||||
|
||||
# Step 2. Get tables.
|
||||
self.progress.set_state("Fetching tables")
|
||||
sqlite_tables = await self.sqlite_store.db.simple_select_onecol(
|
||||
sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol(
|
||||
table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
|
||||
)
|
||||
|
||||
postgres_tables = await self.postgres_store.db.simple_select_onecol(
|
||||
postgres_tables = await self.postgres_store.db_pool.simple_select_onecol(
|
||||
table="information_schema.tables",
|
||||
keyvalues={},
|
||||
retcol="distinct table_name",
|
||||
@ -692,7 +692,7 @@ class Porter(object):
|
||||
|
||||
return headers, [r for r in rows if r[ts_ind] < yesterday]
|
||||
|
||||
headers, rows = await self.sqlite_store.db.runInteraction("select", r)
|
||||
headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r)
|
||||
|
||||
rows = self._convert_rows("sent_transactions", headers, rows)
|
||||
|
||||
@ -725,7 +725,7 @@ class Porter(object):
|
||||
next_chunk = await self.sqlite_store.execute(get_start_id)
|
||||
next_chunk = max(max_inserted_rowid + 1, next_chunk)
|
||||
|
||||
await self.postgres_store.db.simple_insert(
|
||||
await self.postgres_store.db_pool.simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={
|
||||
"table_name": "sent_transactions",
|
||||
@ -794,14 +794,14 @@ class Porter(object):
|
||||
next_id = curr_id + 1
|
||||
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
|
||||
|
||||
return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
|
||||
return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
|
||||
|
||||
def _setup_user_id_seq(self):
|
||||
def r(txn):
|
||||
next_id = find_max_generated_user_id_localpart(txn) + 1
|
||||
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
|
||||
|
||||
return self.postgres_store.db.runInteraction("setup_user_id_seq", r)
|
||||
return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
|
||||
|
||||
|
||||
##############################################
|
||||
|
@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
@ -22,7 +21,6 @@ import sys
|
||||
import traceback
|
||||
from typing import Iterable
|
||||
|
||||
from daemonize import Daemonize
|
||||
from typing_extensions import NoReturn
|
||||
|
||||
from twisted.internet import defer, error, reactor
|
||||
@ -34,6 +32,7 @@ from synapse.config.server import ListenerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.daemonize import daemonize_process
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
@ -129,17 +128,8 @@ def start_reactor(
|
||||
if print_pidfile:
|
||||
print(pid_file)
|
||||
|
||||
daemon = Daemonize(
|
||||
app=appname,
|
||||
pid=pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
daemonize_process(pid_file, logger)
|
||||
run()
|
||||
|
||||
|
||||
def quit_with_error(error_string: str) -> NoReturn:
|
||||
@ -278,7 +268,7 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
|
||||
|
||||
# It is now safe to start your Synapse.
|
||||
hs.start_listening(listeners)
|
||||
hs.get_datastore().db.start_profiling()
|
||||
hs.get_datastore().db_pool.start_profiling()
|
||||
hs.get_pusherpool().start()
|
||||
|
||||
setup_sentry(hs)
|
||||
|
@ -125,15 +125,15 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
from synapse.rest.client.versions import VersionsRestServlet
|
||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.data_stores.main.censor_events import CensorEventsStore
|
||||
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
|
||||
from synapse.storage.data_stores.main.monthly_active_users import (
|
||||
from synapse.storage.databases.main.censor_events import CensorEventsStore
|
||||
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
|
||||
from synapse.storage.databases.main.monthly_active_users import (
|
||||
MonthlyActiveUsersWorkerStore,
|
||||
)
|
||||
from synapse.storage.data_stores.main.presence import UserPresenceState
|
||||
from synapse.storage.data_stores.main.search import SearchWorkerStore
|
||||
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
|
||||
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
|
||||
from synapse.storage.databases.main.presence import UserPresenceState
|
||||
from synapse.storage.databases.main.search import SearchWorkerStore
|
||||
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
|
||||
from synapse.storage.databases.main.user_directory import UserDirectoryStore
|
||||
from synapse.types import ReadReceipt
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
|
@ -380,13 +380,12 @@ def setup(config_options):
|
||||
|
||||
hs.setup_master()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_acme():
|
||||
async def do_acme() -> bool:
|
||||
"""
|
||||
Reprovision an ACME certificate, if it's required.
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: Whether the cert has been updated.
|
||||
Whether the cert has been updated.
|
||||
"""
|
||||
acme = hs.get_acme_handler()
|
||||
|
||||
@ -405,7 +404,7 @@ def setup(config_options):
|
||||
provision = True
|
||||
|
||||
if provision:
|
||||
yield acme.provision_certificate()
|
||||
await acme.provision_certificate()
|
||||
|
||||
return provision
|
||||
|
||||
@ -415,7 +414,7 @@ def setup(config_options):
|
||||
Provision a certificate from ACME, if required, and reload the TLS
|
||||
certificate if it's renewed.
|
||||
"""
|
||||
reprovisioned = yield do_acme()
|
||||
reprovisioned = yield defer.ensureDeferred(do_acme())
|
||||
if reprovisioned:
|
||||
_base.refresh_certificate(hs)
|
||||
|
||||
@ -427,8 +426,8 @@ def setup(config_options):
|
||||
acme = hs.get_acme_handler()
|
||||
# Start up the webservices which we will respond to ACME
|
||||
# challenges with, and then provision.
|
||||
yield acme.start_listening()
|
||||
yield do_acme()
|
||||
yield defer.ensureDeferred(acme.start_listening())
|
||||
yield defer.ensureDeferred(do_acme())
|
||||
|
||||
# Check if it needs to be reprovisioned every day.
|
||||
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
|
||||
@ -442,7 +441,7 @@ def setup(config_options):
|
||||
|
||||
_base.start(hs, config.listeners)
|
||||
|
||||
hs.get_datastore().db.updates.start_doing_background_updates()
|
||||
hs.get_datastore().db_pool.updates.start_doing_background_updates()
|
||||
except Exception:
|
||||
# Print the exception and bail out.
|
||||
print("Error during startup:", file=sys.stderr)
|
||||
@ -552,8 +551,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
|
||||
#
|
||||
|
||||
# This only reports info about the *main* database.
|
||||
stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
|
||||
stats["database_server_version"] = hs.get_datastore().db.engine.server_version
|
||||
stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
|
||||
stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
|
||||
|
||||
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
|
||||
try:
|
||||
|
@ -175,7 +175,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
urllib.parse.quote(protocol),
|
||||
)
|
||||
try:
|
||||
info = yield self.get_json(uri, {})
|
||||
info = yield defer.ensureDeferred(self.get_json(uri, {}))
|
||||
|
||||
if not _is_valid_3pe_metadata(info):
|
||||
logger.warning(
|
||||
|
@ -100,7 +100,10 @@ class DatabaseConnectionConfig:
|
||||
|
||||
self.name = name
|
||||
self.config = db_config
|
||||
self.data_stores = data_stores
|
||||
|
||||
# The `data_stores` config is actually talking about `databases` (we
|
||||
# changed the name).
|
||||
self.databases = data_stores
|
||||
|
||||
|
||||
class DatabaseConfig(Config):
|
||||
|
@ -93,6 +93,15 @@ class RatelimitConfig(Config):
|
||||
if rc_admin_redaction:
|
||||
self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction)
|
||||
|
||||
self.rc_joins_local = RateLimitConfig(
|
||||
config.get("rc_joins", {}).get("local", {}),
|
||||
defaults={"per_second": 0.1, "burst_count": 3},
|
||||
)
|
||||
self.rc_joins_remote = RateLimitConfig(
|
||||
config.get("rc_joins", {}).get("remote", {}),
|
||||
defaults={"per_second": 0.01, "burst_count": 3},
|
||||
)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
## Ratelimiting ##
|
||||
@ -118,6 +127,10 @@ class RatelimitConfig(Config):
|
||||
# - one for ratelimiting redactions by room admins. If this is not explicitly
|
||||
# set then it uses the same ratelimiting as per rc_message. This is useful
|
||||
# to allow room admins to deal with abuse quickly.
|
||||
# - two for ratelimiting number of rooms a user can join, "local" for when
|
||||
# users are joining rooms the server is already in (this is cheap) vs
|
||||
# "remote" for when users are trying to join rooms not on the server (which
|
||||
# can be more expensive)
|
||||
#
|
||||
# The defaults are as shown below.
|
||||
#
|
||||
@ -143,6 +156,14 @@ class RatelimitConfig(Config):
|
||||
#rc_admin_redaction:
|
||||
# per_second: 1
|
||||
# burst_count: 50
|
||||
#
|
||||
#rc_joins:
|
||||
# local:
|
||||
# per_second: 0.1
|
||||
# burst_count: 3
|
||||
# remote:
|
||||
# per_second: 0.01
|
||||
# burst_count: 3
|
||||
|
||||
|
||||
# Ratelimiting settings for incoming federation
|
||||
|
@ -223,8 +223,7 @@ class Keyring(object):
|
||||
|
||||
return results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _start_key_lookups(self, verify_requests):
|
||||
async def _start_key_lookups(self, verify_requests):
|
||||
"""Sets off the key fetches for each verify request
|
||||
|
||||
Once each fetch completes, verify_request.key_ready will be resolved.
|
||||
@ -245,7 +244,7 @@ class Keyring(object):
|
||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||
|
||||
# Wait for any previous lookups to complete before proceeding.
|
||||
yield self.wait_for_previous_lookups(server_to_request_ids.keys())
|
||||
await self.wait_for_previous_lookups(server_to_request_ids.keys())
|
||||
|
||||
# take out a lock on each of the servers by sticking a Deferred in
|
||||
# key_downloads
|
||||
@ -283,15 +282,14 @@ class Keyring(object):
|
||||
except Exception:
|
||||
logger.exception("Error starting key lookups")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_previous_lookups(self, server_names):
|
||||
async def wait_for_previous_lookups(self, server_names) -> None:
|
||||
"""Waits for any previous key lookups for the given servers to finish.
|
||||
|
||||
Args:
|
||||
server_names (Iterable[str]): list of servers which we want to look up
|
||||
|
||||
Returns:
|
||||
Deferred[None]: resolves once all key lookups for the given servers have
|
||||
Resolves once all key lookups for the given servers have
|
||||
completed. Follows the synapse rules of logcontext preservation.
|
||||
"""
|
||||
loop_count = 1
|
||||
@ -309,7 +307,7 @@ class Keyring(object):
|
||||
loop_count,
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
yield defer.DeferredList((w[1] for w in wait_on))
|
||||
await defer.DeferredList((w[1] for w in wait_on))
|
||||
|
||||
loop_count += 1
|
||||
|
||||
@ -326,44 +324,44 @@ class Keyring(object):
|
||||
|
||||
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_iterations():
|
||||
with Measure(self.clock, "get_server_verify_keys"):
|
||||
for f in self._key_fetchers:
|
||||
if not remaining_requests:
|
||||
return
|
||||
yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
|
||||
|
||||
# look for any requests which weren't satisfied
|
||||
with PreserveLoggingContext():
|
||||
for verify_request in remaining_requests:
|
||||
verify_request.key_ready.errback(
|
||||
SynapseError(
|
||||
401,
|
||||
"No key for %s with ids in %s (min_validity %i)"
|
||||
% (
|
||||
verify_request.server_name,
|
||||
verify_request.key_ids,
|
||||
verify_request.minimum_valid_until_ts,
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
async def do_iterations():
|
||||
try:
|
||||
with Measure(self.clock, "get_server_verify_keys"):
|
||||
for f in self._key_fetchers:
|
||||
if not remaining_requests:
|
||||
return
|
||||
await self._attempt_key_fetches_with_fetcher(
|
||||
f, remaining_requests
|
||||
)
|
||||
|
||||
def on_err(err):
|
||||
# we don't really expect to get here, because any errors should already
|
||||
# have been caught and logged. But if we do, let's log the error and make
|
||||
# sure that all of the deferreds are resolved.
|
||||
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
|
||||
with PreserveLoggingContext():
|
||||
for verify_request in remaining_requests:
|
||||
if not verify_request.key_ready.called:
|
||||
verify_request.key_ready.errback(err)
|
||||
# look for any requests which weren't satisfied
|
||||
with PreserveLoggingContext():
|
||||
for verify_request in remaining_requests:
|
||||
verify_request.key_ready.errback(
|
||||
SynapseError(
|
||||
401,
|
||||
"No key for %s with ids in %s (min_validity %i)"
|
||||
% (
|
||||
verify_request.server_name,
|
||||
verify_request.key_ids,
|
||||
verify_request.minimum_valid_until_ts,
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
)
|
||||
except Exception as err:
|
||||
# we don't really expect to get here, because any errors should already
|
||||
# have been caught and logged. But if we do, let's log the error and make
|
||||
# sure that all of the deferreds are resolved.
|
||||
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
|
||||
with PreserveLoggingContext():
|
||||
for verify_request in remaining_requests:
|
||||
if not verify_request.key_ready.called:
|
||||
verify_request.key_ready.errback(err)
|
||||
|
||||
run_in_background(do_iterations).addErrback(on_err)
|
||||
run_in_background(do_iterations)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
|
||||
async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
|
||||
"""Use a key fetcher to attempt to satisfy some key requests
|
||||
|
||||
Args:
|
||||
@ -390,7 +388,7 @@ class Keyring(object):
|
||||
verify_request.minimum_valid_until_ts,
|
||||
)
|
||||
|
||||
results = yield fetcher.get_keys(missing_keys)
|
||||
results = await fetcher.get_keys(missing_keys)
|
||||
|
||||
completed = []
|
||||
for verify_request in remaining_requests:
|
||||
@ -423,7 +421,7 @@ class Keyring(object):
|
||||
|
||||
|
||||
class KeyFetcher(object):
|
||||
def get_keys(self, keys_to_fetch):
|
||||
async def get_keys(self, keys_to_fetch):
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch (dict[str, dict[str, int]]):
|
||||
@ -442,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys(self, keys_to_fetch):
|
||||
async def get_keys(self, keys_to_fetch):
|
||||
"""see KeyFetcher.get_keys"""
|
||||
|
||||
keys_to_fetch = (
|
||||
@ -452,7 +449,7 @@ class StoreKeyFetcher(KeyFetcher):
|
||||
for key_id in keys_for_server.keys()
|
||||
)
|
||||
|
||||
res = yield self.store.get_server_verify_keys(keys_to_fetch)
|
||||
res = await self.store.get_server_verify_keys(keys_to_fetch)
|
||||
keys = {}
|
||||
for (server_name, key_id), key in res.items():
|
||||
keys.setdefault(server_name, {})[key_id] = key
|
||||
@ -464,8 +461,7 @@ class BaseV2KeyFetcher(object):
|
||||
self.store = hs.get_datastore()
|
||||
self.config = hs.get_config()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def process_v2_response(self, from_server, response_json, time_added_ms):
|
||||
async def process_v2_response(self, from_server, response_json, time_added_ms):
|
||||
"""Parse a 'Server Keys' structure from the result of a /key request
|
||||
|
||||
This is used to parse either the entirety of the response from
|
||||
@ -537,7 +533,7 @@ class BaseV2KeyFetcher(object):
|
||||
|
||||
key_json_bytes = encode_canonical_json(response_json)
|
||||
|
||||
yield make_deferred_yieldable(
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
@ -567,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
self.client = hs.get_http_client()
|
||||
self.key_servers = self.config.key_servers
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys(self, keys_to_fetch):
|
||||
async def get_keys(self, keys_to_fetch):
|
||||
"""see KeyFetcher.get_keys"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_key(key_server):
|
||||
async def get_key(key_server):
|
||||
try:
|
||||
result = yield self.get_server_verify_key_v2_indirect(
|
||||
result = await self.get_server_verify_key_v2_indirect(
|
||||
keys_to_fetch, key_server
|
||||
)
|
||||
return result
|
||||
@ -592,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
|
||||
return {}
|
||||
|
||||
results = yield make_deferred_yieldable(
|
||||
results = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[run_in_background(get_key, server) for server in self.key_servers],
|
||||
consumeErrors=True,
|
||||
@ -606,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
|
||||
return union_of_keys
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
|
||||
async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch (dict[str, dict[str, int]]):
|
||||
@ -617,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
the keys
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
|
||||
dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
|
||||
from server_name -> key_id -> FetchKeyResult
|
||||
|
||||
Raises:
|
||||
@ -632,20 +625,18 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
)
|
||||
|
||||
try:
|
||||
query_response = yield defer.ensureDeferred(
|
||||
self.client.post_json(
|
||||
destination=perspective_name,
|
||||
path="/_matrix/key/v2/query",
|
||||
data={
|
||||
"server_keys": {
|
||||
server_name: {
|
||||
key_id: {"minimum_valid_until_ts": min_valid_ts}
|
||||
for key_id, min_valid_ts in server_keys.items()
|
||||
}
|
||||
for server_name, server_keys in keys_to_fetch.items()
|
||||
query_response = await self.client.post_json(
|
||||
destination=perspective_name,
|
||||
path="/_matrix/key/v2/query",
|
||||
data={
|
||||
"server_keys": {
|
||||
server_name: {
|
||||
key_id: {"minimum_valid_until_ts": min_valid_ts}
|
||||
for key_id, min_valid_ts in server_keys.items()
|
||||
}
|
||||
},
|
||||
)
|
||||
for server_name, server_keys in keys_to_fetch.items()
|
||||
}
|
||||
},
|
||||
)
|
||||
except (NotRetryingDestination, RequestSendFailed) as e:
|
||||
# these both have str() representations which we can't really improve upon
|
||||
@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
try:
|
||||
self._validate_perspectives_response(key_server, response)
|
||||
|
||||
processed_response = yield self.process_v2_response(
|
||||
processed_response = await self.process_v2_response(
|
||||
perspective_name, response, time_added_ms=time_now_ms
|
||||
)
|
||||
except KeyLookupError as e:
|
||||
@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
)
|
||||
keys.setdefault(server_name, {}).update(processed_response)
|
||||
|
||||
yield self.store.store_server_verify_keys(
|
||||
await self.store.store_server_verify_keys(
|
||||
perspective_name, time_now_ms, added_keys
|
||||
)
|
||||
|
||||
@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
self.clock = hs.get_clock()
|
||||
self.client = hs.get_http_client()
|
||||
|
||||
def get_keys(self, keys_to_fetch):
|
||||
async def get_keys(self, keys_to_fetch):
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch (dict[str, iterable[str]]):
|
||||
the keys to be fetched. server_name -> key_ids
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
|
||||
dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
|
||||
map from server_name -> key_id -> FetchKeyResult
|
||||
"""
|
||||
|
||||
results = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_key(key_to_fetch_item):
|
||||
async def get_key(key_to_fetch_item):
|
||||
server_name, key_ids = key_to_fetch_item
|
||||
try:
|
||||
keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
|
||||
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
||||
results[server_name] = keys
|
||||
except KeyLookupError as e:
|
||||
logger.warning(
|
||||
@ -767,12 +757,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
except Exception:
|
||||
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
||||
|
||||
return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
|
||||
lambda _: results
|
||||
)
|
||||
return await yieldable_gather_results(
|
||||
get_key, keys_to_fetch.items()
|
||||
).addCallback(lambda _: results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
||||
async def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
||||
"""
|
||||
|
||||
Args:
|
||||
@ -794,25 +783,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
try:
|
||||
response = yield defer.ensureDeferred(
|
||||
self.client.get_json(
|
||||
destination=server_name,
|
||||
path="/_matrix/key/v2/server/"
|
||||
+ urllib.parse.quote(requested_key_id),
|
||||
ignore_backoff=True,
|
||||
# we only give the remote server 10s to respond. It should be an
|
||||
# easy request to handle, so if it doesn't reply within 10s, it's
|
||||
# probably not going to.
|
||||
#
|
||||
# Furthermore, when we are acting as a notary server, we cannot
|
||||
# wait all day for all of the origin servers, as the requesting
|
||||
# server will otherwise time out before we can respond.
|
||||
#
|
||||
# (Note that get_json may make 4 attempts, so this can still take
|
||||
# almost 45 seconds to fetch the headers, plus up to another 60s to
|
||||
# read the response).
|
||||
timeout=10000,
|
||||
)
|
||||
response = await self.client.get_json(
|
||||
destination=server_name,
|
||||
path="/_matrix/key/v2/server/"
|
||||
+ urllib.parse.quote(requested_key_id),
|
||||
ignore_backoff=True,
|
||||
# we only give the remote server 10s to respond. It should be an
|
||||
# easy request to handle, so if it doesn't reply within 10s, it's
|
||||
# probably not going to.
|
||||
#
|
||||
# Furthermore, when we are acting as a notary server, we cannot
|
||||
# wait all day for all of the origin servers, as the requesting
|
||||
# server will otherwise time out before we can respond.
|
||||
#
|
||||
# (Note that get_json may make 4 attempts, so this can still take
|
||||
# almost 45 seconds to fetch the headers, plus up to another 60s to
|
||||
# read the response).
|
||||
timeout=10000,
|
||||
)
|
||||
except (NotRetryingDestination, RequestSendFailed) as e:
|
||||
# these both have str() representations which we can't really improve
|
||||
@ -827,12 +814,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
% (server_name, response["server_name"])
|
||||
)
|
||||
|
||||
response_keys = yield self.process_v2_response(
|
||||
response_keys = await self.process_v2_response(
|
||||
from_server=server_name,
|
||||
response_json=response,
|
||||
time_added_ms=time_now_ms,
|
||||
)
|
||||
yield self.store.store_server_verify_keys(
|
||||
await self.store.store_server_verify_keys(
|
||||
server_name,
|
||||
time_now_ms,
|
||||
((server_name, key_id, key) for key_id, key in response_keys.items()),
|
||||
@ -842,22 +829,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
return keys
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_key_deferred(verify_request):
|
||||
async def _handle_key_deferred(verify_request) -> None:
|
||||
"""Waits for the key to become available, and then performs a verification
|
||||
|
||||
Args:
|
||||
verify_request (VerifyJsonRequest):
|
||||
|
||||
Returns:
|
||||
Deferred[None]
|
||||
|
||||
Raises:
|
||||
SynapseError if there was a problem performing the verification
|
||||
"""
|
||||
server_name = verify_request.server_name
|
||||
with PreserveLoggingContext():
|
||||
_, key_id, verify_key = yield verify_request.key_ready
|
||||
_, key_id, verify_key = await verify_request.key_ready
|
||||
|
||||
json_object = verify_request.json_object
|
||||
|
||||
|
@ -23,7 +23,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.types import StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.data_stores.main import DataStore
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
@ -17,7 +17,6 @@ import logging
|
||||
|
||||
import twisted
|
||||
import twisted.internet.error
|
||||
from twisted.internet import defer
|
||||
from twisted.web import server, static
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
@ -41,8 +40,7 @@ class AcmeHandler(object):
|
||||
self.reactor = hs.get_reactor()
|
||||
self._acme_domain = hs.config.acme_domain
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start_listening(self):
|
||||
async def start_listening(self):
|
||||
from synapse.handlers import acme_issuing_service
|
||||
|
||||
# Configure logging for txacme, if you need to debug
|
||||
@ -82,18 +80,17 @@ class AcmeHandler(object):
|
||||
self._issuer._registered = False
|
||||
|
||||
try:
|
||||
yield self._issuer._ensure_registered()
|
||||
await self._issuer._ensure_registered()
|
||||
except Exception:
|
||||
logger.error(ACME_REGISTER_FAIL_ERROR)
|
||||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def provision_certificate(self):
|
||||
async def provision_certificate(self):
|
||||
|
||||
logger.warning("Reprovisioning %s", self._acme_domain)
|
||||
|
||||
try:
|
||||
yield self._issuer.issue_cert(self._acme_domain)
|
||||
await self._issuer.issue_cert(self._acme_domain)
|
||||
except Exception:
|
||||
logger.exception("Fail!")
|
||||
raise
|
||||
|
@ -71,7 +71,7 @@ from synapse.replication.http.federation import (
|
||||
)
|
||||
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
|
||||
from synapse.state import StateResolutionStore, resolve_events_with_store
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.distributor import user_joined_room
|
||||
|
@ -22,14 +22,10 @@ import urllib.parse
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from twisted.internet.error import TimeoutError
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
HttpResponseException,
|
||||
@ -628,9 +624,9 @@ class IdentityHandler(BaseHandler):
|
||||
)
|
||||
|
||||
if "mxid" in data:
|
||||
if "signatures" not in data:
|
||||
raise AuthError(401, "No signatures on 3pid binding")
|
||||
await self._verify_any_signature(data, id_server)
|
||||
# note: we used to verify the identity server's signature here, but no longer
|
||||
# require or validate it. See the following for context:
|
||||
# https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950
|
||||
return data["mxid"]
|
||||
except TimeoutError:
|
||||
raise SynapseError(500, "Timed out contacting identity server")
|
||||
@ -751,30 +747,6 @@ class IdentityHandler(BaseHandler):
|
||||
mxid = lookup_results["mappings"].get(lookup_value)
|
||||
return mxid
|
||||
|
||||
async def _verify_any_signature(self, data, server_hostname):
|
||||
if server_hostname not in data["signatures"]:
|
||||
raise AuthError(401, "No signature from server %s" % (server_hostname,))
|
||||
for key_name, signature in data["signatures"][server_hostname].items():
|
||||
try:
|
||||
key_data = await self.blacklisting_http_client.get_json(
|
||||
"%s%s/_matrix/identity/api/v1/pubkey/%s"
|
||||
% (id_server_scheme, server_hostname, key_name)
|
||||
)
|
||||
except TimeoutError:
|
||||
raise SynapseError(500, "Timed out contacting identity server")
|
||||
if "public_key" not in key_data:
|
||||
raise AuthError(
|
||||
401, "No public key named %s from %s" % (key_name, server_hostname)
|
||||
)
|
||||
verify_signed_json(
|
||||
data,
|
||||
server_hostname,
|
||||
decode_verify_key_bytes(
|
||||
key_name, decode_base64(key_data["public_key"])
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
async def ask_id_server_for_third_party_invite(
|
||||
self,
|
||||
requester: Requester,
|
||||
|
@ -109,7 +109,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
|
||||
rooms_ret = []
|
||||
|
||||
now_token = await self.hs.get_event_sources().get_current_token()
|
||||
now_token = self.hs.get_event_sources().get_current_token()
|
||||
|
||||
presence_stream = self.hs.get_event_sources().sources["presence"]
|
||||
pagination_config = PaginationConfig(from_token=now_token)
|
||||
@ -360,7 +360,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
current_state.values(), time_now
|
||||
)
|
||||
|
||||
now_token = await self.hs.get_event_sources().get_current_token()
|
||||
now_token = self.hs.get_event_sources().get_current_token()
|
||||
|
||||
limit = pagin_config.limit if pagin_config else None
|
||||
if limit is None:
|
||||
|
@ -45,7 +45,7 @@ from synapse.events.validator import EventValidator
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
Collection,
|
||||
|
@ -309,7 +309,7 @@ class PaginationHandler(object):
|
||||
room_token = pagin_config.from_token.room_key
|
||||
else:
|
||||
pagin_config.from_token = (
|
||||
await self.hs.get_event_sources().get_current_token_for_pagination()
|
||||
self.hs.get_event_sources().get_current_token_for_pagination()
|
||||
)
|
||||
room_token = pagin_config.from_token.room_key
|
||||
|
||||
|
@ -38,7 +38,7 @@ from synapse.logging.utils import log_function
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage.data_stores.main import DataStore
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
@ -319,7 +319,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||
is some spurious presence changes that will self-correct.
|
||||
"""
|
||||
# If the DB pool has already terminated, don't try updating
|
||||
if not self.store.db.is_running():
|
||||
if not self.store.db_pool.is_running():
|
||||
return
|
||||
|
||||
logger.info(
|
||||
|
@ -548,7 +548,7 @@ class RegistrationHandler(BaseHandler):
|
||||
address (str|None): the IP address used to perform the registration.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
Awaitable
|
||||
"""
|
||||
if self.hs.config.worker_app:
|
||||
return self._register_client(
|
||||
|
@ -22,7 +22,7 @@ import logging
|
||||
import math
|
||||
import string
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Tuple
|
||||
from typing import Awaitable, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import (
|
||||
EventTypes,
|
||||
@ -1041,7 +1041,7 @@ class RoomEventSource(object):
|
||||
):
|
||||
# We just ignore the key for now.
|
||||
|
||||
to_key = await self.get_current_key()
|
||||
to_key = self.get_current_key()
|
||||
|
||||
from_token = RoomStreamToken.parse(from_key)
|
||||
if from_token.topological:
|
||||
@ -1081,10 +1081,10 @@ class RoomEventSource(object):
|
||||
|
||||
return (events, end_key)
|
||||
|
||||
def get_current_key(self):
|
||||
return self.store.get_room_events_max_id()
|
||||
def get_current_key(self) -> str:
|
||||
return "s%d" % (self.store.get_room_max_stream_ordering(),)
|
||||
|
||||
def get_current_key_for_room(self, room_id):
|
||||
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
|
||||
return self.store.get_room_events_max_id(room_id)
|
||||
|
||||
|
||||
|
@ -22,7 +22,8 @@ from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse import types
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.api.room_versions import EventFormatVersions
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
from synapse.events import EventBase
|
||||
@ -77,6 +78,17 @@ class RoomMemberHandler(object):
|
||||
if self._is_on_event_persistence_instance:
|
||||
self.persist_event_storage = hs.get_storage().persistence
|
||||
|
||||
self._join_rate_limiter_local = Ratelimiter(
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
|
||||
)
|
||||
self._join_rate_limiter_remote = Ratelimiter(
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
|
||||
)
|
||||
|
||||
# This is only used to get at ratelimit function, and
|
||||
# maybe_kick_guest_users. It's fine there are multiple of these as
|
||||
# it doesn't store state.
|
||||
@ -441,7 +453,28 @@ class RoomMemberHandler(object):
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
if not is_host_in_room:
|
||||
if is_host_in_room:
|
||||
time_now_s = self.clock.time()
|
||||
allowed, time_allowed = self._join_rate_limiter_local.can_do_action(
|
||||
requester.user.to_string(),
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now_s))
|
||||
)
|
||||
|
||||
else:
|
||||
time_now_s = self.clock.time()
|
||||
allowed, time_allowed = self._join_rate_limiter_remote.can_do_action(
|
||||
requester.user.to_string(),
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now_s))
|
||||
)
|
||||
|
||||
inviter = await self._get_inviter(target.to_string(), room_id)
|
||||
if inviter and not self.hs.is_mine(inviter):
|
||||
remote_room_hosts.append(inviter.domain)
|
||||
|
@ -340,7 +340,7 @@ class SearchHandler(BaseHandler):
|
||||
# If client has asked for "context" for each event (i.e. some surrounding
|
||||
# events and state), fetch that
|
||||
if event_context is not None:
|
||||
now_token = await self.hs.get_event_sources().get_current_token()
|
||||
now_token = self.hs.get_event_sources().get_current_token()
|
||||
|
||||
contexts = {}
|
||||
for event in allowed_events:
|
||||
|
@ -232,7 +232,7 @@ class StatsHandler:
|
||||
|
||||
if membership == prev_membership:
|
||||
pass # noop
|
||||
if membership == Membership.JOIN:
|
||||
elif membership == Membership.JOIN:
|
||||
room_stats_delta["joined_members"] += 1
|
||||
elif membership == Membership.INVITE:
|
||||
room_stats_delta["invited_members"] += 1
|
||||
|
@ -961,7 +961,7 @@ class SyncHandler(object):
|
||||
# this is due to some of the underlying streams not supporting the ability
|
||||
# to query up to a given point.
|
||||
# Always use the `now_token` in `SyncResultBuilder`
|
||||
now_token = await self.event_sources.get_current_token()
|
||||
now_token = self.event_sources.get_current_token()
|
||||
|
||||
logger.debug(
|
||||
"Calculating sync response for %r between %s and %s",
|
||||
|
@ -284,8 +284,7 @@ class SimpleHttpClient(object):
|
||||
ip_blacklist=self._ip_blacklist,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def request(self, method, uri, data=None, headers=None):
|
||||
async def request(self, method, uri, data=None, headers=None):
|
||||
"""
|
||||
Args:
|
||||
method (str): HTTP method to use.
|
||||
@ -330,7 +329,7 @@ class SimpleHttpClient(object):
|
||||
self.hs.get_reactor(),
|
||||
cancelled_to_request_timed_out_error,
|
||||
)
|
||||
response = yield make_deferred_yieldable(request_deferred)
|
||||
response = await make_deferred_yieldable(request_deferred)
|
||||
|
||||
incoming_responses_counter.labels(method, response.code).inc()
|
||||
logger.info(
|
||||
@ -353,8 +352,7 @@ class SimpleHttpClient(object):
|
||||
set_tag("error_reason", e.args[0])
|
||||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_urlencoded_get_json(self, uri, args={}, headers=None):
|
||||
async def post_urlencoded_get_json(self, uri, args={}, headers=None):
|
||||
"""
|
||||
Args:
|
||||
uri (str):
|
||||
@ -363,7 +361,7 @@ class SimpleHttpClient(object):
|
||||
header name to a list of values for that header
|
||||
|
||||
Returns:
|
||||
Deferred[object]: parsed json
|
||||
object: parsed json
|
||||
|
||||
Raises:
|
||||
HttpResponseException: On a non-2xx HTTP response.
|
||||
@ -386,11 +384,11 @@ class SimpleHttpClient(object):
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
response = yield self.request(
|
||||
response = await self.request(
|
||||
"POST", uri, headers=Headers(actual_headers), data=query_bytes
|
||||
)
|
||||
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
body = await make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
@ -399,8 +397,7 @@ class SimpleHttpClient(object):
|
||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_json_get_json(self, uri, post_json, headers=None):
|
||||
async def post_json_get_json(self, uri, post_json, headers=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
@ -410,7 +407,7 @@ class SimpleHttpClient(object):
|
||||
header name to a list of values for that header
|
||||
|
||||
Returns:
|
||||
Deferred[object]: parsed json
|
||||
object: parsed json
|
||||
|
||||
Raises:
|
||||
HttpResponseException: On a non-2xx HTTP response.
|
||||
@ -429,11 +426,11 @@ class SimpleHttpClient(object):
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
response = yield self.request(
|
||||
response = await self.request(
|
||||
"POST", uri, headers=Headers(actual_headers), data=json_str
|
||||
)
|
||||
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
body = await make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
@ -442,8 +439,7 @@ class SimpleHttpClient(object):
|
||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, uri, args={}, headers=None):
|
||||
async def get_json(self, uri, args={}, headers=None):
|
||||
""" Gets some json from the given URI.
|
||||
|
||||
Args:
|
||||
@ -455,7 +451,7 @@ class SimpleHttpClient(object):
|
||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||
Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body as JSON.
|
||||
Raises:
|
||||
HttpResponseException On a non-2xx HTTP response.
|
||||
@ -466,11 +462,10 @@ class SimpleHttpClient(object):
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
body = yield self.get_raw(uri, args, headers=headers)
|
||||
body = await self.get_raw(uri, args, headers=headers)
|
||||
return json.loads(body.decode("utf-8"))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(self, uri, json_body, args={}, headers=None):
|
||||
async def put_json(self, uri, json_body, args={}, headers=None):
|
||||
""" Puts some json to the given URI.
|
||||
|
||||
Args:
|
||||
@ -483,7 +478,7 @@ class SimpleHttpClient(object):
|
||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||
Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body as JSON.
|
||||
Raises:
|
||||
HttpResponseException On a non-2xx HTTP response.
|
||||
@ -504,11 +499,11 @@ class SimpleHttpClient(object):
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
response = yield self.request(
|
||||
response = await self.request(
|
||||
"PUT", uri, headers=Headers(actual_headers), data=json_str
|
||||
)
|
||||
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
body = await make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
@ -517,8 +512,7 @@ class SimpleHttpClient(object):
|
||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_raw(self, uri, args={}, headers=None):
|
||||
async def get_raw(self, uri, args={}, headers=None):
|
||||
""" Gets raw text from the given URI.
|
||||
|
||||
Args:
|
||||
@ -530,7 +524,7 @@ class SimpleHttpClient(object):
|
||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||
Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body as bytes.
|
||||
Raises:
|
||||
HttpResponseException on a non-2xx HTTP response.
|
||||
@ -543,9 +537,9 @@ class SimpleHttpClient(object):
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
response = yield self.request("GET", uri, headers=Headers(actual_headers))
|
||||
response = await self.request("GET", uri, headers=Headers(actual_headers))
|
||||
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
body = await make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
return body
|
||||
@ -557,8 +551,7 @@ class SimpleHttpClient(object):
|
||||
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
||||
# The two should be factored out.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_file(self, url, output_stream, max_size=None, headers=None):
|
||||
async def get_file(self, url, output_stream, max_size=None, headers=None):
|
||||
"""GETs a file from a given URL
|
||||
Args:
|
||||
url (str): The URL to GET
|
||||
@ -574,7 +567,7 @@ class SimpleHttpClient(object):
|
||||
if headers:
|
||||
actual_headers.update(headers)
|
||||
|
||||
response = yield self.request("GET", url, headers=Headers(actual_headers))
|
||||
response = await self.request("GET", url, headers=Headers(actual_headers))
|
||||
|
||||
resp_headers = dict(response.headers.getAllRawHeaders())
|
||||
|
||||
@ -598,7 +591,7 @@ class SimpleHttpClient(object):
|
||||
# straight back in again
|
||||
|
||||
try:
|
||||
length = yield make_deferred_yieldable(
|
||||
length = await make_deferred_yieldable(
|
||||
_readBodyToFile(response, output_stream, max_size)
|
||||
)
|
||||
except SynapseError:
|
||||
|
@ -242,10 +242,12 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||
no appropriate method exists. Can be overriden in sub classes for
|
||||
different routing.
|
||||
"""
|
||||
# Treat HEAD requests as GET requests.
|
||||
request_method = request.method.decode("ascii")
|
||||
if request_method == "HEAD":
|
||||
request_method = "GET"
|
||||
|
||||
method_handler = getattr(
|
||||
self, "_async_render_%s" % (request.method.decode("ascii"),), None
|
||||
)
|
||||
method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
|
||||
if method_handler:
|
||||
raw_callback_return = method_handler(request)
|
||||
|
||||
@ -362,11 +364,15 @@ class JsonResource(DirectServeJsonResource):
|
||||
A tuple of the callback to use, the name of the servlet, and the
|
||||
key word arguments to pass to the callback
|
||||
"""
|
||||
# Treat HEAD requests as GET requests.
|
||||
request_path = request.path.decode("ascii")
|
||||
request_method = request.method
|
||||
if request_method == b"HEAD":
|
||||
request_method = b"GET"
|
||||
|
||||
# Loop through all the registered callbacks to check if the method
|
||||
# and path regex match
|
||||
for path_entry in self.path_regexs.get(request.method, []):
|
||||
for path_entry in self.path_regexs.get(request_method, []):
|
||||
m = path_entry.pattern.match(request_path)
|
||||
if m:
|
||||
# We found a match!
|
||||
@ -579,7 +585,7 @@ def set_cors_headers(request: Request):
|
||||
"""
|
||||
request.setHeader(b"Access-Control-Allow-Origin", b"*")
|
||||
request.setHeader(
|
||||
b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS"
|
||||
b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS"
|
||||
)
|
||||
request.setHeader(
|
||||
b"Access-Control-Allow-Headers",
|
||||
|
@ -219,7 +219,7 @@ class ModuleApi(object):
|
||||
Returns:
|
||||
Deferred[object]: result of func
|
||||
"""
|
||||
return self._store.db.runInteraction(desc, func, *args, **kwargs)
|
||||
return self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
|
||||
|
||||
def complete_sso_login(
|
||||
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
|
||||
|
@ -320,7 +320,7 @@ class Notifier(object):
|
||||
"""
|
||||
user_stream = self.user_to_user_stream.get(user_id)
|
||||
if user_stream is None:
|
||||
current_token = await self.event_sources.get_current_token()
|
||||
current_token = self.event_sources.get_current_token()
|
||||
if room_ids is None:
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
user_stream = _NotifierUserStream(
|
||||
@ -397,7 +397,7 @@ class Notifier(object):
|
||||
"""
|
||||
from_token = pagination_config.from_token
|
||||
if not from_token:
|
||||
from_token = await self.event_sources.get_current_token()
|
||||
from_token = self.event_sources.get_current_token()
|
||||
|
||||
limit = pagination_config.limit
|
||||
|
||||
|
@ -59,7 +59,6 @@ REQUIREMENTS = [
|
||||
"pyyaml>=3.11",
|
||||
"pyasn1>=0.1.9",
|
||||
"pyasn1-modules>=0.0.7",
|
||||
"daemonize>=2.3.1",
|
||||
"bcrypt>=3.1.0",
|
||||
"pillow>=4.3.0",
|
||||
"sortedcontainers>=1.4.4",
|
||||
|
@ -20,8 +20,6 @@ import urllib
|
||||
from inspect import signature
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
HttpResponseException,
|
||||
@ -101,7 +99,7 @@ class ReplicationEndpoint(object):
|
||||
assert self.METHOD in ("PUT", "POST", "GET")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _serialize_payload(**kwargs):
|
||||
async def _serialize_payload(**kwargs):
|
||||
"""Static method that is called when creating a request.
|
||||
|
||||
Concrete implementations should have explicit parameters (rather than
|
||||
@ -110,9 +108,8 @@ class ReplicationEndpoint(object):
|
||||
argument list.
|
||||
|
||||
Returns:
|
||||
Deferred[dict]|dict: If POST/PUT request then dictionary must be
|
||||
JSON serialisable, otherwise must be appropriate for adding as
|
||||
query args.
|
||||
dict: If POST/PUT request then dictionary must be JSON serialisable,
|
||||
otherwise must be appropriate for adding as query args.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@ -144,8 +141,7 @@ class ReplicationEndpoint(object):
|
||||
instance_map = hs.config.worker.instance_map
|
||||
|
||||
@trace(opname="outgoing_replication_request")
|
||||
@defer.inlineCallbacks
|
||||
def send_request(instance_name="master", **kwargs):
|
||||
async def send_request(instance_name="master", **kwargs):
|
||||
if instance_name == local_instance_name:
|
||||
raise Exception("Trying to send HTTP request to self")
|
||||
if instance_name == "master":
|
||||
@ -159,7 +155,7 @@ class ReplicationEndpoint(object):
|
||||
"Instance %r not in 'instance_map' config" % (instance_name,)
|
||||
)
|
||||
|
||||
data = yield cls._serialize_payload(**kwargs)
|
||||
data = await cls._serialize_payload(**kwargs)
|
||||
|
||||
url_args = [
|
||||
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
|
||||
@ -197,7 +193,7 @@ class ReplicationEndpoint(object):
|
||||
headers = {} # type: Dict[bytes, List[bytes]]
|
||||
inject_active_span_byte_dict(headers, None, check_destination=False)
|
||||
try:
|
||||
result = yield request_func(uri, data, headers=headers)
|
||||
result = await request_func(uri, data, headers=headers)
|
||||
break
|
||||
except CodeMessageException as e:
|
||||
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
|
||||
@ -207,7 +203,7 @@ class ReplicationEndpoint(object):
|
||||
|
||||
# If we timed out we probably don't need to worry about backing
|
||||
# off too much, but lets just wait a little anyway.
|
||||
yield clock.sleep(1)
|
||||
await clock.sleep(1)
|
||||
except HttpResponseException as e:
|
||||
# We convert to SynapseError as we know that it was a SynapseError
|
||||
# on the master process that we should send to the client. (And
|
||||
|
@ -60,7 +60,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(user_id):
|
||||
async def _serialize_payload(user_id):
|
||||
return {}
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
|
@ -15,8 +15,6 @@
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
@ -67,8 +65,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||
self.federation_handler = hs.get_handlers().federation_handler
|
||||
|
||||
@staticmethod
|
||||
@defer.inlineCallbacks
|
||||
def _serialize_payload(store, event_and_contexts, backfilled):
|
||||
async def _serialize_payload(store, event_and_contexts, backfilled):
|
||||
"""
|
||||
Args:
|
||||
store
|
||||
@ -78,9 +75,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||
"""
|
||||
event_payloads = []
|
||||
for event, context in event_and_contexts:
|
||||
serialized_context = yield defer.ensureDeferred(
|
||||
context.serialize(event, store)
|
||||
)
|
||||
serialized_context = await context.serialize(event, store)
|
||||
|
||||
event_payloads.append(
|
||||
{
|
||||
@ -156,7 +151,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
|
||||
self.registry = hs.get_federation_registry()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(edu_type, origin, content):
|
||||
async def _serialize_payload(edu_type, origin, content):
|
||||
return {"origin": origin, "content": content}
|
||||
|
||||
async def _handle_request(self, request, edu_type):
|
||||
@ -199,7 +194,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
|
||||
self.registry = hs.get_federation_registry()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(query_type, args):
|
||||
async def _serialize_payload(query_type, args):
|
||||
"""
|
||||
Args:
|
||||
query_type (str)
|
||||
@ -240,7 +235,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(room_id, args):
|
||||
async def _serialize_payload(room_id, args):
|
||||
"""
|
||||
Args:
|
||||
room_id (str)
|
||||
@ -275,7 +270,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(room_id, room_version):
|
||||
async def _serialize_payload(room_id, room_version):
|
||||
return {"room_version": room_version.identifier}
|
||||
|
||||
async def _handle_request(self, request, room_id):
|
||||
|
@ -36,7 +36,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
|
||||
async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
|
||||
"""
|
||||
Args:
|
||||
device_id (str|None): Device ID to use, if None a new one is
|
||||
|
@ -52,7 +52,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
|
||||
async def _serialize_payload(
|
||||
requester, room_id, user_id, remote_room_hosts, content
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
requester(Requester)
|
||||
@ -112,7 +114,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||
self.member_handler = hs.get_room_member_handler()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload( # type: ignore
|
||||
async def _serialize_payload( # type: ignore
|
||||
invite_event_id: str,
|
||||
txn_id: Optional[str],
|
||||
requester: Requester,
|
||||
@ -174,7 +176,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
||||
self.distributor = hs.get_distributor()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(room_id, user_id, change):
|
||||
async def _serialize_payload(room_id, user_id, change):
|
||||
"""
|
||||
Args:
|
||||
room_id (str)
|
||||
|
@ -50,7 +50,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
|
||||
self._presence_handler = hs.get_presence_handler()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(user_id):
|
||||
async def _serialize_payload(user_id):
|
||||
return {}
|
||||
|
||||
async def _handle_request(self, request, user_id):
|
||||
@ -92,7 +92,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
||||
self._presence_handler = hs.get_presence_handler()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(user_id, state, ignore_status_msg=False):
|
||||
async def _serialize_payload(user_id, state, ignore_status_msg=False):
|
||||
return {
|
||||
"state": state,
|
||||
"ignore_status_msg": ignore_status_msg,
|
||||
|
@ -34,7 +34,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(
|
||||
async def _serialize_payload(
|
||||
user_id,
|
||||
password_hash,
|
||||
was_guest,
|
||||
@ -105,7 +105,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(user_id, auth_result, access_token):
|
||||
async def _serialize_payload(user_id, auth_result, access_token):
|
||||
"""
|
||||
Args:
|
||||
user_id (str): The user ID that consented
|
||||
|
@ -15,8 +15,6 @@
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
@ -62,8 +60,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
@defer.inlineCallbacks
|
||||
def _serialize_payload(
|
||||
async def _serialize_payload(
|
||||
event_id, store, event, context, requester, ratelimit, extra_users
|
||||
):
|
||||
"""
|
||||
@ -77,7 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||
extra_users (list(UserID)): Any extra users to notify about event
|
||||
"""
|
||||
|
||||
serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
|
||||
serialized_context = await context.serialize(event, store)
|
||||
|
||||
payload = {
|
||||
"event": event.get_pdu_json(),
|
||||
|
@ -54,7 +54,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||
self.streams = hs.get_replication_streams()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(stream_name, from_token, upto_token):
|
||||
async def _serialize_payload(stream_name, from_token, upto_token):
|
||||
return {"from_token": from_token, "upto_token": upto_token}
|
||||
|
||||
async def _handle_request(self, request, stream_name):
|
||||
|
@ -16,8 +16,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
|
||||
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseSlavedStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = MultiWriterIdGenerator(
|
||||
|
@ -17,13 +17,13 @@
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
|
||||
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
|
||||
from synapse.storage.data_stores.main.tags import TagsWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||
from synapse.storage.databases.main.tags import TagsWorkerStore
|
||||
|
||||
|
||||
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
self._account_data_id_gen = SlavedIdTracker(
|
||||
db_conn,
|
||||
"account_data",
|
||||
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.data_stores.main.appservice import (
|
||||
from synapse.storage.databases.main.appservice import (
|
||||
ApplicationServiceTransactionWorkerStore,
|
||||
ApplicationServiceWorkerStore,
|
||||
)
|
||||
|
@ -13,15 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
|
||||
from synapse.util.caches.descriptors import Cache
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
|
||||
|
||||
class SlavedClientIpStore(BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
self.client_ip_last_seen = Cache(
|
||||
|
@ -16,14 +16,14 @@
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams import ToDeviceStream
|
||||
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
|
||||
self._device_inbox_id_gen = SlavedIdTracker(
|
||||
db_conn, "device_inbox", "stream_id"
|
||||
|
@ -16,14 +16,14 @@
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
|
||||
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
|
||||
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.devices import DeviceWorkerStore
|
||||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
self.hs = hs
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.data_stores.main.directory import DirectoryWorkerStore
|
||||
from synapse.storage.databases.main.directory import DirectoryWorkerStore
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
|
||||
|
@ -15,18 +15,18 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
|
||||
from synapse.storage.data_stores.main.event_push_actions import (
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
|
||||
from synapse.storage.databases.main.event_push_actions import (
|
||||
EventPushActionsWorkerStore,
|
||||
)
|
||||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.data_stores.main.relations import RelationsWorkerStore
|
||||
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.data_stores.main.state import StateGroupWorkerStore
|
||||
from synapse.storage.data_stores.main.stream import StreamWorkerStore
|
||||
from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.relations import RelationsWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.databases.main.state import StateGroupWorkerStore
|
||||
from synapse.storage.databases.main.stream import StreamWorkerStore
|
||||
from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
@ -55,11 +55,11 @@ class SlavedEventStore(
|
||||
RelationsWorkerStore,
|
||||
BaseSlavedStore,
|
||||
):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(SlavedEventStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
events_max = self._stream_id_gen.get_current_token()
|
||||
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
|
||||
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"current_state_delta_stream",
|
||||
entity_column="room_id",
|
||||
|
@ -13,14 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.data_stores.main.filtering import FilteringStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.filtering import FilteringStore
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
|
||||
|
||||
class SlavedFilteringStore(BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
# Filters are immutable so this cache doesn't need to be expired
|
||||
|
@ -16,13 +16,13 @@
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams import GroupServerStream
|
||||
from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.group_server import GroupServerWorkerStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
self.hs = hs
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.data_stores.main.keys import KeyStore
|
||||
from synapse.storage.databases.main.keys import KeyStore
|
||||
|
||||
# KeyStore isn't really safe to use from a worker, but for now we do so and hope that
|
||||
# the races it creates aren't too bad.
|
||||
|
@ -15,8 +15,8 @@
|
||||
|
||||
from synapse.replication.tcp.streams import PresenceStream
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.data_stores.main.presence import PresenceStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.presence import PresenceStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
|
||||
class SlavedPresenceStore(BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
|
||||
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.storage.data_stores.main.profile import ProfileWorkerStore
|
||||
from synapse.storage.databases.main.profile import ProfileWorkerStore
|
||||
|
||||
|
||||
class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
|
||||
|
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams import PushRulesStream
|
||||
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
|
||||
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
||||
|
||||
from .events import SlavedEventStore
|
||||
|
||||
|
@ -15,15 +15,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams import PushersStream
|
||||
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.pusher import PusherWorkerStore
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
|
||||
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(SlavedPusherStore, self).__init__(database, db_conn, hs)
|
||||
self._pushers_id_gen = SlavedIdTracker(
|
||||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||
|
@ -15,15 +15,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams import ReceiptsStream
|
||||
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
|
||||
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
# We instantiate this first as the ReceiptsWorkerStore constructor
|
||||
# needs to be able to call get_max_receipt_stream_id
|
||||
self._receipts_id_gen = SlavedIdTracker(
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.data_stores.main.registration import RegistrationWorkerStore
|
||||
from synapse.storage.databases.main.registration import RegistrationWorkerStore
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
|
||||
|
@ -14,15 +14,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams import PublicRoomsStream
|
||||
from synapse.storage.data_stores.main.room import RoomWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.room import RoomWorkerStore
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
|
||||
class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(RoomStore, self).__init__(database, db_conn, hs)
|
||||
self._public_room_id_gen = SlavedIdTracker(
|
||||
db_conn, "public_room_list_stream", "stream_id"
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.data_stores.main.transactions import TransactionStore
|
||||
from synapse.storage.databases.main.transactions import TransactionStore
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
|
||||
|
@ -31,7 +31,7 @@ from synapse.rest.admin._base import (
|
||||
assert_user_is_admin,
|
||||
historical_admin_path_patterns,
|
||||
)
|
||||
from synapse.storage.data_stores.main.room import RoomSortOrder
|
||||
from synapse.storage.databases.main.room import RoomSortOrder
|
||||
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -444,7 +444,7 @@ class RoomMemberListRestServlet(RestServlet):
|
||||
|
||||
async def on_GET(self, request, room_id):
|
||||
# TODO support Pagination stream API (limit/tokens)
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
handler = self.message_handler
|
||||
|
||||
# request the state as of a given event, as identified by a stream token,
|
||||
|
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@ -30,7 +29,7 @@ from .filepath import MediaFilePaths
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from .storage_provider import StorageProvider
|
||||
from .storage_provider import StorageProviderWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -50,7 +49,7 @@ class MediaStorage(object):
|
||||
hs: "HomeServer",
|
||||
local_media_directory: str,
|
||||
filepaths: MediaFilePaths,
|
||||
storage_providers: Sequence["StorageProvider"],
|
||||
storage_providers: Sequence["StorageProviderWrapper"],
|
||||
):
|
||||
self.hs = hs
|
||||
self.local_media_directory = local_media_directory
|
||||
@ -115,11 +114,7 @@ class MediaStorage(object):
|
||||
|
||||
async def finish():
|
||||
for provider in self.storage_providers:
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
result = provider.store_file(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
await provider.store_file(path, file_info)
|
||||
|
||||
finished_called[0] = True
|
||||
|
||||
@ -153,11 +148,7 @@ class MediaStorage(object):
|
||||
return FileResponder(open(local_path, "rb"))
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = provider.fetch(path, file_info) # type: Any
|
||||
# Fetch is supposed to return an Awaitable[Responder], but guard
|
||||
# against improper implementations.
|
||||
if inspect.isawaitable(res):
|
||||
res = await res
|
||||
res = await provider.fetch(path, file_info) # type: Any
|
||||
if res:
|
||||
logger.debug("Streaming %s from %s", path, provider)
|
||||
return res
|
||||
@ -184,11 +175,7 @@ class MediaStorage(object):
|
||||
os.makedirs(dirname)
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = provider.fetch(path, file_info) # type: Any
|
||||
# Fetch is supposed to return an Awaitable[Responder], but guard
|
||||
# against improper implementations.
|
||||
if inspect.isawaitable(res):
|
||||
res = await res
|
||||
res = await provider.fetch(path, file_info) # type: Any
|
||||
if res:
|
||||
with res:
|
||||
consumer = BackgroundFileConsumer(
|
||||
|
@ -586,7 +586,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
|
||||
logger.debug("Running url preview cache expiry")
|
||||
|
||||
if not (await self.store.db.updates.has_completed_background_updates()):
|
||||
if not (await self.store.db_pool.updates.has_completed_background_updates()):
|
||||
logger.info("Still running DB updates; skipping expiry")
|
||||
return
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@ -88,12 +89,18 @@ class StorageProviderWrapper(StorageProvider):
|
||||
return None
|
||||
|
||||
if self.store_synchronous:
|
||||
return await self.backend.store_file(path, file_info)
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
result = self.backend.store_file(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
else:
|
||||
# TODO: Handle errors.
|
||||
def store():
|
||||
async def store():
|
||||
try:
|
||||
return self.backend.store_file(path, file_info)
|
||||
result = self.backend.store_file(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
except Exception:
|
||||
logger.exception("Error storing file")
|
||||
|
||||
@ -101,7 +108,11 @@ class StorageProviderWrapper(StorageProvider):
|
||||
return None
|
||||
|
||||
async def fetch(self, path, file_info):
|
||||
return await self.backend.fetch(path, file_info)
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
result = self.backend.fetch(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
|
||||
|
||||
class FileStorageProviderBackend(StorageProvider):
|
||||
|
@ -105,7 +105,7 @@ from synapse.server_notices.worker_server_notices_sender import (
|
||||
WorkerServerNoticesSender,
|
||||
)
|
||||
from synapse.state import StateHandler, StateResolutionHandler
|
||||
from synapse.storage import DataStore, DataStores, Storage
|
||||
from synapse.storage import Databases, DataStore, Storage
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
@ -280,7 +280,7 @@ class HomeServer(object):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.start_time = int(self.get_clock().time())
|
||||
self.datastores = DataStores(self.DATASTORE_CLASS, self)
|
||||
self.datastores = Databases(self.DATASTORE_CLASS, self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def setup_master(self):
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.urls import ConsentURIBuilder
|
||||
@ -55,14 +56,11 @@ class ConsentServerNotices(object):
|
||||
|
||||
self._consent_uri_builder = ConsentURIBuilder(hs.config)
|
||||
|
||||
async def maybe_send_server_notice_to_user(self, user_id):
|
||||
async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
|
||||
"""Check if we need to send a notice to this user, and does so if so
|
||||
|
||||
Args:
|
||||
user_id (str): user to check
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
user_id: user to check
|
||||
"""
|
||||
if self._server_notice_content is None:
|
||||
# not enabled
|
||||
@ -105,7 +103,7 @@ class ConsentServerNotices(object):
|
||||
self._users_in_progress.remove(user_id)
|
||||
|
||||
|
||||
def copy_with_str_subst(x, substitutions):
|
||||
def copy_with_str_subst(x: Any, substitutions: Any) -> Any:
|
||||
"""Deep-copy a structure, carrying out string substitions on any strings
|
||||
|
||||
Args:
|
||||
@ -121,7 +119,7 @@ def copy_with_str_subst(x, substitutions):
|
||||
if isinstance(x, dict):
|
||||
return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()}
|
||||
if isinstance(x, (list, tuple)):
|
||||
return [copy_with_str_subst(y) for y in x]
|
||||
return [copy_with_str_subst(y, substitutions) for y in x]
|
||||
|
||||
# assume it's uninterested and can be shallow-copied.
|
||||
return x
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from synapse.api.constants import (
|
||||
EventTypes,
|
||||
@ -52,7 +53,7 @@ class ResourceLimitsServerNotices(object):
|
||||
and not hs.config.hs_disabled
|
||||
)
|
||||
|
||||
async def maybe_send_server_notice_to_user(self, user_id):
|
||||
async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
|
||||
"""Check if we need to send a notice to this user, this will be true in
|
||||
two cases.
|
||||
1. The server has reached its limit does not reflect this
|
||||
@ -60,10 +61,7 @@ class ResourceLimitsServerNotices(object):
|
||||
actually the server is fine
|
||||
|
||||
Args:
|
||||
user_id (str): user to check
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
user_id: user to check
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
@ -115,19 +113,21 @@ class ResourceLimitsServerNotices(object):
|
||||
elif not currently_blocked and limit_msg:
|
||||
# Room is not notifying of a block, when it ought to be.
|
||||
await self._apply_limit_block_notification(
|
||||
user_id, limit_msg, limit_type
|
||||
user_id, limit_msg, limit_type # type: ignore
|
||||
)
|
||||
except SynapseError as e:
|
||||
logger.error("Error sending resource limits server notice: %s", e)
|
||||
|
||||
async def _remove_limit_block_notification(self, user_id, ref_events):
|
||||
async def _remove_limit_block_notification(
|
||||
self, user_id: str, ref_events: List[str]
|
||||
) -> None:
|
||||
"""Utility method to remove limit block notifications from the server
|
||||
notices room.
|
||||
|
||||
Args:
|
||||
user_id (str): user to notify
|
||||
ref_events (list[str]): The event_ids of pinned events that are unrelated to
|
||||
limit blocking and need to be preserved.
|
||||
user_id: user to notify
|
||||
ref_events: The event_ids of pinned events that are unrelated to
|
||||
limit blocking and need to be preserved.
|
||||
"""
|
||||
content = {"pinned": ref_events}
|
||||
await self._server_notices_manager.send_notice(
|
||||
@ -135,16 +135,16 @@ class ResourceLimitsServerNotices(object):
|
||||
)
|
||||
|
||||
async def _apply_limit_block_notification(
|
||||
self, user_id, event_body, event_limit_type
|
||||
):
|
||||
self, user_id: str, event_body: str, event_limit_type: str
|
||||
) -> None:
|
||||
"""Utility method to apply limit block notifications in the server
|
||||
notices room.
|
||||
|
||||
Args:
|
||||
user_id (str): user to notify
|
||||
event_body(str): The human readable text that describes the block.
|
||||
event_limit_type(str): Specifies the type of block e.g. monthly active user
|
||||
limit has been exceeded.
|
||||
user_id: user to notify
|
||||
event_body: The human readable text that describes the block.
|
||||
event_limit_type: Specifies the type of block e.g. monthly active user
|
||||
limit has been exceeded.
|
||||
"""
|
||||
content = {
|
||||
"body": event_body,
|
||||
@ -162,7 +162,7 @@ class ResourceLimitsServerNotices(object):
|
||||
user_id, content, EventTypes.Pinned, ""
|
||||
)
|
||||
|
||||
async def _check_and_set_tags(self, user_id, room_id):
|
||||
async def _check_and_set_tags(self, user_id: str, room_id: str) -> None:
|
||||
"""
|
||||
Since server notices rooms were originally not with tags,
|
||||
important to check that tags have been set correctly
|
||||
@ -182,17 +182,16 @@ class ResourceLimitsServerNotices(object):
|
||||
)
|
||||
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
||||
|
||||
async def _is_room_currently_blocked(self, room_id):
|
||||
async def _is_room_currently_blocked(self, room_id: str) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
Determines if the room is currently blocked
|
||||
|
||||
Args:
|
||||
room_id(str): The room id of the server notices room
|
||||
room_id: The room id of the server notices room
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[bool, List]]:
|
||||
bool: Is the room currently blocked
|
||||
list: The list of pinned events that are unrelated to limit blocking
|
||||
list: The list of pinned event IDs that are unrelated to limit blocking
|
||||
This list can be used as a convenience in the case where the block
|
||||
is to be lifted and the remaining pinned event references need to be
|
||||
preserved
|
||||
@ -207,7 +206,7 @@ class ResourceLimitsServerNotices(object):
|
||||
# The user has yet to join the server notices room
|
||||
pass
|
||||
|
||||
referenced_events = []
|
||||
referenced_events = [] # type: List[str]
|
||||
if pinned_state_event is not None:
|
||||
referenced_events = list(pinned_state_event.content.get("pinned", []))
|
||||
|
||||
|
@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import UserID, create_requester
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
@ -50,20 +52,21 @@ class ServerNoticesManager(object):
|
||||
return self._config.server_notices_mxid is not None
|
||||
|
||||
async def send_notice(
|
||||
self, user_id, event_content, type=EventTypes.Message, state_key=None
|
||||
):
|
||||
self,
|
||||
user_id: str,
|
||||
event_content: dict,
|
||||
type: str = EventTypes.Message,
|
||||
state_key: Optional[bool] = None,
|
||||
) -> EventBase:
|
||||
"""Send a notice to the given user
|
||||
|
||||
Creates the server notices room, if none exists.
|
||||
|
||||
Args:
|
||||
user_id (str): mxid of user to send event to.
|
||||
event_content (dict): content of event to send
|
||||
type(EventTypes): type of event
|
||||
is_state_event(bool): Is the event a state event
|
||||
|
||||
Returns:
|
||||
Deferred[FrozenEvent]
|
||||
user_id: mxid of user to send event to.
|
||||
event_content: content of event to send
|
||||
type: type of event
|
||||
is_state_event: Is the event a state event
|
||||
"""
|
||||
room_id = await self.get_or_create_notice_room_for_user(user_id)
|
||||
await self.maybe_invite_user_to_room(user_id, room_id)
|
||||
@ -89,17 +92,17 @@ class ServerNoticesManager(object):
|
||||
return event
|
||||
|
||||
@cached()
|
||||
async def get_or_create_notice_room_for_user(self, user_id):
|
||||
async def get_or_create_notice_room_for_user(self, user_id: str) -> str:
|
||||
"""Get the room for notices for a given user
|
||||
|
||||
If we have not yet created a notice room for this user, create it, but don't
|
||||
invite the user to it.
|
||||
|
||||
Args:
|
||||
user_id (str): complete user id for the user we want a room for
|
||||
user_id: complete user id for the user we want a room for
|
||||
|
||||
Returns:
|
||||
str: room id of notice room.
|
||||
room id of notice room.
|
||||
"""
|
||||
if not self.is_enabled():
|
||||
raise Exception("Server notices not enabled")
|
||||
@ -163,7 +166,7 @@ class ServerNoticesManager(object):
|
||||
logger.info("Created server notices room %s for %s", room_id, user_id)
|
||||
return room_id
|
||||
|
||||
async def maybe_invite_user_to_room(self, user_id: str, room_id: str):
|
||||
async def maybe_invite_user_to_room(self, user_id: str, room_id: str) -> None:
|
||||
"""Invite the given user to the given server room, unless the user has already
|
||||
joined or been invited to it.
|
||||
|
||||
|
@ -12,6 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable, Union
|
||||
|
||||
from synapse.server_notices.consent_server_notices import ConsentServerNotices
|
||||
from synapse.server_notices.resource_limits_server_notices import (
|
||||
ResourceLimitsServerNotices,
|
||||
@ -32,22 +34,22 @@ class ServerNoticesSender(object):
|
||||
self._server_notices = (
|
||||
ConsentServerNotices(hs),
|
||||
ResourceLimitsServerNotices(hs),
|
||||
)
|
||||
) # type: Iterable[Union[ConsentServerNotices, ResourceLimitsServerNotices]]
|
||||
|
||||
async def on_user_syncing(self, user_id):
|
||||
async def on_user_syncing(self, user_id: str) -> None:
|
||||
"""Called when the user performs a sync operation.
|
||||
|
||||
Args:
|
||||
user_id (str): mxid of user who synced
|
||||
user_id: mxid of user who synced
|
||||
"""
|
||||
for sn in self._server_notices:
|
||||
await sn.maybe_send_server_notice_to_user(user_id)
|
||||
|
||||
async def on_user_ip(self, user_id):
|
||||
async def on_user_ip(self, user_id: str) -> None:
|
||||
"""Called on the master when a worker process saw a client request.
|
||||
|
||||
Args:
|
||||
user_id (str): mxid
|
||||
user_id: mxid
|
||||
"""
|
||||
# The synchrotrons use a stubbed version of ServerNoticesSender, so
|
||||
# we check for notices to send to the user in on_user_ip as well as
|
||||
|
@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
class WorkerServerNoticesSender(object):
|
||||
@ -24,24 +23,18 @@ class WorkerServerNoticesSender(object):
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
|
||||
def on_user_syncing(self, user_id):
|
||||
async def on_user_syncing(self, user_id: str) -> None:
|
||||
"""Called when the user performs a sync operation.
|
||||
|
||||
Args:
|
||||
user_id (str): mxid of user who synced
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
user_id: mxid of user who synced
|
||||
"""
|
||||
return defer.succeed(None)
|
||||
return None
|
||||
|
||||
def on_user_ip(self, user_id):
|
||||
async def on_user_ip(self, user_id: str) -> None:
|
||||
"""Called on the master when a worker process saw a client request.
|
||||
|
||||
Args:
|
||||
user_id (str): mxid
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
user_id: mxid
|
||||
"""
|
||||
raise AssertionError("on_user_ip unexpectedly called on worker")
|
||||
|
@ -28,7 +28,7 @@ from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.state import v1, v2
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.types import StateMap
|
||||
from synapse.util import Clock
|
||||
|
@ -17,18 +17,19 @@
|
||||
"""
|
||||
The storage layer is split up into multiple parts to allow Synapse to run
|
||||
against different configurations of databases (e.g. single or multiple
|
||||
databases). The `Database` class represents a single physical database. The
|
||||
`data_stores` are classes that talk directly to a `Database` instance and have
|
||||
associated schemas, background updates, etc. On top of those there are classes
|
||||
that provide high level interfaces that combine calls to multiple `data_stores`.
|
||||
databases). The `DatabasePool` class represents connections to a single physical
|
||||
database. The `databases` are classes that talk directly to a `DatabasePool`
|
||||
instance and have associated schemas, background updates, etc. On top of those
|
||||
there are classes that provide high level interfaces that combine calls to
|
||||
multiple `databases`.
|
||||
|
||||
There are also schemas that get applied to every database, regardless of the
|
||||
data stores associated with them (e.g. the schema version tables), which are
|
||||
stored in `synapse.storage.schema`.
|
||||
"""
|
||||
|
||||
from synapse.storage.data_stores import DataStores
|
||||
from synapse.storage.data_stores.main import DataStore
|
||||
from synapse.storage.databases import Databases
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.storage.persist_events import EventsPersistenceStorage
|
||||
from synapse.storage.purge_events import PurgeEventsStorage
|
||||
from synapse.storage.state import StateGroupStorage
|
||||
@ -40,7 +41,7 @@ class Storage(object):
|
||||
"""The high level interfaces for talking to various storage layers.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, stores: DataStores):
|
||||
def __init__(self, hs, stores: Databases):
|
||||
# We include the main data store here mainly so that we don't have to
|
||||
# rewrite all the existing code to split it into high vs low level
|
||||
# interfaces.
|
||||
|
@ -23,7 +23,7 @@ from canonicaljson import json
|
||||
|
||||
from synapse.storage.database import LoggingTransaction # noqa: F401
|
||||
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -37,11 +37,11 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
per data store (and not one per physical database).
|
||||
"""
|
||||
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self.database_engine = database.engine
|
||||
self.db = database
|
||||
self.db_pool = database
|
||||
self.rand = random.SystemRandom()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
|
@ -88,7 +88,7 @@ class BackgroundUpdater(object):
|
||||
|
||||
def __init__(self, hs, database):
|
||||
self._clock = hs.get_clock()
|
||||
self.db = database
|
||||
self.db_pool = database
|
||||
|
||||
# if a background update is currently running, its name.
|
||||
self._current_background_update = None # type: Optional[str]
|
||||
@ -139,7 +139,7 @@ class BackgroundUpdater(object):
|
||||
# otherwise, check if there are updates to be run. This is important,
|
||||
# as we may be running on a worker which doesn't perform the bg updates
|
||||
# itself, but still wants to wait for them to happen.
|
||||
updates = await self.db.simple_select_onecol(
|
||||
updates = await self.db_pool.simple_select_onecol(
|
||||
"background_updates",
|
||||
keyvalues=None,
|
||||
retcol="1",
|
||||
@ -160,7 +160,7 @@ class BackgroundUpdater(object):
|
||||
if update_name == self._current_background_update:
|
||||
return False
|
||||
|
||||
update_exists = await self.db.simple_select_one_onecol(
|
||||
update_exists = await self.db_pool.simple_select_one_onecol(
|
||||
"background_updates",
|
||||
keyvalues={"update_name": update_name},
|
||||
retcol="1",
|
||||
@ -189,10 +189,10 @@ class BackgroundUpdater(object):
|
||||
ORDER BY ordering, update_name
|
||||
"""
|
||||
)
|
||||
return self.db.cursor_to_dict(txn)
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
if not self._current_background_update:
|
||||
all_pending_updates = await self.db.runInteraction(
|
||||
all_pending_updates = await self.db_pool.runInteraction(
|
||||
"background_updates", get_background_updates_txn,
|
||||
)
|
||||
if not all_pending_updates:
|
||||
@ -243,7 +243,7 @@ class BackgroundUpdater(object):
|
||||
else:
|
||||
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
|
||||
|
||||
progress_json = await self.db.simple_select_one_onecol(
|
||||
progress_json = await self.db_pool.simple_select_one_onecol(
|
||||
"background_updates",
|
||||
keyvalues={"update_name": update_name},
|
||||
retcol="progress_json",
|
||||
@ -402,7 +402,7 @@ class BackgroundUpdater(object):
|
||||
logger.debug("[SQL] %s", sql)
|
||||
c.execute(sql)
|
||||
|
||||
if isinstance(self.db.engine, engines.PostgresEngine):
|
||||
if isinstance(self.db_pool.engine, engines.PostgresEngine):
|
||||
runner = create_index_psql
|
||||
elif psql_only:
|
||||
runner = None
|
||||
@ -413,7 +413,7 @@ class BackgroundUpdater(object):
|
||||
def updater(progress, batch_size):
|
||||
if runner is not None:
|
||||
logger.info("Adding index %s to %s", index_name, table)
|
||||
yield self.db.runWithConnection(runner)
|
||||
yield self.db_pool.runWithConnection(runner)
|
||||
yield self._end_background_update(update_name)
|
||||
return 1
|
||||
|
||||
@ -433,7 +433,7 @@ class BackgroundUpdater(object):
|
||||
% update_name
|
||||
)
|
||||
self._current_background_update = None
|
||||
return self.db.simple_delete_one(
|
||||
return self.db_pool.simple_delete_one(
|
||||
"background_updates", keyvalues={"update_name": update_name}
|
||||
)
|
||||
|
||||
@ -445,7 +445,7 @@ class BackgroundUpdater(object):
|
||||
progress: The progress of the update.
|
||||
"""
|
||||
|
||||
return self.db.runInteraction(
|
||||
return self.db_pool.runInteraction(
|
||||
"background_update_progress",
|
||||
self._background_update_progress_txn,
|
||||
update_name,
|
||||
@ -463,7 +463,7 @@ class BackgroundUpdater(object):
|
||||
|
||||
progress_json = json.dumps(progress)
|
||||
|
||||
self.db.simple_update_one_txn(
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"background_updates",
|
||||
keyvalues={"update_name": update_name},
|
||||
|
@ -279,7 +279,7 @@ class PerformanceCounters(object):
|
||||
return top_n_counters
|
||||
|
||||
|
||||
class Database(object):
|
||||
class DatabasePool(object):
|
||||
"""Wraps a single physical database and connection pool.
|
||||
|
||||
A single database may be used by multiple data stores.
|
||||
|
@ -15,17 +15,17 @@
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.storage.data_stores.main.events import PersistEventsStore
|
||||
from synapse.storage.data_stores.state import StateGroupDataStore
|
||||
from synapse.storage.database import Database, make_conn
|
||||
from synapse.storage.database import DatabasePool, make_conn
|
||||
from synapse.storage.databases.main.events import PersistEventsStore
|
||||
from synapse.storage.databases.state import StateGroupDataStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataStores(object):
|
||||
"""The various data stores.
|
||||
class Databases(object):
|
||||
"""The various databases.
|
||||
|
||||
These are low level interfaces to physical databases.
|
||||
|
||||
@ -51,12 +51,12 @@ class DataStores(object):
|
||||
|
||||
engine.check_database(db_conn)
|
||||
prepare_database(
|
||||
db_conn, engine, hs.config, data_stores=database_config.data_stores,
|
||||
db_conn, engine, hs.config, databases=database_config.databases,
|
||||
)
|
||||
|
||||
database = Database(hs, database_config, engine)
|
||||
database = DatabasePool(hs, database_config, engine)
|
||||
|
||||
if "main" in database_config.data_stores:
|
||||
if "main" in database_config.databases:
|
||||
logger.info("Starting 'main' data store")
|
||||
|
||||
# Sanity check we don't try and configure the main store on
|
||||
@ -73,7 +73,7 @@ class DataStores(object):
|
||||
hs, database, self.main
|
||||
)
|
||||
|
||||
if "state" in database_config.data_stores:
|
||||
if "state" in database_config.databases:
|
||||
logger.info("Starting 'state' data store")
|
||||
|
||||
# Sanity check we don't try and configure the state store on
|
@ -21,7 +21,7 @@ import time
|
||||
|
||||
from synapse.api.constants import PresenceState
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import (
|
||||
IdGenerator,
|
||||
@ -119,7 +119,7 @@ class DataStore(
|
||||
CacheInvalidationWorkerStore,
|
||||
ServerMetricsStore,
|
||||
):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self.database_engine = database.engine
|
||||
@ -174,7 +174,7 @@ class DataStore(
|
||||
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
||||
presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
|
||||
presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"presence_stream",
|
||||
entity_column="user_id",
|
||||
@ -188,7 +188,7 @@ class DataStore(
|
||||
)
|
||||
|
||||
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
|
||||
device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
|
||||
device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"device_inbox",
|
||||
entity_column="user_id",
|
||||
@ -203,7 +203,7 @@ class DataStore(
|
||||
)
|
||||
# The federation outbox and the local device inbox uses the same
|
||||
# stream_id generator.
|
||||
device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
|
||||
device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"device_federation_outbox",
|
||||
entity_column="destination",
|
||||
@ -229,7 +229,7 @@ class DataStore(
|
||||
)
|
||||
|
||||
events_max = self._stream_id_gen.get_current_token()
|
||||
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
|
||||
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"current_state_delta_stream",
|
||||
entity_column="room_id",
|
||||
@ -243,7 +243,7 @@ class DataStore(
|
||||
prefilled_cache=curr_state_delta_prefill,
|
||||
)
|
||||
|
||||
_group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
|
||||
_group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"local_group_updates",
|
||||
entity_column="user_id",
|
||||
@ -282,7 +282,7 @@ class DataStore(
|
||||
|
||||
txn = db_conn.cursor()
|
||||
txn.execute(sql, (PresenceState.OFFLINE,))
|
||||
rows = self.db.cursor_to_dict(txn)
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
txn.close()
|
||||
|
||||
for row in rows:
|
||||
@ -295,7 +295,9 @@ class DataStore(
|
||||
Counts the number of users who used this homeserver in the last 24 hours.
|
||||
"""
|
||||
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
|
||||
return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
|
||||
return self.db_pool.runInteraction(
|
||||
"count_daily_users", self._count_users, yesterday
|
||||
)
|
||||
|
||||
def count_monthly_users(self):
|
||||
"""
|
||||
@ -305,7 +307,7 @@ class DataStore(
|
||||
amongst other things, includes a 3 day grace period before a user counts.
|
||||
"""
|
||||
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
||||
return self.db.runInteraction(
|
||||
return self.db_pool.runInteraction(
|
||||
"count_monthly_users", self._count_users, thirty_days_ago
|
||||
)
|
||||
|
||||
@ -405,7 +407,7 @@ class DataStore(
|
||||
|
||||
return results
|
||||
|
||||
return self.db.runInteraction("count_r30_users", _count_r30_users)
|
||||
return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
|
||||
|
||||
def _get_start_of_day(self):
|
||||
"""
|
||||
@ -470,7 +472,7 @@ class DataStore(
|
||||
# frequently
|
||||
self._last_user_visit_update = now
|
||||
|
||||
return self.db.runInteraction(
|
||||
return self.db_pool.runInteraction(
|
||||
"generate_user_daily_visits", _generate_user_daily_visits
|
||||
)
|
||||
|
||||
@ -481,7 +483,7 @@ class DataStore(
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
"""
|
||||
return self.db.simple_select_list(
|
||||
return self.db_pool.simple_select_list(
|
||||
table="users",
|
||||
keyvalues={},
|
||||
retcols=[
|
||||
@ -543,10 +545,12 @@ class DataStore(
|
||||
where_clause
|
||||
)
|
||||
txn.execute(sql, args)
|
||||
users = self.db.cursor_to_dict(txn)
|
||||
users = self.db_pool.cursor_to_dict(txn)
|
||||
return users, count
|
||||
|
||||
return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn)
|
||||
return self.db_pool.runInteraction(
|
||||
"get_users_paginate_txn", get_users_paginate_txn
|
||||
)
|
||||
|
||||
def search_users(self, term):
|
||||
"""Function to search users list for one or more users with
|
||||
@ -558,7 +562,7 @@ class DataStore(
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
"""
|
||||
return self.db.simple_search_list(
|
||||
return self.db_pool.simple_search_list(
|
||||
table="users",
|
||||
term=term,
|
||||
col="name",
|
@ -23,7 +23,7 @@ from canonicaljson import json
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
@ -40,7 +40,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
# the abstract methods being implemented.
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
account_max = self.get_max_account_data_stream_id()
|
||||
self._account_data_stream_cache = StreamChangeCache(
|
||||
"AccountDataAndTagsChangeCache", account_max
|
||||
@ -69,7 +69,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
|
||||
def get_account_data_for_user_txn(txn):
|
||||
rows = self.db.simple_select_list_txn(
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
"account_data",
|
||||
{"user_id": user_id},
|
||||
@ -80,7 +80,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
row["account_data_type"]: db_to_json(row["content"]) for row in rows
|
||||
}
|
||||
|
||||
rows = self.db.simple_select_list_txn(
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
"room_account_data",
|
||||
{"user_id": user_id},
|
||||
@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
|
||||
return global_account_data, by_room
|
||||
|
||||
return self.db.runInteraction(
|
||||
return self.db_pool.runInteraction(
|
||||
"get_account_data_for_user", get_account_data_for_user_txn
|
||||
)
|
||||
|
||||
@ -104,7 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
Returns:
|
||||
Deferred: A dict
|
||||
"""
|
||||
result = yield self.db.simple_select_one_onecol(
|
||||
result = yield self.db_pool.simple_select_one_onecol(
|
||||
table="account_data",
|
||||
keyvalues={"user_id": user_id, "account_data_type": data_type},
|
||||
retcol="content",
|
||||
@ -129,7 +129,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
|
||||
def get_account_data_for_room_txn(txn):
|
||||
rows = self.db.simple_select_list_txn(
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
"room_account_data",
|
||||
{"user_id": user_id, "room_id": room_id},
|
||||
@ -140,7 +140,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
row["account_data_type"]: db_to_json(row["content"]) for row in rows
|
||||
}
|
||||
|
||||
return self.db.runInteraction(
|
||||
return self.db_pool.runInteraction(
|
||||
"get_account_data_for_room", get_account_data_for_room_txn
|
||||
)
|
||||
|
||||
@ -158,7 +158,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
|
||||
def get_account_data_for_room_and_type_txn(txn):
|
||||
content_json = self.db.simple_select_one_onecol_txn(
|
||||
content_json = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="room_account_data",
|
||||
keyvalues={
|
||||
@ -172,7 +172,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
|
||||
return db_to_json(content_json) if content_json else None
|
||||
|
||||
return self.db.runInteraction(
|
||||
return self.db_pool.runInteraction(
|
||||
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
||||
)
|
||||
|
||||
@ -202,7 +202,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return await self.db.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_updated_global_account_data", get_updated_global_account_data_txn
|
||||
)
|
||||
|
||||
@ -232,7 +232,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return await self.db.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_updated_room_account_data", get_updated_room_account_data_txn
|
||||
)
|
||||
|
||||
@ -277,7 +277,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
if not changed:
|
||||
return defer.succeed(({}, {}))
|
||||
|
||||
return self.db.runInteraction(
|
||||
return self.db_pool.runInteraction(
|
||||
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
||||
)
|
||||
|
||||
@ -295,7 +295,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
|
||||
|
||||
class AccountDataStore(AccountDataWorkerStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
self._account_data_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
"account_data_max_stream_id",
|
||||
@ -333,7 +333,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
# no need to lock here as room_account_data has a unique constraint
|
||||
# on (user_id, room_id, account_data_type) so simple_upsert will
|
||||
# retry if there is a conflict.
|
||||
yield self.db.simple_upsert(
|
||||
yield self.db_pool.simple_upsert(
|
||||
desc="add_room_account_data",
|
||||
table="room_account_data",
|
||||
keyvalues={
|
||||
@ -379,7 +379,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
# no need to lock here as account_data has a unique constraint on
|
||||
# (user_id, account_data_type) so simple_upsert will retry if
|
||||
# there is a conflict.
|
||||
yield self.db.simple_upsert(
|
||||
yield self.db_pool.simple_upsert(
|
||||
desc="add_user_account_data",
|
||||
table="account_data",
|
||||
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
|
||||
@ -427,4 +427,4 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
)
|
||||
txn.execute(update_max_id_sql, (next_id, next_id))
|
||||
|
||||
return self.db.runInteraction("update_account_data_max_stream_id", _update)
|
||||
return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
|
@ -23,8 +23,8 @@ from twisted.internet import defer
|
||||
from synapse.appservice import AppServiceTransaction
|
||||
from synapse.config.appservice import load_appservices
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -49,7 +49,7 @@ def _make_exclusive_regex(services_cache):
|
||||
|
||||
|
||||
class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
self.services_cache = load_appservices(
|
||||
hs.hostname, hs.config.app_service_config_files
|
||||
)
|
||||
@ -134,7 +134,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
A Deferred which resolves to a list of ApplicationServices, which
|
||||
may be empty.
|
||||
"""
|
||||
results = yield self.db.simple_select_list(
|
||||
results = yield self.db_pool.simple_select_list(
|
||||
"application_services_state", {"state": state}, ["as_id"]
|
||||
)
|
||||
# NB: This assumes this class is linked with ApplicationServiceStore
|
||||
@ -156,7 +156,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
Returns:
|
||||
A Deferred which resolves to ApplicationServiceState.
|
||||
"""
|
||||
result = yield self.db.simple_select_one(
|
||||
result = yield self.db_pool.simple_select_one(
|
||||
"application_services_state",
|
||||
{"as_id": service.id},
|
||||
["state"],
|
||||
@ -176,7 +176,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
Returns:
|
||||
A Deferred which resolves when the state was set successfully.
|
||||
"""
|
||||
return self.db.simple_upsert(
|
||||
return self.db_pool.simple_upsert(
|
||||
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||
)
|
||||
|
||||
@ -217,7 +217,9 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
)
|
||||
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
|
||||
|
||||
return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
|
||||
return self.db_pool.runInteraction(
|
||||
"create_appservice_txn", _create_appservice_txn
|
||||
)
|
||||
|
||||
def complete_appservice_txn(self, txn_id, service):
|
||||
"""Completes an application service transaction.
|
||||
@ -250,7 +252,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
)
|
||||
|
||||
# Set current txn_id for AS to 'txn_id'
|
||||
self.db.simple_upsert_txn(
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
"application_services_state",
|
||||
{"as_id": service.id},
|
||||
@ -258,13 +260,13 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
)
|
||||
|
||||
# Delete txn
|
||||
self.db.simple_delete_txn(
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
"application_services_txns",
|
||||
{"txn_id": txn_id, "as_id": service.id},
|
||||
)
|
||||
|
||||
return self.db.runInteraction(
|
||||
return self.db_pool.runInteraction(
|
||||
"complete_appservice_txn", _complete_appservice_txn
|
||||
)
|
||||
|
||||
@ -288,7 +290,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
" ORDER BY txn_id ASC LIMIT 1",
|
||||
(service.id,),
|
||||
)
|
||||
rows = self.db.cursor_to_dict(txn)
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
@ -296,7 +298,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
|
||||
return entry
|
||||
|
||||
entry = yield self.db.runInteraction(
|
||||
entry = yield self.db_pool.runInteraction(
|
||||
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
|
||||
)
|
||||
|
||||
@ -326,7 +328,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||
)
|
||||
|
||||
return self.db.runInteraction(
|
||||
return self.db_pool.runInteraction(
|
||||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||
)
|
||||
|
||||
@ -355,7 +357,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
|
||||
return upper_bound, [row[1] for row in rows]
|
||||
|
||||
upper_bound, event_ids = yield self.db.runInteraction(
|
||||
upper_bound, event_ids = yield self.db_pool.runInteraction(
|
||||
"get_new_events_for_appservice", get_new_events_for_appservice_txn
|
||||
)
|
||||
|
@ -26,7 +26,7 @@ from synapse.replication.tcp.streams.events import (
|
||||
EventsStreamEventRow,
|
||||
)
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
@ -39,7 +39,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
|
||||
|
||||
|
||||
class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
@ -92,7 +92,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
return await self.db.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
@ -203,7 +203,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
return
|
||||
|
||||
cache_func.invalidate(keys)
|
||||
await self.db.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"invalidate_cache_and_stream",
|
||||
self._send_invalidation_to_replication,
|
||||
cache_func.__name__,
|
||||
@ -288,7 +288,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
if keys is not None:
|
||||
keys = list(keys)
|
||||
|
||||
self.db.simple_insert_txn(
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="cache_invalidation_stream_by_instance",
|
||||
values={
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user