mirror of
https://github.com/mirage/qubes-mirage-firewall.git
synced 2025-02-05 09:55:30 -05:00
Support firewall rules with hostnames.
Co-Authored-By: Mindy Preston <yomimono@users.noreply.github.com> Co-Authored-By: Olle Jonsson <olle.jonsson@gmail.com> Co-Authored-By: hannes <hannes@mehnert.org> Co-Authored-By: cfcs <cfcs@users.noreply.github.com>
This commit is contained in:
parent
87df5bdcc0
commit
2d78d47591
@ -59,7 +59,7 @@ let input_arp ~fixed_arp ~iface request =
|
||||
iface#writev `ARP (fun b -> Arp_packet.encode_into response b; Arp_packet.size)
|
||||
|
||||
(** Handle an IPv4 packet from the client. *)
|
||||
let input_ipv4 get_ts cache ~iface ~router packet =
|
||||
let input_ipv4 get_ts cache ~iface ~router dns_client packet =
|
||||
let cache', r = Nat_packet.of_ipv4_packet !cache ~now:(get_ts ()) packet in
|
||||
cache := cache';
|
||||
match r with
|
||||
@ -70,7 +70,7 @@ let input_ipv4 get_ts cache ~iface ~router packet =
|
||||
| Ok (Some packet) ->
|
||||
let `IPv4 (ip, _) = packet in
|
||||
let src = ip.Ipv4_packet.src in
|
||||
if src = iface#other_ip then Firewall.ipv4_from_client router ~src:iface packet
|
||||
if src = iface#other_ip then Firewall.ipv4_from_client dns_client router ~src:iface packet
|
||||
else (
|
||||
Log.warn (fun f -> f "Incorrect source IP %a in IP packet from %a (dropping)"
|
||||
Ipaddr.V4.pp src Ipaddr.V4.pp iface#other_ip);
|
||||
@ -78,7 +78,7 @@ let input_ipv4 get_ts cache ~iface ~router packet =
|
||||
)
|
||||
|
||||
(** Connect to a new client's interface and listen for incoming frames and firewall rule changes. *)
|
||||
let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanup_tasks qubesDB =
|
||||
let add_vif get_ts { Dao.ClientVif.domid; device_id } dns_client ~client_ip ~router ~cleanup_tasks qubesDB =
|
||||
Netback.make ~domid ~device_id >>= fun backend ->
|
||||
Log.info (fun f -> f "Client %d (IP: %s) ready" domid (Ipaddr.V4.to_string client_ip));
|
||||
ClientEth.connect backend >>= fun eth ->
|
||||
@ -101,7 +101,7 @@ let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanu
|
||||
(Ipaddr.V4.to_string client_ip)
|
||||
Fmt.(list ~sep:(unit "@.") Pf_qubes.Parse_qubes.pp_rule) new_rules);
|
||||
(* empty NAT table if rules are updated: they might deny old connections *)
|
||||
My_nat.remove_connections router.Router.nat client_ip;
|
||||
My_nat.remove_connections router.Router.nat router.Router.ports client_ip;
|
||||
end);
|
||||
update new_db new_rules
|
||||
in
|
||||
@ -122,7 +122,7 @@ let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanu
|
||||
| Ok (eth, payload) ->
|
||||
match eth.Ethernet_packet.ethertype with
|
||||
| `ARP -> input_arp ~fixed_arp ~iface payload
|
||||
| `IPv4 -> input_ipv4 get_ts fragment_cache ~iface ~router payload
|
||||
| `IPv4 -> input_ipv4 get_ts fragment_cache ~iface ~router dns_client payload
|
||||
| `IPv6 -> Lwt.return_unit (* TODO: oh no! *)
|
||||
)
|
||||
>|= or_raise "Listen on client interface" Netback.pp_error)
|
||||
@ -132,13 +132,13 @@ let add_vif get_ts { Dao.ClientVif.domid; device_id } ~client_ip ~router ~cleanu
|
||||
Lwt.pick [ qubesdb_updater ; listener ]
|
||||
|
||||
(** A new client VM has been found in XenStore. Find its interface and connect to it. *)
|
||||
let add_client get_ts ~router vif client_ip qubesDB =
|
||||
let add_client get_ts dns_client ~router vif client_ip qubesDB =
|
||||
let cleanup_tasks = Cleanup.create () in
|
||||
Log.info (fun f -> f "add client vif %a with IP %a"
|
||||
Dao.ClientVif.pp vif Ipaddr.V4.pp client_ip);
|
||||
Lwt.async (fun () ->
|
||||
Lwt.catch (fun () ->
|
||||
add_vif get_ts vif ~client_ip ~router ~cleanup_tasks qubesDB
|
||||
add_vif get_ts vif dns_client ~client_ip ~router ~cleanup_tasks qubesDB
|
||||
)
|
||||
(fun ex ->
|
||||
Log.warn (fun f -> f "Error with client %a: %s"
|
||||
@ -149,7 +149,7 @@ let add_client get_ts ~router vif client_ip qubesDB =
|
||||
cleanup_tasks
|
||||
|
||||
(** Watch XenStore for notifications of new clients. *)
|
||||
let listen get_ts qubesDB router =
|
||||
let listen get_ts dns_client qubesDB router =
|
||||
Dao.watch_clients (fun new_set ->
|
||||
(* Check for removed clients *)
|
||||
!clients |> Dao.VifMap.iter (fun key cleanup ->
|
||||
@ -162,7 +162,7 @@ let listen get_ts qubesDB router =
|
||||
(* Check for added clients *)
|
||||
new_set |> Dao.VifMap.iter (fun key ip_addr ->
|
||||
if not (Dao.VifMap.mem key !clients) then (
|
||||
let cleanup = add_client get_ts ~router key ip_addr qubesDB in
|
||||
let cleanup = add_client get_ts dns_client ~router key ip_addr qubesDB in
|
||||
Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key);
|
||||
clients := !clients |> Dao.VifMap.add key cleanup
|
||||
)
|
||||
|
@ -3,8 +3,10 @@
|
||||
|
||||
(** Handling client VMs. *)
|
||||
|
||||
val listen : (unit -> int64) -> Qubes.DB.t -> Router.t -> 'a Lwt.t
|
||||
(** [listen get_timestamp db router] is a thread that watches for clients being added to and
|
||||
val listen : (unit -> int64) ->
|
||||
([ `host ] Domain_name.t -> (int32 * Dns.Rr_map.Ipv4_set.t, [> `Msg of string ]) result Lwt.t) ->
|
||||
Qubes.DB.t -> Router.t -> 'a Lwt.t
|
||||
(** [listen get_timestamp resolver db router] is a thread that watches for clients being added to and
|
||||
removed from XenStore. Clients are connected to the client network and
|
||||
packets are sent via [router]. We ensure the source IP address is correct
|
||||
before routing a packet. *)
|
||||
|
@ -34,6 +34,7 @@ let main =
|
||||
package "mirage-nat" ~min:"2.2.1";
|
||||
package "mirage-logs";
|
||||
package "mirage-xen" ~min:"5.0.0";
|
||||
package ~min:"4.5.0" "dns-client";
|
||||
package "pf-qubes";
|
||||
]
|
||||
"Unikernel.Main" (random @-> mclock @-> job)
|
||||
|
18
firewall.ml
18
firewall.ml
@ -45,8 +45,9 @@ let translate t packet =
|
||||
|
||||
(* Add a NAT rule for the endpoints in this frame, via a random port on the firewall. *)
|
||||
let add_nat_and_forward_ipv4 t packet =
|
||||
let xl_host = t.Router.uplink#my_ip in
|
||||
My_nat.add_nat_rule_and_translate t.Router.nat ~xl_host `NAT packet >>= function
|
||||
let open Router in
|
||||
let xl_host = t.uplink#my_ip in
|
||||
My_nat.add_nat_rule_and_translate t.nat t.ports ~xl_host `NAT packet >>= function
|
||||
| Ok packet -> forward_ipv4 t packet
|
||||
| Error e ->
|
||||
Log.warn (fun f -> f "Failed to add NAT rewrite rule: %s (%a)" e Nat_packet.pp packet);
|
||||
@ -54,11 +55,12 @@ let add_nat_and_forward_ipv4 t packet =
|
||||
|
||||
(* Add a NAT rule to redirect this conversation to [host:port] instead of us. *)
|
||||
let nat_to t ~host ~port packet =
|
||||
match Router.resolve t host with
|
||||
let open Router in
|
||||
match resolve t host with
|
||||
| Ipaddr.V6 _ -> Log.warn (fun f -> f "Cannot NAT with IPv6"); Lwt.return_unit
|
||||
| Ipaddr.V4 target ->
|
||||
let xl_host = t.Router.uplink#my_ip in
|
||||
My_nat.add_nat_rule_and_translate t.Router.nat ~xl_host (`Redirect (target, port)) packet >>= function
|
||||
let xl_host = t.uplink#my_ip in
|
||||
My_nat.add_nat_rule_and_translate t.nat t.ports ~xl_host (`Redirect (target, port)) packet >>= function
|
||||
| Ok packet -> forward_ipv4 t packet
|
||||
| Error e ->
|
||||
Log.warn (fun f -> f "Failed to add NAT redirect rule: %s (%a)" e Nat_packet.pp packet);
|
||||
@ -85,11 +87,11 @@ let handle_low_memory t =
|
||||
match Memory_pressure.status () with
|
||||
| `Memory_critical -> (* TODO: should happen before copying and async *)
|
||||
Log.warn (fun f -> f "Memory low - dropping packet and resetting NAT table");
|
||||
My_nat.reset t.Router.nat >|= fun () ->
|
||||
My_nat.reset t.Router.nat t.Router.ports >|= fun () ->
|
||||
`Memory_critical
|
||||
| `Ok -> Lwt.return `Ok
|
||||
|
||||
let ipv4_from_client t ~src packet =
|
||||
let ipv4_from_client resolver t ~src packet =
|
||||
handle_low_memory t >>= function
|
||||
| `Memory_critical -> Lwt.return_unit
|
||||
| `Ok ->
|
||||
@ -102,7 +104,7 @@ let ipv4_from_client t ~src packet =
|
||||
let dst = Router.classify t (Ipaddr.V4 ip.Ipv4_packet.dst) in
|
||||
match of_mirage_nat_packet ~src:(`Client src) ~dst packet with
|
||||
| None -> Lwt.return_unit
|
||||
| Some firewall_packet -> apply_rules t Rules.from_client ~dst firewall_packet
|
||||
| Some firewall_packet -> apply_rules t (Rules.from_client resolver) ~dst firewall_packet
|
||||
|
||||
let ipv4_from_netvm t packet =
|
||||
handle_low_memory t >>= function
|
||||
|
@ -6,6 +6,8 @@
|
||||
val ipv4_from_netvm : Router.t -> Nat_packet.t -> unit Lwt.t
|
||||
(** Handle a packet from the outside world (this module will validate the source IP). *)
|
||||
|
||||
val ipv4_from_client : Router.t -> src:Fw_utils.client_link -> Nat_packet.t -> unit Lwt.t
|
||||
(* TODO the function type is a workaround, rework the module dependencies / functors to get rid of it *)
|
||||
val ipv4_from_client : ([ `host ] Domain_name.t -> (int32 * Dns.Rr_map.Ipv4_set.t, [> `Msg of string ]) result Lwt.t) ->
|
||||
Router.t -> src:Fw_utils.client_link -> Nat_packet.t -> unit Lwt.t
|
||||
(** Handle a packet from a client. Caller must check the source IP matches the client's
|
||||
before calling this. *)
|
||||
|
57
my_dns.ml
Normal file
57
my_dns.ml
Normal file
@ -0,0 +1,57 @@
|
||||
open Lwt.Infix
|
||||
|
||||
module Transport (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) = struct
|
||||
type +'a io = 'a Lwt.t
|
||||
type io_addr = Ipaddr.V4.t * int
|
||||
type ns_addr = [ `TCP | `UDP ] * io_addr
|
||||
type stack = Router.t * (src_port:int -> dst:Ipaddr.V4.t -> dst_port:int -> Cstruct.t -> (unit, [ `Msg of string ]) result Lwt.t) * (Udp_packet.t * Cstruct.t) Lwt_mvar.t
|
||||
|
||||
type t = {
|
||||
nameserver : ns_addr ;
|
||||
stack : stack ;
|
||||
timeout_ns : int64 ;
|
||||
}
|
||||
type context = { t : t ; timeout_ns : int64 ref; mutable src_port : int }
|
||||
|
||||
let nameserver t = t.nameserver
|
||||
let rng = R.generate ?g:None
|
||||
let clock = C.elapsed_ns
|
||||
|
||||
let create ?(nameserver = `UDP, (Ipaddr.V4.of_string_exn "91.239.100.100", 53)) ~timeout stack =
|
||||
{ nameserver ; stack ; timeout_ns = timeout }
|
||||
|
||||
let with_timeout ctx f =
|
||||
let timeout = OS.Time.sleep_ns !(ctx.timeout_ns) >|= fun () -> Error (`Msg "DNS request timeout") in
|
||||
let start = clock () in
|
||||
Lwt.pick [ f ; timeout ] >|= fun result ->
|
||||
let stop = clock () in
|
||||
ctx.timeout_ns := Int64.sub !(ctx.timeout_ns) (Int64.sub stop start);
|
||||
result
|
||||
|
||||
let connect ?nameserver:_ (t : t) = Lwt.return (Ok { t ; timeout_ns = ref t.timeout_ns ; src_port = 0 })
|
||||
|
||||
let send (ctx : context) buf : (unit, [> `Msg of string ]) result Lwt.t =
|
||||
let open Router in
|
||||
let open My_nat in
|
||||
let dst, dst_port = snd ctx.t.nameserver in
|
||||
let router, send_udp, _ = ctx.t.stack in
|
||||
let src_port = Ports.pick_free_port ~consult:router.ports.nat_udp router.ports.dns_udp in
|
||||
ctx.src_port <- src_port;
|
||||
with_timeout ctx (send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg)
|
||||
|
||||
let recv ctx =
|
||||
let open Router in
|
||||
let open My_nat in
|
||||
let router, _, answers = ctx.t.stack in
|
||||
with_timeout ctx
|
||||
(Lwt_mvar.take answers >|= fun (_, dns_response) -> Ok dns_response) >|= fun result ->
|
||||
router.ports.dns_udp := Ports.remove ctx.src_port !(router.ports.dns_udp);
|
||||
result
|
||||
|
||||
let close _ = Lwt.return_unit
|
||||
|
||||
let bind = Lwt.bind
|
||||
|
||||
let lift = Lwt.return
|
||||
end
|
||||
|
46
my_nat.ml
46
my_nat.ml
@ -11,6 +11,20 @@ type action = [
|
||||
| `Redirect of Mirage_nat.endpoint
|
||||
]
|
||||
|
||||
type ports = {
|
||||
nat_tcp : Ports.t ref;
|
||||
nat_udp : Ports.t ref;
|
||||
nat_icmp : Ports.t ref;
|
||||
dns_udp : Ports.t ref;
|
||||
}
|
||||
|
||||
let empty_ports () =
|
||||
let nat_tcp = ref Ports.empty in
|
||||
let nat_udp = ref Ports.empty in
|
||||
let nat_icmp = ref Ports.empty in
|
||||
let dns_udp = ref Ports.empty in
|
||||
{ nat_tcp ; nat_udp ; nat_icmp ; dns_udp }
|
||||
|
||||
module Nat = Mirage_nat_lru
|
||||
|
||||
type t = {
|
||||
@ -33,17 +47,23 @@ let translate t packet =
|
||||
None
|
||||
| Ok packet -> Some packet
|
||||
|
||||
let random_user_port () =
|
||||
1024 + Random.int (0xffff - 1024)
|
||||
let pick_free_port ~nat_ports ~dns_ports =
|
||||
Ports.pick_free_port ~consult:dns_ports nat_ports
|
||||
|
||||
let reset t =
|
||||
(* just clears the nat ports, dns ports stay as is *)
|
||||
let reset t ports =
|
||||
ports.nat_tcp := Ports.empty;
|
||||
ports.nat_udp := Ports.empty;
|
||||
ports.nat_icmp := Ports.empty;
|
||||
Nat.reset t.table
|
||||
|
||||
let remove_connections t ip =
|
||||
let Mirage_nat.{ tcp ; udp } = Nat.remove_connections t.table ip in
|
||||
ignore(tcp, udp)
|
||||
let remove_connections t ports ip =
|
||||
let freed_ports = Nat.remove_connections t.table ip in
|
||||
ports.nat_tcp := Ports.diff !(ports.nat_tcp) (Ports.of_list freed_ports.Mirage_nat.tcp);
|
||||
ports.nat_udp := Ports.diff !(ports.nat_udp) (Ports.of_list freed_ports.Mirage_nat.udp);
|
||||
ports.nat_icmp := Ports.diff !(ports.nat_icmp) (Ports.of_list freed_ports.Mirage_nat.icmp)
|
||||
|
||||
let add_nat_rule_and_translate t ~xl_host action packet =
|
||||
let add_nat_rule_and_translate t ports ~xl_host action packet =
|
||||
let apply_action xl_port =
|
||||
Lwt.catch (fun () ->
|
||||
Nat.add t.table packet (xl_host, xl_port) action
|
||||
@ -54,19 +74,25 @@ let add_nat_rule_and_translate t ~xl_host action packet =
|
||||
)
|
||||
in
|
||||
let rec aux ~retries =
|
||||
let xl_port = random_user_port () in
|
||||
let nat_ports, dns_ports =
|
||||
match packet with
|
||||
| `IPv4 (_, `TCP _) -> ports.nat_tcp, ref Ports.empty
|
||||
| `IPv4 (_, `UDP _) -> ports.nat_udp, ports.dns_udp
|
||||
| `IPv4 (_, `ICMP _) -> ports.nat_icmp, ref Ports.empty
|
||||
in
|
||||
let xl_port = pick_free_port ~nat_ports ~dns_ports in
|
||||
apply_action xl_port >>= function
|
||||
| Error `Out_of_memory ->
|
||||
(* Because hash tables resize in big steps, this can happen even if we have a fair
|
||||
chunk of free memory. *)
|
||||
Log.warn (fun f -> f "Out_of_memory adding NAT rule. Dropping NAT table...");
|
||||
reset t >>= fun () ->
|
||||
reset t ports >>= fun () ->
|
||||
aux ~retries:(retries - 1)
|
||||
| Error `Overlap when retries < 0 -> Lwt.return (Error "Too many retries")
|
||||
| Error `Overlap ->
|
||||
if retries = 0 then (
|
||||
Log.warn (fun f -> f "Failed to find a free port; resetting NAT table");
|
||||
reset t >>= fun () ->
|
||||
reset t ports >>= fun () ->
|
||||
aux ~retries:(retries - 1)
|
||||
) else (
|
||||
aux ~retries:(retries - 1)
|
||||
|
15
my_nat.mli
15
my_nat.mli
@ -3,6 +3,15 @@
|
||||
|
||||
(* Abstract over NAT interface (todo: remove this) *)
|
||||
|
||||
type ports = private {
|
||||
nat_tcp : Ports.t ref;
|
||||
nat_udp : Ports.t ref;
|
||||
nat_icmp : Ports.t ref;
|
||||
dns_udp : Ports.t ref;
|
||||
}
|
||||
|
||||
val empty_ports : unit -> ports
|
||||
|
||||
type t
|
||||
|
||||
type action = [
|
||||
@ -11,8 +20,8 @@ type action = [
|
||||
]
|
||||
|
||||
val create : max_entries:int -> t Lwt.t
|
||||
val reset : t -> unit Lwt.t
|
||||
val remove_connections : t -> Ipaddr.V4.t -> unit
|
||||
val reset : t -> ports -> unit Lwt.t
|
||||
val remove_connections : t -> ports -> Ipaddr.V4.t -> unit
|
||||
val translate : t -> Nat_packet.t -> Nat_packet.t option Lwt.t
|
||||
val add_nat_rule_and_translate : t ->
|
||||
val add_nat_rule_and_translate : t -> ports ->
|
||||
xl_host:Ipaddr.V4.t -> action -> Nat_packet.t -> (Nat_packet.t, string) result Lwt.t
|
||||
|
16
ports.ml
Normal file
16
ports.ml
Normal file
@ -0,0 +1,16 @@
|
||||
module Set = Set.Make(struct
|
||||
type t = int
|
||||
let compare a b = compare a b
|
||||
end)
|
||||
|
||||
include Set
|
||||
|
||||
let rec pick_free_port ?(retries = 10) ~consult add_to =
|
||||
let p = 1024 + Random.int (0xffff - 1024) in
|
||||
if (mem p !consult || mem p !add_to) && retries <> 0
|
||||
then pick_free_port ~retries:(retries - 1) ~consult add_to
|
||||
else
|
||||
begin
|
||||
add_to := add p !add_to;
|
||||
p
|
||||
end
|
@ -9,10 +9,13 @@ type t = {
|
||||
client_eth : Client_eth.t;
|
||||
nat : My_nat.t;
|
||||
uplink : interface;
|
||||
(* NOTE: do not try to make this pure, it relies on mvars / side effects *)
|
||||
ports : My_nat.ports;
|
||||
}
|
||||
|
||||
let create ~client_eth ~uplink ~nat =
|
||||
{ client_eth; nat; uplink }
|
||||
let ports = My_nat.empty_ports () in
|
||||
{ client_eth; nat; uplink; ports }
|
||||
|
||||
let target t buf =
|
||||
let dst_ip = buf.Ipv4_packet.dst in
|
||||
|
@ -9,6 +9,7 @@ type t = private {
|
||||
client_eth : Client_eth.t;
|
||||
nat : My_nat.t;
|
||||
uplink : interface;
|
||||
ports : My_nat.ports;
|
||||
}
|
||||
|
||||
val create :
|
||||
|
33
rules.ml
33
rules.ml
@ -49,51 +49,60 @@ module Classifier = struct
|
||||
end
|
||||
| _, _ -> false
|
||||
|
||||
let matches_dest rule packet =
|
||||
let matches_dest dns_client rule packet =
|
||||
let ip = packet.ipv4_header.Ipv4_packet.dst in
|
||||
match rule.Q.dst with
|
||||
| `any -> Lwt.return @@ `Match rule
|
||||
| `hosts subnet ->
|
||||
Lwt.return @@ if (Ipaddr.Prefix.mem Ipaddr.(V4 ip) subnet) then `Match rule else `No_match
|
||||
| `dnsname name ->
|
||||
Log.warn (fun f -> f "Resolving %a" Domain_name.pp name);
|
||||
Lwt.return @@ `No_match
|
||||
Log.debug (fun f -> f "Resolving %a" Domain_name.pp name);
|
||||
dns_client name >|= function
|
||||
| Ok (_ttl, found_ips) ->
|
||||
if Dns.Rr_map.Ipv4_set.mem ip found_ips
|
||||
then `Match rule
|
||||
else `No_match
|
||||
| Error (`Msg m) ->
|
||||
Log.warn (fun f -> f "Ignoring rule %a, could not resolve" Q.pp_rule rule);
|
||||
Log.debug (fun f -> f "%s" m);
|
||||
`No_match
|
||||
| Error _ -> assert false (* TODO: fix type of dns_client so that this case can go *)
|
||||
|
||||
end
|
||||
|
||||
let find_first_match packet acc rule =
|
||||
let find_first_match dns_client packet acc rule =
|
||||
match acc with
|
||||
| `No_match ->
|
||||
if Classifier.matches_proto rule packet
|
||||
then Classifier.matches_dest rule packet
|
||||
then Classifier.matches_dest dns_client rule packet
|
||||
else Lwt.return `No_match
|
||||
| q -> Lwt.return q
|
||||
|
||||
(* Does the packet match our rules? *)
|
||||
let classify_client_packet (packet : ([`Client of Fw_utils.client_link], _) Packet.t) =
|
||||
let classify_client_packet dns_client (packet : ([`Client of Fw_utils.client_link], _) Packet.t) =
|
||||
let (`Client client_link) = packet.src in
|
||||
let rules = client_link#get_rules in
|
||||
Lwt_list.fold_left_s (find_first_match packet) `No_match rules >|= function
|
||||
Lwt_list.fold_left_s (find_first_match dns_client packet) `No_match rules >|= function
|
||||
| `No_match -> `Drop "No matching rule; assuming default drop"
|
||||
| `Match {Q.action = Q.Accept; _} -> `Accept
|
||||
| `Match ({Q.action = Q.Drop; _} as rule) ->
|
||||
`Drop (Format.asprintf "rule number %a explicitly drops this packet" Q.pp_rule rule)
|
||||
|
||||
let translate_accepted_packets packet =
|
||||
classify_client_packet packet >|= function
|
||||
let translate_accepted_packets dns_client packet =
|
||||
classify_client_packet dns_client packet >|= function
|
||||
| `Accept -> `NAT
|
||||
| `Drop s -> `Drop s
|
||||
|
||||
(** Packets from the private interface that don't match any NAT table entry are being checked against the fw rules here *)
|
||||
let from_client (packet : ([`Client of Fw_utils.client_link], _) Packet.t) : Packet.action Lwt.t =
|
||||
let from_client dns_client (packet : ([`Client of Fw_utils.client_link], _) Packet.t) : Packet.action Lwt.t =
|
||||
match packet with
|
||||
| { dst = `Firewall; transport_header = `UDP header; _ } ->
|
||||
if header.Udp_packet.dst_port = dns_port
|
||||
then Lwt.return @@ `NAT_to (`NetVM, dns_port)
|
||||
else Lwt.return @@ `Drop "packet addressed to client gateway"
|
||||
| { dst = `External _ ; _ } | { dst = `NetVM; _ } -> translate_accepted_packets packet
|
||||
| { dst = `External _ ; _ } | { dst = `NetVM; _ } -> translate_accepted_packets dns_client packet
|
||||
| { dst = `Firewall ; _ } -> Lwt.return @@ `Drop "packet addressed to firewall itself"
|
||||
| { dst = `Client _ ; _ } -> classify_client_packet packet
|
||||
| { dst = `Client _ ; _ } -> classify_client_packet dns_client packet
|
||||
| _ -> Lwt.return @@ `Drop "could not classify packet"
|
||||
|
||||
(** Packets from the outside world that don't match any NAT table entry are being dropped by default *)
|
||||
|
15
unikernel.ml
15
unikernel.ml
@ -8,15 +8,18 @@ let src = Logs.Src.create "unikernel" ~doc:"Main unikernel code"
|
||||
module Log = (val Logs.src_log src : Logs.LOG)
|
||||
|
||||
module Main (R : Mirage_random.S)(Clock : Mirage_clock.MCLOCK) = struct
|
||||
module Uplink = Uplink.Make(R)(Clock)
|
||||
module Dns_transport = My_dns.Transport(R)(Clock)
|
||||
module Dns_client = Dns_client.Make(Dns_transport)
|
||||
|
||||
(* Set up networking and listen for incoming packets. *)
|
||||
let network uplink qubesDB router =
|
||||
let network dns_client dns_responses uplink qubesDB router =
|
||||
(* Report success *)
|
||||
Dao.set_iptables_error qubesDB "" >>= fun () ->
|
||||
(* Handle packets from both networks *)
|
||||
Lwt.choose [
|
||||
Client_net.listen Clock.elapsed_ns qubesDB router;
|
||||
Uplink.listen uplink Clock.elapsed_ns router
|
||||
Client_net.listen Clock.elapsed_ns dns_client qubesDB router;
|
||||
Uplink.listen uplink Clock.elapsed_ns dns_responses router
|
||||
]
|
||||
|
||||
(* We don't use the GUI, but it's interesting to keep an eye on it.
|
||||
@ -76,7 +79,11 @@ module Main (R : Mirage_random.S)(Clock : Mirage_clock.MCLOCK) = struct
|
||||
~nat
|
||||
in
|
||||
|
||||
let net_listener = network uplink qubesDB router in
|
||||
let send_dns_query = Uplink.send_dns_client_query uplink in
|
||||
let dns_mvar = Lwt_mvar.create_empty () in
|
||||
let dns_client = Dns_client.create (router, send_dns_query, dns_mvar) in
|
||||
|
||||
let net_listener = network (Dns_client.getaddrinfo dns_client Dns.Rr_map.A) dns_mvar uplink qubesDB router in
|
||||
|
||||
(* Report memory usage to XenStore *)
|
||||
Memory_pressure.init ();
|
||||
|
74
uplink.ml
74
uplink.ml
@ -9,15 +9,20 @@ module Eth = Ethernet.Make(Netif)
|
||||
let src = Logs.Src.create "uplink" ~doc:"Network connection to NetVM"
|
||||
module Log = (val Logs.src_log src : Logs.LOG)
|
||||
|
||||
module Arp = Arp.Make(Eth)(OS.Time)
|
||||
module Make (R:Mirage_random.S) (Clock : Mirage_clock.MCLOCK) = struct
|
||||
module Arp = Arp.Make(Eth)(OS.Time)
|
||||
module I = Static_ipv4.Make(R)(Clock)(Eth)(Arp)
|
||||
module U = Udp.Make(I)(R)
|
||||
|
||||
type t = {
|
||||
net : Netif.t;
|
||||
eth : Eth.t;
|
||||
arp : Arp.t;
|
||||
interface : interface;
|
||||
mutable fragments : Fragments.Cache.t;
|
||||
}
|
||||
type t = {
|
||||
net : Netif.t;
|
||||
eth : Eth.t;
|
||||
arp : Arp.t;
|
||||
interface : interface;
|
||||
mutable fragments : Fragments.Cache.t;
|
||||
ip : I.t;
|
||||
udp: U.t;
|
||||
}
|
||||
|
||||
class netvm_iface eth mac ~my_ip ~other_ip : interface = object
|
||||
val queue = FrameQ.create (Ipaddr.V4.to_string other_ip)
|
||||
@ -31,10 +36,26 @@ class netvm_iface eth mac ~my_ip ~other_ip : interface = object
|
||||
)
|
||||
end
|
||||
|
||||
let listen t get_ts router =
|
||||
Netif.listen t.net ~header_size:Ethernet_wire.sizeof_ethernet (fun frame ->
|
||||
(* Handle one Ethernet frame from NetVM *)
|
||||
Eth.input t.eth
|
||||
let send_dns_client_query t ~src_port ~dst ~dst_port buf =
|
||||
U.write ~src_port ~dst ~dst_port t.udp buf >|= function
|
||||
| Error s -> Log.err (fun f -> f "error sending udp packet: %a" U.pp_error s); Error (`Msg "failure")
|
||||
| Ok () -> Ok ()
|
||||
|
||||
let listen t get_ts dns_responses router =
|
||||
let handle_packet ip_header ip_packet =
|
||||
let open Udp_packet in
|
||||
|
||||
Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src);
|
||||
match ip_packet with
|
||||
| `UDP (header, packet) when Ports.mem header.dst_port !(router.Router.ports.My_nat.dns_udp) ->
|
||||
Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port);
|
||||
Lwt_mvar.put dns_responses (header, packet)
|
||||
| _ ->
|
||||
Firewall.ipv4_from_netvm router (`IPv4 (ip_header, ip_packet))
|
||||
in
|
||||
Netif.listen t.net ~header_size:Ethernet_wire.sizeof_ethernet (fun frame ->
|
||||
(* Handle one Ethernet frame from NetVM *)
|
||||
Eth.input t.eth
|
||||
~arpv4:(Arp.input t.arp)
|
||||
~ipv4:(fun ip ->
|
||||
let cache, r =
|
||||
@ -42,30 +63,35 @@ let listen t get_ts router =
|
||||
in
|
||||
t.fragments <- cache;
|
||||
match r with
|
||||
| Error e ->
|
||||
Log.warn (fun f -> f "Ignored unknown IPv4 message from uplink: %a" Nat_packet.pp_error e);
|
||||
Lwt.return_unit
|
||||
| Ok None -> Lwt.return_unit
|
||||
| Ok (Some packet) ->
|
||||
Firewall.ipv4_from_netvm router packet
|
||||
)
|
||||
| Error e ->
|
||||
Log.warn (fun f -> f "Ignored unknown IPv4 message from uplink: %a" Nat_packet.pp_error e);
|
||||
Lwt.return ()
|
||||
| Ok None -> Lwt.return_unit
|
||||
| Ok (Some (`IPv4 (header, packet))) -> handle_packet header packet
|
||||
)
|
||||
~ipv6:(fun _ip -> Lwt.return_unit)
|
||||
frame
|
||||
) >|= or_raise "Uplink listen loop" Netif.pp_error
|
||||
|
||||
|
||||
let interface t = t.interface
|
||||
|
||||
let connect config =
|
||||
let ip = config.Dao.uplink_our_ip in
|
||||
let my_ip = config.Dao.uplink_our_ip in
|
||||
let gateway = config.Dao.uplink_netvm_ip in
|
||||
Netif.connect "0" >>= fun net ->
|
||||
Eth.connect net >>= fun eth ->
|
||||
Arp.connect eth >>= fun arp ->
|
||||
Arp.add_ip arp ip >>= fun () ->
|
||||
Arp.add_ip arp my_ip >>= fun () ->
|
||||
let network = Ipaddr.V4.Prefix.make 0 Ipaddr.V4.any in
|
||||
I.connect ~ip:(network, my_ip) ~gateway eth arp >>= fun ip ->
|
||||
U.connect ip >>= fun udp ->
|
||||
let netvm_mac =
|
||||
Arp.query arp config.Dao.uplink_netvm_ip
|
||||
Arp.query arp gateway
|
||||
>|= or_raise "Getting MAC of our NetVM" Arp.pp_error in
|
||||
let interface = new netvm_iface eth netvm_mac
|
||||
~my_ip:ip
|
||||
~my_ip
|
||||
~other_ip:config.Dao.uplink_netvm_ip in
|
||||
let fragments = Fragments.Cache.empty (256 * 1024) in
|
||||
Lwt.return { net; eth; arp; interface ; fragments }
|
||||
Lwt.return { net; eth; arp; interface ; fragments ; ip ; udp }
|
||||
end
|
||||
|
19
uplink.mli
19
uplink.mli
@ -5,13 +5,18 @@
|
||||
|
||||
open Fw_utils
|
||||
|
||||
type t
|
||||
[@@@ocaml.warning "-67"]
|
||||
module Make (R: Mirage_random.S)(Clock : Mirage_clock.MCLOCK) : sig
|
||||
type t
|
||||
|
||||
val connect : Dao.network_config -> t Lwt.t
|
||||
(** Connect to our NetVM (gateway). *)
|
||||
val connect : Dao.network_config -> t Lwt.t
|
||||
(** Connect to our NetVM (gateway). *)
|
||||
|
||||
val interface : t -> interface
|
||||
(** The network interface to NetVM. *)
|
||||
val interface : t -> interface
|
||||
(** The network interface to NetVM. *)
|
||||
|
||||
val listen : t -> (unit -> int64) -> Router.t -> unit Lwt.t
|
||||
(** Handle incoming frames from NetVM. *)
|
||||
val listen : t -> (unit -> int64) -> (Udp_packet.t * Cstruct.t) Lwt_mvar.t -> Router.t -> unit Lwt.t
|
||||
(** Handle incoming frames from NetVM. *)
|
||||
|
||||
val send_dns_client_query: t -> src_port:int-> dst:Ipaddr.V4.t -> dst_port:int -> Cstruct.t -> (unit, [`Msg of string]) result Lwt.t
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user