From 4d89b858922556c13a5be2f9797fc6159a791263 Mon Sep 17 00:00:00 2001 From: "Automated ocamlformat GitHub action, developed by robur.coop" Date: Tue, 18 Mar 2025 08:16:13 +0000 Subject: [PATCH] formatted code --- cleanup.ml | 4 +- cleanup.mli | 4 +- client_eth.ml | 115 +++-- client_eth.mli | 36 +- command.ml | 20 +- config.ml | 38 +- dao.ml | 193 ++++---- dao.mli | 35 +- dispatcher.ml | 1128 ++++++++++++++++++++++--------------------- fw_utils.ml | 12 +- memory_pressure.ml | 8 +- memory_pressure.mli | 4 +- my_dns.ml | 127 ++--- my_nat.ml | 72 ++- my_nat.mli | 22 +- packet.ml | 46 +- packet.mli | 30 +- rules.ml | 120 +++-- test/config.ml | 36 +- test/unikernel.ml | 460 +++++++++++------- unikernel.ml | 159 +++--- 21 files changed, 1433 insertions(+), 1236 deletions(-) diff --git a/cleanup.ml b/cleanup.ml index cbe9ebc..ecd3c78 100644 --- a/cleanup.ml +++ b/cleanup.ml @@ -4,9 +4,7 @@ type t = (unit -> unit) list ref let create () = ref [] - -let on_cleanup t fn = - t := fn :: !t +let on_cleanup t fn = t := fn :: !t let cleanup t = let tasks = !t in diff --git a/cleanup.mli b/cleanup.mli index d43661b..1358c07 100644 --- a/cleanup.mli +++ b/cleanup.mli @@ -1,8 +1,8 @@ (* Copyright (C) 2015, Thomas Leonard See the README file for details. *) -(** Register actions to take when a resource is finished. - Like [Lwt_switch], but synchronous. *) +(** Register actions to take when a resource is finished. Like [Lwt_switch], but + synchronous. *) type t diff --git a/client_eth.ml b/client_eth.ml index fc0b01a..bd9d931 100644 --- a/client_eth.ml +++ b/client_eth.ml @@ -4,19 +4,19 @@ open Fw_utils open Lwt.Infix -let src = Logs.Src.create "client_eth" ~doc:"Ethernet networks for NetVM clients" +let src = + Logs.Src.create "client_eth" ~doc:"Ethernet networks for NetVM clients" + module Log = (val Logs.src_log src : Logs.LOG) type t = { mutable iface_of_ip : client_link Ipaddr.V4.Map.t; - changed : unit Lwt_condition.t; (* Fires when [iface_of_ip] changes. *) - my_ip : Ipaddr.V4.t; (* The IP that clients are given as their default gateway. *) + changed : unit Lwt_condition.t; (* Fires when [iface_of_ip] changes. *) + my_ip : Ipaddr.V4.t; + (* The IP that clients are given as their default gateway. *) } -type host = - [ `Client of client_link - | `Firewall - | `External of Ipaddr.t ] +type host = [ `Client of client_link | `Firewall | `External of Ipaddr.t ] let create config = let changed = Lwt_condition.create () in @@ -30,14 +30,17 @@ let add_client t iface = let rec aux () = match Ipaddr.V4.Map.find_opt ip t.iface_of_ip with | Some old -> - (* Wait for old client to disappear before adding one with the same IP address. + (* Wait for old client to disappear before adding one with the same IP address. Otherwise, its [remove_client] call will remove the new client instead. *) - Log.info (fun f -> f ~header:iface#log_header "Waiting for old client %s to go away before accepting new one" old#log_header); - Lwt_condition.wait t.changed >>= aux + Log.info (fun f -> + f ~header:iface#log_header + "Waiting for old client %s to go away before accepting new one" + old#log_header); + Lwt_condition.wait t.changed >>= aux | None -> - t.iface_of_ip <- t.iface_of_ip |> Ipaddr.V4.Map.add ip iface; - Lwt_condition.broadcast t.changed (); - Lwt.return_unit + t.iface_of_ip <- t.iface_of_ip |> Ipaddr.V4.Map.add ip iface; + Lwt_condition.broadcast t.changed (); + Lwt.return_unit in aux () @@ -52,11 +55,12 @@ let lookup t ip = Ipaddr.V4.Map.find_opt ip t.iface_of_ip let classify t ip = match ip with | Ipaddr.V6 _ -> `External ip - | Ipaddr.V4 ip4 -> - if ip4 = t.my_ip then `Firewall - else match lookup t ip4 with - | Some client_link -> `Client client_link - | None -> `External ip + | Ipaddr.V4 ip4 -> ( + if ip4 = t.my_ip then `Firewall + else + match lookup t ip4 with + | Some client_link -> `Client client_link + | None -> `External ip) let resolve t : host -> Ipaddr.t = function | `Client client_link -> Ipaddr.V4 client_link#other_ip @@ -64,50 +68,53 @@ let resolve t : host -> Ipaddr.t = function | `External addr -> addr module ARP = struct - type arp = { - net : t; - client_link : client_link; - } + type arp = { net : t; client_link : client_link } let lookup t ip = if ip = t.net.my_ip then Some t.client_link#my_mac else if (Ipaddr.V4.to_octets ip).[3] = '\x01' then ( - Log.info (fun f -> f ~header:t.client_link#log_header - "Request for %a is invalid, but pretending it's me (see Qubes issue #5022)" Ipaddr.V4.pp ip); - Some t.client_link#my_mac - ) else None + Log.info (fun f -> + f ~header:t.client_link#log_header + "Request for %a is invalid, but pretending it's me (see Qubes \ + issue #5022)" + Ipaddr.V4.pp ip); + Some t.client_link#my_mac) + else None (* We're now treating client networks as point-to-point links, so we no longer respond on behalf of other clients. *) - (* + (* else match Ipaddr.V4.Map.find_opt ip t.net.iface_of_ip with | Some client_iface -> Some client_iface#other_mac | None -> None *) - let create ~net client_link = {net; client_link} + let create ~net client_link = { net; client_link } let input_query t arp = let req_ipv4 = arp.Arp_packet.target_ip in let pf (f : ?header:string -> ?tags:_ -> _) fmt = - f ~header:t.client_link#log_header ("who-has %a? " ^^ fmt) Ipaddr.V4.pp req_ipv4 + f ~header:t.client_link#log_header ("who-has %a? " ^^ fmt) Ipaddr.V4.pp + req_ipv4 in if req_ipv4 = t.client_link#other_ip then ( Log.info (fun f -> pf f "ignoring request for client's own IP"); - None - ) else match lookup t req_ipv4 with + None) + else + match lookup t req_ipv4 with | None -> - Log.info (fun f -> pf f "unknown address; not responding"); - None + Log.info (fun f -> pf f "unknown address; not responding"); + None | Some req_mac -> - Log.info (fun f -> pf f "responding with %a" Macaddr.pp req_mac); - Some { Arp_packet. - operation = Arp_packet.Reply; - (* The Target Hardware Address and IP are copied from the request *) - target_ip = arp.Arp_packet.source_ip; - target_mac = arp.Arp_packet.source_mac; - source_ip = req_ipv4; - source_mac = req_mac; - } + Log.info (fun f -> pf f "responding with %a" Macaddr.pp req_mac); + Some + { + Arp_packet.operation = Arp_packet.Reply; + (* The Target Hardware Address and IP are copied from the request *) + target_ip = arp.Arp_packet.source_ip; + target_mac = arp.Arp_packet.source_mac; + source_ip = req_ipv4; + source_mac = req_mac; + } let input_gratuitous t arp = let source_ip = arp.Arp_packet.source_ip in @@ -115,18 +122,28 @@ module ARP = struct let header = t.client_link#log_header in match lookup t source_ip with | Some real_mac when Macaddr.compare source_mac real_mac = 0 -> - Log.info (fun f -> f ~header "client suggests updating %s -> %s (as expected)" - (Ipaddr.V4.to_string source_ip) (Macaddr.to_string source_mac)); + Log.info (fun f -> + f ~header "client suggests updating %s -> %s (as expected)" + (Ipaddr.V4.to_string source_ip) + (Macaddr.to_string source_mac)) | Some other_mac -> - Log.warn (fun f -> f ~header "client suggests incorrect update %s -> %s (should be %s)" - (Ipaddr.V4.to_string source_ip) (Macaddr.to_string source_mac) (Macaddr.to_string other_mac)); + Log.warn (fun f -> + f ~header "client suggests incorrect update %s -> %s (should be %s)" + (Ipaddr.V4.to_string source_ip) + (Macaddr.to_string source_mac) + (Macaddr.to_string other_mac)) | None -> - Log.warn (fun f -> f ~header "client suggests incorrect update %s -> %s (unexpected IP)" - (Ipaddr.V4.to_string source_ip) (Macaddr.to_string source_mac)) + Log.warn (fun f -> + f ~header + "client suggests incorrect update %s -> %s (unexpected IP)" + (Ipaddr.V4.to_string source_ip) + (Macaddr.to_string source_mac)) let input t arp = let op = arp.Arp_packet.operation in match op with | Arp_packet.Request -> input_query t arp - | Arp_packet.Reply -> input_gratuitous t arp; None + | Arp_packet.Reply -> + input_gratuitous t arp; + None end diff --git a/client_eth.mli b/client_eth.mli index 02ccee9..d7ecb55 100644 --- a/client_eth.mli +++ b/client_eth.mli @@ -1,34 +1,32 @@ (* Copyright (C) 2016, Thomas Leonard See the README file for details. *) -(** The ethernet networks connecting us to our client AppVMs. - Note: each AppVM is on a point-to-point link, each link being considered to be a separate Ethernet network. *) +(** The ethernet networks connecting us to our client AppVMs. Note: each AppVM + is on a point-to-point link, each link being considered to be a separate + Ethernet network. *) open Fw_utils type t (** A collection of clients. *) -type host = - [ `Client of client_link - | `Firewall - | `External of Ipaddr.t ] +type host = [ `Client of client_link | `Firewall | `External of Ipaddr.t ] (* Note: Qubes does not allow us to distinguish between an external address and a disconnected client. See: https://github.com/talex5/qubes-mirage-firewall/issues/9#issuecomment-246956850 *) val create : Dao.network_config -> t Lwt.t -(** [create ~client_gw] is a network of client machines. - Qubes will have configured the client machines to use [client_gw] as their default gateway. *) +(** [create ~client_gw] is a network of client machines. Qubes will have + configured the client machines to use [client_gw] as their default gateway. +*) val add_client : t -> client_link -> unit Lwt.t -(** [add_client t client] registers a new client. If a client with this IP address is already registered, - it waits for [remove_client] to be called on that before adding the new client and returning. *) +(** [add_client t client] registers a new client. If a client with this IP + address is already registered, it waits for [remove_client] to be called on + that before adding the new client and returning. *) val remove_client : t -> client_link -> unit - val client_gw : t -> Ipaddr.V4.t - val classify : t -> Ipaddr.t -> host val resolve : t -> host -> Ipaddr.t @@ -36,18 +34,18 @@ val lookup : t -> Ipaddr.V4.t -> client_link option (** [lookup t addr] is the client with IP address [addr], if connected. *) module ARP : sig - (** We already know the correct mapping of IP addresses to MAC addresses, so we never - allow clients to update it. We log a warning if a client attempts to set incorrect - information. *) + (** We already know the correct mapping of IP addresses to MAC addresses, so + we never allow clients to update it. We log a warning if a client attempts + to set incorrect information. *) type arp (** An ARP-responder for one client. *) val create : net:t -> client_link -> arp - (** [create ~net client_link] is an ARP responder for [client_link]. - It answers only for the client's gateway address. *) + (** [create ~net client_link] is an ARP responder for [client_link]. It + answers only for the client's gateway address. *) val input : arp -> Arp_packet.t -> Arp_packet.t option - (** Process one ethernet frame containing an ARP message. - Returns a response frame, if one is needed. *) + (** Process one ethernet frame containing an ARP message. Returns a response + frame, if one is needed. *) end diff --git a/command.ml b/command.ml index da70727..0661bfc 100644 --- a/command.ml +++ b/command.ml @@ -4,24 +4,30 @@ (** Commands we provide via qvm-run. *) open Lwt - module Flow = Qubes.RExec.Flow let src = Logs.Src.create "command" ~doc:"qrexec command handler" + module Log = (val Logs.src_log src : Logs.LOG) let set_date_time flow = Flow.read_line flow >|= function - | `Eof -> Log.warn (fun f -> f "EOF reading time from dom0"); 1 - | `Ok line -> Log.info (fun f -> f "TODO: set time to %S" line); 0 + | `Eof -> + Log.warn (fun f -> f "EOF reading time from dom0"); + 1 + | `Ok line -> + Log.info (fun f -> f "TODO: set time to %S" line); + 0 let handler ~user:_ cmd flow = (* Write a message to the client and return an exit status of 1. *) let error fmt = - fmt |> Printf.ksprintf @@ fun s -> - Log.warn (fun f -> f "<< %s" s); - Flow.ewritef flow "%s [while processing %S]" s cmd >|= fun () -> 1 in + fmt + |> Printf.ksprintf @@ fun s -> + Log.warn (fun f -> f "<< %s" s); + Flow.ewritef flow "%s [while processing %S]" s cmd >|= fun () -> 1 + in match cmd with | "QUBESRPC qubes.SetDateTime dom0" -> set_date_time flow - | "QUBESRPC qubes.WaitForSession none" -> return 0 (* Always ready! *) + | "QUBESRPC qubes.WaitForSession none" -> return 0 (* Always ready! *) | cmd -> error "Unknown command %S" cmd diff --git a/config.ml b/config.ml index 5c06a4b..b663813 100644 --- a/config.ml +++ b/config.ml @@ -7,24 +7,24 @@ open Mirage let main = - main - ~packages:[ - package "vchan" ~min:"4.0.2"; - package "cstruct"; - package "tcpip" ~min:"3.7.0"; - package ~min:"2.3.0" ~sublibs:["mirage"] "arp"; - package ~min:"3.0.0" "ethernet"; - package "shared-memory-ring" ~min:"3.0.0"; - package "mirage-net-xen" ~min:"2.1.4"; - package "ipaddr" ~min:"5.2.0"; - package "mirage-qubes" ~min:"0.9.1"; - package ~min:"3.0.1" "mirage-nat"; - package "mirage-logs"; - package "mirage-xen" ~min:"8.0.0"; - package ~min:"6.4.0" "dns-client"; - package "pf-qubes"; - ] + main + ~packages: + [ + package "vchan" ~min:"4.0.2"; + package "cstruct"; + package "tcpip" ~min:"3.7.0"; + package ~min:"2.3.0" ~sublibs:[ "mirage" ] "arp"; + package ~min:"3.0.0" "ethernet"; + package "shared-memory-ring" ~min:"3.0.0"; + package "mirage-net-xen" ~min:"2.1.4"; + package "ipaddr" ~min:"5.2.0"; + package "mirage-qubes" ~min:"0.9.1"; + package ~min:"3.0.1" "mirage-nat"; + package "mirage-logs"; + package "mirage-xen" ~min:"8.0.0"; + package ~min:"6.4.0" "dns-client"; + package "pf-qubes"; + ] "Unikernel" job -let () = - register "qubes-firewall" [main] +let () = register "qubes-firewall" [ main ] diff --git a/dao.ml b/dao.ml index 9344c1f..9219fa6 100644 --- a/dao.ml +++ b/dao.ml @@ -5,35 +5,34 @@ open Lwt.Infix open Qubes let src = Logs.Src.create "dao" ~doc:"QubesDB data access" + module Log = (val Logs.src_log src : Logs.LOG) module ClientVif = struct - type t = { - domid : int; - device_id : int; - } + type t = { domid : int; device_id : int } - let pp f { domid; device_id } = Fmt.pf f "{domid=%d;device_id=%d}" domid device_id + let pp f { domid; device_id } = + Fmt.pf f "{domid=%d;device_id=%d}" domid device_id let compare = compare end + module VifMap = struct - include Map.Make(ClientVif) + include Map.Make (ClientVif) + let rec of_list = function | [] -> empty | (k, v) :: rest -> add k v (of_list rest) - let find key t = - try Some (find key t) - with Not_found -> None + + let find key t = try Some (find key t) with Not_found -> None end let directory ~handle dir = Xen_os.Xs.directory handle dir >|= function - | [""] -> [] (* XenStore client bug *) + | [ "" ] -> [] (* XenStore client bug *) | items -> items -let db_root client_ip = - "/qubes-firewall/" ^ (Ipaddr.V4.to_string client_ip) +let db_root client_ip = "/qubes-firewall/" ^ Ipaddr.V4.to_string client_ip let read_rules rules client_ip = let root = db_root client_ip in @@ -42,86 +41,101 @@ let read_rules rules client_ip = Log.debug (fun f -> f "reading %s" pattern); match Qubes.DB.KeyMap.find_opt pattern rules with | None -> - Log.debug (fun f -> f "rule %d does not exist; won't look for more" n); - Ok (List.rev l) - | Some rule -> - Log.debug (fun f -> f "rule %d: %s" n rule); - match Pf_qubes.Parse_qubes.parse_qubes ~number:n rule with - | Error e -> Log.warn (fun f -> f "Error parsing rule %d: %s" n e); Error e - | Ok rule -> - Log.debug (fun f -> f "parsed rule: %a" Pf_qubes.Parse_qubes.pp_rule rule); - get_rule (n+1) (rule :: l) + Log.debug (fun f -> f "rule %d does not exist; won't look for more" n); + Ok (List.rev l) + | Some rule -> ( + Log.debug (fun f -> f "rule %d: %s" n rule); + match Pf_qubes.Parse_qubes.parse_qubes ~number:n rule with + | Error e -> + Log.warn (fun f -> f "Error parsing rule %d: %s" n e); + Error e + | Ok rule -> + Log.debug (fun f -> + f "parsed rule: %a" Pf_qubes.Parse_qubes.pp_rule rule); + get_rule (n + 1) (rule :: l)) in match get_rule 0 [] with | Ok l -> l | Error e -> - Log.warn (fun f -> f "Defaulting to deny-all because of rule parse failure (%s)" e); - [ Pf_qubes.Parse_qubes.({action = Drop; - proto = None; - specialtarget = None; - dst = `any; - dstports = None; - icmp_type = None; - number = 0;})] + Log.warn (fun f -> + f "Defaulting to deny-all because of rule parse failure (%s)" e); + [ + Pf_qubes.Parse_qubes. + { + action = Drop; + proto = None; + specialtarget = None; + dst = `any; + dstports = None; + icmp_type = None; + number = 0; + }; + ] let vifs client domid = let open Lwt.Syntax in match int_of_string_opt domid with - | None -> Log.err (fun f -> f "Invalid domid %S" domid); Lwt.return [] + | None -> + Log.err (fun f -> f "Invalid domid %S" domid); + Lwt.return [] | Some domid -> - let path = Fmt.str "backend/vif/%d" domid in - let vifs_of_domain handle = - let* devices = directory ~handle path in - let ip_of_vif device_id = match int_of_string_opt device_id with - | None -> - Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid); - Lwt.return_none - | Some device_id -> - let vif = { ClientVif.domid; device_id } in - let get_client_ip () = - let* str = Xen_os.Xs.read handle (Fmt.str "%s/%d/ip" path device_id) in - let client_ip = List.hd (String.split_on_char ' ' str) in - (* NOTE(dinosaure): it's safe to use [List.hd] here, + let path = Fmt.str "backend/vif/%d" domid in + let vifs_of_domain handle = + let* devices = directory ~handle path in + let ip_of_vif device_id = + match int_of_string_opt device_id with + | None -> + Log.err (fun f -> + f "Invalid device ID %S for domid %d" device_id domid); + Lwt.return_none + | Some device_id -> ( + let vif = { ClientVif.domid; device_id } in + let get_client_ip () = + let* str = + Xen_os.Xs.read handle (Fmt.str "%s/%d/ip" path device_id) + in + let client_ip = List.hd (String.split_on_char ' ' str) in + (* NOTE(dinosaure): it's safe to use [List.hd] here, [String.split_on_char] can not return an empty list. *) - Lwt.return_some (vif, Ipaddr.V4.of_string_exn client_ip) - in - Lwt.catch get_client_ip @@ function - | Xs_protocol.Enoent _ -> Lwt.return_none - | Ipaddr.Parse_error (msg, client_ip) -> - Log.err (fun f -> f "Error parsing IP address of %a from %s: %s" - ClientVif.pp vif client_ip msg); - Lwt.return_none - | exn -> - Log.err (fun f -> f "Error getting IP address of %a: %s" - ClientVif.pp vif (Printexc.to_string exn)); - Lwt.return_none + Lwt.return_some (vif, Ipaddr.V4.of_string_exn client_ip) + in + Lwt.catch get_client_ip @@ function + | Xs_protocol.Enoent _ -> Lwt.return_none + | Ipaddr.Parse_error (msg, client_ip) -> + Log.err (fun f -> + f "Error parsing IP address of %a from %s: %s" + ClientVif.pp vif client_ip msg); + Lwt.return_none + | exn -> + Log.err (fun f -> + f "Error getting IP address of %a: %s" ClientVif.pp vif + (Printexc.to_string exn)); + Lwt.return_none) + in + Lwt_list.filter_map_p ip_of_vif devices in - Lwt_list.filter_map_p ip_of_vif devices - in - Xen_os.Xs.immediate client vifs_of_domain + Xen_os.Xs.immediate client vifs_of_domain let watch_clients fn = Xen_os.Xs.make () >>= fun xs -> let backend_vifs = "backend/vif" in Log.info (fun f -> f "Watching %s" backend_vifs); Xen_os.Xs.wait xs (fun handle -> - begin Lwt.catch - (fun () -> directory ~handle backend_vifs) - (function - | Xs_protocol.Enoent _ -> Lwt.return [] - | ex -> Lwt.fail ex) - end >>= fun items -> - Xen_os.Xs.make () >>= fun xs -> - Lwt_list.map_p (vifs xs) items >>= fun items -> - fn (List.concat items |> VifMap.of_list) >>= fun () -> - (* Wait for further updates *) - Lwt.fail Xs_protocol.Eagain - ) + Lwt.catch + (fun () -> directory ~handle backend_vifs) + (function Xs_protocol.Enoent _ -> Lwt.return [] | ex -> Lwt.fail ex) + >>= fun items -> + Xen_os.Xs.make () >>= fun xs -> + Lwt_list.map_p (vifs xs) items >>= fun items -> + fn (List.concat items |> VifMap.of_list) >>= fun () -> + (* Wait for further updates *) + Lwt.fail Xs_protocol.Eagain) type network_config = { - from_cmdline : bool; (* Specify if we have network configuration from command line or from qubesDB*) - netvm_ip : Ipaddr.V4.t; (* The IP address of NetVM (our gateway) *) - our_ip : Ipaddr.V4.t; (* The IP address of our interface to NetVM *) + from_cmdline : bool; + (* Specify if we have network configuration from command line or from qubesDB*) + netvm_ip : Ipaddr.V4.t; (* The IP address of NetVM (our gateway) *) + our_ip : Ipaddr.V4.t; (* The IP address of our interface to NetVM *) dns : Ipaddr.V4.t; dns2 : Ipaddr.V4.t; } @@ -132,31 +146,36 @@ let try_read_network_config db = let get name = match DB.KeyMap.find_opt name db with | None -> raise (Missing_key name) - | Some value -> Ipaddr.V4.of_string_exn value in - let our_ip = get "/qubes-ip" in (* - IP address for this VM (only when VM has netvm set) *) - let netvm_ip = get "/qubes-gateway" in (* - default gateway IP (only when VM has netvm set); VM should add host route to this address directly via eth0 (or whatever default interface name is) *) + | Some value -> Ipaddr.V4.of_string_exn value + in + let our_ip = get "/qubes-ip" in + (* - IP address for this VM (only when VM has netvm set) *) + let netvm_ip = get "/qubes-gateway" in + (* - default gateway IP (only when VM has netvm set); VM should add host route to this address directly via eth0 (or whatever default interface name is) *) let dns = get "/qubes-primary-dns" in let dns2 = get "/qubes-secondary-dns" in - { from_cmdline=false; netvm_ip ; our_ip ; dns ; dns2 } + { from_cmdline = false; netvm_ip; our_ip; dns; dns2 } let read_network_config qubesDB = let rec aux bindings = try Lwt.return (try_read_network_config bindings) with Missing_key key -> - Log.warn (fun f -> f "QubesDB key %S not (yet) present; waiting for QubesDB to change..." key); + Log.warn (fun f -> + f "QubesDB key %S not (yet) present; waiting for QubesDB to change..." + key); DB.after qubesDB bindings >>= aux in aux (DB.bindings qubesDB) let print_network_config config = - Log.info (fun f -> f "@[Current network configuration (QubesDB or command line):@,\ - NetVM IP on uplink network: %a@,\ - Our IP on client networks: %a@,\ - DNS primary resolver: %a@,\ - DNS secondary resolver: %a@]" - Ipaddr.V4.pp config.netvm_ip - Ipaddr.V4.pp config.our_ip - Ipaddr.V4.pp config.dns - Ipaddr.V4.pp config.dns2) + Log.info (fun f -> + f + "@[Current network configuration (QubesDB or command line):@,\ + NetVM IP on uplink network: %a@,\ + Our IP on client networks: %a@,\ + DNS primary resolver: %a@,\ + DNS secondary resolver: %a@]" + Ipaddr.V4.pp config.netvm_ip Ipaddr.V4.pp config.our_ip Ipaddr.V4.pp + config.dns Ipaddr.V4.pp config.dns2) let set_iptables_error db = Qubes.DB.write db "/qubes-iptables-error" diff --git a/dao.mli b/dao.mli index c278d16..85f8912 100644 --- a/dao.mli +++ b/dao.mli @@ -4,40 +4,43 @@ (** Wrapper for XenStore and QubesDB databases. *) module ClientVif : sig - type t = { - domid : int; - device_id : int; - } + type t = { domid : int; device_id : int } + val pp : t Fmt.t end + module VifMap : sig include Map.S with type key = ClientVif.t + val find : key -> 'a t -> 'a option end val watch_clients : (Ipaddr.V4.t VifMap.t -> unit Lwt.t) -> 'a Lwt.t -(** [watch_clients fn] calls [fn clients] with the list of backend clients - in XenStore, and again each time XenStore updates. *) +(** [watch_clients fn] calls [fn clients] with the list of backend clients in + XenStore, and again each time XenStore updates. *) type network_config = { - from_cmdline : bool; (* Specify if we have network configuration from command line or from qubesDB*) - netvm_ip : Ipaddr.V4.t; (* The IP address of NetVM (our gateway) *) - our_ip : Ipaddr.V4.t; (* The IP address of our interface to NetVM *) + from_cmdline : bool; + (* Specify if we have network configuration from command line or from qubesDB*) + netvm_ip : Ipaddr.V4.t; (* The IP address of NetVM (our gateway) *) + our_ip : Ipaddr.V4.t; (* The IP address of our interface to NetVM *) dns : Ipaddr.V4.t; dns2 : Ipaddr.V4.t; } val read_network_config : Qubes.DB.t -> network_config Lwt.t -(** [read_network_config db] fetches the configuration from QubesDB. - If it isn't there yet, it waits until it is. *) +(** [read_network_config db] fetches the configuration from QubesDB. If it isn't + there yet, it waits until it is. *) val db_root : Ipaddr.V4.t -> string -(** Returns the root path of the firewall rules in the QubesDB for a given IP address. *) +(** Returns the root path of the firewall rules in the QubesDB for a given IP + address. *) -val read_rules : string Qubes.DB.KeyMap.t -> Ipaddr.V4.t -> Pf_qubes.Parse_qubes.rule list -(** [read_rules bindings ip] extracts firewall rule information for [ip] from [bindings]. - If any rules fail to parse, it will return only one rule denying all traffic. *) +val read_rules : + string Qubes.DB.KeyMap.t -> Ipaddr.V4.t -> Pf_qubes.Parse_qubes.rule list +(** [read_rules bindings ip] extracts firewall rule information for [ip] from + [bindings]. If any rules fail to parse, it will return only one rule denying + all traffic. *) val print_network_config : network_config -> unit - val set_iptables_error : Qubes.DB.t -> string -> unit Lwt.t diff --git a/dispatcher.ml b/dispatcher.ml index 9f6db7f..9d67f88 100644 --- a/dispatcher.ml +++ b/dispatcher.ml @@ -7,158 +7,161 @@ module UplinkEth = Ethernet.Make (Netif) let src = Logs.Src.create "dispatcher" ~doc:"Networking dispatch" module Log = (val Logs.src_log src : Logs.LOG) +module Arp = Arp.Make (UplinkEth) +module I = Static_ipv4.Make (UplinkEth) (Arp) +module U = Udp.Make (I) - module Arp = Arp.Make (UplinkEth) - module I = Static_ipv4.Make (UplinkEth) (Arp) - module U = Udp.Make (I) +class client_iface eth ~domid ~gateway_ip ~client_ip client_mac : client_link = + let log_header = Fmt.str "dom%d:%a" domid Ipaddr.V4.pp client_ip in + object + val mutable rules = [] + method get_rules = rules + method set_rules new_db = rules <- Dao.read_rules new_db client_ip + method my_mac = ClientEth.mac eth + method other_mac = client_mac + method my_ip = gateway_ip + method other_ip = client_ip - class client_iface eth ~domid ~gateway_ip ~client_ip client_mac : client_link - = - let log_header = Fmt.str "dom%d:%a" domid Ipaddr.V4.pp client_ip in - object - val mutable rules = [] - method get_rules = rules - method set_rules new_db = rules <- Dao.read_rules new_db client_ip - method my_mac = ClientEth.mac eth - method other_mac = client_mac - method my_ip = gateway_ip - method other_ip = client_ip + method writev proto fillfn = + Lwt.catch + (fun () -> + ClientEth.write eth client_mac proto fillfn >|= function + | Ok () -> () + | Error e -> + Log.err (fun f -> + f "error trying to send to client: @[%a@]" ClientEth.pp_error + e)) + (fun ex -> + (* Usually Netback_shutdown, because the client disconnected *) + Log.err (fun f -> + f "uncaught exception trying to send to client: @[%s@]" + (Printexc.to_string ex)); + Lwt.return_unit) - method writev proto fillfn = - Lwt.catch - (fun () -> - ClientEth.write eth client_mac proto fillfn >|= function - | Ok () -> () - | Error e -> - Log.err (fun f -> - f "error trying to send to client: @[%a@]" - ClientEth.pp_error e)) - (fun ex -> - (* Usually Netback_shutdown, because the client disconnected *) - Log.err (fun f -> - f "uncaught exception trying to send to client: @[%s@]" - (Printexc.to_string ex)); - Lwt.return_unit) + method log_header = log_header + end - method log_header = log_header - end +class netvm_iface eth mac ~my_ip ~other_ip : interface = + object + method my_mac = UplinkEth.mac eth + method my_ip = my_ip + method other_ip = other_ip - class netvm_iface eth mac ~my_ip ~other_ip : interface = - object - method my_mac = UplinkEth.mac eth - method my_ip = my_ip - method other_ip = other_ip + method writev ethertype fillfn = + Lwt.catch + (fun () -> + mac >>= fun dst -> + UplinkEth.write eth dst ethertype fillfn + >|= or_raise "Write to uplink" UplinkEth.pp_error) + (fun ex -> + Log.err (fun f -> + f "uncaught exception trying to send to uplink: @[%s@]" + (Printexc.to_string ex)); + Lwt.return_unit) + end - method writev ethertype fillfn = - Lwt.catch - (fun () -> - mac >>= fun dst -> - UplinkEth.write eth dst ethertype fillfn - >|= or_raise "Write to uplink" UplinkEth.pp_error) - (fun ex -> - Log.err (fun f -> - f "uncaught exception trying to send to uplink: @[%s@]" - (Printexc.to_string ex)); - Lwt.return_unit) - end +type uplink = { + net : Netif.t; + eth : UplinkEth.t; + arp : Arp.t; + interface : interface; + mutable fragments : Fragments.Cache.t; + ip : I.t; + udp : U.t; +} - type uplink = { - net : Netif.t; - eth : UplinkEth.t; - arp : Arp.t; - interface : interface; - mutable fragments : Fragments.Cache.t; - ip : I.t; - udp : U.t; +type t = { + uplink_connected : unit Lwt_condition.t; + uplink_disconnect : unit Lwt_condition.t; + uplink_disconnected : unit Lwt_condition.t; + mutable config : Dao.network_config; + clients : Client_eth.t; + nat : My_nat.t; + mutable uplink : uplink option; +} + +let create ~config ~clients ~nat ~uplink = + { + uplink_connected = Lwt_condition.create (); + uplink_disconnect = Lwt_condition.create (); + uplink_disconnected = Lwt_condition.create (); + config; + clients; + nat; + uplink; } - type t = { - uplink_connected : unit Lwt_condition.t; - uplink_disconnect : unit Lwt_condition.t; - uplink_disconnected : unit Lwt_condition.t; - mutable config : Dao.network_config; - clients : Client_eth.t; - nat : My_nat.t; - mutable uplink : uplink option; - } +let update t ~config ~uplink = + t.config <- config; + t.uplink <- uplink; + Lwt.return_unit - let create ~config ~clients ~nat ~uplink = - { - uplink_connected = Lwt_condition.create (); - uplink_disconnect = Lwt_condition.create (); - uplink_disconnected = Lwt_condition.create (); - config; - clients; - nat; - uplink; - } - - let update t ~config ~uplink = - t.config <- config; - t.uplink <- uplink; - Lwt.return_unit - - let target t buf = - let dst_ip = buf.Ipv4_packet.dst in - match Client_eth.lookup t.clients dst_ip with - | Some client_link -> Some (client_link :> interface) - | None -> ( (* if dest is not a client, transfer it to our uplink *) - match t.uplink with - | None -> ( - match Client_eth.lookup t.clients t.config.netvm_ip with - | Some uplink -> - Some (uplink :> interface) - | None -> - Log.err (fun f -> f "We have a command line configuration %a but it's currently not connected to us (please check its netvm property)...%!" Ipaddr.V4.pp t.config.netvm_ip); +let target t buf = + let dst_ip = buf.Ipv4_packet.dst in + match Client_eth.lookup t.clients dst_ip with + | Some client_link -> Some (client_link :> interface) + | None -> ( + (* if dest is not a client, transfer it to our uplink *) + match t.uplink with + | None -> ( + match Client_eth.lookup t.clients t.config.netvm_ip with + | Some uplink -> Some (uplink :> interface) + | None -> + Log.err (fun f -> + f + "We have a command line configuration %a but it's \ + currently not connected to us (please check its netvm \ + property)...%!" + Ipaddr.V4.pp t.config.netvm_ip); None) - | Some uplink -> Some uplink.interface) + | Some uplink -> Some uplink.interface) - let add_client t = Client_eth.add_client t.clients - let remove_client t = Client_eth.remove_client t.clients +let add_client t = Client_eth.add_client t.clients +let remove_client t = Client_eth.remove_client t.clients - let classify t ip = - if ip = Ipaddr.V4 t.config.our_ip then `Firewall - else if ip = Ipaddr.V4 t.config.netvm_ip then `NetVM - else (Client_eth.classify t.clients ip :> Packet.host) +let classify t ip = + if ip = Ipaddr.V4 t.config.our_ip then `Firewall + else if ip = Ipaddr.V4 t.config.netvm_ip then `NetVM + else (Client_eth.classify t.clients ip :> Packet.host) - let resolve t = function - | `Firewall -> Ipaddr.V4 t.config.our_ip - | `NetVM -> Ipaddr.V4 t.config.netvm_ip - | #Client_eth.host as host -> Client_eth.resolve t.clients host +let resolve t = function + | `Firewall -> Ipaddr.V4 t.config.our_ip + | `NetVM -> Ipaddr.V4 t.config.netvm_ip + | #Client_eth.host as host -> Client_eth.resolve t.clients host - (* Transmission *) +(* Transmission *) - let transmit_ipv4 packet iface = - Lwt.catch - (fun () -> - let fragments = ref [] in - iface#writev `IPv4 (fun b -> - match Nat_packet.into_cstruct packet b with - | Error e -> - Log.warn (fun f -> - f "Failed to write packet to %a: %a" Ipaddr.V4.pp - iface#other_ip Nat_packet.pp_error e); - 0 - | Ok (n, frags) -> - fragments := frags; - n) - >>= fun () -> - Lwt_list.iter_s - (fun f -> - let size = Cstruct.length f in - iface#writev `IPv4 (fun b -> - Cstruct.blit f 0 b 0 size; - size)) - !fragments) - (fun ex -> - Log.warn (fun f -> - f "Failed to write packet to %a: %s" Ipaddr.V4.pp iface#other_ip - (Printexc.to_string ex)); - Lwt.return_unit) +let transmit_ipv4 packet iface = + Lwt.catch + (fun () -> + let fragments = ref [] in + iface#writev `IPv4 (fun b -> + match Nat_packet.into_cstruct packet b with + | Error e -> + Log.warn (fun f -> + f "Failed to write packet to %a: %a" Ipaddr.V4.pp + iface#other_ip Nat_packet.pp_error e); + 0 + | Ok (n, frags) -> + fragments := frags; + n) + >>= fun () -> + Lwt_list.iter_s + (fun f -> + let size = Cstruct.length f in + iface#writev `IPv4 (fun b -> + Cstruct.blit f 0 b 0 size; + size)) + !fragments) + (fun ex -> + Log.warn (fun f -> + f "Failed to write packet to %a: %s" Ipaddr.V4.pp iface#other_ip + (Printexc.to_string ex)); + Lwt.return_unit) - let forward_ipv4 t packet = - let (`IPv4 (ip, _)) = packet in - Lwt.catch +let forward_ipv4 t packet = + let (`IPv4 (ip, _)) = packet in + Lwt.catch (fun () -> match target t ip with | Some iface -> transmit_ipv4 packet iface @@ -170,460 +173,463 @@ module Log = (val Logs.src_log src : Logs.LOG) (Printexc.to_string ex)); Lwt.return_unit) - (* NAT *) +(* NAT *) - let translate t packet = My_nat.translate t.nat packet +let translate t packet = My_nat.translate t.nat packet - (* Add a NAT rule for the endpoints in this frame, via a random port on the firewall. *) - let add_nat_and_forward_ipv4 t packet = - let xl_host = t.config.our_ip in - match My_nat.add_nat_rule_and_translate t.nat ~xl_host `NAT packet with - | Ok packet -> forward_ipv4 t packet - | Error e -> - Log.warn (fun f -> - f "Failed to add NAT rewrite rule: %s (%a)" e Nat_packet.pp packet); - Lwt.return_unit +(* Add a NAT rule for the endpoints in this frame, via a random port on the firewall. *) +let add_nat_and_forward_ipv4 t packet = + let xl_host = t.config.our_ip in + match My_nat.add_nat_rule_and_translate t.nat ~xl_host `NAT packet with + | Ok packet -> forward_ipv4 t packet + | Error e -> + Log.warn (fun f -> + f "Failed to add NAT rewrite rule: %s (%a)" e Nat_packet.pp packet); + Lwt.return_unit - (* Add a NAT rule to redirect this conversation to [host:port] instead of us. *) - let nat_to t ~host ~port packet = - match resolve t host with - | Ipaddr.V6 _ -> - Log.warn (fun f -> f "Cannot NAT with IPv6"); - Lwt.return_unit - | Ipaddr.V4 target -> ( - let xl_host = t.config.our_ip in - match - My_nat.add_nat_rule_and_translate t.nat ~xl_host - (`Redirect (target, port)) - packet - with - | Ok packet -> forward_ipv4 t packet - | Error e -> - Log.warn (fun f -> - f "Failed to add NAT redirect rule: %s (%a)" e Nat_packet.pp - packet); - Lwt.return_unit) +(* Add a NAT rule to redirect this conversation to [host:port] instead of us. *) +let nat_to t ~host ~port packet = + match resolve t host with + | Ipaddr.V6 _ -> + Log.warn (fun f -> f "Cannot NAT with IPv6"); + Lwt.return_unit + | Ipaddr.V4 target -> ( + let xl_host = t.config.our_ip in + match + My_nat.add_nat_rule_and_translate t.nat ~xl_host + (`Redirect (target, port)) + packet + with + | Ok packet -> forward_ipv4 t packet + | Error e -> + Log.warn (fun f -> + f "Failed to add NAT redirect rule: %s (%a)" e Nat_packet.pp + packet); + Lwt.return_unit) - let apply_rules t (rules : ('a, 'b) Packet.t -> Packet.action Lwt.t) ~dst - (annotated_packet : ('a, 'b) Packet.t) : unit Lwt.t = - let packet = Packet.to_mirage_nat_packet annotated_packet in - rules annotated_packet >>= fun action -> - match (action, dst) with - | `Accept, `Client client_link -> transmit_ipv4 packet client_link - | `Accept, (`External _ | `NetVM) -> ( - match t.uplink with - | Some uplink -> transmit_ipv4 packet uplink.interface - | None -> ( - match Client_eth.lookup t.clients t.config.netvm_ip with - | Some iface -> transmit_ipv4 packet iface - | None -> - Log.warn (fun f -> - f "No output interface for %a : drop" Nat_packet.pp packet); - Lwt.return_unit)) - | `Accept, `Firewall -> - Log.warn (fun f -> - f "Bad rule: firewall can't accept packets %a" Nat_packet.pp packet); - Lwt.return_unit - | `NAT, _ -> - Log.debug (fun f -> f "adding NAT rule for %a" Nat_packet.pp packet); - add_nat_and_forward_ipv4 t packet - | `NAT_to (host, port), _ -> nat_to t packet ~host ~port - | `Drop reason, _ -> - Log.debug (fun f -> - f "Dropped packet (%s) %a" reason Nat_packet.pp packet); - Lwt.return_unit +let apply_rules t (rules : ('a, 'b) Packet.t -> Packet.action Lwt.t) ~dst + (annotated_packet : ('a, 'b) Packet.t) : unit Lwt.t = + let packet = Packet.to_mirage_nat_packet annotated_packet in + rules annotated_packet >>= fun action -> + match (action, dst) with + | `Accept, `Client client_link -> transmit_ipv4 packet client_link + | `Accept, (`External _ | `NetVM) -> ( + match t.uplink with + | Some uplink -> transmit_ipv4 packet uplink.interface + | None -> ( + match Client_eth.lookup t.clients t.config.netvm_ip with + | Some iface -> transmit_ipv4 packet iface + | None -> + Log.warn (fun f -> + f "No output interface for %a : drop" Nat_packet.pp packet); + Lwt.return_unit)) + | `Accept, `Firewall -> + Log.warn (fun f -> + f "Bad rule: firewall can't accept packets %a" Nat_packet.pp packet); + Lwt.return_unit + | `NAT, _ -> + Log.debug (fun f -> f "adding NAT rule for %a" Nat_packet.pp packet); + add_nat_and_forward_ipv4 t packet + | `NAT_to (host, port), _ -> nat_to t packet ~host ~port + | `Drop reason, _ -> + Log.debug (fun f -> + f "Dropped packet (%s) %a" reason Nat_packet.pp packet); + Lwt.return_unit - let ipv4_from_netvm t packet = - match Memory_pressure.status () with - | `Memory_critical -> Lwt.return_unit - | `Ok -> ( - let (`IPv4 (ip, _transport)) = packet in - let src = classify t (Ipaddr.V4 ip.Ipv4_packet.src) in - let dst = classify t (Ipaddr.V4 ip.Ipv4_packet.dst) in - match Packet.of_mirage_nat_packet ~src ~dst packet with - | None -> Lwt.return_unit - | Some _ -> ( - match src with - | `Client _ | `Firewall -> - Log.warn (fun f -> - f "Frame from NetVM has internal source IP address! %a" - Nat_packet.pp packet); - Lwt.return_unit - | (`External _ | `NetVM) as src -> ( - match translate t packet with - | Some frame -> forward_ipv4 t frame - | None -> ( - match Packet.of_mirage_nat_packet ~src ~dst packet with - | None -> Lwt.return_unit - | Some packet -> apply_rules t Rules.from_netvm ~dst packet) - ))) +let ipv4_from_netvm t packet = + match Memory_pressure.status () with + | `Memory_critical -> Lwt.return_unit + | `Ok -> ( + let (`IPv4 (ip, _transport)) = packet in + let src = classify t (Ipaddr.V4 ip.Ipv4_packet.src) in + let dst = classify t (Ipaddr.V4 ip.Ipv4_packet.dst) in + match Packet.of_mirage_nat_packet ~src ~dst packet with + | None -> Lwt.return_unit + | Some _ -> ( + match src with + | `Client _ | `Firewall -> + Log.warn (fun f -> + f "Frame from NetVM has internal source IP address! %a" + Nat_packet.pp packet); + Lwt.return_unit + | (`External _ | `NetVM) as src -> ( + match translate t packet with + | Some frame -> forward_ipv4 t frame + | None -> ( + match Packet.of_mirage_nat_packet ~src ~dst packet with + | None -> Lwt.return_unit + | Some packet -> apply_rules t Rules.from_netvm ~dst packet))) + ) - let ipv4_from_client resolver dns_servers t ~src packet = - match Memory_pressure.status () with - | `Memory_critical -> Lwt.return_unit - | `Ok -> ( - (* Check for existing NAT entry for this packet *) - match translate t packet with - | Some frame -> - forward_ipv4 t frame (* Some existing connection or redirect *) - | None -> ( - (* No existing NAT entry. Check the firewall rules. *) - let (`IPv4 (ip, _transport)) = packet in - match classify t (Ipaddr.V4 ip.Ipv4_packet.src) with - | `Client _ | `Firewall -> ( - let dst = classify t (Ipaddr.V4 ip.Ipv4_packet.dst) in - match - Packet.of_mirage_nat_packet ~src:(`Client src) ~dst packet - with - | None -> Lwt.return_unit - | Some firewall_packet -> - apply_rules t - (Rules.from_client resolver dns_servers) - ~dst firewall_packet) - | `NetVM -> ipv4_from_netvm t packet - | `External _ -> - Log.warn (fun f -> - f "Frame from Inside has external source IP address! %a" - Nat_packet.pp packet); - Lwt.return_unit)) +let ipv4_from_client resolver dns_servers t ~src packet = + match Memory_pressure.status () with + | `Memory_critical -> Lwt.return_unit + | `Ok -> ( + (* Check for existing NAT entry for this packet *) + match translate t packet with + | Some frame -> + forward_ipv4 t frame (* Some existing connection or redirect *) + | None -> ( + (* No existing NAT entry. Check the firewall rules. *) + let (`IPv4 (ip, _transport)) = packet in + match classify t (Ipaddr.V4 ip.Ipv4_packet.src) with + | `Client _ | `Firewall -> ( + let dst = classify t (Ipaddr.V4 ip.Ipv4_packet.dst) in + match + Packet.of_mirage_nat_packet ~src:(`Client src) ~dst packet + with + | None -> Lwt.return_unit + | Some firewall_packet -> + apply_rules t + (Rules.from_client resolver dns_servers) + ~dst firewall_packet) + | `NetVM -> ipv4_from_netvm t packet + | `External _ -> + Log.warn (fun f -> + f "Frame from Inside has external source IP address! %a" + Nat_packet.pp packet); + Lwt.return_unit)) - (** Handle an ARP message from the client. *) - let client_handle_arp ~fixed_arp ~iface request = - match Arp_packet.decode request with - | Error e -> - Log.warn (fun f -> - f "Ignored unknown ARP message: %a" Arp_packet.pp_error e); - Lwt.return_unit - | Ok arp -> ( - match Client_eth.ARP.input fixed_arp arp with - | None -> Lwt.return_unit - | Some response -> +(** Handle an ARP message from the client. *) +let client_handle_arp ~fixed_arp ~iface request = + match Arp_packet.decode request with + | Error e -> + Log.warn (fun f -> + f "Ignored unknown ARP message: %a" Arp_packet.pp_error e); + Lwt.return_unit + | Ok arp -> ( + match Client_eth.ARP.input fixed_arp arp with + | None -> Lwt.return_unit + | Some response -> Lwt.catch (fun () -> - iface#writev `ARP (fun b -> - Arp_packet.encode_into response b; - Arp_packet.size)) + iface#writev `ARP (fun b -> + Arp_packet.encode_into response b; + Arp_packet.size)) (fun ex -> Log.warn (fun f -> f "Failed to write APR to %a: %s" Ipaddr.V4.pp iface#other_ip (Printexc.to_string ex)); - Lwt.return_unit) - ) + Lwt.return_unit)) - (** Handle an IPv4 packet from the client. *) - let client_handle_ipv4 get_ts cache ~iface ~router dns_client dns_servers - packet = - let cache', r = Nat_packet.of_ipv4_packet !cache ~now:(get_ts ()) packet in - cache := cache'; - match r with - | Error e -> +(** Handle an IPv4 packet from the client. *) +let client_handle_ipv4 get_ts cache ~iface ~router dns_client dns_servers packet + = + let cache', r = Nat_packet.of_ipv4_packet !cache ~now:(get_ts ()) packet in + cache := cache'; + match r with + | Error e -> + Log.warn (fun f -> + f "Ignored unknown IPv4 message: %a" Nat_packet.pp_error e); + Lwt.return_unit + | Ok None -> Lwt.return_unit + | Ok (Some packet) -> + let (`IPv4 (ip, _)) = packet in + let src = ip.Ipv4_packet.src in + if src = iface#other_ip then + ipv4_from_client dns_client dns_servers router ~src:iface packet + else if iface#other_ip = router.config.netvm_ip then + (* This can occurs when used with *BSD as netvm (and a gateway is set) *) + ipv4_from_netvm router packet + else ( Log.warn (fun f -> - f "Ignored unknown IPv4 message: %a" Nat_packet.pp_error e); - Lwt.return_unit - | Ok None -> Lwt.return_unit - | Ok (Some packet) -> - let (`IPv4 (ip, _)) = packet in - let src = ip.Ipv4_packet.src in - if src = iface#other_ip then - ipv4_from_client dns_client dns_servers router ~src:iface packet - else if iface#other_ip = router.config.netvm_ip then - (* This can occurs when used with *BSD as netvm (and a gateway is set) *) - ipv4_from_netvm router packet - else ( - Log.warn (fun f -> - f "Incorrect source IP %a in IP packet from %a (dropping)" - Ipaddr.V4.pp src Ipaddr.V4.pp iface#other_ip); - Lwt.return_unit) + f "Incorrect source IP %a in IP packet from %a (dropping)" + Ipaddr.V4.pp src Ipaddr.V4.pp iface#other_ip); + Lwt.return_unit) - (** Connect to a new client's interface and listen for incoming frames and firewall rule changes. *) - let conf_vif get_ts vif backend client_eth dns_client dns_servers - ~client_ip ~iface ~router ~cleanup_tasks qubesDB () = - let { Dao.ClientVif.domid; device_id } = vif in - Log.info (fun f -> - f "Client %d:%d (IP: %s) ready" domid device_id (Ipaddr.V4.to_string client_ip)); +(** Connect to a new client's interface and listen for incoming frames and + firewall rule changes. *) +let conf_vif get_ts vif backend client_eth dns_client dns_servers ~client_ip + ~iface ~router ~cleanup_tasks qubesDB () = + let { Dao.ClientVif.domid; device_id } = vif in + Log.info (fun f -> + f "Client %d:%d (IP: %s) ready" domid device_id + (Ipaddr.V4.to_string client_ip)); - (* update the rules whenever QubesDB notices a change for this IP *) - let qubesdb_updater = - Lwt.catch - (fun () -> - let rec update current_db current_rules = - Qubes.DB.got_new_commit qubesDB (Dao.db_root client_ip) current_db - >>= fun new_db -> - iface#set_rules new_db; - let new_rules = iface#get_rules in - if current_rules = new_rules then - Log.info (fun m -> - m "Rules did not change for %s" - (Ipaddr.V4.to_string client_ip)) - else ( - Log.info (fun m -> - m "New firewall rules for %s@.%a" - (Ipaddr.V4.to_string client_ip) - Fmt.(list ~sep:(any "@.") Pf_qubes.Parse_qubes.pp_rule) - new_rules); - (* empty NAT table if rules are updated: they might deny old connections *) - My_nat.remove_connections router.nat client_ip); - update new_db new_rules - in - update Qubes.DB.KeyMap.empty []) - (function Lwt.Canceled -> Lwt.return_unit | e -> Lwt.fail e) - in - Cleanup.on_cleanup cleanup_tasks (fun () -> Lwt.cancel qubesdb_updater); + (* update the rules whenever QubesDB notices a change for this IP *) + let qubesdb_updater = + Lwt.catch + (fun () -> + let rec update current_db current_rules = + Qubes.DB.got_new_commit qubesDB (Dao.db_root client_ip) current_db + >>= fun new_db -> + iface#set_rules new_db; + let new_rules = iface#get_rules in + if current_rules = new_rules then + Log.info (fun m -> + m "Rules did not change for %s" (Ipaddr.V4.to_string client_ip)) + else ( + Log.info (fun m -> + m "New firewall rules for %s@.%a" + (Ipaddr.V4.to_string client_ip) + Fmt.(list ~sep:(any "@.") Pf_qubes.Parse_qubes.pp_rule) + new_rules); + (* empty NAT table if rules are updated: they might deny old connections *) + My_nat.remove_connections router.nat client_ip); + update new_db new_rules + in + update Qubes.DB.KeyMap.empty []) + (function Lwt.Canceled -> Lwt.return_unit | e -> Lwt.fail e) + in + Cleanup.on_cleanup cleanup_tasks (fun () -> Lwt.cancel qubesdb_updater); - let fixed_arp = Client_eth.ARP.create ~net:client_eth iface in - let fragment_cache = ref (Fragments.Cache.empty (256 * 1024)) in - let listener = - Lwt.catch - (fun () -> - Netback.listen backend ~header_size:Ethernet.Packet.sizeof_ethernet - (fun frame -> - match Ethernet.Packet.of_cstruct frame with - | Error err -> - Log.warn (fun f -> f "Invalid Ethernet frame: %s" err); - Lwt.return_unit - | Ok (eth, payload) -> ( - match eth.Ethernet.Packet.ethertype with - | `ARP -> client_handle_arp ~fixed_arp ~iface payload - | `IPv4 -> - client_handle_ipv4 get_ts fragment_cache ~iface ~router - dns_client dns_servers payload - | `IPv6 -> Lwt.return_unit (* TODO: oh no! *))) - >|= or_raise "Listen on client interface" Netback.pp_error) - (function Lwt.Canceled -> Lwt.return_unit | e -> Lwt.fail e) - in - Cleanup.on_cleanup cleanup_tasks (fun () -> Lwt.cancel listener); - (* NOTE(dinosaure): [qubes_updater] and [listener] can be forgotten, our [cleanup_task] + let fixed_arp = Client_eth.ARP.create ~net:client_eth iface in + let fragment_cache = ref (Fragments.Cache.empty (256 * 1024)) in + let listener = + Lwt.catch + (fun () -> + Netback.listen backend ~header_size:Ethernet.Packet.sizeof_ethernet + (fun frame -> + match Ethernet.Packet.of_cstruct frame with + | Error err -> + Log.warn (fun f -> f "Invalid Ethernet frame: %s" err); + Lwt.return_unit + | Ok (eth, payload) -> ( + match eth.Ethernet.Packet.ethertype with + | `ARP -> client_handle_arp ~fixed_arp ~iface payload + | `IPv4 -> + client_handle_ipv4 get_ts fragment_cache ~iface ~router + dns_client dns_servers payload + | `IPv6 -> Lwt.return_unit (* TODO: oh no! *))) + >|= or_raise "Listen on client interface" Netback.pp_error) + (function Lwt.Canceled -> Lwt.return_unit | e -> Lwt.fail e) + in + Cleanup.on_cleanup cleanup_tasks (fun () -> Lwt.cancel listener); + (* NOTE(dinosaure): [qubes_updater] and [listener] can be forgotten, our [cleanup_task] will cancel them if the client is disconnected. *) - Lwt.async (fun () -> Lwt.pick [ qubesdb_updater; listener ]); + Lwt.async (fun () -> Lwt.pick [ qubesdb_updater; listener ]); + Lwt.return_unit + +(** A new client VM has been found in XenStore. Find its interface and connect + to it. *) +let add_client get_ts dns_client dns_servers ~router vif client_ip qubesDB = + let open Lwt.Syntax in + let cleanup_tasks = Cleanup.create () in + Log.info (fun f -> + f "add client vif %a with IP %a" Dao.ClientVif.pp vif Ipaddr.V4.pp + client_ip); + let { Dao.ClientVif.domid; device_id } = vif in + let* backend = Netback.make ~domid ~device_id in + let* eth = ClientEth.connect backend in + let client_mac = Netback.frontend_mac backend in + let client_eth = router.clients in + let gateway_ip = Client_eth.client_gw client_eth in + let iface = new client_iface eth ~domid ~gateway_ip ~client_ip client_mac in + + Cleanup.on_cleanup cleanup_tasks (fun () -> remove_client router iface); + Lwt.async (fun () -> + Lwt.catch + (fun () -> add_client router iface) + (fun ex -> + Log.warn (fun f -> + f "Error with client %a: %s" Dao.ClientVif.pp vif + (Printexc.to_string ex)); + Lwt.return_unit)); + + let* () = + Lwt.catch + (conf_vif get_ts vif backend client_eth dns_client dns_servers ~client_ip + ~iface ~router ~cleanup_tasks qubesDB) + @@ fun exn -> + Log.warn (fun f -> + f "Error with client %a: %s" Dao.ClientVif.pp vif + (Printexc.to_string exn)); Lwt.return_unit + in + Lwt.return cleanup_tasks - (** A new client VM has been found in XenStore. Find its interface and connect to it. *) - let add_client get_ts dns_client dns_servers ~router vif client_ip qubesDB = - let open Lwt.Syntax in - let cleanup_tasks = Cleanup.create () in - Log.info (fun f -> - f "add client vif %a with IP %a" Dao.ClientVif.pp vif Ipaddr.V4.pp - client_ip); - let { Dao.ClientVif.domid; device_id } = vif in - let* backend = Netback.make ~domid ~device_id in - let* eth = ClientEth.connect backend in - let client_mac = Netback.frontend_mac backend in - let client_eth = router.clients in - let gateway_ip = Client_eth.client_gw client_eth in - let iface = new client_iface eth ~domid ~gateway_ip ~client_ip client_mac in - - Cleanup.on_cleanup cleanup_tasks (fun () -> remove_client router iface); - Lwt.async (fun () -> - Lwt.catch - (fun () -> - add_client router iface) - (fun ex -> - Log.warn (fun f -> - f "Error with client %a: %s" Dao.ClientVif.pp vif - (Printexc.to_string ex)); - Lwt.return_unit)) ; - - let* () = - Lwt.catch ( - conf_vif get_ts vif backend client_eth dns_client dns_servers ~client_ip ~iface ~router - ~cleanup_tasks qubesDB) - @@ fun exn -> - Log.warn (fun f -> - f "Error with client %a: %s" Dao.ClientVif.pp vif - (Printexc.to_string exn)); - Lwt.return_unit - in - Lwt.return cleanup_tasks - - (** Watch XenStore for notifications of new clients. *) - let wait_clients get_ts dns_client dns_servers qubesDB router = - let open Lwt.Syntax in - let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty in - Dao.watch_clients @@ fun new_set -> - (* Check for removed clients *) - let clean_up_clients key cleanup = - if not (Dao.VifMap.mem key new_set) then begin - clients := !clients |> Dao.VifMap.remove key; - Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key); - Cleanup.cleanup cleanup - end - in - Dao.VifMap.iter clean_up_clients !clients; - (* Check for added clients *) - let rec go seq = match Seq.uncons seq with - | None -> Lwt.return_unit - | Some ((key, ipaddr), seq) when not (Dao.VifMap.mem key !clients) -> - let* cleanup = add_client get_ts dns_client dns_servers ~router key ipaddr qubesDB in +(** Watch XenStore for notifications of new clients. *) +let wait_clients get_ts dns_client dns_servers qubesDB router = + let open Lwt.Syntax in + let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty in + Dao.watch_clients @@ fun new_set -> + (* Check for removed clients *) + let clean_up_clients key cleanup = + if not (Dao.VifMap.mem key new_set) then ( + clients := !clients |> Dao.VifMap.remove key; + Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key); + Cleanup.cleanup cleanup) + in + Dao.VifMap.iter clean_up_clients !clients; + (* Check for added clients *) + let rec go seq = + match Seq.uncons seq with + | None -> Lwt.return_unit + | Some ((key, ipaddr), seq) when not (Dao.VifMap.mem key !clients) -> + let* cleanup = + add_client get_ts dns_client dns_servers ~router key ipaddr qubesDB + in Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key); clients := Dao.VifMap.add key cleanup !clients; go seq - | Some (_, seq) -> go seq - in - go (Dao.VifMap.to_seq new_set) + | Some (_, seq) -> go seq + in + go (Dao.VifMap.to_seq new_set) - let send_dns_client_query t ~src_port ~dst ~dst_port buf = - match t.uplink with - | None -> - Log.err (fun f -> f "No uplink interface"); - Lwt.return (Error (`Msg "failure")) - | Some uplink -> ( - Lwt.catch - (fun () -> - U.write ~src_port ~dst ~dst_port uplink.udp (Cstruct.of_string buf) >|= function - | Error s -> - Log.err (fun f -> f "error sending udp packet: %a" U.pp_error s); - Error (`Msg "failure") - | Ok () -> Ok ()) - (fun ex -> - Log.err (fun f -> - f "uncaught exception trying to send DNS request to uplink: @[%s@]" - (Printexc.to_string ex)); - Lwt.return (Error (`Msg "DNS request not sent")))) - - (** Wait for packet from our uplink (we must have an uplink here...). *) - let rec uplink_listen get_ts dns_responses router = - Lwt_condition.wait router.uplink_connected >>= fun () -> - match router.uplink with - | None -> +let send_dns_client_query t ~src_port ~dst ~dst_port buf = + match t.uplink with + | None -> + Log.err (fun f -> f "No uplink interface"); + Lwt.return (Error (`Msg "failure")) + | Some uplink -> + Lwt.catch + (fun () -> + U.write ~src_port ~dst ~dst_port uplink.udp (Cstruct.of_string buf) + >|= function + | Error s -> + Log.err (fun f -> f "error sending udp packet: %a" U.pp_error s); + Error (`Msg "failure") + | Ok () -> Ok ()) + (fun ex -> Log.err (fun f -> f - "Uplink is connected but not found in the router, retrying...%!"); - uplink_listen get_ts dns_responses router - | Some uplink -> - let listen = - Lwt.catch - (fun () -> - Netif.listen uplink.net ~header_size:Ethernet.Packet.sizeof_ethernet - (fun frame -> - (* Handle one Ethernet frame from NetVM *) - UplinkEth.input uplink.eth ~arpv4:(Arp.input uplink.arp) - ~ipv4:(fun ip -> - let cache, r = - Nat_packet.of_ipv4_packet uplink.fragments ~now:(get_ts ()) - ip - in - uplink.fragments <- cache; - begin match r with - | Error e -> - Log.warn (fun f -> - f "Ignored unknown IPv4 message from uplink: %a" - Nat_packet.pp_error e); - Lwt.return () - | Ok None -> Lwt.return_unit - | Ok (Some (`IPv4 (header, packet))) -> - let open Udp_packet in - Log.debug (fun f -> - f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp - header.Ipv4_packet.src); - begin match packet with - | `UDP (header, packet) when My_nat.dns_port router.nat header.dst_port -> - Log.debug (fun f -> - f - "found a DNS packet whose dst_port (%d) was in the list of \ - dns_client ports" - header.dst_port); - Lwt_mvar.put dns_responses (header, Cstruct.to_string packet) - | _ -> ipv4_from_netvm router (`IPv4 (header, packet)) - end - end) - ~ipv6:(fun _ip -> Lwt.return_unit) - frame) - >|= or_raise "Uplink listen loop" Netif.pp_error) - (function Lwt.Canceled -> - (* We can be cancelled if reconnect_uplink is achieved (via the Lwt_condition), so we need to disconnect and broadcast when it's done - currently we delay 1s as Netif.disconnect is non-blocking... (need to fix upstream?) *) - Log.info (fun f -> - f "disconnecting from our uplink"); - U.disconnect uplink.udp >>= fun () -> - I.disconnect uplink.ip >>= fun () -> - (* mutable fragments : Fragments.Cache.t; *) - (* interface : interface; *) - Arp.disconnect uplink.arp >>= fun () -> - UplinkEth.disconnect uplink.eth >>= fun () -> - Netif.disconnect uplink.net >>= fun () -> - Lwt_condition.broadcast router.uplink_disconnected (); - Lwt.return_unit - | e -> Lwt.fail e) - in - let reconnect_uplink = - Lwt_condition.wait router.uplink_disconnect >>= fun () -> - Log.info (fun f -> - f "we need to reconnect to the new uplink"); - Lwt.return_unit - in - Lwt.pick [ listen ; reconnect_uplink ] >>= fun () -> - uplink_listen get_ts dns_responses router + "uncaught exception trying to send DNS request to uplink: \ + @[%s@]" + (Printexc.to_string ex)); + Lwt.return (Error (`Msg "DNS request not sent"))) - (** Connect to our uplink backend (we must have an uplink here...). *) - let connect config = - let my_ip = config.Dao.our_ip in - let gateway = config.Dao.netvm_ip in - Netif.connect "0" >>= fun net -> - UplinkEth.connect net >>= fun eth -> - Arp.connect eth >>= fun arp -> - Arp.add_ip arp my_ip >>= fun () -> - let cidr = Ipaddr.V4.Prefix.make 0 my_ip in - I.connect ~cidr ~gateway eth arp >>= fun ip -> - U.connect ip >>= fun udp -> - let netvm_mac = - Arp.query arp gateway >>= function - | Error e -> - Log.err(fun f -> f "Getting MAC of our NetVM: %a" Arp.pp_error e); - (* This mac address is a special address used by Qubes when the device +(** Wait for packet from our uplink (we must have an uplink here...). *) +let rec uplink_listen get_ts dns_responses router = + Lwt_condition.wait router.uplink_connected >>= fun () -> + match router.uplink with + | None -> + Log.err (fun f -> + f "Uplink is connected but not found in the router, retrying...%!"); + uplink_listen get_ts dns_responses router + | Some uplink -> + let listen = + Lwt.catch + (fun () -> + Netif.listen uplink.net ~header_size:Ethernet.Packet.sizeof_ethernet + (fun frame -> + (* Handle one Ethernet frame from NetVM *) + UplinkEth.input uplink.eth ~arpv4:(Arp.input uplink.arp) + ~ipv4:(fun ip -> + let cache, r = + Nat_packet.of_ipv4_packet uplink.fragments + ~now:(get_ts ()) ip + in + uplink.fragments <- cache; + match r with + | Error e -> + Log.warn (fun f -> + f "Ignored unknown IPv4 message from uplink: %a" + Nat_packet.pp_error e); + Lwt.return () + | Ok None -> Lwt.return_unit + | Ok (Some (`IPv4 (header, packet))) -> ( + let open Udp_packet in + Log.debug (fun f -> + f "received ipv4 packet from %a on uplink" + Ipaddr.V4.pp header.Ipv4_packet.src); + match packet with + | `UDP (header, packet) + when My_nat.dns_port router.nat header.dst_port -> + Log.debug (fun f -> + f + "found a DNS packet whose dst_port (%d) was \ + in the list of dns_client ports" + header.dst_port); + Lwt_mvar.put dns_responses + (header, Cstruct.to_string packet) + | _ -> ipv4_from_netvm router (`IPv4 (header, packet)))) + ~ipv6:(fun _ip -> Lwt.return_unit) + frame) + >|= or_raise "Uplink listen loop" Netif.pp_error) + (function + | Lwt.Canceled -> + (* We can be cancelled if reconnect_uplink is achieved (via the Lwt_condition), so we need to disconnect and broadcast when it's done + currently we delay 1s as Netif.disconnect is non-blocking... (need to fix upstream?) *) + Log.info (fun f -> f "disconnecting from our uplink"); + U.disconnect uplink.udp >>= fun () -> + I.disconnect uplink.ip >>= fun () -> + (* mutable fragments : Fragments.Cache.t; *) + (* interface : interface; *) + Arp.disconnect uplink.arp >>= fun () -> + UplinkEth.disconnect uplink.eth >>= fun () -> + Netif.disconnect uplink.net >>= fun () -> + Lwt_condition.broadcast router.uplink_disconnected (); + Lwt.return_unit + | e -> Lwt.fail e) + in + let reconnect_uplink = + Lwt_condition.wait router.uplink_disconnect >>= fun () -> + Log.info (fun f -> f "we need to reconnect to the new uplink"); + Lwt.return_unit + in + Lwt.pick [ listen; reconnect_uplink ] >>= fun () -> + uplink_listen get_ts dns_responses router + +(** Connect to our uplink backend (we must have an uplink here...). *) +let connect config = + let my_ip = config.Dao.our_ip in + let gateway = config.Dao.netvm_ip in + Netif.connect "0" >>= fun net -> + UplinkEth.connect net >>= fun eth -> + Arp.connect eth >>= fun arp -> + Arp.add_ip arp my_ip >>= fun () -> + let cidr = Ipaddr.V4.Prefix.make 0 my_ip in + I.connect ~cidr ~gateway eth arp >>= fun ip -> + U.connect ip >>= fun udp -> + let netvm_mac = + Arp.query arp gateway >>= function + | Error e -> + Log.err (fun f -> f "Getting MAC of our NetVM: %a" Arp.pp_error e); + (* This mac address is a special address used by Qubes when the device is not managed by Qubes itself. This can occurs inside a service AppVM (e.g. VPN) when the service creates a new interface. *) - Lwt.return (Macaddr.of_string_exn "fe:ff:ff:ff:ff:ff") - | Ok mac -> Lwt.return mac - in - let interface = - new netvm_iface eth netvm_mac ~my_ip ~other_ip:config.Dao.netvm_ip - in - let fragments = Fragments.Cache.empty (256 * 1024) in - Lwt.return { net; eth; arp; interface; fragments; ip; udp } + Lwt.return (Macaddr.of_string_exn "fe:ff:ff:ff:ff:ff") + | Ok mac -> Lwt.return mac + in + let interface = + new netvm_iface eth netvm_mac ~my_ip ~other_ip:config.Dao.netvm_ip + in + let fragments = Fragments.Cache.empty (256 * 1024) in + Lwt.return { net; eth; arp; interface; fragments; ip; udp } - (** Wait Xenstore for our uplink changes (we must have an uplink here...). *) - let uplink_wait_update qubesDB router = - let rec aux current_db = - let netvm = "/qubes-gateway" in - Log.info (fun f -> f "Waiting for netvm changes to %S...%!" netvm); - Qubes.DB.after qubesDB current_db >>= fun new_db -> - (match (router.uplink, Qubes.DB.KeyMap.find_opt netvm new_db) with - | Some uplink, Some netvm - when not - (String.equal netvm - (Ipaddr.V4.to_string uplink.interface#other_ip)) -> - Log.info (fun f -> - f "Our netvm IP has changed, before it was %s, now it's: %s%!" - (Ipaddr.V4.to_string uplink.interface#other_ip) - netvm); - Lwt_condition.broadcast router.uplink_disconnect (); - (* wait for uplink disconnexion *) - Lwt_condition.wait router.uplink_disconnected >>= fun () -> - Dao.read_network_config qubesDB >>= fun config -> - Dao.print_network_config config; - connect config >>= fun uplink -> - update router ~config ~uplink:(Some uplink) >>= fun () -> - Lwt_condition.broadcast router.uplink_connected (); - Lwt.return_unit - | None, Some _ -> - (* a new interface is attributed to qubes-mirage-firewall *) - Log.info (fun f -> f "Going from netvm not connected to %s%!" netvm); - Dao.read_network_config qubesDB >>= fun config -> - Dao.print_network_config config; - connect config >>= fun uplink -> - update router ~config ~uplink:(Some uplink) >>= fun () -> - Lwt_condition.broadcast router.uplink_connected (); - Lwt.return_unit - | Some _, None -> - (* This currently is never triggered :( *) - Log.info (fun f -> - f "TODO: Our netvm disapeared, troubles are coming!%!"); - Lwt.return_unit - | Some _, Some _ (* The new netvm IP is unchanged (it's our old netvm IP) *) - | None, None -> - Log.info (fun f -> - f "QubesDB has changed but not the situation of our netvm!%!"); - Lwt.return_unit) - >>= fun () -> aux new_db - in - aux Qubes.DB.KeyMap.empty +(** Wait Xenstore for our uplink changes (we must have an uplink here...). *) +let uplink_wait_update qubesDB router = + let rec aux current_db = + let netvm = "/qubes-gateway" in + Log.info (fun f -> f "Waiting for netvm changes to %S...%!" netvm); + Qubes.DB.after qubesDB current_db >>= fun new_db -> + (match (router.uplink, Qubes.DB.KeyMap.find_opt netvm new_db) with + | Some uplink, Some netvm + when not + (String.equal netvm + (Ipaddr.V4.to_string uplink.interface#other_ip)) -> + Log.info (fun f -> + f "Our netvm IP has changed, before it was %s, now it's: %s%!" + (Ipaddr.V4.to_string uplink.interface#other_ip) + netvm); + Lwt_condition.broadcast router.uplink_disconnect (); + (* wait for uplink disconnexion *) + Lwt_condition.wait router.uplink_disconnected >>= fun () -> + Dao.read_network_config qubesDB >>= fun config -> + Dao.print_network_config config; + connect config >>= fun uplink -> + update router ~config ~uplink:(Some uplink) >>= fun () -> + Lwt_condition.broadcast router.uplink_connected (); + Lwt.return_unit + | None, Some _ -> + (* a new interface is attributed to qubes-mirage-firewall *) + Log.info (fun f -> f "Going from netvm not connected to %s%!" netvm); + Dao.read_network_config qubesDB >>= fun config -> + Dao.print_network_config config; + connect config >>= fun uplink -> + update router ~config ~uplink:(Some uplink) >>= fun () -> + Lwt_condition.broadcast router.uplink_connected (); + Lwt.return_unit + | Some _, None -> + (* This currently is never triggered :( *) + Log.info (fun f -> + f "TODO: Our netvm disapeared, troubles are coming!%!"); + Lwt.return_unit + | Some _, Some _ (* The new netvm IP is unchanged (it's our old netvm IP) *) + | None, None -> + Log.info (fun f -> + f "QubesDB has changed but not the situation of our netvm!%!"); + Lwt.return_unit) + >>= fun () -> aux new_db + in + aux Qubes.DB.KeyMap.empty diff --git a/fw_utils.ml b/fw_utils.ml index f20c63a..53fddb0 100644 --- a/fw_utils.ml +++ b/fw_utils.ml @@ -15,14 +15,16 @@ end class type client_link = object inherit interface method other_mac : Macaddr.t - method log_header : string (* For log messages *) - method get_rules: Pf_qubes.Parse_qubes.rule list - method set_rules: string Qubes.DB.KeyMap.t -> unit + method log_header : string (* For log messages *) + method get_rules : Pf_qubes.Parse_qubes.rule list + method set_rules : string Qubes.DB.KeyMap.t -> unit end -(** An Ethernet header from [src]'s MAC address to [dst]'s with an IPv4 payload. *) +(** An Ethernet header from [src]'s MAC address to [dst]'s with an IPv4 payload. +*) let eth_header ethertype ~src ~dst = - Ethernet.Packet.make_cstruct { Ethernet.Packet.source = src; destination = dst; ethertype } + Ethernet.Packet.make_cstruct + { Ethernet.Packet.source = src; destination = dst; ethertype } let error fmt = let err s = Failure s in diff --git a/memory_pressure.ml b/memory_pressure.ml index 667bd50..fe04bca 100644 --- a/memory_pressure.ml +++ b/memory_pressure.ml @@ -2,14 +2,14 @@ See the README file for details. *) let src = Logs.Src.create "memory_pressure" ~doc:"Memory pressure monitor" + module Log = (val Logs.src_log src : Logs.LOG) let fraction_free stats = let { Xen_os.Memory.free_words; heap_words; _ } = stats in float free_words /. float heap_words -let init () = - Gc.full_major () +let init () = Gc.full_major () let status () = let stats = Xen_os.Memory.quick_stat () in @@ -18,6 +18,4 @@ let status () = Gc.full_major (); Xen_os.Memory.trim (); let stats = Xen_os.Memory.quick_stat () in - if fraction_free stats < 0.6 then `Memory_critical - else `Ok - ) + if fraction_free stats < 0.6 then `Memory_critical else `Ok) diff --git a/memory_pressure.mli b/memory_pressure.mli index c0d9f49..f0d7df8 100644 --- a/memory_pressure.mli +++ b/memory_pressure.mli @@ -8,5 +8,5 @@ val status : unit -> [ `Ok | `Memory_critical ] (** Check the memory situation. If we're running low, do a GC (work-around for http://caml.inria.fr/mantis/view.php?id=7100 and OCaml GC needing to malloc extra space to run finalisers). Returns [`Memory_critical] if memory is - still low - caller should take action to reduce memory use. - After GC, updates meminfo in XenStore. *) + still low - caller should take action to reduce memory use. After GC, + updates meminfo in XenStore. *) diff --git a/my_dns.ml b/my_dns.ml index 6000e80..e3bb267 100644 --- a/my_dns.ml +++ b/my_dns.ml @@ -1,72 +1,81 @@ open Lwt.Infix - type +'a io = 'a Lwt.t - type io_addr = Ipaddr.V4.t * int - type stack = Dispatcher.t * - (src_port:int -> dst:Ipaddr.V4.t -> dst_port:int -> string -> (unit, [ `Msg of string ]) result Lwt.t) * - (Udp_packet.t * string) Lwt_mvar.t +type +'a io = 'a Lwt.t +type io_addr = Ipaddr.V4.t * int - module IM = Map.Make(Int) +type stack = + Dispatcher.t + * (src_port:int -> + dst:Ipaddr.V4.t -> + dst_port:int -> + string -> + (unit, [ `Msg of string ]) result Lwt.t) + * (Udp_packet.t * string) Lwt_mvar.t - type t = { - protocol : Dns.proto ; - nameserver : io_addr ; - stack : stack ; - timeout_ns : int64 ; - mutable requests : string Lwt_condition.t IM.t ; - } - type context = t +module IM = Map.Make (Int) - let nameservers { protocol ; nameserver ; _ } = protocol, [ nameserver ] - let rng = Mirage_crypto_rng.generate ?g:None - let clock = Mirage_mtime.elapsed_ns +type t = { + protocol : Dns.proto; + nameserver : io_addr; + stack : stack; + timeout_ns : int64; + mutable requests : string Lwt_condition.t IM.t; +} - let rec read t = - let _, _, answer = t.stack in - Lwt_mvar.take answer >>= fun (_, data) -> - if String.length data > 2 then begin - match IM.find_opt (String.get_uint16_be data 0) t.requests with - | Some cond -> Lwt_condition.broadcast cond data - | None -> () - end; - read t +type context = t - let create ?nameservers ~timeout stack = - let protocol, nameserver = match nameservers with - | None | Some (_, []) -> invalid_arg "no nameserver found" - | Some (proto, ns :: _) -> proto, ns - in - let t = - { protocol ; nameserver ; stack ; timeout_ns = timeout ; requests = IM.empty } - in - Lwt.async (fun () -> read t); - t +let nameservers { protocol; nameserver; _ } = (protocol, [ nameserver ]) +let rng = Mirage_crypto_rng.generate ?g:None +let clock = Mirage_mtime.elapsed_ns - let with_timeout timeout_ns f = - let timeout = Mirage_sleep.ns timeout_ns >|= fun () -> Error (`Msg "DNS request timeout") in - Lwt.pick [ f ; timeout ] +let rec read t = + let _, _, answer = t.stack in + Lwt_mvar.take answer >>= fun (_, data) -> + (if String.length data > 2 then + match IM.find_opt (String.get_uint16_be data 0) t.requests with + | Some cond -> Lwt_condition.broadcast cond data + | None -> ()); + read t - let connect (t : t) = Lwt.return (Ok (t.protocol, t)) +let create ?nameservers ~timeout stack = + let protocol, nameserver = + match nameservers with + | None | Some (_, []) -> invalid_arg "no nameserver found" + | Some (proto, ns :: _) -> (proto, ns) + in + let t = + { protocol; nameserver; stack; timeout_ns = timeout; requests = IM.empty } + in + Lwt.async (fun () -> read t); + t - let send_recv (ctx : context) buf : (string, [> `Msg of string ]) result Lwt.t = - let dst, dst_port = ctx.nameserver in - let router, send_udp, _ = ctx.stack in - let src_port, evict = - My_nat.free_udp_port router.nat ~src:router.config.our_ip ~dst ~dst_port:53 - in - let id = String.get_uint16_be buf 0 in - with_timeout ctx.timeout_ns - (let cond = Lwt_condition.create () in - ctx.requests <- IM.add id cond ctx.requests; - (send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg) >>= function - | Ok () -> Lwt_condition.wait cond >|= fun dns_response -> Ok dns_response - | Error _ as e -> Lwt.return e) >|= fun result -> - ctx.requests <- IM.remove id ctx.requests; - evict (); - result +let with_timeout timeout_ns f = + let timeout = + Mirage_sleep.ns timeout_ns >|= fun () -> Error (`Msg "DNS request timeout") + in + Lwt.pick [ f; timeout ] - let close _ = Lwt.return_unit +let connect (t : t) = Lwt.return (Ok (t.protocol, t)) - let bind = Lwt.bind +let send_recv (ctx : context) buf : (string, [> `Msg of string ]) result Lwt.t = + let dst, dst_port = ctx.nameserver in + let router, send_udp, _ = ctx.stack in + let src_port, evict = + My_nat.free_udp_port router.nat ~src:router.config.our_ip ~dst ~dst_port:53 + in + let id = String.get_uint16_be buf 0 in + with_timeout ctx.timeout_ns + (let cond = Lwt_condition.create () in + ctx.requests <- IM.add id cond ctx.requests; + send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg + >>= function + | Ok () -> Lwt_condition.wait cond >|= fun dns_response -> Ok dns_response + | Error _ as e -> Lwt.return e) + >|= fun result -> + ctx.requests <- IM.remove id ctx.requests; + evict (); + result - let lift = Lwt.return +let close _ = Lwt.return_unit +let bind = Lwt.bind +let lift = Lwt.return diff --git a/my_nat.ml b/my_nat.ml index 1e86c2d..e6b70e6 100644 --- a/my_nat.ml +++ b/my_nat.ml @@ -2,65 +2,57 @@ See the README file for details. *) let src = Logs.Src.create "my-nat" ~doc:"NAT shim" + module Log = (val Logs.src_log src : Logs.LOG) -type action = [ - | `NAT - | `Redirect of Mirage_nat.endpoint -] +type action = [ `NAT | `Redirect of Mirage_nat.endpoint ] module Nat = Mirage_nat_lru -module S = - Set.Make(struct type t = int let compare (a : int) (b : int) = compare a b end) +module S = Set.Make (struct + type t = int -type t = { - table : Nat.t; - mutable udp_dns : S.t; - last_resort_port : int -} + let compare (a : int) (b : int) = compare a b +end) -let pick_port () = - 1024 + Random.int (0xffff - 1024) +type t = { table : Nat.t; mutable udp_dns : S.t; last_resort_port : int } + +let pick_port () = 1024 + Random.int (0xffff - 1024) let create ~max_entries = let tcp_size = 7 * max_entries / 8 in let udp_size = max_entries - tcp_size in let table = Nat.empty ~tcp_size ~udp_size ~icmp_size:100 in let last_resort_port = pick_port () in - { table ; udp_dns = S.empty ; last_resort_port } + { table; udp_dns = S.empty; last_resort_port } let pick_free_port t proto = let rec go retries = - if retries = 0 then - None + if retries = 0 then None else let p = 1024 + Random.int (0xffff - 1024) in match proto with - | `Udp when S.mem p t.udp_dns || p = t.last_resort_port -> - go (retries - 1) + | `Udp when S.mem p t.udp_dns || p = t.last_resort_port -> go (retries - 1) | _ -> Some p in go 10 let free_udp_port t ~src ~dst ~dst_port = let rec go retries = - if retries = 0 then - t.last_resort_port, Fun.id + if retries = 0 then (t.last_resort_port, Fun.id) else let src_port = Option.value ~default:t.last_resort_port (pick_free_port t `Udp) in - if Nat.is_port_free t.table `Udp ~src ~dst ~src_port ~dst_port then begin + if Nat.is_port_free t.table `Udp ~src ~dst ~src_port ~dst_port then let remove = - if src_port <> t.last_resort_port then begin + if src_port <> t.last_resort_port then ( t.udp_dns <- S.add src_port t.udp_dns; - (fun () -> t.udp_dns <- S.remove src_port t.udp_dns) - end else Fun.id + fun () -> t.udp_dns <- S.remove src_port t.udp_dns) + else Fun.id in - src_port, remove - end else - go (retries - 1) + (src_port, remove) + else go (retries - 1) in go 10 @@ -68,27 +60,27 @@ let dns_port t port = S.mem port t.udp_dns || port = t.last_resort_port let translate t packet = match Nat.translate t.table packet with - | Error (`Untranslated | `TTL_exceeded as e) -> - Log.debug (fun f -> f "Failed to NAT %a: %a" - Nat_packet.pp packet - Mirage_nat.pp_error e - ); - None + | Error ((`Untranslated | `TTL_exceeded) as e) -> + Log.debug (fun f -> + f "Failed to NAT %a: %a" Nat_packet.pp packet Mirage_nat.pp_error e); + None | Ok packet -> Some packet -let remove_connections t ip = - ignore (Nat.remove_connections t.table ip) +let remove_connections t ip = ignore (Nat.remove_connections t.table ip) let add_nat_rule_and_translate t ~xl_host action packet = - let proto = match packet with + let proto = + match packet with | `IPv4 (_, `TCP _) -> `Tcp | `IPv4 (_, `UDP _) -> `Udp | `IPv4 (_, `ICMP _) -> `Icmp in - match Nat.add t.table packet xl_host (fun () -> pick_free_port t proto) action with + match + Nat.add t.table packet xl_host (fun () -> pick_free_port t proto) action + with | Error `Overlap -> Error "Too many retries" | Error `Cannot_NAT -> Error "Cannot NAT this packet" | Ok () -> - Log.debug (fun f -> f "Updated NAT table: %a" Nat.pp_summary t.table); - Option.to_result ~none:"No NAT entry, even after adding one!" - (translate t packet) + Log.debug (fun f -> f "Updated NAT table: %a" Nat.pp_summary t.table); + Option.to_result ~none:"No NAT entry, even after adding one!" + (translate t packet) diff --git a/my_nat.mli b/my_nat.mli index eab1a34..a9d3829 100644 --- a/my_nat.mli +++ b/my_nat.mli @@ -4,17 +4,23 @@ (* Abstract over NAT interface (todo: remove this) *) type t +type action = [ `NAT | `Redirect of Mirage_nat.endpoint ] -type action = [ - | `NAT - | `Redirect of Mirage_nat.endpoint -] - -val free_udp_port : t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> dst_port:int -> +val free_udp_port : + t -> + src:Ipaddr.V4.t -> + dst:Ipaddr.V4.t -> + dst_port:int -> int * (unit -> unit) + val dns_port : t -> int -> bool val create : max_entries:int -> t val remove_connections : t -> Ipaddr.V4.t -> unit val translate : t -> Nat_packet.t -> Nat_packet.t option -val add_nat_rule_and_translate : t -> - xl_host:Ipaddr.V4.t -> action -> Nat_packet.t -> (Nat_packet.t, string) result + +val add_nat_rule_and_translate : + t -> + xl_host:Ipaddr.V4.t -> + action -> + Nat_packet.t -> + (Nat_packet.t, string) result diff --git a/packet.ml b/packet.ml index 7d8c3c4..d6d4f92 100644 --- a/packet.ml +++ b/packet.ml @@ -8,9 +8,8 @@ type port = int type host = [ `Client of client_link | `Firewall | `NetVM | `External of Ipaddr.t ] -type transport_header = [`TCP of Tcp.Tcp_packet.t - |`UDP of Udp_packet.t - |`ICMP of Icmpv4_packet.t] +type transport_header = + [ `TCP of Tcp.Tcp_packet.t | `UDP of Udp_packet.t | `ICMP of Icmpv4_packet.t ] type ('src, 'dst) t = { ipv4_header : Ipv4_packet.t; @@ -19,13 +18,14 @@ type ('src, 'dst) t = { src : 'src; dst : 'dst; } + let pp_transport_header f = function | `ICMP h -> Icmpv4_packet.pp f h | `TCP h -> Tcp.Tcp_packet.pp f h | `UDP h -> Udp_packet.pp f h let pp_host fmt = function - | `Client c -> Ipaddr.V4.pp fmt (c#other_ip) + | `Client c -> Ipaddr.V4.pp fmt c#other_ip | `Unknown_client ip -> Format.fprintf fmt "unknown-client(%a)" Ipaddr.pp ip | `NetVM -> Format.pp_print_string fmt "net-vm" | `External ip -> Format.fprintf fmt "external(%a)" Ipaddr.pp ip @@ -33,32 +33,28 @@ let pp_host fmt = function let to_mirage_nat_packet t : Nat_packet.t = match t.transport_header with - | `TCP h -> `IPv4 (t.ipv4_header, (`TCP (h, t.transport_payload))) - | `UDP h -> `IPv4 (t.ipv4_header, (`UDP (h, t.transport_payload))) - | `ICMP h -> `IPv4 (t.ipv4_header, (`ICMP (h, t.transport_payload))) + | `TCP h -> `IPv4 (t.ipv4_header, `TCP (h, t.transport_payload)) + | `UDP h -> `IPv4 (t.ipv4_header, `UDP (h, t.transport_payload)) + | `ICMP h -> `IPv4 (t.ipv4_header, `ICMP (h, t.transport_payload)) let of_mirage_nat_packet ~src ~dst packet : ('a, 'b) t option = - let `IPv4 (ipv4_header, ipv4_payload) = packet in - let transport_header, transport_payload = match ipv4_payload with - | `TCP (h, p) -> `TCP h, p - | `UDP (h, p) -> `UDP h, p - | `ICMP (h, p) -> `ICMP h, p + let (`IPv4 (ipv4_header, ipv4_payload)) = packet in + let transport_header, transport_payload = + match ipv4_payload with + | `TCP (h, p) -> (`TCP h, p) + | `UDP (h, p) -> (`UDP h, p) + | `ICMP (h, p) -> (`ICMP h, p) in - Some { - ipv4_header; - transport_header; - transport_payload; - src; - dst; - } + Some { ipv4_header; transport_header; transport_payload; src; dst } (* possible actions to take for a packet: *) -type action = [ - | `Accept (* Send to destination, unmodified. *) - | `NAT (* Rewrite source field to the firewall's IP, with a fresh source port. +type action = + [ `Accept (* Send to destination, unmodified. *) + | `NAT + (* Rewrite source field to the firewall's IP, with a fresh source port. Also, add translation rules for future traffic in both directions, between these hosts on these ports, and corresponding ICMP error traffic. *) - | `NAT_to of host * port (* As for [`NAT], but also rewrite the packet's + | `NAT_to of host * port + (* As for [`NAT], but also rewrite the packet's destination fields so it will be sent to [host:port]. *) - | `Drop of string (* Drop packet for this reason. *) -] + | `Drop of string (* Drop packet for this reason. *) ] diff --git a/packet.mli b/packet.mli index f7d2876..af8ee43 100644 --- a/packet.mli +++ b/packet.mli @@ -1,15 +1,13 @@ type port = int type host = - [ `Client of Fw_utils.client_link (** an IP address on the private network *) - | `Firewall (** the firewall's IP on the private network *) - | `NetVM (** the IP of the firewall's default route *) - | `External of Ipaddr.t (** an IP on the public network *) - ] + [ `Client of Fw_utils.client_link (** an IP address on the private network *) + | `Firewall (** the firewall's IP on the private network *) + | `NetVM (** the IP of the firewall's default route *) + | `External of Ipaddr.t (** an IP on the public network *) ] -type transport_header = [`TCP of Tcp.Tcp_packet.t - |`UDP of Udp_packet.t - |`ICMP of Icmpv4_packet.t] +type transport_header = + [ `TCP of Tcp.Tcp_packet.t | `UDP of Udp_packet.t | `ICMP of Icmpv4_packet.t ] type ('src, 'dst) t = { ipv4_header : Ipv4_packet.t; @@ -20,20 +18,18 @@ type ('src, 'dst) t = { } val pp_transport_header : Format.formatter -> transport_header -> unit - val pp_host : Format.formatter -> host -> unit - val to_mirage_nat_packet : ('a, 'b) t -> Nat_packet.t - val of_mirage_nat_packet : src:'a -> dst:'b -> Nat_packet.t -> ('a, 'b) t option (* possible actions to take for a packet: *) -type action = [ - | `Accept (* Send to destination, unmodified. *) - | `NAT (* Rewrite source field to the firewall's IP, with a fresh source port. +type action = + [ `Accept (* Send to destination, unmodified. *) + | `NAT + (* Rewrite source field to the firewall's IP, with a fresh source port. Also, add translation rules for future traffic in both directions, between these hosts on these ports, and corresponding ICMP error traffic. *) - | `NAT_to of host * port (* As for [`NAT], but also rewrite the packet's + | `NAT_to of host * port + (* As for [`NAT], but also rewrite the packet's destination fields so it will be sent to [host:port]. *) - | `Drop of string (* Drop packet for this reason. *) -] + | `Drop of string (* Drop packet for this reason. *) ] diff --git a/rules.ml b/rules.ml index 9210b47..c85a596 100644 --- a/rules.ml +++ b/rules.ml @@ -8,93 +8,115 @@ open Lwt.Infix module Q = Pf_qubes.Parse_qubes let src = Logs.Src.create "rules" ~doc:"Firewall rules" + module Log = (val Logs.src_log src : Logs.LOG) let dns_port = 53 module Classifier = struct - - let matches_port dstports (port : int) = match dstports with + let matches_port dstports (port : int) = + match dstports with | None -> true | Some (Q.Range_inclusive (min, max)) -> min <= port && port <= max - let matches_proto rule dns_servers packet = match rule.Q.proto, rule.Q.specialtarget with + let matches_proto rule dns_servers packet = + match (rule.Q.proto, rule.Q.specialtarget) with | None, None -> true - | None, Some `dns when List.mem packet.ipv4_header.Ipv4_packet.dst dns_servers -> begin - (* specialtarget=dns applies only to the specialtarget destination IPs, and + | None, Some `dns + when List.mem packet.ipv4_header.Ipv4_packet.dst dns_servers -> ( + (* specialtarget=dns applies only to the specialtarget destination IPs, and specialtarget=dns is also implicitly tcp/udp port 53 *) - match packet.transport_header with + match packet.transport_header with | `TCP header -> header.Tcp.Tcp_packet.dst_port = dns_port | `UDP header -> header.Udp_packet.dst_port = dns_port - | _ -> false - end - (* DNS rules can only match traffic headed to the specialtarget hosts, so any other destination + | _ -> false) + (* DNS rules can only match traffic headed to the specialtarget hosts, so any other destination isn't a match for DNS rules *) | None, Some `dns -> false - | Some rule_proto, _ -> match rule_proto, packet.transport_header with - | `tcp, `TCP header -> matches_port rule.Q.dstports header.Tcp.Tcp_packet.dst_port - | `udp, `UDP header -> matches_port rule.Q.dstports header.Udp_packet.dst_port - | `icmp, `ICMP header -> - begin - match rule.Q.icmp_type with - | None -> true - | Some rule_icmp_type -> - 0 = compare rule_icmp_type @@ Icmpv4_wire.ty_to_int header.Icmpv4_packet.ty - end - | _, _ -> false + | Some rule_proto, _ -> ( + match (rule_proto, packet.transport_header) with + | `tcp, `TCP header -> + matches_port rule.Q.dstports header.Tcp.Tcp_packet.dst_port + | `udp, `UDP header -> + matches_port rule.Q.dstports header.Udp_packet.dst_port + | `icmp, `ICMP header -> ( + match rule.Q.icmp_type with + | None -> true + | Some rule_icmp_type -> + 0 + = compare rule_icmp_type + @@ Icmpv4_wire.ty_to_int header.Icmpv4_packet.ty) + | _, _ -> false) let matches_dest dns_client rule packet = let ip = packet.ipv4_header.Ipv4_packet.dst in match rule.Q.dst with - | `any -> Lwt.return @@ `Match rule + | `any -> Lwt.return @@ `Match rule | `hosts subnet -> - Lwt.return @@ if (Ipaddr.Prefix.mem Ipaddr.(V4 ip) subnet) then `Match rule else `No_match - | `dnsname name -> - Log.debug (fun f -> f "Resolving %a" Domain_name.pp name); - dns_client name >|= function - | Ok (_ttl, found_ips) -> - if Ipaddr.V4.Set.mem ip found_ips - then `Match rule + Lwt.return + @@ + if Ipaddr.Prefix.mem Ipaddr.(V4 ip) subnet then `Match rule else `No_match - | Error (`Msg m) -> - Log.warn (fun f -> f "Ignoring rule %a, could not resolve" Q.pp_rule rule); - Log.debug (fun f -> f "%s" m); - `No_match - | Error _ -> assert false (* TODO: fix type of dns_client so that this case can go *) - + | `dnsname name -> ( + Log.debug (fun f -> f "Resolving %a" Domain_name.pp name); + dns_client name >|= function + | Ok (_ttl, found_ips) -> + if Ipaddr.V4.Set.mem ip found_ips then `Match rule else `No_match + | Error (`Msg m) -> + Log.warn (fun f -> + f "Ignoring rule %a, could not resolve" Q.pp_rule rule); + Log.debug (fun f -> f "%s" m); + `No_match + | Error _ -> + assert + false (* TODO: fix type of dns_client so that this case can go *)) end let find_first_match dns_client dns_servers packet acc rule = match acc with | `No_match -> - if Classifier.matches_proto rule dns_servers packet - then Classifier.matches_dest dns_client rule packet - else Lwt.return `No_match + if Classifier.matches_proto rule dns_servers packet then + Classifier.matches_dest dns_client rule packet + else Lwt.return `No_match | q -> Lwt.return q (* Does the packet match our rules? *) -let classify_client_packet dns_client dns_servers (packet : ([`Client of Fw_utils.client_link], _) Packet.t) = +let classify_client_packet dns_client dns_servers + (packet : ([ `Client of Fw_utils.client_link ], _) Packet.t) = let (`Client client_link) = packet.src in let rules = client_link#get_rules in - Lwt_list.fold_left_s (find_first_match dns_client dns_servers packet) `No_match rules >|= function + Lwt_list.fold_left_s + (find_first_match dns_client dns_servers packet) + `No_match rules + >|= function | `No_match -> `Drop "No matching rule; assuming default drop" - | `Match {Q.action = Q.Accept; _} -> `Accept - | `Match ({Q.action = Q.Drop; _} as rule) -> - `Drop (Format.asprintf "rule number %a explicitly drops this packet" Q.pp_rule rule) + | `Match { Q.action = Q.Accept; _ } -> `Accept + | `Match ({ Q.action = Q.Drop; _ } as rule) -> + `Drop + (Format.asprintf "rule number %a explicitly drops this packet" Q.pp_rule + rule) let translate_accepted_packets dns_client dns_servers packet = classify_client_packet dns_client dns_servers packet >|= function | `Accept -> `NAT | `Drop s -> `Drop s -(** Packets from the private interface that don't match any NAT table entry are being checked against the fw rules here *) -let from_client dns_client dns_servers (packet : ([`Client of Fw_utils.client_link], _) Packet.t) : Packet.action Lwt.t = +(** Packets from the private interface that don't match any NAT table entry are + being checked against the fw rules here *) +let from_client dns_client dns_servers + (packet : ([ `Client of Fw_utils.client_link ], _) Packet.t) : + Packet.action Lwt.t = match packet with - | { dst = `External _ ; _ } | { dst = `NetVM; _ } -> translate_accepted_packets dns_client dns_servers packet - | { dst = `Firewall ; _ } -> Lwt.return @@ `Drop "packet addressed to firewall itself" - | { dst = `Client _ ; _ } -> classify_client_packet dns_client dns_servers packet + | { dst = `External _; _ } | { dst = `NetVM; _ } -> + translate_accepted_packets dns_client dns_servers packet + | { dst = `Firewall; _ } -> + Lwt.return @@ `Drop "packet addressed to firewall itself" + | { dst = `Client _; _ } -> + classify_client_packet dns_client dns_servers packet | _ -> Lwt.return @@ `Drop "could not classify packet" -(** Packets from the outside world that don't match any NAT table entry are being dropped by default *) -let from_netvm (_packet : ([`NetVM | `External of _], _) Packet.t) : Packet.action Lwt.t = +(** Packets from the outside world that don't match any NAT table entry are + being dropped by default *) +let from_netvm (_packet : ([ `NetVM | `External of _ ], _) Packet.t) : + Packet.action Lwt.t = Lwt.return @@ `Drop "drop by default" diff --git a/test/config.ml b/test/config.ml index d8695e4..d5589d5 100644 --- a/test/config.ml +++ b/test/config.ml @@ -2,26 +2,32 @@ open Mirage let pin = "git+https://github.com/roburio/alcotest.git#mirage" -let packages = [ - package "ethernet"; - package "arp"; - package "arp-mirage"; - package "ipaddr"; - package "tcpip" ~sublibs:["stack-direct"; "icmpv4"; "ipv4"; "udp"; "tcp"]; - package "mirage-qubes"; - package "mirage-qubes-ipv4"; - package "dns-client" ~sublibs:["mirage"]; - package ~pin "alcotest"; - package ~pin "alcotest-mirage"; -] +let packages = + [ + package "ethernet"; + package "arp"; + package "arp-mirage"; + package "ipaddr"; + package "tcpip" ~sublibs:[ "stack-direct"; "icmpv4"; "ipv4"; "udp"; "tcp" ]; + package "mirage-qubes"; + package "mirage-qubes-ipv4"; + package "dns-client" ~sublibs:[ "mirage" ]; + package ~pin "alcotest"; + package ~pin "alcotest-mirage"; + ] let client = - foreign ~packages - "Unikernel.Client" @@ random @-> time @-> mclock @-> network @-> qubesdb @-> job + foreign ~packages "Unikernel.Client" + @@ random @-> time @-> mclock @-> network @-> qubesdb @-> job let db = default_qubesdb let network = default_network let () = - let job = [ client $ default_random $ default_time $ default_monotonic_clock $ network $ db ] in + let job = + [ + client $ default_random $ default_time $ default_monotonic_clock $ network + $ db; + ] + in register "http-fetch" job diff --git a/test/unikernel.ml b/test/unikernel.ml index 04f7d6a..2a0c23a 100644 --- a/test/unikernel.ml +++ b/test/unikernel.ml @@ -1,6 +1,8 @@ open Lwt.Infix + (* https://www.qubes-os.org/doc/vm-interface/#firewall-rules-in-4x *) let src = Logs.Src.create "firewall test" ~doc:"Firewalltest" + module Log = (val Logs.src_log src : Logs.LOG) (* TODO @@ -39,18 +41,24 @@ module Log = (val Logs.src_log src : Logs.LOG) (* Point-to-point links out of a netvm always have this IP TODO clarify with Marek *) let netvm = "10.137.0.5" + (* default "nameserver"s, which netvm redirects to whatever its real nameservers are *) -let nameserver_1, nameserver_2 = "10.139.1.1", "10.139.1.2" +let nameserver_1, nameserver_2 = ("10.139.1.1", "10.139.1.2") -module Client (R: Mirage_crypto_rng_mirage.S) (Time: Mirage_time.S) (Clock : Mirage_clock.MCLOCK) (NET: Mirage_net.S) (DB : Qubes.S.DB) = struct - module E = Ethernet.Make(NET) - module A = Arp.Make(E)(Time) - module I = Qubesdb_ipv4.Make(DB)(R)(Clock)(E)(A) - module Icmp = Icmpv4.Make(I) - module U = Udp.Make(I)(R) - module T = Tcp.Flow.Make(I)(Time)(Clock)(R) - - module Alcotest = Alcotest_mirage.Make(Clock) +module Client + (R : Mirage_crypto_rng_mirage.S) + (Time : Mirage_time.S) + (Clock : Mirage_clock.MCLOCK) + (NET : Mirage_net.S) + (DB : Qubes.S.DB) = +struct + module E = Ethernet.Make (NET) + module A = Arp.Make (E) (Time) + module I = Qubesdb_ipv4.Make (DB) (R) (Clock) (E) (A) + module Icmp = Icmpv4.Make (I) + module U = Udp.Make (I) (R) + module T = Tcp.Flow.Make (I) (Time) (Clock) (R) + module Alcotest = Alcotest_mirage.Make (Clock) module Stack = struct (* A Mirage_stack.V4 implementation which diverts DHCP messages to a DHCP @@ -66,67 +74,77 @@ module Client (R: Mirage_crypto_rng_mirage.S) (Time: Mirage_time.S) (Clock : Mir module IPV4 = I type t = { - net : NET.t ; eth : E.t ; arp : A.t ; - ip : I.t ; icmp : Icmp.t ; udp : U.t ; tcp : T.t ; - udp_listeners : (int, U.callback) Hashtbl.t ; - tcp_listeners : (int, T.listener) Hashtbl.t ; - mutable icmp_listener : (src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> Cstruct.t -> unit Lwt.t) option ; + net : NET.t; + eth : E.t; + arp : A.t; + ip : I.t; + icmp : Icmp.t; + udp : U.t; + tcp : T.t; + udp_listeners : (int, U.callback) Hashtbl.t; + tcp_listeners : (int, T.listener) Hashtbl.t; + mutable icmp_listener : + (src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> Cstruct.t -> unit Lwt.t) option; } - let ipv4 { ip ; _ } = ip - let udpv4 { udp ; _ } = udp - let tcpv4 { tcp ; _ } = tcp - let icmpv4 { icmp ; _ } = icmp - + let ipv4 { ip; _ } = ip + let udpv4 { udp; _ } = udp + let tcpv4 { tcp; _ } = tcp + let icmpv4 { icmp; _ } = icmp let listener h port = Hashtbl.find_opt h port let udp_listener h ~dst_port = listener h dst_port - let listen_udpv4 { udp_listeners ; _ } ~port cb = + let listen_udpv4 { udp_listeners; _ } ~port cb = Hashtbl.replace udp_listeners port cb - let stop_listen_udpv4 { udp_listeners ; _ } ~port = + let stop_listen_udpv4 { udp_listeners; _ } ~port = Hashtbl.remove udp_listeners port - let listen_tcpv4 ?keepalive { tcp_listeners ; _ } ~port cb = - Hashtbl.replace tcp_listeners port { T.process = cb ; T.keepalive } + let listen_tcpv4 ?keepalive { tcp_listeners; _ } ~port cb = + Hashtbl.replace tcp_listeners port { T.process = cb; T.keepalive } - let stop_listen_tcpv4 { tcp_listeners ; _ } ~port = + let stop_listen_tcpv4 { tcp_listeners; _ } ~port = Hashtbl.remove tcp_listeners port let listen_icmp t cb = t.icmp_listener <- cb let listen t = let ethif_listener = - E.input - ~arpv4:(A.input t.arp) - ~ipv4:( - I.input - ~tcp:(T.input t.tcp ~listeners:(listener t.tcp_listeners)) - ~udp:(U.input t.udp ~listeners:(udp_listener t.udp_listeners)) - ~default:(fun ~proto ~src ~dst buf -> - match proto with - | 1 -> - begin match t.icmp_listener with + E.input ~arpv4:(A.input t.arp) + ~ipv4: + (I.input + ~tcp:(T.input t.tcp ~listeners:(listener t.tcp_listeners)) + ~udp:(U.input t.udp ~listeners:(udp_listener t.udp_listeners)) + ~default:(fun ~proto ~src ~dst buf -> + match proto with + | 1 -> ( + match t.icmp_listener with | None -> Icmp.input t.icmp ~src ~dst buf - | Some cb -> cb ~src ~dst buf - end - | _ -> Lwt.return_unit) - t.ip) + | Some cb -> cb ~src ~dst buf) + | _ -> Lwt.return_unit) + t.ip) ~ipv6:(fun _ -> Lwt.return_unit) t.eth in NET.listen t.net ~header_size:Ethernet_wire.sizeof_ethernet ethif_listener >>= function | Error e -> - Logs.warn (fun p -> p "%a" NET.pp_error e) ; - Lwt.return_unit + Logs.warn (fun p -> p "%a" NET.pp_error e); + Lwt.return_unit | Ok _res -> Lwt.return_unit let connect net eth arp ip icmp udp tcp = - { net ; eth ; arp ; ip ; icmp ; udp ; tcp ; - udp_listeners = Hashtbl.create 2 ; - tcp_listeners = Hashtbl.create 2 ; - icmp_listener = None ; + { + net; + eth; + arp; + ip; + icmp; + udp; + tcp; + udp_listeners = Hashtbl.create 2; + tcp_listeners = Hashtbl.create 2; + icmp_listener = None; } let disconnect _ = @@ -134,31 +152,39 @@ module Client (R: Mirage_crypto_rng_mirage.S) (Time: Mirage_time.S) (Clock : Mir Lwt.return_unit end - module Dns = Dns_client_mirage.Make(R)(Time)(Clock)(Stack) + module Dns = Dns_client_mirage.Make (R) (Time) (Clock) (Stack) let make_ping_packet payload = - let echo_request = { Icmpv4_packet.code = 0; (* constant for echo request/reply *) - ty = Icmpv4_wire.Echo_request; - subheader = Icmpv4_packet.(Id_and_seq (0, 0)); } in + let echo_request = + { + Icmpv4_packet.code = 0; + (* constant for echo request/reply *) + ty = Icmpv4_wire.Echo_request; + subheader = Icmpv4_packet.(Id_and_seq (0, 0)); + } + in Icmpv4_packet.Marshal.make_cstruct echo_request ~payload let is_ping_reply src server packet = - 0 = Ipaddr.V4.(compare src @@ of_string_exn server) && - packet.Icmpv4_packet.code = 0 && - packet.Icmpv4_packet.ty = Icmpv4_wire.Echo_reply && - packet.Icmpv4_packet.subheader = Icmpv4_packet.(Id_and_seq (0, 0)) + (0 = Ipaddr.V4.(compare src @@ of_string_exn server)) + && packet.Icmpv4_packet.code = 0 + && packet.Icmpv4_packet.ty = Icmpv4_wire.Echo_reply + && packet.Icmpv4_packet.subheader = Icmpv4_packet.(Id_and_seq (0, 0)) let ping_denied_listener server resp_received stack = let icmp_listener ~src ~dst:_ buf = (* hopefully this is a reply to an ICMP echo request we sent *) - Log.info (fun f -> f "ping test: ICMP message received from %a: %a" I.pp_ipaddr src Cstruct.hexdump_pp buf); + Log.info (fun f -> + f "ping test: ICMP message received from %a: %a" I.pp_ipaddr src + Cstruct.hexdump_pp buf); match Icmpv4_packet.Unmarshal.of_cstruct buf with - | Error e -> Log.err (fun f -> f "couldn't parse ICMP packet: %s" e); - Lwt.return_unit + | Error e -> + Log.err (fun f -> f "couldn't parse ICMP packet: %s" e); + Lwt.return_unit | Ok (packet, _payload) -> - Log.info (fun f -> f "ICMP message: %a" Icmpv4_packet.pp packet); - if is_ping_reply src server packet then resp_received := true; - Lwt.return_unit + Log.info (fun f -> f "ICMP message: %a" Icmpv4_packet.pp packet); + if is_ping_reply src server packet then resp_received := true; + Lwt.return_unit in Stack.listen_icmp stack (Some icmp_listener) @@ -166,49 +192,68 @@ module Client (R: Mirage_crypto_rng_mirage.S) (Time: Mirage_time.S) (Clock : Mir let resp_received = ref false in Log.info (fun f -> f "Entering ping test: %s" server); ping_denied_listener server resp_received stack; - Icmp.write (Stack.icmpv4 stack) ~dst:(Ipaddr.V4.of_string_exn server) (make_ping_packet (Cstruct.of_string "hi")) >>= function - | Error e -> Log.err (fun f -> f "ping test: error sending ping: %a" Icmp.pp_error e); Lwt.return_unit + Icmp.write (Stack.icmpv4 stack) + ~dst:(Ipaddr.V4.of_string_exn server) + (make_ping_packet (Cstruct.of_string "hi")) + >>= function + | Error e -> + Log.err (fun f -> f "ping test: error sending ping: %a" Icmp.pp_error e); + Lwt.return_unit | Ok () -> - Log.info (fun f -> f "ping test: sent ping to %s" server); - Time.sleep_ns 2_000_000_000L >>= fun () -> - (if !resp_received then - Log.err (fun f -> f "ping test failed: server %s got a response, block expected :(" server) - else - Log.err (fun f -> f "ping test passed: successfully blocked :)") - ); - Stack.listen_icmp stack None; - Lwt.return_unit + Log.info (fun f -> f "ping test: sent ping to %s" server); + Time.sleep_ns 2_000_000_000L >>= fun () -> + if !resp_received then + Log.err (fun f -> + f "ping test failed: server %s got a response, block expected :(" + server) + else Log.err (fun f -> f "ping test passed: successfully blocked :)"); + Stack.listen_icmp stack None; + Lwt.return_unit let icmp_error_type stack () = let resp_correct = ref false in let echo_server = Ipaddr.V4.of_string_exn netvm in let icmp_callback ~src ~dst:_ buf = - if Ipaddr.V4.compare src echo_server = 0 then begin - (* TODO: check that packet is error packet *) - match Icmpv4_packet.Unmarshal.of_cstruct buf with - | Error e -> Log.err (fun f -> f "Error parsing icmp packet %s" e) - | Ok (packet, _) -> + (if Ipaddr.V4.compare src echo_server = 0 then + (* TODO: check that packet is error packet *) + match Icmpv4_packet.Unmarshal.of_cstruct buf with + | Error e -> Log.err (fun f -> f "Error parsing icmp packet %s" e) + | Ok (packet, _) -> (* TODO don't hardcode the numbers, make a datatype *) - if packet.Icmpv4_packet.code = 10 (* unreachable, admin prohibited *) + if + packet.Icmpv4_packet.code + = 10 (* unreachable, admin prohibited *) then resp_correct := true - else Log.debug (fun f -> f "Unrelated icmp packet %a" Icmpv4_packet.pp packet) - end; + else + Log.debug (fun f -> + f "Unrelated icmp packet %a" Icmpv4_packet.pp packet)); Lwt.return_unit in let content = Cstruct.of_string "important data" in Stack.listen_icmp stack (Some icmp_callback); - U.write ~src_port:1337 ~dst:echo_server ~dst_port:1338 (Stack.udpv4 stack) content >>= function - | Ok () -> (* .. listener: test with accept rule, if we get reply we're good *) - Time.sleep_ns 1_000_000_000L >>= fun () -> - if !resp_correct - then Log.info (fun m -> m "UDP fetch test to port %d succeeded :)" 1338) - else Log.err (fun f -> f "UDP fetch test to port %d: failed. :( correct response not received" 1338); - Stack.listen_icmp stack None; - Lwt.return_unit + U.write ~src_port:1337 ~dst:echo_server ~dst_port:1338 (Stack.udpv4 stack) + content + >>= function + | Ok () -> + (* .. listener: test with accept rule, if we get reply we're good *) + Time.sleep_ns 1_000_000_000L >>= fun () -> + if !resp_correct then + Log.info (fun m -> m "UDP fetch test to port %d succeeded :)" 1338) + else + Log.err (fun f -> + f + "UDP fetch test to port %d: failed. :( correct response not \ + received" + 1338); + Stack.listen_icmp stack None; + Lwt.return_unit | Error e -> - Log.err (fun f -> f "UDP fetch test to port %d failed: :( couldn't write the packet: %a" - 1338 U.pp_error e); - Lwt.return_unit + Log.err (fun f -> + f + "UDP fetch test to port %d failed: :( couldn't write the packet: \ + %a" + 1338 U.pp_error e); + Lwt.return_unit let tcp_connect msg server port tcp () = Log.info (fun f -> f "Entering tcp connect test: %s:%d" server port); @@ -216,98 +261,141 @@ module Client (R: Mirage_crypto_rng_mirage.S) (Time: Mirage_time.S) (Clock : Mir let msg' = Printf.sprintf "TCP connect test %s to %s:%d" msg server port in T.create_connection tcp (ip, port) >>= function | Ok flow -> - Log.info (fun f -> f "%s passed :)" msg'); - T.close flow - | Error e -> Log.err (fun f -> f "%s failed: Connection failed (%a) :(" msg' T.pp_error e); - Lwt.return_unit + Log.info (fun f -> f "%s passed :)" msg'); + T.close flow + | Error e -> + Log.err (fun f -> + f "%s failed: Connection failed (%a) :(" msg' T.pp_error e); + Lwt.return_unit let tcp_connect_denied msg server port tcp () = let ip = Ipaddr.V4.of_string_exn server in - let msg' = Printf.sprintf "TCP connect denied test %s to %s:%d" msg server port in - let connect = (T.create_connection tcp (ip, port) >>= function - | Ok flow -> - Log.err (fun f -> f "%s failed: Connection should be denied, but was not. :(" msg'); - T.close flow - | Error e -> Log.info (fun f -> f "%s passed (error text: %a) :)" msg' T.pp_error e); - Lwt.return_unit) + let msg' = + Printf.sprintf "TCP connect denied test %s to %s:%d" msg server port in - let timeout = ( + let connect = + T.create_connection tcp (ip, port) >>= function + | Ok flow -> + Log.err (fun f -> + f "%s failed: Connection should be denied, but was not. :(" msg'); + T.close flow + | Error e -> + Log.info (fun f -> + f "%s passed (error text: %a) :)" msg' T.pp_error e); + Lwt.return_unit + in + let timeout = Time.sleep_ns 1_000_000_000L >>= fun () -> Log.info (fun f -> f "%s passed :)" msg'); - Lwt.return_unit) + Lwt.return_unit in - Lwt.pick [ connect ; timeout ] + Lwt.pick [ connect; timeout ] let udp_fetch ~src_port ~echo_server_port stack () = - Log.info (fun f -> f "Entering udp fetch test: %d -> %s:%d" - src_port netvm echo_server_port); + Log.info (fun f -> + f "Entering udp fetch test: %d -> %s:%d" src_port netvm echo_server_port); let resp_correct = ref false in let echo_server = Ipaddr.V4.of_string_exn netvm in let content = Cstruct.of_string "important data" in - let udp_listener : U.callback = (fun ~src ~dst:_ ~src_port buf -> - Log.debug (fun f -> f "listen_udpv4 function invoked for packet: %a" Cstruct.hexdump_pp buf); - if ((0 = Ipaddr.V4.compare echo_server src) && src_port = echo_server_port) then - match Cstruct.equal buf content with - | true -> (* yay *) - Log.info (fun f -> f "UDP fetch test to port %d: passed :)" echo_server_port); + let udp_listener : U.callback = + fun ~src ~dst:_ ~src_port buf -> + Log.debug (fun f -> + f "listen_udpv4 function invoked for packet: %a" Cstruct.hexdump_pp + buf); + if 0 = Ipaddr.V4.compare echo_server src && src_port = echo_server_port + then ( + match Cstruct.equal buf content with + | true -> + (* yay *) + Log.info (fun f -> + f "UDP fetch test to port %d: passed :)" echo_server_port); resp_correct := true; Lwt.return_unit - | false -> (* oh no *) - Log.err (fun f -> f "UDP fetch test to port %d: failed. :( Packet corrupted; expected %a but got %a" - echo_server_port Cstruct.hexdump_pp content Cstruct.hexdump_pp buf); - Lwt.return_unit - else - begin - (* disregard this packet *) - Log.debug (fun f -> f "packet is not from the echo server or has the wrong source port (%d but we wanted %d)" - src_port echo_server_port); - (* don't cancel the listener, since we want to keep listening *) - Lwt.return_unit - end - ) + | false -> + (* oh no *) + Log.err (fun f -> + f + "UDP fetch test to port %d: failed. :( Packet corrupted; \ + expected %a but got %a" + echo_server_port Cstruct.hexdump_pp content Cstruct.hexdump_pp + buf); + Lwt.return_unit) + else ( + (* disregard this packet *) + Log.debug (fun f -> + f + "packet is not from the echo server or has the wrong source port \ + (%d but we wanted %d)" + src_port echo_server_port); + (* don't cancel the listener, since we want to keep listening *) + Lwt.return_unit) in Stack.listen_udpv4 stack ~port:src_port udp_listener; - U.write ~src_port ~dst:echo_server ~dst_port:echo_server_port (Stack.udpv4 stack) content >>= function - | Ok () -> (* .. listener: test with accept rule, if we get reply we're good *) - Time.sleep_ns 1_000_000_000L >>= fun () -> - Stack.stop_listen_udpv4 stack ~port:src_port; - if !resp_correct then Lwt.return_unit else begin - Log.err (fun f -> f "UDP fetch test to port %d: failed. :( correct response not received" echo_server_port); - Lwt.return_unit - end + U.write ~src_port ~dst:echo_server ~dst_port:echo_server_port + (Stack.udpv4 stack) content + >>= function + | Ok () -> + (* .. listener: test with accept rule, if we get reply we're good *) + Time.sleep_ns 1_000_000_000L >>= fun () -> + Stack.stop_listen_udpv4 stack ~port:src_port; + if !resp_correct then Lwt.return_unit + else ( + Log.err (fun f -> + f + "UDP fetch test to port %d: failed. :( correct response not \ + received" + echo_server_port); + Lwt.return_unit) | Error e -> - Log.err (fun f -> f "UDP fetch test to port %d failed: :( couldn't write the packet: %a" - echo_server_port U.pp_error e); - Lwt.return_unit + Log.err (fun f -> + f + "UDP fetch test to port %d failed: :( couldn't write the packet: \ + %a" + echo_server_port U.pp_error e); + Lwt.return_unit let dns_expect_failure ~nameserver ~hostname stack () = let lookup = Domain_name.(of_string_exn hostname |> host_exn) in - let nameserver' = `UDP, (Ipaddr.V4.of_string_exn nameserver, 53) in + let nameserver' = (`UDP, (Ipaddr.V4.of_string_exn nameserver, 53)) in let dns = Dns.create ~nameserver:nameserver' stack in Dns.gethostbyname dns lookup >>= function - | Error (`Msg s) when String.compare s "Truncated UDP response" <> 0 -> Log.debug (fun f -> f "DNS test to %s failed as expected: %s" - nameserver s); - Log.info (fun f -> f "DNS traffic to %s correctly blocked :)" nameserver); - Lwt.return_unit + | Error (`Msg s) when String.compare s "Truncated UDP response" <> 0 -> + Log.debug (fun f -> + f "DNS test to %s failed as expected: %s" nameserver s); + Log.info (fun f -> + f "DNS traffic to %s correctly blocked :)" nameserver); + Lwt.return_unit | Error (`Msg s) -> - Log.debug (fun f -> f "DNS test to %s failed unexpectedly (truncated response): %s :(" - nameserver s); - Lwt.return_unit - | Ok addr -> Log.err (fun f -> f "DNS test to %s should have been blocked, but looked up %s:%a" nameserver hostname Ipaddr.V4.pp addr); - Lwt.return_unit + Log.debug (fun f -> + f "DNS test to %s failed unexpectedly (truncated response): %s :(" + nameserver s); + Lwt.return_unit + | Ok addr -> + Log.err (fun f -> + f "DNS test to %s should have been blocked, but looked up %s:%a" + nameserver hostname Ipaddr.V4.pp addr); + Lwt.return_unit let dns_then_tcp_denied server stack () = let parsed_server = Domain_name.(of_string_exn server |> host_exn) in (* ask dns about server *) - Log.debug (fun f -> f "going to make a dns thing using nameserver %s" nameserver_1); - let dns = Dns.create ~nameserver:(`UDP, ((Ipaddr.V4.of_string_exn nameserver_1), 53)) stack in + Log.debug (fun f -> + f "going to make a dns thing using nameserver %s" nameserver_1); + let dns = + Dns.create + ~nameserver:(`UDP, (Ipaddr.V4.of_string_exn nameserver_1, 53)) + stack + in Log.debug (fun f -> f "OK, going to look up %s now" server); Dns.gethostbyname dns parsed_server >>= function - | Error (`Msg s) -> Log.err (fun f -> f "couldn't look up ip for %s: %s" server s); Lwt.return_unit + | Error (`Msg s) -> + Log.err (fun f -> f "couldn't look up ip for %s: %s" server s); + Lwt.return_unit | Ok addr -> - Log.debug (fun f -> f "looked up ip for %s: %a" server Ipaddr.V4.pp addr); - Log.err (fun f -> f "Do more stuff here!!!! :("); - Lwt.return_unit + Log.debug (fun f -> + f "looked up ip for %s: %a" server Ipaddr.V4.pp addr); + Log.err (fun f -> f "Do more stuff here!!!! :("); + Lwt.return_unit let start _random _time _clock network db = E.connect network >>= fun ethernet -> @@ -316,42 +404,64 @@ module Client (R: Mirage_crypto_rng_mirage.S) (Time: Mirage_time.S) (Clock : Mir Icmp.connect ipv4 >>= fun icmp -> U.connect ipv4 >>= fun udp -> T.connect ipv4 >>= fun tcp -> - - let stack = Stack.connect network ethernet arp ipv4 icmp udp tcp in + let stack = Stack.connect network ethernet arp ipv4 icmp udp tcp in Lwt.async (fun () -> Stack.listen stack); (* put this first because tcp_connect_denied tests also generate icmp messages *) - let general_tests : unit Alcotest.test = ("firewall tests", [ - ("UDP fetch", `Quick, udp_fetch ~src_port:9090 ~echo_server_port:1235 stack); - ("Ping expect failure", `Quick, ping_expect_failure "8.8.8.8" stack ); - (* TODO: ping_expect_success to the netvm, for which we have an icmptype rule in update-firewall.sh *) - ("ICMP error type", `Quick, icmp_error_type stack) - ] ) in + let general_tests : unit Alcotest.test = + ( "firewall tests", + [ + ( "UDP fetch", + `Quick, + udp_fetch ~src_port:9090 ~echo_server_port:1235 stack ); + ("Ping expect failure", `Quick, ping_expect_failure "8.8.8.8" stack); + (* TODO: ping_expect_success to the netvm, for which we have an icmptype rule in update-firewall.sh *) + ("ICMP error type", `Quick, icmp_error_type stack); + ] ) + in Alcotest.run ~and_exit:false "name" [ general_tests ] >>= fun () -> - let tcp_tests : unit Alcotest.test = ("tcp tests", [ - (* this test fails on 4.0R3 + let tcp_tests : unit Alcotest.test = + ( "tcp tests", + [ + (* this test fails on 4.0R3 ("TCP connect", `Quick, tcp_connect "when trying specialtarget" nameserver_1 53 tcp); *) - ("TCP connect", `Quick, tcp_connect_denied "" netvm 53 tcp); - ("TCP connect", `Quick, tcp_connect_denied "when trying below range" netvm 6667 tcp); - ("TCP connect", `Quick, tcp_connect "when trying lower bound in range" netvm 6668 tcp); - ("TCP connect", `Quick, tcp_connect "when trying upper bound in range" netvm 6670 tcp); - ("TCP connect", `Quick, tcp_connect_denied "when trying above range" netvm 6671 tcp); - ("TCP connect", `Quick, tcp_connect_denied "" netvm 8082 tcp); - ] ) in + ("TCP connect", `Quick, tcp_connect_denied "" netvm 53 tcp); + ( "TCP connect", + `Quick, + tcp_connect_denied "when trying below range" netvm 6667 tcp ); + ( "TCP connect", + `Quick, + tcp_connect "when trying lower bound in range" netvm 6668 tcp ); + ( "TCP connect", + `Quick, + tcp_connect "when trying upper bound in range" netvm 6670 tcp ); + ( "TCP connect", + `Quick, + tcp_connect_denied "when trying above range" netvm 6671 tcp ); + ("TCP connect", `Quick, tcp_connect_denied "" netvm 8082 tcp); + ] ) + in (* replace the udp-related listeners with the right one for tcp *) Alcotest.run "name" [ tcp_tests ] >>= fun () -> (* use the stack abstraction only after the other tests have run, since it's not friendly with outside use of its modules *) - let stack_tests = "stack tests", [ - ("DNS expect failure", `Quick, dns_expect_failure ~nameserver:"8.8.8.8" ~hostname:"mirage.io" stack); - - (* the test below won't work on @linse's internet, + let stack_tests = + ( "stack tests", + [ + ( "DNS expect failure", + `Quick, + dns_expect_failure ~nameserver:"8.8.8.8" ~hostname:"mirage.io" stack + ); + (* the test below won't work on @linse's internet, * because the nameserver there doesn't answer on TCP port 53, * only UDP port 53. Dns_mirage_client.ml disregards our request * to use UDP and uses TCP anyway, so this request can never work there. *) - (* If we can figure out a way to have this test unikernel do a UDP lookup with minimal pain, + (* If we can figure out a way to have this test unikernel do a UDP lookup with minimal pain, * we should re-enable this test. *) - ("DNS lookup + TCP connect", `Quick, dns_then_tcp_denied "google.com" stack); - ] in + ( "DNS lookup + TCP connect", + `Quick, + dns_then_tcp_denied "google.com" stack ); + ] ) + in Alcotest.run "name" [ stack_tests ] end diff --git a/unikernel.ml b/unikernel.ml index 28115d1..51841ae 100644 --- a/unikernel.ml +++ b/unikernel.ml @@ -6,10 +6,13 @@ open Qubes open Cmdliner let src = Logs.Src.create "unikernel" ~doc:"Main unikernel code" + module Log = (val Logs.src_log src : Logs.LOG) let nat_table_size = - let doc = Arg.info ~doc:"The number of NAT entries to allocate." [ "nat-table-size" ] in + let doc = + Arg.info ~doc:"The number of NAT entries to allocate." [ "nat-table-size" ] + in Mirage_runtime.register_arg Arg.(value & opt int 5_000 doc) let ipv4 = @@ -28,86 +31,96 @@ let ipv4_dns2 = let doc = Arg.info ~doc:"Manual Second DNS IP setting." [ "ipv4-dns2" ] in Mirage_runtime.register_arg Arg.(value & opt string "10.139.1.2" doc) - module Dns_client = Dns_client.Make(My_dns) +module Dns_client = Dns_client.Make (My_dns) - (* Set up networking and listen for incoming packets. *) - let network dns_client dns_responses dns_servers qubesDB router = - (* Report success *) - Dao.set_iptables_error qubesDB "" >>= fun () -> - (* Handle packets from both networks *) - Lwt.choose [ - Dispatcher.wait_clients Mirage_mtime.elapsed_ns dns_client dns_servers qubesDB router ; - Dispatcher.uplink_wait_update qubesDB router ; - Dispatcher.uplink_listen Mirage_mtime.elapsed_ns dns_responses router +(* Set up networking and listen for incoming packets. *) +let network dns_client dns_responses dns_servers qubesDB router = + (* Report success *) + Dao.set_iptables_error qubesDB "" >>= fun () -> + (* Handle packets from both networks *) + Lwt.choose + [ + Dispatcher.wait_clients Mirage_mtime.elapsed_ns dns_client dns_servers + qubesDB router; + Dispatcher.uplink_wait_update qubesDB router; + Dispatcher.uplink_listen Mirage_mtime.elapsed_ns dns_responses router; ] - (* Main unikernel entry point (called from auto-generated main.ml). *) - let start () = - let open Lwt.Syntax in - let start_time = Mirage_mtime.elapsed_ns () in - (* Start qrexec agent and QubesDB agent in parallel *) - let* qrexec = RExec.connect ~domid:0 () in - let agent_listener = RExec.listen qrexec Command.handler in - let* qubesDB = DB.connect ~domid:0 () in - let startup_time = - let (-) = Int64.sub in - let time_in_ns = Mirage_mtime.elapsed_ns () - start_time in - Int64.to_float time_in_ns /. 1e9 - in - Log.info (fun f -> f "QubesDB and qrexec agents connected in %.3f s" startup_time); - (* Watch for shutdown requests from Qubes *) - let shutdown_rq = - Xen_os.Lifecycle.await_shutdown_request () >>= fun (`Poweroff | `Reboot) -> - Lwt.return_unit in - (* Set up networking *) - let nat = My_nat.create ~max_entries:(nat_table_size ()) in +(* Main unikernel entry point (called from auto-generated main.ml). *) +let start () = + let open Lwt.Syntax in + let start_time = Mirage_mtime.elapsed_ns () in + (* Start qrexec agent and QubesDB agent in parallel *) + let* qrexec = RExec.connect ~domid:0 () in + let agent_listener = RExec.listen qrexec Command.handler in + let* qubesDB = DB.connect ~domid:0 () in + let startup_time = + let ( - ) = Int64.sub in + let time_in_ns = Mirage_mtime.elapsed_ns () - start_time in + Int64.to_float time_in_ns /. 1e9 + in + Log.info (fun f -> + f "QubesDB and qrexec agents connected in %.3f s" startup_time); + (* Watch for shutdown requests from Qubes *) + let shutdown_rq = + Xen_os.Lifecycle.await_shutdown_request () >>= fun (`Poweroff | `Reboot) -> + Lwt.return_unit + in + (* Set up networking *) + let nat = My_nat.create ~max_entries:(nat_table_size ()) in - let netvm_ip = Ipaddr.V4.of_string_exn (ipv4_gw ()) in - let our_ip = Ipaddr.V4.of_string_exn (ipv4 ()) in - let dns = Ipaddr.V4.of_string_exn (ipv4_dns ()) in - let dns2 = Ipaddr.V4.of_string_exn (ipv4_dns2 ()) in + let netvm_ip = Ipaddr.V4.of_string_exn (ipv4_gw ()) in + let our_ip = Ipaddr.V4.of_string_exn (ipv4 ()) in + let dns = Ipaddr.V4.of_string_exn (ipv4_dns ()) in + let dns2 = Ipaddr.V4.of_string_exn (ipv4_dns2 ()) in - let zero_ip = Ipaddr.V4.any in + let zero_ip = Ipaddr.V4.any in - let network_config = - if (netvm_ip = zero_ip && our_ip = zero_ip) then (* Read network configuration from QubesDB *) - Dao.read_network_config qubesDB >>= fun config -> - if config.netvm_ip = zero_ip || config.our_ip = zero_ip then - Log.info (fun f -> f "We currently have no netvm nor command line for setting it up, aborting..."); - assert (config.netvm_ip <> zero_ip && config.our_ip <> zero_ip); - Lwt.return config - else begin - let config:Dao.network_config = {from_cmdline=true; netvm_ip; our_ip; dns; dns2} in - Lwt.return config - end - in - network_config >>= fun config -> + let network_config = + if netvm_ip = zero_ip && our_ip = zero_ip then ( + (* Read network configuration from QubesDB *) + Dao.read_network_config qubesDB + >>= fun config -> + if config.netvm_ip = zero_ip || config.our_ip = zero_ip then + Log.info (fun f -> + f + "We currently have no netvm nor command line for setting it up, \ + aborting..."); + assert (config.netvm_ip <> zero_ip && config.our_ip <> zero_ip); + Lwt.return config) + else + let config : Dao.network_config = + { from_cmdline = true; netvm_ip; our_ip; dns; dns2 } + in + Lwt.return config + in + network_config >>= fun config -> + (* We now must have a valid netvm IP address and our IP address or crash *) + Dao.print_network_config config; - (* We now must have a valid netvm IP address and our IP address or crash *) - Dao.print_network_config config ; + (* Set up client-side networking *) + let* clients = Client_eth.create config in - (* Set up client-side networking *) - let* clients = Client_eth.create config in + (* Set up routing between networks and hosts *) + let router = Dispatcher.create ~config ~clients ~nat ~uplink:None in - (* Set up routing between networks and hosts *) - let router = Dispatcher.create - ~config - ~clients - ~nat - ~uplink:None - in + let send_dns_query = Dispatcher.send_dns_client_query router in + let dns_mvar = Lwt_mvar.create_empty () in + let nameservers = (`Udp, [ (config.Dao.dns, 53); (config.Dao.dns2, 53) ]) in + let dns_client = + Dns_client.create ~nameservers (router, send_dns_query, dns_mvar) + in - let send_dns_query = Dispatcher.send_dns_client_query router in - let dns_mvar = Lwt_mvar.create_empty () in - let nameservers = `Udp, [ config.Dao.dns, 53 ; config.Dao.dns2, 53 ] in - let dns_client = Dns_client.create ~nameservers (router, send_dns_query, dns_mvar) in + let dns_servers = [ config.Dao.dns; config.Dao.dns2 ] in + let net_listener = + network + (Dns_client.getaddrinfo dns_client Dns.Rr_map.A) + dns_mvar dns_servers qubesDB router + in - let dns_servers = [ config.Dao.dns ; config.Dao.dns2 ] in - let net_listener = network (Dns_client.getaddrinfo dns_client Dns.Rr_map.A) dns_mvar dns_servers qubesDB router in - - (* Report memory usage to XenStore *) - Memory_pressure.init (); - (* Run until something fails or we get a shutdown request. *) - Lwt.choose [agent_listener; net_listener; shutdown_rq] >>= fun () -> - (* Give the console daemon time to show any final log messages. *) - Mirage_sleep.ns (1.0 *. 1e9 |> Int64.of_float) + (* Report memory usage to XenStore *) + Memory_pressure.init (); + (* Run until something fails or we get a shutdown request. *) + Lwt.choose [ agent_listener; net_listener; shutdown_rq ] >>= fun () -> + (* Give the console daemon time to show any final log messages. *) + Mirage_sleep.ns (1.0 *. 1e9 |> Int64.of_float)