Address some issues from the core conformance suite (#4)

Also address Clippy warnings
This commit is contained in:
Simon Bihel 2021-12-20 16:29:43 +00:00 committed by GitHub
parent 0287a60296
commit c37577f218
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1117 additions and 5035 deletions

5682
js/ui/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -41,11 +41,11 @@
"typescript": "^4.0.3", "typescript": "^4.0.3",
"webpack": "^5.16.0", "webpack": "^5.16.0",
"webpack-cli": "^4.4.0", "webpack-cli": "^4.4.0",
"webpack-dev-server": "^3.11.2" "webpack-dev-server": "^4.6.0"
}, },
"scripts": { "scripts": {
"build": "cross-env NODE_ENV=production webpack", "build": "cross-env NODE_ENV=production webpack",
"dev": "webpack serve --content-base ../../static --port 9080", "dev": "webpack serve --static-directory ../../static --port 9080",
"validate": "svelte-check" "validate": "svelte-check"
}, },
"dependencies": { "dependencies": {

View File

@ -13,6 +13,7 @@
export let redirect: string; export let redirect: string;
export let state: string; export let state: string;
export let oidc_nonce: string; export let oidc_nonce: string;
export let client_id: string;
let uri: string = window.location.href.split('?')[0]; let uri: string = window.location.href.split('?')[0];
@ -90,7 +91,7 @@
client.on('signIn', (result) => { client.on('signIn', (result) => {
console.log(result); console.log(result);
window.location.replace( window.location.replace(
`/sign_in?redirect_uri=${encodeURI(redirect)}&state=${encodeURI(state)}${encodeURI(oidc_nonce_param)}`, `/sign_in?redirect_uri=${encodeURI(redirect)}&state=${encodeURI(state)}&client_id=${encodeURI(client_id)}${encodeURI(oidc_nonce_param)}`,
); );
}); });
</script> </script>

View File

@ -11,7 +11,8 @@ const app = new App({
nonce: params.get('nonce'), nonce: params.get('nonce'),
redirect: params.get('redirect_uri'), redirect: params.get('redirect_uri'),
state: params.get('state'), state: params.get('state'),
oidc_nonce: params.get('oidc_nonce') oidc_nonce: params.get('oidc_nonce'),
client_id: params.get('client_id')
} }
}); });

View File

@ -26,7 +26,7 @@ impl Default for Config {
rsa_pem: None, rsa_pem: None,
redis_url: Url::parse("redis://localhost").unwrap(), redis_url: Url::parse("redis://localhost").unwrap(),
default_clients: HashMap::default(), default_clients: HashMap::default(),
require_secret: true, require_secret: false,
} }
} }
} }

43
src/db.rs Normal file
View File

@ -0,0 +1,43 @@
use anyhow::{anyhow, Result};
use bb8_redis::{bb8::PooledConnection, redis::AsyncCommands, RedisConnectionManager};
use openidconnect::RedirectUrl;
use serde::{Deserialize, Serialize};
const KV_CLIENT_PREFIX: &str = "clients";
#[derive(Serialize, Deserialize)]
pub struct ClientEntry {
pub secret: String,
pub redirect_uris: Vec<RedirectUrl>,
}
pub async fn set_client(
mut conn: PooledConnection<'_, RedisConnectionManager>,
client_id: String,
client_entry: ClientEntry,
) -> Result<()> {
conn.set(
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
serde_json::to_string(&client_entry)
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
Ok(())
}
pub async fn get_client(
mut conn: PooledConnection<'_, RedisConnectionManager>,
client_id: String,
) -> Result<Option<ClientEntry>> {
let entry: Option<String> = conn
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if let Some(e) = entry {
Ok(serde_json::from_str(&e)
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
} else {
Ok(None)
}
}

View File

@ -23,16 +23,18 @@ use hex::FromHex;
use iri_string::types::{UriAbsoluteString, UriString}; use iri_string::types::{UriAbsoluteString, UriString};
use openidconnect::{ use openidconnect::{
core::{ core::{
CoreClaimName, CoreClientAuthMethod, CoreClientMetadata, CoreClientRegistrationResponse, CoreAuthErrorResponseType, CoreAuthPrompt, CoreClaimName, CoreClientAuthMethod,
CoreGrantType, CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet, CoreClientMetadata, CoreClientRegistrationResponse, CoreErrorResponseType, CoreGrantType,
CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey, CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims, CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
}, },
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse}, registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
AccessToken, Audience, AuthUrl, ClientId, EmptyAdditionalClaims, url::Url,
AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims,
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId, EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId,
JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, ResponseTypes, Scope, JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, RequestUrl,
StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl, ResponseTypes, Scope, StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
}; };
use rand::rngs::OsRng; use rand::rngs::OsRng;
use rsa::{ use rsa::{
@ -52,20 +54,28 @@ use urlencoding::decode;
use uuid::Uuid; use uuid::Uuid;
mod config; mod config;
mod db;
mod session; mod session;
use db::*;
use session::*; use session::*;
const KID: &str = "key1"; const KID: &str = "key1";
const KV_CLIENT_PREFIX: &str = "clients"; const ENTRY_LIFETIME: usize = 30;
const ENTRY_LIFETIME: usize = 60 * 60 * 24 * 2;
type ConnectionPool = Pool<RedisConnectionManager>; type ConnectionPool = Pool<RedisConnectionManager>;
#[derive(Serialize, Debug)]
pub struct TokenError {
pub error: CoreErrorResponseType,
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum CustomError { pub enum CustomError {
#[error("{0}")] #[error("{0}")]
BadRequest(String), BadRequest(String),
#[error("{0:?}")]
BadRequestToken(Json<TokenError>),
#[error("{0}")] #[error("{0}")]
Unauthorized(String), Unauthorized(String),
#[error(transparent)] #[error(transparent)]
@ -78,11 +88,17 @@ impl IntoResponse for CustomError {
fn into_response(self) -> Response<Self::Body> { fn into_response(self) -> Response<Self::Body> {
match self { match self {
CustomError::BadRequest(_) => (StatusCode::BAD_REQUEST, self.to_string()), CustomError::BadRequest(_) => {
CustomError::Unauthorized(_) => (StatusCode::UNAUTHORIZED, self.to_string()), (StatusCode::BAD_REQUEST, self.to_string()).into_response()
CustomError::Other(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), }
CustomError::BadRequestToken(e) => (StatusCode::BAD_REQUEST, e).into_response(),
CustomError::Unauthorized(_) => {
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
}
CustomError::Other(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
} }
.into_response()
} }
} }
@ -163,29 +179,33 @@ async fn provider_metadata(
.join("register") .join("register")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?, .map_err(|e| anyhow!("Unable to join URL: {}", e))?,
))) )))
.set_token_endpoint_auth_methods_supported(Some(vec![CoreClientAuthMethod::ClientSecretPost])); .set_token_endpoint_auth_methods_supported(Some(vec![
CoreClientAuthMethod::ClientSecretBasic,
CoreClientAuthMethod::ClientSecretPost,
]));
Ok(pm.into()) Ok(pm.into())
} }
#[derive(Deserialize)] #[derive(Serialize, Deserialize)]
struct TokenForm { struct TokenForm {
code: String, code: String,
client_id: String, client_id: Option<String>,
client_secret: Option<String>, client_secret: Option<String>,
grant_type: CoreGrantType, // TODO should just be authorization_code apparently? grant_type: CoreGrantType, // TODO should just be authorization_code apparently?
} }
// TODO should check Authorization header // TODO should check Authorization header
// Actually, client secret can be // Actually, client secret can be
// 1. in the POST (currently supported) // 1. in the POST (currently supported) [x]
// 2. Authorization header // 2. Authorization header [x]
// 3. JWT // 3. JWT [ ]
// 4. signed JWT // 4. signed JWT [ ]
// according to Keycloak // according to Keycloak
async fn token( async fn token(
form: Form<TokenForm>, form: Form<TokenForm>,
bearer: Option<TypedHeader<Authorization<Bearer>>>,
Extension(private_key): Extension<RsaPrivateKey>, Extension(private_key): Extension<RsaPrivateKey>,
Extension(config): Extension<config::Config>, Extension(config): Extension<config::Config>,
Extension(pool): Extension<ConnectionPool>, Extension(pool): Extension<ConnectionPool>,
@ -195,29 +215,17 @@ async fn token(
.await .await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?; .map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
if let Some(secret) = form.client_secret.clone() {
let stored_secret: Option<String> = conn
.get(format!("{}/{}", KV_CLIENT_PREFIX, form.client_id))
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if stored_secret.is_none() {
Err(CustomError::Unauthorized(
"Unrecognised client id.".to_string(),
))?;
}
if secret != stored_secret.unwrap() {
Err(CustomError::Unauthorized("Bad secret.".to_string()))?;
}
} else if config.require_secret {
Err(CustomError::Unauthorized("Secret required.".to_string()))?;
}
let serialized_entry: Option<Vec<u8>> = conn let serialized_entry: Option<Vec<u8>> = conn
.get(form.code.to_string()) .get(form.code.to_string())
.await .await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?; .map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if serialized_entry.is_none() { if serialized_entry.is_none() {
Err(CustomError::BadRequest("Unknown code.".to_string()))?; return Err(CustomError::BadRequestToken(
TokenError {
error: CoreErrorResponseType::InvalidGrant,
}
.into(),
));
} }
let code_entry: CodeEntry = bincode::deserialize( let code_entry: CodeEntry = bincode::deserialize(
&hex::decode(serialized_entry.unwrap()) &hex::decode(serialized_entry.unwrap())
@ -225,9 +233,37 @@ async fn token(
) )
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?; .map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
let client_id = if let Some(c) = form.client_id.clone() {
c
} else {
code_entry.client_id.clone()
};
if let Some(secret) = if let Some(TypedHeader(Authorization(b))) = bearer {
Some(b.token().to_string())
} else {
form.client_secret.clone()
} {
let conn2 = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let client_entry = get_client(conn2, client_id.clone()).await?;
if client_entry.is_none() {
return Err(CustomError::Unauthorized(
"Unrecognised client id.".to_string(),
));
}
if secret != client_entry.unwrap().secret {
return Err(CustomError::Unauthorized("Bad secret.".to_string()));
}
} else if config.require_secret {
return Err(CustomError::Unauthorized("Secret required.".to_string()));
}
if code_entry.exchange_count > 0 { if code_entry.exchange_count > 0 {
// TODO use Oauth error response // TODO use Oauth error response
Err(anyhow!("Code was previously exchanged."))?; return Err(anyhow!("Code was previously exchanged.").into());
} }
conn.set_ex( conn.set_ex(
form.code.to_string(), form.code.to_string(),
@ -240,10 +276,10 @@ async fn token(
.await .await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?; .map_err(|e| anyhow!("Failed to set kv: {}", e))?;
let access_token = AccessToken::new(form.code.to_string().clone()); let access_token = AccessToken::new(form.code.to_string());
let core_id_token = CoreIdTokenClaims::new( let core_id_token = CoreIdTokenClaims::new(
IssuerUrl::from_url(config.base_url), IssuerUrl::from_url(config.base_url),
vec![Audience::new(form.client_id.clone())], vec![Audience::new(client_id.clone())],
Utc::now() + Duration::seconds(60), Utc::now() + Duration::seconds(60),
Utc::now(), Utc::now(),
StandardClaims::new(SubjectIdentifier::new(code_entry.address)), StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
@ -278,58 +314,162 @@ struct AuthorizeParams {
client_id: String, client_id: String,
redirect_uri: RedirectUrl, redirect_uri: RedirectUrl,
scope: Scope, scope: Scope,
response_type: CoreResponseType, response_type: Option<CoreResponseType>,
state: String, state: Option<String>,
nonce: Option<Nonce>, nonce: Option<Nonce>,
prompt: Option<CoreAuthPrompt>,
request_uri: Option<RequestUrl>,
request: Option<String>,
} }
// TODO handle `registration` parameter // TODO handle `registration` parameter
async fn authorize( async fn authorize(
session: UserSessionFromSession, session: UserSessionFromSession,
params: Query<AuthorizeParams>, params: Query<AuthorizeParams>,
// Extension(private_key): Extension<RsaPrivateKey>, Extension(pool): Extension<ConnectionPool>,
) -> Result<(HeaderMap, Redirect), CustomError> { ) -> Result<(HeaderMap, Redirect), CustomError> {
// TODO: Enforce Client Registration let conn = pool
// let d = std::str::from_utf8( .get()
// &jwk.decrypt( .await
// PaddingScheme::new_pkcs1v15_encrypt(), .map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
// &params.client_id.as_bytes(), let client_entry = get_client(conn, params.client_id.clone())
// ) .await
// .map_err(|e| anyhow!("Failed to decrypt client id: {}", e))?, .map_err(|e| anyhow!("Failed to get kv: {}", e))?;
// ) if client_entry.is_none() {
// .map_err(|e| anyhow!("Failed to decrypt client id: {}", e))? return Err(CustomError::Unauthorized(
// if d != params.redirect_uri.as_str() { "Unrecognised client id.".to_string(),
// return Err(anyhow!("Client id not composed of redirect url")); ));
// }; }
let mut r_u = params.0.redirect_uri.clone().url().clone();
r_u.set_query(None);
let mut r_us: Vec<Url> = client_entry
.unwrap()
.redirect_uris
.iter_mut()
.map(|u| u.url().clone())
.collect();
r_us.iter_mut().for_each(|u| u.set_query(None));
if !r_us.contains(&r_u) {
return Ok((
HeaderMap::new(),
Redirect::to(
"/error?message=unregistered_request_uri"
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
let state = if let Some(s) = params.0.state.clone() {
s
} else if params.0.request_uri.is_some() {
let mut url = params.0.redirect_uri.url().clone();
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::RequestUriNotSupported.as_ref(),
);
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
} else if params.0.request.is_some() {
let mut url = params.0.redirect_uri.url().clone();
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::RequestNotSupported.as_ref(),
);
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
} else {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut()
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
url.query_pairs_mut()
.append_pair("error_description", "Missing state");
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
};
if let Some(CoreAuthPrompt::None) = params.0.prompt {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("state", &state);
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::InteractionRequired.as_ref(),
);
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
if params.0.response_type.is_none() {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("state", &state);
url.query_pairs_mut()
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
url.query_pairs_mut()
.append_pair("error_description", "Missing response_type");
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
let response_type = params.0.response_type.as_ref().unwrap();
if params.scope != Scope::new("openid".to_string()) { if params.scope != Scope::new("openid".to_string()) {
Err(anyhow!("Scope not supported"))?; return Err(anyhow!("Scope not supported").into());
} }
let (nonce, headers) = match session { let (nonce, headers) = match session {
UserSessionFromSession::FoundUserSession(nonce) => (nonce, HeaderMap::new()), UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
UserSessionFromSession::InvalidUserSession(cookie) => { UserSessionFromSession::Invalid(cookie) => {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, cookie); headers.insert(header::SET_COOKIE, cookie);
return Ok(( return Ok((
headers, headers,
Redirect::to( Redirect::to(
format!( format!(
"/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}{}", "/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{}",
&params.0.client_id, &params.0.client_id,
&params.0.redirect_uri.to_string(), &params.0.redirect_uri.to_string(),
&params.0.scope.to_string(), &params.0.scope.to_string(),
&params.0.response_type.as_ref(), &response_type.as_ref(),
&params.0.state, &state,
&params.0.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or(String::new()) &params.0.client_id,
&params.0.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or_default()
) )
.to_string()
.parse() .parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?, .map_err(|e| anyhow!("Could not parse URI: {}", e))?,
), ),
)); ));
} }
UserSessionFromSession::CreatedFreshUserSession { header, nonce } => { UserSessionFromSession::Created { header, nonce } => {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, header); headers.insert(header::SET_COOKIE, header);
(nonce, headers) (nonce, headers)
@ -346,11 +486,12 @@ async fn authorize(
headers, headers,
Redirect::to( Redirect::to(
format!( format!(
"/?nonce={}&domain={}&redirect_uri={}&state={}{}", "/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}",
nonce, nonce,
domain, domain,
params.redirect_uri.to_string(), params.redirect_uri.to_string(),
params.state, state,
params.client_id,
oidc_nonce_param oidc_nonce_param
) )
.parse() .parse()
@ -417,6 +558,7 @@ struct CodeEntry {
exchange_count: usize, exchange_count: usize,
address: String, address: String,
nonce: Option<Nonce>, nonce: Option<Nonce>,
client_id: String,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -424,6 +566,7 @@ struct SignInParams {
redirect_uri: RedirectUrl, redirect_uri: RedirectUrl,
state: String, state: String,
oidc_nonce: Option<Nonce>, oidc_nonce: Option<Nonce>,
client_id: String,
} }
async fn sign_in( async fn sign_in(
@ -438,39 +581,39 @@ async fn sign_in(
&decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?, &decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?,
) )
.map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?, .map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?,
None => Err(anyhow!("No `siwe` cookie"))?, None => {
return Err(anyhow!("No `siwe` cookie").into());
}
}; };
let (nonce, headers) = match session { let (nonce, headers) = match session {
UserSessionFromSession::FoundUserSession(nonce) => (nonce, HeaderMap::new()), UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
UserSessionFromSession::InvalidUserSession(header) => { UserSessionFromSession::Invalid(header) => {
headers.insert(header::SET_COOKIE, header); headers.insert(header::SET_COOKIE, header);
return Ok(( return Ok((
headers, headers,
Redirect::to( Redirect::to(
format!( format!(
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}", "/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
&params.0.redirect_uri.to_string(), &params.0.client_id.clone(),
&params.0.redirect_uri.to_string(), &params.0.redirect_uri.to_string(),
&params.0.state, &params.0.state,
) )
.to_string()
.parse() .parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?, .map_err(|e| anyhow!("Could not parse URI: {}", e))?,
), ),
)); ));
} }
UserSessionFromSession::CreatedFreshUserSession { .. } => { UserSessionFromSession::Created { .. } => {
return Ok(( return Ok((
headers, headers,
Redirect::to( Redirect::to(
format!( format!(
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}", "/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
&params.0.redirect_uri.to_string(), &params.0.client_id.clone(),
&params.0.redirect_uri.to_string(), &params.0.redirect_uri.to_string(),
&params.0.state, &params.0.state,
) )
.to_string()
.parse() .parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?, .map_err(|e| anyhow!("Could not parse URI: {}", e))?,
), ),
@ -484,11 +627,12 @@ async fn sign_in(
.chars() .chars()
.skip(2) .skip(2)
.take(130) .take(130)
.collect::<String>() .collect::<String>(),
.clone(),
) { ) {
Ok(s) => s, Ok(s) => s,
Err(e) => Err(CustomError::BadRequest(format!("Bad signature: {}", e)))?, Err(e) => {
return Err(CustomError::BadRequest(format!("Bad signature: {}", e)));
}
}; };
let message = siwe_cookie let message = siwe_cookie
@ -502,16 +646,17 @@ async fn sign_in(
let domain = params.redirect_uri.url().host().unwrap(); let domain = params.redirect_uri.url().host().unwrap();
if domain.to_string() != siwe_cookie.message.domain { if domain.to_string() != siwe_cookie.message.domain {
Err(anyhow!("Conflicting domains in message and redirect"))? return Err(anyhow!("Conflicting domains in message and redirect").into());
} }
if nonce != siwe_cookie.message.nonce { if nonce != siwe_cookie.message.nonce {
Err(anyhow!("Conflicting nonces in message and session"))? return Err(anyhow!("Conflicting nonces in message and session").into());
} }
let code_entry = CodeEntry { let code_entry = CodeEntry {
address: siwe_cookie.message.address, address: siwe_cookie.message.address,
nonce: params.oidc_nonce.clone(), nonce: params.oidc_nonce.clone(),
exchange_count: 0, exchange_count: 0,
client_id: params.0.client_id.clone(),
}; };
let code = Uuid::new_v4(); let code = Uuid::new_v4();
@ -547,25 +692,31 @@ async fn sign_in(
async fn register( async fn register(
extract::Json(payload): extract::Json<CoreClientMetadata>, extract::Json(payload): extract::Json<CoreClientMetadata>,
Extension(pool): Extension<ConnectionPool>, Extension(pool): Extension<ConnectionPool>,
) -> Result<Json<CoreClientRegistrationResponse>, CustomError> { ) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> {
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let secret = Uuid::new_v4(); let secret = Uuid::new_v4();
let mut conn = pool let conn = pool
.get() .get()
.await .await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?; .map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
conn.set(format!("{}/{}", KV_CLIENT_PREFIX, id), secret.to_string()) let entry = ClientEntry {
.await secret: secret.to_string(),
.map_err(|e| anyhow!("Failed to set kv: {}", e))?; redirect_uris: payload.redirect_uris().to_vec(),
};
set_client(conn, id.to_string(), entry).await?;
Ok(CoreClientRegistrationResponse::new( Ok((
StatusCode::CREATED,
CoreClientRegistrationResponse::new(
ClientId::new(id.to_string()), ClientId::new(id.to_string()),
payload.redirect_uris().to_vec(), payload.redirect_uris().to_vec(),
EmptyAdditionalClientMetadata::default(), EmptyAdditionalClientMetadata::default(),
EmptyAdditionalClientRegistrationResponse::default(), EmptyAdditionalClientRegistrationResponse::default(),
) )
.into()) .set_client_secret(Some(ClientSecret::new(secret.to_string())))
.into(),
))
} }
// TODO CORS // TODO CORS
@ -586,7 +737,7 @@ async fn userinfo(
.await .await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?; .map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if serialized_entry.is_none() { if serialized_entry.is_none() {
Err(CustomError::BadRequest("Unknown code.".to_string()))?; return Err(CustomError::BadRequest("Unknown code.".to_string()));
} }
let code_entry: CodeEntry = bincode::deserialize( let code_entry: CodeEntry = bincode::deserialize(
&hex::decode(serialized_entry.unwrap()) &hex::decode(serialized_entry.unwrap())
@ -616,21 +767,23 @@ async fn main() {
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap(); let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
let pool2 = bb8::Pool::builder().build(manager).await.unwrap(); let pool2 = bb8::Pool::builder().build(manager).await.unwrap();
let mut conn = pool2 for (id, secret) in &config.default_clients.clone() {
let conn = pool2
.get() .get()
.await .await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e)) .map_err(|e| anyhow!("Failed to get connection to database: {}", e))
.unwrap(); .unwrap();
for (id, secret) in &config.default_clients.clone() { let client_entry = ClientEntry {
let _: () = conn secret: secret.to_string(),
.set(format!("{}/{}", KV_CLIENT_PREFIX, id), secret) redirect_uris: vec![],
};
set_client(conn, id.to_string(), client_entry)
.await .await
.map_err(|e| anyhow!("Failed to set kv: {}", e)) .unwrap(); // TODO
.unwrap();
} }
let private_key = if let Some(key) = &config.rsa_pem { let private_key = if let Some(key) = &config.rsa_pem {
RsaPrivateKey::from_pkcs1_pem(&key) RsaPrivateKey::from_pkcs1_pem(key)
.map_err(|e| anyhow!("Failed to load private key: {}", e)) .map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap() .unwrap()
} else { } else {
@ -680,6 +833,17 @@ async fn main() {
}, },
), ),
) )
.route(
"/error",
service_method_routing::get(ServeFile::new("./static/error.html")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route( .route(
"/favicon.png", "/favicon.png",
service_method_routing::get(ServeFile::new("./static/favicon.png")).handle_error( service_method_routing::get(ServeFile::new("./static/favicon.png")).handle_error(

View File

@ -15,9 +15,9 @@ const SESSION_COOKIE_NAME: &str = "session";
const SESSION_KEY: &str = "user_session"; const SESSION_KEY: &str = "user_session";
pub enum UserSessionFromSession { pub enum UserSessionFromSession {
FoundUserSession(String), Found(String),
CreatedFreshUserSession { header: HeaderValue, nonce: String }, Created { header: HeaderValue, nonce: String },
InvalidUserSession(HeaderValue), Invalid(HeaderValue),
} }
#[async_trait] #[async_trait]
@ -52,12 +52,11 @@ where
.and_then(|value| value.to_str().ok()) .and_then(|value| value.to_str().ok())
.map(|header| { .map(|header| {
header header
.split(";") .split(';')
.map(|cookie| Cookie::parse(cookie).ok()) .map(|cookie| Cookie::parse(cookie).ok())
.filter(|cookie| { .find(|cookie| {
cookie.is_some() && cookie.as_ref().unwrap().name() == SESSION_COOKIE_NAME cookie.is_some() && cookie.as_ref().unwrap().name() == SESSION_COOKIE_NAME
}) })
.next()
}) })
.flatten() .flatten()
.flatten() .flatten()
@ -69,7 +68,7 @@ where
session.insert(SESSION_KEY, user_session.clone()).unwrap(); session.insert(SESSION_KEY, user_session.clone()).unwrap();
let cookie = store.store_session(session).await.unwrap().unwrap(); let cookie = store.store_session(session).await.unwrap().unwrap();
return Ok(Self::CreatedFreshUserSession { return Ok(Self::Created {
header: Cookie::new(SESSION_COOKIE_NAME, cookie) header: Cookie::new(SESSION_COOKIE_NAME, cookie)
.to_string() .to_string()
.parse() .parse()
@ -84,9 +83,7 @@ where
debug!("Could not load session"); debug!("Could not load session");
let mut cookie = session_cookie.clone(); let mut cookie = session_cookie.clone();
cookie.make_removal(); cookie.make_removal();
return Ok(Self::InvalidUserSession( return Ok(Self::Invalid(cookie.to_string().parse().unwrap()));
cookie.to_string().parse().unwrap(),
));
} }
}; };
let user_session = if let Some(user_session) = session.get::<UserSession>(SESSION_KEY) { let user_session = if let Some(user_session) = session.get::<UserSession>(SESSION_KEY) {
@ -95,12 +92,10 @@ where
debug!("No `user_session` found in session"); debug!("No `user_session` found in session");
let mut cookie = session_cookie.clone(); let mut cookie = session_cookie.clone();
cookie.make_removal(); cookie.make_removal();
return Ok(Self::InvalidUserSession( return Ok(Self::Invalid(cookie.to_string().parse().unwrap()));
cookie.to_string().parse().unwrap(),
));
}; };
Ok(Self::FoundUserSession(user_session.nonce)) Ok(Self::Found(user_session.nonce))
} }
} }

20
static/error.html Normal file
View File

@ -0,0 +1,20 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width,initial-scale=1" />
<title>SIWE Open ID Connect</title>
<link rel="icon" type="image/png" href="/favicon.png" />
<link
href="https://api.fontshare.com/css?f[]=satoshi@300,301,400,401,500,501,700,701,900,901,1,2&display=swap"
rel="stylesheet"
/>
</head>
<body>
<h1>Invalid request</h1>
</body>
</html>