Merge pull request #197 from dinosaure/lint

Use Lwt.Syntax and avoid some >>= fun () patterns
This commit is contained in:
Pierre Alain 2024-10-16 14:19:17 +02:00 committed by GitHub
commit 56e66ca39a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 88 additions and 95 deletions

View File

@ -20,5 +20,5 @@ $builder build -t qubes-mirage-firewall .
echo Building Firewall...
$builder run --rm -i -v `pwd`:/tmp/orb-build:Z qubes-mirage-firewall
echo "SHA2 of build: $(sha256sum ./dist/qubes-firewall.xen)"
echo "SHA2 last known: 4b1f743bf4540bc8a9366cf8f23a78316e4f2d477af77962e50618753c4adf10"
echo "SHA2 last known: 2392386d9056b17a648f26b0c5d1c72b93f8a197964c670b2b45e71707727317"
echo "(hashes should match for released versions)"

View File

@ -8,7 +8,7 @@ let src = Logs.Src.create "client_eth" ~doc:"Ethernet networks for NetVM clients
module Log = (val Logs.src_log src : Logs.LOG)
type t = {
mutable iface_of_ip : client_link IpMap.t;
mutable iface_of_ip : client_link Ipaddr.V4.Map.t;
changed : unit Lwt_condition.t; (* Fires when [iface_of_ip] changes. *)
my_ip : Ipaddr.V4.t; (* The IP that clients are given as their default gateway. *)
}
@ -21,21 +21,21 @@ type host =
let create config =
let changed = Lwt_condition.create () in
let my_ip = config.Dao.our_ip in
Lwt.return { iface_of_ip = IpMap.empty; my_ip; changed }
Lwt.return { iface_of_ip = Ipaddr.V4.Map.empty; my_ip; changed }
let client_gw t = t.my_ip
let add_client t iface =
let ip = iface#other_ip in
let rec aux () =
match IpMap.find ip t.iface_of_ip with
match Ipaddr.V4.Map.find_opt ip t.iface_of_ip with
| Some old ->
(* Wait for old client to disappear before adding one with the same IP address.
Otherwise, its [remove_client] call will remove the new client instead. *)
Log.info (fun f -> f ~header:iface#log_header "Waiting for old client %s to go away before accepting new one" old#log_header);
Lwt_condition.wait t.changed >>= aux
| None ->
t.iface_of_ip <- t.iface_of_ip |> IpMap.add ip iface;
t.iface_of_ip <- t.iface_of_ip |> Ipaddr.V4.Map.add ip iface;
Lwt_condition.broadcast t.changed ();
Lwt.return_unit
in
@ -43,11 +43,11 @@ let add_client t iface =
let remove_client t iface =
let ip = iface#other_ip in
assert (IpMap.mem ip t.iface_of_ip);
t.iface_of_ip <- t.iface_of_ip |> IpMap.remove ip;
assert (Ipaddr.V4.Map.mem ip t.iface_of_ip);
t.iface_of_ip <- t.iface_of_ip |> Ipaddr.V4.Map.remove ip;
Lwt_condition.broadcast t.changed ()
let lookup t ip = IpMap.find ip t.iface_of_ip
let lookup t ip = Ipaddr.V4.Map.find_opt ip t.iface_of_ip
let classify t ip =
match ip with
@ -79,7 +79,7 @@ module ARP = struct
(* We're now treating client networks as point-to-point links,
so we no longer respond on behalf of other clients. *)
(*
else match IpMap.find ip t.net.iface_of_ip with
else match Ipaddr.V4.Map.find_opt ip t.net.iface_of_ip with
| Some client_iface -> Some client_iface#other_mac
| None -> None
*)

67
dao.ml
View File

@ -65,43 +65,40 @@ let read_rules rules client_ip =
number = 0;})]
let vifs client domid =
let open Lwt.Syntax in
match int_of_string_opt domid with
| None -> Log.err (fun f -> f "Invalid domid %S" domid); Lwt.return []
| Some domid ->
let path = Printf.sprintf "backend/vif/%d" domid in
Xen_os.Xs.immediate client (fun handle ->
directory ~handle path >>=
Lwt_list.filter_map_p (fun device_id ->
match int_of_string_opt device_id with
| None -> Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid); Lwt.return_none
| Some device_id ->
let vif = { ClientVif.domid; device_id } in
Lwt.try_bind
(fun () -> Xen_os.Xs.read handle (Printf.sprintf "%s/%d/ip" path device_id))
(fun client_ip ->
let client_ip' = match String.split_on_char ' ' client_ip with
| [] -> Log.err (fun m -> m "unexpected empty list"); ""
| [ ip ] -> ip
| ip::rest ->
Log.warn (fun m -> m "ignoring IPs %s from %a, we support one IP per client"
(String.concat " " rest) ClientVif.pp vif);
ip
in
match Ipaddr.V4.of_string client_ip' with
| Ok ip -> Lwt.return (Some (vif, ip))
| Error `Msg msg ->
Log.err (fun f -> f "Error parsing IP address of %a from %s: %s"
ClientVif.pp vif client_ip msg);
Lwt.return None
)
(function
| Xs_protocol.Enoent _ -> Lwt.return None
| ex ->
Log.err (fun f -> f "Error getting IP address of %a: %s"
ClientVif.pp vif (Printexc.to_string ex));
Lwt.return None
)
))
let path = Fmt.str "backend/vif/%d" domid in
let vifs_of_domain handle =
let* devices = directory ~handle path in
let ip_of_vif device_id = match int_of_string_opt device_id with
| None ->
Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid);
Lwt.return_none
| Some device_id ->
let vif = { ClientVif.domid; device_id } in
let get_client_ip () =
let* str = Xen_os.Xs.read handle (Fmt.str "%s/%d/ip" path device_id) in
let client_ip = List.hd (String.split_on_char ' ' str) in
(* NOTE(dinosaure): it's safe to use [List.hd] here,
[String.split_on_char] can not return an empty list. *)
Lwt.return_some (vif, Ipaddr.V4.of_string_exn client_ip)
in
Lwt.catch get_client_ip @@ function
| Xs_protocol.Enoent _ -> Lwt.return_none
| Ipaddr.Parse_error (msg, client_ip) ->
Log.err (fun f -> f "Error parsing IP address of %a from %s: %s"
ClientVif.pp vif client_ip msg);
Lwt.return_none
| exn ->
Log.err (fun f -> f "Error getting IP address of %a: %s"
ClientVif.pp vif (Printexc.to_string exn));
Lwt.return_none
in
Lwt_list.filter_map_p ip_of_vif devices
in
Xen_os.Xs.immediate client vifs_of_domain
let watch_clients fn =
Xen_os.Xs.make () >>= fun xs ->
@ -116,7 +113,7 @@ let watch_clients fn =
end >>= fun items ->
Xen_os.Xs.make () >>= fun xs ->
Lwt_list.map_p (vifs xs) items >>= fun items ->
fn (List.concat items |> VifMap.of_list);
fn (List.concat items |> VifMap.of_list) >>= fun () ->
(* Wait for further updates *)
Lwt.fail Xs_protocol.Eagain
)

View File

@ -15,7 +15,7 @@ module VifMap : sig
val find : key -> 'a t -> 'a option
end
val watch_clients : (Ipaddr.V4.t VifMap.t -> unit) -> 'a Lwt.t
val watch_clients : (Ipaddr.V4.t VifMap.t -> unit Lwt.t) -> 'a Lwt.t
(** [watch_clients fn] calls [fn clients] with the list of backend clients
in XenStore, and again each time XenStore updates. *)

View File

@ -17,8 +17,6 @@ struct
module I = Static_ipv4.Make (R) (Clock) (UplinkEth) (Arp)
module U = Udp.Make (I) (R)
let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty
class client_iface eth ~domid ~gateway_ip ~client_ip client_mac : client_link
=
let log_header = Fmt.str "dom%d:%a" domid Ipaddr.V4.pp client_ip in
@ -344,11 +342,12 @@ struct
(** Connect to a new client's interface and listen for incoming frames and firewall rule changes. *)
let add_vif get_ts { Dao.ClientVif.domid; device_id } dns_client dns_servers
~client_ip ~router ~cleanup_tasks qubesDB =
Netback.make ~domid ~device_id >>= fun backend ->
~client_ip ~router ~cleanup_tasks qubesDB () =
let open Lwt.Syntax in
let* backend = Netback.make ~domid ~device_id in
Log.info (fun f ->
f "Client %d (IP: %s) ready" domid (Ipaddr.V4.to_string client_ip));
ClientEth.connect backend >>= fun eth ->
let* eth = ClientEth.connect backend in
let client_mac = Netback.frontend_mac backend in
let client_eth = router.clients in
let gateway_ip = Client_eth.client_gw client_eth in
@ -404,46 +403,54 @@ struct
(function Lwt.Canceled -> Lwt.return_unit | e -> Lwt.fail e)
in
Cleanup.on_cleanup cleanup_tasks (fun () -> Lwt.cancel listener);
Lwt.pick [ qubesdb_updater; listener ]
(* NOTE(dinosaure): [qubes_updater] and [listener] can be forgotten, our [cleanup_task]
will cancel them if the client is disconnected. *)
Lwt.async (fun () -> Lwt.pick [ qubesdb_updater; listener ]);
Lwt.return_unit
(** A new client VM has been found in XenStore. Find its interface and connect to it. *)
let add_client get_ts dns_client dns_servers ~router vif client_ip qubesDB =
let open Lwt.Syntax in
let cleanup_tasks = Cleanup.create () in
Log.info (fun f ->
f "add client vif %a with IP %a" Dao.ClientVif.pp vif Ipaddr.V4.pp
client_ip);
Lwt.async (fun () ->
Lwt.catch
(fun () ->
add_vif get_ts vif dns_client dns_servers ~client_ip ~router
~cleanup_tasks qubesDB)
(fun ex ->
Log.warn (fun f ->
f "Error with client %a: %s" Dao.ClientVif.pp vif
(Printexc.to_string ex));
Lwt.return_unit));
cleanup_tasks
let* () =
Lwt.catch (add_vif get_ts vif dns_client dns_servers ~client_ip ~router
~cleanup_tasks qubesDB)
@@ fun exn ->
Log.warn (fun f ->
f "Error with client %a: %s" Dao.ClientVif.pp vif
(Printexc.to_string exn));
Lwt.return_unit
in
Lwt.return cleanup_tasks
(** Watch XenStore for notifications of new clients. *)
let wait_clients get_ts dns_client dns_servers qubesDB router =
Dao.watch_clients (fun new_set ->
(* Check for removed clients *)
!clients
|> Dao.VifMap.iter (fun key cleanup ->
if not (Dao.VifMap.mem key new_set) then (
clients := !clients |> Dao.VifMap.remove key;
Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key);
Cleanup.cleanup cleanup));
(* Check for added clients *)
new_set
|> Dao.VifMap.iter (fun key ip_addr ->
if not (Dao.VifMap.mem key !clients) then (
let cleanup =
add_client get_ts dns_client dns_servers ~router key ip_addr
qubesDB
in
Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key);
clients := !clients |> Dao.VifMap.add key cleanup)))
let open Lwt.Syntax in
let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty in
Dao.watch_clients @@ fun new_set ->
(* Check for removed clients *)
let clean_up_clients key cleanup =
if not (Dao.VifMap.mem key new_set) then begin
clients := !clients |> Dao.VifMap.remove key;
Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key);
Cleanup.cleanup cleanup
end
in
Dao.VifMap.iter clean_up_clients !clients;
(* Check for added clients *)
let rec go seq = match Seq.uncons seq with
| None -> Lwt.return_unit
| Some ((key, ipaddr), seq) when not (Dao.VifMap.mem key !clients) ->
let* cleanup = add_client get_ts dns_client dns_servers ~router key ipaddr qubesDB in
Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key);
clients := Dao.VifMap.add key cleanup !clients;
go seq
| Some (_, seq) -> go seq
in
go (Dao.VifMap.to_seq new_set)
let send_dns_client_query t ~src_port ~dst ~dst_port buf =
match t.uplink with

View File

@ -3,14 +3,6 @@
(** General utility functions. *)
module IpMap = struct
include Map.Make(Ipaddr.V4)
let find x map =
try Some (find x map)
with Not_found -> None
| _ -> Logs.err( fun f -> f "uncaught exception in find...%!"); None
end
(** An Ethernet interface. *)
class type interface = object
method my_mac : Macaddr.t

View File

@ -46,15 +46,12 @@ module Main (R : Mirage_crypto_rng_mirage.S)(Clock : Mirage_clock.MCLOCK)(Time :
(* Main unikernel entry point (called from auto-generated main.ml). *)
let start _random _clock _time =
let open Lwt.Syntax in
let start_time = Clock.elapsed_ns () in
(* Start qrexec agent and QubesDB agent in parallel *)
let qrexec = RExec.connect ~domid:0 () in
let qubesDB = DB.connect ~domid:0 () in
(* Wait for clients to connect *)
qrexec >>= fun qrexec ->
let* qrexec = RExec.connect ~domid:0 () in
let agent_listener = RExec.listen qrexec Command.handler in
qubesDB >>= fun qubesDB ->
let* qubesDB = DB.connect ~domid:0 () in
let startup_time =
let (-) = Int64.sub in
let time_in_ns = Clock.elapsed_ns () - start_time in
@ -93,7 +90,7 @@ module Main (R : Mirage_crypto_rng_mirage.S)(Clock : Mirage_clock.MCLOCK)(Time :
Dao.print_network_config config ;
(* Set up client-side networking *)
Client_eth.create config >>= fun clients ->
let* clients = Client_eth.create config in
(* Set up routing between networks and hosts *)
let router = Dispatcher.create