From e0241feb9aa3dd6a9960e81e25625c3491ad9095 Mon Sep 17 00:00:00 2001 From: Simon Bihel Date: Mon, 14 Feb 2022 22:29:12 +0000 Subject: [PATCH] Complete client configuration endpoints --- src/axum_lib.rs | 36 ++++++++++++++++++--- src/db/cf.rs | 12 +++++++ src/db/mod.rs | 4 ++- src/db/redis.rs | 12 +++++++ src/oidc.rs | 75 +++++++++++++++++++++++++++++++++++++------ src/worker_lib.rs | 58 +++++++++++++++++++++++++++++++-- wrangler_example.toml | 2 +- 7 files changed, 180 insertions(+), 19 deletions(-) diff --git a/src/axum_lib.rs b/src/axum_lib.rs index 9450e19..bf839bf 100644 --- a/src/axum_lib.rs +++ b/src/axum_lib.rs @@ -7,7 +7,7 @@ use axum::{ StatusCode, }, response::{self, IntoResponse, Redirect}, - routing::{get, get_service, post}, + routing::{delete, get, get_service, post}, AddExtensionLayer, Json, Router, }; use bb8_redis::{bb8, RedisConnectionManager}; @@ -220,9 +220,10 @@ async fn sign_in( async fn register( extract::Json(payload): extract::Json, + Extension(config): Extension, Extension(redis_client): Extension, ) -> Result<(StatusCode, Json), CustomError> { - let registration = oidc::register(payload, &redis_client).await?; + let registration = oidc::register(payload, config.base_url, &redis_client).await?; Ok((StatusCode::CREATED, registration.into())) } @@ -289,9 +290,34 @@ async fn userinfo( async fn clientinfo( Path(client_id): Path, + bearer: Option>>, Extension(redis_client): Extension, ) -> Result, CustomError> { - Ok(oidc::clientinfo(client_id, &redis_client).await?.into()) + Ok( + oidc::clientinfo(client_id, bearer.map(|b| b.0 .0), &redis_client) + .await? + .into(), + ) +} + +async fn client_update( + Path(client_id): Path, + extract::Json(payload): extract::Json, + bearer: Option>>, + Extension(redis_client): Extension, +) -> Result<(), CustomError> { + Ok(oidc::client_update(client_id, payload, bearer.map(|b| b.0 .0), &redis_client).await?) +} + +async fn client_delete( + Path(client_id): Path, + bearer: Option>>, + Extension(redis_client): Extension, +) -> Result<(StatusCode, ()), CustomError> { + Ok(( + StatusCode::NO_CONTENT, + oidc::client_delete(client_id, bearer.map(|b| b.0 .0), &redis_client).await?, + )) } async fn healthcheck() {} @@ -400,7 +426,9 @@ pub async fn main() { .route(oidc::AUTHORIZE_PATH, get(authorize)) .route(oidc::REGISTER_PATH, post(register)) .route(oidc::USERINFO_PATH, get(userinfo).post(userinfo)) - .route(&format!("{}/:id", oidc::CLIENTINFO_PATH), get(clientinfo)) + .route(&format!("{}/:id", oidc::CLIENT_PATH), get(clientinfo)) + .route(&format!("{}/:id", oidc::CLIENT_PATH), delete(client_delete)) + .route(&format!("{}/:id", oidc::CLIENT_PATH), post(client_update)) .route(oidc::SIGNIN_PATH, get(sign_in)) .route("/health", get(healthcheck)) .layer(AddExtensionLayer::new(private_key)) diff --git a/src/db/cf.rs b/src/db/cf.rs index 097602a..c4fff58 100644 --- a/src/db/cf.rs +++ b/src/db/cf.rs @@ -115,6 +115,7 @@ impl DBClient for CFClient { .map_err(|e| anyhow!("Failed to put KV: {}", e))?; Ok(()) } + async fn get_client(&self, client_id: String) -> Result> { Ok(self .ctx @@ -125,6 +126,17 @@ impl DBClient for CFClient { .await .map_err(|e| anyhow!("Failed to get KV: {}", e))?) } + + async fn delete_client(&self, client_id: String) -> Result<()> { + Ok(self + .ctx + .kv(KV_NAMESPACE) + .map_err(|e| anyhow!("Failed to get KV store: {}", e))? + .delete(&format!("{}/{}", KV_CLIENT_PREFIX, client_id)) + .await + .map_err(|e| anyhow!("Failed to get KV: {}", e))?) + } + async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> { let namespace = self .ctx diff --git a/src/db/mod.rs b/src/db/mod.rs index 7c2bf9a..a80f43b 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -2,7 +2,7 @@ use anyhow::Result; use async_trait::async_trait; use chrono::{offset::Utc, DateTime}; use ethers_core::types::H160; -use openidconnect::{core::CoreClientMetadata, Nonce}; +use openidconnect::{core::CoreClientMetadata, Nonce, RegistrationAccessToken}; use serde::{Deserialize, Serialize}; #[cfg(not(target_arch = "wasm32"))] @@ -30,6 +30,7 @@ pub struct CodeEntry { pub struct ClientEntry { pub secret: String, pub metadata: CoreClientMetadata, + pub access_token: Option, } // Using a trait to easily pass async functions with async_trait @@ -38,6 +39,7 @@ pub struct ClientEntry { pub trait DBClient { async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()>; async fn get_client(&self, client_id: String) -> Result>; + async fn delete_client(&self, client_id: String) -> Result<()>; async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()>; async fn get_code(&self, code: String) -> Result>; } diff --git a/src/db/redis.rs b/src/db/redis.rs index ab81ea6..7912022 100644 --- a/src/db/redis.rs +++ b/src/db/redis.rs @@ -47,6 +47,18 @@ impl DBClient for RedisClient { } } + async fn delete_client(&self, client_id: String) -> Result<()> { + let mut conn = self + .pool + .get() + .await + .map_err(|e| anyhow!("Failed to get connection to database: {}", e))?; + conn.del(format!("{}/{}", KV_CLIENT_PREFIX, client_id)) + .await + .map_err(|e| anyhow!("Failed to get kv: {}", e))?; + Ok(()) + } + async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> { let mut conn = self .pool diff --git a/src/oidc.rs b/src/oidc.rs index ca56f9a..664d738 100644 --- a/src/oidc.rs +++ b/src/oidc.rs @@ -15,12 +15,13 @@ use openidconnect::{ }, registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse}, url::Url, - AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims, + AccessToken, Audience, AuthUrl, ClientConfigUrl, ClientId, ClientSecret, EmptyAdditionalClaims, EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, EndUserPictureUrl, EndUserUsername, IssuerUrl, JsonWebKeyId, JsonWebKeySetUrl, LocalizedClaim, Nonce, PrivateSigningKey, - RedirectUrl, RegistrationUrl, RequestUrl, ResponseTypes, Scope, StandardClaims, - SubjectIdentifier, TokenUrl, UserInfoUrl, + RedirectUrl, RegistrationAccessToken, RegistrationUrl, RequestUrl, ResponseTypes, Scope, + StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl, }; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; use rsa::{pkcs1::ToRsaPrivateKey, RsaPrivateKey}; use serde::{Deserialize, Serialize}; use siwe::{Message, TimeStamp, Version}; @@ -48,9 +49,9 @@ pub const JWK_PATH: &str = "/jwk"; pub const TOKEN_PATH: &str = "/token"; pub const AUTHORIZE_PATH: &str = "/authorize"; pub const REGISTER_PATH: &str = "/register"; +pub const CLIENT_PATH: &str = "/client"; pub const USERINFO_PATH: &str = "/userinfo"; pub const SIGNIN_PATH: &str = "/sign_in"; -pub const CLIENTINFO_PATH: &str = "/clientinfo"; pub const SIWE_COOKIE_KEY: &str = "siwe"; #[cfg(not(target_arch = "wasm32"))] @@ -535,6 +536,7 @@ pub struct RegisterError { pub async fn register( payload: CoreClientMetadata, + base_url: Url, db_client: &DBClientType, ) -> Result { let id = Uuid::new_v4(); @@ -549,9 +551,18 @@ pub async fn register( } } + let access_token = RegistrationAccessToken::new( + thread_rng() + .sample_iter(&Alphanumeric) + .take(11) + .map(char::from) + .collect(), + ); + let entry = ClientEntry { secret: secret.to_string(), metadata: payload, + access_token: Some(access_token.clone()), }; db_client.set_client(id.to_string(), entry).await?; @@ -561,18 +572,62 @@ pub async fn register( EmptyAdditionalClientMetadata::default(), EmptyAdditionalClientRegistrationResponse::default(), ) - .set_client_secret(Some(ClientSecret::new(secret.to_string())))) + .set_client_secret(Some(ClientSecret::new(secret.to_string()))) + .set_registration_client_uri(Some(ClientConfigUrl::from_url( + base_url + .join(&format!("{}/{}", CLIENT_PATH, id)) + .map_err(|e| anyhow!("Unable to join URL: {}", e))?, + ))) + .set_registration_access_token(Some(access_token))) +} + +async fn client_access( + client_id: String, + bearer: Option, + db_client: &DBClientType, +) -> Result { + let access_token = if let Some(b) = bearer { + b.token().to_string() + } else { + return Err(CustomError::BadRequest("Missing access token.".to_string())); + }; + let client_entry = db_client + .get_client(client_id) + .await? + .ok_or(CustomError::NotFound)?; + let stored_access_token = client_entry.access_token.clone(); + if stored_access_token.is_none() || *stored_access_token.unwrap().secret() != access_token { + return Err(CustomError::Unauthorized("Bad access token.".to_string())); + } + Ok(client_entry) } pub async fn clientinfo( client_id: String, + bearer: Option, db_client: &DBClientType, ) -> Result { - let client_entry = db_client - .get_client(client_id) - .await? - .ok_or_else(|| CustomError::NotFound)?; - Ok(client_entry.metadata) + Ok(client_access(client_id, bearer, db_client).await?.metadata) +} + +pub async fn client_delete( + client_id: String, + bearer: Option, + db_client: &DBClientType, +) -> Result<(), CustomError> { + client_access(client_id.clone(), bearer, db_client).await?; + Ok(db_client.delete_client(client_id).await?) +} + +pub async fn client_update( + client_id: String, + payload: CoreClientMetadata, + bearer: Option, + db_client: &DBClientType, +) -> Result<(), CustomError> { + let mut client_entry = client_access(client_id.clone(), bearer, db_client).await?; + client_entry.metadata = payload; + Ok(db_client.set_client(client_id, client_entry).await?) } #[derive(Deserialize)] diff --git a/src/worker_lib.rs b/src/worker_lib.rs index 6e47398..86dfa6b 100644 --- a/src/worker_lib.rs +++ b/src/worker_lib.rs @@ -226,9 +226,10 @@ pub async fn main(req: Request, env: Env) -> Result { }) .post_async(oidc::REGISTER_PATH, |mut req, ctx| async move { let payload = req.json().await?; + let base_url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap(); let url = req.url()?; let db_client = CFClient { ctx, url }; - match oidc::register(payload, &db_client).await { + match oidc::register(payload, base_url, &db_client).await { Ok(r) => Ok(Response::from_json(&r)?.with_status(201)), Err(e) => e.into(), } @@ -236,21 +237,72 @@ pub async fn main(req: Request, env: Env) -> Result { .post_async(oidc::USERINFO_PATH, userinfo) .get_async(oidc::USERINFO_PATH, userinfo) .get_async( - &format!("{}/:id", oidc::CLIENTINFO_PATH), + &format!("{}/:id", oidc::CLIENT_PATH), |req, ctx| async move { let client_id = if let Some(id) = ctx.param("id") { id.clone() } else { return Response::error("Bad Request", 400); }; + let bearer = req + .headers() + .get(Authorization::::name().as_str())? + .and_then(|b| HeaderValue::from_str(b.as_ref()).ok()) + .as_ref() + .and_then(Bearer::decode); let url = req.url()?; let db_client = CFClient { ctx, url }; - match oidc::clientinfo(client_id, &db_client).await { + match oidc::clientinfo(client_id, bearer, &db_client).await { Ok(r) => Ok(Response::from_json(&r)?), Err(e) => e.into(), } }, ) + .delete_async( + &format!("{}/:id", oidc::CLIENT_PATH), + |req, ctx| async move { + let client_id = if let Some(id) = ctx.param("id") { + id.clone() + } else { + return Response::error("Bad Request", 400); + }; + let bearer = req + .headers() + .get(Authorization::::name().as_str())? + .and_then(|b| HeaderValue::from_str(b.as_ref()).ok()) + .as_ref() + .and_then(Bearer::decode); + let url = req.url()?; + let db_client = CFClient { ctx, url }; + match oidc::client_delete(client_id, bearer, &db_client).await { + Ok(()) => Ok(Response::empty()?.with_status(204)), + Err(e) => e.into(), + } + }, + ) + .post_async( + &format!("{}/:id", oidc::CLIENT_PATH), + |mut req, ctx| async move { + let client_id = if let Some(id) = ctx.param("id") { + id.clone() + } else { + return Response::error("Bad Request", 400); + }; + let bearer = req + .headers() + .get(Authorization::::name().as_str())? + .and_then(|b| HeaderValue::from_str(b.as_ref()).ok()) + .as_ref() + .and_then(Bearer::decode); + let payload = req.json().await?; + let url = req.url()?; + let db_client = CFClient { ctx, url }; + match oidc::client_update(client_id, payload, bearer, &db_client).await { + Ok(()) => Ok(Response::empty()?), + Err(e) => e.into(), + } + }, + ) .get_async(oidc::SIGNIN_PATH, |req, ctx| async move { let url = req.url()?; let query = url.query().unwrap_or_default(); diff --git a/wrangler_example.toml b/wrangler_example.toml index 05da884..d4be4ca 100644 --- a/wrangler_example.toml +++ b/wrangler_example.toml @@ -10,7 +10,7 @@ kv_namespaces = [ ] [vars] -WORKERS_RS_VERSION = "0.0.7" +WORKERS_RS_VERSION = "0.0.9" BASE_URL = "https://siweoidc.spruceid.xyz" # ETH_PROVIDER = ""