JWT support for userinfo

This commit is contained in:
Simon Bihel 2022-01-19 17:07:55 +00:00
parent bd3d2e8a1e
commit 5c0b748373
No known key found for this signature in database
GPG Key ID: B7013150BEAA28FD
4 changed files with 129 additions and 40 deletions

View File

@ -18,11 +18,11 @@ use figment::{
use headers::{ use headers::{
self, self,
authorization::{Basic, Bearer}, authorization::{Basic, Bearer},
Authorization, Authorization, ContentType, Header,
}; };
use openidconnect::core::{ use openidconnect::core::{
CoreClientMetadata, CoreClientRegistrationResponse, CoreJsonWebKeySet, CoreProviderMetadata, CoreClientMetadata, CoreClientRegistrationResponse, CoreJsonWebKeySet, CoreProviderMetadata,
CoreResponseType, CoreTokenResponse, CoreUserInfoClaims, CoreResponseType, CoreTokenResponse, CoreUserInfoClaims, CoreUserInfoJsonWebToken,
}; };
use rand::rngs::OsRng; use rand::rngs::OsRng;
use rsa::{ use rsa::{
@ -221,20 +221,64 @@ async fn register(
Ok((StatusCode::CREATED, registration.into())) Ok((StatusCode::CREATED, registration.into()))
} }
struct UserInfoResponseJWT(Json<CoreUserInfoJsonWebToken>);
impl IntoResponse for UserInfoResponseJWT {
fn into_response(self) -> response::Response {
response::Response::builder()
.status(StatusCode::OK)
.header(ContentType::name(), "application/jwt")
.body(
serde_json::to_string(&self.0 .0)
.unwrap()
.replace('"', "")
.into_response()
.into_body(),
)
.unwrap()
}
}
enum UserInfoResponse {
Json(Json<CoreUserInfoClaims>),
Jwt(UserInfoResponseJWT),
}
impl IntoResponse for UserInfoResponse {
fn into_response(self) -> response::Response {
match self {
UserInfoResponse::Json(j) => j.into_response(),
UserInfoResponse::Jwt(j) => j.into_response(),
}
}
}
// TODO CORS // TODO CORS
// TODO need validation of the token // TODO need validation of the token
async fn userinfo( async fn userinfo(
Extension(private_key): Extension<RsaPrivateKey>,
Extension(config): Extension<config::Config>,
payload: Option<Form<oidc::UserInfoPayload>>, payload: Option<Form<oidc::UserInfoPayload>>,
bearer: Option<TypedHeader<Authorization<Bearer>>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs bearer: Option<TypedHeader<Authorization<Bearer>>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
Extension(redis_client): Extension<RedisClient>, Extension(redis_client): Extension<RedisClient>,
) -> Result<Json<CoreUserInfoClaims>, CustomError> { ) -> Result<UserInfoResponse, CustomError> {
let payload = if let Some(Form(p)) = payload { let payload = if let Some(Form(p)) = payload {
p p
} else { } else {
oidc::UserInfoPayload { access_token: None } oidc::UserInfoPayload { access_token: None }
}; };
let claims = oidc::userinfo(bearer.map(|b| b.0 .0), payload, &redis_client).await?; let claims = oidc::userinfo(
Ok(claims.into()) config.base_url,
private_key,
bearer.map(|b| b.0 .0),
payload,
&redis_client,
)
.await?;
Ok(match claims {
oidc::UserInfoResponse::Json(c) => UserInfoResponse::Json(c.into()),
oidc::UserInfoResponse::Jwt(c) => UserInfoResponse::Jwt(UserInfoResponseJWT(c.into())),
})
} }
async fn healthcheck() {} async fn healthcheck() {}
@ -253,16 +297,16 @@ pub async fn main() {
let redis_client = RedisClient { pool }; let redis_client = RedisClient { pool };
for (id, secret) in &config.default_clients.clone() { // for (id, secret) in &config.default_clients.clone() {
let client_entry = ClientEntry { // let client_entry = ClientEntry {
secret: secret.to_string(), // secret: secret.to_string(),
redirect_uris: vec![], // redirect_uris: vec![],
}; // };
redis_client // redis_client
.set_client(id.to_string(), client_entry) // .set_client(id.to_string(), client_entry)
.await // .await
.unwrap(); // TODO // .unwrap(); // TODO
} // }
let private_key = if let Some(key) = &config.rsa_pem { let private_key = if let Some(key) = &config.rsa_pem {
RsaPrivateKey::from_pkcs1_pem(key) RsaPrivateKey::from_pkcs1_pem(key)

View File

@ -1,7 +1,7 @@
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{offset::Utc, DateTime}; use chrono::{offset::Utc, DateTime};
use openidconnect::{Nonce, RedirectUrl}; use openidconnect::{core::CoreClientMetadata, Nonce};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -28,7 +28,7 @@ pub struct CodeEntry {
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct ClientEntry { pub struct ClientEntry {
pub secret: String, pub secret: String,
pub redirect_uris: Vec<RedirectUrl>, pub metadata: CoreClientMetadata,
} }
// Using a trait to easily pass async functions with async_trait // Using a trait to easily pass async functions with async_trait

View File

@ -10,6 +10,7 @@ use openidconnect::{
CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet, CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey, CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims, CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
CoreUserInfoJsonWebToken,
}, },
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse}, registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
url::Url, url::Url,
@ -32,6 +33,7 @@ use super::db::*;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use siwe_oidc::db::*; use siwe_oidc::db::*;
const SIGNING_ALG: [CoreJwsSigningAlgorithm; 1] = [CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256];
const KID: &str = "key1"; const KID: &str = "key1";
pub const METADATA_PATH: &str = "/.well-known/openid-configuration"; pub const METADATA_PATH: &str = "/.well-known/openid-configuration";
pub const JWK_PATH: &str = "/jwk"; pub const JWK_PATH: &str = "/jwk";
@ -67,16 +69,17 @@ pub enum CustomError {
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
pub fn jwks(private_key: RsaPrivateKey) -> Result<CoreJsonWebKeySet, CustomError> { fn jwk(private_key: RsaPrivateKey) -> Result<CoreRsaPrivateSigningKey> {
let pem = private_key let pem = private_key
.to_pkcs1_pem() .to_pkcs1_pem()
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?; .map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem( CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string())))
&pem, .map_err(|e| anyhow!("Invalid RSA private key: {}", e))
Some(JsonWebKeyId::new(KID.to_string())), }
)
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))? pub fn jwks(private_key: RsaPrivateKey) -> Result<CoreJsonWebKeySet, CustomError> {
.as_verification_key()]); let signing_key = jwk(private_key)?;
let jwks = CoreJsonWebKeySet::new(vec![signing_key.as_verification_key()]);
Ok(jwks) Ok(jwks)
} }
@ -98,7 +101,7 @@ pub fn metadata(base_url: Url) -> Result<CoreProviderMetadata, CustomError> {
ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]), ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]),
], ],
vec![CoreSubjectIdentifierType::Pairwise], vec![CoreSubjectIdentifierType::Pairwise],
vec![CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256], SIGNING_ALG.to_vec(),
EmptyAdditionalProviderMetadata {}, EmptyAdditionalProviderMetadata {},
) )
.set_token_endpoint(Some(TokenUrl::from_url( .set_token_endpoint(Some(TokenUrl::from_url(
@ -111,6 +114,7 @@ pub fn metadata(base_url: Url) -> Result<CoreProviderMetadata, CustomError> {
.join(USERINFO_PATH) .join(USERINFO_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?, .map_err(|e| anyhow!("Unable to join URL: {}", e))?,
))) )))
.set_userinfo_signing_alg_values_supported(Some(SIGNING_ALG.to_vec()))
.set_scopes_supported(Some(vec![ .set_scopes_supported(Some(vec![
Scope::new("openid".to_string()), Scope::new("openid".to_string()),
// Scope::new("email".to_string()), // Scope::new("email".to_string()),
@ -138,6 +142,7 @@ pub fn metadata(base_url: Url) -> Result<CoreProviderMetadata, CustomError> {
.set_token_endpoint_auth_methods_supported(Some(vec![ .set_token_endpoint_auth_methods_supported(Some(vec![
CoreClientAuthMethod::ClientSecretBasic, CoreClientAuthMethod::ClientSecretBasic,
CoreClientAuthMethod::ClientSecretPost, CoreClientAuthMethod::ClientSecretPost,
CoreClientAuthMethod::PrivateKeyJwt,
])); ]));
Ok(pm) Ok(pm)
@ -274,7 +279,9 @@ pub async fn authorize(
r_u.set_query(None); r_u.set_query(None);
let mut r_us: Vec<Url> = client_entry let mut r_us: Vec<Url> = client_entry
.unwrap() .unwrap()
.redirect_uris .metadata
.redirect_uris()
.clone()
.iter_mut() .iter_mut()
.map(|u| u.url().clone()) .map(|u| u.url().clone())
.collect(); .collect();
@ -343,12 +350,7 @@ pub async fn authorize(
}; };
Ok(format!( Ok(format!(
"/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}", "/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}",
nonce, nonce, domain, *params.redirect_uri, state, params.client_id, oidc_nonce_param
domain,
params.redirect_uri.to_string(),
state,
params.client_id,
oidc_nonce_param
)) ))
} }
@ -484,15 +486,17 @@ pub async fn register(
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let secret = Uuid::new_v4(); let secret = Uuid::new_v4();
let redirect_uris = payload.redirect_uris().to_vec();
let entry = ClientEntry { let entry = ClientEntry {
secret: secret.to_string(), secret: secret.to_string(),
redirect_uris: payload.redirect_uris().to_vec(), metadata: payload,
}; };
db_client.set_client(id.to_string(), entry).await?; db_client.set_client(id.to_string(), entry).await?;
Ok(CoreClientRegistrationResponse::new( Ok(CoreClientRegistrationResponse::new(
ClientId::new(id.to_string()), ClientId::new(id.to_string()),
payload.redirect_uris().to_vec(), redirect_uris,
EmptyAdditionalClientMetadata::default(), EmptyAdditionalClientMetadata::default(),
EmptyAdditionalClientRegistrationResponse::default(), EmptyAdditionalClientRegistrationResponse::default(),
) )
@ -504,11 +508,18 @@ pub struct UserInfoPayload {
pub access_token: Option<String>, pub access_token: Option<String>,
} }
pub enum UserInfoResponse {
Json(CoreUserInfoClaims),
Jwt(CoreUserInfoJsonWebToken),
}
pub async fn userinfo( pub async fn userinfo(
base_url: Url,
private_key: RsaPrivateKey,
bearer: Option<Bearer>, bearer: Option<Bearer>,
payload: UserInfoPayload, payload: UserInfoPayload,
db_client: &DBClientType, db_client: &DBClientType,
) -> Result<CoreUserInfoClaims, CustomError> { ) -> Result<UserInfoResponse, CustomError> {
let code = if let Some(b) = bearer { let code = if let Some(b) = bearer {
b.token().to_string() b.token().to_string()
} else if let Some(c) = payload.access_token { } else if let Some(c) = payload.access_token {
@ -522,8 +533,26 @@ pub async fn userinfo(
return Err(CustomError::BadRequest("Unknown code.".to_string())); return Err(CustomError::BadRequest("Unknown code.".to_string()));
}; };
Ok(CoreUserInfoClaims::new( let client_entry = if let Some(c) = db_client.get_client(code_entry.client_id.clone()).await? {
c
} else {
return Err(CustomError::BadRequest("Unknown client.".to_string()));
};
let response = CoreUserInfoClaims::new(
StandardClaims::new(SubjectIdentifier::new(code_entry.address)), StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
EmptyAdditionalClaims::default(), EmptyAdditionalClaims::default(),
)
.set_issuer(Some(IssuerUrl::from_url(base_url.clone())))
.set_audiences(Some(vec![Audience::new(code_entry.client_id)]));
match client_entry.metadata.userinfo_signed_response_alg() {
None => Ok(UserInfoResponse::Json(response)),
Some(alg) => {
let signing_key = jwk(private_key)?;
Ok(UserInfoResponse::Jwt(
CoreUserInfoJsonWebToken::new(response, &signing_key, alg.clone())
.map_err(|_| anyhow!("Error signing response."))?,
)) ))
} }
}
}

View File

@ -2,7 +2,7 @@ use anyhow::anyhow;
use headers::{ use headers::{
self, self,
authorization::{Basic, Bearer, Credentials}, authorization::{Basic, Bearer, Credentials},
Authorization, Header, HeaderValue, Authorization, ContentType, Header, HeaderValue,
}; };
use rand::{distributions::Alphanumeric, Rng}; use rand::{distributions::Alphanumeric, Rng};
use rsa::{pkcs1::FromRsaPrivateKey, RsaPrivateKey}; use rsa::{pkcs1::FromRsaPrivateKey, RsaPrivateKey};
@ -57,10 +57,26 @@ pub async fn main(req: Request, env: Env) -> Result<Response> {
} else { } else {
UserInfoPayload { access_token: None } UserInfoPayload { access_token: None }
}; };
let base_url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap();
let url = req.url()?; let url = req.url()?;
let db_client = CFClient { ctx, url }; let db_client = CFClient { ctx, url };
match oidc::userinfo(bearer, payload, &db_client).await { match oidc::userinfo(base_url, private_key, bearer, payload, &db_client).await {
Ok(r) => Ok(Response::from_json(&r)?), Ok(oidc::UserInfoResponse::Json(r)) => Ok(Response::from_json(&r)?),
Ok(oidc::UserInfoResponse::Jwt(r)) => {
let mut headers = Headers::new();
headers.append(&ContentType::name().to_string(), "application/jwt")?;
Ok(Response::from_bytes(
serde_json::to_string(&r)
.unwrap()
.replace('"', "")
.as_bytes()
.to_vec(),
)?
.with_headers(headers))
}
Err(e) => e.into(), Err(e) => e.into(),
} }
}; };