2021-11-29 15:05:43 -05:00
use anyhow ::{ anyhow , Result } ;
use async_redis_session ::RedisSessionStore ;
use axum ::{
body ::{ Bytes , Full } ,
error_handling ::HandleErrorExt ,
extract ::{ self , Extension , Form , Query , TypedHeader } ,
http ::{
header ::{ self , HeaderMap } ,
Response , StatusCode ,
} ,
response ::{ IntoResponse , Redirect } ,
routing ::{ get , post , service_method_routing } ,
AddExtensionLayer , Json , Router ,
} ;
use bb8_redis ::{ bb8 , bb8 ::Pool , redis ::AsyncCommands , RedisConnectionManager } ;
use chrono ::{ Duration , Utc } ;
use figment ::{
providers ::{ Env , Format , Serialized , Toml } ,
Figment ,
} ;
use headers ::{ self , authorization ::Bearer , Authorization } ;
use hex ::FromHex ;
use iri_string ::types ::{ UriAbsoluteString , UriString } ;
use openidconnect ::{
core ::{
2021-12-20 11:29:43 -05:00
CoreAuthErrorResponseType , CoreAuthPrompt , CoreClaimName , CoreClientAuthMethod ,
CoreClientMetadata , CoreClientRegistrationResponse , CoreErrorResponseType , CoreGrantType ,
CoreIdToken , CoreIdTokenClaims , CoreIdTokenFields , CoreJsonWebKeySet ,
2021-11-29 15:05:43 -05:00
CoreJwsSigningAlgorithm , CoreProviderMetadata , CoreResponseType , CoreRsaPrivateSigningKey ,
CoreSubjectIdentifierType , CoreTokenResponse , CoreTokenType , CoreUserInfoClaims ,
} ,
registration ::{ EmptyAdditionalClientMetadata , EmptyAdditionalClientRegistrationResponse } ,
2021-12-20 11:29:43 -05:00
url ::Url ,
AccessToken , Audience , AuthUrl , ClientId , ClientSecret , EmptyAdditionalClaims ,
2021-11-29 15:05:43 -05:00
EmptyAdditionalProviderMetadata , EmptyExtraTokenFields , IssuerUrl , JsonWebKeyId ,
2021-12-20 11:29:43 -05:00
JsonWebKeySetUrl , Nonce , PrivateSigningKey , RedirectUrl , RegistrationUrl , RequestUrl ,
ResponseTypes , Scope , StandardClaims , SubjectIdentifier , TokenUrl , UserInfoUrl ,
2021-11-29 15:05:43 -05:00
} ;
use rand ::rngs ::OsRng ;
use rsa ::{
pkcs1 ::{ FromRsaPrivateKey , ToRsaPrivateKey } ,
RsaPrivateKey ,
} ;
use serde ::{ Deserialize , Serialize } ;
use siwe ::eip4361 ::{ Message , Version } ;
use std ::{ convert ::Infallible , net ::SocketAddr , str ::FromStr } ;
use thiserror ::Error ;
use tower_http ::{
services ::{ ServeDir , ServeFile } ,
trace ::TraceLayer ,
} ;
use tracing ::info ;
use urlencoding ::decode ;
use uuid ::Uuid ;
mod config ;
2021-12-20 11:29:43 -05:00
mod db ;
2021-11-29 15:05:43 -05:00
mod session ;
2021-12-20 11:29:43 -05:00
use db ::* ;
2021-11-29 15:05:43 -05:00
use session ::* ;
const KID : & str = " key1 " ;
2021-12-20 11:29:43 -05:00
const ENTRY_LIFETIME : usize = 30 ;
2021-11-29 15:05:43 -05:00
type ConnectionPool = Pool < RedisConnectionManager > ;
2021-12-20 11:29:43 -05:00
#[ derive(Serialize, Debug) ]
pub struct TokenError {
pub error : CoreErrorResponseType ,
}
2021-11-29 15:05:43 -05:00
#[ derive(Debug, Error) ]
pub enum CustomError {
#[ error( " {0} " ) ]
BadRequest ( String ) ,
2021-12-20 11:29:43 -05:00
#[ error( " {0:?} " ) ]
BadRequestToken ( Json < TokenError > ) ,
2021-11-29 15:05:43 -05:00
#[ error( " {0} " ) ]
Unauthorized ( String ) ,
#[ error(transparent) ]
Other ( #[ from ] anyhow ::Error ) ,
}
impl IntoResponse for CustomError {
type Body = Full < Bytes > ;
type BodyError = Infallible ;
fn into_response ( self ) -> Response < Self ::Body > {
match self {
2021-12-20 11:29:43 -05:00
CustomError ::BadRequest ( _ ) = > {
( StatusCode ::BAD_REQUEST , self . to_string ( ) ) . into_response ( )
}
CustomError ::BadRequestToken ( e ) = > ( StatusCode ::BAD_REQUEST , e ) . into_response ( ) ,
CustomError ::Unauthorized ( _ ) = > {
( StatusCode ::UNAUTHORIZED , self . to_string ( ) ) . into_response ( )
}
CustomError ::Other ( _ ) = > {
( StatusCode ::INTERNAL_SERVER_ERROR , self . to_string ( ) ) . into_response ( )
}
2021-11-29 15:05:43 -05:00
}
}
}
async fn jwk_set (
Extension ( private_key ) : Extension < RsaPrivateKey > ,
) -> Result < Json < CoreJsonWebKeySet > , CustomError > {
let pem = private_key
. to_pkcs1_pem ( )
. map_err ( | e | anyhow! ( " Failed to serialise key as PEM: {} " , e ) ) ? ;
let jwks = CoreJsonWebKeySet ::new ( vec! [ CoreRsaPrivateSigningKey ::from_pem (
& pem ,
Some ( JsonWebKeyId ::new ( KID . to_string ( ) ) ) ,
)
. map_err ( | e | anyhow! ( " Invalid RSA private key: {} " , e ) ) ?
. as_verification_key ( ) ] ) ;
Ok ( jwks . into ( ) )
}
async fn provider_metadata (
Extension ( config ) : Extension < config ::Config > ,
) -> Result < Json < CoreProviderMetadata > , CustomError > {
let pm = CoreProviderMetadata ::new (
IssuerUrl ::from_url ( config . base_url . clone ( ) ) ,
AuthUrl ::from_url (
config
. base_url
. join ( " authorize " )
. map_err ( | e | anyhow! ( " Unable to join URL: {} " , e ) ) ? ,
) ,
JsonWebKeySetUrl ::from_url (
config
. base_url
. join ( " jwk " )
. map_err ( | e | anyhow! ( " Unable to join URL: {} " , e ) ) ? ,
) ,
vec! [
ResponseTypes ::new ( vec! [ CoreResponseType ::Code ] ) ,
ResponseTypes ::new ( vec! [ CoreResponseType ::Token , CoreResponseType ::IdToken ] ) ,
] ,
vec! [ CoreSubjectIdentifierType ::Pairwise ] ,
vec! [ CoreJwsSigningAlgorithm ::RsaSsaPssSha256 ] ,
EmptyAdditionalProviderMetadata { } ,
)
. set_token_endpoint ( Some ( TokenUrl ::from_url (
config
. base_url
. join ( " token " )
. map_err ( | e | anyhow! ( " Unable to join URL: {} " , e ) ) ? ,
) ) )
. set_userinfo_endpoint ( Some ( UserInfoUrl ::from_url (
config
. base_url
. join ( " userinfo " )
. map_err ( | e | anyhow! ( " Unable to join URL: {} " , e ) ) ? ,
) ) )
. set_scopes_supported ( Some ( vec! [
Scope ::new ( " openid " . to_string ( ) ) ,
// Scope::new("email".to_string()),
// Scope::new("profile".to_string()),
] ) )
. set_claims_supported ( Some ( vec! [
CoreClaimName ::new ( " sub " . to_string ( ) ) ,
CoreClaimName ::new ( " aud " . to_string ( ) ) ,
// CoreClaimName::new("email".to_string()),
// CoreClaimName::new("email_verified".to_string()),
CoreClaimName ::new ( " exp " . to_string ( ) ) ,
CoreClaimName ::new ( " iat " . to_string ( ) ) ,
CoreClaimName ::new ( " iss " . to_string ( ) ) ,
// CoreClaimName::new("name".to_string()),
// CoreClaimName::new("given_name".to_string()),
// CoreClaimName::new("family_name".to_string()),
// CoreClaimName::new("picture".to_string()),
// CoreClaimName::new("locale".to_string()),
] ) )
. set_registration_endpoint ( Some ( RegistrationUrl ::from_url (
config
. base_url
. join ( " register " )
. map_err ( | e | anyhow! ( " Unable to join URL: {} " , e ) ) ? ,
) ) )
2021-12-20 11:29:43 -05:00
. set_token_endpoint_auth_methods_supported ( Some ( vec! [
CoreClientAuthMethod ::ClientSecretBasic ,
CoreClientAuthMethod ::ClientSecretPost ,
] ) ) ;
2021-11-29 15:05:43 -05:00
Ok ( pm . into ( ) )
}
2021-12-20 11:29:43 -05:00
#[ derive(Serialize, Deserialize) ]
2021-11-29 15:05:43 -05:00
struct TokenForm {
code : String ,
2021-12-20 11:29:43 -05:00
client_id : Option < String > ,
2021-11-29 15:05:43 -05:00
client_secret : Option < String > ,
grant_type : CoreGrantType , // TODO should just be authorization_code apparently?
}
// TODO should check Authorization header
// Actually, client secret can be
2021-12-20 11:29:43 -05:00
// 1. in the POST (currently supported) [x]
// 2. Authorization header [x]
// 3. JWT [ ]
// 4. signed JWT [ ]
2021-11-29 15:05:43 -05:00
// according to Keycloak
async fn token (
form : Form < TokenForm > ,
2021-12-20 11:29:43 -05:00
bearer : Option < TypedHeader < Authorization < Bearer > > > ,
2021-11-29 15:05:43 -05:00
Extension ( private_key ) : Extension < RsaPrivateKey > ,
Extension ( config ) : Extension < config ::Config > ,
Extension ( pool ) : Extension < ConnectionPool > ,
) -> Result < Json < CoreTokenResponse > , CustomError > {
let mut conn = pool
. get ( )
. await
. map_err ( | e | anyhow! ( " Failed to get connection to database: {} " , e ) ) ? ;
let serialized_entry : Option < Vec < u8 > > = conn
. get ( form . code . to_string ( ) )
. await
. map_err ( | e | anyhow! ( " Failed to get kv: {} " , e ) ) ? ;
if serialized_entry . is_none ( ) {
2021-12-20 11:29:43 -05:00
return Err ( CustomError ::BadRequestToken (
TokenError {
error : CoreErrorResponseType ::InvalidGrant ,
}
. into ( ) ,
) ) ;
2021-11-29 15:05:43 -05:00
}
let code_entry : CodeEntry = bincode ::deserialize (
& hex ::decode ( serialized_entry . unwrap ( ) )
. map_err ( | e | anyhow! ( " Failed to decode code entry: {} " , e ) ) ? ,
)
. map_err ( | e | anyhow! ( " Failed to deserialize code: {} " , e ) ) ? ;
2021-12-20 11:29:43 -05:00
let client_id = if let Some ( c ) = form . client_id . clone ( ) {
c
} else {
code_entry . client_id . clone ( )
} ;
if let Some ( secret ) = if let Some ( TypedHeader ( Authorization ( b ) ) ) = bearer {
Some ( b . token ( ) . to_string ( ) )
} else {
form . client_secret . clone ( )
} {
let conn2 = pool
. get ( )
. await
. map_err ( | e | anyhow! ( " Failed to get connection to database: {} " , e ) ) ? ;
let client_entry = get_client ( conn2 , client_id . clone ( ) ) . await ? ;
if client_entry . is_none ( ) {
return Err ( CustomError ::Unauthorized (
" Unrecognised client id. " . to_string ( ) ,
) ) ;
}
if secret ! = client_entry . unwrap ( ) . secret {
return Err ( CustomError ::Unauthorized ( " Bad secret. " . to_string ( ) ) ) ;
}
} else if config . require_secret {
return Err ( CustomError ::Unauthorized ( " Secret required. " . to_string ( ) ) ) ;
}
2021-11-29 15:05:43 -05:00
if code_entry . exchange_count > 0 {
// TODO use Oauth error response
2021-12-20 11:29:43 -05:00
return Err ( anyhow! ( " Code was previously exchanged. " ) . into ( ) ) ;
2021-11-29 15:05:43 -05:00
}
conn . set_ex (
form . code . to_string ( ) ,
hex ::encode (
bincode ::serialize ( & code_entry )
. map_err ( | e | anyhow! ( " Failed to serialise code: {} " , e ) ) ? ,
) ,
ENTRY_LIFETIME ,
)
. await
. map_err ( | e | anyhow! ( " Failed to set kv: {} " , e ) ) ? ;
2021-12-20 11:29:43 -05:00
let access_token = AccessToken ::new ( form . code . to_string ( ) ) ;
2021-11-29 15:05:43 -05:00
let core_id_token = CoreIdTokenClaims ::new (
IssuerUrl ::from_url ( config . base_url ) ,
2021-12-20 11:29:43 -05:00
vec! [ Audience ::new ( client_id . clone ( ) ) ] ,
2021-11-29 15:05:43 -05:00
Utc ::now ( ) + Duration ::seconds ( 60 ) ,
Utc ::now ( ) ,
StandardClaims ::new ( SubjectIdentifier ::new ( code_entry . address ) ) ,
EmptyAdditionalClaims { } ,
)
. set_nonce ( code_entry . nonce ) ;
let pem = private_key
. to_pkcs1_pem ( )
. map_err ( | e | anyhow! ( " Failed to serialise key as PEM: {} " , e ) ) ? ;
let id_token = CoreIdToken ::new (
core_id_token ,
& CoreRsaPrivateSigningKey ::from_pem ( & pem , Some ( JsonWebKeyId ::new ( KID . to_string ( ) ) ) )
. map_err ( | e | anyhow! ( " Invalid RSA private key: {} " , e ) ) ? ,
CoreJwsSigningAlgorithm ::RsaSsaPkcs1V15Sha256 ,
Some ( & access_token ) ,
None ,
)
. map_err ( | e | anyhow! ( " {} " , e ) ) ? ;
Ok ( CoreTokenResponse ::new (
access_token ,
CoreTokenType ::Bearer ,
CoreIdTokenFields ::new ( Some ( id_token ) , EmptyExtraTokenFields { } ) ,
)
. into ( ) )
}
#[ derive(Deserialize) ]
struct AuthorizeParams {
client_id : String ,
redirect_uri : RedirectUrl ,
scope : Scope ,
2021-12-20 11:29:43 -05:00
response_type : Option < CoreResponseType > ,
state : Option < String > ,
2021-11-29 15:05:43 -05:00
nonce : Option < Nonce > ,
2021-12-20 11:29:43 -05:00
prompt : Option < CoreAuthPrompt > ,
request_uri : Option < RequestUrl > ,
request : Option < String > ,
2021-11-29 15:05:43 -05:00
}
// TODO handle `registration` parameter
async fn authorize (
session : UserSessionFromSession ,
params : Query < AuthorizeParams > ,
2021-12-20 11:29:43 -05:00
Extension ( pool ) : Extension < ConnectionPool > ,
2021-11-29 15:05:43 -05:00
) -> Result < ( HeaderMap , Redirect ) , CustomError > {
2021-12-20 11:29:43 -05:00
let conn = pool
. get ( )
. await
. map_err ( | e | anyhow! ( " Failed to get connection to database: {} " , e ) ) ? ;
let client_entry = get_client ( conn , params . client_id . clone ( ) )
. await
. map_err ( | e | anyhow! ( " Failed to get kv: {} " , e ) ) ? ;
if client_entry . is_none ( ) {
return Err ( CustomError ::Unauthorized (
" Unrecognised client id. " . to_string ( ) ,
) ) ;
}
let mut r_u = params . 0. redirect_uri . clone ( ) . url ( ) . clone ( ) ;
r_u . set_query ( None ) ;
let mut r_us : Vec < Url > = client_entry
. unwrap ( )
. redirect_uris
. iter_mut ( )
. map ( | u | u . url ( ) . clone ( ) )
. collect ( ) ;
r_us . iter_mut ( ) . for_each ( | u | u . set_query ( None ) ) ;
if ! r_us . contains ( & r_u ) {
return Ok ( (
HeaderMap ::new ( ) ,
Redirect ::to (
" /error?message=unregistered_request_uri "
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) ) ;
}
let state = if let Some ( s ) = params . 0. state . clone ( ) {
s
} else if params . 0. request_uri . is_some ( ) {
let mut url = params . 0. redirect_uri . url ( ) . clone ( ) ;
url . query_pairs_mut ( ) . append_pair (
" error " ,
CoreAuthErrorResponseType ::RequestUriNotSupported . as_ref ( ) ,
) ;
return Ok ( (
HeaderMap ::new ( ) ,
Redirect ::to (
url . as_str ( )
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) ) ;
} else if params . 0. request . is_some ( ) {
let mut url = params . 0. redirect_uri . url ( ) . clone ( ) ;
url . query_pairs_mut ( ) . append_pair (
" error " ,
CoreAuthErrorResponseType ::RequestNotSupported . as_ref ( ) ,
) ;
return Ok ( (
HeaderMap ::new ( ) ,
Redirect ::to (
url . as_str ( )
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) ) ;
} else {
let mut url = params . redirect_uri . url ( ) . clone ( ) ;
url . query_pairs_mut ( )
. append_pair ( " error " , CoreAuthErrorResponseType ::InvalidRequest . as_ref ( ) ) ;
url . query_pairs_mut ( )
. append_pair ( " error_description " , " Missing state " ) ;
return Ok ( (
HeaderMap ::new ( ) ,
Redirect ::to (
url . as_str ( )
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) ) ;
} ;
if let Some ( CoreAuthPrompt ::None ) = params . 0. prompt {
let mut url = params . redirect_uri . url ( ) . clone ( ) ;
url . query_pairs_mut ( ) . append_pair ( " state " , & state ) ;
url . query_pairs_mut ( ) . append_pair (
" error " ,
CoreAuthErrorResponseType ::InteractionRequired . as_ref ( ) ,
) ;
return Ok ( (
HeaderMap ::new ( ) ,
Redirect ::to (
url . as_str ( )
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) ) ;
}
if params . 0. response_type . is_none ( ) {
let mut url = params . redirect_uri . url ( ) . clone ( ) ;
url . query_pairs_mut ( ) . append_pair ( " state " , & state ) ;
url . query_pairs_mut ( )
. append_pair ( " error " , CoreAuthErrorResponseType ::InvalidRequest . as_ref ( ) ) ;
url . query_pairs_mut ( )
. append_pair ( " error_description " , " Missing response_type " ) ;
return Ok ( (
HeaderMap ::new ( ) ,
Redirect ::to (
url . as_str ( )
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) ) ;
}
let response_type = params . 0. response_type . as_ref ( ) . unwrap ( ) ;
2021-11-29 15:05:43 -05:00
if params . scope ! = Scope ::new ( " openid " . to_string ( ) ) {
2021-12-20 11:29:43 -05:00
return Err ( anyhow! ( " Scope not supported " ) . into ( ) ) ;
2021-11-29 15:05:43 -05:00
}
let ( nonce , headers ) = match session {
2021-12-20 11:29:43 -05:00
UserSessionFromSession ::Found ( nonce ) = > ( nonce , HeaderMap ::new ( ) ) ,
UserSessionFromSession ::Invalid ( cookie ) = > {
2021-11-29 15:05:43 -05:00
let mut headers = HeaderMap ::new ( ) ;
headers . insert ( header ::SET_COOKIE , cookie ) ;
return Ok ( (
headers ,
Redirect ::to (
format! (
2021-12-20 11:29:43 -05:00
" /authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{} " ,
2021-11-29 15:05:43 -05:00
& params . 0. client_id ,
& params . 0. redirect_uri . to_string ( ) ,
& params . 0. scope . to_string ( ) ,
2021-12-20 11:29:43 -05:00
& response_type . as_ref ( ) ,
& state ,
& params . 0. client_id ,
& params . 0. nonce . map ( | n | format! ( " &nonce= {} " , n . secret ( ) ) ) . unwrap_or_default ( )
2021-11-29 15:05:43 -05:00
)
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) ) ;
}
2021-12-20 11:29:43 -05:00
UserSessionFromSession ::Created { header , nonce } = > {
2021-11-29 15:05:43 -05:00
let mut headers = HeaderMap ::new ( ) ;
headers . insert ( header ::SET_COOKIE , header ) ;
( nonce , headers )
}
} ;
let domain = params . redirect_uri . url ( ) . host ( ) . unwrap ( ) ;
let oidc_nonce_param = if let Some ( n ) = & params . nonce {
format! ( " &oidc_nonce= {} " , n . secret ( ) )
} else {
" " . to_string ( )
} ;
Ok ( (
headers ,
Redirect ::to (
format! (
2021-12-20 11:29:43 -05:00
" /?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{} " ,
2021-11-29 15:05:43 -05:00
nonce ,
domain ,
params . redirect_uri . to_string ( ) ,
2021-12-20 11:29:43 -05:00
state ,
params . client_id ,
2021-11-29 15:05:43 -05:00
oidc_nonce_param
)
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) )
}
#[ derive(Serialize, Deserialize) ]
struct SiweCookie {
message : Web3ModalMessage ,
signature : String ,
}
#[ derive(Serialize, Deserialize) ]
#[ serde(rename_all = " camelCase " ) ]
struct Web3ModalMessage {
pub domain : String ,
pub address : String ,
pub statement : String ,
pub uri : String ,
pub version : String ,
pub chain_id : String ,
pub nonce : String ,
pub issued_at : String ,
pub expiration_time : Option < String > ,
pub not_before : Option < String > ,
pub request_id : Option < String > ,
pub resources : Option < Vec < String > > ,
}
impl Web3ModalMessage {
pub fn to_eip4361_message ( & self ) -> Result < Message > {
let mut next_resources : Vec < UriString > = Vec ::new ( ) ;
match & self . resources {
Some ( resources ) = > {
for resource in resources {
let x = UriString ::from_str ( resource ) ? ;
next_resources . push ( x )
}
}
None = > { }
}
Ok ( Message {
domain : self . domain . clone ( ) . try_into ( ) ? ,
address : < [ u8 ; 20 ] > ::from_hex ( self . address . chars ( ) . skip ( 2 ) . collect ::< String > ( ) ) ? ,
statement : self . statement . to_string ( ) ,
uri : UriAbsoluteString ::from_str ( & self . uri ) ? ,
version : Version ::from_str ( & self . version ) ? ,
chain_id : self . chain_id . to_string ( ) ,
nonce : self . nonce . to_string ( ) ,
issued_at : self . issued_at . to_string ( ) ,
expiration_time : self . expiration_time . clone ( ) ,
not_before : self . not_before . clone ( ) ,
request_id : self . request_id . clone ( ) ,
resources : next_resources ,
} )
}
}
#[ derive(Serialize, Deserialize) ]
struct CodeEntry {
exchange_count : usize ,
address : String ,
nonce : Option < Nonce > ,
2021-12-20 11:29:43 -05:00
client_id : String ,
2021-11-29 15:05:43 -05:00
}
#[ derive(Deserialize) ]
struct SignInParams {
redirect_uri : RedirectUrl ,
state : String ,
oidc_nonce : Option < Nonce > ,
2021-12-20 11:29:43 -05:00
client_id : String ,
2021-11-29 15:05:43 -05:00
}
async fn sign_in (
session : UserSessionFromSession ,
params : Query < SignInParams > ,
TypedHeader ( cookies ) : TypedHeader < headers ::Cookie > ,
Extension ( pool ) : Extension < ConnectionPool > ,
) -> Result < ( HeaderMap , Redirect ) , CustomError > {
let mut headers = HeaderMap ::new ( ) ;
let siwe_cookie : SiweCookie = match cookies . get ( " siwe " ) {
Some ( c ) = > serde_json ::from_str (
& decode ( c ) . map_err ( | e | anyhow! ( " Could not decode siwe cookie: {} " , e ) ) ? ,
)
. map_err ( | e | anyhow! ( " Could not deserialize siwe cookie: {} " , e ) ) ? ,
2021-12-20 11:29:43 -05:00
None = > {
return Err ( anyhow! ( " No `siwe` cookie " ) . into ( ) ) ;
}
2021-11-29 15:05:43 -05:00
} ;
let ( nonce , headers ) = match session {
2021-12-20 11:29:43 -05:00
UserSessionFromSession ::Found ( nonce ) = > ( nonce , HeaderMap ::new ( ) ) ,
UserSessionFromSession ::Invalid ( header ) = > {
2021-11-29 15:05:43 -05:00
headers . insert ( header ::SET_COOKIE , header ) ;
return Ok ( (
headers ,
Redirect ::to (
format! (
" /authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={} " ,
2021-12-20 11:29:43 -05:00
& params . 0. client_id . clone ( ) ,
2021-11-29 15:05:43 -05:00
& params . 0. redirect_uri . to_string ( ) ,
& params . 0. state ,
)
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) ) ;
}
2021-12-20 11:29:43 -05:00
UserSessionFromSession ::Created { .. } = > {
2021-11-29 15:05:43 -05:00
return Ok ( (
headers ,
Redirect ::to (
format! (
" /authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={} " ,
2021-12-20 11:29:43 -05:00
& params . 0. client_id . clone ( ) ,
2021-11-29 15:05:43 -05:00
& params . 0. redirect_uri . to_string ( ) ,
& params . 0. state ,
)
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) )
}
} ;
let signature = match < [ u8 ; 65 ] > ::from_hex (
siwe_cookie
. signature
. chars ( )
. skip ( 2 )
. take ( 130 )
2021-12-20 11:29:43 -05:00
. collect ::< String > ( ) ,
2021-11-29 15:05:43 -05:00
) {
Ok ( s ) = > s ,
2021-12-20 11:29:43 -05:00
Err ( e ) = > {
return Err ( CustomError ::BadRequest ( format! ( " Bad signature: {} " , e ) ) ) ;
}
2021-11-29 15:05:43 -05:00
} ;
let message = siwe_cookie
. message
. to_eip4361_message ( )
. map_err ( | e | anyhow! ( " Failed to serialise message: {} " , e ) ) ? ;
info! ( " {} " , message ) ;
message
. verify_eip191 ( signature )
. map_err ( | e | anyhow! ( " Failed signature validation: {} " , e ) ) ? ;
let domain = params . redirect_uri . url ( ) . host ( ) . unwrap ( ) ;
if domain . to_string ( ) ! = siwe_cookie . message . domain {
2021-12-20 11:29:43 -05:00
return Err ( anyhow! ( " Conflicting domains in message and redirect " ) . into ( ) ) ;
2021-11-29 15:05:43 -05:00
}
if nonce ! = siwe_cookie . message . nonce {
2021-12-20 11:29:43 -05:00
return Err ( anyhow! ( " Conflicting nonces in message and session " ) . into ( ) ) ;
2021-11-29 15:05:43 -05:00
}
let code_entry = CodeEntry {
address : siwe_cookie . message . address ,
nonce : params . oidc_nonce . clone ( ) ,
exchange_count : 0 ,
2021-12-20 11:29:43 -05:00
client_id : params . 0. client_id . clone ( ) ,
2021-11-29 15:05:43 -05:00
} ;
let code = Uuid ::new_v4 ( ) ;
let mut conn = pool
. get ( )
. await
. map_err ( | e | anyhow! ( " Failed to get connection to database: {} " , e ) ) ? ;
conn . set_ex (
code . to_string ( ) ,
hex ::encode (
bincode ::serialize ( & code_entry )
. map_err ( | e | anyhow! ( " Failed to serialise code: {} " , e ) ) ? ,
) ,
ENTRY_LIFETIME ,
)
. await
. map_err ( | e | anyhow! ( " Failed to set kv: {} " , e ) ) ? ;
let mut url = params . redirect_uri . url ( ) . clone ( ) ;
url . query_pairs_mut ( ) . append_pair ( " code " , & code . to_string ( ) ) ;
url . query_pairs_mut ( ) . append_pair ( " state " , & params . state ) ;
Ok ( (
headers ,
Redirect ::to (
url . as_str ( )
. parse ( )
. map_err ( | e | anyhow! ( " Could not parse URI: {} " , e ) ) ? ,
) ,
) )
// TODO clear session
}
async fn register (
extract ::Json ( payload ) : extract ::Json < CoreClientMetadata > ,
Extension ( pool ) : Extension < ConnectionPool > ,
2021-12-20 11:29:43 -05:00
) -> Result < ( StatusCode , Json < CoreClientRegistrationResponse > ) , CustomError > {
2021-11-29 15:05:43 -05:00
let id = Uuid ::new_v4 ( ) ;
let secret = Uuid ::new_v4 ( ) ;
2021-12-20 11:29:43 -05:00
let conn = pool
2021-11-29 15:05:43 -05:00
. get ( )
. await
. map_err ( | e | anyhow! ( " Failed to get connection to database: {} " , e ) ) ? ;
2021-12-20 11:29:43 -05:00
let entry = ClientEntry {
secret : secret . to_string ( ) ,
redirect_uris : payload . redirect_uris ( ) . to_vec ( ) ,
} ;
set_client ( conn , id . to_string ( ) , entry ) . await ? ;
2021-11-29 15:05:43 -05:00
2021-12-20 11:29:43 -05:00
Ok ( (
StatusCode ::CREATED ,
CoreClientRegistrationResponse ::new (
ClientId ::new ( id . to_string ( ) ) ,
payload . redirect_uris ( ) . to_vec ( ) ,
EmptyAdditionalClientMetadata ::default ( ) ,
EmptyAdditionalClientRegistrationResponse ::default ( ) ,
)
. set_client_secret ( Some ( ClientSecret ::new ( secret . to_string ( ) ) ) )
. into ( ) ,
) )
2021-11-29 15:05:43 -05:00
}
// TODO CORS
// TODO need validation of the token
// TODO restrict access token use to only once?
async fn userinfo (
// access_token: AccessTokenUserInfo, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
TypedHeader ( Authorization ( bearer ) ) : TypedHeader < Authorization < Bearer > > , // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
Extension ( pool ) : Extension < ConnectionPool > ,
) -> Result < Json < CoreUserInfoClaims > , CustomError > {
let code = bearer . token ( ) . to_string ( ) ;
let mut conn = pool
. get ( )
. await
. map_err ( | e | anyhow! ( " Failed to get connection to database: {} " , e ) ) ? ;
let serialized_entry : Option < Vec < u8 > > = conn
. get ( code )
. await
. map_err ( | e | anyhow! ( " Failed to get kv: {} " , e ) ) ? ;
if serialized_entry . is_none ( ) {
2021-12-20 11:29:43 -05:00
return Err ( CustomError ::BadRequest ( " Unknown code. " . to_string ( ) ) ) ;
2021-11-29 15:05:43 -05:00
}
let code_entry : CodeEntry = bincode ::deserialize (
& hex ::decode ( serialized_entry . unwrap ( ) )
. map_err ( | e | anyhow! ( " Failed to decode code entry: {} " , e ) ) ? ,
)
. map_err ( | e | anyhow! ( " Failed to deserialize code: {} " , e ) ) ? ;
Ok ( CoreUserInfoClaims ::new (
StandardClaims ::new ( SubjectIdentifier ::new ( code_entry . address ) ) ,
EmptyAdditionalClaims ::default ( ) ,
)
. into ( ) )
}
async fn healthcheck ( ) { }
#[ tokio::main ]
async fn main ( ) {
let config = Figment ::from ( Serialized ::defaults ( config ::Config ::default ( ) ) )
. merge ( Toml ::file ( " siwe-oidc.toml " ) . nested ( ) )
. merge ( Env ::prefixed ( " SIWEOIDC_ " ) . split ( " __ " ) . global ( ) ) ;
let config = config . extract ::< config ::Config > ( ) . unwrap ( ) ;
tracing_subscriber ::fmt ::init ( ) ;
let manager = RedisConnectionManager ::new ( config . redis_url . clone ( ) ) . unwrap ( ) ;
let pool = bb8 ::Pool ::builder ( ) . build ( manager . clone ( ) ) . await . unwrap ( ) ;
let pool2 = bb8 ::Pool ::builder ( ) . build ( manager ) . await . unwrap ( ) ;
for ( id , secret ) in & config . default_clients . clone ( ) {
2021-12-20 11:29:43 -05:00
let conn = pool2
. get ( )
2021-11-29 15:05:43 -05:00
. await
2021-12-20 11:29:43 -05:00
. map_err ( | e | anyhow! ( " Failed to get connection to database: {} " , e ) )
2021-11-29 15:05:43 -05:00
. unwrap ( ) ;
2021-12-20 11:29:43 -05:00
let client_entry = ClientEntry {
secret : secret . to_string ( ) ,
redirect_uris : vec ! [ ] ,
} ;
set_client ( conn , id . to_string ( ) , client_entry )
. await
. unwrap ( ) ; // TODO
2021-11-29 15:05:43 -05:00
}
let private_key = if let Some ( key ) = & config . rsa_pem {
2021-12-20 11:29:43 -05:00
RsaPrivateKey ::from_pkcs1_pem ( key )
2021-11-29 15:05:43 -05:00
. map_err ( | e | anyhow! ( " Failed to load private key: {} " , e ) )
. unwrap ( )
} else {
info! ( " Generating key... " ) ;
let mut rng = OsRng ;
let bits = 2048 ;
let private = RsaPrivateKey ::new ( & mut rng , bits )
. map_err ( | e | anyhow! ( " Failed to generate a key: {} " , e ) )
. unwrap ( ) ;
info! ( " Generated key. " ) ;
info! ( " {:?} " , private . to_pkcs1_pem ( ) . unwrap ( ) ) ;
private
} ;
let app = Router ::new ( )
. nest (
" /build " ,
service_method_routing ::get ( ServeDir ::new ( " ./static/build " ) ) . handle_error (
| error : std ::io ::Error | {
(
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Unhandled internal error: {} " , error ) ,
)
} ,
) ,
)
2021-12-17 10:02:07 -05:00
. nest (
" /img " ,
service_method_routing ::get ( ServeDir ::new ( " ./static/img " ) ) . handle_error (
| error : std ::io ::Error | {
(
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Unhandled internal error: {} " , error ) ,
)
} ,
) ,
)
2021-11-29 15:05:43 -05:00
. route (
" / " ,
service_method_routing ::get ( ServeFile ::new ( " ./static/index.html " ) ) . handle_error (
| error : std ::io ::Error | {
(
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Unhandled internal error: {} " , error ) ,
)
} ,
) ,
)
2021-12-20 11:29:43 -05:00
. route (
" /error " ,
service_method_routing ::get ( ServeFile ::new ( " ./static/error.html " ) ) . handle_error (
| error : std ::io ::Error | {
(
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Unhandled internal error: {} " , error ) ,
)
} ,
) ,
)
2021-11-29 15:05:43 -05:00
. route (
" /favicon.png " ,
service_method_routing ::get ( ServeFile ::new ( " ./static/favicon.png " ) ) . handle_error (
| error : std ::io ::Error | {
(
StatusCode ::INTERNAL_SERVER_ERROR ,
format! ( " Unhandled internal error: {} " , error ) ,
)
} ,
) ,
)
. route ( " /.well-known/openid-configuration " , get ( provider_metadata ) )
. route ( " /jwk " , get ( jwk_set ) )
. route ( " /token " , post ( token ) )
. route ( " /authorize " , get ( authorize ) )
. route ( " /register " , post ( register ) )
. route ( " /userinfo " , get ( userinfo ) . post ( userinfo ) )
. route ( " /sign_in " , get ( sign_in ) )
. route ( " /health " , get ( healthcheck ) )
. layer ( AddExtensionLayer ::new ( private_key ) )
. layer ( AddExtensionLayer ::new ( config . clone ( ) ) )
. layer ( AddExtensionLayer ::new ( pool ) )
. layer ( AddExtensionLayer ::new (
RedisSessionStore ::new ( config . redis_url . clone ( ) )
. unwrap ( )
. with_prefix ( " async-sessions/ " ) ,
) )
. layer ( TraceLayer ::new_for_http ( ) ) ;
let addr = SocketAddr ::from ( ( config . address , config . port ) ) ;
tracing ::info! ( " Listening on {} " , addr ) ;
axum ::Server ::bind ( & addr )
. serve ( app . into_make_service ( ) )
. await
. unwrap ( ) ;
}