mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-25 11:09:21 -05:00
Merge branch 'develop' into rav/module_api_extensions
This commit is contained in:
commit
107f256cd8
@ -1,6 +1,9 @@
|
||||
Synapse 1.8.0 (2020-01-09)
|
||||
==========================
|
||||
|
||||
**WARNING**: As of this release Synapse will refuse to start if the `log_file` config option is specified. Support for the option was removed in v1.3.0.
|
||||
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
|
@ -133,6 +133,11 @@ sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
|
||||
sudo yum groupinstall "Development Tools"
|
||||
```
|
||||
|
||||
Note that Synapse does not support versions of SQLite before 3.11, and CentOS 7
|
||||
uses SQLite 3.7. You may be able to work around this by installing a more
|
||||
recent SQLite version, but it is recommended that you instead use a Postgres
|
||||
database: see [docs/postgres.md](docs/postgres.md).
|
||||
|
||||
#### macOS
|
||||
|
||||
Installing prerequisites on macOS:
|
||||
|
@ -75,6 +75,15 @@ for example:
|
||||
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||
|
||||
|
||||
Upgrading to v1.8.0
|
||||
===================
|
||||
|
||||
Specifying a ``log_file`` config option will now cause Synapse to refuse to
|
||||
start, and should be replaced by with the ``log_config`` option. Support for
|
||||
the ``log_file`` option was removed in v1.3.0 and has since had no effect.
|
||||
|
||||
|
||||
Upgrading to v1.7.0
|
||||
===================
|
||||
|
||||
|
1
changelog.d/6655.misc
Normal file
1
changelog.d/6655.misc
Normal file
@ -0,0 +1 @@
|
||||
Add `local_current_membership` table for tracking local user membership state in rooms.
|
1
changelog.d/6667.misc
Normal file
1
changelog.d/6667.misc
Normal file
@ -0,0 +1 @@
|
||||
Fixup `synapse.replication` to pass mypy checks.
|
1
changelog.d/6675.removal
Normal file
1
changelog.d/6675.removal
Normal file
@ -0,0 +1 @@
|
||||
Synapse no longer supports versions of SQLite before 3.11, and will refuse to start when configured to use an older version. Administrators are recommended to migrate their database to Postgres (see instructions [here](docs/postgres.md)).
|
1
changelog.d/6681.feature
Normal file
1
changelog.d/6681.feature
Normal file
@ -0,0 +1 @@
|
||||
Add new quarantine media admin APIs to quarantine by media ID or by user who uploaded the media.
|
2
changelog.d/6682.bugfix
Normal file
2
changelog.d/6682.bugfix
Normal file
@ -0,0 +1,2 @@
|
||||
Fix "CRITICAL" errors being logged when a request is received for a uri containing non-ascii characters.
|
||||
|
1
changelog.d/6686.misc
Normal file
1
changelog.d/6686.misc
Normal file
@ -0,0 +1 @@
|
||||
Allow additional_resources to implement IResource directly.
|
1
changelog.d/6687.misc
Normal file
1
changelog.d/6687.misc
Normal file
@ -0,0 +1 @@
|
||||
Allow REST endpoint implementations to raise a RedirectException, which will redirect the user's browser to a given location.
|
1
changelog.d/6689.misc
Normal file
1
changelog.d/6689.misc
Normal file
@ -0,0 +1 @@
|
||||
Updates to the SAML mapping provider API.
|
1
changelog.d/6690.bugfix
Normal file
1
changelog.d/6690.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix a bug where we would assign a numeric userid if somebody tried registering with an empty username.
|
1
changelog.d/6691.misc
Normal file
1
changelog.d/6691.misc
Normal file
@ -0,0 +1 @@
|
||||
Remove redundant RegistrationError class.
|
1
changelog.d/6697.misc
Normal file
1
changelog.d/6697.misc
Normal file
@ -0,0 +1 @@
|
||||
Don't block processing of incoming EDUs behind processing PDUs in the same transaction.
|
1
changelog.d/6698.doc
Normal file
1
changelog.d/6698.doc
Normal file
@ -0,0 +1 @@
|
||||
Add more endpoints to the documentation for Synapse workers.
|
@ -22,19 +22,81 @@ It returns a JSON body like the following:
|
||||
}
|
||||
```
|
||||
|
||||
# Quarantine media in a room
|
||||
|
||||
This API 'quarantines' all the media in a room.
|
||||
|
||||
The API is:
|
||||
|
||||
```
|
||||
POST /_synapse/admin/v1/quarantine_media/<room_id>
|
||||
|
||||
{}
|
||||
```
|
||||
# Quarantine media
|
||||
|
||||
Quarantining media means that it is marked as inaccessible by users. It applies
|
||||
to any local media, and any locally-cached copies of remote media.
|
||||
|
||||
The media file itself (and any thumbnails) is not deleted from the server.
|
||||
|
||||
## Quarantining media by ID
|
||||
|
||||
This API quarantines a single piece of local or remote media.
|
||||
|
||||
Request:
|
||||
|
||||
```
|
||||
POST /_synapse/admin/v1/media/quarantine/<server_name>/<media_id>
|
||||
|
||||
{}
|
||||
```
|
||||
|
||||
Where `server_name` is in the form of `example.org`, and `media_id` is in the
|
||||
form of `abcdefg12345...`.
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
{}
|
||||
```
|
||||
|
||||
## Quarantining media in a room
|
||||
|
||||
This API quarantines all local and remote media in a room.
|
||||
|
||||
Request:
|
||||
|
||||
```
|
||||
POST /_synapse/admin/v1/room/<room_id>/media/quarantine
|
||||
|
||||
{}
|
||||
```
|
||||
|
||||
Where `room_id` is in the form of `!roomid12345:example.org`.
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
{
|
||||
"num_quarantined": 10 # The number of media items successfully quarantined
|
||||
}
|
||||
```
|
||||
|
||||
Note that there is a legacy endpoint, `POST
|
||||
/_synapse/admin/v1/quarantine_media/<room_id >`, that operates the same.
|
||||
However, it is deprecated and may be removed in a future release.
|
||||
|
||||
## Quarantining all media of a user
|
||||
|
||||
This API quarantines all *local* media that a *local* user has uploaded. That is to say, if
|
||||
you would like to quarantine media uploaded by a user on a remote homeserver, you should
|
||||
instead use one of the other APIs.
|
||||
|
||||
Request:
|
||||
|
||||
```
|
||||
POST /_synapse/admin/v1/user/<user_id>/media/quarantine
|
||||
|
||||
{}
|
||||
```
|
||||
|
||||
Where `user_id` is in the form of `@bob:example.org`.
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
{
|
||||
"num_quarantined": 10 # The number of media items successfully quarantined
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -168,8 +168,11 @@ endpoints matching the following regular expressions:
|
||||
^/_matrix/federation/v1/make_join/
|
||||
^/_matrix/federation/v1/make_leave/
|
||||
^/_matrix/federation/v1/send_join/
|
||||
^/_matrix/federation/v2/send_join/
|
||||
^/_matrix/federation/v1/send_leave/
|
||||
^/_matrix/federation/v2/send_leave/
|
||||
^/_matrix/federation/v1/invite/
|
||||
^/_matrix/federation/v2/invite/
|
||||
^/_matrix/federation/v1/query_auth/
|
||||
^/_matrix/federation/v1/event_auth/
|
||||
^/_matrix/federation/v1/exchange_third_party_invite/
|
||||
@ -199,7 +202,9 @@ Handles the media repository. It can handle all endpoints starting with:
|
||||
... and the following regular expressions matching media-specific administration APIs:
|
||||
|
||||
^/_synapse/admin/v1/purge_media_cache$
|
||||
^/_synapse/admin/v1/room/.*/media$
|
||||
^/_synapse/admin/v1/room/.*/media.*$
|
||||
^/_synapse/admin/v1/user/.*/media.*$
|
||||
^/_synapse/admin/v1/media/.*$
|
||||
^/_synapse/admin/v1/quarantine_media/.*$
|
||||
|
||||
You should also set `enable_media_repo: False` in the shared configuration
|
||||
@ -288,6 +293,7 @@ file. For example:
|
||||
Handles some event creation. It can handle REST endpoints matching:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
|
||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/
|
||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
|
||||
^/_matrix/client/(api/v1|r0|unstable)/join/
|
||||
^/_matrix/client/(api/v1|r0|unstable)/profile/
|
||||
|
@ -447,20 +447,15 @@ class Porter(object):
|
||||
else:
|
||||
return
|
||||
|
||||
def setup_db(self, db_config: DatabaseConnectionConfig, engine):
|
||||
db_conn = make_conn(db_config, engine)
|
||||
prepare_database(db_conn, engine, config=None)
|
||||
|
||||
db_conn.commit()
|
||||
|
||||
return db_conn
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def build_db_store(self, db_config: DatabaseConnectionConfig):
|
||||
def build_db_store(
|
||||
self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False,
|
||||
):
|
||||
"""Builds and returns a database store using the provided configuration.
|
||||
|
||||
Args:
|
||||
config: The database configuration
|
||||
db_config: The database configuration
|
||||
allow_outdated_version: True to suppress errors about the database server
|
||||
version being too old to run a complete synapse
|
||||
|
||||
Returns:
|
||||
The built Store object.
|
||||
@ -468,16 +463,16 @@ class Porter(object):
|
||||
self.progress.set_state("Preparing %s" % db_config.config["name"])
|
||||
|
||||
engine = create_engine(db_config.config)
|
||||
conn = self.setup_db(db_config, engine)
|
||||
|
||||
hs = MockHomeserver(self.hs_config)
|
||||
|
||||
store = Store(Database(hs, db_config, engine), conn, hs)
|
||||
|
||||
yield store.db.runInteraction(
|
||||
"%s_engine.check_database" % db_config.config["name"],
|
||||
engine.check_database,
|
||||
)
|
||||
with make_conn(db_config, engine) as db_conn:
|
||||
engine.check_database(
|
||||
db_conn, allow_outdated_version=allow_outdated_version
|
||||
)
|
||||
prepare_database(db_conn, engine, config=self.hs_config)
|
||||
store = Store(Database(hs, db_config, engine), db_conn, hs)
|
||||
db_conn.commit()
|
||||
|
||||
return store
|
||||
|
||||
@ -502,8 +497,10 @@ class Porter(object):
|
||||
@defer.inlineCallbacks
|
||||
def run(self):
|
||||
try:
|
||||
self.sqlite_store = yield self.build_db_store(
|
||||
DatabaseConnectionConfig("master-sqlite", self.sqlite_config)
|
||||
# we allow people to port away from outdated versions of sqlite.
|
||||
self.sqlite_store = self.build_db_store(
|
||||
DatabaseConnectionConfig("master-sqlite", self.sqlite_config),
|
||||
allow_outdated_version=True,
|
||||
)
|
||||
|
||||
# Check if all background updates are done, abort if not.
|
||||
@ -518,7 +515,7 @@ class Porter(object):
|
||||
)
|
||||
defer.returnValue(None)
|
||||
|
||||
self.postgres_store = yield self.build_db_store(
|
||||
self.postgres_store = self.build_db_store(
|
||||
self.hs_config.get_single_database()
|
||||
)
|
||||
|
||||
|
@ -17,13 +17,15 @@
|
||||
"""Contains exceptions and error codes."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
from six import iteritems
|
||||
from six.moves import http_client
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.web import http
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -80,6 +82,29 @@ class CodeMessageException(RuntimeError):
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class RedirectException(CodeMessageException):
|
||||
"""A pseudo-error indicating that we want to redirect the client to a different
|
||||
location
|
||||
|
||||
Attributes:
|
||||
cookies: a list of set-cookies values to add to the response. For example:
|
||||
b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT"
|
||||
"""
|
||||
|
||||
def __init__(self, location: bytes, http_code: int = http.FOUND):
|
||||
"""
|
||||
|
||||
Args:
|
||||
location: the URI to redirect to
|
||||
http_code: the HTTP response code
|
||||
"""
|
||||
msg = "Redirect to %s" % (location.decode("utf-8"),)
|
||||
super().__init__(code=http_code, msg=msg)
|
||||
self.location = location
|
||||
|
||||
self.cookies = [] # type: List[bytes]
|
||||
|
||||
|
||||
class SynapseError(CodeMessageException):
|
||||
"""A base exception type for matrix errors which have an errcode and error
|
||||
message (as well as an HTTP status code).
|
||||
@ -158,12 +183,6 @@ class UserDeactivatedError(SynapseError):
|
||||
)
|
||||
|
||||
|
||||
class RegistrationError(SynapseError):
|
||||
"""An error raised when a registration event fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FederationDeniedError(SynapseError):
|
||||
"""An error raised when the server tries to federate with a server which
|
||||
is not on its federation whitelist.
|
||||
|
@ -31,7 +31,7 @@ from prometheus_client import Gauge
|
||||
from twisted.application import service
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
||||
from twisted.web.resource import EncodingResourceWrapper, IResource, NoResource
|
||||
from twisted.web.server import GzipEncoderFactory
|
||||
from twisted.web.static import File
|
||||
|
||||
@ -109,7 +109,16 @@ class SynapseHomeServer(HomeServer):
|
||||
for path, resmodule in additional_resources.items():
|
||||
handler_cls, config = load_module(resmodule)
|
||||
handler = handler_cls(config, module_api)
|
||||
resources[path] = AdditionalResource(self, handler.handle_request)
|
||||
if IResource.providedBy(handler):
|
||||
resource = handler
|
||||
elif hasattr(handler, "handle_request"):
|
||||
resource = AdditionalResource(self, handler.handle_request)
|
||||
else:
|
||||
raise ConfigError(
|
||||
"additional_resource %s does not implement a known interface"
|
||||
% (resmodule["module"],)
|
||||
)
|
||||
resources[path] = resource
|
||||
|
||||
# try to find something useful to redirect '/' to
|
||||
if WEB_CLIENT_PREFIX in resources:
|
||||
|
@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import six
|
||||
from six import iteritems
|
||||
@ -22,6 +23,7 @@ from six import iteritems
|
||||
from canonicaljson import json
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.abstract import isIPAddress
|
||||
from twisted.python import failure
|
||||
|
||||
@ -41,7 +43,11 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
|
||||
from synapse.federation.persistence import TransactionActions
|
||||
from synapse.federation.units import Edu, Transaction
|
||||
from synapse.http.endpoint import parse_server_name
|
||||
from synapse.logging.context import nested_logging_context
|
||||
from synapse.logging.context import (
|
||||
make_deferred_yieldable,
|
||||
nested_logging_context,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.replication.http.federation import (
|
||||
@ -49,7 +55,7 @@ from synapse.replication.http.federation import (
|
||||
ReplicationGetQueryRestServlet,
|
||||
)
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util import glob_to_regex
|
||||
from synapse.util import glob_to_regex, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
||||
@ -160,6 +166,43 @@ class FederationServer(FederationBase):
|
||||
)
|
||||
return 400, response
|
||||
|
||||
# We process PDUs and EDUs in parallel. This is important as we don't
|
||||
# want to block things like to device messages from reaching clients
|
||||
# behind the potentially expensive handling of PDUs.
|
||||
pdu_results, _ = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
self._handle_pdus_in_txn, origin, transaction, request_time
|
||||
),
|
||||
run_in_background(self._handle_edus_in_txn, origin, transaction),
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
response = {"pdus": pdu_results}
|
||||
|
||||
logger.debug("Returning: %s", str(response))
|
||||
|
||||
await self.transaction_actions.set_response(origin, transaction, 200, response)
|
||||
return 200, response
|
||||
|
||||
async def _handle_pdus_in_txn(
|
||||
self, origin: str, transaction: Transaction, request_time: int
|
||||
) -> Dict[str, dict]:
|
||||
"""Process the PDUs in a received transaction.
|
||||
|
||||
Args:
|
||||
origin: the server making the request
|
||||
transaction: incoming transaction
|
||||
request_time: timestamp that the HTTP request arrived at
|
||||
|
||||
Returns:
|
||||
A map from event ID of a processed PDU to any errors we should
|
||||
report back to the sending server.
|
||||
"""
|
||||
|
||||
received_pdus_counter.inc(len(transaction.pdus))
|
||||
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
@ -250,20 +293,23 @@ class FederationServer(FederationBase):
|
||||
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
|
||||
)
|
||||
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in (Edu(**x) for x in transaction.edus):
|
||||
await self.received_edu(origin, edu.edu_type, edu.content)
|
||||
return pdu_results
|
||||
|
||||
response = {"pdus": pdu_results}
|
||||
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
|
||||
"""Process the EDUs in a received transaction.
|
||||
"""
|
||||
|
||||
logger.debug("Returning: %s", str(response))
|
||||
async def _process_edu(edu_dict):
|
||||
received_edus_counter.inc()
|
||||
|
||||
await self.transaction_actions.set_response(origin, transaction, 200, response)
|
||||
return 200, response
|
||||
edu = Edu(**edu_dict)
|
||||
await self.registry.on_edu(edu.edu_type, origin, edu.content)
|
||||
|
||||
async def received_edu(self, origin, edu_type, content):
|
||||
received_edus_counter.inc()
|
||||
await self.registry.on_edu(edu_type, origin, content)
|
||||
await concurrently_execute(
|
||||
_process_edu,
|
||||
getattr(transaction, "edus", []),
|
||||
TRANSACTION_CONCURRENCY_LIMIT,
|
||||
)
|
||||
|
||||
async def on_context_state_request(self, origin, room_id, event_id):
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
|
@ -134,7 +134,7 @@ class AdminHandler(BaseHandler):
|
||||
The returned value is that returned by `writer.finished()`.
|
||||
"""
|
||||
# Get all rooms the user is in or has been in
|
||||
rooms = await self.store.get_rooms_for_user_where_membership_is(
|
||||
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
|
||||
user_id,
|
||||
membership_list=(
|
||||
Membership.JOIN,
|
||||
|
@ -140,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler):
|
||||
user_id (str): The user ID to reject pending invites for.
|
||||
"""
|
||||
user = UserID.from_string(user_id)
|
||||
pending_invites = await self.store.get_invited_rooms_for_user(user_id)
|
||||
pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
|
||||
|
||||
for room in pending_invites:
|
||||
try:
|
||||
|
@ -101,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
if include_archived:
|
||||
memberships.append(Membership.LEAVE)
|
||||
|
||||
room_list = await self.store.get_rooms_for_user_where_membership_is(
|
||||
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
|
||||
user_id=user_id, membership_list=memberships
|
||||
)
|
||||
|
||||
|
@ -20,13 +20,7 @@ from twisted.internet import defer
|
||||
|
||||
from synapse import types
|
||||
from synapse.api.constants import MAX_USERID_LENGTH, LoginType
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
ConsentNotGivenError,
|
||||
RegistrationError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
from synapse.http.servlet import assert_params_in_dict
|
||||
from synapse.replication.http.login import RegisterDeviceReplicationServlet
|
||||
@ -165,7 +159,7 @@ class RegistrationHandler(BaseHandler):
|
||||
Returns:
|
||||
Deferred[str]: user_id
|
||||
Raises:
|
||||
RegistrationError if there was a problem registering.
|
||||
SynapseError if there was a problem registering.
|
||||
"""
|
||||
yield self.check_registration_ratelimit(address)
|
||||
|
||||
@ -174,7 +168,7 @@ class RegistrationHandler(BaseHandler):
|
||||
if password:
|
||||
password_hash = yield self._auth_handler.hash(password)
|
||||
|
||||
if localpart:
|
||||
if localpart is not None:
|
||||
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
||||
|
||||
was_guest = guest_access_token is not None
|
||||
@ -182,7 +176,7 @@ class RegistrationHandler(BaseHandler):
|
||||
if not was_guest:
|
||||
try:
|
||||
int(localpart)
|
||||
raise RegistrationError(
|
||||
raise SynapseError(
|
||||
400, "Numeric user IDs are reserved for guest users."
|
||||
)
|
||||
except ValueError:
|
||||
|
@ -690,7 +690,7 @@ class RoomMemberHandler(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_inviter(self, user_id, room_id):
|
||||
invite = yield self.store.get_invite_for_user_in_room(
|
||||
invite = yield self.store.get_invite_for_local_user_in_room(
|
||||
user_id=user_id, room_id=room_id
|
||||
)
|
||||
if invite:
|
||||
|
@ -24,6 +24,7 @@ from saml2.client import Saml2Client
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.rest.client.v1.login import SSOAuthHandler
|
||||
from synapse.types import (
|
||||
UserID,
|
||||
@ -59,7 +60,8 @@ class SamlHandler:
|
||||
|
||||
# plugin to do custom mapping from saml response to mxid
|
||||
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
|
||||
hs.config.saml2_user_mapping_provider_config
|
||||
hs.config.saml2_user_mapping_provider_config,
|
||||
ModuleApi(hs, hs.get_auth_handler()),
|
||||
)
|
||||
|
||||
# identifier for the external_ids table
|
||||
@ -112,10 +114,10 @@ class SamlHandler:
|
||||
# the dict.
|
||||
self.expire_sessions()
|
||||
|
||||
user_id = await self._map_saml_response_to_user(resp_bytes)
|
||||
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
|
||||
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||
|
||||
async def _map_saml_response_to_user(self, resp_bytes):
|
||||
async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
|
||||
try:
|
||||
saml2_auth = self._saml_client.parse_authn_request_response(
|
||||
resp_bytes,
|
||||
@ -183,7 +185,7 @@ class SamlHandler:
|
||||
# Map saml response to user attributes using the configured mapping provider
|
||||
for i in range(1000):
|
||||
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
|
||||
saml2_auth, i
|
||||
saml2_auth, i, client_redirect_url=client_redirect_url,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@ -216,6 +218,8 @@ class SamlHandler:
|
||||
500, "Unable to generate a Matrix ID from the SAML response"
|
||||
)
|
||||
|
||||
logger.info("Mapped SAML user to local part %s", localpart)
|
||||
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=displayname
|
||||
)
|
||||
@ -265,17 +269,21 @@ class SamlConfig(object):
|
||||
class DefaultSamlMappingProvider(object):
|
||||
__version__ = "0.0.1"
|
||||
|
||||
def __init__(self, parsed_config: SamlConfig):
|
||||
def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
|
||||
"""The default SAML user mapping provider
|
||||
|
||||
Args:
|
||||
parsed_config: Module configuration
|
||||
module_api: module api proxy
|
||||
"""
|
||||
self._mxid_source_attribute = parsed_config.mxid_source_attribute
|
||||
self._mxid_mapper = parsed_config.mxid_mapper
|
||||
|
||||
def saml_response_to_user_attributes(
|
||||
self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
|
||||
self,
|
||||
saml_response: saml2.response.AuthnResponse,
|
||||
failures: int,
|
||||
client_redirect_url: str,
|
||||
) -> dict:
|
||||
"""Maps some text from a SAML response to attributes of a new user
|
||||
|
||||
@ -285,6 +293,8 @@ class DefaultSamlMappingProvider(object):
|
||||
failures: How many times a call to this function with this
|
||||
saml_response has resulted in a failure
|
||||
|
||||
client_redirect_url: where the client wants to redirect to
|
||||
|
||||
Returns:
|
||||
dict: A dict containing new user attributes. Possible keys:
|
||||
* mxid_localpart (str): Required. The localpart of the user's mxid
|
||||
|
@ -179,7 +179,7 @@ class SearchHandler(BaseHandler):
|
||||
search_filter = Filter(filter_dict)
|
||||
|
||||
# TODO: Search through left rooms too
|
||||
rooms = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
|
||||
user.to_string(),
|
||||
membership_list=[Membership.JOIN],
|
||||
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
|
||||
|
@ -1662,7 +1662,7 @@ class SyncHandler(object):
|
||||
Membership.BAN,
|
||||
)
|
||||
|
||||
room_list = await self.store.get_rooms_for_user_where_membership_is(
|
||||
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
|
||||
user_id=user_id, membership_list=membership_list
|
||||
)
|
||||
|
||||
|
@ -14,8 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import cgi
|
||||
import collections
|
||||
import html
|
||||
import http.client
|
||||
import logging
|
||||
import types
|
||||
@ -36,6 +36,7 @@ import synapse.metrics
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
RedirectException,
|
||||
SynapseError,
|
||||
UnrecognizedRequestError,
|
||||
)
|
||||
@ -153,14 +154,18 @@ def _return_html_error(f, request):
|
||||
|
||||
Args:
|
||||
f (twisted.python.failure.Failure):
|
||||
request (twisted.web.iweb.IRequest):
|
||||
request (twisted.web.server.Request):
|
||||
"""
|
||||
if f.check(CodeMessageException):
|
||||
cme = f.value
|
||||
code = cme.code
|
||||
msg = cme.msg
|
||||
|
||||
if isinstance(cme, SynapseError):
|
||||
if isinstance(cme, RedirectException):
|
||||
logger.info("%s redirect to %s", request, cme.location)
|
||||
request.setHeader(b"location", cme.location)
|
||||
request.cookies.extend(cme.cookies)
|
||||
elif isinstance(cme, SynapseError):
|
||||
logger.info("%s SynapseError: %s - %s", request, code, msg)
|
||||
else:
|
||||
logger.error(
|
||||
@ -178,7 +183,7 @@ def _return_html_error(f, request):
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
|
||||
body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8")
|
||||
body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
|
||||
request.setResponseCode(code)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%i" % (len(body),))
|
||||
|
@ -88,7 +88,7 @@ class SynapseRequest(Request):
|
||||
def get_redacted_uri(self):
|
||||
uri = self.uri
|
||||
if isinstance(uri, bytes):
|
||||
uri = self.uri.decode("ascii")
|
||||
uri = self.uri.decode("ascii", errors="replace")
|
||||
return redact_uri(uri)
|
||||
|
||||
def get_method(self):
|
||||
|
@ -571,6 +571,9 @@ def run_in_background(f, *args, **kwargs):
|
||||
yield or await on (for instance because you want to pass it to
|
||||
deferred.gatherResults()).
|
||||
|
||||
If f returns a Coroutine object, it will be wrapped into a Deferred (which will have
|
||||
the side effect of executing the coroutine).
|
||||
|
||||
Note that if you completely discard the result, you should make sure that
|
||||
`f` doesn't raise any deferred exceptions, otherwise a scary-looking
|
||||
CRITICAL error about an unhandled error will be logged without much
|
||||
|
@ -21,7 +21,7 @@ from synapse.storage import Storage
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_badge_count(store, user_id):
|
||||
invites = yield store.get_invited_rooms_for_user(user_id)
|
||||
invites = yield store.get_invited_rooms_for_local_user(user_id)
|
||||
joins = yield store.get_rooms_for_user(user_id)
|
||||
|
||||
my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")
|
||||
|
@ -16,6 +16,7 @@
|
||||
import abc
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from six import raise_from
|
||||
from six.moves import urllib
|
||||
@ -78,9 +79,8 @@ class ReplicationEndpoint(object):
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
NAME = abc.abstractproperty()
|
||||
PATH_ARGS = abc.abstractproperty()
|
||||
|
||||
NAME = abc.abstractproperty() # type: str # type: ignore
|
||||
PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
|
||||
METHOD = "POST"
|
||||
CACHE = True
|
||||
RETRY_ON_TIMEOUT = True
|
||||
@ -171,7 +171,7 @@ class ReplicationEndpoint(object):
|
||||
# have a good idea that the request has either succeeded or failed on
|
||||
# the master, and so whether we should clean up or not.
|
||||
while True:
|
||||
headers = {}
|
||||
headers = {} # type: Dict[bytes, List[bytes]]
|
||||
inject_active_span_byte_dict(headers, None, check_destination=False)
|
||||
try:
|
||||
result = yield request_func(uri, data, headers=headers)
|
||||
@ -207,7 +207,7 @@ class ReplicationEndpoint(object):
|
||||
method = self.METHOD
|
||||
|
||||
if self.CACHE:
|
||||
handler = self._cached_handler
|
||||
handler = self._cached_handler # type: ignore
|
||||
url_args.append("txn_id")
|
||||
|
||||
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import six
|
||||
|
||||
@ -41,7 +41,7 @@ class BaseSlavedStore(SQLBaseStore):
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = SlavedIdTracker(
|
||||
db_conn, "cache_invalidation_stream", "stream_id"
|
||||
)
|
||||
) # type: Optional[SlavedIdTracker]
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
|
||||
@ -62,7 +62,8 @@ class BaseSlavedStore(SQLBaseStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "caches":
|
||||
self._cache_id_gen.advance(token)
|
||||
if self._cache_id_gen:
|
||||
self._cache_id_gen.advance(token)
|
||||
for row in rows:
|
||||
if row.cache_func == CURRENT_STATE_CACHE_NAME:
|
||||
room_id = row.keys[0]
|
||||
|
@ -152,7 +152,7 @@ class SlavedEventStore(
|
||||
|
||||
if etype == EventTypes.Member:
|
||||
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
|
||||
self.get_invited_rooms_for_user.invalidate((state_key,))
|
||||
self.get_invited_rooms_for_local_user.invalidate((state_key,))
|
||||
|
||||
if relates_to:
|
||||
self.get_relations_for_event.invalidate_many((relates_to,))
|
||||
|
@ -29,7 +29,7 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
||||
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
|
||||
self.presence_stream_cache = StreamChangeCache(
|
||||
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
|
||||
)
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
@ -28,6 +28,7 @@ from synapse.replication.tcp.protocol import (
|
||||
)
|
||||
|
||||
from .commands import (
|
||||
Command,
|
||||
FederationAckCommand,
|
||||
InvalidateCacheCommand,
|
||||
RemovePusherCommand,
|
||||
@ -89,15 +90,15 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
|
||||
|
||||
# Any pending commands to be sent once a new connection has been
|
||||
# established
|
||||
self.pending_commands = []
|
||||
self.pending_commands = [] # type: List[Command]
|
||||
|
||||
# Map from string -> deferred, to wake up when receiveing a SYNC with
|
||||
# the given string.
|
||||
# Used for tests.
|
||||
self.awaiting_syncs = {}
|
||||
self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
|
||||
|
||||
# The factory used to create connections.
|
||||
self.factory = None
|
||||
self.factory = None # type: Optional[ReplicationClientFactory]
|
||||
|
||||
def start_replication(self, hs):
|
||||
"""Helper method to start a replication connection to the remote server
|
||||
@ -235,4 +236,5 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
|
||||
# We don't reset the delay any earlier as otherwise if there is a
|
||||
# problem during start up we'll end up tight looping connecting to the
|
||||
# server.
|
||||
self.factory.resetDelay()
|
||||
if self.factory:
|
||||
self.factory.resetDelay()
|
||||
|
@ -20,15 +20,16 @@ allowed to be sent by which side.
|
||||
|
||||
import logging
|
||||
import platform
|
||||
from typing import Tuple, Type
|
||||
|
||||
if platform.python_implementation() == "PyPy":
|
||||
import json
|
||||
|
||||
_json_encoder = json.JSONEncoder()
|
||||
else:
|
||||
import simplejson as json
|
||||
import simplejson as json # type: ignore[no-redef] # noqa: F821
|
||||
|
||||
_json_encoder = json.JSONEncoder(namedtuple_as_object=False)
|
||||
_json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -44,7 +45,7 @@ class Command(object):
|
||||
The default implementation creates a command of form `<NAME> <data>`
|
||||
"""
|
||||
|
||||
NAME = None
|
||||
NAME = None # type: str
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
@ -386,25 +387,24 @@ class UserIpCommand(Command):
|
||||
)
|
||||
|
||||
|
||||
_COMMANDS = (
|
||||
ServerCommand,
|
||||
RdataCommand,
|
||||
PositionCommand,
|
||||
ErrorCommand,
|
||||
PingCommand,
|
||||
NameCommand,
|
||||
ReplicateCommand,
|
||||
UserSyncCommand,
|
||||
FederationAckCommand,
|
||||
SyncCommand,
|
||||
RemovePusherCommand,
|
||||
InvalidateCacheCommand,
|
||||
UserIpCommand,
|
||||
) # type: Tuple[Type[Command], ...]
|
||||
|
||||
# Map of command name to command type.
|
||||
COMMAND_MAP = {
|
||||
cmd.NAME: cmd
|
||||
for cmd in (
|
||||
ServerCommand,
|
||||
RdataCommand,
|
||||
PositionCommand,
|
||||
ErrorCommand,
|
||||
PingCommand,
|
||||
NameCommand,
|
||||
ReplicateCommand,
|
||||
UserSyncCommand,
|
||||
FederationAckCommand,
|
||||
SyncCommand,
|
||||
RemovePusherCommand,
|
||||
InvalidateCacheCommand,
|
||||
UserIpCommand,
|
||||
)
|
||||
}
|
||||
COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
|
||||
|
||||
# The commands the server is allowed to send
|
||||
VALID_SERVER_COMMANDS = (
|
||||
|
@ -53,6 +53,7 @@ import fcntl
|
||||
import logging
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List, Set, Tuple
|
||||
|
||||
from six import iteritems, iterkeys
|
||||
|
||||
@ -65,13 +66,11 @@ from twisted.python.failure import Failure
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from .commands import (
|
||||
from synapse.replication.tcp.commands import (
|
||||
COMMAND_MAP,
|
||||
VALID_CLIENT_COMMANDS,
|
||||
VALID_SERVER_COMMANDS,
|
||||
Command,
|
||||
ErrorCommand,
|
||||
NameCommand,
|
||||
PingCommand,
|
||||
@ -82,6 +81,10 @@ from .commands import (
|
||||
SyncCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.types import Collection
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from .streams import STREAMS_MAP
|
||||
|
||||
connection_close_counter = Counter(
|
||||
@ -124,8 +127,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
|
||||
delimiter = b"\n"
|
||||
|
||||
VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
|
||||
VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
|
||||
# Valid commands we expect to receive
|
||||
VALID_INBOUND_COMMANDS = [] # type: Collection[str]
|
||||
|
||||
# Valid commands we can send
|
||||
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
|
||||
|
||||
max_line_buffer = 10000
|
||||
|
||||
@ -144,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
self.conn_id = random_string(5) # To dedupe in case of name clashes.
|
||||
|
||||
# List of pending commands to send once we've established the connection
|
||||
self.pending_commands = []
|
||||
self.pending_commands = [] # type: List[Command]
|
||||
|
||||
# The LoopingCall for sending pings.
|
||||
self._send_ping_loop = None
|
||||
|
||||
self.inbound_commands_counter = defaultdict(int)
|
||||
self.outbound_commands_counter = defaultdict(int)
|
||||
self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
|
||||
self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
|
||||
|
||||
def connectionMade(self):
|
||||
logger.info("[%s] Connection established", self.id())
|
||||
@ -409,14 +415,14 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
self.streamer = streamer
|
||||
|
||||
# The streams the client has subscribed to and is up to date with
|
||||
self.replication_streams = set()
|
||||
self.replication_streams = set() # type: Set[str]
|
||||
|
||||
# The streams the client is currently subscribing to.
|
||||
self.connecting_streams = set()
|
||||
self.connecting_streams = set() # type: Set[str]
|
||||
|
||||
# Map from stream name to list of updates to send once we've finished
|
||||
# subscribing the client to the stream.
|
||||
self.pending_rdata = {}
|
||||
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
|
||||
|
||||
def connectionMade(self):
|
||||
self.send_command(ServerCommand(self.server_name))
|
||||
@ -642,11 +648,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
# Set of stream names that have been subscribe to, but haven't yet
|
||||
# caught up with. This is used to track when the client has been fully
|
||||
# connected to the remote.
|
||||
self.streams_connecting = set()
|
||||
self.streams_connecting = set() # type: Set[str]
|
||||
|
||||
# Map of stream to batched updates. See RdataCommand for info on how
|
||||
# batching works.
|
||||
self.pending_batches = {}
|
||||
self.pending_batches = {} # type: Dict[str, Any]
|
||||
|
||||
def connectionMade(self):
|
||||
self.send_command(NameCommand(self.client_name))
|
||||
@ -766,7 +772,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
|
||||
op = SIOCINQ
|
||||
else:
|
||||
op = SIOCOUTQ
|
||||
size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
|
||||
size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
|
||||
return size
|
||||
return 0
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from six import itervalues
|
||||
|
||||
@ -79,7 +80,7 @@ class ReplicationStreamer(object):
|
||||
self._replication_torture_level = hs.config.replication_torture_level
|
||||
|
||||
# Current connections.
|
||||
self.connections = []
|
||||
self.connections = [] # type: List[ServerReplicationStreamProtocol]
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_total_connections",
|
||||
|
@ -14,10 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Any
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@ -104,8 +104,9 @@ class Stream(object):
|
||||
time it was called up until the point `advance_current_token` was called.
|
||||
"""
|
||||
|
||||
NAME = None # The name of the stream
|
||||
ROW_TYPE = None # The type of the row. Used by the default impl of parse_row.
|
||||
NAME = None # type: str # The name of the stream
|
||||
# The type of the row. Used by the default impl of parse_row.
|
||||
ROW_TYPE = None # type: Any
|
||||
_LIMITED = True # Whether the update function takes a limit
|
||||
|
||||
@classmethod
|
||||
@ -231,8 +232,8 @@ class BackfillStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
self.current_token = store.get_current_backfill_token
|
||||
self.update_function = store.get_all_new_backfill_event_rows
|
||||
self.current_token = store.get_current_backfill_token # type: ignore
|
||||
self.update_function = store.get_all_new_backfill_event_rows # type: ignore
|
||||
|
||||
super(BackfillStream, self).__init__(hs)
|
||||
|
||||
@ -246,8 +247,8 @@ class PresenceStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
presence_handler = hs.get_presence_handler()
|
||||
|
||||
self.current_token = store.get_current_presence_token
|
||||
self.update_function = presence_handler.get_all_presence_updates
|
||||
self.current_token = store.get_current_presence_token # type: ignore
|
||||
self.update_function = presence_handler.get_all_presence_updates # type: ignore
|
||||
|
||||
super(PresenceStream, self).__init__(hs)
|
||||
|
||||
@ -260,8 +261,8 @@ class TypingStream(Stream):
|
||||
def __init__(self, hs):
|
||||
typing_handler = hs.get_typing_handler()
|
||||
|
||||
self.current_token = typing_handler.get_current_token
|
||||
self.update_function = typing_handler.get_all_typing_updates
|
||||
self.current_token = typing_handler.get_current_token # type: ignore
|
||||
self.update_function = typing_handler.get_all_typing_updates # type: ignore
|
||||
|
||||
super(TypingStream, self).__init__(hs)
|
||||
|
||||
@ -273,8 +274,8 @@ class ReceiptsStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_receipt_stream_id
|
||||
self.update_function = store.get_all_updated_receipts
|
||||
self.current_token = store.get_max_receipt_stream_id # type: ignore
|
||||
self.update_function = store.get_all_updated_receipts # type: ignore
|
||||
|
||||
super(ReceiptsStream, self).__init__(hs)
|
||||
|
||||
@ -310,8 +311,8 @@ class PushersStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_pushers_stream_token
|
||||
self.update_function = store.get_all_updated_pushers_rows
|
||||
self.current_token = store.get_pushers_stream_token # type: ignore
|
||||
self.update_function = store.get_all_updated_pushers_rows # type: ignore
|
||||
|
||||
super(PushersStream, self).__init__(hs)
|
||||
|
||||
@ -327,8 +328,8 @@ class CachesStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_cache_stream_token
|
||||
self.update_function = store.get_all_updated_caches
|
||||
self.current_token = store.get_cache_stream_token # type: ignore
|
||||
self.update_function = store.get_all_updated_caches # type: ignore
|
||||
|
||||
super(CachesStream, self).__init__(hs)
|
||||
|
||||
@ -343,8 +344,8 @@ class PublicRoomsStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_current_public_room_stream_id
|
||||
self.update_function = store.get_all_new_public_rooms
|
||||
self.current_token = store.get_current_public_room_stream_id # type: ignore
|
||||
self.update_function = store.get_all_new_public_rooms # type: ignore
|
||||
|
||||
super(PublicRoomsStream, self).__init__(hs)
|
||||
|
||||
@ -360,8 +361,8 @@ class DeviceListsStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token
|
||||
self.update_function = store.get_all_device_list_changes_for_remotes
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
|
||||
|
||||
super(DeviceListsStream, self).__init__(hs)
|
||||
|
||||
@ -376,8 +377,8 @@ class ToDeviceStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_to_device_stream_token
|
||||
self.update_function = store.get_all_new_device_messages
|
||||
self.current_token = store.get_to_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_new_device_messages # type: ignore
|
||||
|
||||
super(ToDeviceStream, self).__init__(hs)
|
||||
|
||||
@ -392,8 +393,8 @@ class TagAccountDataStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_account_data_stream_id
|
||||
self.update_function = store.get_all_updated_tags
|
||||
self.current_token = store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = store.get_all_updated_tags # type: ignore
|
||||
|
||||
super(TagAccountDataStream, self).__init__(hs)
|
||||
|
||||
@ -408,7 +409,7 @@ class AccountDataStream(Stream):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.current_token = self.store.get_max_account_data_stream_id
|
||||
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
|
||||
|
||||
super(AccountDataStream, self).__init__(hs)
|
||||
|
||||
@ -434,8 +435,8 @@ class GroupServerStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_group_stream_token
|
||||
self.update_function = store.get_all_groups_changes
|
||||
self.current_token = store.get_group_stream_token # type: ignore
|
||||
self.update_function = store.get_all_groups_changes # type: ignore
|
||||
|
||||
super(GroupServerStream, self).__init__(hs)
|
||||
|
||||
@ -451,7 +452,7 @@ class UserSignatureStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token
|
||||
self.update_function = store.get_all_user_signature_changes_for_remotes
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
|
||||
|
||||
super(UserSignatureStream, self).__init__(hs)
|
||||
|
@ -13,7 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import heapq
|
||||
from typing import Tuple, Type
|
||||
|
||||
import attr
|
||||
|
||||
@ -63,7 +65,8 @@ class BaseEventsStreamRow(object):
|
||||
Specifies how to identify, serialize and deserialize the different types.
|
||||
"""
|
||||
|
||||
TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
|
||||
# Unique string that ids the type. Must be overriden in sub classes.
|
||||
TypeId = None # type: str
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data):
|
||||
@ -99,9 +102,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
|
||||
event_id = attr.ib() # str, optional
|
||||
|
||||
|
||||
TypeToRow = {
|
||||
Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
|
||||
}
|
||||
_EventRows = (
|
||||
EventsStreamEventRow,
|
||||
EventsStreamCurrentStateRow,
|
||||
) # type: Tuple[Type[BaseEventsStreamRow], ...]
|
||||
|
||||
TypeToRow = {Row.TypeId: Row for Row in _EventRows}
|
||||
|
||||
|
||||
class EventsStream(Stream):
|
||||
@ -112,7 +118,7 @@ class EventsStream(Stream):
|
||||
|
||||
def __init__(self, hs):
|
||||
self._store = hs.get_datastore()
|
||||
self.current_token = self._store.get_current_events_token
|
||||
self.current_token = self._store.get_current_events_token # type: ignore
|
||||
|
||||
super(EventsStream, self).__init__(hs)
|
||||
|
||||
|
@ -37,7 +37,7 @@ class FederationStream(Stream):
|
||||
def __init__(self, hs):
|
||||
federation_sender = hs.get_federation_sender()
|
||||
|
||||
self.current_token = federation_sender.get_current_token
|
||||
self.update_function = federation_sender.get_replication_rows
|
||||
self.current_token = federation_sender.get_current_token # type: ignore
|
||||
self.update_function = federation_sender.get_replication_rows # type: ignore
|
||||
|
||||
super(FederationStream, self).__init__(hs)
|
||||
|
@ -32,16 +32,24 @@ class QuarantineMediaInRoom(RestServlet):
|
||||
this server.
|
||||
"""
|
||||
|
||||
PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
|
||||
PATTERNS = (
|
||||
historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
|
||||
+
|
||||
# This path kept around for legacy reasons
|
||||
historical_admin_path_patterns("/quarantine_media/(?P<room_id>![^/]+)")
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request, room_id):
|
||||
async def on_POST(self, request, room_id: str):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
logging.info("Quarantining room: %s", room_id)
|
||||
|
||||
# Quarantine all media in this room
|
||||
num_quarantined = await self.store.quarantine_media_ids_in_room(
|
||||
room_id, requester.user.to_string()
|
||||
)
|
||||
@ -49,6 +57,60 @@ class QuarantineMediaInRoom(RestServlet):
|
||||
return 200, {"num_quarantined": num_quarantined}
|
||||
|
||||
|
||||
class QuarantineMediaByUser(RestServlet):
|
||||
"""Quarantines all local media by a given user so that no one can download it via
|
||||
this server.
|
||||
"""
|
||||
|
||||
PATTERNS = historical_admin_path_patterns(
|
||||
"/user/(?P<user_id>[^/]+)/media/quarantine"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request, user_id: str):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
logging.info("Quarantining local media by user: %s", user_id)
|
||||
|
||||
# Quarantine all media this user has uploaded
|
||||
num_quarantined = await self.store.quarantine_media_ids_by_user(
|
||||
user_id, requester.user.to_string()
|
||||
)
|
||||
|
||||
return 200, {"num_quarantined": num_quarantined}
|
||||
|
||||
|
||||
class QuarantineMediaByID(RestServlet):
|
||||
"""Quarantines local or remote media by a given ID so that no one can download
|
||||
it via this server.
|
||||
"""
|
||||
|
||||
PATTERNS = historical_admin_path_patterns(
|
||||
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request, server_name: str, media_id: str):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
logging.info("Quarantining local media by ID: %s/%s", server_name, media_id)
|
||||
|
||||
# Quarantine this media id
|
||||
await self.store.quarantine_media_by_id(
|
||||
server_name, media_id, requester.user.to_string()
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
||||
class ListMediaInRoom(RestServlet):
|
||||
"""Lists all of the media in a given room.
|
||||
"""
|
||||
@ -94,4 +156,6 @@ def register_servlets_for_media_repo(hs, http_server):
|
||||
"""
|
||||
PurgeMediaCacheRestServlet(hs).register(http_server)
|
||||
QuarantineMediaInRoom(hs).register(http_server)
|
||||
QuarantineMediaByID(hs).register(http_server)
|
||||
QuarantineMediaByUser(hs).register(http_server)
|
||||
ListMediaInRoom(hs).register(http_server)
|
||||
|
@ -105,7 +105,7 @@ class ServerNoticesManager(object):
|
||||
|
||||
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
|
||||
|
||||
rooms = yield self._store.get_rooms_for_user_where_membership_is(
|
||||
rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
|
||||
user_id, [Membership.INVITE, Membership.JOIN]
|
||||
)
|
||||
system_mxid = self._config.server_notices_mxid
|
||||
|
@ -47,7 +47,7 @@ class DataStores(object):
|
||||
with make_conn(database_config, engine) as db_conn:
|
||||
logger.info("Preparing database %r...", db_name)
|
||||
|
||||
engine.check_database(db_conn.cursor())
|
||||
engine.check_database(db_conn)
|
||||
prepare_database(
|
||||
db_conn, engine, hs.config, data_stores=database_config.data_stores,
|
||||
)
|
||||
|
@ -128,6 +128,7 @@ class EventsStore(
|
||||
hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
|
||||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _read_forward_extremities(self):
|
||||
@ -547,6 +548,34 @@ class EventsStore(
|
||||
],
|
||||
)
|
||||
|
||||
# Note: Do we really want to delete rows here (that we do not
|
||||
# subsequently reinsert below)? While technically correct it means
|
||||
# we have no record of the fact the user *was* a member of the
|
||||
# room but got, say, state reset out of it.
|
||||
if to_delete or to_insert:
|
||||
txn.executemany(
|
||||
"DELETE FROM local_current_membership"
|
||||
" WHERE room_id = ? AND user_id = ?",
|
||||
(
|
||||
(room_id, state_key)
|
||||
for etype, state_key in itertools.chain(to_delete, to_insert)
|
||||
if etype == EventTypes.Member and self.is_mine_id(state_key)
|
||||
),
|
||||
)
|
||||
|
||||
if to_insert:
|
||||
txn.executemany(
|
||||
"""INSERT INTO local_current_membership
|
||||
(room_id, user_id, event_id, membership)
|
||||
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
|
||||
""",
|
||||
[
|
||||
(room_id, key[1], ev_id, ev_id)
|
||||
for key, ev_id in to_insert.items()
|
||||
if key[0] == EventTypes.Member and self.is_mine_id(key[1])
|
||||
],
|
||||
)
|
||||
|
||||
txn.call_after(
|
||||
self._curr_state_delta_stream_cache.entity_has_changed,
|
||||
room_id,
|
||||
@ -1724,6 +1753,7 @@ class EventsStore(
|
||||
"local_invites",
|
||||
"room_account_data",
|
||||
"room_tags",
|
||||
"local_current_membership",
|
||||
):
|
||||
logger.info("[purge] removing %s from %s", room_id, table)
|
||||
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
|
||||
|
@ -18,7 +18,7 @@ import collections
|
||||
import logging
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from six import integer_types
|
||||
|
||||
@ -399,6 +399,8 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
the associated media
|
||||
"""
|
||||
|
||||
logger.info("Quarantining media in room: %s", room_id)
|
||||
|
||||
def _quarantine_media_in_room_txn(txn):
|
||||
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
|
||||
total_media_quarantined = 0
|
||||
@ -494,6 +496,118 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
|
||||
return local_media_mxcs, remote_media_mxcs
|
||||
|
||||
def quarantine_media_by_id(
|
||||
self, server_name: str, media_id: str, quarantined_by: str,
|
||||
):
|
||||
"""quarantines a single local or remote media id
|
||||
|
||||
Args:
|
||||
server_name: The name of the server that holds this media
|
||||
media_id: The ID of the media to be quarantined
|
||||
quarantined_by: The user ID that initiated the quarantine request
|
||||
"""
|
||||
logger.info("Quarantining media: %s/%s", server_name, media_id)
|
||||
is_local = server_name == self.config.server_name
|
||||
|
||||
def _quarantine_media_by_id_txn(txn):
|
||||
local_mxcs = [media_id] if is_local else []
|
||||
remote_mxcs = [(server_name, media_id)] if not is_local else []
|
||||
|
||||
return self._quarantine_media_txn(
|
||||
txn, local_mxcs, remote_mxcs, quarantined_by
|
||||
)
|
||||
|
||||
return self.db.runInteraction(
|
||||
"quarantine_media_by_user", _quarantine_media_by_id_txn
|
||||
)
|
||||
|
||||
def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
|
||||
"""quarantines all local media associated with a single user
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to quarantine media of
|
||||
quarantined_by: The ID of the user who made the quarantine request
|
||||
"""
|
||||
|
||||
def _quarantine_media_by_user_txn(txn):
|
||||
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
|
||||
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
|
||||
|
||||
return self.db.runInteraction(
|
||||
"quarantine_media_by_user", _quarantine_media_by_user_txn
|
||||
)
|
||||
|
||||
def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
|
||||
"""Retrieves local media IDs by a given user
|
||||
|
||||
Args:
|
||||
txn (cursor)
|
||||
user_id: The ID of the user to retrieve media IDs of
|
||||
|
||||
Returns:
|
||||
The local and remote media as a lists of tuples where the key is
|
||||
the hostname and the value is the media ID.
|
||||
"""
|
||||
# Local media
|
||||
sql = """
|
||||
SELECT media_id
|
||||
FROM local_media_repository
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
if filter_quarantined:
|
||||
sql += "AND quarantined_by IS NULL"
|
||||
txn.execute(sql, (user_id,))
|
||||
|
||||
local_media_ids = [row[0] for row in txn]
|
||||
|
||||
# TODO: Figure out all remote media a user has referenced in a message
|
||||
|
||||
return local_media_ids
|
||||
|
||||
def _quarantine_media_txn(
|
||||
self,
|
||||
txn,
|
||||
local_mxcs: List[str],
|
||||
remote_mxcs: List[Tuple[str, str]],
|
||||
quarantined_by: str,
|
||||
) -> int:
|
||||
"""Quarantine local and remote media items
|
||||
|
||||
Args:
|
||||
txn (cursor)
|
||||
local_mxcs: A list of local mxc URLs
|
||||
remote_mxcs: A list of (remote server, media id) tuples representing
|
||||
remote mxc URLs
|
||||
quarantined_by: The ID of the user who initiated the quarantine request
|
||||
Returns:
|
||||
The total number of media items quarantined
|
||||
"""
|
||||
total_media_quarantined = 0
|
||||
|
||||
# Update all the tables to set the quarantined_by flag
|
||||
txn.executemany(
|
||||
"""
|
||||
UPDATE local_media_repository
|
||||
SET quarantined_by = ?
|
||||
WHERE media_id = ?
|
||||
""",
|
||||
((quarantined_by, media_id) for media_id in local_mxcs),
|
||||
)
|
||||
|
||||
txn.executemany(
|
||||
"""
|
||||
UPDATE remote_media_cache
|
||||
SET quarantined_by = ?
|
||||
WHERE media_origin = ? AND media_id = ?
|
||||
""",
|
||||
((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
|
||||
)
|
||||
|
||||
total_media_quarantined += len(local_mxcs)
|
||||
total_media_quarantined += len(remote_mxcs)
|
||||
|
||||
return total_media_quarantined
|
||||
|
||||
|
||||
class RoomBackgroundUpdateStore(SQLBaseStore):
|
||||
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
||||
|
@ -297,19 +297,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
return {row[0]: row[1] for row in txn}
|
||||
|
||||
@cached()
|
||||
def get_invited_rooms_for_user(self, user_id):
|
||||
""" Get all the rooms the user is invited to
|
||||
def get_invited_rooms_for_local_user(self, user_id):
|
||||
""" Get all the rooms the *local* user is invited to
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
Returns:
|
||||
A deferred list of RoomsForUser.
|
||||
"""
|
||||
|
||||
return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
|
||||
return self.get_rooms_for_local_user_where_membership_is(
|
||||
user_id, [Membership.INVITE]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_invite_for_user_in_room(self, user_id, room_id):
|
||||
"""Gets the invite for the given user and room
|
||||
def get_invite_for_local_user_in_room(self, user_id, room_id):
|
||||
"""Gets the invite for the given *local* user and room
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
@ -319,15 +322,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
Deferred: Resolves to either a RoomsForUser or None if no invite was
|
||||
found.
|
||||
"""
|
||||
invites = yield self.get_invited_rooms_for_user(user_id)
|
||||
invites = yield self.get_invited_rooms_for_local_user(user_id)
|
||||
for invite in invites:
|
||||
if invite.room_id == room_id:
|
||||
return invite
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
|
||||
""" Get all the rooms for this user where the membership for this user
|
||||
def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
|
||||
""" Get all the rooms for this *local* user where the membership for this user
|
||||
matches one in the membership list.
|
||||
|
||||
Filters out forgotten rooms.
|
||||
@ -344,8 +347,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
return defer.succeed(None)
|
||||
|
||||
rooms = yield self.db.runInteraction(
|
||||
"get_rooms_for_user_where_membership_is",
|
||||
self._get_rooms_for_user_where_membership_is_txn,
|
||||
"get_rooms_for_local_user_where_membership_is",
|
||||
self._get_rooms_for_local_user_where_membership_is_txn,
|
||||
user_id,
|
||||
membership_list,
|
||||
)
|
||||
@ -354,76 +357,42 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
|
||||
return [room for room in rooms if room.room_id not in forgotten_rooms]
|
||||
|
||||
def _get_rooms_for_user_where_membership_is_txn(
|
||||
def _get_rooms_for_local_user_where_membership_is_txn(
|
||||
self, txn, user_id, membership_list
|
||||
):
|
||||
|
||||
do_invite = Membership.INVITE in membership_list
|
||||
membership_list = [m for m in membership_list if m != Membership.INVITE]
|
||||
|
||||
results = []
|
||||
if membership_list:
|
||||
if self._current_state_events_membership_up_to_date:
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "c.membership", membership_list
|
||||
)
|
||||
sql = """
|
||||
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
|
||||
FROM current_state_events AS c
|
||||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
WHERE
|
||||
c.type = 'm.room.member'
|
||||
AND state_key = ?
|
||||
AND %s
|
||||
""" % (
|
||||
clause,
|
||||
)
|
||||
else:
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "m.membership", membership_list
|
||||
)
|
||||
sql = """
|
||||
SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
|
||||
FROM current_state_events AS c
|
||||
INNER JOIN room_memberships AS m USING (room_id, event_id)
|
||||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
WHERE
|
||||
c.type = 'm.room.member'
|
||||
AND state_key = ?
|
||||
AND %s
|
||||
""" % (
|
||||
clause,
|
||||
)
|
||||
|
||||
txn.execute(sql, (user_id, *args))
|
||||
results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
|
||||
|
||||
if do_invite:
|
||||
sql = (
|
||||
"SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
|
||||
" FROM local_invites as i"
|
||||
" INNER JOIN events as e USING (event_id)"
|
||||
" WHERE invitee = ? AND locally_rejected is NULL"
|
||||
" AND replaced_by is NULL"
|
||||
# Paranoia check.
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
raise Exception(
|
||||
"Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
|
||||
% (user_id,),
|
||||
)
|
||||
|
||||
txn.execute(sql, (user_id,))
|
||||
results.extend(
|
||||
RoomsForUser(
|
||||
room_id=r["room_id"],
|
||||
sender=r["inviter"],
|
||||
event_id=r["event_id"],
|
||||
stream_ordering=r["stream_ordering"],
|
||||
membership=Membership.INVITE,
|
||||
)
|
||||
for r in self.db.cursor_to_dict(txn)
|
||||
)
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "c.membership", membership_list
|
||||
)
|
||||
|
||||
sql = """
|
||||
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
|
||||
FROM local_current_membership AS c
|
||||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
WHERE
|
||||
user_id = ?
|
||||
AND %s
|
||||
""" % (
|
||||
clause,
|
||||
)
|
||||
|
||||
txn.execute(sql, (user_id, *args))
|
||||
results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
|
||||
|
||||
return results
|
||||
|
||||
@cachedInlineCallbacks(max_entries=500000, iterable=True)
|
||||
@cached(max_entries=500000, iterable=True)
|
||||
def get_rooms_for_user_with_stream_ordering(self, user_id):
|
||||
"""Returns a set of room_ids the user is currently joined to
|
||||
"""Returns a set of room_ids the user is currently joined to.
|
||||
|
||||
If a remote user only returns rooms this server is currently
|
||||
participating in.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
@ -433,17 +402,49 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
the rooms the user is in currently, along with the stream ordering
|
||||
of the most recent join for that user and room.
|
||||
"""
|
||||
rooms = yield self.get_rooms_for_user_where_membership_is(
|
||||
user_id, membership_list=[Membership.JOIN]
|
||||
)
|
||||
return frozenset(
|
||||
GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
|
||||
for r in rooms
|
||||
return self.db.runInteraction(
|
||||
"get_rooms_for_user_with_stream_ordering",
|
||||
self._get_rooms_for_user_with_stream_ordering_txn,
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
|
||||
# We use `current_state_events` here and not `local_current_membership`
|
||||
# as a) this gets called with remote users and b) this only gets called
|
||||
# for rooms the server is participating in.
|
||||
if self._current_state_events_membership_up_to_date:
|
||||
sql = """
|
||||
SELECT room_id, e.stream_ordering
|
||||
FROM current_state_events AS c
|
||||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
WHERE
|
||||
c.type = 'm.room.member'
|
||||
AND state_key = ?
|
||||
AND c.membership = ?
|
||||
"""
|
||||
else:
|
||||
sql = """
|
||||
SELECT room_id, e.stream_ordering
|
||||
FROM current_state_events AS c
|
||||
INNER JOIN room_memberships AS m USING (room_id, event_id)
|
||||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
WHERE
|
||||
c.type = 'm.room.member'
|
||||
AND state_key = ?
|
||||
AND m.membership = ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, Membership.JOIN))
|
||||
results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
|
||||
|
||||
return results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_for_user(self, user_id, on_invalidate=None):
|
||||
"""Returns a set of room_ids the user is currently joined to
|
||||
"""Returns a set of room_ids the user is currently joined to.
|
||||
|
||||
If a remote user only returns rooms this server is currently
|
||||
participating in.
|
||||
"""
|
||||
rooms = yield self.get_rooms_for_user_with_stream_ordering(
|
||||
user_id, on_invalidate=on_invalidate
|
||||
@ -1022,7 +1023,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||
event.internal_metadata.stream_ordering,
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
|
||||
self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
|
||||
)
|
||||
|
||||
# We update the local_invites table only if the event is "current",
|
||||
@ -1064,6 +1065,27 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||
),
|
||||
)
|
||||
|
||||
# We also update the `local_current_membership` table with
|
||||
# latest invite info. This will usually get updated by the
|
||||
# `current_state_events` handling, unless its an outlier.
|
||||
if event.internal_metadata.is_outlier():
|
||||
# This should only happen for out of band memberships, so
|
||||
# we add a paranoia check.
|
||||
assert event.internal_metadata.is_out_of_band_membership()
|
||||
|
||||
self.db.simple_upsert_txn(
|
||||
txn,
|
||||
table="local_current_membership",
|
||||
keyvalues={
|
||||
"room_id": event.room_id,
|
||||
"user_id": event.state_key,
|
||||
},
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"membership": event.membership,
|
||||
},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def locally_reject_invite(self, user_id, room_id):
|
||||
sql = (
|
||||
@ -1075,6 +1097,15 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||
def f(txn, stream_ordering):
|
||||
txn.execute(sql, (stream_ordering, True, room_id, user_id))
|
||||
|
||||
# We also clear this entry from `local_current_membership`.
|
||||
# Ideally we'd point to a leave event, but we don't have one, so
|
||||
# nevermind.
|
||||
self.db.simple_delete_txn(
|
||||
txn,
|
||||
table="local_current_membership",
|
||||
keyvalues={"room_id": room_id, "user_id": user_id},
|
||||
)
|
||||
|
||||
with self._stream_id_gen.get_next() as stream_ordering:
|
||||
yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
|
||||
|
||||
|
@ -0,0 +1,97 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# We create a new table called `local_current_membership` that stores the latest
|
||||
# membership state of local users in rooms, which helps track leaves/bans/etc
|
||||
# even if the server has left the room (and so has deleted the room from
|
||||
# `current_state_events`). This will also include outstanding invites for local
|
||||
# users for rooms the server isn't in.
|
||||
#
|
||||
# If the server isn't and hasn't been in the room then it will only include
|
||||
# outsstanding invites, and not e.g. pre-emptive bans of local users.
|
||||
#
|
||||
# If the server later rejoins a room `local_current_membership` can simply be
|
||||
# replaced with the new current state of the room (which results in the
|
||||
# equivalent behaviour as if the server had remained in the room).
|
||||
|
||||
|
||||
def run_upgrade(cur, database_engine, config, *args, **kwargs):
|
||||
# We need to do the insert in `run_upgrade` section as we don't have access
|
||||
# to `config` in `run_create`.
|
||||
|
||||
# This upgrade may take a bit of time for large servers (e.g. one minute for
|
||||
# matrix.org) but means we avoid a lots of book keeping required to do it as
|
||||
# a background update.
|
||||
|
||||
# We check if the `current_state_events.membership` is up to date by
|
||||
# checking if the relevant background update has finished. If it has
|
||||
# finished we can avoid doing a join against `room_memberships`, which
|
||||
# speesd things up.
|
||||
cur.execute(
|
||||
"""SELECT 1 FROM background_updates
|
||||
WHERE update_name = 'current_state_events_membership'
|
||||
"""
|
||||
)
|
||||
current_state_membership_up_to_date = not bool(cur.fetchone())
|
||||
|
||||
# Cheekily drop and recreate indices, as that is faster.
|
||||
cur.execute("DROP INDEX local_current_membership_idx")
|
||||
cur.execute("DROP INDEX local_current_membership_room_idx")
|
||||
|
||||
if current_state_membership_up_to_date:
|
||||
sql = """
|
||||
INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
|
||||
SELECT c.room_id, state_key AS user_id, event_id, c.membership
|
||||
FROM current_state_events AS c
|
||||
WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key like '%' || ?
|
||||
"""
|
||||
else:
|
||||
# We can't rely on the membership column, so we need to join against
|
||||
# `room_memberships`.
|
||||
sql = """
|
||||
INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
|
||||
SELECT c.room_id, state_key AS user_id, event_id, r.membership
|
||||
FROM current_state_events AS c
|
||||
INNER JOIN room_memberships AS r USING (event_id)
|
||||
WHERE type = 'm.room.member' and state_key like '%' || ?
|
||||
"""
|
||||
cur.execute(sql, (config.server_name,))
|
||||
|
||||
cur.execute(
|
||||
"CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
|
||||
)
|
||||
cur.execute(
|
||||
"CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
|
||||
)
|
||||
|
||||
|
||||
def run_create(cur, database_engine, *args, **kwargs):
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE local_current_membership (
|
||||
room_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
membership TEXT NOT NULL
|
||||
)"""
|
||||
)
|
||||
|
||||
cur.execute(
|
||||
"CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
|
||||
)
|
||||
cur.execute(
|
||||
"CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
|
||||
)
|
@ -32,20 +32,7 @@ class PostgresEngine(object):
|
||||
self.synchronous_commit = database_config.get("synchronous_commit", True)
|
||||
self._version = None # unknown as yet
|
||||
|
||||
def check_database(self, txn):
|
||||
txn.execute("SHOW SERVER_ENCODING")
|
||||
rows = txn.fetchall()
|
||||
if rows and rows[0][0] != "UTF8":
|
||||
raise IncorrectDatabaseSetup(
|
||||
"Database has incorrect encoding: '%s' instead of 'UTF8'\n"
|
||||
"See docs/postgres.rst for more information." % (rows[0][0],)
|
||||
)
|
||||
|
||||
def convert_param_style(self, sql):
|
||||
return sql.replace("?", "%s")
|
||||
|
||||
def on_new_connection(self, db_conn):
|
||||
|
||||
def check_database(self, db_conn, allow_outdated_version: bool = False):
|
||||
# Get the version of PostgreSQL that we're using. As per the psycopg2
|
||||
# docs: The number is formed by converting the major, minor, and
|
||||
# revision numbers into two-decimal-digit numbers and appending them
|
||||
@ -53,9 +40,22 @@ class PostgresEngine(object):
|
||||
self._version = db_conn.server_version
|
||||
|
||||
# Are we on a supported PostgreSQL version?
|
||||
if self._version < 90500:
|
||||
if not allow_outdated_version and self._version < 90500:
|
||||
raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
|
||||
|
||||
with db_conn.cursor() as txn:
|
||||
txn.execute("SHOW SERVER_ENCODING")
|
||||
rows = txn.fetchall()
|
||||
if rows and rows[0][0] != "UTF8":
|
||||
raise IncorrectDatabaseSetup(
|
||||
"Database has incorrect encoding: '%s' instead of 'UTF8'\n"
|
||||
"See docs/postgres.rst for more information." % (rows[0][0],)
|
||||
)
|
||||
|
||||
def convert_param_style(self, sql):
|
||||
return sql.replace("?", "%s")
|
||||
|
||||
def on_new_connection(self, db_conn):
|
||||
db_conn.set_isolation_level(
|
||||
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
)
|
||||
@ -119,8 +119,8 @@ class PostgresEngine(object):
|
||||
Returns:
|
||||
string
|
||||
"""
|
||||
# note that this is a bit of a hack because it relies on on_new_connection
|
||||
# having been called at least once. Still, that should be a safe bet here.
|
||||
# note that this is a bit of a hack because it relies on check_database
|
||||
# having been called. Still, that should be a safe bet here.
|
||||
numver = self._version
|
||||
assert numver is not None
|
||||
|
||||
|
@ -53,8 +53,11 @@ class Sqlite3Engine(object):
|
||||
"""
|
||||
return False
|
||||
|
||||
def check_database(self, txn):
|
||||
pass
|
||||
def check_database(self, db_conn, allow_outdated_version: bool = False):
|
||||
if not allow_outdated_version:
|
||||
version = self.module.sqlite_version_info
|
||||
if version < (3, 11, 0):
|
||||
raise RuntimeError("Synapse requires sqlite 3.11 or above.")
|
||||
|
||||
def convert_param_style(self, sql):
|
||||
return sql
|
||||
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 56
|
||||
SCHEMA_VERSION = 57
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
@ -269,8 +269,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||
one will be randomly generated.
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
RegistrationError if there was a problem registering.
|
||||
"""
|
||||
if localpart is None:
|
||||
raise SynapseError(400, "Request must include user id")
|
||||
|
@ -32,8 +32,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||
|
||||
def test_wait_for_sync_for_user_auth_blocking(self):
|
||||
|
||||
user_id1 = "@user1:server"
|
||||
user_id2 = "@user2:server"
|
||||
user_id1 = "@user1:test"
|
||||
user_id2 = "@user2:test"
|
||||
sync_config = self._generate_sync_config(user_id1)
|
||||
|
||||
self.reactor.advance(100) # So we get not 0 time
|
||||
|
@ -115,13 +115,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
def test_invites(self):
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.check("get_invited_rooms_for_user", [USER_ID_2], [])
|
||||
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
|
||||
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
|
||||
|
||||
self.replicate()
|
||||
|
||||
self.check(
|
||||
"get_invited_rooms_for_user",
|
||||
"get_invited_rooms_for_local_user",
|
||||
[USER_ID_2],
|
||||
[
|
||||
RoomsForUser(
|
||||
|
@ -14,11 +14,17 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
from binascii import unhexlify
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.rest.admin import VersionServlet
|
||||
from synapse.rest.client.v1 import events, login, room
|
||||
from synapse.rest.client.v2_alpha import groups
|
||||
@ -346,3 +352,338 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
|
||||
|
||||
test_purge_room.skip = "Disabled because it's currently broken"
|
||||
|
||||
|
||||
class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
||||
"""Test /quarantine_media admin API.
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
synapse.rest.admin.register_servlets_for_media_repo,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
|
||||
# Allow for uploading and downloading to/from the media repo
|
||||
self.media_repo = hs.get_media_repository_resource()
|
||||
self.download_resource = self.media_repo.children[b"download"]
|
||||
self.upload_resource = self.media_repo.children[b"upload"]
|
||||
self.image_data = unhexlify(
|
||||
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
|
||||
b"0000001f15c4890000000a49444154789c63000100000500010d"
|
||||
b"0a2db40000000049454e44ae426082"
|
||||
)
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
self.fetches = []
|
||||
|
||||
def get_file(destination, path, output_stream, args=None, max_size=None):
|
||||
"""
|
||||
Returns tuple[int,dict,str,int] of file length, response headers,
|
||||
absolute URI, and response code.
|
||||
"""
|
||||
|
||||
def write_to(r):
|
||||
data, response = r
|
||||
output_stream.write(data)
|
||||
return response
|
||||
|
||||
d = Deferred()
|
||||
d.addCallback(write_to)
|
||||
self.fetches.append((d, destination, path, args))
|
||||
return make_deferred_yieldable(d)
|
||||
|
||||
client = Mock()
|
||||
client.get_file = get_file
|
||||
|
||||
self.storage_path = self.mktemp()
|
||||
self.media_store_path = self.mktemp()
|
||||
os.mkdir(self.storage_path)
|
||||
os.mkdir(self.media_store_path)
|
||||
|
||||
config = self.default_config()
|
||||
config["media_store_path"] = self.media_store_path
|
||||
config["thumbnail_requirements"] = {}
|
||||
config["max_image_pixels"] = 2000000
|
||||
|
||||
provider_config = {
|
||||
"module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
|
||||
"store_local": True,
|
||||
"store_synchronous": False,
|
||||
"store_remote": True,
|
||||
"config": {"directory": self.storage_path},
|
||||
}
|
||||
config["media_storage_providers"] = [provider_config]
|
||||
|
||||
hs = self.setup_test_homeserver(config=config, http_client=client)
|
||||
|
||||
return hs
|
||||
|
||||
def test_quarantine_media_requires_admin(self):
|
||||
self.register_user("nonadmin", "pass", admin=False)
|
||||
non_admin_user_tok = self.login("nonadmin", "pass")
|
||||
|
||||
# Attempt quarantine media APIs as non-admin
|
||||
url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
|
||||
request, channel = self.make_request(
|
||||
"POST", url.encode("ascii"), access_token=non_admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
# Expect a forbidden error
|
||||
self.assertEqual(
|
||||
403,
|
||||
int(channel.result["code"]),
|
||||
msg="Expected forbidden on quarantining media as a non-admin",
|
||||
)
|
||||
|
||||
# And the roomID/userID endpoint
|
||||
url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
|
||||
request, channel = self.make_request(
|
||||
"POST", url.encode("ascii"), access_token=non_admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
# Expect a forbidden error
|
||||
self.assertEqual(
|
||||
403,
|
||||
int(channel.result["code"]),
|
||||
msg="Expected forbidden on quarantining media as a non-admin",
|
||||
)
|
||||
|
||||
def test_quarantine_media_by_id(self):
|
||||
self.register_user("id_admin", "pass", admin=True)
|
||||
admin_user_tok = self.login("id_admin", "pass")
|
||||
|
||||
self.register_user("id_nonadmin", "pass", admin=False)
|
||||
non_admin_user_tok = self.login("id_nonadmin", "pass")
|
||||
|
||||
# Upload some media into the room
|
||||
response = self.helper.upload_media(
|
||||
self.upload_resource, self.image_data, tok=admin_user_tok
|
||||
)
|
||||
|
||||
# Extract media ID from the response
|
||||
server_name_and_media_id = response["content_uri"][
|
||||
6:
|
||||
] # Cut off the 'mxc://' bit
|
||||
server_name, media_id = server_name_and_media_id.split("/")
|
||||
|
||||
# Attempt to access the media
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
server_name_and_media_id,
|
||||
shorthand=False,
|
||||
access_token=non_admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
|
||||
# Should be successful
|
||||
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
|
||||
|
||||
# Quarantine the media
|
||||
url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
|
||||
urllib.parse.quote(server_name),
|
||||
urllib.parse.quote(media_id),
|
||||
)
|
||||
request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
|
||||
self.render(request)
|
||||
self.pump(1.0)
|
||||
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
|
||||
|
||||
# Attempt to access the media
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
server_name_and_media_id,
|
||||
shorthand=False,
|
||||
access_token=admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
|
||||
# Should be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
% server_name_and_media_id
|
||||
),
|
||||
)
|
||||
|
||||
def test_quarantine_all_media_in_room(self):
|
||||
self.register_user("room_admin", "pass", admin=True)
|
||||
admin_user_tok = self.login("room_admin", "pass")
|
||||
|
||||
non_admin_user = self.register_user("room_nonadmin", "pass", admin=False)
|
||||
non_admin_user_tok = self.login("room_nonadmin", "pass")
|
||||
|
||||
room_id = self.helper.create_room_as(non_admin_user, tok=admin_user_tok)
|
||||
self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok)
|
||||
|
||||
# Upload some media
|
||||
response_1 = self.helper.upload_media(
|
||||
self.upload_resource, self.image_data, tok=non_admin_user_tok
|
||||
)
|
||||
response_2 = self.helper.upload_media(
|
||||
self.upload_resource, self.image_data, tok=non_admin_user_tok
|
||||
)
|
||||
|
||||
# Extract mxcs
|
||||
mxc_1 = response_1["content_uri"]
|
||||
mxc_2 = response_2["content_uri"]
|
||||
|
||||
# Send it into the room
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
"m.room.message",
|
||||
content={"body": "image-1", "msgtype": "m.image", "url": mxc_1},
|
||||
txn_id="111",
|
||||
tok=non_admin_user_tok,
|
||||
)
|
||||
self.helper.send_event(
|
||||
room_id,
|
||||
"m.room.message",
|
||||
content={"body": "image-2", "msgtype": "m.image", "url": mxc_2},
|
||||
txn_id="222",
|
||||
tok=non_admin_user_tok,
|
||||
)
|
||||
|
||||
# Quarantine all media in the room
|
||||
url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
|
||||
room_id
|
||||
)
|
||||
request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
|
||||
self.render(request)
|
||||
self.pump(1.0)
|
||||
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
|
||||
self.assertEqual(
|
||||
json.loads(channel.result["body"].decode("utf-8")),
|
||||
{"num_quarantined": 2},
|
||||
"Expected 2 quarantined items",
|
||||
)
|
||||
|
||||
# Convert mxc URLs to server/media_id strings
|
||||
server_and_media_id_1 = mxc_1[6:]
|
||||
server_and_media_id_2 = mxc_2[6:]
|
||||
|
||||
# Test that we cannot download any of the media anymore
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
server_and_media_id_1,
|
||||
shorthand=False,
|
||||
access_token=non_admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
|
||||
# Should be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
% server_and_media_id_1
|
||||
),
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
server_and_media_id_2,
|
||||
shorthand=False,
|
||||
access_token=non_admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
|
||||
# Should be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
% server_and_media_id_2
|
||||
),
|
||||
)
|
||||
|
||||
def test_quarantine_all_media_by_user(self):
|
||||
self.register_user("user_admin", "pass", admin=True)
|
||||
admin_user_tok = self.login("user_admin", "pass")
|
||||
|
||||
non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
|
||||
non_admin_user_tok = self.login("user_nonadmin", "pass")
|
||||
|
||||
# Upload some media
|
||||
response_1 = self.helper.upload_media(
|
||||
self.upload_resource, self.image_data, tok=non_admin_user_tok
|
||||
)
|
||||
response_2 = self.helper.upload_media(
|
||||
self.upload_resource, self.image_data, tok=non_admin_user_tok
|
||||
)
|
||||
|
||||
# Extract media IDs
|
||||
server_and_media_id_1 = response_1["content_uri"][6:]
|
||||
server_and_media_id_2 = response_2["content_uri"][6:]
|
||||
|
||||
# Quarantine all media by this user
|
||||
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
|
||||
non_admin_user
|
||||
)
|
||||
request, channel = self.make_request(
|
||||
"POST", url.encode("ascii"), access_token=admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.pump(1.0)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(
|
||||
json.loads(channel.result["body"].decode("utf-8")),
|
||||
{"num_quarantined": 2},
|
||||
"Expected 2 quarantined items",
|
||||
)
|
||||
|
||||
# Attempt to access each piece of media
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
server_and_media_id_1,
|
||||
shorthand=False,
|
||||
access_token=non_admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
|
||||
# Should be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
% server_and_media_id_1,
|
||||
),
|
||||
)
|
||||
|
||||
# Attempt to access each piece of media
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
server_and_media_id_2,
|
||||
shorthand=False,
|
||||
access_token=non_admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
|
||||
# Should be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
% server_and_media_id_2
|
||||
),
|
||||
)
|
||||
|
@ -21,6 +21,8 @@ import time
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
|
||||
from tests.server import make_request, render
|
||||
@ -160,3 +162,38 @@ class RestHelper(object):
|
||||
)
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def upload_media(
|
||||
self,
|
||||
resource: Resource,
|
||||
image_data: bytes,
|
||||
tok: str,
|
||||
filename: str = "test.png",
|
||||
expect_code: int = 200,
|
||||
) -> dict:
|
||||
"""Upload a piece of test media to the media repo
|
||||
Args:
|
||||
resource: The resource that will handle the upload request
|
||||
image_data: The image data to upload
|
||||
tok: The user token to use during the upload
|
||||
filename: The filename of the media to be uploaded
|
||||
expect_code: The return code to expect from attempting to upload the media
|
||||
"""
|
||||
image_length = len(image_data)
|
||||
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
|
||||
request, channel = make_request(
|
||||
self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
|
||||
)
|
||||
request.requestHeaders.addRawHeader(
|
||||
b"Content-Length", str(image_length).encode("UTF-8")
|
||||
)
|
||||
request.render(resource)
|
||||
self.hs.get_reactor().pump([100])
|
||||
|
||||
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
|
||||
expect_code,
|
||||
int(channel.result["code"]),
|
||||
channel.result["body"],
|
||||
)
|
||||
|
||||
return channel.json_body
|
||||
|
@ -285,7 +285,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# Make sure the invite is here.
|
||||
pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
|
||||
pending_invites = self.get_success(
|
||||
store.get_invited_rooms_for_local_user(invitee_id)
|
||||
)
|
||||
self.assertEqual(len(pending_invites), 1, pending_invites)
|
||||
self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
|
||||
|
||||
@ -293,12 +295,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
|
||||
self.deactivate(invitee_id, invitee_tok)
|
||||
|
||||
# Check that the invite isn't there anymore.
|
||||
pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
|
||||
pending_invites = self.get_success(
|
||||
store.get_invited_rooms_for_local_user(invitee_id)
|
||||
)
|
||||
self.assertEqual(len(pending_invites), 0, pending_invites)
|
||||
|
||||
# Check that the membership of @invitee:test in the room is now "leave".
|
||||
memberships = self.get_success(
|
||||
store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
|
||||
store.get_rooms_for_local_user_where_membership_is(
|
||||
invitee_id, [Membership.LEAVE]
|
||||
)
|
||||
)
|
||||
self.assertEqual(len(memberships), 1, memberships)
|
||||
self.assertEqual(memberships[0].room_id, room_id, memberships)
|
||||
|
@ -15,8 +15,6 @@
|
||||
# limitations under the License.
|
||||
import json
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventContentFields, EventTypes
|
||||
from synapse.rest.client.v1 import login, room
|
||||
@ -36,13 +34,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
sync.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
hs = self.setup_test_homeserver(
|
||||
"red", http_client=None, federation_client=Mock()
|
||||
)
|
||||
return hs
|
||||
|
||||
def test_sync_argless(self):
|
||||
request, channel = self.make_request("GET", "/sync")
|
||||
self.render(request)
|
||||
|
@ -57,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
|
||||
|
||||
rooms_for_user = self.get_success(
|
||||
self.store.get_rooms_for_user_where_membership_is(
|
||||
self.store.get_rooms_for_local_user_where_membership_is(
|
||||
self.u_alice, [Membership.JOIN]
|
||||
)
|
||||
)
|
||||
|
@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.api.errors import Codes, RedirectException, SynapseError
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
JsonResource,
|
||||
wrap_html_request_handler,
|
||||
)
|
||||
from synapse.http.site import SynapseSite, logger
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util import Clock
|
||||
@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase):
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
|
||||
|
||||
|
||||
class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
||||
class TestResource(DirectServeResource):
|
||||
callback = None
|
||||
|
||||
@wrap_html_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
return await self.callback(request)
|
||||
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
|
||||
def test_good_response(self):
|
||||
def callback(request):
|
||||
request.write(b"response")
|
||||
request.finish()
|
||||
|
||||
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||
res.callback = callback
|
||||
|
||||
request, channel = make_request(self.reactor, b"GET", b"/path")
|
||||
render(request, res, self.reactor)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
body = channel.result["body"]
|
||||
self.assertEqual(body, b"response")
|
||||
|
||||
def test_redirect_exception(self):
|
||||
"""
|
||||
If the callback raises a RedirectException, it is turned into a 30x
|
||||
with the right location.
|
||||
"""
|
||||
|
||||
def callback(request, **kwargs):
|
||||
raise RedirectException(b"/look/an/eagle", 301)
|
||||
|
||||
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||
res.callback = callback
|
||||
|
||||
request, channel = make_request(self.reactor, b"GET", b"/path")
|
||||
render(request, res, self.reactor)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"301")
|
||||
headers = channel.result["headers"]
|
||||
location_headers = [v for k, v in headers if k == b"Location"]
|
||||
self.assertEqual(location_headers, [b"/look/an/eagle"])
|
||||
|
||||
def test_redirect_exception_with_cookie(self):
|
||||
"""
|
||||
If the callback raises a RedirectException which sets a cookie, that is
|
||||
returned too
|
||||
"""
|
||||
|
||||
def callback(request, **kwargs):
|
||||
e = RedirectException(b"/no/over/there", 304)
|
||||
e.cookies.append(b"session=yespls")
|
||||
raise e
|
||||
|
||||
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||
res.callback = callback
|
||||
|
||||
request, channel = make_request(self.reactor, b"GET", b"/path")
|
||||
render(request, res, self.reactor)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"304")
|
||||
headers = channel.result["headers"]
|
||||
location_headers = [v for k, v in headers if k == b"Location"]
|
||||
self.assertEqual(location_headers, [b"/no/over/there"])
|
||||
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
|
||||
self.assertEqual(cookies_headers, [b"session=yespls"])
|
||||
|
||||
|
||||
class SiteTestCase(unittest.HomeserverTestCase):
|
||||
def test_lose_connection(self):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user