Cloudflare Worker version (#6)

Refactor/generalise API/DB interactions out of OIDC.
This commit is contained in:
Simon Bihel 2022-01-11 10:43:06 +00:00 committed by GitHub
parent 9d725552e0
commit bbcacf4232
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 3236 additions and 2854 deletions

View File

@ -8,12 +8,25 @@ env:
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
include:
- cargo_target: "x86_64-unknown-linux-gnu"
- cargo_target: "wasm32-unknown-unknown"
steps:
- name: Clone repo
uses: actions/checkout@master
- name: Add targets
run: rustup target add wasm32-unknown-unknown
- name: Build
env:
CARGO_BUILD_TARGET: ${{ matrix.cargo_target }}
run: cargo build --verbose
- name: Clippy
env:
CARGO_BUILD_TARGET: ${{ matrix.cargo_target }}
run: RUSTFLAGS="-Dwarnings" cargo clippy
- name: Fmt
env:
CARGO_BUILD_TARGET: ${{ matrix.cargo_target }}
run: cargo fmt -- --check

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
/target
/static/build
wrangler.toml

618
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -5,35 +5,68 @@ edition = "2021"
authors = ["Spruce Systems, Inc."]
license = "MIT OR Apache-2.0"
repository = "https://github.com/spruceid/siwe-oidc/"
description = "OpenID Connect Identity Provider for Sign-In with Ethereum."
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
crate-type = ["cdylib", "rlib"]
[dependencies]
anyhow = "1.0.51"
axum = { version = "0.3.4", features = ["headers"] }
chrono = "0.4.19"
headers = "0.3.5"
hex = "0.4.3"
iri-string = { version = "0.4", features = ["serde-std"] }
openidconnect = "2.1.2"
# openidconnect = "2.1.2"
openidconnect = { git = "https://github.com/sbihel/openidconnect-rs", branch = "main", default-features = false, features = ["reqwest", "rustls-tls", "rustcrypto"] }
rand = "0.8.4"
rsa = { version = "0.5.0", features = ["alloc"] }
rust-argon2 = "0.8"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.72"
siwe = "0.1"
async-session = "3.0.0"
siwe = "0.1.2"
thiserror = "1.0.30"
tokio = { version = "1.14.0", features = ["full"] }
tower-http = { version = "0.2.0", features = ["fs", "trace", "cors"] }
tracing = "0.1.29"
tracing-subscriber = { version = "0.3.2", features = ["env-filter"] }
url = { version = "2.2", features = ["serde"] }
urlencoding = "2.1.0"
uuid = { version = "0.8", features = ["serde", "v4"] }
figment = { version = "0.10.6", features = ["toml", "env"] }
sha2 = "0.9.0"
cookie = "0.15.1"
bincode = "1.3.3"
async-trait = "0.1.52"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
async-session = "3.0.0"
axum = { version = "0.4.3", features = ["headers"] }
# axum-debug = "0.3.2"
chrono = "0.4.19"
figment = { version = "0.10.6", features = ["toml", "env"] }
tokio = { version = "1.14.0", features = ["full"] }
tower-http = { version = "0.2.0", features = ["fs", "trace", "cors"] }
tracing-subscriber = { version = "0.3.2", features = ["env-filter"] }
bb8-redis = "0.10.1"
async-redis-session = "0.2.2"
uuid = { version = "0.8", features = ["serde", "v4"] }
[target.'cfg(target_arch = "wasm32")'.dependencies]
# cached = { version = "0.26", default-features = false }
chrono = { version = "0.4.19", features = ["wasmbind"] }
console_error_panic_hook = { version = "0.1" }
# console_log = "0.2"
getrandom = { version = "0.2", features = ["js"] }
# log = "0.4"
matchit = "0.4.2"
serde_urlencoded = "0.7.0"
uuid = { version = "0.8", features = ["serde", "v4", "wasm-bindgen"] }
wee_alloc = { version = "0.4" }
worker = "0.0.7"
[profile.release]
opt-level = "z"
lto = true
# [target.'cfg(target_arch = "wasm32")'.profile.release]
# opt-level = "z"
# [target.'cfg(target_arch = "wasm32")'.profile.debug]
# opt-level = "z"
# lto = false
[package.metadata.wasm-pack.profile.profiling]
wasm-opt = ['-g', '-O']

View File

@ -2,11 +2,54 @@
## Getting Started
### Dependencies
Two versions are available, a stand-alone binary (using Axum and Redis) and a
Cloudflare Worker. They use the same code base and are selected at compile time
(compiling for `wasm32` will make the Worker version).
### Cloudflare Worker
You will need [`wrangler`](https://github.com/cloudflare/wrangler).
Then copy the configuration file template:
```bash
cp wrangler_example.toml wrangler.toml
```
Replacing the following fields:
- `account_id`: your Cloudflare account ID;
- `zone_id`: (Optional) DNS zone ID; and
- `kv_namespaces`: a KV namespace ID (created with `wrangler kv:namespace create SIWE-OIDC`).
At this point, you should be able to create/publish the worker:
```
wrangler publish
```
The IdP currently only supports having the **frontend under the same subdomain as
the API**. Here is the configuration for Cloudflare Pages:
- `Build command`: `cd js/ui && npm install && npm run build`;
- `Build output directory`: `/static`; and
- `Root directory`: `/`.
And you will need to add some rules to do the routing between the Page and the
Worker. Here are the rules for the Worker (the Page being used as the fallback
on the subdomain):
```
siweoidc.example.com/s*
siweoidc.example.com/u*
siweoidc.example.com/r*
siweoidc.example.com/a*
siweoidc.example.com/t*
siweoidc.example.com/j*
siweoidc.example.com/.w*
```
### Stand-Alone Binary
#### Dependencies
Redis, or a Redis compatible database (e.g. MemoryDB in AWS), is required.
### Starting the IdP
#### Starting the IdP
The Docker image is available at `ghcr.io/spruceid/siwe_oidc:0.1.0`. Here is an
example usage:
@ -35,9 +78,23 @@ For the core OIDC information, it is available under
* Additional information, from native projects (e.g. ENS domains), to more
traditional ones (e.g. email).
* PKCE support (code challenge).
* Browser session support for the Worker version.
## Development
### Cloudflare Worker
```bash
wrangler dev
```
You can now use http://127.0.0.1:8787/.well-known/openid-configuration.
> At the moment it's not possible to use it end-to-end with the frontend as they
> need to share the same host (i.e. port), unless using a local load-balancer.
### Stand Alone Binary
A Docker Compose is available to test the IdP locally with Keycloak.
1. You will first need to run:

2895
js/ui/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -2,28 +2,28 @@
"name": "svelte-app",
"version": "1.0.0",
"devDependencies": {
"@tsconfig/svelte": "^1.0.10",
"@types/node": "^14.11.1",
"@typescript-eslint/eslint-plugin": "^4.21.0",
"@typescript-eslint/parser": "^4.21.0",
"@tsconfig/svelte": "^3.0.0",
"@types/node": "^17.0.7",
"@typescript-eslint/eslint-plugin": "^5.9.0",
"@typescript-eslint/parser": "^5.9.0",
"assert": "^2.0.0",
"autoprefixer": "^10.2.5",
"base64-loader": "^1.0.0",
"buffer": "^6.0.3",
"cross-env": "^7.0.3",
"crypto-browserify": "^3.12.0",
"css-loader": "^5.0.1",
"css-loader": "^6.5.1",
"cssnano": "^5.0.8",
"dotenv-webpack": "^7.0.3",
"eslint": "^7.23.0",
"eslint": "^8.6.0",
"eslint-config-prettier": "^8.1.0",
"eslint-plugin-svelte3": "^3.1.2",
"https-browserify": "^1.0.0",
"mini-css-extract-plugin": "^1.3.4",
"mini-css-extract-plugin": "^2.4.5",
"os-browserify": "^0.3.0",
"postcss": "^8.2.8",
"postcss-load-config": "^3.0.1",
"postcss-loader": "^5.2.0",
"postcss-loader": "^6.2.1",
"precss": "^4.0.0",
"prettier": "^2.2.1",
"prettier-plugin-svelte": "^2.2.0",
@ -31,12 +31,12 @@
"stream-browserify": "^3.0.0",
"stream-http": "^3.2.0",
"svelte": "^3.31.2",
"svelte-check": "^1.0.46",
"svelte-check": "^2.2.11",
"svelte-loader": "^3.0.0",
"svelte-preprocess": "^4.3.0",
"svg-url-loader": "^7.1.1",
"tailwindcss": "^2.0.4",
"ts-loader": "^8.0.4",
"tailwindcss": "^3.0.9",
"ts-loader": "^9.2.6",
"tslib": "^2.0.1",
"typescript": "^4.0.3",
"webpack": "^5.16.0",
@ -54,6 +54,7 @@
"@toruslabs/torus-embed": "^1.18.3",
"@walletconnect/web3-provider": "^1.6.6",
"fortmatic": "^2.2.1",
"url": "^0.11.0",
"walletlink": "^2.2.8"
}
}

View File

@ -1,5 +1,8 @@
{
"extends": "@tsconfig/svelte/tsconfig.json",
"include": ["src/**/*", "src/node_modules/**/*"],
"exclude": ["node_modules/*", "__sapper__/*", "static/*"]
"extends": "@tsconfig/svelte/tsconfig.json",
"include": ["src/**/*", "src/node_modules/**/*"],
"exclude": ["node_modules/*", "__sapper__/*", "static/*"],
"compilerOptions": {
"types": ["node", "svelte"]
}
}

View File

@ -27,6 +27,7 @@ module.exports = {
path: false,
process: require.resolve('process/browser'),
stream: require.resolve('stream-browserify'),
url: require.resolve("url")
// util: false,
}
},

364
src/axum_lib.rs Normal file
View File

@ -0,0 +1,364 @@
use anyhow::{anyhow, Result};
use async_redis_session::RedisSessionStore;
use axum::{
extract::{self, Extension, Form, Query, TypedHeader},
http::{
header::{self, HeaderMap},
StatusCode,
},
response::{self, IntoResponse, Redirect},
routing::{get, get_service, post},
AddExtensionLayer, Json, Router,
};
use bb8_redis::{bb8, RedisConnectionManager};
use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
};
use headers::{
self,
authorization::{Basic, Bearer},
Authorization,
};
use openidconnect::core::{
CoreClientMetadata, CoreClientRegistrationResponse, CoreJsonWebKeySet, CoreProviderMetadata,
CoreResponseType, CoreTokenResponse, CoreUserInfoClaims,
};
use rand::rngs::OsRng;
use rsa::{
pkcs1::{FromRsaPrivateKey, ToRsaPrivateKey},
RsaPrivateKey,
};
use std::net::SocketAddr;
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
use tracing::info;
use super::config;
use super::oidc::{self, CustomError};
use super::session::*;
use ::siwe_oidc::db::*;
impl IntoResponse for CustomError {
fn into_response(self) -> response::Response {
match self {
CustomError::BadRequest(_) => {
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
}
CustomError::BadRequestToken(e) => {
(StatusCode::BAD_REQUEST, Json::from(e)).into_response()
}
CustomError::Unauthorized(_) => {
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
}
CustomError::Redirect(uri) => Redirect::to(
uri.parse().unwrap(),
// .map_err(|e| anyhow!("Could not parse URI: {}", e))?,
)
.into_response(),
CustomError::Other(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
}
}
}
async fn jwk_set(
Extension(private_key): Extension<RsaPrivateKey>,
) -> Result<Json<CoreJsonWebKeySet>, CustomError> {
let jwks = oidc::jwks(private_key)?;
Ok(jwks.into())
}
async fn provider_metadata(
Extension(config): Extension<config::Config>,
) -> Result<Json<CoreProviderMetadata>, CustomError> {
Ok(oidc::metadata(config.base_url)?.into())
}
// TODO should check Authorization header
// Actually, client secret can be
// 1. in the POST (currently supported) [x]
// 2. Authorization header [x]
// 3. JWT [ ]
// 4. signed JWT [ ]
// according to Keycloak
async fn token(
Form(form): Form<oidc::TokenForm>,
bearer: Option<TypedHeader<Authorization<Bearer>>>,
basic: Option<TypedHeader<Authorization<Basic>>>,
Extension(private_key): Extension<RsaPrivateKey>,
Extension(config): Extension<config::Config>,
Extension(redis_client): Extension<RedisClient>,
) -> Result<Json<CoreTokenResponse>, CustomError> {
let secret = if let Some(b) = bearer {
Some(b.0 .0.token().to_string())
} else {
basic.map(|b| b.0 .0.password().to_string())
};
let token_response = oidc::token(
form,
secret,
private_key,
config.base_url,
config.require_secret,
&redis_client,
)
.await?;
Ok(token_response.into())
}
// TODO handle `registration` parameter
async fn authorize(
session: UserSessionFromSession,
Query(params): Query<oidc::AuthorizeParams>,
Extension(redis_client): Extension<RedisClient>,
) -> Result<(HeaderMap, Redirect), CustomError> {
let (nonce, headers) = match session {
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
UserSessionFromSession::Invalid(cookie) => {
let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, cookie);
return Ok((
headers,
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{}",
&params.client_id,
&params.redirect_uri.to_string(),
&params.scope.to_string(),
&params.response_type.unwrap_or(CoreResponseType::Code).as_ref(),
&params.state.unwrap_or_default(),
&params.client_id,
&params.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or_default()
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
UserSessionFromSession::Created { header, nonce } => {
let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, header);
(nonce, headers)
}
};
let url = oidc::authorize(params, nonce, &redis_client).await?;
Ok((
headers,
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
}
async fn sign_in(
session: UserSessionFromSession,
Query(params): Query<oidc::SignInParams>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
Extension(redis_client): Extension<RedisClient>,
) -> Result<(HeaderMap, Redirect), CustomError> {
let (nonce, headers) = match session {
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
UserSessionFromSession::Invalid(header) => {
let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, header);
return Ok((
headers,
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
&params.client_id.clone(),
&params.redirect_uri.to_string(),
&params.state,
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
UserSessionFromSession::Created { .. } => {
return Ok((
HeaderMap::new(),
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
&params.client_id.clone(),
&params.redirect_uri.to_string(),
&params.state,
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
}
};
let url = oidc::sign_in(params, Some(nonce), cookies, &redis_client).await?;
Ok((
headers,
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
// TODO clear session
}
async fn register(
extract::Json(payload): extract::Json<CoreClientMetadata>,
Extension(redis_client): Extension<RedisClient>,
) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> {
let registration = oidc::register(payload, &redis_client).await?;
Ok((StatusCode::CREATED, registration.into()))
}
// TODO CORS
// TODO need validation of the token
async fn userinfo(
payload: Option<Form<oidc::UserInfoPayload>>,
bearer: Option<TypedHeader<Authorization<Bearer>>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
Extension(redis_client): Extension<RedisClient>,
) -> Result<Json<CoreUserInfoClaims>, CustomError> {
let payload = if let Some(Form(p)) = payload {
p
} else {
oidc::UserInfoPayload { access_token: None }
};
let claims = oidc::userinfo(bearer.map(|b| b.0 .0), payload, &redis_client).await?;
Ok(claims.into())
}
async fn healthcheck() {}
pub async fn main() {
let config = Figment::from(Serialized::defaults(config::Config::default()))
.merge(Toml::file("siwe-oidc.toml").nested())
.merge(Env::prefixed("SIWEOIDC_").split("__").global());
let config = config.extract::<config::Config>().unwrap();
tracing_subscriber::fmt::init();
let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap();
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
// let pool2 = bb8::Pool::builder().build(manager).await.unwrap();
let redis_client = RedisClient { pool };
for (id, secret) in &config.default_clients.clone() {
let client_entry = ClientEntry {
secret: secret.to_string(),
redirect_uris: vec![],
};
redis_client
.set_client(id.to_string(), client_entry)
.await
.unwrap(); // TODO
}
let private_key = if let Some(key) = &config.rsa_pem {
RsaPrivateKey::from_pkcs1_pem(key)
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap()
} else {
info!("Generating key...");
let mut rng = OsRng;
let bits = 2048;
let private = RsaPrivateKey::new(&mut rng, bits)
.map_err(|e| anyhow!("Failed to generate a key: {}", e))
.unwrap();
info!("Generated key.");
info!("{:?}", private.to_pkcs1_pem().unwrap());
private
};
let app = Router::new()
.nest(
"/build",
get_service(ServeDir::new("./static/build")).handle_error(
|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.nest(
"/img",
get_service(ServeDir::new("./static/img")).handle_error(
|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/",
get_service(ServeFile::new("./static/index.html")).handle_error(
|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/error",
get_service(ServeFile::new("./static/error.html")).handle_error(
|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/favicon.png",
get_service(ServeFile::new("./static/favicon.png")).handle_error(
|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(oidc::METADATA_PATH, get(provider_metadata))
.route(oidc::JWK_PATH, get(jwk_set))
.route(oidc::TOKEN_PATH, post(token))
.route(oidc::AUTHORIZE_PATH, get(authorize))
.route(oidc::REGISTER_PATH, post(register))
.route(oidc::USERINFO_PATH, get(userinfo).post(userinfo))
.route(oidc::SIGNIN_PATH, get(sign_in))
.route("/health", get(healthcheck))
.layer(AddExtensionLayer::new(private_key))
.layer(AddExtensionLayer::new(config.clone()))
.layer(AddExtensionLayer::new(redis_client))
.layer(AddExtensionLayer::new(
RedisSessionStore::new(config.redis_url.clone())
.unwrap()
.with_prefix("async-sessions/"),
))
.layer(TraceLayer::new_for_http());
let addr = SocketAddr::from((config.address, config.port));
tracing::info!("Listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}

View File

@ -1,43 +0,0 @@
use anyhow::{anyhow, Result};
use bb8_redis::{bb8::PooledConnection, redis::AsyncCommands, RedisConnectionManager};
use openidconnect::RedirectUrl;
use serde::{Deserialize, Serialize};
const KV_CLIENT_PREFIX: &str = "clients";
#[derive(Serialize, Deserialize)]
pub struct ClientEntry {
pub secret: String,
pub redirect_uris: Vec<RedirectUrl>,
}
pub async fn set_client(
mut conn: PooledConnection<'_, RedisConnectionManager>,
client_id: String,
client_entry: ClientEntry,
) -> Result<()> {
conn.set(
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
serde_json::to_string(&client_entry)
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
Ok(())
}
pub async fn get_client(
mut conn: PooledConnection<'_, RedisConnectionManager>,
client_id: String,
) -> Result<Option<ClientEntry>> {
let entry: Option<String> = conn
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if let Some(e) = entry {
Ok(serde_json::from_str(&e)
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
} else {
Ok(None)
}
}

199
src/db/cf.rs Normal file
View File

@ -0,0 +1,199 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
// use cached::{stores::TimedCache, Cached};
use chrono::{DateTime, Duration, Utc};
use matchit::Node;
use std::collections::HashMap;
use worker::*;
use super::*;
const KV_NAMESPACE: &str = "SIWE-OIDC";
const DO_NAMESPACE: &str = "SIWE-OIDC-CODES";
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
// Heavily relying on:
// A Durable Object is given 30 seconds of additional CPU time for every
// request it processes, including WebSocket messages. In the absence of
// failures, in-memory state should not be reset after less than 30 seconds of
// inactivity.
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
// Wanted to use TimedCache but it (probably) crashes because it's using std::time::Instant which isn't available on wasm32.
#[durable_object]
pub struct DOCodes {
// codes: TimedCache<String, CodeEntry>,
codes: HashMap<String, (DateTime<Utc>, CodeEntry)>,
// state: State,
// env: Env,
}
#[durable_object]
impl DurableObject for DOCodes {
fn new(state: State, _env: Env) -> Self {
Self {
// codes: TimedCache::with_lifespan(ENTRY_LIFETIME.try_into().unwrap()),
codes: HashMap::new(),
// state,
// env,
}
}
async fn fetch(&mut self, mut req: Request) -> worker::Result<Response> {
// Can't use the Router because we need to reference self (thus move the var to the closure)
if matches!(req.method(), Method::Get) {
let mut matcher = Node::new();
matcher.insert("/:code", ())?;
let path = req.path();
let matched = match matcher.at(&path) {
Ok(m) => m,
Err(_) => return Response::error("Bad request", 400),
};
let code = if let Some(c) = matched.params.get("code") {
c
} else {
return Response::error("Bad request", 400);
};
if let Some(c) = self.codes.get(code) {
if c.0 + Duration::seconds(ENTRY_LIFETIME.try_into().unwrap()) < Utc::now() {
self.codes.remove(code);
Response::error("Not found", 404)
} else {
Response::from_json(&c.1)
}
} else {
Response::error("Not found", 404)
}
} else if matches!(req.method(), Method::Post) {
let mut matcher = Node::new();
matcher.insert("/:code", ())?;
let path = req.path();
let matched = match matcher.at(&path) {
Ok(m) => m,
Err(_) => return Response::error("Bad request", 400),
};
let code = if let Some(c) = matched.params.get("code") {
c
} else {
return Response::error("Bad request", 400);
};
let code_entry = match req.json().await {
Ok(p) => p,
Err(e) => return Response::error(format!("Bad request: {}", e), 400),
};
self.codes
.insert(code.to_string(), (Utc::now(), code_entry));
Response::empty()
} else {
Response::error("Method Not Allowed", 405)
}
}
}
pub struct CFClient {
pub ctx: RouteContext<()>,
pub url: Url,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl DBClient for CFClient {
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
self.ctx
.kv(KV_NAMESPACE)
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
.put(
&format!("{}/{}", KV_CLIENT_PREFIX, client_id),
serde_json::to_string(&client_entry)
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
)
.map_err(|e| anyhow!("Failed to build KV put: {}", e))?
// TODO put some sort of expiration for dynamic registration
.execute()
.await
.map_err(|e| anyhow!("Failed to put KV: {}", e))?;
Ok(())
}
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
let entry = self
.ctx
.kv(KV_NAMESPACE)
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
.get(&format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get KV: {}", e))?
.map(|e| e.as_string());
if let Some(e) = entry {
Ok(serde_json::from_str(&e)
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
} else {
Ok(None)
}
}
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
let namespace = self
.ctx
.durable_object(DO_NAMESPACE)
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
let stub = namespace
.id_from_name(&code)
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
.get_stub()
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
let mut headers = Headers::new();
headers.set("Content-Type", "application/json").unwrap();
let mut url = self.url.clone();
url.set_path(&code);
url.set_query(None);
let req = Request::new_with_init(
url.as_str(),
&RequestInit {
body: Some(wasm_bindgen::JsValue::from_str(
&serde_json::to_string(&code_entry)
.map_err(|e| anyhow!("Failed to serialize: {}", e))?,
)),
method: Method::Post,
headers,
..Default::default()
},
)
.map_err(|e| anyhow!("Failed to construct request for Durable Object: {}", e))?;
let res = stub
.fetch_with_request(req)
.await
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
match res.status_code() {
200 => Ok(()),
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
}
}
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
let namespace = self
.ctx
.durable_object(DO_NAMESPACE)
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
let stub = namespace
.id_from_name(&code)
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
.get_stub()
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
let mut url = self.url.clone();
url.set_path(&code);
url.set_query(None);
let mut res = stub
.fetch_with_str(url.as_str())
.await
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
match res.status_code() {
200 => Ok(Some(res.json().await.map_err(|e| {
anyhow!(
"Response to Durable Object failed to be deserialized: {}",
e
)
})?)),
404 => Ok(None),
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
}
}
}

40
src/db/mod.rs Normal file
View File

@ -0,0 +1,40 @@
use anyhow::Result;
use async_trait::async_trait;
use openidconnect::{Nonce, RedirectUrl};
use serde::{Deserialize, Serialize};
#[cfg(not(target_arch = "wasm32"))]
mod redis;
#[cfg(not(target_arch = "wasm32"))]
pub use redis::RedisClient;
#[cfg(target_arch = "wasm32")]
mod cf;
#[cfg(target_arch = "wasm32")]
pub use cf::CFClient;
const KV_CLIENT_PREFIX: &str = "clients";
const ENTRY_LIFETIME: usize = 30;
#[derive(Clone, Serialize, Deserialize)]
pub struct CodeEntry {
pub exchange_count: usize,
pub address: String,
pub nonce: Option<Nonce>,
pub client_id: String,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ClientEntry {
pub secret: String,
pub redirect_uris: Vec<RedirectUrl>,
}
// Using a trait to easily pass async functions with async_trait
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait DBClient {
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()>;
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>>;
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()>;
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>>;
}

89
src/db/redis.rs Normal file
View File

@ -0,0 +1,89 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
use super::*;
#[derive(Clone)]
pub struct RedisClient {
pub pool: Pool<RedisConnectionManager>,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl DBClient for RedisClient {
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
conn.set(
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
serde_json::to_string(&client_entry)
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
Ok(())
}
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let entry: Option<String> = conn
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if let Some(e) = entry {
Ok(serde_json::from_str(&e)
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
} else {
Ok(None)
}
}
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
conn.set_ex(
code.to_string(),
hex::encode(
bincode::serialize(&code_entry)
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
),
ENTRY_LIFETIME,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
Ok(())
}
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let serialized_entry: Option<Vec<u8>> = conn
.get(code)
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if serialized_entry.is_none() {
return Ok(None);
}
let code_entry: CodeEntry = bincode::deserialize(
&hex::decode(serialized_entry.unwrap())
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
)
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
Ok(Some(code_entry))
}
}

18
src/lib.rs Normal file
View File

@ -0,0 +1,18 @@
#[cfg(target_arch = "wasm32")]
use worker::*;
pub mod db;
#[cfg(target_arch = "wasm32")]
pub mod oidc;
#[cfg(target_arch = "wasm32")]
mod worker_lib;
#[cfg(target_arch = "wasm32")]
use worker_lib::main as worker_main;
// pub use worker_lib::main;
#[cfg(target_arch = "wasm32")]
#[event(fetch)]
pub async fn main(req: Request, env: Env) -> Result<Response> {
worker_main(req, env).await
}

View File

@ -1,882 +1,19 @@
use anyhow::{anyhow, Result};
use async_redis_session::RedisSessionStore;
use axum::{
body::{Bytes, Full},
error_handling::HandleErrorExt,
extract::{self, Extension, Form, Query, TypedHeader},
http::{
header::{self, HeaderMap},
Response, StatusCode,
},
response::{IntoResponse, Redirect},
routing::{get, post, service_method_routing},
AddExtensionLayer, Json, Router,
};
use bb8_redis::{bb8, bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
use chrono::{Duration, Utc};
use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
};
use headers::{self, authorization::Bearer, Authorization};
use hex::FromHex;
use iri_string::types::{UriAbsoluteString, UriString};
use openidconnect::{
core::{
CoreAuthErrorResponseType, CoreAuthPrompt, CoreClaimName, CoreClientAuthMethod,
CoreClientMetadata, CoreClientRegistrationResponse, CoreErrorResponseType, CoreGrantType,
CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
},
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
url::Url,
AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims,
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId,
JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, RequestUrl,
ResponseTypes, Scope, StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
};
use rand::rngs::OsRng;
use rsa::{
pkcs1::{FromRsaPrivateKey, ToRsaPrivateKey},
RsaPrivateKey,
};
use serde::{Deserialize, Serialize};
use siwe::eip4361::{Message, Version};
use std::{convert::Infallible, net::SocketAddr, str::FromStr};
use thiserror::Error;
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
use tracing::info;
use urlencoding::decode;
use uuid::Uuid;
#[cfg(not(target_arch = "wasm32"))]
mod axum_lib;
#[cfg(not(target_arch = "wasm32"))]
mod config;
mod db;
#[cfg(not(target_arch = "wasm32"))]
mod oidc;
#[cfg(not(target_arch = "wasm32"))]
mod session;
#[cfg(not(target_arch = "wasm32"))]
use axum_lib::main as axum_main;
use db::*;
use session::*;
const KID: &str = "key1";
const ENTRY_LIFETIME: usize = 30;
type ConnectionPool = Pool<RedisConnectionManager>;
#[derive(Serialize, Debug)]
pub struct TokenError {
pub error: CoreErrorResponseType,
}
#[derive(Debug, Error)]
pub enum CustomError {
#[error("{0}")]
BadRequest(String),
#[error("{0:?}")]
BadRequestToken(Json<TokenError>),
#[error("{0}")]
Unauthorized(String),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl IntoResponse for CustomError {
type Body = Full<Bytes>;
type BodyError = Infallible;
fn into_response(self) -> Response<Self::Body> {
match self {
CustomError::BadRequest(_) => {
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
}
CustomError::BadRequestToken(e) => (StatusCode::BAD_REQUEST, e).into_response(),
CustomError::Unauthorized(_) => {
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
}
CustomError::Other(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
}
}
}
async fn jwk_set(
Extension(private_key): Extension<RsaPrivateKey>,
) -> Result<Json<CoreJsonWebKeySet>, CustomError> {
let pem = private_key
.to_pkcs1_pem()
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem(
&pem,
Some(JsonWebKeyId::new(KID.to_string())),
)
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?
.as_verification_key()]);
Ok(jwks.into())
}
async fn provider_metadata(
Extension(config): Extension<config::Config>,
) -> Result<Json<CoreProviderMetadata>, CustomError> {
let pm = CoreProviderMetadata::new(
IssuerUrl::from_url(config.base_url.clone()),
AuthUrl::from_url(
config
.base_url
.join("authorize")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
),
JsonWebKeySetUrl::from_url(
config
.base_url
.join("jwk")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
),
vec![
ResponseTypes::new(vec![CoreResponseType::Code]),
ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]),
],
vec![CoreSubjectIdentifierType::Pairwise],
vec![CoreJwsSigningAlgorithm::RsaSsaPssSha256],
EmptyAdditionalProviderMetadata {},
)
.set_token_endpoint(Some(TokenUrl::from_url(
config
.base_url
.join("token")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_userinfo_endpoint(Some(UserInfoUrl::from_url(
config
.base_url
.join("userinfo")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_scopes_supported(Some(vec![
Scope::new("openid".to_string()),
// Scope::new("email".to_string()),
// Scope::new("profile".to_string()),
]))
.set_claims_supported(Some(vec![
CoreClaimName::new("sub".to_string()),
CoreClaimName::new("aud".to_string()),
// CoreClaimName::new("email".to_string()),
// CoreClaimName::new("email_verified".to_string()),
CoreClaimName::new("exp".to_string()),
CoreClaimName::new("iat".to_string()),
CoreClaimName::new("iss".to_string()),
// CoreClaimName::new("name".to_string()),
// CoreClaimName::new("given_name".to_string()),
// CoreClaimName::new("family_name".to_string()),
// CoreClaimName::new("picture".to_string()),
// CoreClaimName::new("locale".to_string()),
]))
.set_registration_endpoint(Some(RegistrationUrl::from_url(
config
.base_url
.join("register")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_token_endpoint_auth_methods_supported(Some(vec![
CoreClientAuthMethod::ClientSecretBasic,
CoreClientAuthMethod::ClientSecretPost,
]));
Ok(pm.into())
}
#[derive(Serialize, Deserialize)]
struct TokenForm {
code: String,
client_id: Option<String>,
client_secret: Option<String>,
grant_type: CoreGrantType, // TODO should just be authorization_code apparently?
}
// TODO should check Authorization header
// Actually, client secret can be
// 1. in the POST (currently supported) [x]
// 2. Authorization header [x]
// 3. JWT [ ]
// 4. signed JWT [ ]
// according to Keycloak
async fn token(
form: Form<TokenForm>,
bearer: Option<TypedHeader<Authorization<Bearer>>>,
Extension(private_key): Extension<RsaPrivateKey>,
Extension(config): Extension<config::Config>,
Extension(pool): Extension<ConnectionPool>,
) -> Result<Json<CoreTokenResponse>, CustomError> {
let mut conn = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let serialized_entry: Option<Vec<u8>> = conn
.get(form.code.to_string())
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if serialized_entry.is_none() {
return Err(CustomError::BadRequestToken(
TokenError {
error: CoreErrorResponseType::InvalidGrant,
}
.into(),
));
}
let code_entry: CodeEntry = bincode::deserialize(
&hex::decode(serialized_entry.unwrap())
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
)
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
let client_id = if let Some(c) = form.client_id.clone() {
c
} else {
code_entry.client_id.clone()
};
if let Some(secret) = if let Some(TypedHeader(Authorization(b))) = bearer {
Some(b.token().to_string())
} else {
form.client_secret.clone()
} {
let conn2 = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let client_entry = get_client(conn2, client_id.clone()).await?;
if client_entry.is_none() {
return Err(CustomError::Unauthorized(
"Unrecognised client id.".to_string(),
));
}
if secret != client_entry.unwrap().secret {
return Err(CustomError::Unauthorized("Bad secret.".to_string()));
}
} else if config.require_secret {
return Err(CustomError::Unauthorized("Secret required.".to_string()));
}
if code_entry.exchange_count > 0 {
// TODO use Oauth error response
return Err(anyhow!("Code was previously exchanged.").into());
}
conn.set_ex(
form.code.to_string(),
hex::encode(
bincode::serialize(&code_entry)
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
),
ENTRY_LIFETIME,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
let access_token = AccessToken::new(form.code.to_string());
let core_id_token = CoreIdTokenClaims::new(
IssuerUrl::from_url(config.base_url),
vec![Audience::new(client_id.clone())],
Utc::now() + Duration::seconds(60),
Utc::now(),
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
EmptyAdditionalClaims {},
)
.set_nonce(code_entry.nonce);
let pem = private_key
.to_pkcs1_pem()
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
let id_token = CoreIdToken::new(
core_id_token,
&CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string())))
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?,
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256,
Some(&access_token),
None,
)
.map_err(|e| anyhow!("{}", e))?;
Ok(CoreTokenResponse::new(
access_token,
CoreTokenType::Bearer,
CoreIdTokenFields::new(Some(id_token), EmptyExtraTokenFields {}),
)
.into())
}
#[derive(Deserialize)]
struct AuthorizeParams {
client_id: String,
redirect_uri: RedirectUrl,
scope: Scope,
response_type: Option<CoreResponseType>,
state: Option<String>,
nonce: Option<Nonce>,
prompt: Option<CoreAuthPrompt>,
request_uri: Option<RequestUrl>,
request: Option<String>,
}
// TODO handle `registration` parameter
async fn authorize(
session: UserSessionFromSession,
params: Query<AuthorizeParams>,
Extension(pool): Extension<ConnectionPool>,
) -> Result<(HeaderMap, Redirect), CustomError> {
let conn = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let client_entry = get_client(conn, params.client_id.clone())
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if client_entry.is_none() {
return Err(CustomError::Unauthorized(
"Unrecognised client id.".to_string(),
));
}
let mut r_u = params.0.redirect_uri.clone().url().clone();
r_u.set_query(None);
let mut r_us: Vec<Url> = client_entry
.unwrap()
.redirect_uris
.iter_mut()
.map(|u| u.url().clone())
.collect();
r_us.iter_mut().for_each(|u| u.set_query(None));
if !r_us.contains(&r_u) {
return Ok((
HeaderMap::new(),
Redirect::to(
"/error?message=unregistered_request_uri"
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
let state = if let Some(s) = params.0.state.clone() {
s
} else if params.0.request_uri.is_some() {
let mut url = params.0.redirect_uri.url().clone();
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::RequestUriNotSupported.as_ref(),
);
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
} else if params.0.request.is_some() {
let mut url = params.0.redirect_uri.url().clone();
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::RequestNotSupported.as_ref(),
);
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
} else {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut()
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
url.query_pairs_mut()
.append_pair("error_description", "Missing state");
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
};
if let Some(CoreAuthPrompt::None) = params.0.prompt {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("state", &state);
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::InteractionRequired.as_ref(),
);
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
if params.0.response_type.is_none() {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("state", &state);
url.query_pairs_mut()
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
url.query_pairs_mut()
.append_pair("error_description", "Missing response_type");
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
let response_type = params.0.response_type.as_ref().unwrap();
if params.scope != Scope::new("openid".to_string()) {
return Err(anyhow!("Scope not supported").into());
}
let (nonce, headers) = match session {
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
UserSessionFromSession::Invalid(cookie) => {
let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, cookie);
return Ok((
headers,
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{}",
&params.0.client_id,
&params.0.redirect_uri.to_string(),
&params.0.scope.to_string(),
&response_type.as_ref(),
&state,
&params.0.client_id,
&params.0.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or_default()
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
UserSessionFromSession::Created { header, nonce } => {
let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, header);
(nonce, headers)
}
};
let domain = params.redirect_uri.url().host().unwrap();
let oidc_nonce_param = if let Some(n) = &params.nonce {
format!("&oidc_nonce={}", n.secret())
} else {
"".to_string()
};
Ok((
headers,
Redirect::to(
format!(
"/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}",
nonce,
domain,
params.redirect_uri.to_string(),
state,
params.client_id,
oidc_nonce_param
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
}
#[derive(Serialize, Deserialize)]
struct SiweCookie {
message: Web3ModalMessage,
signature: String,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Web3ModalMessage {
pub domain: String,
pub address: String,
pub statement: String,
pub uri: String,
pub version: String,
pub chain_id: String,
pub nonce: String,
pub issued_at: String,
pub expiration_time: Option<String>,
pub not_before: Option<String>,
pub request_id: Option<String>,
pub resources: Option<Vec<String>>,
}
impl Web3ModalMessage {
pub fn to_eip4361_message(&self) -> Result<Message> {
let mut next_resources: Vec<UriString> = Vec::new();
match &self.resources {
Some(resources) => {
for resource in resources {
let x = UriString::from_str(resource)?;
next_resources.push(x)
}
}
None => {}
}
Ok(Message {
domain: self.domain.clone().try_into()?,
address: <[u8; 20]>::from_hex(self.address.chars().skip(2).collect::<String>())?,
statement: self.statement.to_string(),
uri: UriAbsoluteString::from_str(&self.uri)?,
version: Version::from_str(&self.version)?,
chain_id: self.chain_id.to_string(),
nonce: self.nonce.to_string(),
issued_at: self.issued_at.to_string(),
expiration_time: self.expiration_time.clone(),
not_before: self.not_before.clone(),
request_id: self.request_id.clone(),
resources: next_resources,
})
}
}
#[derive(Serialize, Deserialize)]
struct CodeEntry {
exchange_count: usize,
address: String,
nonce: Option<Nonce>,
client_id: String,
}
#[derive(Deserialize)]
struct SignInParams {
redirect_uri: RedirectUrl,
state: String,
oidc_nonce: Option<Nonce>,
client_id: String,
}
async fn sign_in(
session: UserSessionFromSession,
params: Query<SignInParams>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
Extension(pool): Extension<ConnectionPool>,
) -> Result<(HeaderMap, Redirect), CustomError> {
let mut headers = HeaderMap::new();
let siwe_cookie: SiweCookie = match cookies.get("siwe") {
Some(c) => serde_json::from_str(
&decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?,
)
.map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?,
None => {
return Err(anyhow!("No `siwe` cookie").into());
}
};
let (nonce, headers) = match session {
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
UserSessionFromSession::Invalid(header) => {
headers.insert(header::SET_COOKIE, header);
return Ok((
headers,
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
&params.0.client_id.clone(),
&params.0.redirect_uri.to_string(),
&params.0.state,
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
UserSessionFromSession::Created { .. } => {
return Ok((
headers,
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
&params.0.client_id.clone(),
&params.0.redirect_uri.to_string(),
&params.0.state,
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
}
};
let signature = match <[u8; 65]>::from_hex(
siwe_cookie
.signature
.chars()
.skip(2)
.take(130)
.collect::<String>(),
) {
Ok(s) => s,
Err(e) => {
return Err(CustomError::BadRequest(format!("Bad signature: {}", e)));
}
};
let message = siwe_cookie
.message
.to_eip4361_message()
.map_err(|e| anyhow!("Failed to serialise message: {}", e))?;
info!("{}", message);
message
.verify_eip191(signature)
.map_err(|e| anyhow!("Failed signature validation: {}", e))?;
let domain = params.redirect_uri.url().host().unwrap();
if domain.to_string() != siwe_cookie.message.domain {
return Err(anyhow!("Conflicting domains in message and redirect").into());
}
if nonce != siwe_cookie.message.nonce {
return Err(anyhow!("Conflicting nonces in message and session").into());
}
let code_entry = CodeEntry {
address: siwe_cookie.message.address,
nonce: params.oidc_nonce.clone(),
exchange_count: 0,
client_id: params.0.client_id.clone(),
};
let code = Uuid::new_v4();
let mut conn = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
conn.set_ex(
code.to_string(),
hex::encode(
bincode::serialize(&code_entry)
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
),
ENTRY_LIFETIME,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("code", &code.to_string());
url.query_pairs_mut().append_pair("state", &params.state);
Ok((
headers,
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
// TODO clear session
}
async fn register(
extract::Json(payload): extract::Json<CoreClientMetadata>,
Extension(pool): Extension<ConnectionPool>,
) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> {
let id = Uuid::new_v4();
let secret = Uuid::new_v4();
let conn = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let entry = ClientEntry {
secret: secret.to_string(),
redirect_uris: payload.redirect_uris().to_vec(),
};
set_client(conn, id.to_string(), entry).await?;
Ok((
StatusCode::CREATED,
CoreClientRegistrationResponse::new(
ClientId::new(id.to_string()),
payload.redirect_uris().to_vec(),
EmptyAdditionalClientMetadata::default(),
EmptyAdditionalClientRegistrationResponse::default(),
)
.set_client_secret(Some(ClientSecret::new(secret.to_string())))
.into(),
))
}
// TODO CORS
// TODO need validation of the token
// TODO restrict access token use to only once?
async fn userinfo(
// access_token: AccessTokenUserInfo, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
Extension(pool): Extension<ConnectionPool>,
) -> Result<Json<CoreUserInfoClaims>, CustomError> {
let code = bearer.token().to_string();
let mut conn = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let serialized_entry: Option<Vec<u8>> = conn
.get(code)
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if serialized_entry.is_none() {
return Err(CustomError::BadRequest("Unknown code.".to_string()));
}
let code_entry: CodeEntry = bincode::deserialize(
&hex::decode(serialized_entry.unwrap())
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
)
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
Ok(CoreUserInfoClaims::new(
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
EmptyAdditionalClaims::default(),
)
.into())
}
async fn healthcheck() {}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::main]
async fn main() {
let config = Figment::from(Serialized::defaults(config::Config::default()))
.merge(Toml::file("siwe-oidc.toml").nested())
.merge(Env::prefixed("SIWEOIDC_").split("__").global());
let config = config.extract::<config::Config>().unwrap();
tracing_subscriber::fmt::init();
let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap();
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
let pool2 = bb8::Pool::builder().build(manager).await.unwrap();
for (id, secret) in &config.default_clients.clone() {
let conn = pool2
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))
.unwrap();
let client_entry = ClientEntry {
secret: secret.to_string(),
redirect_uris: vec![],
};
set_client(conn, id.to_string(), client_entry)
.await
.unwrap(); // TODO
}
let private_key = if let Some(key) = &config.rsa_pem {
RsaPrivateKey::from_pkcs1_pem(key)
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap()
} else {
info!("Generating key...");
let mut rng = OsRng;
let bits = 2048;
let private = RsaPrivateKey::new(&mut rng, bits)
.map_err(|e| anyhow!("Failed to generate a key: {}", e))
.unwrap();
info!("Generated key.");
info!("{:?}", private.to_pkcs1_pem().unwrap());
private
};
let app = Router::new()
.nest(
"/build",
service_method_routing::get(ServeDir::new("./static/build")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.nest(
"/img",
service_method_routing::get(ServeDir::new("./static/img")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/",
service_method_routing::get(ServeFile::new("./static/index.html")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/error",
service_method_routing::get(ServeFile::new("./static/error.html")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/favicon.png",
service_method_routing::get(ServeFile::new("./static/favicon.png")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route("/.well-known/openid-configuration", get(provider_metadata))
.route("/jwk", get(jwk_set))
.route("/token", post(token))
.route("/authorize", get(authorize))
.route("/register", post(register))
.route("/userinfo", get(userinfo).post(userinfo))
.route("/sign_in", get(sign_in))
.route("/health", get(healthcheck))
.layer(AddExtensionLayer::new(private_key))
.layer(AddExtensionLayer::new(config.clone()))
.layer(AddExtensionLayer::new(pool))
.layer(AddExtensionLayer::new(
RedisSessionStore::new(config.redis_url.clone())
.unwrap()
.with_prefix("async-sessions/"),
))
.layer(TraceLayer::new_for_http());
let addr = SocketAddr::from((config.address, config.port));
tracing::info!("Listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
axum_main().await
}
#[cfg(target_arch = "wasm32")]
fn main() {}

523
src/oidc.rs Normal file
View File

@ -0,0 +1,523 @@
use anyhow::{anyhow, Result};
use chrono::{Duration, Utc};
use headers::{self, authorization::Bearer};
use hex::FromHex;
use iri_string::types::UriString;
use openidconnect::{
core::{
CoreAuthErrorResponseType, CoreAuthPrompt, CoreClaimName, CoreClientAuthMethod,
CoreClientMetadata, CoreClientRegistrationResponse, CoreErrorResponseType, CoreGrantType,
CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
},
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
url::Url,
AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims,
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId,
JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, RequestUrl,
ResponseTypes, Scope, StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
};
use rsa::{pkcs1::ToRsaPrivateKey, RsaPrivateKey};
use serde::{Deserialize, Serialize};
use siwe::eip4361::{Message, Version};
use std::str::FromStr;
use thiserror::Error;
use tracing::info;
use urlencoding::decode;
use uuid::Uuid;
#[cfg(target_arch = "wasm32")]
use super::db::*;
#[cfg(not(target_arch = "wasm32"))]
use siwe_oidc::db::*;
const KID: &str = "key1";
pub const METADATA_PATH: &str = "/.well-known/openid-configuration";
pub const JWK_PATH: &str = "/jwk";
pub const TOKEN_PATH: &str = "/token";
pub const AUTHORIZE_PATH: &str = "/authorize";
pub const REGISTER_PATH: &str = "/register";
pub const USERINFO_PATH: &str = "/userinfo";
pub const SIGNIN_PATH: &str = "/sign_in";
pub const SIWE_COOKIE_KEY: &str = "siwe";
#[cfg(not(target_arch = "wasm32"))]
type DBClientType = (dyn DBClient + Sync);
#[cfg(target_arch = "wasm32")]
type DBClientType = dyn DBClient;
#[derive(Serialize, Debug)]
pub struct TokenError {
pub error: CoreErrorResponseType,
pub error_description: String,
}
#[derive(Debug, Error)]
pub enum CustomError {
#[error("{0}")]
BadRequest(String),
#[error("{0:?}")]
BadRequestToken(TokenError),
#[error("{0}")]
Unauthorized(String),
#[error("{0:?}")]
Redirect(String),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub fn jwks(private_key: RsaPrivateKey) -> Result<CoreJsonWebKeySet, CustomError> {
let pem = private_key
.to_pkcs1_pem()
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem(
&pem,
Some(JsonWebKeyId::new(KID.to_string())),
)
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?
.as_verification_key()]);
Ok(jwks)
}
pub fn metadata(base_url: Url) -> Result<CoreProviderMetadata, CustomError> {
let pm = CoreProviderMetadata::new(
IssuerUrl::from_url(base_url.clone()),
AuthUrl::from_url(
base_url
.join(AUTHORIZE_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
),
JsonWebKeySetUrl::from_url(
base_url
.join(JWK_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
),
vec![
ResponseTypes::new(vec![CoreResponseType::Code]),
ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]),
],
vec![CoreSubjectIdentifierType::Pairwise],
vec![CoreJwsSigningAlgorithm::RsaSsaPssSha256],
EmptyAdditionalProviderMetadata {},
)
.set_token_endpoint(Some(TokenUrl::from_url(
base_url
.join(TOKEN_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_userinfo_endpoint(Some(UserInfoUrl::from_url(
base_url
.join(USERINFO_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_scopes_supported(Some(vec![
Scope::new("openid".to_string()),
// Scope::new("email".to_string()),
// Scope::new("profile".to_string()),
]))
.set_claims_supported(Some(vec![
CoreClaimName::new("sub".to_string()),
CoreClaimName::new("aud".to_string()),
// CoreClaimName::new("email".to_string()),
// CoreClaimName::new("email_verified".to_string()),
CoreClaimName::new("exp".to_string()),
CoreClaimName::new("iat".to_string()),
CoreClaimName::new("iss".to_string()),
// CoreClaimName::new("name".to_string()),
// CoreClaimName::new("given_name".to_string()),
// CoreClaimName::new("family_name".to_string()),
// CoreClaimName::new("picture".to_string()),
// CoreClaimName::new("locale".to_string()),
]))
.set_registration_endpoint(Some(RegistrationUrl::from_url(
base_url
.join(REGISTER_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_token_endpoint_auth_methods_supported(Some(vec![
CoreClientAuthMethod::ClientSecretBasic,
CoreClientAuthMethod::ClientSecretPost,
]));
Ok(pm)
}
#[derive(Serialize, Deserialize)]
pub struct TokenForm {
pub code: String,
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub grant_type: CoreGrantType, // TODO should just be authorization_code apparently?
}
pub async fn token(
form: TokenForm,
// From the request's Authorization header
secret: Option<String>,
private_key: RsaPrivateKey,
base_url: Url,
require_secret: bool,
db_client: &DBClientType,
) -> Result<CoreTokenResponse, CustomError> {
let code_entry = if let Some(c) = db_client.get_code(form.code.to_string()).await? {
c
} else {
return Err(CustomError::BadRequestToken(TokenError {
error: CoreErrorResponseType::InvalidGrant,
error_description: "Unknown code.".to_string(),
}));
};
let client_id = if let Some(c) = form.client_id.clone() {
c
} else {
code_entry.client_id.clone()
};
if let Some(secret) = if let Some(b) = secret {
Some(b)
} else {
form.client_secret.clone()
} {
let client_entry = db_client.get_client(client_id.clone()).await?;
if client_entry.is_none() {
return Err(CustomError::Unauthorized(
"Unrecognised client id.".to_string(),
));
}
if secret != client_entry.unwrap().secret {
return Err(CustomError::Unauthorized("Bad secret.".to_string()));
}
} else if require_secret {
return Err(CustomError::Unauthorized("Secret required.".to_string()));
}
if code_entry.exchange_count > 0 {
// TODO use Oauth error response
return Err(CustomError::BadRequestToken(TokenError {
error: CoreErrorResponseType::InvalidGrant,
error_description: "Code was previously exchanged.".to_string(),
}));
}
let mut code_entry2 = code_entry.clone();
code_entry2.exchange_count += 1;
db_client
.set_code(form.code.to_string(), code_entry2)
.await?;
let access_token = AccessToken::new(form.code);
let core_id_token = CoreIdTokenClaims::new(
IssuerUrl::from_url(base_url),
vec![Audience::new(client_id.clone())],
Utc::now() + Duration::seconds(60),
Utc::now(),
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
EmptyAdditionalClaims {},
)
.set_nonce(code_entry.nonce);
let pem = private_key
.to_pkcs1_pem()
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
let id_token = CoreIdToken::new(
core_id_token,
&CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string())))
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?,
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256,
Some(&access_token),
None,
)
.map_err(|e| anyhow!("{}", e))?;
Ok(CoreTokenResponse::new(
access_token,
CoreTokenType::Bearer,
CoreIdTokenFields::new(Some(id_token), EmptyExtraTokenFields {}),
))
}
#[derive(Deserialize)]
pub struct AuthorizeParams {
pub client_id: String,
pub redirect_uri: RedirectUrl,
pub scope: Scope,
pub response_type: Option<CoreResponseType>,
pub state: Option<String>,
pub nonce: Option<Nonce>,
pub prompt: Option<CoreAuthPrompt>,
pub request_uri: Option<RequestUrl>,
pub request: Option<String>,
}
pub async fn authorize(
params: AuthorizeParams,
nonce: String,
db_client: &DBClientType,
) -> Result<String, CustomError> {
let client_entry = db_client
.get_client(params.client_id.clone())
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if client_entry.is_none() {
return Err(CustomError::Unauthorized(
"Unrecognised client id.".to_string(),
));
}
let mut r_u = params.redirect_uri.clone().url().clone();
r_u.set_query(None);
let mut r_us: Vec<Url> = client_entry
.unwrap()
.redirect_uris
.iter_mut()
.map(|u| u.url().clone())
.collect();
r_us.iter_mut().for_each(|u| u.set_query(None));
if !r_us.contains(&r_u) {
return Err(CustomError::Redirect(
"/error?message=unregistered_request_uri".to_string(),
));
}
let state = if let Some(s) = params.state.clone() {
s
} else if params.request_uri.is_some() {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::RequestUriNotSupported.as_ref(),
);
return Err(CustomError::Redirect(url.to_string()));
} else if params.request.is_some() {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::RequestNotSupported.as_ref(),
);
return Err(CustomError::Redirect(url.to_string()));
} else {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut()
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
url.query_pairs_mut()
.append_pair("error_description", "Missing state");
return Err(CustomError::Redirect(url.to_string()));
};
if let Some(CoreAuthPrompt::None) = params.prompt {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("state", &state);
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::InteractionRequired.as_ref(),
);
return Err(CustomError::Redirect(url.to_string()));
}
if params.response_type.is_none() {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("state", &state);
url.query_pairs_mut()
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
url.query_pairs_mut()
.append_pair("error_description", "Missing response_type");
return Err(CustomError::Redirect(url.to_string()));
}
let _response_type = params.response_type.as_ref().unwrap();
if params.scope != Scope::new("openid".to_string()) {
return Err(anyhow!("Scope not supported").into());
}
let domain = params.redirect_uri.url().host().unwrap();
let oidc_nonce_param = if let Some(n) = &params.nonce {
format!("&oidc_nonce={}", n.secret())
} else {
"".to_string()
};
Ok(format!(
"/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}",
nonce,
domain,
params.redirect_uri.to_string(),
state,
params.client_id,
oidc_nonce_param
))
}
#[derive(Serialize, Deserialize)]
pub struct SiweCookie {
message: Web3ModalMessage,
signature: String,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Web3ModalMessage {
pub domain: String,
pub address: String,
pub statement: String,
pub uri: String,
pub version: String,
pub chain_id: String,
pub nonce: String,
pub issued_at: String,
pub expiration_time: Option<String>,
pub not_before: Option<String>,
pub request_id: Option<String>,
pub resources: Option<Vec<String>>,
}
impl Web3ModalMessage {
fn to_eip4361_message(&self) -> Result<Message> {
let mut next_resources: Vec<UriString> = Vec::new();
match &self.resources {
Some(resources) => {
for resource in resources {
let x = UriString::from_str(resource)?;
next_resources.push(x)
}
}
None => {}
}
Ok(Message {
domain: self.domain.clone().try_into()?,
address: <[u8; 20]>::from_hex(self.address.chars().skip(2).collect::<String>())?,
statement: self.statement.to_string(),
uri: UriString::from_str(&self.uri)?,
version: Version::from_str(&self.version)?,
chain_id: self.chain_id.to_string(),
nonce: self.nonce.to_string(),
issued_at: self.issued_at.to_string(),
expiration_time: self.expiration_time.clone(),
not_before: self.not_before.clone(),
request_id: self.request_id.clone(),
resources: next_resources,
})
}
}
#[derive(Deserialize)]
pub struct SignInParams {
pub redirect_uri: RedirectUrl,
pub state: String,
pub oidc_nonce: Option<Nonce>,
pub client_id: String,
}
pub async fn sign_in(
params: SignInParams,
expected_nonce: Option<String>,
cookies: headers::Cookie,
db_client: &DBClientType,
) -> Result<Url, CustomError> {
let siwe_cookie: SiweCookie = match cookies.get(SIWE_COOKIE_KEY) {
Some(c) => serde_json::from_str(
&decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?,
)
.map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?,
None => {
return Err(anyhow!("No `siwe` cookie").into());
}
};
let signature = match <[u8; 65]>::from_hex(
siwe_cookie
.signature
.chars()
.skip(2)
.take(130)
.collect::<String>(),
) {
Ok(s) => s,
Err(e) => {
return Err(CustomError::BadRequest(format!("Bad signature: {}", e)));
}
};
let message = siwe_cookie
.message
.to_eip4361_message()
.map_err(|e| anyhow!("Failed to serialise message: {}", e))?;
info!("{}", message);
message
.verify(signature)
.map_err(|e| anyhow!("Failed signature validation: {}", e))?;
let domain = params.redirect_uri.url().host().unwrap();
if domain.to_string() != siwe_cookie.message.domain {
return Err(anyhow!("Conflicting domains in message and redirect").into());
}
if expected_nonce.is_some() && expected_nonce.unwrap() != siwe_cookie.message.nonce {
return Err(anyhow!("Conflicting nonces in message and session").into());
}
let code_entry = CodeEntry {
address: siwe_cookie.message.address,
nonce: params.oidc_nonce.clone(),
exchange_count: 0,
client_id: params.client_id.clone(),
};
let code = Uuid::new_v4();
db_client.set_code(code.to_string(), code_entry).await?;
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("code", &code.to_string());
url.query_pairs_mut().append_pair("state", &params.state);
Ok(url)
}
pub async fn register(
payload: CoreClientMetadata,
db_client: &DBClientType,
) -> Result<CoreClientRegistrationResponse, CustomError> {
let id = Uuid::new_v4();
let secret = Uuid::new_v4();
let entry = ClientEntry {
secret: secret.to_string(),
redirect_uris: payload.redirect_uris().to_vec(),
};
db_client.set_client(id.to_string(), entry).await?;
Ok(CoreClientRegistrationResponse::new(
ClientId::new(id.to_string()),
payload.redirect_uris().to_vec(),
EmptyAdditionalClientMetadata::default(),
EmptyAdditionalClientRegistrationResponse::default(),
)
.set_client_secret(Some(ClientSecret::new(secret.to_string()))))
}
#[derive(Deserialize)]
pub struct UserInfoPayload {
pub access_token: Option<String>,
}
pub async fn userinfo(
bearer: Option<Bearer>,
payload: UserInfoPayload,
db_client: &DBClientType,
) -> Result<CoreUserInfoClaims, CustomError> {
let code = if let Some(b) = bearer {
b.token().to_string()
} else if let Some(c) = payload.access_token {
c
} else {
return Err(CustomError::BadRequest("Missing access token.".to_string()));
};
let code_entry = if let Some(c) = db_client.get_code(code).await? {
c
} else {
return Err(CustomError::BadRequest("Unknown code.".to_string()));
};
Ok(CoreUserInfoClaims::new(
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
EmptyAdditionalClaims::default(),
))
}

210
src/worker_lib.rs Normal file
View File

@ -0,0 +1,210 @@
use anyhow::anyhow;
use headers::{
self,
authorization::{Basic, Bearer, Credentials},
Authorization, Header, HeaderValue,
};
use rand::{distributions::Alphanumeric, Rng};
use rsa::{pkcs1::FromRsaPrivateKey, RsaPrivateKey};
use worker::*;
use super::db::CFClient;
use super::oidc::{self, CustomError, TokenForm, UserInfoPayload};
const BASE_URL_KEY: &str = "BASE_URL";
const RSA_PEM_KEY: &str = "RSA_PEM";
// https://github.com/cloudflare/workers-rs/issues/64
// #[global_allocator]
// static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
impl From<CustomError> for Result<Response> {
fn from(error: CustomError) -> Self {
match error {
CustomError::BadRequest(_) => Response::error(&error.to_string(), 400),
CustomError::BadRequestToken(e) => Response::from_json(&e).map(|r| r.with_status(400)),
CustomError::Unauthorized(_) => Response::error(&error.to_string(), 401),
CustomError::Redirect(uri) => Response::redirect(uri.parse().unwrap()),
CustomError::Other(_) => Response::error(&error.to_string(), 500),
}
}
}
pub async fn main(req: Request, env: Env) -> Result<Response> {
console_error_panic_hook::set_once();
// tracing_subscriber::fmt::init();
// console_log::init_with_level(log::Level::Info).expect("error initializing log");
let userinfo = |mut req: Request, ctx: RouteContext<()>| async move {
let bearer = req
.headers()
.get(Authorization::<Bearer>::name().as_str())?
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
.as_ref()
.and_then(Bearer::decode);
let payload = if bearer.is_none() {
match req.form_data().await {
Ok(f) => {
let access_token = if let Some(FormEntry::Field(a)) = f.get("access_token") {
Some(a)
} else {
return Response::error("Missing code", 400);
};
UserInfoPayload { access_token }
}
Err(_) => return Response::error("Bad request", 400),
}
} else {
UserInfoPayload { access_token: None }
};
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::userinfo(bearer, payload, &db_client).await {
Ok(r) => Ok(Response::from_json(&r)?),
Err(e) => e.into(),
}
};
let router = Router::new();
router
.get_async(oidc::METADATA_PATH, |_req, ctx| async move {
match oidc::metadata(ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap()) {
Ok(m) => Response::from_json(&m),
Err(e) => e.into(),
}
})
.get_async(oidc::JWK_PATH, |_req, ctx| async move {
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap();
match oidc::jwks(private_key) {
Ok(m) => Response::from_json(&m),
Err(e) => e.into(),
}
})
.post_async(oidc::TOKEN_PATH, |mut req, ctx| async move {
let form_data = req.form_data().await?;
let code = if let Some(FormEntry::Field(c)) = form_data.get("code") {
c
} else {
return Response::error("Missing code", 400);
};
let client_id = match form_data.get("client_id") {
Some(FormEntry::Field(c)) => Some(c),
None => None,
_ => return Response::error("Client ID not a field", 400),
};
let client_secret = match form_data.get("client_secret") {
Some(FormEntry::Field(c)) => Some(c),
None => None,
_ => return Response::error("Client secret not a field", 400),
};
let grant_type = if let Some(FormEntry::Field(c)) = form_data.get("code") {
if let Ok(cc) = serde_json::from_str(&format!("\"{}\"", c)) {
cc
} else {
return Response::error("Invalid grant type", 400);
}
} else {
return Response::error("Missing grant type", 400);
};
let secret = req
.headers()
.get(Authorization::<Bearer>::name().as_str())?
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
.as_ref()
.and_then(|b| {
if b.to_str().unwrap().starts_with("Bearer") {
Bearer::decode(b).map(|bb| bb.token().to_string())
} else {
Basic::decode(b).map(|bb| bb.password().to_string())
}
});
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap();
let base_url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
let url = req.url()?;
let db_client = CFClient { ctx, url };
let token_response = oidc::token(
TokenForm {
code,
client_id,
client_secret,
grant_type,
},
secret,
private_key,
base_url,
false,
&db_client,
)
.await;
match token_response {
Ok(m) => Response::from_json(&m),
Err(e) => e.into(),
}
})
// TODO add browser session
.get_async(oidc::AUTHORIZE_PATH, |req, ctx| async move {
let base_url: Url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
let url = req.url()?;
let query = url.query().unwrap_or_default();
let params = match serde_urlencoded::from_str(query) {
Ok(p) => p,
Err(_) => return CustomError::BadRequest("Bad query params".to_string()).into(),
};
let nonce = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect();
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::authorize(params, nonce, &db_client).await {
Ok(url) => Response::redirect(base_url.join(&url).unwrap()),
Err(e) => match e {
CustomError::Redirect(url) => {
CustomError::Redirect(base_url.join(&url).unwrap().to_string())
}
c => c,
}
.into(),
}
})
.post_async(oidc::REGISTER_PATH, |mut req, ctx| async move {
let payload = req.json().await?;
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::register(payload, &db_client).await {
Ok(r) => Ok(Response::from_json(&r)?.with_status(201)),
Err(e) => e.into(),
}
})
.post_async(oidc::USERINFO_PATH, userinfo)
.get_async(oidc::USERINFO_PATH, userinfo)
.get_async(oidc::SIGNIN_PATH, |req, ctx| async move {
let url = req.url()?;
let query = url.query().unwrap_or_default();
let params = match serde_urlencoded::from_str(query) {
Ok(p) => p,
Err(_) => return CustomError::BadRequest("Bad query params".to_string()).into(),
};
let cookies = req
.headers()
.get(headers::Cookie::name().as_str())?
.and_then(|c| HeaderValue::from_str(&c).ok())
.and_then(|c| headers::Cookie::decode(&mut [c].iter()).ok());
if cookies.is_none() {
return Response::error("Missing cookies", 400);
}
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::sign_in(params, None, cookies.unwrap(), &db_client).await {
Ok(url) => Response::redirect(url),
Err(e) => e.into(),
}
})
.run(req, env)
.await
}

35
wrangle_example.toml Normal file
View File

@ -0,0 +1,35 @@
name = "siwe_oidc"
type = "javascript"
account_id = ""
# zone_id = ""
workers_dev = false
compatibility_date = "2021-12-20"
kv_namespaces = [
{ binding = "SIWE-OIDC", id = "", preview_id = "" }
]
[vars]
WORKERS_RS_VERSION = "0.0.7"
BASE_URL = "https://siweoidc.spruceid.xyz"
[durable_objects]
bindings = [
{ name = "SIWE-OIDC-CODES", class_name = "DOCodes" }
]
[[migrations]]
tag = "v1"
new_classes = ["DOCodes"]
[build]
command = "cargo install -q worker-build && worker-build --release"
[build.upload]
dir = "build/worker"
format = "modules"
main = "./shim.mjs"
[[build.upload.rules]]
globs = ["**/*.wasm"]
type = "CompiledWasm"