Use Lwt.Syntax and avoid some >>= fun () patterns

This commit is contained in:
Calascibetta Romain 2024-05-22 09:41:11 +02:00
parent 9058d25dcc
commit 412276cf4f
2 changed files with 30 additions and 41 deletions

60
dao.ml
View File

@ -65,43 +65,35 @@ let read_rules rules client_ip =
number = 0;})]
let vifs client domid =
let open Lwt.Syntax in
match int_of_string_opt domid with
| None -> Log.err (fun f -> f "Invalid domid %S" domid); Lwt.return []
| Some domid ->
let path = Printf.sprintf "backend/vif/%d" domid in
Xen_os.Xs.immediate client (fun handle ->
directory ~handle path >>=
Lwt_list.filter_map_p (fun device_id ->
match int_of_string_opt device_id with
| None -> Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid); Lwt.return_none
| Some device_id ->
let vif = { ClientVif.domid; device_id } in
Lwt.try_bind
(fun () -> Xen_os.Xs.read handle (Printf.sprintf "%s/%d/ip" path device_id))
(fun client_ip ->
let client_ip' = match String.split_on_char ' ' client_ip with
| [] -> Log.err (fun m -> m "unexpected empty list"); ""
| [ ip ] -> ip
| ip::rest ->
Log.warn (fun m -> m "ignoring IPs %s from %a, we support one IP per client"
(String.concat " " rest) ClientVif.pp vif);
ip
in
match Ipaddr.V4.of_string client_ip' with
| Ok ip -> Lwt.return (Some (vif, ip))
| Error `Msg msg ->
Log.err (fun f -> f "Error parsing IP address of %a from %s: %s"
ClientVif.pp vif client_ip msg);
Lwt.return None
)
(function
| Xs_protocol.Enoent _ -> Lwt.return None
| ex ->
Log.err (fun f -> f "Error getting IP address of %a: %s"
ClientVif.pp vif (Printexc.to_string ex));
Lwt.return None
)
))
let path = Fmt.str "backend/vif/%d" domid in
let fn handle =
let* entries = directory ~handle path in
let fn device_id = match int_of_string_opt device_id with
| None ->
Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid);
Lwt.return_none
| Some device_id ->
let vif = { ClientVif.domid; device_id } in
let fn () =
let* str = Xen_os.Xs.read handle (Fmt.str "%s/%d/ip" path device_id) in
let[@warning "-8"] client_ip :: _ = String.split_on_char ' ' str in
Lwt.return_some (vif, Ipaddr.V4.of_string_exn client_ip) in
Lwt.catch fn @@ function
| Xs_protocol.Enoent _ -> Lwt.return_none
| Ipaddr.Parse_error (msg, client_ip) ->
Log.err (fun f -> f "Error parsing IP address of %a from %s: %s"
ClientVif.pp vif client_ip msg);
Lwt.return_none
| exn ->
Log.err (fun f -> f "Error getting IP address of %a: %s"
ClientVif.pp vif (Printexc.to_string exn));
Lwt.return_none in
Lwt_list.filter_map_p fn entries in
Xen_os.Xs.immediate client fn
let watch_clients fn =
Xen_os.Xs.make () >>= fun xs ->

View File

@ -46,15 +46,12 @@ module Main (R : Mirage_random.S)(Clock : Mirage_clock.MCLOCK)(Time : Mirage_tim
(* Main unikernel entry point (called from auto-generated main.ml). *)
let start _random _clock _time nat_table_size ipv4 ipv4_gw ipv4_dns ipv4_dns2 =
let open Lwt.Syntax in
let start_time = Clock.elapsed_ns () in
(* Start qrexec agent and QubesDB agent in parallel *)
let qrexec = RExec.connect ~domid:0 () in
let qubesDB = DB.connect ~domid:0 () in
(* Wait for clients to connect *)
qrexec >>= fun qrexec ->
let* qrexec = RExec.connect ~domid:0 () in
let agent_listener = RExec.listen qrexec Command.handler in
qubesDB >>= fun qubesDB ->
let* qubesDB = DB.connect ~domid:0 () in
let startup_time =
let (-) = Int64.sub in
let time_in_ns = Clock.elapsed_ns () - start_time in
@ -93,7 +90,7 @@ module Main (R : Mirage_random.S)(Clock : Mirage_clock.MCLOCK)(Time : Mirage_tim
Dao.print_network_config config ;
(* Set up client-side networking *)
Client_eth.create config >>= fun clients ->
let* clients = Client_eth.create config in
(* Set up routing between networks and hosts *)
let router = Dispatcher.create