mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
MSC4108 implementation (#17056)
Co-authored-by: Hugh Nimmo-Smith <hughns@element.io> Co-authored-by: Hugh Nimmo-Smith <hughns@users.noreply.github.com> Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
parent
646cb6ff24
commit
2e92b718d5
164
Cargo.lock
generated
164
Cargo.lock
generated
@ -59,6 +59,12 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.6.0"
|
||||
@ -92,9 +98,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.5"
|
||||
version = "0.10.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adfbc57365a37acbd2ebf2b64d7e69bb766e2fea813521ed536f5d0520dcf86c"
|
||||
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
@ -117,6 +123,19 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"wasi",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "headers"
|
||||
version = "0.4.0"
|
||||
@ -182,6 +201,15 @@ version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc"
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d"
|
||||
dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
@ -266,6 +294,12 @@ version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.76"
|
||||
@ -369,6 +403,36 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.16"
|
||||
@ -461,6 +525,17 @@ dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.10.0"
|
||||
@ -489,6 +564,7 @@ name = "synapse"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64",
|
||||
"blake2",
|
||||
"bytes",
|
||||
"headers",
|
||||
@ -496,12 +572,15 @@ dependencies = [
|
||||
"http",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"mime",
|
||||
"pyo3",
|
||||
"pyo3-log",
|
||||
"pythonize",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"ulid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -516,6 +595,17 @@ version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
||||
|
||||
[[package]]
|
||||
name = "ulid"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34778c17965aa2a08913b57e1f34db9b4a63f5de31768b55bf20d2795f921259"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"rand",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.5"
|
||||
@ -534,6 +624,76 @@ version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"wasm-bindgen-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-backend"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da"
|
||||
dependencies = [
|
||||
"bumpalo",
|
||||
"log",
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"wasm-bindgen-macro-support",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro-support"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-shared"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96"
|
||||
|
||||
[[package]]
|
||||
name = "web-time"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.36.1"
|
||||
|
1
changelog.d/17056.feature
Normal file
1
changelog.d/17056.feature
Normal file
@ -0,0 +1 @@
|
||||
Implement the rendezvous mechanism described by MSC4108.
|
@ -23,11 +23,13 @@ name = "synapse.synapse_rust"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.63"
|
||||
base64 = "0.21.7"
|
||||
bytes = "1.6.0"
|
||||
headers = "0.4.0"
|
||||
http = "1.1.0"
|
||||
lazy_static = "1.4.0"
|
||||
log = "0.4.17"
|
||||
mime = "0.3.17"
|
||||
pyo3 = { version = "0.20.0", features = [
|
||||
"macros",
|
||||
"anyhow",
|
||||
@ -37,8 +39,10 @@ pyo3 = { version = "0.20.0", features = [
|
||||
pyo3-log = "0.9.0"
|
||||
pythonize = "0.20.0"
|
||||
regex = "1.6.0"
|
||||
sha2 = "0.10.8"
|
||||
serde = { version = "1.0.144", features = ["derive"] }
|
||||
serde_json = "1.0.85"
|
||||
ulid = "1.1.2"
|
||||
|
||||
[features]
|
||||
extension-module = ["pyo3/extension-module"]
|
||||
|
@ -7,6 +7,7 @@ pub mod errors;
|
||||
pub mod events;
|
||||
pub mod http;
|
||||
pub mod push;
|
||||
pub mod rendezvous;
|
||||
|
||||
lazy_static! {
|
||||
static ref LOGGING_HANDLE: ResetHandle = pyo3_log::init();
|
||||
@ -45,6 +46,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
acl::register_module(py, m)?;
|
||||
push::register_module(py, m)?;
|
||||
events::register_module(py, m)?;
|
||||
rendezvous::register_module(py, m)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
315
rust/src/rendezvous/mod.rs
Normal file
315
rust/src/rendezvous/mod.rs
Normal file
@ -0,0 +1,315 @@
|
||||
/*
|
||||
* This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
*
|
||||
* Copyright (C) 2024 New Vector, Ltd
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as
|
||||
* published by the Free Software Foundation, either version 3 of the
|
||||
* License, or (at your option) any later version.
|
||||
*
|
||||
* See the GNU Affero General Public License for more details:
|
||||
* <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
*
|
||||
*/
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use bytes::Bytes;
|
||||
use headers::{
|
||||
AccessControlAllowOrigin, AccessControlExposeHeaders, CacheControl, ContentLength, ContentType,
|
||||
HeaderMapExt, IfMatch, IfNoneMatch, Pragma,
|
||||
};
|
||||
use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri};
|
||||
use mime::Mime;
|
||||
use pyo3::{
|
||||
exceptions::PyValueError, pyclass, pymethods, types::PyModule, Py, PyAny, PyObject, PyResult,
|
||||
Python, ToPyObject,
|
||||
};
|
||||
use ulid::Ulid;
|
||||
|
||||
use self::session::Session;
|
||||
use crate::{
|
||||
errors::{NotFoundError, SynapseError},
|
||||
http::{http_request_from_twisted, http_response_to_twisted, HeaderMapPyExt},
|
||||
};
|
||||
|
||||
mod session;
|
||||
|
||||
// n.b. Because OPTIONS requests are handled by the Python code, we don't need to set Access-Control-Allow-Headers.
|
||||
fn prepare_headers(headers: &mut HeaderMap, session: &Session) {
|
||||
headers.typed_insert(AccessControlAllowOrigin::ANY);
|
||||
headers.typed_insert(AccessControlExposeHeaders::from_iter([ETAG]));
|
||||
headers.typed_insert(Pragma::no_cache());
|
||||
headers.typed_insert(CacheControl::new().with_no_store());
|
||||
headers.typed_insert(session.etag());
|
||||
headers.typed_insert(session.expires());
|
||||
headers.typed_insert(session.last_modified());
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct RendezvousHandler {
|
||||
base: Uri,
|
||||
clock: PyObject,
|
||||
sessions: BTreeMap<Ulid, Session>,
|
||||
capacity: usize,
|
||||
max_content_length: u64,
|
||||
ttl: Duration,
|
||||
}
|
||||
|
||||
impl RendezvousHandler {
|
||||
/// Check the input headers of a request which sets data for a session, and return the content type.
|
||||
fn check_input_headers(&self, headers: &HeaderMap) -> PyResult<Mime> {
|
||||
let ContentLength(content_length) = headers.typed_get_required()?;
|
||||
|
||||
if content_length > self.max_content_length {
|
||||
return Err(SynapseError::new(
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
"Payload too large".to_owned(),
|
||||
"M_TOO_LARGE",
|
||||
None,
|
||||
None,
|
||||
));
|
||||
}
|
||||
|
||||
let content_type: ContentType = headers.typed_get_required()?;
|
||||
|
||||
// Content-Type must be text/plain
|
||||
if content_type != ContentType::text() {
|
||||
return Err(SynapseError::new(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Content-Type must be text/plain".to_owned(),
|
||||
"M_INVALID_PARAM",
|
||||
None,
|
||||
None,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(content_type.into())
|
||||
}
|
||||
|
||||
/// Evict expired sessions and remove the oldest sessions until we're under the capacity.
|
||||
fn evict(&mut self, now: SystemTime) {
|
||||
// First remove all the entries which expired
|
||||
self.sessions.retain(|_, session| !session.expired(now));
|
||||
|
||||
// Then we remove the oldest entires until we're under the limit
|
||||
while self.sessions.len() > self.capacity {
|
||||
self.sessions.pop_first();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl RendezvousHandler {
|
||||
#[new]
|
||||
#[pyo3(signature = (homeserver, /, capacity=100, max_content_length=4*1024, eviction_interval=60*1000, ttl=60*1000))]
|
||||
fn new(
|
||||
py: Python<'_>,
|
||||
homeserver: &PyAny,
|
||||
capacity: usize,
|
||||
max_content_length: u64,
|
||||
eviction_interval: u64,
|
||||
ttl: u64,
|
||||
) -> PyResult<Py<Self>> {
|
||||
let base: String = homeserver
|
||||
.getattr("config")?
|
||||
.getattr("server")?
|
||||
.getattr("public_baseurl")?
|
||||
.extract()?;
|
||||
let base = Uri::try_from(format!("{base}_synapse/client/rendezvous"))
|
||||
.map_err(|_| PyValueError::new_err("Invalid base URI"))?;
|
||||
|
||||
let clock = homeserver.call_method0("get_clock")?.to_object(py);
|
||||
|
||||
// Construct a Python object so that we can get a reference to the
|
||||
// evict method and schedule it to run.
|
||||
let self_ = Py::new(
|
||||
py,
|
||||
Self {
|
||||
base,
|
||||
clock,
|
||||
sessions: BTreeMap::new(),
|
||||
capacity,
|
||||
max_content_length,
|
||||
ttl: Duration::from_millis(ttl),
|
||||
},
|
||||
)?;
|
||||
|
||||
let evict = self_.getattr(py, "_evict")?;
|
||||
homeserver.call_method0("get_clock")?.call_method(
|
||||
"looping_call",
|
||||
(evict, eviction_interval),
|
||||
None,
|
||||
)?;
|
||||
|
||||
Ok(self_)
|
||||
}
|
||||
|
||||
fn _evict(&mut self, py: Python<'_>) -> PyResult<()> {
|
||||
let clock = self.clock.as_ref(py);
|
||||
let now: u64 = clock.call_method0("time_msec")?.extract()?;
|
||||
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
|
||||
self.evict(now);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_post(&mut self, py: Python<'_>, twisted_request: &PyAny) -> PyResult<()> {
|
||||
let request = http_request_from_twisted(twisted_request)?;
|
||||
|
||||
let content_type = self.check_input_headers(request.headers())?;
|
||||
|
||||
let clock = self.clock.as_ref(py);
|
||||
let now: u64 = clock.call_method0("time_msec")?.extract()?;
|
||||
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
|
||||
|
||||
// We trigger an immediate eviction if we're at 2x the capacity
|
||||
if self.sessions.len() >= self.capacity * 2 {
|
||||
self.evict(now);
|
||||
}
|
||||
|
||||
// Generate a new ULID for the session from the current time.
|
||||
let id = Ulid::from_datetime(now);
|
||||
|
||||
let uri = format!("{base}/{id}", base = self.base);
|
||||
|
||||
let body = request.into_body();
|
||||
|
||||
let session = Session::new(body, content_type, now, self.ttl);
|
||||
|
||||
let response = serde_json::json!({
|
||||
"url": uri,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let mut response = Response::new(response.as_bytes());
|
||||
*response.status_mut() = StatusCode::CREATED;
|
||||
response.headers_mut().typed_insert(ContentType::json());
|
||||
prepare_headers(response.headers_mut(), &session);
|
||||
http_response_to_twisted(twisted_request, response)?;
|
||||
|
||||
self.sessions.insert(id, session);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_get(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> {
|
||||
let request = http_request_from_twisted(twisted_request)?;
|
||||
|
||||
let if_none_match: Option<IfNoneMatch> = request.headers().typed_get_optional()?;
|
||||
|
||||
let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?;
|
||||
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
|
||||
|
||||
let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
|
||||
let session = self
|
||||
.sessions
|
||||
.get(&id)
|
||||
.filter(|s| !s.expired(now))
|
||||
.ok_or_else(NotFoundError::new)?;
|
||||
|
||||
if let Some(if_none_match) = if_none_match {
|
||||
if !if_none_match.precondition_passes(&session.etag()) {
|
||||
let mut response = Response::new(Bytes::new());
|
||||
*response.status_mut() = StatusCode::NOT_MODIFIED;
|
||||
prepare_headers(response.headers_mut(), session);
|
||||
http_response_to_twisted(twisted_request, response)?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let mut response = Response::new(session.data());
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
let headers = response.headers_mut();
|
||||
prepare_headers(headers, session);
|
||||
headers.typed_insert(session.content_type());
|
||||
headers.typed_insert(session.content_length());
|
||||
http_response_to_twisted(twisted_request, response)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_put(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> {
|
||||
let request = http_request_from_twisted(twisted_request)?;
|
||||
|
||||
let content_type = self.check_input_headers(request.headers())?;
|
||||
|
||||
let if_match: IfMatch = request.headers().typed_get_required()?;
|
||||
|
||||
let data = request.into_body();
|
||||
|
||||
let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?;
|
||||
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
|
||||
|
||||
let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
|
||||
let session = self
|
||||
.sessions
|
||||
.get_mut(&id)
|
||||
.filter(|s| !s.expired(now))
|
||||
.ok_or_else(NotFoundError::new)?;
|
||||
|
||||
if !if_match.precondition_passes(&session.etag()) {
|
||||
let mut headers = HeaderMap::new();
|
||||
prepare_headers(&mut headers, session);
|
||||
|
||||
let mut additional_fields = HashMap::with_capacity(1);
|
||||
additional_fields.insert(
|
||||
String::from("org.matrix.msc4108.errcode"),
|
||||
String::from("M_CONCURRENT_WRITE"),
|
||||
);
|
||||
|
||||
return Err(SynapseError::new(
|
||||
StatusCode::PRECONDITION_FAILED,
|
||||
"ETag does not match".to_owned(),
|
||||
"M_UNKNOWN", // Would be M_CONCURRENT_WRITE
|
||||
Some(additional_fields),
|
||||
Some(headers),
|
||||
));
|
||||
}
|
||||
|
||||
session.update(data, content_type, now);
|
||||
|
||||
let mut response = Response::new(Bytes::new());
|
||||
*response.status_mut() = StatusCode::ACCEPTED;
|
||||
prepare_headers(response.headers_mut(), session);
|
||||
http_response_to_twisted(twisted_request, response)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_delete(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> {
|
||||
let _request = http_request_from_twisted(twisted_request)?;
|
||||
|
||||
let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
|
||||
let _session = self.sessions.remove(&id).ok_or_else(NotFoundError::new)?;
|
||||
|
||||
let mut response = Response::new(Bytes::new());
|
||||
*response.status_mut() = StatusCode::NO_CONTENT;
|
||||
response
|
||||
.headers_mut()
|
||||
.typed_insert(AccessControlAllowOrigin::ANY);
|
||||
http_response_to_twisted(twisted_request, response)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let child_module = PyModule::new(py, "rendezvous")?;
|
||||
|
||||
child_module.add_class::<RendezvousHandler>()?;
|
||||
|
||||
m.add_submodule(child_module)?;
|
||||
|
||||
// We need to manually add the module to sys.modules to make `from
|
||||
// synapse.synapse_rust import rendezvous` work.
|
||||
py.import("sys")?
|
||||
.getattr("modules")?
|
||||
.set_item("synapse.synapse_rust.rendezvous", child_module)?;
|
||||
|
||||
Ok(())
|
||||
}
|
91
rust/src/rendezvous/session.rs
Normal file
91
rust/src/rendezvous/session.rs
Normal file
@ -0,0 +1,91 @@
|
||||
/*
|
||||
* This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
*
|
||||
* Copyright (C) 2024 New Vector, Ltd
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as
|
||||
* published by the Free Software Foundation, either version 3 of the
|
||||
* License, or (at your option) any later version.
|
||||
*
|
||||
* See the GNU Affero General Public License for more details:
|
||||
* <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
*/
|
||||
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use bytes::Bytes;
|
||||
use headers::{ContentLength, ContentType, ETag, Expires, LastModified};
|
||||
use mime::Mime;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// A single session, containing data, metadata, and expiry information.
|
||||
pub struct Session {
|
||||
hash: [u8; 32],
|
||||
data: Bytes,
|
||||
content_type: Mime,
|
||||
last_modified: SystemTime,
|
||||
expires: SystemTime,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Create a new session with the given data, content type, and time-to-live.
|
||||
pub fn new(data: Bytes, content_type: Mime, now: SystemTime, ttl: Duration) -> Self {
|
||||
let hash = Sha256::digest(&data).into();
|
||||
Self {
|
||||
hash,
|
||||
data,
|
||||
content_type,
|
||||
expires: now + ttl,
|
||||
last_modified: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the session has expired at the given time.
|
||||
pub fn expired(&self, now: SystemTime) -> bool {
|
||||
self.expires <= now
|
||||
}
|
||||
|
||||
/// Update the session with new data, content type, and last modified time.
|
||||
pub fn update(&mut self, data: Bytes, content_type: Mime, now: SystemTime) {
|
||||
self.hash = Sha256::digest(&data).into();
|
||||
self.data = data;
|
||||
self.content_type = content_type;
|
||||
self.last_modified = now;
|
||||
}
|
||||
|
||||
/// Returns the Content-Type header of the session.
|
||||
pub fn content_type(&self) -> ContentType {
|
||||
self.content_type.clone().into()
|
||||
}
|
||||
|
||||
/// Returns the Content-Length header of the session.
|
||||
pub fn content_length(&self) -> ContentLength {
|
||||
ContentLength(self.data.len() as _)
|
||||
}
|
||||
|
||||
/// Returns the ETag header of the session.
|
||||
pub fn etag(&self) -> ETag {
|
||||
let encoded = URL_SAFE_NO_PAD.encode(self.hash);
|
||||
// SAFETY: Base64 encoding is URL-safe, so ETag-safe
|
||||
format!("\"{encoded}\"")
|
||||
.parse()
|
||||
.expect("base64-encoded hash should be URL-safe")
|
||||
}
|
||||
|
||||
/// Returns the Last-Modified header of the session.
|
||||
pub fn last_modified(&self) -> LastModified {
|
||||
self.last_modified.into()
|
||||
}
|
||||
|
||||
/// Returns the Expires header of the session.
|
||||
pub fn expires(&self) -> Expires {
|
||||
self.expires.into()
|
||||
}
|
||||
|
||||
/// Returns the current data stored in the session.
|
||||
pub fn data(&self) -> Bytes {
|
||||
self.data.clone()
|
||||
}
|
||||
}
|
@ -413,12 +413,22 @@ class ExperimentalConfig(Config):
|
||||
)
|
||||
|
||||
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code
|
||||
self.msc4108_enabled = experimental.get("msc4108_enabled", False)
|
||||
|
||||
self.msc4108_delegation_endpoint: Optional[str] = experimental.get(
|
||||
"msc4108_delegation_endpoint", None
|
||||
)
|
||||
|
||||
if self.msc4108_delegation_endpoint is not None and not self.msc3861.enabled:
|
||||
if (
|
||||
self.msc4108_enabled or self.msc4108_delegation_endpoint is not None
|
||||
) and not self.msc3861.enabled:
|
||||
raise ConfigError(
|
||||
"MSC4108 requires MSC3861 to be enabled",
|
||||
("experimental", "msc4108_delegation_endpoint"),
|
||||
)
|
||||
|
||||
if self.msc4108_delegation_endpoint is not None and self.msc4108_enabled:
|
||||
raise ConfigError(
|
||||
"You cannot have MSC4108 both enabled and delegated at the same time",
|
||||
("experimental", "msc4108_delegation_endpoint"),
|
||||
)
|
||||
|
@ -909,8 +909,9 @@ def set_cors_headers(request: "SynapseRequest") -> None:
|
||||
request.setHeader(
|
||||
b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS"
|
||||
)
|
||||
if request.path is not None and request.path.startswith(
|
||||
b"/_matrix/client/unstable/org.matrix.msc4108/rendezvous"
|
||||
if request.path is not None and (
|
||||
request.path == b"/_matrix/client/unstable/org.matrix.msc4108/rendezvous"
|
||||
or request.path.startswith(b"/_synapse/client/rendezvous")
|
||||
):
|
||||
request.setHeader(
|
||||
b"Access-Control-Allow-Headers",
|
||||
|
@ -97,9 +97,25 @@ class MSC4108DelegationRendezvousServlet(RestServlet):
|
||||
)
|
||||
|
||||
|
||||
class MSC4108RendezvousServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/org.matrix.msc4108/rendezvous$", releases=[], v1=False, unstable=True
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer") -> None:
|
||||
super().__init__()
|
||||
self._handler = hs.get_rendezvous_handler()
|
||||
|
||||
def on_POST(self, request: SynapseRequest) -> None:
|
||||
self._handler.handle_post(request)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.experimental.msc3886_endpoint is not None:
|
||||
MSC3886RendezvousServlet(hs).register(http_server)
|
||||
|
||||
if hs.config.experimental.msc4108_enabled:
|
||||
MSC4108RendezvousServlet(hs).register(http_server)
|
||||
|
||||
if hs.config.experimental.msc4108_delegation_endpoint is not None:
|
||||
MSC4108DelegationRendezvousServlet(hs).register(http_server)
|
||||
|
@ -141,8 +141,13 @@ class VersionsRestServlet(RestServlet):
|
||||
# Allows clients to handle push for encrypted events.
|
||||
"org.matrix.msc4028": self.config.experimental.msc4028_push_encrypted_events,
|
||||
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code
|
||||
"org.matrix.msc4108": self.config.experimental.msc4108_delegation_endpoint
|
||||
is not None,
|
||||
"org.matrix.msc4108": (
|
||||
self.config.experimental.msc4108_enabled
|
||||
or (
|
||||
self.config.experimental.msc4108_delegation_endpoint
|
||||
is not None
|
||||
)
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
@ -26,6 +26,7 @@ from twisted.web.resource import Resource
|
||||
from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource
|
||||
from synapse.rest.synapse.client.pick_idp import PickIdpResource
|
||||
from synapse.rest.synapse.client.pick_username import pick_username_resource
|
||||
from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource
|
||||
from synapse.rest.synapse.client.sso_register import SsoRegisterResource
|
||||
from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
|
||||
|
||||
@ -76,6 +77,9 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc
|
||||
# To be removed in Synapse v1.32.0.
|
||||
resources["/_matrix/saml2"] = res
|
||||
|
||||
if hs.config.experimental.msc4108_enabled:
|
||||
resources["/_synapse/client/rendezvous"] = MSC4108RendezvousSessionResource(hs)
|
||||
|
||||
return resources
|
||||
|
||||
|
||||
|
58
synapse/rest/synapse/client/rendezvous.py
Normal file
58
synapse/rest/synapse/client/rendezvous.py
Normal file
@ -0,0 +1,58 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
#
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from synapse.api.errors import UnrecognizedRequestError
|
||||
from synapse.http.server import DirectServeJsonResource
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MSC4108RendezvousSessionResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs: "HomeServer") -> None:
|
||||
super().__init__()
|
||||
self._handler = hs.get_rendezvous_handler()
|
||||
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
postpath: List[bytes] = request.postpath # type: ignore
|
||||
if len(postpath) != 1:
|
||||
raise UnrecognizedRequestError()
|
||||
session_id = postpath[0].decode("ascii")
|
||||
|
||||
self._handler.handle_get(request, session_id)
|
||||
|
||||
def _async_render_PUT(self, request: SynapseRequest) -> None:
|
||||
postpath: List[bytes] = request.postpath # type: ignore
|
||||
if len(postpath) != 1:
|
||||
raise UnrecognizedRequestError()
|
||||
session_id = postpath[0].decode("ascii")
|
||||
|
||||
self._handler.handle_put(request, session_id)
|
||||
|
||||
def _async_render_DELETE(self, request: SynapseRequest) -> None:
|
||||
postpath: List[bytes] = request.postpath # type: ignore
|
||||
if len(postpath) != 1:
|
||||
raise UnrecognizedRequestError()
|
||||
session_id = postpath[0].decode("ascii")
|
||||
|
||||
self._handler.handle_delete(request, session_id)
|
@ -143,6 +143,7 @@ from synapse.state import StateHandler, StateResolutionHandler
|
||||
from synapse.storage import Databases
|
||||
from synapse.storage.controllers import StorageControllers
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.synapse_rust.rendezvous import RendezvousHandler
|
||||
from synapse.types import DomainSpecificString, ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
@ -859,6 +860,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
def get_room_forgetter_handler(self) -> RoomForgetterHandler:
|
||||
return RoomForgetterHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_rendezvous_handler(self) -> RendezvousHandler:
|
||||
return RendezvousHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_outbound_redis_connection(self) -> "ConnectionHandler":
|
||||
"""
|
||||
|
30
synapse/synapse_rust/rendezvous.pyi
Normal file
30
synapse/synapse_rust/rendezvous.pyi
Normal file
@ -0,0 +1,30 @@
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
from twisted.web.iweb import IRequest
|
||||
|
||||
from synapse.server import HomeServer
|
||||
|
||||
class RendezvousHandler:
|
||||
def __init__(
|
||||
self,
|
||||
homeserver: HomeServer,
|
||||
/,
|
||||
capacity: int = 100,
|
||||
max_content_length: int = 4 * 1024, # MSC4108 specifies 4KB
|
||||
eviction_interval: int = 60 * 1000,
|
||||
ttl: int = 60 * 1000,
|
||||
) -> None: ...
|
||||
def handle_post(self, request: IRequest) -> None: ...
|
||||
def handle_get(self, request: IRequest, session_id: str) -> None: ...
|
||||
def handle_put(self, request: IRequest, session_id: str) -> None: ...
|
||||
def handle_delete(self, request: IRequest, session_id: str) -> None: ...
|
@ -2,7 +2,7 @@
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
# Copyright (C) 2023-2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
@ -19,9 +19,14 @@
|
||||
#
|
||||
#
|
||||
|
||||
from typing import Dict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.rest.client import rendezvous
|
||||
from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
@ -42,6 +47,12 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
|
||||
self.hs = self.setup_test_homeserver()
|
||||
return self.hs
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
return {
|
||||
**super().create_resource_dict(),
|
||||
"/_synapse/client/rendezvous": MSC4108RendezvousSessionResource(self.hs),
|
||||
}
|
||||
|
||||
def test_disabled(self) -> None:
|
||||
channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
|
||||
self.assertEqual(channel.code, 404)
|
||||
@ -75,3 +86,391 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None)
|
||||
self.assertEqual(channel.code, 307)
|
||||
self.assertEqual(channel.headers.getRawHeaders("Location"), ["https://asd"])
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"experimental_features": {
|
||||
"msc4108_enabled": True,
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": "https://issuer",
|
||||
"client_id": "client_id",
|
||||
"client_auth_method": "client_secret_post",
|
||||
"client_secret": "client_secret",
|
||||
"admin_token": "admin_token_value",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_msc4108(self) -> None:
|
||||
"""
|
||||
Test the MSC4108 rendezvous endpoint, including:
|
||||
- Creating a session
|
||||
- Getting the data back
|
||||
- Updating the data
|
||||
- Deleting the data
|
||||
- ETag handling
|
||||
"""
|
||||
# We can post arbitrary data to the endpoint
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
msc4108_endpoint,
|
||||
"foo=bar",
|
||||
content_type=b"text/plain",
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 201)
|
||||
self.assertSubstring("/_synapse/client/rendezvous/", channel.json_body["url"])
|
||||
headers = dict(channel.headers.getAllRawHeaders())
|
||||
self.assertIn(b"ETag", headers)
|
||||
self.assertIn(b"Expires", headers)
|
||||
self.assertEqual(headers[b"Content-Type"], [b"application/json"])
|
||||
self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
|
||||
self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
|
||||
self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
|
||||
self.assertEqual(headers[b"Pragma"], [b"no-cache"])
|
||||
self.assertIn("url", channel.json_body)
|
||||
self.assertTrue(channel.json_body["url"].startswith("https://"))
|
||||
|
||||
url = urlparse(channel.json_body["url"])
|
||||
session_endpoint = url.path
|
||||
etag = headers[b"ETag"][0]
|
||||
|
||||
# We can get the data back
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
headers = dict(channel.headers.getAllRawHeaders())
|
||||
self.assertEqual(headers[b"ETag"], [etag])
|
||||
self.assertIn(b"Expires", headers)
|
||||
self.assertEqual(headers[b"Content-Type"], [b"text/plain"])
|
||||
self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
|
||||
self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
|
||||
self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
|
||||
self.assertEqual(headers[b"Pragma"], [b"no-cache"])
|
||||
self.assertEqual(channel.text_body, "foo=bar")
|
||||
|
||||
# We can make sure the data hasn't changed
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
custom_headers=[("If-None-Match", etag)],
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 304)
|
||||
|
||||
# We can update the data
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
session_endpoint,
|
||||
"foo=baz",
|
||||
content_type=b"text/plain",
|
||||
access_token=None,
|
||||
custom_headers=[("If-Match", etag)],
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 202)
|
||||
headers = dict(channel.headers.getAllRawHeaders())
|
||||
old_etag = etag
|
||||
new_etag = headers[b"ETag"][0]
|
||||
|
||||
# If we try to update it again with the old etag, it should fail
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
session_endpoint,
|
||||
"bar=baz",
|
||||
content_type=b"text/plain",
|
||||
access_token=None,
|
||||
custom_headers=[("If-Match", old_etag)],
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 412)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN")
|
||||
self.assertEqual(
|
||||
channel.json_body["org.matrix.msc4108.errcode"], "M_CONCURRENT_WRITE"
|
||||
)
|
||||
|
||||
# If we try to get with the old etag, we should get the updated data
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
custom_headers=[("If-None-Match", old_etag)],
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
headers = dict(channel.headers.getAllRawHeaders())
|
||||
self.assertEqual(headers[b"ETag"], [new_etag])
|
||||
self.assertEqual(channel.text_body, "foo=baz")
|
||||
|
||||
# We can delete the data
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 204)
|
||||
|
||||
# If we try to get the data again, it should fail
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"experimental_features": {
|
||||
"msc4108_enabled": True,
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": "https://issuer",
|
||||
"client_id": "client_id",
|
||||
"client_auth_method": "client_secret_post",
|
||||
"client_secret": "client_secret",
|
||||
"admin_token": "admin_token_value",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_msc4108_expiration(self) -> None:
|
||||
"""
|
||||
Test that entries are evicted after a TTL.
|
||||
"""
|
||||
# Start a new session
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
msc4108_endpoint,
|
||||
"foo=bar",
|
||||
content_type=b"text/plain",
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 201)
|
||||
session_endpoint = urlparse(channel.json_body["url"]).path
|
||||
|
||||
# Sanity check that we can get the data back
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.text_body, "foo=bar")
|
||||
|
||||
# Advance the clock, TTL of entries is 1 minute
|
||||
self.reactor.advance(60)
|
||||
|
||||
# Get the data back, it should be gone
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 404)
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"experimental_features": {
|
||||
"msc4108_enabled": True,
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": "https://issuer",
|
||||
"client_id": "client_id",
|
||||
"client_auth_method": "client_secret_post",
|
||||
"client_secret": "client_secret",
|
||||
"admin_token": "admin_token_value",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_msc4108_capacity(self) -> None:
|
||||
"""
|
||||
Test that a capacity limit is enforced on the rendezvous sessions, as old
|
||||
entries are evicted at an interval when the limit is reached.
|
||||
"""
|
||||
# Start a new session
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
msc4108_endpoint,
|
||||
"foo=bar",
|
||||
content_type=b"text/plain",
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 201)
|
||||
session_endpoint = urlparse(channel.json_body["url"]).path
|
||||
|
||||
# Sanity check that we can get the data back
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.text_body, "foo=bar")
|
||||
|
||||
# Start a lot of new sessions
|
||||
for _ in range(100):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
msc4108_endpoint,
|
||||
"foo=bar",
|
||||
content_type=b"text/plain",
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 201)
|
||||
|
||||
# Get the data back, it should still be there, as the eviction hasn't run yet
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Advance the clock, as it will trigger the eviction
|
||||
self.reactor.advance(1)
|
||||
|
||||
# Get the data back, it should be gone
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"experimental_features": {
|
||||
"msc4108_enabled": True,
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": "https://issuer",
|
||||
"client_id": "client_id",
|
||||
"client_auth_method": "client_secret_post",
|
||||
"client_secret": "client_secret",
|
||||
"admin_token": "admin_token_value",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_msc4108_hard_capacity(self) -> None:
|
||||
"""
|
||||
Test that a hard capacity limit is enforced on the rendezvous sessions, as old
|
||||
entries are evicted immediately when the limit is reached.
|
||||
"""
|
||||
# Start a new session
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
msc4108_endpoint,
|
||||
"foo=bar",
|
||||
content_type=b"text/plain",
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 201)
|
||||
session_endpoint = urlparse(channel.json_body["url"]).path
|
||||
# We advance the clock to make sure that this entry is the "lowest" in the session list
|
||||
self.reactor.advance(1)
|
||||
|
||||
# Sanity check that we can get the data back
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.text_body, "foo=bar")
|
||||
|
||||
# Start a lot of new sessions
|
||||
for _ in range(200):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
msc4108_endpoint,
|
||||
"foo=bar",
|
||||
content_type=b"text/plain",
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 201)
|
||||
|
||||
# Get the data back, it should already be gone as we hit the hard limit
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404)
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"experimental_features": {
|
||||
"msc4108_enabled": True,
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": "https://issuer",
|
||||
"client_id": "client_id",
|
||||
"client_auth_method": "client_secret_post",
|
||||
"client_secret": "client_secret",
|
||||
"admin_token": "admin_token_value",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_msc4108_content_type(self) -> None:
|
||||
"""
|
||||
Test that the content-type is restricted to text/plain.
|
||||
"""
|
||||
# We cannot post invalid content-type arbitrary data to the endpoint
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
msc4108_endpoint,
|
||||
"foo=bar",
|
||||
content_is_form=True,
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
|
||||
|
||||
# Make a valid request
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
msc4108_endpoint,
|
||||
"foo=bar",
|
||||
content_type=b"text/plain",
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 201)
|
||||
url = urlparse(channel.json_body["url"])
|
||||
session_endpoint = url.path
|
||||
headers = dict(channel.headers.getAllRawHeaders())
|
||||
etag = headers[b"ETag"][0]
|
||||
|
||||
# We can't update the data with invalid content-type
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
session_endpoint,
|
||||
"foo=baz",
|
||||
content_is_form=True,
|
||||
access_token=None,
|
||||
custom_headers=[("If-Match", etag)],
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
|
||||
|
@ -351,6 +351,7 @@ def make_request(
|
||||
request: Type[Request] = SynapseRequest,
|
||||
shorthand: bool = True,
|
||||
federation_auth_origin: Optional[bytes] = None,
|
||||
content_type: Optional[bytes] = None,
|
||||
content_is_form: bool = False,
|
||||
await_result: bool = True,
|
||||
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
|
||||
@ -373,6 +374,8 @@ def make_request(
|
||||
with the usual REST API path, if it doesn't contain it.
|
||||
federation_auth_origin: if set to not-None, we will add a fake
|
||||
Authorization header pretenting to be the given server name.
|
||||
content_type: The content-type to use for the request. If not set then will default to
|
||||
application/json unless content_is_form is true.
|
||||
content_is_form: Whether the content is URL encoded form data. Adds the
|
||||
'Content-Type': 'application/x-www-form-urlencoded' header.
|
||||
await_result: whether to wait for the request to complete rendering. If true,
|
||||
@ -436,7 +439,9 @@ def make_request(
|
||||
)
|
||||
|
||||
if content:
|
||||
if content_is_form:
|
||||
if content_type is not None:
|
||||
req.requestHeaders.addRawHeader(b"Content-Type", content_type)
|
||||
elif content_is_form:
|
||||
req.requestHeaders.addRawHeader(
|
||||
b"Content-Type", b"application/x-www-form-urlencoded"
|
||||
)
|
||||
|
@ -523,6 +523,7 @@ class HomeserverTestCase(TestCase):
|
||||
request: Type[Request] = SynapseRequest,
|
||||
shorthand: bool = True,
|
||||
federation_auth_origin: Optional[bytes] = None,
|
||||
content_type: Optional[bytes] = None,
|
||||
content_is_form: bool = False,
|
||||
await_result: bool = True,
|
||||
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
|
||||
@ -541,6 +542,9 @@ class HomeserverTestCase(TestCase):
|
||||
with the usual REST API path, if it doesn't contain it.
|
||||
federation_auth_origin: if set to not-None, we will add a fake
|
||||
Authorization header pretenting to be the given server name.
|
||||
|
||||
content_type: The content-type to use for the request. If not set then will default to
|
||||
application/json unless content_is_form is true.
|
||||
content_is_form: Whether the content is URL encoded form data. Adds the
|
||||
'Content-Type': 'application/x-www-form-urlencoded' header.
|
||||
|
||||
@ -566,6 +570,7 @@ class HomeserverTestCase(TestCase):
|
||||
request,
|
||||
shorthand,
|
||||
federation_auth_origin,
|
||||
content_type,
|
||||
content_is_form,
|
||||
await_result,
|
||||
custom_headers,
|
||||
|
Loading…
Reference in New Issue
Block a user