siwe-oidc/src/worker_lib.rs

211 lines
8.4 KiB
Rust
Raw Normal View History

use anyhow::anyhow;
use headers::{
self,
authorization::{Basic, Bearer, Credentials},
Authorization, Header, HeaderValue,
};
use rand::{distributions::Alphanumeric, Rng};
use rsa::{pkcs1::FromRsaPrivateKey, RsaPrivateKey};
use worker::*;
use super::db::CFClient;
use super::oidc::{self, CustomError, TokenForm, UserInfoPayload};
const BASE_URL_KEY: &str = "BASE_URL";
const RSA_PEM_KEY: &str = "RSA_PEM";
// https://github.com/cloudflare/workers-rs/issues/64
// #[global_allocator]
// static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
impl From<CustomError> for Result<Response> {
fn from(error: CustomError) -> Self {
match error {
CustomError::BadRequest(_) => Response::error(&error.to_string(), 400),
CustomError::BadRequestToken(e) => Response::from_json(&e).map(|r| r.with_status(400)),
CustomError::Unauthorized(_) => Response::error(&error.to_string(), 401),
CustomError::Redirect(uri) => Response::redirect(uri.parse().unwrap()),
CustomError::Other(_) => Response::error(&error.to_string(), 500),
}
}
}
pub async fn main(req: Request, env: Env) -> Result<Response> {
console_error_panic_hook::set_once();
// tracing_subscriber::fmt::init();
// console_log::init_with_level(log::Level::Info).expect("error initializing log");
let userinfo = |mut req: Request, ctx: RouteContext<()>| async move {
let bearer = req
.headers()
.get(Authorization::<Bearer>::name().as_str())?
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
.as_ref()
.and_then(Bearer::decode);
let payload = if bearer.is_none() {
match req.form_data().await {
Ok(f) => {
let access_token = if let Some(FormEntry::Field(a)) = f.get("access_token") {
Some(a)
} else {
return Response::error("Missing code", 400);
};
UserInfoPayload { access_token }
}
Err(_) => return Response::error("Bad request", 400),
}
} else {
UserInfoPayload { access_token: None }
};
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::userinfo(bearer, payload, &db_client).await {
Ok(r) => Ok(Response::from_json(&r)?),
Err(e) => e.into(),
}
};
let router = Router::new();
router
.get_async(oidc::METADATA_PATH, |_req, ctx| async move {
match oidc::metadata(ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap()) {
Ok(m) => Response::from_json(&m),
Err(e) => e.into(),
}
})
.get_async(oidc::JWK_PATH, |_req, ctx| async move {
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap();
match oidc::jwks(private_key) {
Ok(m) => Response::from_json(&m),
Err(e) => e.into(),
}
})
.post_async(oidc::TOKEN_PATH, |mut req, ctx| async move {
let form_data = req.form_data().await?;
let code = if let Some(FormEntry::Field(c)) = form_data.get("code") {
c
} else {
return Response::error("Missing code", 400);
};
let client_id = match form_data.get("client_id") {
Some(FormEntry::Field(c)) => Some(c),
None => None,
_ => return Response::error("Client ID not a field", 400),
};
let client_secret = match form_data.get("client_secret") {
Some(FormEntry::Field(c)) => Some(c),
None => None,
_ => return Response::error("Client secret not a field", 400),
};
let grant_type = if let Some(FormEntry::Field(c)) = form_data.get("code") {
if let Ok(cc) = serde_json::from_str(&format!("\"{}\"", c)) {
cc
} else {
return Response::error("Invalid grant type", 400);
}
} else {
return Response::error("Missing grant type", 400);
};
let secret = req
.headers()
.get(Authorization::<Bearer>::name().as_str())?
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
.as_ref()
.and_then(|b| {
if b.to_str().unwrap().starts_with("Bearer") {
Bearer::decode(b).map(|bb| bb.token().to_string())
} else {
Basic::decode(b).map(|bb| bb.password().to_string())
}
});
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap();
let base_url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
let url = req.url()?;
let db_client = CFClient { ctx, url };
let token_response = oidc::token(
TokenForm {
code,
client_id,
client_secret,
grant_type,
},
secret,
private_key,
base_url,
false,
&db_client,
)
.await;
match token_response {
Ok(m) => Response::from_json(&m),
Err(e) => e.into(),
}
})
// TODO add browser session
.get_async(oidc::AUTHORIZE_PATH, |req, ctx| async move {
let base_url: Url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
let url = req.url()?;
let query = url.query().unwrap_or_default();
let params = match serde_urlencoded::from_str(query) {
Ok(p) => p,
Err(_) => return CustomError::BadRequest("Bad query params".to_string()).into(),
};
let nonce = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect();
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::authorize(params, nonce, &db_client).await {
Ok(url) => Response::redirect(base_url.join(&url).unwrap()),
Err(e) => match e {
CustomError::Redirect(url) => {
CustomError::Redirect(base_url.join(&url).unwrap().to_string())
}
c => c,
}
.into(),
}
})
.post_async(oidc::REGISTER_PATH, |mut req, ctx| async move {
let payload = req.json().await?;
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::register(payload, &db_client).await {
Ok(r) => Ok(Response::from_json(&r)?.with_status(201)),
Err(e) => e.into(),
}
})
.post_async(oidc::USERINFO_PATH, userinfo)
.get_async(oidc::USERINFO_PATH, userinfo)
.get_async(oidc::SIGNIN_PATH, |req, ctx| async move {
let url = req.url()?;
let query = url.query().unwrap_or_default();
let params = match serde_urlencoded::from_str(query) {
Ok(p) => p,
Err(_) => return CustomError::BadRequest("Bad query params".to_string()).into(),
};
let cookies = req
.headers()
.get(headers::Cookie::name().as_str())?
.and_then(|c| HeaderValue::from_str(&c).ok())
.and_then(|c| headers::Cookie::decode(&mut [c].iter()).ok());
if cookies.is_none() {
return Response::error("Missing cookies", 400);
}
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::sign_in(params, None, cookies.unwrap(), &db_client).await {
Ok(url) => Response::redirect(url),
Err(e) => e.into(),
}
})
.run(req, env)
.await
}