diff --git a/client_net.ml b/client_net.ml index 5de5fa2..ca39938 100644 --- a/client_net.ml +++ b/client_net.ml @@ -33,7 +33,7 @@ class client_iface eth ~gateway_ip ~client_ip client_mac : client_link = object ) end -let clients : Cleanup.t IntMap.t ref = ref IntMap.empty +let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty (** Handle an ARP message from the client. *) let input_arp ~fixed_arp ~eth request = @@ -52,7 +52,7 @@ let input_ipv4 ~client_ip ~router frame packet = ) (** Connect to a new client's interface and listen for incoming frames. *) -let add_vif { Dao.domid; device_id; client_ip } ~router ~cleanup_tasks = +let add_vif { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanup_tasks = Netback.make ~domid ~device_id >>= fun backend -> Log.info (fun f -> f "Client %d (IP: %s) ready" domid (Ipaddr.V4.to_string client_ip)); ClientEth.connect backend >>= or_fail "Can't make Ethernet device" >>= fun eth -> @@ -75,45 +75,37 @@ let add_vif { Dao.domid; device_id; client_ip } ~router ~cleanup_tasks = ) (** A new client VM has been found in XenStore. Find its interface and connect to it. *) -let add_client ~router domid = +let add_client ~router vif client_ip = let cleanup_tasks = Cleanup.create () in - Log.info (fun f -> f "add client domain %d" domid); + Log.info (fun f -> f "add client vif %a" Dao.ClientVif.pp vif); Lwt.async (fun () -> - Lwt.catch (fun () -> - Dao.client_vifs domid >>= function - | [] -> - Log.warn (fun f -> f "Client has no interfaces"); - return () - | vif :: others -> - if others <> [] then Log.warn (fun f -> f "Client has multiple interfaces; using first"); - add_vif vif ~router ~cleanup_tasks - ) - (fun ex -> - Log.warn (fun f -> f "Error connecting client domain %d: %s" - domid (Printexc.to_string ex)); - return () - ) - ); + Lwt.catch (fun () -> + add_vif vif ~client_ip ~router ~cleanup_tasks + ) + (fun ex -> + Log.warn (fun f -> f "Error connecting client %a: %s" + Dao.ClientVif.pp vif (Printexc.to_string ex)); + return () + ) + ); cleanup_tasks (** Watch XenStore for notifications of new clients. *) let listen router = - let backend_vifs = "backend/vif" in - Log.info (fun f -> f "Watching %s" backend_vifs); Dao.watch_clients (fun new_set -> (* Check for removed clients *) - !clients |> IntMap.iter (fun key cleanup -> - if not (IntSet.mem key new_set) then ( - clients := !clients |> IntMap.remove key; - Log.info (fun f -> f "client %d has gone" key); + !clients |> Dao.VifMap.iter (fun key cleanup -> + if not (Dao.VifMap.mem key new_set) then ( + clients := !clients |> Dao.VifMap.remove key; + Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key); Cleanup.cleanup cleanup ) ); (* Check for added clients *) - new_set |> IntSet.iter (fun key -> - if not (IntMap.mem key !clients) then ( - let cleanup = add_client ~router key in - clients := !clients |> IntMap.add key cleanup + new_set |> Dao.VifMap.iter (fun key ip_addr -> + if not (Dao.VifMap.mem key !clients) then ( + let cleanup = add_client ~router key ip_addr in + clients := !clients |> Dao.VifMap.add key cleanup ) ) ) diff --git a/dao.ml b/dao.ml index f0ab65b..dd22735 100644 --- a/dao.ml +++ b/dao.ml @@ -4,38 +4,75 @@ open Lwt.Infix open Utils open Qubes +open Astring -type client_vif = { - domid : int; - device_id : int; - client_ip : Ipaddr.V4.t; -} +let src = Logs.Src.create "dao" ~doc:"QubesDB data access" +module Log = (val Logs.src_log src : Logs.LOG) -let client_vifs domid = - let path = Printf.sprintf "backend/vif/%d" domid in - OS.Xs.make () >>= fun xs -> - OS.Xs.immediate xs (fun h -> - OS.Xs.directory h path >>= - Lwt_list.map_p (fun device_id -> - let device_id = int_of_string device_id in - OS.Xs.read h (Printf.sprintf "%s/%d/ip" path device_id) >|= fun client_ip -> - let client_ip = Ipaddr.V4.of_string_exn client_ip in - { domid; device_id; client_ip } - ) - ) +module ClientVif = struct + type t = { + domid : int; + device_id : int; + } + + let pp f { domid; device_id } = Fmt.pf f "{domid=%d;device_id=%d}" domid device_id + + let compare = compare +end +module VifMap = struct + include Map.Make(ClientVif) + let rec of_list = function + | [] -> empty + | (k, v) :: rest -> add k v (of_list rest) + let find key t = + try Some (find key t) + with Not_found -> None +end + +let directory ~handle dir = + OS.Xs.directory handle dir >|= function + | [""] -> [] (* XenStore client bug *) + | items -> items + +let vifs ~handle domid = + match String.to_int domid with + | None -> Log.err (fun f -> f "Invalid domid %S" domid); Lwt.return [] + | Some domid -> + let path = Printf.sprintf "backend/vif/%d" domid in + directory ~handle path >>= + Lwt_list.filter_map_p (fun device_id -> + match String.to_int device_id with + | None -> Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid); Lwt.return_none + | Some device_id -> + let vif = { ClientVif.domid; device_id } in + Lwt.try_bind + (fun () -> OS.Xs.read handle (Printf.sprintf "%s/%d/ip" path device_id)) + (fun client_ip -> + let client_ip = Ipaddr.V4.of_string_exn client_ip in + Lwt.return (Some (vif, client_ip)) + ) + (function + | Xs_protocol.Enoent _ -> Lwt.return None + | ex -> + Log.err (fun f -> f "Error getting IP address of %a: %s" + ClientVif.pp vif (Printexc.to_string ex)); + Lwt.return None + ) + ) let watch_clients fn = OS.Xs.make () >>= fun xs -> let backend_vifs = "backend/vif" in + Log.info (fun f -> f "Watching %s" backend_vifs); OS.Xs.wait xs (fun handle -> begin Lwt.catch - (fun () -> OS.Xs.directory handle backend_vifs) + (fun () -> directory ~handle backend_vifs) (function | Xs_protocol.Enoent _ -> return [] | ex -> fail ex) end >>= fun items -> - let items = items |> List.fold_left (fun acc key -> IntSet.add (int_of_string key) acc) IntSet.empty in - fn items; + Lwt_list.map_p (vifs ~handle) items >>= fun items -> + fn (List.concat items |> VifMap.of_list); (* Wait for further updates *) fail Xs_protocol.Eagain ) diff --git a/dao.mli b/dao.mli index c0f2862..e1b96c6 100644 --- a/dao.mli +++ b/dao.mli @@ -3,20 +3,21 @@ (** Wrapper for XenStore and QubesDB databases. *) -open Utils +module ClientVif : sig + type t = { + domid : int; + device_id : int; + } + val pp : t Fmt.t +end +module VifMap : sig + include Map.S with type key = ClientVif.t + val find : key -> 'a t -> 'a option +end -type client_vif = { - domid : int; - device_id : int; - client_ip : Ipaddr.V4.t; -} - -val watch_clients : (IntSet.t -> unit) -> 'a Lwt.t -(** [watch_clients fn] calls [fn clients] with the current set of backend client domain IDs - in XenStore, and again each time the set changes. *) - -val client_vifs : int -> client_vif list Lwt.t -(** [client_vif domid] is the list of network interfaces to the client VM [domid]. *) +val watch_clients : (Ipaddr.V4.t VifMap.t -> unit) -> 'a Lwt.t +(** [watch_clients fn] calls [fn clients] with the list of backend clients + in XenStore, and again each time XenStore updates. *) type network_config = { uplink_netvm_ip : Ipaddr.V4.t; (* The IP address of NetVM (our gateway) *)