diff --git a/src/axum_lib.rs b/src/axum_lib.rs index fac1e96..8edb886 100644 --- a/src/axum_lib.rs +++ b/src/axum_lib.rs @@ -18,11 +18,11 @@ use figment::{ use headers::{ self, authorization::{Basic, Bearer}, - Authorization, + Authorization, ContentType, Header, }; use openidconnect::core::{ CoreClientMetadata, CoreClientRegistrationResponse, CoreJsonWebKeySet, CoreProviderMetadata, - CoreResponseType, CoreTokenResponse, CoreUserInfoClaims, + CoreResponseType, CoreTokenResponse, CoreUserInfoClaims, CoreUserInfoJsonWebToken, }; use rand::rngs::OsRng; use rsa::{ @@ -221,20 +221,64 @@ async fn register( Ok((StatusCode::CREATED, registration.into())) } +struct UserInfoResponseJWT(Json); + +impl IntoResponse for UserInfoResponseJWT { + fn into_response(self) -> response::Response { + response::Response::builder() + .status(StatusCode::OK) + .header(ContentType::name(), "application/jwt") + .body( + serde_json::to_string(&self.0 .0) + .unwrap() + .replace('"', "") + .into_response() + .into_body(), + ) + .unwrap() + } +} + +enum UserInfoResponse { + Json(Json), + Jwt(UserInfoResponseJWT), +} + +impl IntoResponse for UserInfoResponse { + fn into_response(self) -> response::Response { + match self { + UserInfoResponse::Json(j) => j.into_response(), + UserInfoResponse::Jwt(j) => j.into_response(), + } + } +} + // TODO CORS // TODO need validation of the token async fn userinfo( + Extension(private_key): Extension, + Extension(config): Extension, payload: Option>, bearer: Option>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs Extension(redis_client): Extension, -) -> Result, CustomError> { +) -> Result { let payload = if let Some(Form(p)) = payload { p } else { oidc::UserInfoPayload { access_token: None } }; - let claims = oidc::userinfo(bearer.map(|b| b.0 .0), payload, &redis_client).await?; - Ok(claims.into()) + let claims = oidc::userinfo( + config.base_url, + private_key, + bearer.map(|b| b.0 .0), + payload, + &redis_client, + ) + .await?; + Ok(match claims { + oidc::UserInfoResponse::Json(c) => UserInfoResponse::Json(c.into()), + oidc::UserInfoResponse::Jwt(c) => UserInfoResponse::Jwt(UserInfoResponseJWT(c.into())), + }) } async fn healthcheck() {} @@ -253,16 +297,16 @@ pub async fn main() { let redis_client = RedisClient { pool }; - for (id, secret) in &config.default_clients.clone() { - let client_entry = ClientEntry { - secret: secret.to_string(), - redirect_uris: vec![], - }; - redis_client - .set_client(id.to_string(), client_entry) - .await - .unwrap(); // TODO - } + // for (id, secret) in &config.default_clients.clone() { + // let client_entry = ClientEntry { + // secret: secret.to_string(), + // redirect_uris: vec![], + // }; + // redis_client + // .set_client(id.to_string(), client_entry) + // .await + // .unwrap(); // TODO + // } let private_key = if let Some(key) = &config.rsa_pem { RsaPrivateKey::from_pkcs1_pem(key) diff --git a/src/db/mod.rs b/src/db/mod.rs index 2fb058d..382c845 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,7 +1,7 @@ use anyhow::Result; use async_trait::async_trait; use chrono::{offset::Utc, DateTime}; -use openidconnect::{Nonce, RedirectUrl}; +use openidconnect::{core::CoreClientMetadata, Nonce}; use serde::{Deserialize, Serialize}; #[cfg(not(target_arch = "wasm32"))] @@ -28,7 +28,7 @@ pub struct CodeEntry { #[derive(Clone, Serialize, Deserialize)] pub struct ClientEntry { pub secret: String, - pub redirect_uris: Vec, + pub metadata: CoreClientMetadata, } // Using a trait to easily pass async functions with async_trait diff --git a/src/oidc.rs b/src/oidc.rs index d5bc2a6..9ccf615 100644 --- a/src/oidc.rs +++ b/src/oidc.rs @@ -10,6 +10,7 @@ use openidconnect::{ CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet, CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey, CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims, + CoreUserInfoJsonWebToken, }, registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse}, url::Url, @@ -32,6 +33,7 @@ use super::db::*; #[cfg(not(target_arch = "wasm32"))] use siwe_oidc::db::*; +const SIGNING_ALG: [CoreJwsSigningAlgorithm; 1] = [CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256]; const KID: &str = "key1"; pub const METADATA_PATH: &str = "/.well-known/openid-configuration"; pub const JWK_PATH: &str = "/jwk"; @@ -67,16 +69,17 @@ pub enum CustomError { Other(#[from] anyhow::Error), } -pub fn jwks(private_key: RsaPrivateKey) -> Result { +fn jwk(private_key: RsaPrivateKey) -> Result { let pem = private_key .to_pkcs1_pem() .map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?; - let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem( - &pem, - Some(JsonWebKeyId::new(KID.to_string())), - ) - .map_err(|e| anyhow!("Invalid RSA private key: {}", e))? - .as_verification_key()]); + CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string()))) + .map_err(|e| anyhow!("Invalid RSA private key: {}", e)) +} + +pub fn jwks(private_key: RsaPrivateKey) -> Result { + let signing_key = jwk(private_key)?; + let jwks = CoreJsonWebKeySet::new(vec![signing_key.as_verification_key()]); Ok(jwks) } @@ -98,7 +101,7 @@ pub fn metadata(base_url: Url) -> Result { ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]), ], vec![CoreSubjectIdentifierType::Pairwise], - vec![CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256], + SIGNING_ALG.to_vec(), EmptyAdditionalProviderMetadata {}, ) .set_token_endpoint(Some(TokenUrl::from_url( @@ -111,6 +114,7 @@ pub fn metadata(base_url: Url) -> Result { .join(USERINFO_PATH) .map_err(|e| anyhow!("Unable to join URL: {}", e))?, ))) + .set_userinfo_signing_alg_values_supported(Some(SIGNING_ALG.to_vec())) .set_scopes_supported(Some(vec![ Scope::new("openid".to_string()), // Scope::new("email".to_string()), @@ -138,6 +142,7 @@ pub fn metadata(base_url: Url) -> Result { .set_token_endpoint_auth_methods_supported(Some(vec![ CoreClientAuthMethod::ClientSecretBasic, CoreClientAuthMethod::ClientSecretPost, + CoreClientAuthMethod::PrivateKeyJwt, ])); Ok(pm) @@ -274,7 +279,9 @@ pub async fn authorize( r_u.set_query(None); let mut r_us: Vec = client_entry .unwrap() - .redirect_uris + .metadata + .redirect_uris() + .clone() .iter_mut() .map(|u| u.url().clone()) .collect(); @@ -343,12 +350,7 @@ pub async fn authorize( }; Ok(format!( "/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}", - nonce, - domain, - params.redirect_uri.to_string(), - state, - params.client_id, - oidc_nonce_param + nonce, domain, *params.redirect_uri, state, params.client_id, oidc_nonce_param )) } @@ -484,15 +486,17 @@ pub async fn register( let id = Uuid::new_v4(); let secret = Uuid::new_v4(); + let redirect_uris = payload.redirect_uris().to_vec(); + let entry = ClientEntry { secret: secret.to_string(), - redirect_uris: payload.redirect_uris().to_vec(), + metadata: payload, }; db_client.set_client(id.to_string(), entry).await?; Ok(CoreClientRegistrationResponse::new( ClientId::new(id.to_string()), - payload.redirect_uris().to_vec(), + redirect_uris, EmptyAdditionalClientMetadata::default(), EmptyAdditionalClientRegistrationResponse::default(), ) @@ -504,11 +508,18 @@ pub struct UserInfoPayload { pub access_token: Option, } +pub enum UserInfoResponse { + Json(CoreUserInfoClaims), + Jwt(CoreUserInfoJsonWebToken), +} + pub async fn userinfo( + base_url: Url, + private_key: RsaPrivateKey, bearer: Option, payload: UserInfoPayload, db_client: &DBClientType, -) -> Result { +) -> Result { let code = if let Some(b) = bearer { b.token().to_string() } else if let Some(c) = payload.access_token { @@ -522,8 +533,26 @@ pub async fn userinfo( return Err(CustomError::BadRequest("Unknown code.".to_string())); }; - Ok(CoreUserInfoClaims::new( + let client_entry = if let Some(c) = db_client.get_client(code_entry.client_id.clone()).await? { + c + } else { + return Err(CustomError::BadRequest("Unknown client.".to_string())); + }; + + let response = CoreUserInfoClaims::new( StandardClaims::new(SubjectIdentifier::new(code_entry.address)), EmptyAdditionalClaims::default(), - )) + ) + .set_issuer(Some(IssuerUrl::from_url(base_url.clone()))) + .set_audiences(Some(vec![Audience::new(code_entry.client_id)])); + match client_entry.metadata.userinfo_signed_response_alg() { + None => Ok(UserInfoResponse::Json(response)), + Some(alg) => { + let signing_key = jwk(private_key)?; + Ok(UserInfoResponse::Jwt( + CoreUserInfoJsonWebToken::new(response, &signing_key, alg.clone()) + .map_err(|_| anyhow!("Error signing response."))?, + )) + } + } } diff --git a/src/worker_lib.rs b/src/worker_lib.rs index 43e0a9f..c1cde49 100644 --- a/src/worker_lib.rs +++ b/src/worker_lib.rs @@ -2,7 +2,7 @@ use anyhow::anyhow; use headers::{ self, authorization::{Basic, Bearer, Credentials}, - Authorization, Header, HeaderValue, + Authorization, ContentType, Header, HeaderValue, }; use rand::{distributions::Alphanumeric, Rng}; use rsa::{pkcs1::FromRsaPrivateKey, RsaPrivateKey}; @@ -57,10 +57,26 @@ pub async fn main(req: Request, env: Env) -> Result { } else { UserInfoPayload { access_token: None } }; + let base_url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap(); + let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string()) + .map_err(|e| anyhow!("Failed to load private key: {}", e)) + .unwrap(); let url = req.url()?; let db_client = CFClient { ctx, url }; - match oidc::userinfo(bearer, payload, &db_client).await { - Ok(r) => Ok(Response::from_json(&r)?), + match oidc::userinfo(base_url, private_key, bearer, payload, &db_client).await { + Ok(oidc::UserInfoResponse::Json(r)) => Ok(Response::from_json(&r)?), + Ok(oidc::UserInfoResponse::Jwt(r)) => { + let mut headers = Headers::new(); + headers.append(&ContentType::name().to_string(), "application/jwt")?; + Ok(Response::from_bytes( + serde_json::to_string(&r) + .unwrap() + .replace('"', "") + .as_bytes() + .to_vec(), + )? + .with_headers(headers)) + } Err(e) => e.into(), } };