Cloudflare Worker version (#6)
Refactor/generalise API/DB interactions out of OIDC.
This commit is contained in:
parent
9d725552e0
commit
bbcacf4232
13
.github/workflows/ci.yml
vendored
13
.github/workflows/ci.yml
vendored
@ -8,12 +8,25 @@ env:
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cargo_target: "x86_64-unknown-linux-gnu"
|
||||
- cargo_target: "wasm32-unknown-unknown"
|
||||
steps:
|
||||
- name: Clone repo
|
||||
uses: actions/checkout@master
|
||||
- name: Add targets
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Build
|
||||
env:
|
||||
CARGO_BUILD_TARGET: ${{ matrix.cargo_target }}
|
||||
run: cargo build --verbose
|
||||
- name: Clippy
|
||||
env:
|
||||
CARGO_BUILD_TARGET: ${{ matrix.cargo_target }}
|
||||
run: RUSTFLAGS="-Dwarnings" cargo clippy
|
||||
- name: Fmt
|
||||
env:
|
||||
CARGO_BUILD_TARGET: ${{ matrix.cargo_target }}
|
||||
run: cargo fmt -- --check
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
/target
|
||||
/static/build
|
||||
wrangler.toml
|
||||
|
618
Cargo.lock
generated
618
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
57
Cargo.toml
57
Cargo.toml
@ -5,35 +5,68 @@ edition = "2021"
|
||||
authors = ["Spruce Systems, Inc."]
|
||||
license = "MIT OR Apache-2.0"
|
||||
repository = "https://github.com/spruceid/siwe-oidc/"
|
||||
description = "OpenID Connect Identity Provider for Sign-In with Ethereum."
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.51"
|
||||
axum = { version = "0.3.4", features = ["headers"] }
|
||||
chrono = "0.4.19"
|
||||
headers = "0.3.5"
|
||||
hex = "0.4.3"
|
||||
iri-string = { version = "0.4", features = ["serde-std"] }
|
||||
openidconnect = "2.1.2"
|
||||
# openidconnect = "2.1.2"
|
||||
openidconnect = { git = "https://github.com/sbihel/openidconnect-rs", branch = "main", default-features = false, features = ["reqwest", "rustls-tls", "rustcrypto"] }
|
||||
rand = "0.8.4"
|
||||
rsa = { version = "0.5.0", features = ["alloc"] }
|
||||
rust-argon2 = "0.8"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0.72"
|
||||
siwe = "0.1"
|
||||
async-session = "3.0.0"
|
||||
siwe = "0.1.2"
|
||||
thiserror = "1.0.30"
|
||||
tokio = { version = "1.14.0", features = ["full"] }
|
||||
tower-http = { version = "0.2.0", features = ["fs", "trace", "cors"] }
|
||||
tracing = "0.1.29"
|
||||
tracing-subscriber = { version = "0.3.2", features = ["env-filter"] }
|
||||
url = { version = "2.2", features = ["serde"] }
|
||||
urlencoding = "2.1.0"
|
||||
uuid = { version = "0.8", features = ["serde", "v4"] }
|
||||
figment = { version = "0.10.6", features = ["toml", "env"] }
|
||||
sha2 = "0.9.0"
|
||||
cookie = "0.15.1"
|
||||
bincode = "1.3.3"
|
||||
async-trait = "0.1.52"
|
||||
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
async-session = "3.0.0"
|
||||
axum = { version = "0.4.3", features = ["headers"] }
|
||||
# axum-debug = "0.3.2"
|
||||
chrono = "0.4.19"
|
||||
figment = { version = "0.10.6", features = ["toml", "env"] }
|
||||
tokio = { version = "1.14.0", features = ["full"] }
|
||||
tower-http = { version = "0.2.0", features = ["fs", "trace", "cors"] }
|
||||
tracing-subscriber = { version = "0.3.2", features = ["env-filter"] }
|
||||
bb8-redis = "0.10.1"
|
||||
async-redis-session = "0.2.2"
|
||||
uuid = { version = "0.8", features = ["serde", "v4"] }
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
# cached = { version = "0.26", default-features = false }
|
||||
chrono = { version = "0.4.19", features = ["wasmbind"] }
|
||||
console_error_panic_hook = { version = "0.1" }
|
||||
# console_log = "0.2"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
# log = "0.4"
|
||||
matchit = "0.4.2"
|
||||
serde_urlencoded = "0.7.0"
|
||||
uuid = { version = "0.8", features = ["serde", "v4", "wasm-bindgen"] }
|
||||
wee_alloc = { version = "0.4" }
|
||||
worker = "0.0.7"
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z"
|
||||
lto = true
|
||||
|
||||
# [target.'cfg(target_arch = "wasm32")'.profile.release]
|
||||
# opt-level = "z"
|
||||
|
||||
# [target.'cfg(target_arch = "wasm32")'.profile.debug]
|
||||
# opt-level = "z"
|
||||
# lto = false
|
||||
|
||||
[package.metadata.wasm-pack.profile.profiling]
|
||||
wasm-opt = ['-g', '-O']
|
||||
|
61
README.md
61
README.md
@ -2,11 +2,54 @@
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Dependencies
|
||||
Two versions are available, a stand-alone binary (using Axum and Redis) and a
|
||||
Cloudflare Worker. They use the same code base and are selected at compile time
|
||||
(compiling for `wasm32` will make the Worker version).
|
||||
|
||||
### Cloudflare Worker
|
||||
|
||||
You will need [`wrangler`](https://github.com/cloudflare/wrangler).
|
||||
|
||||
Then copy the configuration file template:
|
||||
```bash
|
||||
cp wrangler_example.toml wrangler.toml
|
||||
```
|
||||
|
||||
Replacing the following fields:
|
||||
- `account_id`: your Cloudflare account ID;
|
||||
- `zone_id`: (Optional) DNS zone ID; and
|
||||
- `kv_namespaces`: a KV namespace ID (created with `wrangler kv:namespace create SIWE-OIDC`).
|
||||
|
||||
At this point, you should be able to create/publish the worker:
|
||||
```
|
||||
wrangler publish
|
||||
```
|
||||
|
||||
The IdP currently only supports having the **frontend under the same subdomain as
|
||||
the API**. Here is the configuration for Cloudflare Pages:
|
||||
- `Build command`: `cd js/ui && npm install && npm run build`;
|
||||
- `Build output directory`: `/static`; and
|
||||
- `Root directory`: `/`.
|
||||
And you will need to add some rules to do the routing between the Page and the
|
||||
Worker. Here are the rules for the Worker (the Page being used as the fallback
|
||||
on the subdomain):
|
||||
```
|
||||
siweoidc.example.com/s*
|
||||
siweoidc.example.com/u*
|
||||
siweoidc.example.com/r*
|
||||
siweoidc.example.com/a*
|
||||
siweoidc.example.com/t*
|
||||
siweoidc.example.com/j*
|
||||
siweoidc.example.com/.w*
|
||||
```
|
||||
|
||||
### Stand-Alone Binary
|
||||
|
||||
#### Dependencies
|
||||
|
||||
Redis, or a Redis compatible database (e.g. MemoryDB in AWS), is required.
|
||||
|
||||
### Starting the IdP
|
||||
#### Starting the IdP
|
||||
|
||||
The Docker image is available at `ghcr.io/spruceid/siwe_oidc:0.1.0`. Here is an
|
||||
example usage:
|
||||
@ -35,9 +78,23 @@ For the core OIDC information, it is available under
|
||||
|
||||
* Additional information, from native projects (e.g. ENS domains), to more
|
||||
traditional ones (e.g. email).
|
||||
* PKCE support (code challenge).
|
||||
* Browser session support for the Worker version.
|
||||
|
||||
## Development
|
||||
|
||||
### Cloudflare Worker
|
||||
|
||||
```bash
|
||||
wrangler dev
|
||||
```
|
||||
You can now use http://127.0.0.1:8787/.well-known/openid-configuration.
|
||||
|
||||
> At the moment it's not possible to use it end-to-end with the frontend as they
|
||||
> need to share the same host (i.e. port), unless using a local load-balancer.
|
||||
|
||||
### Stand Alone Binary
|
||||
|
||||
A Docker Compose is available to test the IdP locally with Keycloak.
|
||||
|
||||
1. You will first need to run:
|
||||
|
2895
js/ui/package-lock.json
generated
2895
js/ui/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -2,28 +2,28 @@
|
||||
"name": "svelte-app",
|
||||
"version": "1.0.0",
|
||||
"devDependencies": {
|
||||
"@tsconfig/svelte": "^1.0.10",
|
||||
"@types/node": "^14.11.1",
|
||||
"@typescript-eslint/eslint-plugin": "^4.21.0",
|
||||
"@typescript-eslint/parser": "^4.21.0",
|
||||
"@tsconfig/svelte": "^3.0.0",
|
||||
"@types/node": "^17.0.7",
|
||||
"@typescript-eslint/eslint-plugin": "^5.9.0",
|
||||
"@typescript-eslint/parser": "^5.9.0",
|
||||
"assert": "^2.0.0",
|
||||
"autoprefixer": "^10.2.5",
|
||||
"base64-loader": "^1.0.0",
|
||||
"buffer": "^6.0.3",
|
||||
"cross-env": "^7.0.3",
|
||||
"crypto-browserify": "^3.12.0",
|
||||
"css-loader": "^5.0.1",
|
||||
"css-loader": "^6.5.1",
|
||||
"cssnano": "^5.0.8",
|
||||
"dotenv-webpack": "^7.0.3",
|
||||
"eslint": "^7.23.0",
|
||||
"eslint": "^8.6.0",
|
||||
"eslint-config-prettier": "^8.1.0",
|
||||
"eslint-plugin-svelte3": "^3.1.2",
|
||||
"https-browserify": "^1.0.0",
|
||||
"mini-css-extract-plugin": "^1.3.4",
|
||||
"mini-css-extract-plugin": "^2.4.5",
|
||||
"os-browserify": "^0.3.0",
|
||||
"postcss": "^8.2.8",
|
||||
"postcss-load-config": "^3.0.1",
|
||||
"postcss-loader": "^5.2.0",
|
||||
"postcss-loader": "^6.2.1",
|
||||
"precss": "^4.0.0",
|
||||
"prettier": "^2.2.1",
|
||||
"prettier-plugin-svelte": "^2.2.0",
|
||||
@ -31,12 +31,12 @@
|
||||
"stream-browserify": "^3.0.0",
|
||||
"stream-http": "^3.2.0",
|
||||
"svelte": "^3.31.2",
|
||||
"svelte-check": "^1.0.46",
|
||||
"svelte-check": "^2.2.11",
|
||||
"svelte-loader": "^3.0.0",
|
||||
"svelte-preprocess": "^4.3.0",
|
||||
"svg-url-loader": "^7.1.1",
|
||||
"tailwindcss": "^2.0.4",
|
||||
"ts-loader": "^8.0.4",
|
||||
"tailwindcss": "^3.0.9",
|
||||
"ts-loader": "^9.2.6",
|
||||
"tslib": "^2.0.1",
|
||||
"typescript": "^4.0.3",
|
||||
"webpack": "^5.16.0",
|
||||
@ -54,6 +54,7 @@
|
||||
"@toruslabs/torus-embed": "^1.18.3",
|
||||
"@walletconnect/web3-provider": "^1.6.6",
|
||||
"fortmatic": "^2.2.1",
|
||||
"url": "^0.11.0",
|
||||
"walletlink": "^2.2.8"
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
{
|
||||
"extends": "@tsconfig/svelte/tsconfig.json",
|
||||
"include": ["src/**/*", "src/node_modules/**/*"],
|
||||
"exclude": ["node_modules/*", "__sapper__/*", "static/*"]
|
||||
"extends": "@tsconfig/svelte/tsconfig.json",
|
||||
"include": ["src/**/*", "src/node_modules/**/*"],
|
||||
"exclude": ["node_modules/*", "__sapper__/*", "static/*"],
|
||||
"compilerOptions": {
|
||||
"types": ["node", "svelte"]
|
||||
}
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ module.exports = {
|
||||
path: false,
|
||||
process: require.resolve('process/browser'),
|
||||
stream: require.resolve('stream-browserify'),
|
||||
url: require.resolve("url")
|
||||
// util: false,
|
||||
}
|
||||
},
|
||||
|
364
src/axum_lib.rs
Normal file
364
src/axum_lib.rs
Normal file
@ -0,0 +1,364 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_redis_session::RedisSessionStore;
|
||||
use axum::{
|
||||
extract::{self, Extension, Form, Query, TypedHeader},
|
||||
http::{
|
||||
header::{self, HeaderMap},
|
||||
StatusCode,
|
||||
},
|
||||
response::{self, IntoResponse, Redirect},
|
||||
routing::{get, get_service, post},
|
||||
AddExtensionLayer, Json, Router,
|
||||
};
|
||||
use bb8_redis::{bb8, RedisConnectionManager};
|
||||
use figment::{
|
||||
providers::{Env, Format, Serialized, Toml},
|
||||
Figment,
|
||||
};
|
||||
use headers::{
|
||||
self,
|
||||
authorization::{Basic, Bearer},
|
||||
Authorization,
|
||||
};
|
||||
use openidconnect::core::{
|
||||
CoreClientMetadata, CoreClientRegistrationResponse, CoreJsonWebKeySet, CoreProviderMetadata,
|
||||
CoreResponseType, CoreTokenResponse, CoreUserInfoClaims,
|
||||
};
|
||||
use rand::rngs::OsRng;
|
||||
use rsa::{
|
||||
pkcs1::{FromRsaPrivateKey, ToRsaPrivateKey},
|
||||
RsaPrivateKey,
|
||||
};
|
||||
use std::net::SocketAddr;
|
||||
use tower_http::{
|
||||
services::{ServeDir, ServeFile},
|
||||
trace::TraceLayer,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use super::config;
|
||||
use super::oidc::{self, CustomError};
|
||||
use super::session::*;
|
||||
use ::siwe_oidc::db::*;
|
||||
|
||||
impl IntoResponse for CustomError {
|
||||
fn into_response(self) -> response::Response {
|
||||
match self {
|
||||
CustomError::BadRequest(_) => {
|
||||
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
|
||||
}
|
||||
CustomError::BadRequestToken(e) => {
|
||||
(StatusCode::BAD_REQUEST, Json::from(e)).into_response()
|
||||
}
|
||||
CustomError::Unauthorized(_) => {
|
||||
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
|
||||
}
|
||||
CustomError::Redirect(uri) => Redirect::to(
|
||||
uri.parse().unwrap(),
|
||||
// .map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
)
|
||||
.into_response(),
|
||||
CustomError::Other(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn jwk_set(
|
||||
Extension(private_key): Extension<RsaPrivateKey>,
|
||||
) -> Result<Json<CoreJsonWebKeySet>, CustomError> {
|
||||
let jwks = oidc::jwks(private_key)?;
|
||||
Ok(jwks.into())
|
||||
}
|
||||
|
||||
async fn provider_metadata(
|
||||
Extension(config): Extension<config::Config>,
|
||||
) -> Result<Json<CoreProviderMetadata>, CustomError> {
|
||||
Ok(oidc::metadata(config.base_url)?.into())
|
||||
}
|
||||
|
||||
// TODO should check Authorization header
|
||||
// Actually, client secret can be
|
||||
// 1. in the POST (currently supported) [x]
|
||||
// 2. Authorization header [x]
|
||||
// 3. JWT [ ]
|
||||
// 4. signed JWT [ ]
|
||||
// according to Keycloak
|
||||
|
||||
async fn token(
|
||||
Form(form): Form<oidc::TokenForm>,
|
||||
bearer: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
basic: Option<TypedHeader<Authorization<Basic>>>,
|
||||
Extension(private_key): Extension<RsaPrivateKey>,
|
||||
Extension(config): Extension<config::Config>,
|
||||
Extension(redis_client): Extension<RedisClient>,
|
||||
) -> Result<Json<CoreTokenResponse>, CustomError> {
|
||||
let secret = if let Some(b) = bearer {
|
||||
Some(b.0 .0.token().to_string())
|
||||
} else {
|
||||
basic.map(|b| b.0 .0.password().to_string())
|
||||
};
|
||||
let token_response = oidc::token(
|
||||
form,
|
||||
secret,
|
||||
private_key,
|
||||
config.base_url,
|
||||
config.require_secret,
|
||||
&redis_client,
|
||||
)
|
||||
.await?;
|
||||
Ok(token_response.into())
|
||||
}
|
||||
|
||||
// TODO handle `registration` parameter
|
||||
async fn authorize(
|
||||
session: UserSessionFromSession,
|
||||
Query(params): Query<oidc::AuthorizeParams>,
|
||||
Extension(redis_client): Extension<RedisClient>,
|
||||
) -> Result<(HeaderMap, Redirect), CustomError> {
|
||||
let (nonce, headers) = match session {
|
||||
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
|
||||
UserSessionFromSession::Invalid(cookie) => {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::SET_COOKIE, cookie);
|
||||
return Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{}",
|
||||
¶ms.client_id,
|
||||
¶ms.redirect_uri.to_string(),
|
||||
¶ms.scope.to_string(),
|
||||
¶ms.response_type.unwrap_or(CoreResponseType::Code).as_ref(),
|
||||
¶ms.state.unwrap_or_default(),
|
||||
¶ms.client_id,
|
||||
¶ms.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or_default()
|
||||
)
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
}
|
||||
UserSessionFromSession::Created { header, nonce } => {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::SET_COOKIE, header);
|
||||
(nonce, headers)
|
||||
}
|
||||
};
|
||||
|
||||
let url = oidc::authorize(params, nonce, &redis_client).await?;
|
||||
Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
url.as_str()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
async fn sign_in(
|
||||
session: UserSessionFromSession,
|
||||
Query(params): Query<oidc::SignInParams>,
|
||||
TypedHeader(cookies): TypedHeader<headers::Cookie>,
|
||||
Extension(redis_client): Extension<RedisClient>,
|
||||
) -> Result<(HeaderMap, Redirect), CustomError> {
|
||||
let (nonce, headers) = match session {
|
||||
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
|
||||
UserSessionFromSession::Invalid(header) => {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::SET_COOKIE, header);
|
||||
return Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
|
||||
¶ms.client_id.clone(),
|
||||
¶ms.redirect_uri.to_string(),
|
||||
¶ms.state,
|
||||
)
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
}
|
||||
UserSessionFromSession::Created { .. } => {
|
||||
return Ok((
|
||||
HeaderMap::new(),
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
|
||||
¶ms.client_id.clone(),
|
||||
¶ms.redirect_uri.to_string(),
|
||||
¶ms.state,
|
||||
)
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let url = oidc::sign_in(params, Some(nonce), cookies, &redis_client).await?;
|
||||
|
||||
Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
url.as_str()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
))
|
||||
// TODO clear session
|
||||
}
|
||||
|
||||
async fn register(
|
||||
extract::Json(payload): extract::Json<CoreClientMetadata>,
|
||||
Extension(redis_client): Extension<RedisClient>,
|
||||
) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> {
|
||||
let registration = oidc::register(payload, &redis_client).await?;
|
||||
Ok((StatusCode::CREATED, registration.into()))
|
||||
}
|
||||
|
||||
// TODO CORS
|
||||
// TODO need validation of the token
|
||||
async fn userinfo(
|
||||
payload: Option<Form<oidc::UserInfoPayload>>,
|
||||
bearer: Option<TypedHeader<Authorization<Bearer>>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
|
||||
Extension(redis_client): Extension<RedisClient>,
|
||||
) -> Result<Json<CoreUserInfoClaims>, CustomError> {
|
||||
let payload = if let Some(Form(p)) = payload {
|
||||
p
|
||||
} else {
|
||||
oidc::UserInfoPayload { access_token: None }
|
||||
};
|
||||
let claims = oidc::userinfo(bearer.map(|b| b.0 .0), payload, &redis_client).await?;
|
||||
Ok(claims.into())
|
||||
}
|
||||
|
||||
async fn healthcheck() {}
|
||||
|
||||
pub async fn main() {
|
||||
let config = Figment::from(Serialized::defaults(config::Config::default()))
|
||||
.merge(Toml::file("siwe-oidc.toml").nested())
|
||||
.merge(Env::prefixed("SIWEOIDC_").split("__").global());
|
||||
let config = config.extract::<config::Config>().unwrap();
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap();
|
||||
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
|
||||
// let pool2 = bb8::Pool::builder().build(manager).await.unwrap();
|
||||
|
||||
let redis_client = RedisClient { pool };
|
||||
|
||||
for (id, secret) in &config.default_clients.clone() {
|
||||
let client_entry = ClientEntry {
|
||||
secret: secret.to_string(),
|
||||
redirect_uris: vec![],
|
||||
};
|
||||
redis_client
|
||||
.set_client(id.to_string(), client_entry)
|
||||
.await
|
||||
.unwrap(); // TODO
|
||||
}
|
||||
|
||||
let private_key = if let Some(key) = &config.rsa_pem {
|
||||
RsaPrivateKey::from_pkcs1_pem(key)
|
||||
.map_err(|e| anyhow!("Failed to load private key: {}", e))
|
||||
.unwrap()
|
||||
} else {
|
||||
info!("Generating key...");
|
||||
let mut rng = OsRng;
|
||||
let bits = 2048;
|
||||
let private = RsaPrivateKey::new(&mut rng, bits)
|
||||
.map_err(|e| anyhow!("Failed to generate a key: {}", e))
|
||||
.unwrap();
|
||||
|
||||
info!("Generated key.");
|
||||
info!("{:?}", private.to_pkcs1_pem().unwrap());
|
||||
private
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.nest(
|
||||
"/build",
|
||||
get_service(ServeDir::new("./static/build")).handle_error(
|
||||
|error: std::io::Error| async move {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.nest(
|
||||
"/img",
|
||||
get_service(ServeDir::new("./static/img")).handle_error(
|
||||
|error: std::io::Error| async move {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route(
|
||||
"/",
|
||||
get_service(ServeFile::new("./static/index.html")).handle_error(
|
||||
|error: std::io::Error| async move {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route(
|
||||
"/error",
|
||||
get_service(ServeFile::new("./static/error.html")).handle_error(
|
||||
|error: std::io::Error| async move {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route(
|
||||
"/favicon.png",
|
||||
get_service(ServeFile::new("./static/favicon.png")).handle_error(
|
||||
|error: std::io::Error| async move {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route(oidc::METADATA_PATH, get(provider_metadata))
|
||||
.route(oidc::JWK_PATH, get(jwk_set))
|
||||
.route(oidc::TOKEN_PATH, post(token))
|
||||
.route(oidc::AUTHORIZE_PATH, get(authorize))
|
||||
.route(oidc::REGISTER_PATH, post(register))
|
||||
.route(oidc::USERINFO_PATH, get(userinfo).post(userinfo))
|
||||
.route(oidc::SIGNIN_PATH, get(sign_in))
|
||||
.route("/health", get(healthcheck))
|
||||
.layer(AddExtensionLayer::new(private_key))
|
||||
.layer(AddExtensionLayer::new(config.clone()))
|
||||
.layer(AddExtensionLayer::new(redis_client))
|
||||
.layer(AddExtensionLayer::new(
|
||||
RedisSessionStore::new(config.redis_url.clone())
|
||||
.unwrap()
|
||||
.with_prefix("async-sessions/"),
|
||||
))
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
let addr = SocketAddr::from((config.address, config.port));
|
||||
tracing::info!("Listening on {}", addr);
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
43
src/db.rs
43
src/db.rs
@ -1,43 +0,0 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use bb8_redis::{bb8::PooledConnection, redis::AsyncCommands, RedisConnectionManager};
|
||||
use openidconnect::RedirectUrl;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const KV_CLIENT_PREFIX: &str = "clients";
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ClientEntry {
|
||||
pub secret: String,
|
||||
pub redirect_uris: Vec<RedirectUrl>,
|
||||
}
|
||||
|
||||
pub async fn set_client(
|
||||
mut conn: PooledConnection<'_, RedisConnectionManager>,
|
||||
client_id: String,
|
||||
client_entry: ClientEntry,
|
||||
) -> Result<()> {
|
||||
conn.set(
|
||||
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
|
||||
serde_json::to_string(&client_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_client(
|
||||
mut conn: PooledConnection<'_, RedisConnectionManager>,
|
||||
client_id: String,
|
||||
) -> Result<Option<ClientEntry>> {
|
||||
let entry: Option<String> = conn
|
||||
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if let Some(e) = entry {
|
||||
Ok(serde_json::from_str(&e)
|
||||
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
199
src/db/cf.rs
Normal file
199
src/db/cf.rs
Normal file
@ -0,0 +1,199 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
// use cached::{stores::TimedCache, Cached};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use matchit::Node;
|
||||
use std::collections::HashMap;
|
||||
use worker::*;
|
||||
|
||||
use super::*;
|
||||
|
||||
const KV_NAMESPACE: &str = "SIWE-OIDC";
|
||||
const DO_NAMESPACE: &str = "SIWE-OIDC-CODES";
|
||||
|
||||
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
|
||||
// Heavily relying on:
|
||||
// A Durable Object is given 30 seconds of additional CPU time for every
|
||||
// request it processes, including WebSocket messages. In the absence of
|
||||
// failures, in-memory state should not be reset after less than 30 seconds of
|
||||
// inactivity.
|
||||
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
|
||||
|
||||
// Wanted to use TimedCache but it (probably) crashes because it's using std::time::Instant which isn't available on wasm32.
|
||||
|
||||
#[durable_object]
|
||||
pub struct DOCodes {
|
||||
// codes: TimedCache<String, CodeEntry>,
|
||||
codes: HashMap<String, (DateTime<Utc>, CodeEntry)>,
|
||||
// state: State,
|
||||
// env: Env,
|
||||
}
|
||||
|
||||
#[durable_object]
|
||||
impl DurableObject for DOCodes {
|
||||
fn new(state: State, _env: Env) -> Self {
|
||||
Self {
|
||||
// codes: TimedCache::with_lifespan(ENTRY_LIFETIME.try_into().unwrap()),
|
||||
codes: HashMap::new(),
|
||||
// state,
|
||||
// env,
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch(&mut self, mut req: Request) -> worker::Result<Response> {
|
||||
// Can't use the Router because we need to reference self (thus move the var to the closure)
|
||||
if matches!(req.method(), Method::Get) {
|
||||
let mut matcher = Node::new();
|
||||
matcher.insert("/:code", ())?;
|
||||
let path = req.path();
|
||||
let matched = match matcher.at(&path) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return Response::error("Bad request", 400),
|
||||
};
|
||||
let code = if let Some(c) = matched.params.get("code") {
|
||||
c
|
||||
} else {
|
||||
return Response::error("Bad request", 400);
|
||||
};
|
||||
if let Some(c) = self.codes.get(code) {
|
||||
if c.0 + Duration::seconds(ENTRY_LIFETIME.try_into().unwrap()) < Utc::now() {
|
||||
self.codes.remove(code);
|
||||
Response::error("Not found", 404)
|
||||
} else {
|
||||
Response::from_json(&c.1)
|
||||
}
|
||||
} else {
|
||||
Response::error("Not found", 404)
|
||||
}
|
||||
} else if matches!(req.method(), Method::Post) {
|
||||
let mut matcher = Node::new();
|
||||
matcher.insert("/:code", ())?;
|
||||
let path = req.path();
|
||||
let matched = match matcher.at(&path) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return Response::error("Bad request", 400),
|
||||
};
|
||||
let code = if let Some(c) = matched.params.get("code") {
|
||||
c
|
||||
} else {
|
||||
return Response::error("Bad request", 400);
|
||||
};
|
||||
let code_entry = match req.json().await {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Response::error(format!("Bad request: {}", e), 400),
|
||||
};
|
||||
self.codes
|
||||
.insert(code.to_string(), (Utc::now(), code_entry));
|
||||
Response::empty()
|
||||
} else {
|
||||
Response::error("Method Not Allowed", 405)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CFClient {
|
||||
pub ctx: RouteContext<()>,
|
||||
pub url: Url,
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
|
||||
impl DBClient for CFClient {
|
||||
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
|
||||
self.ctx
|
||||
.kv(KV_NAMESPACE)
|
||||
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
|
||||
.put(
|
||||
&format!("{}/{}", KV_CLIENT_PREFIX, client_id),
|
||||
serde_json::to_string(&client_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to build KV put: {}", e))?
|
||||
// TODO put some sort of expiration for dynamic registration
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to put KV: {}", e))?;
|
||||
Ok(())
|
||||
}
|
||||
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
|
||||
let entry = self
|
||||
.ctx
|
||||
.kv(KV_NAMESPACE)
|
||||
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
|
||||
.get(&format!("{}/{}", KV_CLIENT_PREFIX, client_id))
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get KV: {}", e))?
|
||||
.map(|e| e.as_string());
|
||||
if let Some(e) = entry {
|
||||
Ok(serde_json::from_str(&e)
|
||||
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
|
||||
let namespace = self
|
||||
.ctx
|
||||
.durable_object(DO_NAMESPACE)
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
|
||||
let stub = namespace
|
||||
.id_from_name(&code)
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
|
||||
.get_stub()
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
|
||||
let mut headers = Headers::new();
|
||||
headers.set("Content-Type", "application/json").unwrap();
|
||||
let mut url = self.url.clone();
|
||||
url.set_path(&code);
|
||||
url.set_query(None);
|
||||
let req = Request::new_with_init(
|
||||
url.as_str(),
|
||||
&RequestInit {
|
||||
body: Some(wasm_bindgen::JsValue::from_str(
|
||||
&serde_json::to_string(&code_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialize: {}", e))?,
|
||||
)),
|
||||
method: Method::Post,
|
||||
headers,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to construct request for Durable Object: {}", e))?;
|
||||
let res = stub
|
||||
.fetch_with_request(req)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
|
||||
match res.status_code() {
|
||||
200 => Ok(()),
|
||||
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
|
||||
}
|
||||
}
|
||||
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
|
||||
let namespace = self
|
||||
.ctx
|
||||
.durable_object(DO_NAMESPACE)
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
|
||||
let stub = namespace
|
||||
.id_from_name(&code)
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
|
||||
.get_stub()
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
|
||||
let mut url = self.url.clone();
|
||||
url.set_path(&code);
|
||||
url.set_query(None);
|
||||
let mut res = stub
|
||||
.fetch_with_str(url.as_str())
|
||||
.await
|
||||
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
|
||||
match res.status_code() {
|
||||
200 => Ok(Some(res.json().await.map_err(|e| {
|
||||
anyhow!(
|
||||
"Response to Durable Object failed to be deserialized: {}",
|
||||
e
|
||||
)
|
||||
})?)),
|
||||
404 => Ok(None),
|
||||
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
|
||||
}
|
||||
}
|
||||
}
|
40
src/db/mod.rs
Normal file
40
src/db/mod.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use openidconnect::{Nonce, RedirectUrl};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod redis;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use redis::RedisClient;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
mod cf;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub use cf::CFClient;
|
||||
|
||||
const KV_CLIENT_PREFIX: &str = "clients";
|
||||
const ENTRY_LIFETIME: usize = 30;
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct CodeEntry {
|
||||
pub exchange_count: usize,
|
||||
pub address: String,
|
||||
pub nonce: Option<Nonce>,
|
||||
pub client_id: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct ClientEntry {
|
||||
pub secret: String,
|
||||
pub redirect_uris: Vec<RedirectUrl>,
|
||||
}
|
||||
|
||||
// Using a trait to easily pass async functions with async_trait
|
||||
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
|
||||
pub trait DBClient {
|
||||
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()>;
|
||||
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>>;
|
||||
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()>;
|
||||
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>>;
|
||||
}
|
89
src/db/redis.rs
Normal file
89
src/db/redis.rs
Normal file
@ -0,0 +1,89 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisClient {
|
||||
pub pool: Pool<RedisConnectionManager>,
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
|
||||
impl DBClient for RedisClient {
|
||||
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
|
||||
let mut conn = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
|
||||
conn.set(
|
||||
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
|
||||
serde_json::to_string(&client_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
|
||||
let mut conn = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
let entry: Option<String> = conn
|
||||
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if let Some(e) = entry {
|
||||
Ok(serde_json::from_str(&e)
|
||||
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
|
||||
let mut conn = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
conn.set_ex(
|
||||
code.to_string(),
|
||||
hex::encode(
|
||||
bincode::serialize(&code_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
|
||||
),
|
||||
ENTRY_LIFETIME,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
|
||||
let mut conn = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
let serialized_entry: Option<Vec<u8>> = conn
|
||||
.get(code)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if serialized_entry.is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
let code_entry: CodeEntry = bincode::deserialize(
|
||||
&hex::decode(serialized_entry.unwrap())
|
||||
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
|
||||
Ok(Some(code_entry))
|
||||
}
|
||||
}
|
18
src/lib.rs
Normal file
18
src/lib.rs
Normal file
@ -0,0 +1,18 @@
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use worker::*;
|
||||
|
||||
pub mod db;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub mod oidc;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
mod worker_lib;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use worker_lib::main as worker_main;
|
||||
// pub use worker_lib::main;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[event(fetch)]
|
||||
pub async fn main(req: Request, env: Env) -> Result<Response> {
|
||||
worker_main(req, env).await
|
||||
}
|
889
src/main.rs
889
src/main.rs
@ -1,882 +1,19 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_redis_session::RedisSessionStore;
|
||||
use axum::{
|
||||
body::{Bytes, Full},
|
||||
error_handling::HandleErrorExt,
|
||||
extract::{self, Extension, Form, Query, TypedHeader},
|
||||
http::{
|
||||
header::{self, HeaderMap},
|
||||
Response, StatusCode,
|
||||
},
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::{get, post, service_method_routing},
|
||||
AddExtensionLayer, Json, Router,
|
||||
};
|
||||
use bb8_redis::{bb8, bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
|
||||
use chrono::{Duration, Utc};
|
||||
use figment::{
|
||||
providers::{Env, Format, Serialized, Toml},
|
||||
Figment,
|
||||
};
|
||||
use headers::{self, authorization::Bearer, Authorization};
|
||||
use hex::FromHex;
|
||||
use iri_string::types::{UriAbsoluteString, UriString};
|
||||
use openidconnect::{
|
||||
core::{
|
||||
CoreAuthErrorResponseType, CoreAuthPrompt, CoreClaimName, CoreClientAuthMethod,
|
||||
CoreClientMetadata, CoreClientRegistrationResponse, CoreErrorResponseType, CoreGrantType,
|
||||
CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
|
||||
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
|
||||
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
|
||||
},
|
||||
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
|
||||
url::Url,
|
||||
AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims,
|
||||
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId,
|
||||
JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, RequestUrl,
|
||||
ResponseTypes, Scope, StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
|
||||
};
|
||||
use rand::rngs::OsRng;
|
||||
use rsa::{
|
||||
pkcs1::{FromRsaPrivateKey, ToRsaPrivateKey},
|
||||
RsaPrivateKey,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use siwe::eip4361::{Message, Version};
|
||||
use std::{convert::Infallible, net::SocketAddr, str::FromStr};
|
||||
use thiserror::Error;
|
||||
use tower_http::{
|
||||
services::{ServeDir, ServeFile},
|
||||
trace::TraceLayer,
|
||||
};
|
||||
use tracing::info;
|
||||
use urlencoding::decode;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod axum_lib;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod config;
|
||||
mod db;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod oidc;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod session;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use axum_lib::main as axum_main;
|
||||
|
||||
use db::*;
|
||||
use session::*;
|
||||
|
||||
const KID: &str = "key1";
|
||||
const ENTRY_LIFETIME: usize = 30;
|
||||
|
||||
type ConnectionPool = Pool<RedisConnectionManager>;
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct TokenError {
|
||||
pub error: CoreErrorResponseType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CustomError {
|
||||
#[error("{0}")]
|
||||
BadRequest(String),
|
||||
#[error("{0:?}")]
|
||||
BadRequestToken(Json<TokenError>),
|
||||
#[error("{0}")]
|
||||
Unauthorized(String),
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl IntoResponse for CustomError {
|
||||
type Body = Full<Bytes>;
|
||||
type BodyError = Infallible;
|
||||
|
||||
fn into_response(self) -> Response<Self::Body> {
|
||||
match self {
|
||||
CustomError::BadRequest(_) => {
|
||||
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
|
||||
}
|
||||
CustomError::BadRequestToken(e) => (StatusCode::BAD_REQUEST, e).into_response(),
|
||||
CustomError::Unauthorized(_) => {
|
||||
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
|
||||
}
|
||||
CustomError::Other(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn jwk_set(
|
||||
Extension(private_key): Extension<RsaPrivateKey>,
|
||||
) -> Result<Json<CoreJsonWebKeySet>, CustomError> {
|
||||
let pem = private_key
|
||||
.to_pkcs1_pem()
|
||||
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
|
||||
let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem(
|
||||
&pem,
|
||||
Some(JsonWebKeyId::new(KID.to_string())),
|
||||
)
|
||||
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?
|
||||
.as_verification_key()]);
|
||||
Ok(jwks.into())
|
||||
}
|
||||
|
||||
async fn provider_metadata(
|
||||
Extension(config): Extension<config::Config>,
|
||||
) -> Result<Json<CoreProviderMetadata>, CustomError> {
|
||||
let pm = CoreProviderMetadata::new(
|
||||
IssuerUrl::from_url(config.base_url.clone()),
|
||||
AuthUrl::from_url(
|
||||
config
|
||||
.base_url
|
||||
.join("authorize")
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
),
|
||||
JsonWebKeySetUrl::from_url(
|
||||
config
|
||||
.base_url
|
||||
.join("jwk")
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
),
|
||||
vec![
|
||||
ResponseTypes::new(vec![CoreResponseType::Code]),
|
||||
ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]),
|
||||
],
|
||||
vec![CoreSubjectIdentifierType::Pairwise],
|
||||
vec![CoreJwsSigningAlgorithm::RsaSsaPssSha256],
|
||||
EmptyAdditionalProviderMetadata {},
|
||||
)
|
||||
.set_token_endpoint(Some(TokenUrl::from_url(
|
||||
config
|
||||
.base_url
|
||||
.join("token")
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
)))
|
||||
.set_userinfo_endpoint(Some(UserInfoUrl::from_url(
|
||||
config
|
||||
.base_url
|
||||
.join("userinfo")
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
)))
|
||||
.set_scopes_supported(Some(vec![
|
||||
Scope::new("openid".to_string()),
|
||||
// Scope::new("email".to_string()),
|
||||
// Scope::new("profile".to_string()),
|
||||
]))
|
||||
.set_claims_supported(Some(vec![
|
||||
CoreClaimName::new("sub".to_string()),
|
||||
CoreClaimName::new("aud".to_string()),
|
||||
// CoreClaimName::new("email".to_string()),
|
||||
// CoreClaimName::new("email_verified".to_string()),
|
||||
CoreClaimName::new("exp".to_string()),
|
||||
CoreClaimName::new("iat".to_string()),
|
||||
CoreClaimName::new("iss".to_string()),
|
||||
// CoreClaimName::new("name".to_string()),
|
||||
// CoreClaimName::new("given_name".to_string()),
|
||||
// CoreClaimName::new("family_name".to_string()),
|
||||
// CoreClaimName::new("picture".to_string()),
|
||||
// CoreClaimName::new("locale".to_string()),
|
||||
]))
|
||||
.set_registration_endpoint(Some(RegistrationUrl::from_url(
|
||||
config
|
||||
.base_url
|
||||
.join("register")
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
)))
|
||||
.set_token_endpoint_auth_methods_supported(Some(vec![
|
||||
CoreClientAuthMethod::ClientSecretBasic,
|
||||
CoreClientAuthMethod::ClientSecretPost,
|
||||
]));
|
||||
|
||||
Ok(pm.into())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct TokenForm {
|
||||
code: String,
|
||||
client_id: Option<String>,
|
||||
client_secret: Option<String>,
|
||||
grant_type: CoreGrantType, // TODO should just be authorization_code apparently?
|
||||
}
|
||||
|
||||
// TODO should check Authorization header
|
||||
// Actually, client secret can be
|
||||
// 1. in the POST (currently supported) [x]
|
||||
// 2. Authorization header [x]
|
||||
// 3. JWT [ ]
|
||||
// 4. signed JWT [ ]
|
||||
// according to Keycloak
|
||||
|
||||
async fn token(
|
||||
form: Form<TokenForm>,
|
||||
bearer: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
Extension(private_key): Extension<RsaPrivateKey>,
|
||||
Extension(config): Extension<config::Config>,
|
||||
Extension(pool): Extension<ConnectionPool>,
|
||||
) -> Result<Json<CoreTokenResponse>, CustomError> {
|
||||
let mut conn = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
|
||||
let serialized_entry: Option<Vec<u8>> = conn
|
||||
.get(form.code.to_string())
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if serialized_entry.is_none() {
|
||||
return Err(CustomError::BadRequestToken(
|
||||
TokenError {
|
||||
error: CoreErrorResponseType::InvalidGrant,
|
||||
}
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
let code_entry: CodeEntry = bincode::deserialize(
|
||||
&hex::decode(serialized_entry.unwrap())
|
||||
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
|
||||
|
||||
let client_id = if let Some(c) = form.client_id.clone() {
|
||||
c
|
||||
} else {
|
||||
code_entry.client_id.clone()
|
||||
};
|
||||
|
||||
if let Some(secret) = if let Some(TypedHeader(Authorization(b))) = bearer {
|
||||
Some(b.token().to_string())
|
||||
} else {
|
||||
form.client_secret.clone()
|
||||
} {
|
||||
let conn2 = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
let client_entry = get_client(conn2, client_id.clone()).await?;
|
||||
if client_entry.is_none() {
|
||||
return Err(CustomError::Unauthorized(
|
||||
"Unrecognised client id.".to_string(),
|
||||
));
|
||||
}
|
||||
if secret != client_entry.unwrap().secret {
|
||||
return Err(CustomError::Unauthorized("Bad secret.".to_string()));
|
||||
}
|
||||
} else if config.require_secret {
|
||||
return Err(CustomError::Unauthorized("Secret required.".to_string()));
|
||||
}
|
||||
|
||||
if code_entry.exchange_count > 0 {
|
||||
// TODO use Oauth error response
|
||||
return Err(anyhow!("Code was previously exchanged.").into());
|
||||
}
|
||||
conn.set_ex(
|
||||
form.code.to_string(),
|
||||
hex::encode(
|
||||
bincode::serialize(&code_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
|
||||
),
|
||||
ENTRY_LIFETIME,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||
|
||||
let access_token = AccessToken::new(form.code.to_string());
|
||||
let core_id_token = CoreIdTokenClaims::new(
|
||||
IssuerUrl::from_url(config.base_url),
|
||||
vec![Audience::new(client_id.clone())],
|
||||
Utc::now() + Duration::seconds(60),
|
||||
Utc::now(),
|
||||
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
|
||||
EmptyAdditionalClaims {},
|
||||
)
|
||||
.set_nonce(code_entry.nonce);
|
||||
|
||||
let pem = private_key
|
||||
.to_pkcs1_pem()
|
||||
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
|
||||
|
||||
let id_token = CoreIdToken::new(
|
||||
core_id_token,
|
||||
&CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string())))
|
||||
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?,
|
||||
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256,
|
||||
Some(&access_token),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| anyhow!("{}", e))?;
|
||||
|
||||
Ok(CoreTokenResponse::new(
|
||||
access_token,
|
||||
CoreTokenType::Bearer,
|
||||
CoreIdTokenFields::new(Some(id_token), EmptyExtraTokenFields {}),
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthorizeParams {
|
||||
client_id: String,
|
||||
redirect_uri: RedirectUrl,
|
||||
scope: Scope,
|
||||
response_type: Option<CoreResponseType>,
|
||||
state: Option<String>,
|
||||
nonce: Option<Nonce>,
|
||||
prompt: Option<CoreAuthPrompt>,
|
||||
request_uri: Option<RequestUrl>,
|
||||
request: Option<String>,
|
||||
}
|
||||
|
||||
// TODO handle `registration` parameter
|
||||
async fn authorize(
|
||||
session: UserSessionFromSession,
|
||||
params: Query<AuthorizeParams>,
|
||||
Extension(pool): Extension<ConnectionPool>,
|
||||
) -> Result<(HeaderMap, Redirect), CustomError> {
|
||||
let conn = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
let client_entry = get_client(conn, params.client_id.clone())
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if client_entry.is_none() {
|
||||
return Err(CustomError::Unauthorized(
|
||||
"Unrecognised client id.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut r_u = params.0.redirect_uri.clone().url().clone();
|
||||
r_u.set_query(None);
|
||||
let mut r_us: Vec<Url> = client_entry
|
||||
.unwrap()
|
||||
.redirect_uris
|
||||
.iter_mut()
|
||||
.map(|u| u.url().clone())
|
||||
.collect();
|
||||
r_us.iter_mut().for_each(|u| u.set_query(None));
|
||||
if !r_us.contains(&r_u) {
|
||||
return Ok((
|
||||
HeaderMap::new(),
|
||||
Redirect::to(
|
||||
"/error?message=unregistered_request_uri"
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let state = if let Some(s) = params.0.state.clone() {
|
||||
s
|
||||
} else if params.0.request_uri.is_some() {
|
||||
let mut url = params.0.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair(
|
||||
"error",
|
||||
CoreAuthErrorResponseType::RequestUriNotSupported.as_ref(),
|
||||
);
|
||||
return Ok((
|
||||
HeaderMap::new(),
|
||||
Redirect::to(
|
||||
url.as_str()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
} else if params.0.request.is_some() {
|
||||
let mut url = params.0.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair(
|
||||
"error",
|
||||
CoreAuthErrorResponseType::RequestNotSupported.as_ref(),
|
||||
);
|
||||
return Ok((
|
||||
HeaderMap::new(),
|
||||
Redirect::to(
|
||||
url.as_str()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
} else {
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut()
|
||||
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
|
||||
url.query_pairs_mut()
|
||||
.append_pair("error_description", "Missing state");
|
||||
return Ok((
|
||||
HeaderMap::new(),
|
||||
Redirect::to(
|
||||
url.as_str()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
};
|
||||
|
||||
if let Some(CoreAuthPrompt::None) = params.0.prompt {
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair("state", &state);
|
||||
url.query_pairs_mut().append_pair(
|
||||
"error",
|
||||
CoreAuthErrorResponseType::InteractionRequired.as_ref(),
|
||||
);
|
||||
return Ok((
|
||||
HeaderMap::new(),
|
||||
Redirect::to(
|
||||
url.as_str()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
if params.0.response_type.is_none() {
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair("state", &state);
|
||||
url.query_pairs_mut()
|
||||
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
|
||||
url.query_pairs_mut()
|
||||
.append_pair("error_description", "Missing response_type");
|
||||
return Ok((
|
||||
HeaderMap::new(),
|
||||
Redirect::to(
|
||||
url.as_str()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
}
|
||||
let response_type = params.0.response_type.as_ref().unwrap();
|
||||
|
||||
if params.scope != Scope::new("openid".to_string()) {
|
||||
return Err(anyhow!("Scope not supported").into());
|
||||
}
|
||||
|
||||
let (nonce, headers) = match session {
|
||||
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
|
||||
UserSessionFromSession::Invalid(cookie) => {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::SET_COOKIE, cookie);
|
||||
return Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{}",
|
||||
¶ms.0.client_id,
|
||||
¶ms.0.redirect_uri.to_string(),
|
||||
¶ms.0.scope.to_string(),
|
||||
&response_type.as_ref(),
|
||||
&state,
|
||||
¶ms.0.client_id,
|
||||
¶ms.0.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or_default()
|
||||
)
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
}
|
||||
UserSessionFromSession::Created { header, nonce } => {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::SET_COOKIE, header);
|
||||
(nonce, headers)
|
||||
}
|
||||
};
|
||||
|
||||
let domain = params.redirect_uri.url().host().unwrap();
|
||||
let oidc_nonce_param = if let Some(n) = ¶ms.nonce {
|
||||
format!("&oidc_nonce={}", n.secret())
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}",
|
||||
nonce,
|
||||
domain,
|
||||
params.redirect_uri.to_string(),
|
||||
state,
|
||||
params.client_id,
|
||||
oidc_nonce_param
|
||||
)
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SiweCookie {
|
||||
message: Web3ModalMessage,
|
||||
signature: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct Web3ModalMessage {
|
||||
pub domain: String,
|
||||
pub address: String,
|
||||
pub statement: String,
|
||||
pub uri: String,
|
||||
pub version: String,
|
||||
pub chain_id: String,
|
||||
pub nonce: String,
|
||||
pub issued_at: String,
|
||||
pub expiration_time: Option<String>,
|
||||
pub not_before: Option<String>,
|
||||
pub request_id: Option<String>,
|
||||
pub resources: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Web3ModalMessage {
|
||||
pub fn to_eip4361_message(&self) -> Result<Message> {
|
||||
let mut next_resources: Vec<UriString> = Vec::new();
|
||||
match &self.resources {
|
||||
Some(resources) => {
|
||||
for resource in resources {
|
||||
let x = UriString::from_str(resource)?;
|
||||
next_resources.push(x)
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
Ok(Message {
|
||||
domain: self.domain.clone().try_into()?,
|
||||
address: <[u8; 20]>::from_hex(self.address.chars().skip(2).collect::<String>())?,
|
||||
statement: self.statement.to_string(),
|
||||
uri: UriAbsoluteString::from_str(&self.uri)?,
|
||||
version: Version::from_str(&self.version)?,
|
||||
chain_id: self.chain_id.to_string(),
|
||||
nonce: self.nonce.to_string(),
|
||||
issued_at: self.issued_at.to_string(),
|
||||
expiration_time: self.expiration_time.clone(),
|
||||
not_before: self.not_before.clone(),
|
||||
request_id: self.request_id.clone(),
|
||||
resources: next_resources,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct CodeEntry {
|
||||
exchange_count: usize,
|
||||
address: String,
|
||||
nonce: Option<Nonce>,
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SignInParams {
|
||||
redirect_uri: RedirectUrl,
|
||||
state: String,
|
||||
oidc_nonce: Option<Nonce>,
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
async fn sign_in(
|
||||
session: UserSessionFromSession,
|
||||
params: Query<SignInParams>,
|
||||
TypedHeader(cookies): TypedHeader<headers::Cookie>,
|
||||
Extension(pool): Extension<ConnectionPool>,
|
||||
) -> Result<(HeaderMap, Redirect), CustomError> {
|
||||
let mut headers = HeaderMap::new();
|
||||
let siwe_cookie: SiweCookie = match cookies.get("siwe") {
|
||||
Some(c) => serde_json::from_str(
|
||||
&decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?,
|
||||
None => {
|
||||
return Err(anyhow!("No `siwe` cookie").into());
|
||||
}
|
||||
};
|
||||
|
||||
let (nonce, headers) = match session {
|
||||
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
|
||||
UserSessionFromSession::Invalid(header) => {
|
||||
headers.insert(header::SET_COOKIE, header);
|
||||
return Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
|
||||
¶ms.0.client_id.clone(),
|
||||
¶ms.0.redirect_uri.to_string(),
|
||||
¶ms.0.state,
|
||||
)
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
}
|
||||
UserSessionFromSession::Created { .. } => {
|
||||
return Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
|
||||
¶ms.0.client_id.clone(),
|
||||
¶ms.0.redirect_uri.to_string(),
|
||||
¶ms.0.state,
|
||||
)
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let signature = match <[u8; 65]>::from_hex(
|
||||
siwe_cookie
|
||||
.signature
|
||||
.chars()
|
||||
.skip(2)
|
||||
.take(130)
|
||||
.collect::<String>(),
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(CustomError::BadRequest(format!("Bad signature: {}", e)));
|
||||
}
|
||||
};
|
||||
|
||||
let message = siwe_cookie
|
||||
.message
|
||||
.to_eip4361_message()
|
||||
.map_err(|e| anyhow!("Failed to serialise message: {}", e))?;
|
||||
info!("{}", message);
|
||||
message
|
||||
.verify_eip191(signature)
|
||||
.map_err(|e| anyhow!("Failed signature validation: {}", e))?;
|
||||
|
||||
let domain = params.redirect_uri.url().host().unwrap();
|
||||
if domain.to_string() != siwe_cookie.message.domain {
|
||||
return Err(anyhow!("Conflicting domains in message and redirect").into());
|
||||
}
|
||||
if nonce != siwe_cookie.message.nonce {
|
||||
return Err(anyhow!("Conflicting nonces in message and session").into());
|
||||
}
|
||||
|
||||
let code_entry = CodeEntry {
|
||||
address: siwe_cookie.message.address,
|
||||
nonce: params.oidc_nonce.clone(),
|
||||
exchange_count: 0,
|
||||
client_id: params.0.client_id.clone(),
|
||||
};
|
||||
|
||||
let code = Uuid::new_v4();
|
||||
let mut conn = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
conn.set_ex(
|
||||
code.to_string(),
|
||||
hex::encode(
|
||||
bincode::serialize(&code_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
|
||||
),
|
||||
ENTRY_LIFETIME,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair("code", &code.to_string());
|
||||
url.query_pairs_mut().append_pair("state", ¶ms.state);
|
||||
Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
url.as_str()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
))
|
||||
// TODO clear session
|
||||
}
|
||||
|
||||
async fn register(
|
||||
extract::Json(payload): extract::Json<CoreClientMetadata>,
|
||||
Extension(pool): Extension<ConnectionPool>,
|
||||
) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> {
|
||||
let id = Uuid::new_v4();
|
||||
let secret = Uuid::new_v4();
|
||||
|
||||
let conn = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
let entry = ClientEntry {
|
||||
secret: secret.to_string(),
|
||||
redirect_uris: payload.redirect_uris().to_vec(),
|
||||
};
|
||||
set_client(conn, id.to_string(), entry).await?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
CoreClientRegistrationResponse::new(
|
||||
ClientId::new(id.to_string()),
|
||||
payload.redirect_uris().to_vec(),
|
||||
EmptyAdditionalClientMetadata::default(),
|
||||
EmptyAdditionalClientRegistrationResponse::default(),
|
||||
)
|
||||
.set_client_secret(Some(ClientSecret::new(secret.to_string())))
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
|
||||
// TODO CORS
|
||||
// TODO need validation of the token
|
||||
// TODO restrict access token use to only once?
|
||||
async fn userinfo(
|
||||
// access_token: AccessTokenUserInfo, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
|
||||
TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
|
||||
Extension(pool): Extension<ConnectionPool>,
|
||||
) -> Result<Json<CoreUserInfoClaims>, CustomError> {
|
||||
let code = bearer.token().to_string();
|
||||
let mut conn = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
let serialized_entry: Option<Vec<u8>> = conn
|
||||
.get(code)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if serialized_entry.is_none() {
|
||||
return Err(CustomError::BadRequest("Unknown code.".to_string()));
|
||||
}
|
||||
let code_entry: CodeEntry = bincode::deserialize(
|
||||
&hex::decode(serialized_entry.unwrap())
|
||||
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
|
||||
|
||||
Ok(CoreUserInfoClaims::new(
|
||||
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
|
||||
EmptyAdditionalClaims::default(),
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
async fn healthcheck() {}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let config = Figment::from(Serialized::defaults(config::Config::default()))
|
||||
.merge(Toml::file("siwe-oidc.toml").nested())
|
||||
.merge(Env::prefixed("SIWEOIDC_").split("__").global());
|
||||
let config = config.extract::<config::Config>().unwrap();
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap();
|
||||
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
|
||||
let pool2 = bb8::Pool::builder().build(manager).await.unwrap();
|
||||
|
||||
for (id, secret) in &config.default_clients.clone() {
|
||||
let conn = pool2
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))
|
||||
.unwrap();
|
||||
let client_entry = ClientEntry {
|
||||
secret: secret.to_string(),
|
||||
redirect_uris: vec![],
|
||||
};
|
||||
set_client(conn, id.to_string(), client_entry)
|
||||
.await
|
||||
.unwrap(); // TODO
|
||||
}
|
||||
|
||||
let private_key = if let Some(key) = &config.rsa_pem {
|
||||
RsaPrivateKey::from_pkcs1_pem(key)
|
||||
.map_err(|e| anyhow!("Failed to load private key: {}", e))
|
||||
.unwrap()
|
||||
} else {
|
||||
info!("Generating key...");
|
||||
let mut rng = OsRng;
|
||||
let bits = 2048;
|
||||
let private = RsaPrivateKey::new(&mut rng, bits)
|
||||
.map_err(|e| anyhow!("Failed to generate a key: {}", e))
|
||||
.unwrap();
|
||||
|
||||
info!("Generated key.");
|
||||
info!("{:?}", private.to_pkcs1_pem().unwrap());
|
||||
private
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.nest(
|
||||
"/build",
|
||||
service_method_routing::get(ServeDir::new("./static/build")).handle_error(
|
||||
|error: std::io::Error| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.nest(
|
||||
"/img",
|
||||
service_method_routing::get(ServeDir::new("./static/img")).handle_error(
|
||||
|error: std::io::Error| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route(
|
||||
"/",
|
||||
service_method_routing::get(ServeFile::new("./static/index.html")).handle_error(
|
||||
|error: std::io::Error| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route(
|
||||
"/error",
|
||||
service_method_routing::get(ServeFile::new("./static/error.html")).handle_error(
|
||||
|error: std::io::Error| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route(
|
||||
"/favicon.png",
|
||||
service_method_routing::get(ServeFile::new("./static/favicon.png")).handle_error(
|
||||
|error: std::io::Error| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route("/.well-known/openid-configuration", get(provider_metadata))
|
||||
.route("/jwk", get(jwk_set))
|
||||
.route("/token", post(token))
|
||||
.route("/authorize", get(authorize))
|
||||
.route("/register", post(register))
|
||||
.route("/userinfo", get(userinfo).post(userinfo))
|
||||
.route("/sign_in", get(sign_in))
|
||||
.route("/health", get(healthcheck))
|
||||
.layer(AddExtensionLayer::new(private_key))
|
||||
.layer(AddExtensionLayer::new(config.clone()))
|
||||
.layer(AddExtensionLayer::new(pool))
|
||||
.layer(AddExtensionLayer::new(
|
||||
RedisSessionStore::new(config.redis_url.clone())
|
||||
.unwrap()
|
||||
.with_prefix("async-sessions/"),
|
||||
))
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
let addr = SocketAddr::from((config.address, config.port));
|
||||
tracing::info!("Listening on {}", addr);
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
axum_main().await
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
fn main() {}
|
||||
|
523
src/oidc.rs
Normal file
523
src/oidc.rs
Normal file
@ -0,0 +1,523 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{Duration, Utc};
|
||||
use headers::{self, authorization::Bearer};
|
||||
use hex::FromHex;
|
||||
use iri_string::types::UriString;
|
||||
use openidconnect::{
|
||||
core::{
|
||||
CoreAuthErrorResponseType, CoreAuthPrompt, CoreClaimName, CoreClientAuthMethod,
|
||||
CoreClientMetadata, CoreClientRegistrationResponse, CoreErrorResponseType, CoreGrantType,
|
||||
CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
|
||||
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
|
||||
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
|
||||
},
|
||||
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
|
||||
url::Url,
|
||||
AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims,
|
||||
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId,
|
||||
JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, RequestUrl,
|
||||
ResponseTypes, Scope, StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
|
||||
};
|
||||
use rsa::{pkcs1::ToRsaPrivateKey, RsaPrivateKey};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use siwe::eip4361::{Message, Version};
|
||||
use std::str::FromStr;
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
use urlencoding::decode;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use super::db::*;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use siwe_oidc::db::*;
|
||||
|
||||
const KID: &str = "key1";
|
||||
pub const METADATA_PATH: &str = "/.well-known/openid-configuration";
|
||||
pub const JWK_PATH: &str = "/jwk";
|
||||
pub const TOKEN_PATH: &str = "/token";
|
||||
pub const AUTHORIZE_PATH: &str = "/authorize";
|
||||
pub const REGISTER_PATH: &str = "/register";
|
||||
pub const USERINFO_PATH: &str = "/userinfo";
|
||||
pub const SIGNIN_PATH: &str = "/sign_in";
|
||||
pub const SIWE_COOKIE_KEY: &str = "siwe";
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
type DBClientType = (dyn DBClient + Sync);
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
type DBClientType = dyn DBClient;
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct TokenError {
|
||||
pub error: CoreErrorResponseType,
|
||||
pub error_description: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CustomError {
|
||||
#[error("{0}")]
|
||||
BadRequest(String),
|
||||
#[error("{0:?}")]
|
||||
BadRequestToken(TokenError),
|
||||
#[error("{0}")]
|
||||
Unauthorized(String),
|
||||
#[error("{0:?}")]
|
||||
Redirect(String),
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
pub fn jwks(private_key: RsaPrivateKey) -> Result<CoreJsonWebKeySet, CustomError> {
|
||||
let pem = private_key
|
||||
.to_pkcs1_pem()
|
||||
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
|
||||
let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem(
|
||||
&pem,
|
||||
Some(JsonWebKeyId::new(KID.to_string())),
|
||||
)
|
||||
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?
|
||||
.as_verification_key()]);
|
||||
Ok(jwks)
|
||||
}
|
||||
|
||||
pub fn metadata(base_url: Url) -> Result<CoreProviderMetadata, CustomError> {
|
||||
let pm = CoreProviderMetadata::new(
|
||||
IssuerUrl::from_url(base_url.clone()),
|
||||
AuthUrl::from_url(
|
||||
base_url
|
||||
.join(AUTHORIZE_PATH)
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
),
|
||||
JsonWebKeySetUrl::from_url(
|
||||
base_url
|
||||
.join(JWK_PATH)
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
),
|
||||
vec![
|
||||
ResponseTypes::new(vec![CoreResponseType::Code]),
|
||||
ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]),
|
||||
],
|
||||
vec![CoreSubjectIdentifierType::Pairwise],
|
||||
vec![CoreJwsSigningAlgorithm::RsaSsaPssSha256],
|
||||
EmptyAdditionalProviderMetadata {},
|
||||
)
|
||||
.set_token_endpoint(Some(TokenUrl::from_url(
|
||||
base_url
|
||||
.join(TOKEN_PATH)
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
)))
|
||||
.set_userinfo_endpoint(Some(UserInfoUrl::from_url(
|
||||
base_url
|
||||
.join(USERINFO_PATH)
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
)))
|
||||
.set_scopes_supported(Some(vec![
|
||||
Scope::new("openid".to_string()),
|
||||
// Scope::new("email".to_string()),
|
||||
// Scope::new("profile".to_string()),
|
||||
]))
|
||||
.set_claims_supported(Some(vec![
|
||||
CoreClaimName::new("sub".to_string()),
|
||||
CoreClaimName::new("aud".to_string()),
|
||||
// CoreClaimName::new("email".to_string()),
|
||||
// CoreClaimName::new("email_verified".to_string()),
|
||||
CoreClaimName::new("exp".to_string()),
|
||||
CoreClaimName::new("iat".to_string()),
|
||||
CoreClaimName::new("iss".to_string()),
|
||||
// CoreClaimName::new("name".to_string()),
|
||||
// CoreClaimName::new("given_name".to_string()),
|
||||
// CoreClaimName::new("family_name".to_string()),
|
||||
// CoreClaimName::new("picture".to_string()),
|
||||
// CoreClaimName::new("locale".to_string()),
|
||||
]))
|
||||
.set_registration_endpoint(Some(RegistrationUrl::from_url(
|
||||
base_url
|
||||
.join(REGISTER_PATH)
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
)))
|
||||
.set_token_endpoint_auth_methods_supported(Some(vec![
|
||||
CoreClientAuthMethod::ClientSecretBasic,
|
||||
CoreClientAuthMethod::ClientSecretPost,
|
||||
]));
|
||||
|
||||
Ok(pm)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TokenForm {
|
||||
pub code: String,
|
||||
pub client_id: Option<String>,
|
||||
pub client_secret: Option<String>,
|
||||
pub grant_type: CoreGrantType, // TODO should just be authorization_code apparently?
|
||||
}
|
||||
|
||||
pub async fn token(
|
||||
form: TokenForm,
|
||||
// From the request's Authorization header
|
||||
secret: Option<String>,
|
||||
private_key: RsaPrivateKey,
|
||||
base_url: Url,
|
||||
require_secret: bool,
|
||||
db_client: &DBClientType,
|
||||
) -> Result<CoreTokenResponse, CustomError> {
|
||||
let code_entry = if let Some(c) = db_client.get_code(form.code.to_string()).await? {
|
||||
c
|
||||
} else {
|
||||
return Err(CustomError::BadRequestToken(TokenError {
|
||||
error: CoreErrorResponseType::InvalidGrant,
|
||||
error_description: "Unknown code.".to_string(),
|
||||
}));
|
||||
};
|
||||
|
||||
let client_id = if let Some(c) = form.client_id.clone() {
|
||||
c
|
||||
} else {
|
||||
code_entry.client_id.clone()
|
||||
};
|
||||
|
||||
if let Some(secret) = if let Some(b) = secret {
|
||||
Some(b)
|
||||
} else {
|
||||
form.client_secret.clone()
|
||||
} {
|
||||
let client_entry = db_client.get_client(client_id.clone()).await?;
|
||||
if client_entry.is_none() {
|
||||
return Err(CustomError::Unauthorized(
|
||||
"Unrecognised client id.".to_string(),
|
||||
));
|
||||
}
|
||||
if secret != client_entry.unwrap().secret {
|
||||
return Err(CustomError::Unauthorized("Bad secret.".to_string()));
|
||||
}
|
||||
} else if require_secret {
|
||||
return Err(CustomError::Unauthorized("Secret required.".to_string()));
|
||||
}
|
||||
|
||||
if code_entry.exchange_count > 0 {
|
||||
// TODO use Oauth error response
|
||||
return Err(CustomError::BadRequestToken(TokenError {
|
||||
error: CoreErrorResponseType::InvalidGrant,
|
||||
error_description: "Code was previously exchanged.".to_string(),
|
||||
}));
|
||||
}
|
||||
let mut code_entry2 = code_entry.clone();
|
||||
code_entry2.exchange_count += 1;
|
||||
db_client
|
||||
.set_code(form.code.to_string(), code_entry2)
|
||||
.await?;
|
||||
let access_token = AccessToken::new(form.code);
|
||||
let core_id_token = CoreIdTokenClaims::new(
|
||||
IssuerUrl::from_url(base_url),
|
||||
vec![Audience::new(client_id.clone())],
|
||||
Utc::now() + Duration::seconds(60),
|
||||
Utc::now(),
|
||||
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
|
||||
EmptyAdditionalClaims {},
|
||||
)
|
||||
.set_nonce(code_entry.nonce);
|
||||
|
||||
let pem = private_key
|
||||
.to_pkcs1_pem()
|
||||
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
|
||||
|
||||
let id_token = CoreIdToken::new(
|
||||
core_id_token,
|
||||
&CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string())))
|
||||
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?,
|
||||
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256,
|
||||
Some(&access_token),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| anyhow!("{}", e))?;
|
||||
|
||||
Ok(CoreTokenResponse::new(
|
||||
access_token,
|
||||
CoreTokenType::Bearer,
|
||||
CoreIdTokenFields::new(Some(id_token), EmptyExtraTokenFields {}),
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AuthorizeParams {
|
||||
pub client_id: String,
|
||||
pub redirect_uri: RedirectUrl,
|
||||
pub scope: Scope,
|
||||
pub response_type: Option<CoreResponseType>,
|
||||
pub state: Option<String>,
|
||||
pub nonce: Option<Nonce>,
|
||||
pub prompt: Option<CoreAuthPrompt>,
|
||||
pub request_uri: Option<RequestUrl>,
|
||||
pub request: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn authorize(
|
||||
params: AuthorizeParams,
|
||||
nonce: String,
|
||||
db_client: &DBClientType,
|
||||
) -> Result<String, CustomError> {
|
||||
let client_entry = db_client
|
||||
.get_client(params.client_id.clone())
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if client_entry.is_none() {
|
||||
return Err(CustomError::Unauthorized(
|
||||
"Unrecognised client id.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut r_u = params.redirect_uri.clone().url().clone();
|
||||
r_u.set_query(None);
|
||||
let mut r_us: Vec<Url> = client_entry
|
||||
.unwrap()
|
||||
.redirect_uris
|
||||
.iter_mut()
|
||||
.map(|u| u.url().clone())
|
||||
.collect();
|
||||
r_us.iter_mut().for_each(|u| u.set_query(None));
|
||||
if !r_us.contains(&r_u) {
|
||||
return Err(CustomError::Redirect(
|
||||
"/error?message=unregistered_request_uri".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let state = if let Some(s) = params.state.clone() {
|
||||
s
|
||||
} else if params.request_uri.is_some() {
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair(
|
||||
"error",
|
||||
CoreAuthErrorResponseType::RequestUriNotSupported.as_ref(),
|
||||
);
|
||||
return Err(CustomError::Redirect(url.to_string()));
|
||||
} else if params.request.is_some() {
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair(
|
||||
"error",
|
||||
CoreAuthErrorResponseType::RequestNotSupported.as_ref(),
|
||||
);
|
||||
return Err(CustomError::Redirect(url.to_string()));
|
||||
} else {
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut()
|
||||
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
|
||||
url.query_pairs_mut()
|
||||
.append_pair("error_description", "Missing state");
|
||||
return Err(CustomError::Redirect(url.to_string()));
|
||||
};
|
||||
|
||||
if let Some(CoreAuthPrompt::None) = params.prompt {
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair("state", &state);
|
||||
url.query_pairs_mut().append_pair(
|
||||
"error",
|
||||
CoreAuthErrorResponseType::InteractionRequired.as_ref(),
|
||||
);
|
||||
return Err(CustomError::Redirect(url.to_string()));
|
||||
}
|
||||
|
||||
if params.response_type.is_none() {
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair("state", &state);
|
||||
url.query_pairs_mut()
|
||||
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
|
||||
url.query_pairs_mut()
|
||||
.append_pair("error_description", "Missing response_type");
|
||||
return Err(CustomError::Redirect(url.to_string()));
|
||||
}
|
||||
let _response_type = params.response_type.as_ref().unwrap();
|
||||
|
||||
if params.scope != Scope::new("openid".to_string()) {
|
||||
return Err(anyhow!("Scope not supported").into());
|
||||
}
|
||||
|
||||
let domain = params.redirect_uri.url().host().unwrap();
|
||||
let oidc_nonce_param = if let Some(n) = ¶ms.nonce {
|
||||
format!("&oidc_nonce={}", n.secret())
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
Ok(format!(
|
||||
"/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}",
|
||||
nonce,
|
||||
domain,
|
||||
params.redirect_uri.to_string(),
|
||||
state,
|
||||
params.client_id,
|
||||
oidc_nonce_param
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SiweCookie {
|
||||
message: Web3ModalMessage,
|
||||
signature: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct Web3ModalMessage {
|
||||
pub domain: String,
|
||||
pub address: String,
|
||||
pub statement: String,
|
||||
pub uri: String,
|
||||
pub version: String,
|
||||
pub chain_id: String,
|
||||
pub nonce: String,
|
||||
pub issued_at: String,
|
||||
pub expiration_time: Option<String>,
|
||||
pub not_before: Option<String>,
|
||||
pub request_id: Option<String>,
|
||||
pub resources: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Web3ModalMessage {
|
||||
fn to_eip4361_message(&self) -> Result<Message> {
|
||||
let mut next_resources: Vec<UriString> = Vec::new();
|
||||
match &self.resources {
|
||||
Some(resources) => {
|
||||
for resource in resources {
|
||||
let x = UriString::from_str(resource)?;
|
||||
next_resources.push(x)
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
Ok(Message {
|
||||
domain: self.domain.clone().try_into()?,
|
||||
address: <[u8; 20]>::from_hex(self.address.chars().skip(2).collect::<String>())?,
|
||||
statement: self.statement.to_string(),
|
||||
uri: UriString::from_str(&self.uri)?,
|
||||
version: Version::from_str(&self.version)?,
|
||||
chain_id: self.chain_id.to_string(),
|
||||
nonce: self.nonce.to_string(),
|
||||
issued_at: self.issued_at.to_string(),
|
||||
expiration_time: self.expiration_time.clone(),
|
||||
not_before: self.not_before.clone(),
|
||||
request_id: self.request_id.clone(),
|
||||
resources: next_resources,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SignInParams {
|
||||
pub redirect_uri: RedirectUrl,
|
||||
pub state: String,
|
||||
pub oidc_nonce: Option<Nonce>,
|
||||
pub client_id: String,
|
||||
}
|
||||
|
||||
pub async fn sign_in(
|
||||
params: SignInParams,
|
||||
expected_nonce: Option<String>,
|
||||
cookies: headers::Cookie,
|
||||
db_client: &DBClientType,
|
||||
) -> Result<Url, CustomError> {
|
||||
let siwe_cookie: SiweCookie = match cookies.get(SIWE_COOKIE_KEY) {
|
||||
Some(c) => serde_json::from_str(
|
||||
&decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?,
|
||||
None => {
|
||||
return Err(anyhow!("No `siwe` cookie").into());
|
||||
}
|
||||
};
|
||||
|
||||
let signature = match <[u8; 65]>::from_hex(
|
||||
siwe_cookie
|
||||
.signature
|
||||
.chars()
|
||||
.skip(2)
|
||||
.take(130)
|
||||
.collect::<String>(),
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(CustomError::BadRequest(format!("Bad signature: {}", e)));
|
||||
}
|
||||
};
|
||||
|
||||
let message = siwe_cookie
|
||||
.message
|
||||
.to_eip4361_message()
|
||||
.map_err(|e| anyhow!("Failed to serialise message: {}", e))?;
|
||||
info!("{}", message);
|
||||
message
|
||||
.verify(signature)
|
||||
.map_err(|e| anyhow!("Failed signature validation: {}", e))?;
|
||||
|
||||
let domain = params.redirect_uri.url().host().unwrap();
|
||||
if domain.to_string() != siwe_cookie.message.domain {
|
||||
return Err(anyhow!("Conflicting domains in message and redirect").into());
|
||||
}
|
||||
if expected_nonce.is_some() && expected_nonce.unwrap() != siwe_cookie.message.nonce {
|
||||
return Err(anyhow!("Conflicting nonces in message and session").into());
|
||||
}
|
||||
|
||||
let code_entry = CodeEntry {
|
||||
address: siwe_cookie.message.address,
|
||||
nonce: params.oidc_nonce.clone(),
|
||||
exchange_count: 0,
|
||||
client_id: params.client_id.clone(),
|
||||
};
|
||||
|
||||
let code = Uuid::new_v4();
|
||||
db_client.set_code(code.to_string(), code_entry).await?;
|
||||
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair("code", &code.to_string());
|
||||
url.query_pairs_mut().append_pair("state", ¶ms.state);
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
pub async fn register(
|
||||
payload: CoreClientMetadata,
|
||||
db_client: &DBClientType,
|
||||
) -> Result<CoreClientRegistrationResponse, CustomError> {
|
||||
let id = Uuid::new_v4();
|
||||
let secret = Uuid::new_v4();
|
||||
|
||||
let entry = ClientEntry {
|
||||
secret: secret.to_string(),
|
||||
redirect_uris: payload.redirect_uris().to_vec(),
|
||||
};
|
||||
db_client.set_client(id.to_string(), entry).await?;
|
||||
|
||||
Ok(CoreClientRegistrationResponse::new(
|
||||
ClientId::new(id.to_string()),
|
||||
payload.redirect_uris().to_vec(),
|
||||
EmptyAdditionalClientMetadata::default(),
|
||||
EmptyAdditionalClientRegistrationResponse::default(),
|
||||
)
|
||||
.set_client_secret(Some(ClientSecret::new(secret.to_string()))))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UserInfoPayload {
|
||||
pub access_token: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn userinfo(
|
||||
bearer: Option<Bearer>,
|
||||
payload: UserInfoPayload,
|
||||
db_client: &DBClientType,
|
||||
) -> Result<CoreUserInfoClaims, CustomError> {
|
||||
let code = if let Some(b) = bearer {
|
||||
b.token().to_string()
|
||||
} else if let Some(c) = payload.access_token {
|
||||
c
|
||||
} else {
|
||||
return Err(CustomError::BadRequest("Missing access token.".to_string()));
|
||||
};
|
||||
let code_entry = if let Some(c) = db_client.get_code(code).await? {
|
||||
c
|
||||
} else {
|
||||
return Err(CustomError::BadRequest("Unknown code.".to_string()));
|
||||
};
|
||||
|
||||
Ok(CoreUserInfoClaims::new(
|
||||
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
|
||||
EmptyAdditionalClaims::default(),
|
||||
))
|
||||
}
|
210
src/worker_lib.rs
Normal file
210
src/worker_lib.rs
Normal file
@ -0,0 +1,210 @@
|
||||
use anyhow::anyhow;
|
||||
use headers::{
|
||||
self,
|
||||
authorization::{Basic, Bearer, Credentials},
|
||||
Authorization, Header, HeaderValue,
|
||||
};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use rsa::{pkcs1::FromRsaPrivateKey, RsaPrivateKey};
|
||||
use worker::*;
|
||||
|
||||
use super::db::CFClient;
|
||||
use super::oidc::{self, CustomError, TokenForm, UserInfoPayload};
|
||||
|
||||
const BASE_URL_KEY: &str = "BASE_URL";
|
||||
const RSA_PEM_KEY: &str = "RSA_PEM";
|
||||
|
||||
// https://github.com/cloudflare/workers-rs/issues/64
|
||||
// #[global_allocator]
|
||||
// static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
|
||||
|
||||
impl From<CustomError> for Result<Response> {
|
||||
fn from(error: CustomError) -> Self {
|
||||
match error {
|
||||
CustomError::BadRequest(_) => Response::error(&error.to_string(), 400),
|
||||
CustomError::BadRequestToken(e) => Response::from_json(&e).map(|r| r.with_status(400)),
|
||||
CustomError::Unauthorized(_) => Response::error(&error.to_string(), 401),
|
||||
CustomError::Redirect(uri) => Response::redirect(uri.parse().unwrap()),
|
||||
CustomError::Other(_) => Response::error(&error.to_string(), 500),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn main(req: Request, env: Env) -> Result<Response> {
|
||||
console_error_panic_hook::set_once();
|
||||
// tracing_subscriber::fmt::init();
|
||||
// console_log::init_with_level(log::Level::Info).expect("error initializing log");
|
||||
|
||||
let userinfo = |mut req: Request, ctx: RouteContext<()>| async move {
|
||||
let bearer = req
|
||||
.headers()
|
||||
.get(Authorization::<Bearer>::name().as_str())?
|
||||
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
|
||||
.as_ref()
|
||||
.and_then(Bearer::decode);
|
||||
let payload = if bearer.is_none() {
|
||||
match req.form_data().await {
|
||||
Ok(f) => {
|
||||
let access_token = if let Some(FormEntry::Field(a)) = f.get("access_token") {
|
||||
Some(a)
|
||||
} else {
|
||||
return Response::error("Missing code", 400);
|
||||
};
|
||||
UserInfoPayload { access_token }
|
||||
}
|
||||
Err(_) => return Response::error("Bad request", 400),
|
||||
}
|
||||
} else {
|
||||
UserInfoPayload { access_token: None }
|
||||
};
|
||||
let url = req.url()?;
|
||||
let db_client = CFClient { ctx, url };
|
||||
match oidc::userinfo(bearer, payload, &db_client).await {
|
||||
Ok(r) => Ok(Response::from_json(&r)?),
|
||||
Err(e) => e.into(),
|
||||
}
|
||||
};
|
||||
|
||||
let router = Router::new();
|
||||
router
|
||||
.get_async(oidc::METADATA_PATH, |_req, ctx| async move {
|
||||
match oidc::metadata(ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap()) {
|
||||
Ok(m) => Response::from_json(&m),
|
||||
Err(e) => e.into(),
|
||||
}
|
||||
})
|
||||
.get_async(oidc::JWK_PATH, |_req, ctx| async move {
|
||||
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
|
||||
.map_err(|e| anyhow!("Failed to load private key: {}", e))
|
||||
.unwrap();
|
||||
match oidc::jwks(private_key) {
|
||||
Ok(m) => Response::from_json(&m),
|
||||
Err(e) => e.into(),
|
||||
}
|
||||
})
|
||||
.post_async(oidc::TOKEN_PATH, |mut req, ctx| async move {
|
||||
let form_data = req.form_data().await?;
|
||||
let code = if let Some(FormEntry::Field(c)) = form_data.get("code") {
|
||||
c
|
||||
} else {
|
||||
return Response::error("Missing code", 400);
|
||||
};
|
||||
let client_id = match form_data.get("client_id") {
|
||||
Some(FormEntry::Field(c)) => Some(c),
|
||||
None => None,
|
||||
_ => return Response::error("Client ID not a field", 400),
|
||||
};
|
||||
let client_secret = match form_data.get("client_secret") {
|
||||
Some(FormEntry::Field(c)) => Some(c),
|
||||
None => None,
|
||||
_ => return Response::error("Client secret not a field", 400),
|
||||
};
|
||||
let grant_type = if let Some(FormEntry::Field(c)) = form_data.get("code") {
|
||||
if let Ok(cc) = serde_json::from_str(&format!("\"{}\"", c)) {
|
||||
cc
|
||||
} else {
|
||||
return Response::error("Invalid grant type", 400);
|
||||
}
|
||||
} else {
|
||||
return Response::error("Missing grant type", 400);
|
||||
};
|
||||
let secret = req
|
||||
.headers()
|
||||
.get(Authorization::<Bearer>::name().as_str())?
|
||||
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
|
||||
.as_ref()
|
||||
.and_then(|b| {
|
||||
if b.to_str().unwrap().starts_with("Bearer") {
|
||||
Bearer::decode(b).map(|bb| bb.token().to_string())
|
||||
} else {
|
||||
Basic::decode(b).map(|bb| bb.password().to_string())
|
||||
}
|
||||
});
|
||||
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
|
||||
.map_err(|e| anyhow!("Failed to load private key: {}", e))
|
||||
.unwrap();
|
||||
let base_url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
|
||||
let url = req.url()?;
|
||||
let db_client = CFClient { ctx, url };
|
||||
let token_response = oidc::token(
|
||||
TokenForm {
|
||||
code,
|
||||
client_id,
|
||||
client_secret,
|
||||
grant_type,
|
||||
},
|
||||
secret,
|
||||
private_key,
|
||||
base_url,
|
||||
false,
|
||||
&db_client,
|
||||
)
|
||||
.await;
|
||||
match token_response {
|
||||
Ok(m) => Response::from_json(&m),
|
||||
Err(e) => e.into(),
|
||||
}
|
||||
})
|
||||
// TODO add browser session
|
||||
.get_async(oidc::AUTHORIZE_PATH, |req, ctx| async move {
|
||||
let base_url: Url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
|
||||
let url = req.url()?;
|
||||
let query = url.query().unwrap_or_default();
|
||||
let params = match serde_urlencoded::from_str(query) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return CustomError::BadRequest("Bad query params".to_string()).into(),
|
||||
};
|
||||
let nonce = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(16)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
let url = req.url()?;
|
||||
let db_client = CFClient { ctx, url };
|
||||
match oidc::authorize(params, nonce, &db_client).await {
|
||||
Ok(url) => Response::redirect(base_url.join(&url).unwrap()),
|
||||
Err(e) => match e {
|
||||
CustomError::Redirect(url) => {
|
||||
CustomError::Redirect(base_url.join(&url).unwrap().to_string())
|
||||
}
|
||||
c => c,
|
||||
}
|
||||
.into(),
|
||||
}
|
||||
})
|
||||
.post_async(oidc::REGISTER_PATH, |mut req, ctx| async move {
|
||||
let payload = req.json().await?;
|
||||
let url = req.url()?;
|
||||
let db_client = CFClient { ctx, url };
|
||||
match oidc::register(payload, &db_client).await {
|
||||
Ok(r) => Ok(Response::from_json(&r)?.with_status(201)),
|
||||
Err(e) => e.into(),
|
||||
}
|
||||
})
|
||||
.post_async(oidc::USERINFO_PATH, userinfo)
|
||||
.get_async(oidc::USERINFO_PATH, userinfo)
|
||||
.get_async(oidc::SIGNIN_PATH, |req, ctx| async move {
|
||||
let url = req.url()?;
|
||||
let query = url.query().unwrap_or_default();
|
||||
let params = match serde_urlencoded::from_str(query) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return CustomError::BadRequest("Bad query params".to_string()).into(),
|
||||
};
|
||||
let cookies = req
|
||||
.headers()
|
||||
.get(headers::Cookie::name().as_str())?
|
||||
.and_then(|c| HeaderValue::from_str(&c).ok())
|
||||
.and_then(|c| headers::Cookie::decode(&mut [c].iter()).ok());
|
||||
if cookies.is_none() {
|
||||
return Response::error("Missing cookies", 400);
|
||||
}
|
||||
let url = req.url()?;
|
||||
let db_client = CFClient { ctx, url };
|
||||
match oidc::sign_in(params, None, cookies.unwrap(), &db_client).await {
|
||||
Ok(url) => Response::redirect(url),
|
||||
Err(e) => e.into(),
|
||||
}
|
||||
})
|
||||
.run(req, env)
|
||||
.await
|
||||
}
|
35
wrangle_example.toml
Normal file
35
wrangle_example.toml
Normal file
@ -0,0 +1,35 @@
|
||||
name = "siwe_oidc"
|
||||
type = "javascript"
|
||||
account_id = ""
|
||||
# zone_id = ""
|
||||
workers_dev = false
|
||||
compatibility_date = "2021-12-20"
|
||||
|
||||
kv_namespaces = [
|
||||
{ binding = "SIWE-OIDC", id = "", preview_id = "" }
|
||||
]
|
||||
|
||||
[vars]
|
||||
WORKERS_RS_VERSION = "0.0.7"
|
||||
BASE_URL = "https://siweoidc.spruceid.xyz"
|
||||
|
||||
[durable_objects]
|
||||
bindings = [
|
||||
{ name = "SIWE-OIDC-CODES", class_name = "DOCodes" }
|
||||
]
|
||||
|
||||
[[migrations]]
|
||||
tag = "v1"
|
||||
new_classes = ["DOCodes"]
|
||||
|
||||
[build]
|
||||
command = "cargo install -q worker-build && worker-build --release"
|
||||
|
||||
[build.upload]
|
||||
dir = "build/worker"
|
||||
format = "modules"
|
||||
main = "./shim.mjs"
|
||||
|
||||
[[build.upload.rules]]
|
||||
globs = ["**/*.wasm"]
|
||||
type = "CompiledWasm"
|
Loading…
Reference in New Issue
Block a user