Use Lwt.Syntax and avoid some >>= fun () patterns

This commit is contained in:
Calascibetta Romain 2024-05-22 09:41:11 +02:00 committed by Hannes Mehnert
parent 8f739c610e
commit c7d8751b1c
2 changed files with 30 additions and 41 deletions

60
dao.ml
View File

@ -65,43 +65,35 @@ let read_rules rules client_ip =
number = 0;})] number = 0;})]
let vifs client domid = let vifs client domid =
let open Lwt.Syntax in
match int_of_string_opt domid with match int_of_string_opt domid with
| None -> Log.err (fun f -> f "Invalid domid %S" domid); Lwt.return [] | None -> Log.err (fun f -> f "Invalid domid %S" domid); Lwt.return []
| Some domid -> | Some domid ->
let path = Printf.sprintf "backend/vif/%d" domid in let path = Fmt.str "backend/vif/%d" domid in
Xen_os.Xs.immediate client (fun handle -> let fn handle =
directory ~handle path >>= let* entries = directory ~handle path in
Lwt_list.filter_map_p (fun device_id -> let fn device_id = match int_of_string_opt device_id with
match int_of_string_opt device_id with | None ->
| None -> Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid); Lwt.return_none Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid);
| Some device_id -> Lwt.return_none
let vif = { ClientVif.domid; device_id } in | Some device_id ->
Lwt.try_bind let vif = { ClientVif.domid; device_id } in
(fun () -> Xen_os.Xs.read handle (Printf.sprintf "%s/%d/ip" path device_id)) let fn () =
(fun client_ip -> let* str = Xen_os.Xs.read handle (Fmt.str "%s/%d/ip" path device_id) in
let client_ip' = match String.split_on_char ' ' client_ip with let[@warning "-8"] client_ip :: _ = String.split_on_char ' ' str in
| [] -> Log.err (fun m -> m "unexpected empty list"); "" Lwt.return_some (vif, Ipaddr.V4.of_string_exn client_ip) in
| [ ip ] -> ip Lwt.catch fn @@ function
| ip::rest -> | Xs_protocol.Enoent _ -> Lwt.return_none
Log.warn (fun m -> m "ignoring IPs %s from %a, we support one IP per client" | Ipaddr.Parse_error (msg, client_ip) ->
(String.concat " " rest) ClientVif.pp vif); Log.err (fun f -> f "Error parsing IP address of %a from %s: %s"
ip ClientVif.pp vif client_ip msg);
in Lwt.return_none
match Ipaddr.V4.of_string client_ip' with | exn ->
| Ok ip -> Lwt.return (Some (vif, ip)) Log.err (fun f -> f "Error getting IP address of %a: %s"
| Error `Msg msg -> ClientVif.pp vif (Printexc.to_string exn));
Log.err (fun f -> f "Error parsing IP address of %a from %s: %s" Lwt.return_none in
ClientVif.pp vif client_ip msg); Lwt_list.filter_map_p fn entries in
Lwt.return None Xen_os.Xs.immediate client fn
)
(function
| Xs_protocol.Enoent _ -> Lwt.return None
| ex ->
Log.err (fun f -> f "Error getting IP address of %a: %s"
ClientVif.pp vif (Printexc.to_string ex));
Lwt.return None
)
))
let watch_clients fn = let watch_clients fn =
Xen_os.Xs.make () >>= fun xs -> Xen_os.Xs.make () >>= fun xs ->

View File

@ -46,15 +46,12 @@ module Main (R : Mirage_crypto_rng_mirage.S)(Clock : Mirage_clock.MCLOCK)(Time :
(* Main unikernel entry point (called from auto-generated main.ml). *) (* Main unikernel entry point (called from auto-generated main.ml). *)
let start _random _clock _time = let start _random _clock _time =
let open Lwt.Syntax in
let start_time = Clock.elapsed_ns () in let start_time = Clock.elapsed_ns () in
(* Start qrexec agent and QubesDB agent in parallel *) (* Start qrexec agent and QubesDB agent in parallel *)
let qrexec = RExec.connect ~domid:0 () in let* qrexec = RExec.connect ~domid:0 () in
let qubesDB = DB.connect ~domid:0 () in
(* Wait for clients to connect *)
qrexec >>= fun qrexec ->
let agent_listener = RExec.listen qrexec Command.handler in let agent_listener = RExec.listen qrexec Command.handler in
qubesDB >>= fun qubesDB -> let* qubesDB = DB.connect ~domid:0 () in
let startup_time = let startup_time =
let (-) = Int64.sub in let (-) = Int64.sub in
let time_in_ns = Clock.elapsed_ns () - start_time in let time_in_ns = Clock.elapsed_ns () - start_time in
@ -93,7 +90,7 @@ module Main (R : Mirage_crypto_rng_mirage.S)(Clock : Mirage_clock.MCLOCK)(Time :
Dao.print_network_config config ; Dao.print_network_config config ;
(* Set up client-side networking *) (* Set up client-side networking *)
Client_eth.create config >>= fun clients -> let* clients = Client_eth.create config in
(* Set up routing between networks and hosts *) (* Set up routing between networks and hosts *)
let router = Dispatcher.create let router = Dispatcher.create