diff --git a/client_net.ml b/client_net.ml index 31f3f2d..10d4412 100644 --- a/client_net.ml +++ b/client_net.ml @@ -59,7 +59,7 @@ let input_arp ~fixed_arp ~iface request = iface#writev `ARP (fun b -> Arp_packet.encode_into response b; Arp_packet.size) (** Handle an IPv4 packet from the client. *) -let input_ipv4 get_ts cache ~iface ~router packet = +let input_ipv4 get_ts cache ~iface ~router dns_client packet = let cache', r = Nat_packet.of_ipv4_packet !cache ~now:(get_ts ()) packet in cache := cache'; match r with @@ -70,7 +70,7 @@ let input_ipv4 get_ts cache ~iface ~router packet = | Ok (Some packet) -> let `IPv4 (ip, _) = packet in let src = ip.Ipv4_packet.src in - if src = iface#other_ip then Firewall.ipv4_from_client router ~src:iface packet + if src = iface#other_ip then Firewall.ipv4_from_client dns_client router ~src:iface packet else ( Log.warn (fun f -> f "Incorrect source IP %a in IP packet from %a (dropping)" Ipaddr.V4.pp src Ipaddr.V4.pp iface#other_ip); @@ -78,7 +78,7 @@ let input_ipv4 get_ts cache ~iface ~router packet = ) (** Connect to a new client's interface and listen for incoming frames and firewall rule changes. *) -let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanup_tasks qubesDB = +let add_vif get_ts { Dao.ClientVif.domid; device_id } dns_client ~client_ip ~router ~cleanup_tasks qubesDB = Netback.make ~domid ~device_id >>= fun backend -> Log.info (fun f -> f "Client %d (IP: %s) ready" domid (Ipaddr.V4.to_string client_ip)); ClientEth.connect backend >>= fun eth -> @@ -101,7 +101,7 @@ let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanu (Ipaddr.V4.to_string client_ip) Fmt.(list ~sep:(unit "@.") Pf_qubes.Parse_qubes.pp_rule) new_rules); (* empty NAT table if rules are updated: they might deny old connections *) - My_nat.remove_connections router.Router.nat client_ip; + My_nat.remove_connections router.Router.nat router.Router.ports client_ip; end); update new_db new_rules in @@ -122,7 +122,7 @@ let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanu | Ok (eth, payload) -> match eth.Ethernet_packet.ethertype with | `ARP -> input_arp ~fixed_arp ~iface payload - | `IPv4 -> input_ipv4 get_ts fragment_cache ~iface ~router payload + | `IPv4 -> input_ipv4 get_ts fragment_cache ~iface ~router dns_client payload | `IPv6 -> Lwt.return_unit (* TODO: oh no! *) ) >|= or_raise "Listen on client interface" Netback.pp_error) @@ -132,13 +132,13 @@ let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanu Lwt.pick [ qubesdb_updater ; listener ] (** A new client VM has been found in XenStore. Find its interface and connect to it. *) -let add_client get_ts ~router vif client_ip qubesDB = +let add_client get_ts dns_client ~router vif client_ip qubesDB = let cleanup_tasks = Cleanup.create () in Log.info (fun f -> f "add client vif %a with IP %a" Dao.ClientVif.pp vif Ipaddr.V4.pp client_ip); Lwt.async (fun () -> Lwt.catch (fun () -> - add_vif get_ts vif ~client_ip ~router ~cleanup_tasks qubesDB + add_vif get_ts vif dns_client ~client_ip ~router ~cleanup_tasks qubesDB ) (fun ex -> Log.warn (fun f -> f "Error with client %a: %s" @@ -149,7 +149,7 @@ let add_client get_ts ~router vif client_ip qubesDB = cleanup_tasks (** Watch XenStore for notifications of new clients. *) -let listen get_ts qubesDB router = +let listen get_ts dns_client qubesDB router = Dao.watch_clients (fun new_set -> (* Check for removed clients *) !clients |> Dao.VifMap.iter (fun key cleanup -> @@ -162,7 +162,7 @@ let listen get_ts qubesDB router = (* Check for added clients *) new_set |> Dao.VifMap.iter (fun key ip_addr -> if not (Dao.VifMap.mem key !clients) then ( - let cleanup = add_client get_ts ~router key ip_addr qubesDB in + let cleanup = add_client get_ts dns_client ~router key ip_addr qubesDB in Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key); clients := !clients |> Dao.VifMap.add key cleanup ) diff --git a/client_net.mli b/client_net.mli index 0bfbb01..fc1953a 100644 --- a/client_net.mli +++ b/client_net.mli @@ -3,8 +3,10 @@ (** Handling client VMs. *) -val listen : (unit -> int64) -> Qubes.DB.t -> Router.t -> 'a Lwt.t -(** [listen get_timestamp db router] is a thread that watches for clients being added to and +val listen : (unit -> int64) -> + ([ `host ] Domain_name.t -> (int32 * Dns.Rr_map.Ipv4_set.t, [> `Msg of string ]) result Lwt.t) -> + Qubes.DB.t -> Router.t -> 'a Lwt.t +(** [listen get_timestamp resolver db router] is a thread that watches for clients being added to and removed from XenStore. Clients are connected to the client network and packets are sent via [router]. We ensure the source IP address is correct before routing a packet. *) diff --git a/config.ml b/config.ml index 87ba926..3075006 100644 --- a/config.ml +++ b/config.ml @@ -34,6 +34,7 @@ let main = package "mirage-nat" ~min:"2.2.1"; package "mirage-logs"; package "mirage-xen" ~min:"5.0.0"; + package ~min:"4.5.0" "dns-client"; package "pf-qubes"; ] "Unikernel.Main" (random @-> mclock @-> job) diff --git a/firewall.ml b/firewall.ml index 48d4fe4..9b1587c 100644 --- a/firewall.ml +++ b/firewall.ml @@ -45,8 +45,9 @@ let translate t packet = (* Add a NAT rule for the endpoints in this frame, via a random port on the firewall. *) let add_nat_and_forward_ipv4 t packet = - let xl_host = t.Router.uplink#my_ip in - My_nat.add_nat_rule_and_translate t.Router.nat ~xl_host `NAT packet >>= function + let open Router in + let xl_host = t.uplink#my_ip in + My_nat.add_nat_rule_and_translate t.nat t.ports ~xl_host `NAT packet >>= function | Ok packet -> forward_ipv4 t packet | Error e -> Log.warn (fun f -> f "Failed to add NAT rewrite rule: %s (%a)" e Nat_packet.pp packet); @@ -54,11 +55,12 @@ let add_nat_and_forward_ipv4 t packet = (* Add a NAT rule to redirect this conversation to [host:port] instead of us. *) let nat_to t ~host ~port packet = - match Router.resolve t host with + let open Router in + match resolve t host with | Ipaddr.V6 _ -> Log.warn (fun f -> f "Cannot NAT with IPv6"); Lwt.return_unit | Ipaddr.V4 target -> - let xl_host = t.Router.uplink#my_ip in - My_nat.add_nat_rule_and_translate t.Router.nat ~xl_host (`Redirect (target, port)) packet >>= function + let xl_host = t.uplink#my_ip in + My_nat.add_nat_rule_and_translate t.nat t.ports ~xl_host (`Redirect (target, port)) packet >>= function | Ok packet -> forward_ipv4 t packet | Error e -> Log.warn (fun f -> f "Failed to add NAT redirect rule: %s (%a)" e Nat_packet.pp packet); @@ -85,11 +87,11 @@ let handle_low_memory t = match Memory_pressure.status () with | `Memory_critical -> (* TODO: should happen before copying and async *) Log.warn (fun f -> f "Memory low - dropping packet and resetting NAT table"); - My_nat.reset t.Router.nat >|= fun () -> + My_nat.reset t.Router.nat t.Router.ports >|= fun () -> `Memory_critical | `Ok -> Lwt.return `Ok -let ipv4_from_client t ~src packet = +let ipv4_from_client resolver t ~src packet = handle_low_memory t >>= function | `Memory_critical -> Lwt.return_unit | `Ok -> @@ -102,7 +104,7 @@ let ipv4_from_client t ~src packet = let dst = Router.classify t (Ipaddr.V4 ip.Ipv4_packet.dst) in match of_mirage_nat_packet ~src:(`Client src) ~dst packet with | None -> Lwt.return_unit - | Some firewall_packet -> apply_rules t Rules.from_client ~dst firewall_packet + | Some firewall_packet -> apply_rules t (Rules.from_client resolver) ~dst firewall_packet let ipv4_from_netvm t packet = handle_low_memory t >>= function diff --git a/firewall.mli b/firewall.mli index 9900f56..88f02ba 100644 --- a/firewall.mli +++ b/firewall.mli @@ -6,6 +6,8 @@ val ipv4_from_netvm : Router.t -> Nat_packet.t -> unit Lwt.t (** Handle a packet from the outside world (this module will validate the source IP). *) -val ipv4_from_client : Router.t -> src:Fw_utils.client_link -> Nat_packet.t -> unit Lwt.t +(* TODO the function type is a workaround, rework the module dependencies / functors to get rid of it *) +val ipv4_from_client : ([ `host ] Domain_name.t -> (int32 * Dns.Rr_map.Ipv4_set.t, [> `Msg of string ]) result Lwt.t) -> + Router.t -> src:Fw_utils.client_link -> Nat_packet.t -> unit Lwt.t (** Handle a packet from a client. Caller must check the source IP matches the client's before calling this. *) diff --git a/my_dns.ml b/my_dns.ml new file mode 100644 index 0000000..c94cbb1 --- /dev/null +++ b/my_dns.ml @@ -0,0 +1,57 @@ +open Lwt.Infix + +module Transport (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) = struct + type +'a io = 'a Lwt.t + type io_addr = Ipaddr.V4.t * int + type ns_addr = [ `TCP | `UDP ] * io_addr + type stack = Router.t * (src_port:int -> dst:Ipaddr.V4.t -> dst_port:int -> Cstruct.t -> (unit, [ `Msg of string ]) result Lwt.t) * (Udp_packet.t * Cstruct.t) Lwt_mvar.t + + type t = { + nameserver : ns_addr ; + stack : stack ; + timeout_ns : int64 ; + } + type context = { t : t ; timeout_ns : int64 ref; mutable src_port : int } + + let nameserver t = t.nameserver + let rng = R.generate ?g:None + let clock = C.elapsed_ns + + let create ?(nameserver = `UDP, (Ipaddr.V4.of_string_exn "91.239.100.100", 53)) ~timeout stack = + { nameserver ; stack ; timeout_ns = timeout } + + let with_timeout ctx f = + let timeout = OS.Time.sleep_ns !(ctx.timeout_ns) >|= fun () -> Error (`Msg "DNS request timeout") in + let start = clock () in + Lwt.pick [ f ; timeout ] >|= fun result -> + let stop = clock () in + ctx.timeout_ns := Int64.sub !(ctx.timeout_ns) (Int64.sub stop start); + result + + let connect ?nameserver:_ (t : t) = Lwt.return (Ok { t ; timeout_ns = ref t.timeout_ns ; src_port = 0 }) + + let send (ctx : context) buf : (unit, [> `Msg of string ]) result Lwt.t = + let open Router in + let open My_nat in + let dst, dst_port = snd ctx.t.nameserver in + let router, send_udp, _ = ctx.t.stack in + let src_port = Ports.pick_free_port ~consult:router.ports.nat_udp router.ports.dns_udp in + ctx.src_port <- src_port; + with_timeout ctx (send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg) + + let recv ctx = + let open Router in + let open My_nat in + let router, _, answers = ctx.t.stack in + with_timeout ctx + (Lwt_mvar.take answers >|= fun (_, dns_response) -> Ok dns_response) >|= fun result -> + router.ports.dns_udp := Ports.remove ctx.src_port !(router.ports.dns_udp); + result + + let close _ = Lwt.return_unit + + let bind = Lwt.bind + + let lift = Lwt.return +end + diff --git a/my_nat.ml b/my_nat.ml index 9dfcf68..2652ff5 100644 --- a/my_nat.ml +++ b/my_nat.ml @@ -11,6 +11,20 @@ type action = [ | `Redirect of Mirage_nat.endpoint ] +type ports = { + nat_tcp : Ports.t ref; + nat_udp : Ports.t ref; + nat_icmp : Ports.t ref; + dns_udp : Ports.t ref; +} + +let empty_ports () = + let nat_tcp = ref Ports.empty in + let nat_udp = ref Ports.empty in + let nat_icmp = ref Ports.empty in + let dns_udp = ref Ports.empty in + { nat_tcp ; nat_udp ; nat_icmp ; dns_udp } + module Nat = Mirage_nat_lru type t = { @@ -33,17 +47,23 @@ let translate t packet = None | Ok packet -> Some packet -let random_user_port () = - 1024 + Random.int (0xffff - 1024) +let pick_free_port ~nat_ports ~dns_ports = + Ports.pick_free_port ~consult:dns_ports nat_ports -let reset t = +(* just clears the nat ports, dns ports stay as is *) +let reset t ports = + ports.nat_tcp := Ports.empty; + ports.nat_udp := Ports.empty; + ports.nat_icmp := Ports.empty; Nat.reset t.table -let remove_connections t ip = - let Mirage_nat.{ tcp ; udp } = Nat.remove_connections t.table ip in - ignore(tcp, udp) +let remove_connections t ports ip = + let freed_ports = Nat.remove_connections t.table ip in + ports.nat_tcp := Ports.diff !(ports.nat_tcp) (Ports.of_list freed_ports.Mirage_nat.tcp); + ports.nat_udp := Ports.diff !(ports.nat_udp) (Ports.of_list freed_ports.Mirage_nat.udp); + ports.nat_icmp := Ports.diff !(ports.nat_icmp) (Ports.of_list freed_ports.Mirage_nat.icmp) -let add_nat_rule_and_translate t ~xl_host action packet = +let add_nat_rule_and_translate t ports ~xl_host action packet = let apply_action xl_port = Lwt.catch (fun () -> Nat.add t.table packet (xl_host, xl_port) action @@ -54,19 +74,25 @@ let add_nat_rule_and_translate t ~xl_host action packet = ) in let rec aux ~retries = - let xl_port = random_user_port () in + let nat_ports, dns_ports = + match packet with + | `IPv4 (_, `TCP _) -> ports.nat_tcp, ref Ports.empty + | `IPv4 (_, `UDP _) -> ports.nat_udp, ports.dns_udp + | `IPv4 (_, `ICMP _) -> ports.nat_icmp, ref Ports.empty + in + let xl_port = pick_free_port ~nat_ports ~dns_ports in apply_action xl_port >>= function | Error `Out_of_memory -> (* Because hash tables resize in big steps, this can happen even if we have a fair chunk of free memory. *) Log.warn (fun f -> f "Out_of_memory adding NAT rule. Dropping NAT table..."); - reset t >>= fun () -> + reset t ports >>= fun () -> aux ~retries:(retries - 1) | Error `Overlap when retries < 0 -> Lwt.return (Error "Too many retries") | Error `Overlap -> if retries = 0 then ( Log.warn (fun f -> f "Failed to find a free port; resetting NAT table"); - reset t >>= fun () -> + reset t ports >>= fun () -> aux ~retries:(retries - 1) ) else ( aux ~retries:(retries - 1) diff --git a/my_nat.mli b/my_nat.mli index fc2049d..2ee21e0 100644 --- a/my_nat.mli +++ b/my_nat.mli @@ -3,6 +3,15 @@ (* Abstract over NAT interface (todo: remove this) *) +type ports = private { + nat_tcp : Ports.t ref; + nat_udp : Ports.t ref; + nat_icmp : Ports.t ref; + dns_udp : Ports.t ref; +} + +val empty_ports : unit -> ports + type t type action = [ @@ -11,8 +20,8 @@ type action = [ ] val create : max_entries:int -> t Lwt.t -val reset : t -> unit Lwt.t -val remove_connections : t -> Ipaddr.V4.t -> unit +val reset : t -> ports -> unit Lwt.t +val remove_connections : t -> ports -> Ipaddr.V4.t -> unit val translate : t -> Nat_packet.t -> Nat_packet.t option Lwt.t -val add_nat_rule_and_translate : t -> +val add_nat_rule_and_translate : t -> ports -> xl_host:Ipaddr.V4.t -> action -> Nat_packet.t -> (Nat_packet.t, string) result Lwt.t diff --git a/ports.ml b/ports.ml new file mode 100644 index 0000000..59d3205 --- /dev/null +++ b/ports.ml @@ -0,0 +1,16 @@ +module Set = Set.Make(struct + type t = int + let compare a b = compare a b +end) + +include Set + +let rec pick_free_port ?(retries = 10) ~consult add_to = + let p = 1024 + Random.int (0xffff - 1024) in + if (mem p !consult || mem p !add_to) && retries <> 0 + then pick_free_port ~retries:(retries - 1) ~consult add_to + else + begin + add_to := add p !add_to; + p + end diff --git a/router.ml b/router.ml index 4d7ed90..b91da74 100644 --- a/router.ml +++ b/router.ml @@ -9,10 +9,13 @@ type t = { client_eth : Client_eth.t; nat : My_nat.t; uplink : interface; + (* NOTE: do not try to make this pure, it relies on mvars / side effects *) + ports : My_nat.ports; } let create ~client_eth ~uplink ~nat = - { client_eth; nat; uplink } + let ports = My_nat.empty_ports () in + { client_eth; nat; uplink; ports } let target t buf = let dst_ip = buf.Ipv4_packet.dst in diff --git a/router.mli b/router.mli index 34fa86b..610bddd 100644 --- a/router.mli +++ b/router.mli @@ -9,6 +9,7 @@ type t = private { client_eth : Client_eth.t; nat : My_nat.t; uplink : interface; + ports : My_nat.ports; } val create : diff --git a/rules.ml b/rules.ml index cb6bb6f..da4706c 100644 --- a/rules.ml +++ b/rules.ml @@ -49,51 +49,60 @@ module Classifier = struct end | _, _ -> false - let matches_dest rule packet = + let matches_dest dns_client rule packet = let ip = packet.ipv4_header.Ipv4_packet.dst in match rule.Q.dst with | `any -> Lwt.return @@ `Match rule | `hosts subnet -> Lwt.return @@ if (Ipaddr.Prefix.mem Ipaddr.(V4 ip) subnet) then `Match rule else `No_match | `dnsname name -> - Log.warn (fun f -> f "Resolving %a" Domain_name.pp name); - Lwt.return @@ `No_match + Log.debug (fun f -> f "Resolving %a" Domain_name.pp name); + dns_client name >|= function + | Ok (_ttl, found_ips) -> + if Dns.Rr_map.Ipv4_set.mem ip found_ips + then `Match rule + else `No_match + | Error (`Msg m) -> + Log.warn (fun f -> f "Ignoring rule %a, could not resolve" Q.pp_rule rule); + Log.debug (fun f -> f "%s" m); + `No_match + | Error _ -> assert false (* TODO: fix type of dns_client so that this case can go *) end -let find_first_match packet acc rule = +let find_first_match dns_client packet acc rule = match acc with | `No_match -> if Classifier.matches_proto rule packet - then Classifier.matches_dest rule packet + then Classifier.matches_dest dns_client rule packet else Lwt.return `No_match | q -> Lwt.return q (* Does the packet match our rules? *) -let classify_client_packet (packet : ([`Client of Fw_utils.client_link], _) Packet.t) = +let classify_client_packet dns_client (packet : ([`Client of Fw_utils.client_link], _) Packet.t) = let (`Client client_link) = packet.src in let rules = client_link#get_rules in - Lwt_list.fold_left_s (find_first_match packet) `No_match rules >|= function + Lwt_list.fold_left_s (find_first_match dns_client packet) `No_match rules >|= function | `No_match -> `Drop "No matching rule; assuming default drop" | `Match {Q.action = Q.Accept; _} -> `Accept | `Match ({Q.action = Q.Drop; _} as rule) -> `Drop (Format.asprintf "rule number %a explicitly drops this packet" Q.pp_rule rule) -let translate_accepted_packets packet = - classify_client_packet packet >|= function +let translate_accepted_packets dns_client packet = + classify_client_packet dns_client packet >|= function | `Accept -> `NAT | `Drop s -> `Drop s (** Packets from the private interface that don't match any NAT table entry are being checked against the fw rules here *) -let from_client (packet : ([`Client of Fw_utils.client_link], _) Packet.t) : Packet.action Lwt.t = +let from_client dns_client (packet : ([`Client of Fw_utils.client_link], _) Packet.t) : Packet.action Lwt.t = match packet with | { dst = `Firewall; transport_header = `UDP header; _ } -> if header.Udp_packet.dst_port = dns_port then Lwt.return @@ `NAT_to (`NetVM, dns_port) else Lwt.return @@ `Drop "packet addressed to client gateway" - | { dst = `External _ ; _ } | { dst = `NetVM; _ } -> translate_accepted_packets packet + | { dst = `External _ ; _ } | { dst = `NetVM; _ } -> translate_accepted_packets dns_client packet | { dst = `Firewall ; _ } -> Lwt.return @@ `Drop "packet addressed to firewall itself" - | { dst = `Client _ ; _ } -> classify_client_packet packet + | { dst = `Client _ ; _ } -> classify_client_packet dns_client packet | _ -> Lwt.return @@ `Drop "could not classify packet" (** Packets from the outside world that don't match any NAT table entry are being dropped by default *) diff --git a/unikernel.ml b/unikernel.ml index 7a3b1d7..72f2c83 100644 --- a/unikernel.ml +++ b/unikernel.ml @@ -8,15 +8,18 @@ let src = Logs.Src.create "unikernel" ~doc:"Main unikernel code" module Log = (val Logs.src_log src : Logs.LOG) module Main (R : Mirage_random.S)(Clock : Mirage_clock.MCLOCK) = struct + module Uplink = Uplink.Make(R)(Clock) + module Dns_transport = My_dns.Transport(R)(Clock) + module Dns_client = Dns_client.Make(Dns_transport) (* Set up networking and listen for incoming packets. *) - let network uplink qubesDB router = + let network dns_client dns_responses uplink qubesDB router = (* Report success *) Dao.set_iptables_error qubesDB "" >>= fun () -> (* Handle packets from both networks *) Lwt.choose [ - Client_net.listen Clock.elapsed_ns qubesDB router; - Uplink.listen uplink Clock.elapsed_ns router + Client_net.listen Clock.elapsed_ns dns_client qubesDB router; + Uplink.listen uplink Clock.elapsed_ns dns_responses router ] (* We don't use the GUI, but it's interesting to keep an eye on it. @@ -76,7 +79,11 @@ module Main (R : Mirage_random.S)(Clock : Mirage_clock.MCLOCK) = struct ~nat in - let net_listener = network uplink qubesDB router in + let send_dns_query = Uplink.send_dns_client_query uplink in + let dns_mvar = Lwt_mvar.create_empty () in + let dns_client = Dns_client.create (router, send_dns_query, dns_mvar) in + + let net_listener = network (Dns_client.getaddrinfo dns_client Dns.Rr_map.A) dns_mvar uplink qubesDB router in (* Report memory usage to XenStore *) Memory_pressure.init (); diff --git a/uplink.ml b/uplink.ml index 343eef3..d4372b3 100644 --- a/uplink.ml +++ b/uplink.ml @@ -9,15 +9,20 @@ module Eth = Ethernet.Make(Netif) let src = Logs.Src.create "uplink" ~doc:"Network connection to NetVM" module Log = (val Logs.src_log src : Logs.LOG) -module Arp = Arp.Make(Eth)(OS.Time) +module Make (R:Mirage_random.S) (Clock : Mirage_clock.MCLOCK) = struct + module Arp = Arp.Make(Eth)(OS.Time) + module I = Static_ipv4.Make(R)(Clock)(Eth)(Arp) + module U = Udp.Make(I)(R) -type t = { - net : Netif.t; - eth : Eth.t; - arp : Arp.t; - interface : interface; - mutable fragments : Fragments.Cache.t; -} + type t = { + net : Netif.t; + eth : Eth.t; + arp : Arp.t; + interface : interface; + mutable fragments : Fragments.Cache.t; + ip : I.t; + udp: U.t; + } class netvm_iface eth mac ~my_ip ~other_ip : interface = object val queue = FrameQ.create (Ipaddr.V4.to_string other_ip) @@ -31,10 +36,26 @@ class netvm_iface eth mac ~my_ip ~other_ip : interface = object ) end -let listen t get_ts router = - Netif.listen t.net ~header_size:Ethernet_wire.sizeof_ethernet (fun frame -> - (* Handle one Ethernet frame from NetVM *) - Eth.input t.eth + let send_dns_client_query t ~src_port ~dst ~dst_port buf = + U.write ~src_port ~dst ~dst_port t.udp buf >|= function + | Error s -> Log.err (fun f -> f "error sending udp packet: %a" U.pp_error s); Error (`Msg "failure") + | Ok () -> Ok () + + let listen t get_ts dns_responses router = + let handle_packet ip_header ip_packet = + let open Udp_packet in + + Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src); + match ip_packet with + | `UDP (header, packet) when Ports.mem header.dst_port !(router.Router.ports.My_nat.dns_udp) -> + Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port); + Lwt_mvar.put dns_responses (header, packet) + | _ -> + Firewall.ipv4_from_netvm router (`IPv4 (ip_header, ip_packet)) + in + Netif.listen t.net ~header_size:Ethernet_wire.sizeof_ethernet (fun frame -> + (* Handle one Ethernet frame from NetVM *) + Eth.input t.eth ~arpv4:(Arp.input t.arp) ~ipv4:(fun ip -> let cache, r = @@ -42,30 +63,35 @@ let listen t get_ts router = in t.fragments <- cache; match r with - | Error e -> - Log.warn (fun f -> f "Ignored unknown IPv4 message from uplink: %a" Nat_packet.pp_error e); - Lwt.return_unit - | Ok None -> Lwt.return_unit - | Ok (Some packet) -> - Firewall.ipv4_from_netvm router packet - ) + | Error e -> + Log.warn (fun f -> f "Ignored unknown IPv4 message from uplink: %a" Nat_packet.pp_error e); + Lwt.return () + | Ok None -> Lwt.return_unit + | Ok (Some (`IPv4 (header, packet))) -> handle_packet header packet + ) ~ipv6:(fun _ip -> Lwt.return_unit) frame ) >|= or_raise "Uplink listen loop" Netif.pp_error + let interface t = t.interface let connect config = - let ip = config.Dao.uplink_our_ip in + let my_ip = config.Dao.uplink_our_ip in + let gateway = config.Dao.uplink_netvm_ip in Netif.connect "0" >>= fun net -> Eth.connect net >>= fun eth -> Arp.connect eth >>= fun arp -> - Arp.add_ip arp ip >>= fun () -> + Arp.add_ip arp my_ip >>= fun () -> + let network = Ipaddr.V4.Prefix.make 0 Ipaddr.V4.any in + I.connect ~ip:(network, my_ip) ~gateway eth arp >>= fun ip -> + U.connect ip >>= fun udp -> let netvm_mac = - Arp.query arp config.Dao.uplink_netvm_ip + Arp.query arp gateway >|= or_raise "Getting MAC of our NetVM" Arp.pp_error in let interface = new netvm_iface eth netvm_mac - ~my_ip:ip + ~my_ip ~other_ip:config.Dao.uplink_netvm_ip in let fragments = Fragments.Cache.empty (256 * 1024) in - Lwt.return { net; eth; arp; interface ; fragments } + Lwt.return { net; eth; arp; interface ; fragments ; ip ; udp } +end diff --git a/uplink.mli b/uplink.mli index 776b1a4..438e04a 100644 --- a/uplink.mli +++ b/uplink.mli @@ -5,13 +5,18 @@ open Fw_utils -type t +[@@@ocaml.warning "-67"] +module Make (R: Mirage_random.S)(Clock : Mirage_clock.MCLOCK) : sig + type t -val connect : Dao.network_config -> t Lwt.t -(** Connect to our NetVM (gateway). *) + val connect : Dao.network_config -> t Lwt.t + (** Connect to our NetVM (gateway). *) -val interface : t -> interface -(** The network interface to NetVM. *) + val interface : t -> interface + (** The network interface to NetVM. *) -val listen : t -> (unit -> int64) -> Router.t -> unit Lwt.t -(** Handle incoming frames from NetVM. *) + val listen : t -> (unit -> int64) -> (Udp_packet.t * Cstruct.t) Lwt_mvar.t -> Router.t -> unit Lwt.t + (** Handle incoming frames from NetVM. *) + + val send_dns_client_query: t -> src_port:int-> dst:Ipaddr.V4.t -> dst_port:int -> Cstruct.t -> (unit, [`Msg of string]) result Lwt.t +end