Complete client configuration endpoints

This commit is contained in:
Simon Bihel 2022-02-14 22:29:12 +00:00
parent 27e36e2aa6
commit e0241feb9a
No known key found for this signature in database
GPG Key ID: B7013150BEAA28FD
7 changed files with 180 additions and 19 deletions

View File

@ -7,7 +7,7 @@ use axum::{
StatusCode, StatusCode,
}, },
response::{self, IntoResponse, Redirect}, response::{self, IntoResponse, Redirect},
routing::{get, get_service, post}, routing::{delete, get, get_service, post},
AddExtensionLayer, Json, Router, AddExtensionLayer, Json, Router,
}; };
use bb8_redis::{bb8, RedisConnectionManager}; use bb8_redis::{bb8, RedisConnectionManager};
@ -220,9 +220,10 @@ async fn sign_in(
async fn register( async fn register(
extract::Json(payload): extract::Json<CoreClientMetadata>, extract::Json(payload): extract::Json<CoreClientMetadata>,
Extension(config): Extension<config::Config>,
Extension(redis_client): Extension<RedisClient>, Extension(redis_client): Extension<RedisClient>,
) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> { ) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> {
let registration = oidc::register(payload, &redis_client).await?; let registration = oidc::register(payload, config.base_url, &redis_client).await?;
Ok((StatusCode::CREATED, registration.into())) Ok((StatusCode::CREATED, registration.into()))
} }
@ -289,9 +290,34 @@ async fn userinfo(
async fn clientinfo( async fn clientinfo(
Path(client_id): Path<String>, Path(client_id): Path<String>,
bearer: Option<TypedHeader<Authorization<Bearer>>>,
Extension(redis_client): Extension<RedisClient>, Extension(redis_client): Extension<RedisClient>,
) -> Result<Json<CoreClientMetadata>, CustomError> { ) -> Result<Json<CoreClientMetadata>, CustomError> {
Ok(oidc::clientinfo(client_id, &redis_client).await?.into()) Ok(
oidc::clientinfo(client_id, bearer.map(|b| b.0 .0), &redis_client)
.await?
.into(),
)
}
async fn client_update(
Path(client_id): Path<String>,
extract::Json(payload): extract::Json<CoreClientMetadata>,
bearer: Option<TypedHeader<Authorization<Bearer>>>,
Extension(redis_client): Extension<RedisClient>,
) -> Result<(), CustomError> {
Ok(oidc::client_update(client_id, payload, bearer.map(|b| b.0 .0), &redis_client).await?)
}
async fn client_delete(
Path(client_id): Path<String>,
bearer: Option<TypedHeader<Authorization<Bearer>>>,
Extension(redis_client): Extension<RedisClient>,
) -> Result<(StatusCode, ()), CustomError> {
Ok((
StatusCode::NO_CONTENT,
oidc::client_delete(client_id, bearer.map(|b| b.0 .0), &redis_client).await?,
))
} }
async fn healthcheck() {} async fn healthcheck() {}
@ -400,7 +426,9 @@ pub async fn main() {
.route(oidc::AUTHORIZE_PATH, get(authorize)) .route(oidc::AUTHORIZE_PATH, get(authorize))
.route(oidc::REGISTER_PATH, post(register)) .route(oidc::REGISTER_PATH, post(register))
.route(oidc::USERINFO_PATH, get(userinfo).post(userinfo)) .route(oidc::USERINFO_PATH, get(userinfo).post(userinfo))
.route(&format!("{}/:id", oidc::CLIENTINFO_PATH), get(clientinfo)) .route(&format!("{}/:id", oidc::CLIENT_PATH), get(clientinfo))
.route(&format!("{}/:id", oidc::CLIENT_PATH), delete(client_delete))
.route(&format!("{}/:id", oidc::CLIENT_PATH), post(client_update))
.route(oidc::SIGNIN_PATH, get(sign_in)) .route(oidc::SIGNIN_PATH, get(sign_in))
.route("/health", get(healthcheck)) .route("/health", get(healthcheck))
.layer(AddExtensionLayer::new(private_key)) .layer(AddExtensionLayer::new(private_key))

View File

@ -115,6 +115,7 @@ impl DBClient for CFClient {
.map_err(|e| anyhow!("Failed to put KV: {}", e))?; .map_err(|e| anyhow!("Failed to put KV: {}", e))?;
Ok(()) Ok(())
} }
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> { async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
Ok(self Ok(self
.ctx .ctx
@ -125,6 +126,17 @@ impl DBClient for CFClient {
.await .await
.map_err(|e| anyhow!("Failed to get KV: {}", e))?) .map_err(|e| anyhow!("Failed to get KV: {}", e))?)
} }
async fn delete_client(&self, client_id: String) -> Result<()> {
Ok(self
.ctx
.kv(KV_NAMESPACE)
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
.delete(&format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get KV: {}", e))?)
}
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> { async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
let namespace = self let namespace = self
.ctx .ctx

View File

@ -2,7 +2,7 @@ use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{offset::Utc, DateTime}; use chrono::{offset::Utc, DateTime};
use ethers_core::types::H160; use ethers_core::types::H160;
use openidconnect::{core::CoreClientMetadata, Nonce}; use openidconnect::{core::CoreClientMetadata, Nonce, RegistrationAccessToken};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -30,6 +30,7 @@ pub struct CodeEntry {
pub struct ClientEntry { pub struct ClientEntry {
pub secret: String, pub secret: String,
pub metadata: CoreClientMetadata, pub metadata: CoreClientMetadata,
pub access_token: Option<RegistrationAccessToken>,
} }
// Using a trait to easily pass async functions with async_trait // Using a trait to easily pass async functions with async_trait
@ -38,6 +39,7 @@ pub struct ClientEntry {
pub trait DBClient { pub trait DBClient {
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()>; async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()>;
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>>; async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>>;
async fn delete_client(&self, client_id: String) -> Result<()>;
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()>; async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()>;
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>>; async fn get_code(&self, code: String) -> Result<Option<CodeEntry>>;
} }

View File

@ -47,6 +47,18 @@ impl DBClient for RedisClient {
} }
} }
async fn delete_client(&self, client_id: String) -> Result<()> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
conn.del(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
Ok(())
}
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> { async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
let mut conn = self let mut conn = self
.pool .pool

View File

@ -15,12 +15,13 @@ use openidconnect::{
}, },
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse}, registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
url::Url, url::Url,
AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims, AccessToken, Audience, AuthUrl, ClientConfigUrl, ClientId, ClientSecret, EmptyAdditionalClaims,
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, EndUserPictureUrl, EndUserUsername, EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, EndUserPictureUrl, EndUserUsername,
IssuerUrl, JsonWebKeyId, JsonWebKeySetUrl, LocalizedClaim, Nonce, PrivateSigningKey, IssuerUrl, JsonWebKeyId, JsonWebKeySetUrl, LocalizedClaim, Nonce, PrivateSigningKey,
RedirectUrl, RegistrationUrl, RequestUrl, ResponseTypes, Scope, StandardClaims, RedirectUrl, RegistrationAccessToken, RegistrationUrl, RequestUrl, ResponseTypes, Scope,
SubjectIdentifier, TokenUrl, UserInfoUrl, StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
}; };
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use rsa::{pkcs1::ToRsaPrivateKey, RsaPrivateKey}; use rsa::{pkcs1::ToRsaPrivateKey, RsaPrivateKey};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use siwe::{Message, TimeStamp, Version}; use siwe::{Message, TimeStamp, Version};
@ -48,9 +49,9 @@ pub const JWK_PATH: &str = "/jwk";
pub const TOKEN_PATH: &str = "/token"; pub const TOKEN_PATH: &str = "/token";
pub const AUTHORIZE_PATH: &str = "/authorize"; pub const AUTHORIZE_PATH: &str = "/authorize";
pub const REGISTER_PATH: &str = "/register"; pub const REGISTER_PATH: &str = "/register";
pub const CLIENT_PATH: &str = "/client";
pub const USERINFO_PATH: &str = "/userinfo"; pub const USERINFO_PATH: &str = "/userinfo";
pub const SIGNIN_PATH: &str = "/sign_in"; pub const SIGNIN_PATH: &str = "/sign_in";
pub const CLIENTINFO_PATH: &str = "/clientinfo";
pub const SIWE_COOKIE_KEY: &str = "siwe"; pub const SIWE_COOKIE_KEY: &str = "siwe";
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -535,6 +536,7 @@ pub struct RegisterError {
pub async fn register( pub async fn register(
payload: CoreClientMetadata, payload: CoreClientMetadata,
base_url: Url,
db_client: &DBClientType, db_client: &DBClientType,
) -> Result<CoreClientRegistrationResponse, CustomError> { ) -> Result<CoreClientRegistrationResponse, CustomError> {
let id = Uuid::new_v4(); let id = Uuid::new_v4();
@ -549,9 +551,18 @@ pub async fn register(
} }
} }
let access_token = RegistrationAccessToken::new(
thread_rng()
.sample_iter(&Alphanumeric)
.take(11)
.map(char::from)
.collect(),
);
let entry = ClientEntry { let entry = ClientEntry {
secret: secret.to_string(), secret: secret.to_string(),
metadata: payload, metadata: payload,
access_token: Some(access_token.clone()),
}; };
db_client.set_client(id.to_string(), entry).await?; db_client.set_client(id.to_string(), entry).await?;
@ -561,18 +572,62 @@ pub async fn register(
EmptyAdditionalClientMetadata::default(), EmptyAdditionalClientMetadata::default(),
EmptyAdditionalClientRegistrationResponse::default(), EmptyAdditionalClientRegistrationResponse::default(),
) )
.set_client_secret(Some(ClientSecret::new(secret.to_string())))) .set_client_secret(Some(ClientSecret::new(secret.to_string())))
.set_registration_client_uri(Some(ClientConfigUrl::from_url(
base_url
.join(&format!("{}/{}", CLIENT_PATH, id))
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_registration_access_token(Some(access_token)))
}
async fn client_access(
client_id: String,
bearer: Option<Bearer>,
db_client: &DBClientType,
) -> Result<ClientEntry, CustomError> {
let access_token = if let Some(b) = bearer {
b.token().to_string()
} else {
return Err(CustomError::BadRequest("Missing access token.".to_string()));
};
let client_entry = db_client
.get_client(client_id)
.await?
.ok_or(CustomError::NotFound)?;
let stored_access_token = client_entry.access_token.clone();
if stored_access_token.is_none() || *stored_access_token.unwrap().secret() != access_token {
return Err(CustomError::Unauthorized("Bad access token.".to_string()));
}
Ok(client_entry)
} }
pub async fn clientinfo( pub async fn clientinfo(
client_id: String, client_id: String,
bearer: Option<Bearer>,
db_client: &DBClientType, db_client: &DBClientType,
) -> Result<CoreClientMetadata, CustomError> { ) -> Result<CoreClientMetadata, CustomError> {
let client_entry = db_client Ok(client_access(client_id, bearer, db_client).await?.metadata)
.get_client(client_id) }
.await?
.ok_or_else(|| CustomError::NotFound)?; pub async fn client_delete(
Ok(client_entry.metadata) client_id: String,
bearer: Option<Bearer>,
db_client: &DBClientType,
) -> Result<(), CustomError> {
client_access(client_id.clone(), bearer, db_client).await?;
Ok(db_client.delete_client(client_id).await?)
}
pub async fn client_update(
client_id: String,
payload: CoreClientMetadata,
bearer: Option<Bearer>,
db_client: &DBClientType,
) -> Result<(), CustomError> {
let mut client_entry = client_access(client_id.clone(), bearer, db_client).await?;
client_entry.metadata = payload;
Ok(db_client.set_client(client_id, client_entry).await?)
} }
#[derive(Deserialize)] #[derive(Deserialize)]

View File

@ -226,9 +226,10 @@ pub async fn main(req: Request, env: Env) -> Result<Response> {
}) })
.post_async(oidc::REGISTER_PATH, |mut req, ctx| async move { .post_async(oidc::REGISTER_PATH, |mut req, ctx| async move {
let payload = req.json().await?; let payload = req.json().await?;
let base_url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
let url = req.url()?; let url = req.url()?;
let db_client = CFClient { ctx, url }; let db_client = CFClient { ctx, url };
match oidc::register(payload, &db_client).await { match oidc::register(payload, base_url, &db_client).await {
Ok(r) => Ok(Response::from_json(&r)?.with_status(201)), Ok(r) => Ok(Response::from_json(&r)?.with_status(201)),
Err(e) => e.into(), Err(e) => e.into(),
} }
@ -236,21 +237,72 @@ pub async fn main(req: Request, env: Env) -> Result<Response> {
.post_async(oidc::USERINFO_PATH, userinfo) .post_async(oidc::USERINFO_PATH, userinfo)
.get_async(oidc::USERINFO_PATH, userinfo) .get_async(oidc::USERINFO_PATH, userinfo)
.get_async( .get_async(
&format!("{}/:id", oidc::CLIENTINFO_PATH), &format!("{}/:id", oidc::CLIENT_PATH),
|req, ctx| async move { |req, ctx| async move {
let client_id = if let Some(id) = ctx.param("id") { let client_id = if let Some(id) = ctx.param("id") {
id.clone() id.clone()
} else { } else {
return Response::error("Bad Request", 400); return Response::error("Bad Request", 400);
}; };
let bearer = req
.headers()
.get(Authorization::<Bearer>::name().as_str())?
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
.as_ref()
.and_then(Bearer::decode);
let url = req.url()?; let url = req.url()?;
let db_client = CFClient { ctx, url }; let db_client = CFClient { ctx, url };
match oidc::clientinfo(client_id, &db_client).await { match oidc::clientinfo(client_id, bearer, &db_client).await {
Ok(r) => Ok(Response::from_json(&r)?), Ok(r) => Ok(Response::from_json(&r)?),
Err(e) => e.into(), Err(e) => e.into(),
} }
}, },
) )
.delete_async(
&format!("{}/:id", oidc::CLIENT_PATH),
|req, ctx| async move {
let client_id = if let Some(id) = ctx.param("id") {
id.clone()
} else {
return Response::error("Bad Request", 400);
};
let bearer = req
.headers()
.get(Authorization::<Bearer>::name().as_str())?
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
.as_ref()
.and_then(Bearer::decode);
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::client_delete(client_id, bearer, &db_client).await {
Ok(()) => Ok(Response::empty()?.with_status(204)),
Err(e) => e.into(),
}
},
)
.post_async(
&format!("{}/:id", oidc::CLIENT_PATH),
|mut req, ctx| async move {
let client_id = if let Some(id) = ctx.param("id") {
id.clone()
} else {
return Response::error("Bad Request", 400);
};
let bearer = req
.headers()
.get(Authorization::<Bearer>::name().as_str())?
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
.as_ref()
.and_then(Bearer::decode);
let payload = req.json().await?;
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::client_update(client_id, payload, bearer, &db_client).await {
Ok(()) => Ok(Response::empty()?),
Err(e) => e.into(),
}
},
)
.get_async(oidc::SIGNIN_PATH, |req, ctx| async move { .get_async(oidc::SIGNIN_PATH, |req, ctx| async move {
let url = req.url()?; let url = req.url()?;
let query = url.query().unwrap_or_default(); let query = url.query().unwrap_or_default();

View File

@ -10,7 +10,7 @@ kv_namespaces = [
] ]
[vars] [vars]
WORKERS_RS_VERSION = "0.0.7" WORKERS_RS_VERSION = "0.0.9"
BASE_URL = "https://siweoidc.spruceid.xyz" BASE_URL = "https://siweoidc.spruceid.xyz"
# ETH_PROVIDER = "" # ETH_PROVIDER = ""