Support firewall rules with hostnames.

Co-Authored-By: Mindy Preston <yomimono@users.noreply.github.com>
Co-Authored-By: Olle Jonsson <olle.jonsson@gmail.com>
Co-Authored-By: hannes <hannes@mehnert.org>
Co-Authored-By: cfcs <cfcs@users.noreply.github.com>
This commit is contained in:
linse 2020-04-29 16:06:48 +02:00
parent 87df5bdcc0
commit 2d78d47591
15 changed files with 247 additions and 81 deletions

View File

@ -59,7 +59,7 @@ let input_arp ~fixed_arp ~iface request =
iface#writev `ARP (fun b -> Arp_packet.encode_into response b; Arp_packet.size) iface#writev `ARP (fun b -> Arp_packet.encode_into response b; Arp_packet.size)
(** Handle an IPv4 packet from the client. *) (** Handle an IPv4 packet from the client. *)
let input_ipv4 get_ts cache ~iface ~router packet = let input_ipv4 get_ts cache ~iface ~router dns_client packet =
let cache', r = Nat_packet.of_ipv4_packet !cache ~now:(get_ts ()) packet in let cache', r = Nat_packet.of_ipv4_packet !cache ~now:(get_ts ()) packet in
cache := cache'; cache := cache';
match r with match r with
@ -70,7 +70,7 @@ let input_ipv4 get_ts cache ~iface ~router packet =
| Ok (Some packet) -> | Ok (Some packet) ->
let `IPv4 (ip, _) = packet in let `IPv4 (ip, _) = packet in
let src = ip.Ipv4_packet.src in let src = ip.Ipv4_packet.src in
if src = iface#other_ip then Firewall.ipv4_from_client router ~src:iface packet if src = iface#other_ip then Firewall.ipv4_from_client dns_client router ~src:iface packet
else ( else (
Log.warn (fun f -> f "Incorrect source IP %a in IP packet from %a (dropping)" Log.warn (fun f -> f "Incorrect source IP %a in IP packet from %a (dropping)"
Ipaddr.V4.pp src Ipaddr.V4.pp iface#other_ip); Ipaddr.V4.pp src Ipaddr.V4.pp iface#other_ip);
@ -78,7 +78,7 @@ let input_ipv4 get_ts cache ~iface ~router packet =
) )
(** Connect to a new client's interface and listen for incoming frames and firewall rule changes. *) (** Connect to a new client's interface and listen for incoming frames and firewall rule changes. *)
let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanup_tasks qubesDB = let add_vif get_ts { Dao.ClientVif.domid; device_id } dns_client ~client_ip ~router ~cleanup_tasks qubesDB =
Netback.make ~domid ~device_id >>= fun backend -> Netback.make ~domid ~device_id >>= fun backend ->
Log.info (fun f -> f "Client %d (IP: %s) ready" domid (Ipaddr.V4.to_string client_ip)); Log.info (fun f -> f "Client %d (IP: %s) ready" domid (Ipaddr.V4.to_string client_ip));
ClientEth.connect backend >>= fun eth -> ClientEth.connect backend >>= fun eth ->
@ -101,7 +101,7 @@ let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanu
(Ipaddr.V4.to_string client_ip) (Ipaddr.V4.to_string client_ip)
Fmt.(list ~sep:(unit "@.") Pf_qubes.Parse_qubes.pp_rule) new_rules); Fmt.(list ~sep:(unit "@.") Pf_qubes.Parse_qubes.pp_rule) new_rules);
(* empty NAT table if rules are updated: they might deny old connections *) (* empty NAT table if rules are updated: they might deny old connections *)
My_nat.remove_connections router.Router.nat client_ip; My_nat.remove_connections router.Router.nat router.Router.ports client_ip;
end); end);
update new_db new_rules update new_db new_rules
in in
@ -122,7 +122,7 @@ let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanu
| Ok (eth, payload) -> | Ok (eth, payload) ->
match eth.Ethernet_packet.ethertype with match eth.Ethernet_packet.ethertype with
| `ARP -> input_arp ~fixed_arp ~iface payload | `ARP -> input_arp ~fixed_arp ~iface payload
| `IPv4 -> input_ipv4 get_ts fragment_cache ~iface ~router payload | `IPv4 -> input_ipv4 get_ts fragment_cache ~iface ~router dns_client payload
| `IPv6 -> Lwt.return_unit (* TODO: oh no! *) | `IPv6 -> Lwt.return_unit (* TODO: oh no! *)
) )
>|= or_raise "Listen on client interface" Netback.pp_error) >|= or_raise "Listen on client interface" Netback.pp_error)
@ -132,13 +132,13 @@ let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanu
Lwt.pick [ qubesdb_updater ; listener ] Lwt.pick [ qubesdb_updater ; listener ]
(** A new client VM has been found in XenStore. Find its interface and connect to it. *) (** A new client VM has been found in XenStore. Find its interface and connect to it. *)
let add_client get_ts ~router vif client_ip qubesDB = let add_client get_ts dns_client ~router vif client_ip qubesDB =
let cleanup_tasks = Cleanup.create () in let cleanup_tasks = Cleanup.create () in
Log.info (fun f -> f "add client vif %a with IP %a" Log.info (fun f -> f "add client vif %a with IP %a"
Dao.ClientVif.pp vif Ipaddr.V4.pp client_ip); Dao.ClientVif.pp vif Ipaddr.V4.pp client_ip);
Lwt.async (fun () -> Lwt.async (fun () ->
Lwt.catch (fun () -> Lwt.catch (fun () ->
add_vif get_ts vif ~client_ip ~router ~cleanup_tasks qubesDB add_vif get_ts vif dns_client ~client_ip ~router ~cleanup_tasks qubesDB
) )
(fun ex -> (fun ex ->
Log.warn (fun f -> f "Error with client %a: %s" Log.warn (fun f -> f "Error with client %a: %s"
@ -149,7 +149,7 @@ let add_client get_ts ~router vif client_ip qubesDB =
cleanup_tasks cleanup_tasks
(** Watch XenStore for notifications of new clients. *) (** Watch XenStore for notifications of new clients. *)
let listen get_ts qubesDB router = let listen get_ts dns_client qubesDB router =
Dao.watch_clients (fun new_set -> Dao.watch_clients (fun new_set ->
(* Check for removed clients *) (* Check for removed clients *)
!clients |> Dao.VifMap.iter (fun key cleanup -> !clients |> Dao.VifMap.iter (fun key cleanup ->
@ -162,7 +162,7 @@ let listen get_ts qubesDB router =
(* Check for added clients *) (* Check for added clients *)
new_set |> Dao.VifMap.iter (fun key ip_addr -> new_set |> Dao.VifMap.iter (fun key ip_addr ->
if not (Dao.VifMap.mem key !clients) then ( if not (Dao.VifMap.mem key !clients) then (
let cleanup = add_client get_ts ~router key ip_addr qubesDB in let cleanup = add_client get_ts dns_client ~router key ip_addr qubesDB in
Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key); Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key);
clients := !clients |> Dao.VifMap.add key cleanup clients := !clients |> Dao.VifMap.add key cleanup
) )

View File

@ -3,8 +3,10 @@
(** Handling client VMs. *) (** Handling client VMs. *)
val listen : (unit -> int64) -> Qubes.DB.t -> Router.t -> 'a Lwt.t val listen : (unit -> int64) ->
(** [listen get_timestamp db router] is a thread that watches for clients being added to and ([ `host ] Domain_name.t -> (int32 * Dns.Rr_map.Ipv4_set.t, [> `Msg of string ]) result Lwt.t) ->
Qubes.DB.t -> Router.t -> 'a Lwt.t
(** [listen get_timestamp resolver db router] is a thread that watches for clients being added to and
removed from XenStore. Clients are connected to the client network and removed from XenStore. Clients are connected to the client network and
packets are sent via [router]. We ensure the source IP address is correct packets are sent via [router]. We ensure the source IP address is correct
before routing a packet. *) before routing a packet. *)

View File

@ -34,6 +34,7 @@ let main =
package "mirage-nat" ~min:"2.2.1"; package "mirage-nat" ~min:"2.2.1";
package "mirage-logs"; package "mirage-logs";
package "mirage-xen" ~min:"5.0.0"; package "mirage-xen" ~min:"5.0.0";
package ~min:"4.5.0" "dns-client";
package "pf-qubes"; package "pf-qubes";
] ]
"Unikernel.Main" (random @-> mclock @-> job) "Unikernel.Main" (random @-> mclock @-> job)

View File

@ -45,8 +45,9 @@ let translate t packet =
(* Add a NAT rule for the endpoints in this frame, via a random port on the firewall. *) (* Add a NAT rule for the endpoints in this frame, via a random port on the firewall. *)
let add_nat_and_forward_ipv4 t packet = let add_nat_and_forward_ipv4 t packet =
let xl_host = t.Router.uplink#my_ip in let open Router in
My_nat.add_nat_rule_and_translate t.Router.nat ~xl_host `NAT packet >>= function let xl_host = t.uplink#my_ip in
My_nat.add_nat_rule_and_translate t.nat t.ports ~xl_host `NAT packet >>= function
| Ok packet -> forward_ipv4 t packet | Ok packet -> forward_ipv4 t packet
| Error e -> | Error e ->
Log.warn (fun f -> f "Failed to add NAT rewrite rule: %s (%a)" e Nat_packet.pp packet); Log.warn (fun f -> f "Failed to add NAT rewrite rule: %s (%a)" e Nat_packet.pp packet);
@ -54,11 +55,12 @@ let add_nat_and_forward_ipv4 t packet =
(* Add a NAT rule to redirect this conversation to [host:port] instead of us. *) (* Add a NAT rule to redirect this conversation to [host:port] instead of us. *)
let nat_to t ~host ~port packet = let nat_to t ~host ~port packet =
match Router.resolve t host with let open Router in
match resolve t host with
| Ipaddr.V6 _ -> Log.warn (fun f -> f "Cannot NAT with IPv6"); Lwt.return_unit | Ipaddr.V6 _ -> Log.warn (fun f -> f "Cannot NAT with IPv6"); Lwt.return_unit
| Ipaddr.V4 target -> | Ipaddr.V4 target ->
let xl_host = t.Router.uplink#my_ip in let xl_host = t.uplink#my_ip in
My_nat.add_nat_rule_and_translate t.Router.nat ~xl_host (`Redirect (target, port)) packet >>= function My_nat.add_nat_rule_and_translate t.nat t.ports ~xl_host (`Redirect (target, port)) packet >>= function
| Ok packet -> forward_ipv4 t packet | Ok packet -> forward_ipv4 t packet
| Error e -> | Error e ->
Log.warn (fun f -> f "Failed to add NAT redirect rule: %s (%a)" e Nat_packet.pp packet); Log.warn (fun f -> f "Failed to add NAT redirect rule: %s (%a)" e Nat_packet.pp packet);
@ -85,11 +87,11 @@ let handle_low_memory t =
match Memory_pressure.status () with match Memory_pressure.status () with
| `Memory_critical -> (* TODO: should happen before copying and async *) | `Memory_critical -> (* TODO: should happen before copying and async *)
Log.warn (fun f -> f "Memory low - dropping packet and resetting NAT table"); Log.warn (fun f -> f "Memory low - dropping packet and resetting NAT table");
My_nat.reset t.Router.nat >|= fun () -> My_nat.reset t.Router.nat t.Router.ports >|= fun () ->
`Memory_critical `Memory_critical
| `Ok -> Lwt.return `Ok | `Ok -> Lwt.return `Ok
let ipv4_from_client t ~src packet = let ipv4_from_client resolver t ~src packet =
handle_low_memory t >>= function handle_low_memory t >>= function
| `Memory_critical -> Lwt.return_unit | `Memory_critical -> Lwt.return_unit
| `Ok -> | `Ok ->
@ -102,7 +104,7 @@ let ipv4_from_client t ~src packet =
let dst = Router.classify t (Ipaddr.V4 ip.Ipv4_packet.dst) in let dst = Router.classify t (Ipaddr.V4 ip.Ipv4_packet.dst) in
match of_mirage_nat_packet ~src:(`Client src) ~dst packet with match of_mirage_nat_packet ~src:(`Client src) ~dst packet with
| None -> Lwt.return_unit | None -> Lwt.return_unit
| Some firewall_packet -> apply_rules t Rules.from_client ~dst firewall_packet | Some firewall_packet -> apply_rules t (Rules.from_client resolver) ~dst firewall_packet
let ipv4_from_netvm t packet = let ipv4_from_netvm t packet =
handle_low_memory t >>= function handle_low_memory t >>= function

View File

@ -6,6 +6,8 @@
val ipv4_from_netvm : Router.t -> Nat_packet.t -> unit Lwt.t val ipv4_from_netvm : Router.t -> Nat_packet.t -> unit Lwt.t
(** Handle a packet from the outside world (this module will validate the source IP). *) (** Handle a packet from the outside world (this module will validate the source IP). *)
val ipv4_from_client : Router.t -> src:Fw_utils.client_link -> Nat_packet.t -> unit Lwt.t (* TODO the function type is a workaround, rework the module dependencies / functors to get rid of it *)
val ipv4_from_client : ([ `host ] Domain_name.t -> (int32 * Dns.Rr_map.Ipv4_set.t, [> `Msg of string ]) result Lwt.t) ->
Router.t -> src:Fw_utils.client_link -> Nat_packet.t -> unit Lwt.t
(** Handle a packet from a client. Caller must check the source IP matches the client's (** Handle a packet from a client. Caller must check the source IP matches the client's
before calling this. *) before calling this. *)

57
my_dns.ml Normal file
View File

@ -0,0 +1,57 @@
open Lwt.Infix
module Transport (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) = struct
type +'a io = 'a Lwt.t
type io_addr = Ipaddr.V4.t * int
type ns_addr = [ `TCP | `UDP ] * io_addr
type stack = Router.t * (src_port:int -> dst:Ipaddr.V4.t -> dst_port:int -> Cstruct.t -> (unit, [ `Msg of string ]) result Lwt.t) * (Udp_packet.t * Cstruct.t) Lwt_mvar.t
type t = {
nameserver : ns_addr ;
stack : stack ;
timeout_ns : int64 ;
}
type context = { t : t ; timeout_ns : int64 ref; mutable src_port : int }
let nameserver t = t.nameserver
let rng = R.generate ?g:None
let clock = C.elapsed_ns
let create ?(nameserver = `UDP, (Ipaddr.V4.of_string_exn "91.239.100.100", 53)) ~timeout stack =
{ nameserver ; stack ; timeout_ns = timeout }
let with_timeout ctx f =
let timeout = OS.Time.sleep_ns !(ctx.timeout_ns) >|= fun () -> Error (`Msg "DNS request timeout") in
let start = clock () in
Lwt.pick [ f ; timeout ] >|= fun result ->
let stop = clock () in
ctx.timeout_ns := Int64.sub !(ctx.timeout_ns) (Int64.sub stop start);
result
let connect ?nameserver:_ (t : t) = Lwt.return (Ok { t ; timeout_ns = ref t.timeout_ns ; src_port = 0 })
let send (ctx : context) buf : (unit, [> `Msg of string ]) result Lwt.t =
let open Router in
let open My_nat in
let dst, dst_port = snd ctx.t.nameserver in
let router, send_udp, _ = ctx.t.stack in
let src_port = Ports.pick_free_port ~consult:router.ports.nat_udp router.ports.dns_udp in
ctx.src_port <- src_port;
with_timeout ctx (send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg)
let recv ctx =
let open Router in
let open My_nat in
let router, _, answers = ctx.t.stack in
with_timeout ctx
(Lwt_mvar.take answers >|= fun (_, dns_response) -> Ok dns_response) >|= fun result ->
router.ports.dns_udp := Ports.remove ctx.src_port !(router.ports.dns_udp);
result
let close _ = Lwt.return_unit
let bind = Lwt.bind
let lift = Lwt.return
end

View File

@ -11,6 +11,20 @@ type action = [
| `Redirect of Mirage_nat.endpoint | `Redirect of Mirage_nat.endpoint
] ]
type ports = {
nat_tcp : Ports.t ref;
nat_udp : Ports.t ref;
nat_icmp : Ports.t ref;
dns_udp : Ports.t ref;
}
let empty_ports () =
let nat_tcp = ref Ports.empty in
let nat_udp = ref Ports.empty in
let nat_icmp = ref Ports.empty in
let dns_udp = ref Ports.empty in
{ nat_tcp ; nat_udp ; nat_icmp ; dns_udp }
module Nat = Mirage_nat_lru module Nat = Mirage_nat_lru
type t = { type t = {
@ -33,17 +47,23 @@ let translate t packet =
None None
| Ok packet -> Some packet | Ok packet -> Some packet
let random_user_port () = let pick_free_port ~nat_ports ~dns_ports =
1024 + Random.int (0xffff - 1024) Ports.pick_free_port ~consult:dns_ports nat_ports
let reset t = (* just clears the nat ports, dns ports stay as is *)
let reset t ports =
ports.nat_tcp := Ports.empty;
ports.nat_udp := Ports.empty;
ports.nat_icmp := Ports.empty;
Nat.reset t.table Nat.reset t.table
let remove_connections t ip = let remove_connections t ports ip =
let Mirage_nat.{ tcp ; udp } = Nat.remove_connections t.table ip in let freed_ports = Nat.remove_connections t.table ip in
ignore(tcp, udp) ports.nat_tcp := Ports.diff !(ports.nat_tcp) (Ports.of_list freed_ports.Mirage_nat.tcp);
ports.nat_udp := Ports.diff !(ports.nat_udp) (Ports.of_list freed_ports.Mirage_nat.udp);
ports.nat_icmp := Ports.diff !(ports.nat_icmp) (Ports.of_list freed_ports.Mirage_nat.icmp)
let add_nat_rule_and_translate t ~xl_host action packet = let add_nat_rule_and_translate t ports ~xl_host action packet =
let apply_action xl_port = let apply_action xl_port =
Lwt.catch (fun () -> Lwt.catch (fun () ->
Nat.add t.table packet (xl_host, xl_port) action Nat.add t.table packet (xl_host, xl_port) action
@ -54,19 +74,25 @@ let add_nat_rule_and_translate t ~xl_host action packet =
) )
in in
let rec aux ~retries = let rec aux ~retries =
let xl_port = random_user_port () in let nat_ports, dns_ports =
match packet with
| `IPv4 (_, `TCP _) -> ports.nat_tcp, ref Ports.empty
| `IPv4 (_, `UDP _) -> ports.nat_udp, ports.dns_udp
| `IPv4 (_, `ICMP _) -> ports.nat_icmp, ref Ports.empty
in
let xl_port = pick_free_port ~nat_ports ~dns_ports in
apply_action xl_port >>= function apply_action xl_port >>= function
| Error `Out_of_memory -> | Error `Out_of_memory ->
(* Because hash tables resize in big steps, this can happen even if we have a fair (* Because hash tables resize in big steps, this can happen even if we have a fair
chunk of free memory. *) chunk of free memory. *)
Log.warn (fun f -> f "Out_of_memory adding NAT rule. Dropping NAT table..."); Log.warn (fun f -> f "Out_of_memory adding NAT rule. Dropping NAT table...");
reset t >>= fun () -> reset t ports >>= fun () ->
aux ~retries:(retries - 1) aux ~retries:(retries - 1)
| Error `Overlap when retries < 0 -> Lwt.return (Error "Too many retries") | Error `Overlap when retries < 0 -> Lwt.return (Error "Too many retries")
| Error `Overlap -> | Error `Overlap ->
if retries = 0 then ( if retries = 0 then (
Log.warn (fun f -> f "Failed to find a free port; resetting NAT table"); Log.warn (fun f -> f "Failed to find a free port; resetting NAT table");
reset t >>= fun () -> reset t ports >>= fun () ->
aux ~retries:(retries - 1) aux ~retries:(retries - 1)
) else ( ) else (
aux ~retries:(retries - 1) aux ~retries:(retries - 1)

View File

@ -3,6 +3,15 @@
(* Abstract over NAT interface (todo: remove this) *) (* Abstract over NAT interface (todo: remove this) *)
type ports = private {
nat_tcp : Ports.t ref;
nat_udp : Ports.t ref;
nat_icmp : Ports.t ref;
dns_udp : Ports.t ref;
}
val empty_ports : unit -> ports
type t type t
type action = [ type action = [
@ -11,8 +20,8 @@ type action = [
] ]
val create : max_entries:int -> t Lwt.t val create : max_entries:int -> t Lwt.t
val reset : t -> unit Lwt.t val reset : t -> ports -> unit Lwt.t
val remove_connections : t -> Ipaddr.V4.t -> unit val remove_connections : t -> ports -> Ipaddr.V4.t -> unit
val translate : t -> Nat_packet.t -> Nat_packet.t option Lwt.t val translate : t -> Nat_packet.t -> Nat_packet.t option Lwt.t
val add_nat_rule_and_translate : t -> val add_nat_rule_and_translate : t -> ports ->
xl_host:Ipaddr.V4.t -> action -> Nat_packet.t -> (Nat_packet.t, string) result Lwt.t xl_host:Ipaddr.V4.t -> action -> Nat_packet.t -> (Nat_packet.t, string) result Lwt.t

16
ports.ml Normal file
View File

@ -0,0 +1,16 @@
module Set = Set.Make(struct
type t = int
let compare a b = compare a b
end)
include Set
let rec pick_free_port ?(retries = 10) ~consult add_to =
let p = 1024 + Random.int (0xffff - 1024) in
if (mem p !consult || mem p !add_to) && retries <> 0
then pick_free_port ~retries:(retries - 1) ~consult add_to
else
begin
add_to := add p !add_to;
p
end

View File

@ -9,10 +9,13 @@ type t = {
client_eth : Client_eth.t; client_eth : Client_eth.t;
nat : My_nat.t; nat : My_nat.t;
uplink : interface; uplink : interface;
(* NOTE: do not try to make this pure, it relies on mvars / side effects *)
ports : My_nat.ports;
} }
let create ~client_eth ~uplink ~nat = let create ~client_eth ~uplink ~nat =
{ client_eth; nat; uplink } let ports = My_nat.empty_ports () in
{ client_eth; nat; uplink; ports }
let target t buf = let target t buf =
let dst_ip = buf.Ipv4_packet.dst in let dst_ip = buf.Ipv4_packet.dst in

View File

@ -9,6 +9,7 @@ type t = private {
client_eth : Client_eth.t; client_eth : Client_eth.t;
nat : My_nat.t; nat : My_nat.t;
uplink : interface; uplink : interface;
ports : My_nat.ports;
} }
val create : val create :

View File

@ -49,51 +49,60 @@ module Classifier = struct
end end
| _, _ -> false | _, _ -> false
let matches_dest rule packet = let matches_dest dns_client rule packet =
let ip = packet.ipv4_header.Ipv4_packet.dst in let ip = packet.ipv4_header.Ipv4_packet.dst in
match rule.Q.dst with match rule.Q.dst with
| `any -> Lwt.return @@ `Match rule | `any -> Lwt.return @@ `Match rule
| `hosts subnet -> | `hosts subnet ->
Lwt.return @@ if (Ipaddr.Prefix.mem Ipaddr.(V4 ip) subnet) then `Match rule else `No_match Lwt.return @@ if (Ipaddr.Prefix.mem Ipaddr.(V4 ip) subnet) then `Match rule else `No_match
| `dnsname name -> | `dnsname name ->
Log.warn (fun f -> f "Resolving %a" Domain_name.pp name); Log.debug (fun f -> f "Resolving %a" Domain_name.pp name);
Lwt.return @@ `No_match dns_client name >|= function
| Ok (_ttl, found_ips) ->
if Dns.Rr_map.Ipv4_set.mem ip found_ips
then `Match rule
else `No_match
| Error (`Msg m) ->
Log.warn (fun f -> f "Ignoring rule %a, could not resolve" Q.pp_rule rule);
Log.debug (fun f -> f "%s" m);
`No_match
| Error _ -> assert false (* TODO: fix type of dns_client so that this case can go *)
end end
let find_first_match packet acc rule = let find_first_match dns_client packet acc rule =
match acc with match acc with
| `No_match -> | `No_match ->
if Classifier.matches_proto rule packet if Classifier.matches_proto rule packet
then Classifier.matches_dest rule packet then Classifier.matches_dest dns_client rule packet
else Lwt.return `No_match else Lwt.return `No_match
| q -> Lwt.return q | q -> Lwt.return q
(* Does the packet match our rules? *) (* Does the packet match our rules? *)
let classify_client_packet (packet : ([`Client of Fw_utils.client_link], _) Packet.t) = let classify_client_packet dns_client (packet : ([`Client of Fw_utils.client_link], _) Packet.t) =
let (`Client client_link) = packet.src in let (`Client client_link) = packet.src in
let rules = client_link#get_rules in let rules = client_link#get_rules in
Lwt_list.fold_left_s (find_first_match packet) `No_match rules >|= function Lwt_list.fold_left_s (find_first_match dns_client packet) `No_match rules >|= function
| `No_match -> `Drop "No matching rule; assuming default drop" | `No_match -> `Drop "No matching rule; assuming default drop"
| `Match {Q.action = Q.Accept; _} -> `Accept | `Match {Q.action = Q.Accept; _} -> `Accept
| `Match ({Q.action = Q.Drop; _} as rule) -> | `Match ({Q.action = Q.Drop; _} as rule) ->
`Drop (Format.asprintf "rule number %a explicitly drops this packet" Q.pp_rule rule) `Drop (Format.asprintf "rule number %a explicitly drops this packet" Q.pp_rule rule)
let translate_accepted_packets packet = let translate_accepted_packets dns_client packet =
classify_client_packet packet >|= function classify_client_packet dns_client packet >|= function
| `Accept -> `NAT | `Accept -> `NAT
| `Drop s -> `Drop s | `Drop s -> `Drop s
(** Packets from the private interface that don't match any NAT table entry are being checked against the fw rules here *) (** Packets from the private interface that don't match any NAT table entry are being checked against the fw rules here *)
let from_client (packet : ([`Client of Fw_utils.client_link], _) Packet.t) : Packet.action Lwt.t = let from_client dns_client (packet : ([`Client of Fw_utils.client_link], _) Packet.t) : Packet.action Lwt.t =
match packet with match packet with
| { dst = `Firewall; transport_header = `UDP header; _ } -> | { dst = `Firewall; transport_header = `UDP header; _ } ->
if header.Udp_packet.dst_port = dns_port if header.Udp_packet.dst_port = dns_port
then Lwt.return @@ `NAT_to (`NetVM, dns_port) then Lwt.return @@ `NAT_to (`NetVM, dns_port)
else Lwt.return @@ `Drop "packet addressed to client gateway" else Lwt.return @@ `Drop "packet addressed to client gateway"
| { dst = `External _ ; _ } | { dst = `NetVM; _ } -> translate_accepted_packets packet | { dst = `External _ ; _ } | { dst = `NetVM; _ } -> translate_accepted_packets dns_client packet
| { dst = `Firewall ; _ } -> Lwt.return @@ `Drop "packet addressed to firewall itself" | { dst = `Firewall ; _ } -> Lwt.return @@ `Drop "packet addressed to firewall itself"
| { dst = `Client _ ; _ } -> classify_client_packet packet | { dst = `Client _ ; _ } -> classify_client_packet dns_client packet
| _ -> Lwt.return @@ `Drop "could not classify packet" | _ -> Lwt.return @@ `Drop "could not classify packet"
(** Packets from the outside world that don't match any NAT table entry are being dropped by default *) (** Packets from the outside world that don't match any NAT table entry are being dropped by default *)

View File

@ -8,15 +8,18 @@ let src = Logs.Src.create "unikernel" ~doc:"Main unikernel code"
module Log = (val Logs.src_log src : Logs.LOG) module Log = (val Logs.src_log src : Logs.LOG)
module Main (R : Mirage_random.S)(Clock : Mirage_clock.MCLOCK) = struct module Main (R : Mirage_random.S)(Clock : Mirage_clock.MCLOCK) = struct
module Uplink = Uplink.Make(R)(Clock)
module Dns_transport = My_dns.Transport(R)(Clock)
module Dns_client = Dns_client.Make(Dns_transport)
(* Set up networking and listen for incoming packets. *) (* Set up networking and listen for incoming packets. *)
let network uplink qubesDB router = let network dns_client dns_responses uplink qubesDB router =
(* Report success *) (* Report success *)
Dao.set_iptables_error qubesDB "" >>= fun () -> Dao.set_iptables_error qubesDB "" >>= fun () ->
(* Handle packets from both networks *) (* Handle packets from both networks *)
Lwt.choose [ Lwt.choose [
Client_net.listen Clock.elapsed_ns qubesDB router; Client_net.listen Clock.elapsed_ns dns_client qubesDB router;
Uplink.listen uplink Clock.elapsed_ns router Uplink.listen uplink Clock.elapsed_ns dns_responses router
] ]
(* We don't use the GUI, but it's interesting to keep an eye on it. (* We don't use the GUI, but it's interesting to keep an eye on it.
@ -76,7 +79,11 @@ module Main (R : Mirage_random.S)(Clock : Mirage_clock.MCLOCK) = struct
~nat ~nat
in in
let net_listener = network uplink qubesDB router in let send_dns_query = Uplink.send_dns_client_query uplink in
let dns_mvar = Lwt_mvar.create_empty () in
let dns_client = Dns_client.create (router, send_dns_query, dns_mvar) in
let net_listener = network (Dns_client.getaddrinfo dns_client Dns.Rr_map.A) dns_mvar uplink qubesDB router in
(* Report memory usage to XenStore *) (* Report memory usage to XenStore *)
Memory_pressure.init (); Memory_pressure.init ();

View File

@ -9,15 +9,20 @@ module Eth = Ethernet.Make(Netif)
let src = Logs.Src.create "uplink" ~doc:"Network connection to NetVM" let src = Logs.Src.create "uplink" ~doc:"Network connection to NetVM"
module Log = (val Logs.src_log src : Logs.LOG) module Log = (val Logs.src_log src : Logs.LOG)
module Arp = Arp.Make(Eth)(OS.Time) module Make (R:Mirage_random.S) (Clock : Mirage_clock.MCLOCK) = struct
module Arp = Arp.Make(Eth)(OS.Time)
module I = Static_ipv4.Make(R)(Clock)(Eth)(Arp)
module U = Udp.Make(I)(R)
type t = { type t = {
net : Netif.t; net : Netif.t;
eth : Eth.t; eth : Eth.t;
arp : Arp.t; arp : Arp.t;
interface : interface; interface : interface;
mutable fragments : Fragments.Cache.t; mutable fragments : Fragments.Cache.t;
} ip : I.t;
udp: U.t;
}
class netvm_iface eth mac ~my_ip ~other_ip : interface = object class netvm_iface eth mac ~my_ip ~other_ip : interface = object
val queue = FrameQ.create (Ipaddr.V4.to_string other_ip) val queue = FrameQ.create (Ipaddr.V4.to_string other_ip)
@ -31,10 +36,26 @@ class netvm_iface eth mac ~my_ip ~other_ip : interface = object
) )
end end
let listen t get_ts router = let send_dns_client_query t ~src_port ~dst ~dst_port buf =
Netif.listen t.net ~header_size:Ethernet_wire.sizeof_ethernet (fun frame -> U.write ~src_port ~dst ~dst_port t.udp buf >|= function
(* Handle one Ethernet frame from NetVM *) | Error s -> Log.err (fun f -> f "error sending udp packet: %a" U.pp_error s); Error (`Msg "failure")
Eth.input t.eth | Ok () -> Ok ()
let listen t get_ts dns_responses router =
let handle_packet ip_header ip_packet =
let open Udp_packet in
Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src);
match ip_packet with
| `UDP (header, packet) when Ports.mem header.dst_port !(router.Router.ports.My_nat.dns_udp) ->
Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port);
Lwt_mvar.put dns_responses (header, packet)
| _ ->
Firewall.ipv4_from_netvm router (`IPv4 (ip_header, ip_packet))
in
Netif.listen t.net ~header_size:Ethernet_wire.sizeof_ethernet (fun frame ->
(* Handle one Ethernet frame from NetVM *)
Eth.input t.eth
~arpv4:(Arp.input t.arp) ~arpv4:(Arp.input t.arp)
~ipv4:(fun ip -> ~ipv4:(fun ip ->
let cache, r = let cache, r =
@ -42,30 +63,35 @@ let listen t get_ts router =
in in
t.fragments <- cache; t.fragments <- cache;
match r with match r with
| Error e -> | Error e ->
Log.warn (fun f -> f "Ignored unknown IPv4 message from uplink: %a" Nat_packet.pp_error e); Log.warn (fun f -> f "Ignored unknown IPv4 message from uplink: %a" Nat_packet.pp_error e);
Lwt.return_unit Lwt.return ()
| Ok None -> Lwt.return_unit | Ok None -> Lwt.return_unit
| Ok (Some packet) -> | Ok (Some (`IPv4 (header, packet))) -> handle_packet header packet
Firewall.ipv4_from_netvm router packet )
)
~ipv6:(fun _ip -> Lwt.return_unit) ~ipv6:(fun _ip -> Lwt.return_unit)
frame frame
) >|= or_raise "Uplink listen loop" Netif.pp_error ) >|= or_raise "Uplink listen loop" Netif.pp_error
let interface t = t.interface let interface t = t.interface
let connect config = let connect config =
let ip = config.Dao.uplink_our_ip in let my_ip = config.Dao.uplink_our_ip in
let gateway = config.Dao.uplink_netvm_ip in
Netif.connect "0" >>= fun net -> Netif.connect "0" >>= fun net ->
Eth.connect net >>= fun eth -> Eth.connect net >>= fun eth ->
Arp.connect eth >>= fun arp -> Arp.connect eth >>= fun arp ->
Arp.add_ip arp ip >>= fun () -> Arp.add_ip arp my_ip >>= fun () ->
let network = Ipaddr.V4.Prefix.make 0 Ipaddr.V4.any in
I.connect ~ip:(network, my_ip) ~gateway eth arp >>= fun ip ->
U.connect ip >>= fun udp ->
let netvm_mac = let netvm_mac =
Arp.query arp config.Dao.uplink_netvm_ip Arp.query arp gateway
>|= or_raise "Getting MAC of our NetVM" Arp.pp_error in >|= or_raise "Getting MAC of our NetVM" Arp.pp_error in
let interface = new netvm_iface eth netvm_mac let interface = new netvm_iface eth netvm_mac
~my_ip:ip ~my_ip
~other_ip:config.Dao.uplink_netvm_ip in ~other_ip:config.Dao.uplink_netvm_ip in
let fragments = Fragments.Cache.empty (256 * 1024) in let fragments = Fragments.Cache.empty (256 * 1024) in
Lwt.return { net; eth; arp; interface ; fragments } Lwt.return { net; eth; arp; interface ; fragments ; ip ; udp }
end

View File

@ -5,13 +5,18 @@
open Fw_utils open Fw_utils
type t [@@@ocaml.warning "-67"]
module Make (R: Mirage_random.S)(Clock : Mirage_clock.MCLOCK) : sig
type t
val connect : Dao.network_config -> t Lwt.t val connect : Dao.network_config -> t Lwt.t
(** Connect to our NetVM (gateway). *) (** Connect to our NetVM (gateway). *)
val interface : t -> interface val interface : t -> interface
(** The network interface to NetVM. *) (** The network interface to NetVM. *)
val listen : t -> (unit -> int64) -> Router.t -> unit Lwt.t val listen : t -> (unit -> int64) -> (Udp_packet.t * Cstruct.t) Lwt_mvar.t -> Router.t -> unit Lwt.t
(** Handle incoming frames from NetVM. *) (** Handle incoming frames from NetVM. *)
val send_dns_client_query: t -> src_port:int-> dst:Ipaddr.V4.t -> dst_port:int -> Cstruct.t -> (unit, [`Msg of string]) result Lwt.t
end