Merge pull request #197 from dinosaure/lint

Use Lwt.Syntax and avoid some >>= fun () patterns
This commit is contained in:
Pierre Alain 2024-10-16 14:19:17 +02:00 committed by GitHub
commit 56e66ca39a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 88 additions and 95 deletions

View File

@ -20,5 +20,5 @@ $builder build -t qubes-mirage-firewall .
echo Building Firewall... echo Building Firewall...
$builder run --rm -i -v `pwd`:/tmp/orb-build:Z qubes-mirage-firewall $builder run --rm -i -v `pwd`:/tmp/orb-build:Z qubes-mirage-firewall
echo "SHA2 of build: $(sha256sum ./dist/qubes-firewall.xen)" echo "SHA2 of build: $(sha256sum ./dist/qubes-firewall.xen)"
echo "SHA2 last known: 4b1f743bf4540bc8a9366cf8f23a78316e4f2d477af77962e50618753c4adf10" echo "SHA2 last known: 2392386d9056b17a648f26b0c5d1c72b93f8a197964c670b2b45e71707727317"
echo "(hashes should match for released versions)" echo "(hashes should match for released versions)"

View File

@ -8,7 +8,7 @@ let src = Logs.Src.create "client_eth" ~doc:"Ethernet networks for NetVM clients
module Log = (val Logs.src_log src : Logs.LOG) module Log = (val Logs.src_log src : Logs.LOG)
type t = { type t = {
mutable iface_of_ip : client_link IpMap.t; mutable iface_of_ip : client_link Ipaddr.V4.Map.t;
changed : unit Lwt_condition.t; (* Fires when [iface_of_ip] changes. *) changed : unit Lwt_condition.t; (* Fires when [iface_of_ip] changes. *)
my_ip : Ipaddr.V4.t; (* The IP that clients are given as their default gateway. *) my_ip : Ipaddr.V4.t; (* The IP that clients are given as their default gateway. *)
} }
@ -21,21 +21,21 @@ type host =
let create config = let create config =
let changed = Lwt_condition.create () in let changed = Lwt_condition.create () in
let my_ip = config.Dao.our_ip in let my_ip = config.Dao.our_ip in
Lwt.return { iface_of_ip = IpMap.empty; my_ip; changed } Lwt.return { iface_of_ip = Ipaddr.V4.Map.empty; my_ip; changed }
let client_gw t = t.my_ip let client_gw t = t.my_ip
let add_client t iface = let add_client t iface =
let ip = iface#other_ip in let ip = iface#other_ip in
let rec aux () = let rec aux () =
match IpMap.find ip t.iface_of_ip with match Ipaddr.V4.Map.find_opt ip t.iface_of_ip with
| Some old -> | Some old ->
(* Wait for old client to disappear before adding one with the same IP address. (* Wait for old client to disappear before adding one with the same IP address.
Otherwise, its [remove_client] call will remove the new client instead. *) Otherwise, its [remove_client] call will remove the new client instead. *)
Log.info (fun f -> f ~header:iface#log_header "Waiting for old client %s to go away before accepting new one" old#log_header); Log.info (fun f -> f ~header:iface#log_header "Waiting for old client %s to go away before accepting new one" old#log_header);
Lwt_condition.wait t.changed >>= aux Lwt_condition.wait t.changed >>= aux
| None -> | None ->
t.iface_of_ip <- t.iface_of_ip |> IpMap.add ip iface; t.iface_of_ip <- t.iface_of_ip |> Ipaddr.V4.Map.add ip iface;
Lwt_condition.broadcast t.changed (); Lwt_condition.broadcast t.changed ();
Lwt.return_unit Lwt.return_unit
in in
@ -43,11 +43,11 @@ let add_client t iface =
let remove_client t iface = let remove_client t iface =
let ip = iface#other_ip in let ip = iface#other_ip in
assert (IpMap.mem ip t.iface_of_ip); assert (Ipaddr.V4.Map.mem ip t.iface_of_ip);
t.iface_of_ip <- t.iface_of_ip |> IpMap.remove ip; t.iface_of_ip <- t.iface_of_ip |> Ipaddr.V4.Map.remove ip;
Lwt_condition.broadcast t.changed () Lwt_condition.broadcast t.changed ()
let lookup t ip = IpMap.find ip t.iface_of_ip let lookup t ip = Ipaddr.V4.Map.find_opt ip t.iface_of_ip
let classify t ip = let classify t ip =
match ip with match ip with
@ -79,7 +79,7 @@ module ARP = struct
(* We're now treating client networks as point-to-point links, (* We're now treating client networks as point-to-point links,
so we no longer respond on behalf of other clients. *) so we no longer respond on behalf of other clients. *)
(* (*
else match IpMap.find ip t.net.iface_of_ip with else match Ipaddr.V4.Map.find_opt ip t.net.iface_of_ip with
| Some client_iface -> Some client_iface#other_mac | Some client_iface -> Some client_iface#other_mac
| None -> None | None -> None
*) *)

55
dao.ml
View File

@ -65,43 +65,40 @@ let read_rules rules client_ip =
number = 0;})] number = 0;})]
let vifs client domid = let vifs client domid =
let open Lwt.Syntax in
match int_of_string_opt domid with match int_of_string_opt domid with
| None -> Log.err (fun f -> f "Invalid domid %S" domid); Lwt.return [] | None -> Log.err (fun f -> f "Invalid domid %S" domid); Lwt.return []
| Some domid -> | Some domid ->
let path = Printf.sprintf "backend/vif/%d" domid in let path = Fmt.str "backend/vif/%d" domid in
Xen_os.Xs.immediate client (fun handle -> let vifs_of_domain handle =
directory ~handle path >>= let* devices = directory ~handle path in
Lwt_list.filter_map_p (fun device_id -> let ip_of_vif device_id = match int_of_string_opt device_id with
match int_of_string_opt device_id with | None ->
| None -> Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid); Lwt.return_none Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid);
Lwt.return_none
| Some device_id -> | Some device_id ->
let vif = { ClientVif.domid; device_id } in let vif = { ClientVif.domid; device_id } in
Lwt.try_bind let get_client_ip () =
(fun () -> Xen_os.Xs.read handle (Printf.sprintf "%s/%d/ip" path device_id)) let* str = Xen_os.Xs.read handle (Fmt.str "%s/%d/ip" path device_id) in
(fun client_ip -> let client_ip = List.hd (String.split_on_char ' ' str) in
let client_ip' = match String.split_on_char ' ' client_ip with (* NOTE(dinosaure): it's safe to use [List.hd] here,
| [] -> Log.err (fun m -> m "unexpected empty list"); "" [String.split_on_char] can not return an empty list. *)
| [ ip ] -> ip Lwt.return_some (vif, Ipaddr.V4.of_string_exn client_ip)
| ip::rest ->
Log.warn (fun m -> m "ignoring IPs %s from %a, we support one IP per client"
(String.concat " " rest) ClientVif.pp vif);
ip
in in
match Ipaddr.V4.of_string client_ip' with Lwt.catch get_client_ip @@ function
| Ok ip -> Lwt.return (Some (vif, ip)) | Xs_protocol.Enoent _ -> Lwt.return_none
| Error `Msg msg -> | Ipaddr.Parse_error (msg, client_ip) ->
Log.err (fun f -> f "Error parsing IP address of %a from %s: %s" Log.err (fun f -> f "Error parsing IP address of %a from %s: %s"
ClientVif.pp vif client_ip msg); ClientVif.pp vif client_ip msg);
Lwt.return None Lwt.return_none
) | exn ->
(function
| Xs_protocol.Enoent _ -> Lwt.return None
| ex ->
Log.err (fun f -> f "Error getting IP address of %a: %s" Log.err (fun f -> f "Error getting IP address of %a: %s"
ClientVif.pp vif (Printexc.to_string ex)); ClientVif.pp vif (Printexc.to_string exn));
Lwt.return None Lwt.return_none
) in
)) Lwt_list.filter_map_p ip_of_vif devices
in
Xen_os.Xs.immediate client vifs_of_domain
let watch_clients fn = let watch_clients fn =
Xen_os.Xs.make () >>= fun xs -> Xen_os.Xs.make () >>= fun xs ->
@ -116,7 +113,7 @@ let watch_clients fn =
end >>= fun items -> end >>= fun items ->
Xen_os.Xs.make () >>= fun xs -> Xen_os.Xs.make () >>= fun xs ->
Lwt_list.map_p (vifs xs) items >>= fun items -> Lwt_list.map_p (vifs xs) items >>= fun items ->
fn (List.concat items |> VifMap.of_list); fn (List.concat items |> VifMap.of_list) >>= fun () ->
(* Wait for further updates *) (* Wait for further updates *)
Lwt.fail Xs_protocol.Eagain Lwt.fail Xs_protocol.Eagain
) )

View File

@ -15,7 +15,7 @@ module VifMap : sig
val find : key -> 'a t -> 'a option val find : key -> 'a t -> 'a option
end end
val watch_clients : (Ipaddr.V4.t VifMap.t -> unit) -> 'a Lwt.t val watch_clients : (Ipaddr.V4.t VifMap.t -> unit Lwt.t) -> 'a Lwt.t
(** [watch_clients fn] calls [fn clients] with the list of backend clients (** [watch_clients fn] calls [fn clients] with the list of backend clients
in XenStore, and again each time XenStore updates. *) in XenStore, and again each time XenStore updates. *)

View File

@ -17,8 +17,6 @@ struct
module I = Static_ipv4.Make (R) (Clock) (UplinkEth) (Arp) module I = Static_ipv4.Make (R) (Clock) (UplinkEth) (Arp)
module U = Udp.Make (I) (R) module U = Udp.Make (I) (R)
let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty
class client_iface eth ~domid ~gateway_ip ~client_ip client_mac : client_link class client_iface eth ~domid ~gateway_ip ~client_ip client_mac : client_link
= =
let log_header = Fmt.str "dom%d:%a" domid Ipaddr.V4.pp client_ip in let log_header = Fmt.str "dom%d:%a" domid Ipaddr.V4.pp client_ip in
@ -344,11 +342,12 @@ struct
(** Connect to a new client's interface and listen for incoming frames and firewall rule changes. *) (** Connect to a new client's interface and listen for incoming frames and firewall rule changes. *)
let add_vif get_ts { Dao.ClientVif.domid; device_id } dns_client dns_servers let add_vif get_ts { Dao.ClientVif.domid; device_id } dns_client dns_servers
~client_ip ~router ~cleanup_tasks qubesDB = ~client_ip ~router ~cleanup_tasks qubesDB () =
Netback.make ~domid ~device_id >>= fun backend -> let open Lwt.Syntax in
let* backend = Netback.make ~domid ~device_id in
Log.info (fun f -> Log.info (fun f ->
f "Client %d (IP: %s) ready" domid (Ipaddr.V4.to_string client_ip)); f "Client %d (IP: %s) ready" domid (Ipaddr.V4.to_string client_ip));
ClientEth.connect backend >>= fun eth -> let* eth = ClientEth.connect backend in
let client_mac = Netback.frontend_mac backend in let client_mac = Netback.frontend_mac backend in
let client_eth = router.clients in let client_eth = router.clients in
let gateway_ip = Client_eth.client_gw client_eth in let gateway_ip = Client_eth.client_gw client_eth in
@ -404,46 +403,54 @@ struct
(function Lwt.Canceled -> Lwt.return_unit | e -> Lwt.fail e) (function Lwt.Canceled -> Lwt.return_unit | e -> Lwt.fail e)
in in
Cleanup.on_cleanup cleanup_tasks (fun () -> Lwt.cancel listener); Cleanup.on_cleanup cleanup_tasks (fun () -> Lwt.cancel listener);
Lwt.pick [ qubesdb_updater; listener ] (* NOTE(dinosaure): [qubes_updater] and [listener] can be forgotten, our [cleanup_task]
will cancel them if the client is disconnected. *)
Lwt.async (fun () -> Lwt.pick [ qubesdb_updater; listener ]);
Lwt.return_unit
(** A new client VM has been found in XenStore. Find its interface and connect to it. *) (** A new client VM has been found in XenStore. Find its interface and connect to it. *)
let add_client get_ts dns_client dns_servers ~router vif client_ip qubesDB = let add_client get_ts dns_client dns_servers ~router vif client_ip qubesDB =
let open Lwt.Syntax in
let cleanup_tasks = Cleanup.create () in let cleanup_tasks = Cleanup.create () in
Log.info (fun f -> Log.info (fun f ->
f "add client vif %a with IP %a" Dao.ClientVif.pp vif Ipaddr.V4.pp f "add client vif %a with IP %a" Dao.ClientVif.pp vif Ipaddr.V4.pp
client_ip); client_ip);
Lwt.async (fun () -> let* () =
Lwt.catch Lwt.catch (add_vif get_ts vif dns_client dns_servers ~client_ip ~router
(fun () ->
add_vif get_ts vif dns_client dns_servers ~client_ip ~router
~cleanup_tasks qubesDB) ~cleanup_tasks qubesDB)
(fun ex -> @@ fun exn ->
Log.warn (fun f -> Log.warn (fun f ->
f "Error with client %a: %s" Dao.ClientVif.pp vif f "Error with client %a: %s" Dao.ClientVif.pp vif
(Printexc.to_string ex)); (Printexc.to_string exn));
Lwt.return_unit)); Lwt.return_unit
cleanup_tasks in
Lwt.return cleanup_tasks
(** Watch XenStore for notifications of new clients. *) (** Watch XenStore for notifications of new clients. *)
let wait_clients get_ts dns_client dns_servers qubesDB router = let wait_clients get_ts dns_client dns_servers qubesDB router =
Dao.watch_clients (fun new_set -> let open Lwt.Syntax in
let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty in
Dao.watch_clients @@ fun new_set ->
(* Check for removed clients *) (* Check for removed clients *)
!clients let clean_up_clients key cleanup =
|> Dao.VifMap.iter (fun key cleanup -> if not (Dao.VifMap.mem key new_set) then begin
if not (Dao.VifMap.mem key new_set) then (
clients := !clients |> Dao.VifMap.remove key; clients := !clients |> Dao.VifMap.remove key;
Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key); Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key);
Cleanup.cleanup cleanup)); Cleanup.cleanup cleanup
(* Check for added clients *) end
new_set
|> Dao.VifMap.iter (fun key ip_addr ->
if not (Dao.VifMap.mem key !clients) then (
let cleanup =
add_client get_ts dns_client dns_servers ~router key ip_addr
qubesDB
in in
Dao.VifMap.iter clean_up_clients !clients;
(* Check for added clients *)
let rec go seq = match Seq.uncons seq with
| None -> Lwt.return_unit
| Some ((key, ipaddr), seq) when not (Dao.VifMap.mem key !clients) ->
let* cleanup = add_client get_ts dns_client dns_servers ~router key ipaddr qubesDB in
Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key); Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key);
clients := !clients |> Dao.VifMap.add key cleanup))) clients := Dao.VifMap.add key cleanup !clients;
go seq
| Some (_, seq) -> go seq
in
go (Dao.VifMap.to_seq new_set)
let send_dns_client_query t ~src_port ~dst ~dst_port buf = let send_dns_client_query t ~src_port ~dst ~dst_port buf =
match t.uplink with match t.uplink with

View File

@ -3,14 +3,6 @@
(** General utility functions. *) (** General utility functions. *)
module IpMap = struct
include Map.Make(Ipaddr.V4)
let find x map =
try Some (find x map)
with Not_found -> None
| _ -> Logs.err( fun f -> f "uncaught exception in find...%!"); None
end
(** An Ethernet interface. *) (** An Ethernet interface. *)
class type interface = object class type interface = object
method my_mac : Macaddr.t method my_mac : Macaddr.t

View File

@ -46,15 +46,12 @@ module Main (R : Mirage_crypto_rng_mirage.S)(Clock : Mirage_clock.MCLOCK)(Time :
(* Main unikernel entry point (called from auto-generated main.ml). *) (* Main unikernel entry point (called from auto-generated main.ml). *)
let start _random _clock _time = let start _random _clock _time =
let open Lwt.Syntax in
let start_time = Clock.elapsed_ns () in let start_time = Clock.elapsed_ns () in
(* Start qrexec agent and QubesDB agent in parallel *) (* Start qrexec agent and QubesDB agent in parallel *)
let qrexec = RExec.connect ~domid:0 () in let* qrexec = RExec.connect ~domid:0 () in
let qubesDB = DB.connect ~domid:0 () in
(* Wait for clients to connect *)
qrexec >>= fun qrexec ->
let agent_listener = RExec.listen qrexec Command.handler in let agent_listener = RExec.listen qrexec Command.handler in
qubesDB >>= fun qubesDB -> let* qubesDB = DB.connect ~domid:0 () in
let startup_time = let startup_time =
let (-) = Int64.sub in let (-) = Int64.sub in
let time_in_ns = Clock.elapsed_ns () - start_time in let time_in_ns = Clock.elapsed_ns () - start_time in
@ -93,7 +90,7 @@ module Main (R : Mirage_crypto_rng_mirage.S)(Clock : Mirage_clock.MCLOCK)(Time :
Dao.print_network_config config ; Dao.print_network_config config ;
(* Set up client-side networking *) (* Set up client-side networking *)
Client_eth.create config >>= fun clients -> let* clients = Client_eth.create config in
(* Set up routing between networks and hosts *) (* Set up routing between networks and hosts *)
let router = Dispatcher.create let router = Dispatcher.create