revise port management

this needs mirage-nat at hannesm#fixes
This commit is contained in:
Hannes Mehnert 2022-10-07 20:54:49 +02:00
parent 8187096bfa
commit f2d3faf1da
9 changed files with 49 additions and 95 deletions

View File

@ -98,7 +98,7 @@ let add_vif get_ts { Dao.ClientVif.domid; device_id } dns_client dns_servers ~cl
(Ipaddr.V4.to_string client_ip) (Ipaddr.V4.to_string client_ip)
Fmt.(list ~sep:(any "@.") Pf_qubes.Parse_qubes.pp_rule) new_rules); Fmt.(list ~sep:(any "@.") Pf_qubes.Parse_qubes.pp_rule) new_rules);
(* empty NAT table if rules are updated: they might deny old connections *) (* empty NAT table if rules are updated: they might deny old connections *)
My_nat.remove_connections router.Router.nat router.Router.ports client_ip; My_nat.remove_connections router.Router.nat client_ip;
end); end);
update new_db new_rules update new_db new_rules
in in

View File

@ -47,7 +47,7 @@ let translate t packet =
let add_nat_and_forward_ipv4 t packet = let add_nat_and_forward_ipv4 t packet =
let open Router in let open Router in
let xl_host = t.uplink#my_ip in let xl_host = t.uplink#my_ip in
match My_nat.add_nat_rule_and_translate t.nat t.ports ~xl_host `NAT packet with match My_nat.add_nat_rule_and_translate t.nat ~xl_host `NAT packet with
| Ok packet -> forward_ipv4 t packet | Ok packet -> forward_ipv4 t packet
| Error e -> | Error e ->
Log.warn (fun f -> f "Failed to add NAT rewrite rule: %s (%a)" e Nat_packet.pp packet); Log.warn (fun f -> f "Failed to add NAT rewrite rule: %s (%a)" e Nat_packet.pp packet);
@ -60,7 +60,7 @@ let nat_to t ~host ~port packet =
| Ipaddr.V6 _ -> Log.warn (fun f -> f "Cannot NAT with IPv6"); Lwt.return_unit | Ipaddr.V6 _ -> Log.warn (fun f -> f "Cannot NAT with IPv6"); Lwt.return_unit
| Ipaddr.V4 target -> | Ipaddr.V4 target ->
let xl_host = t.uplink#my_ip in let xl_host = t.uplink#my_ip in
match My_nat.add_nat_rule_and_translate t.nat t.ports ~xl_host (`Redirect (target, port)) packet with match My_nat.add_nat_rule_and_translate t.nat ~xl_host (`Redirect (target, port)) packet with
| Ok packet -> forward_ipv4 t packet | Ok packet -> forward_ipv4 t packet
| Error e -> | Error e ->
Log.warn (fun f -> f "Failed to add NAT redirect rule: %s (%a)" e Nat_packet.pp packet); Log.warn (fun f -> f "Failed to add NAT redirect rule: %s (%a)" e Nat_packet.pp packet);

View File

@ -35,12 +35,12 @@ module Transport (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) (Time : Mirage_
let open My_nat in let open My_nat in
let dst, dst_port = ctx.nameserver in let dst, dst_port = ctx.nameserver in
let router, send_udp, answer = ctx.stack in let router, send_udp, answer = ctx.stack in
let src_port = Ports.pick_free_port ~consult:router.ports.nat_udp router.ports.dns_udp in let src_port = My_nat.free_udp_port router.nat ~src:router.uplink#my_ip ~dst ~dst_port:53 in
with_timeout ctx.timeout_ns with_timeout ctx.timeout_ns
((send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg) >>= function ((send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg) >>= function
| Ok () -> (Lwt_mvar.take answer >|= fun (_, dns_response) -> Ok dns_response) | Ok () -> (Lwt_mvar.take answer >|= fun (_, dns_response) -> Ok dns_response)
| Error _ as e -> Lwt.return e) >|= fun result -> | Error _ as e -> Lwt.return e) >|= fun result ->
router.ports.dns_udp := Ports.remove src_port !(router.ports.dns_udp); router.nat.udp_dns <- List.filter (fun p -> p <> src_port) router.nat.udp_dns;
result result
let close _ = Lwt.return_unit let close _ = Lwt.return_unit

View File

@ -11,31 +11,38 @@ type action = [
| `Redirect of Mirage_nat.endpoint | `Redirect of Mirage_nat.endpoint
] ]
type ports = {
nat_tcp : Ports.t ref;
nat_udp : Ports.t ref;
nat_icmp : Ports.t ref;
dns_udp : Ports.t ref;
}
let empty_ports () =
let nat_tcp = ref Ports.empty in
let nat_udp = ref Ports.empty in
let nat_icmp = ref Ports.empty in
let dns_udp = ref Ports.empty in
{ nat_tcp ; nat_udp ; nat_icmp ; dns_udp }
module Nat = Mirage_nat_lru module Nat = Mirage_nat_lru
type t = { type t = {
table : Nat.t; table : Nat.t;
mutable udp_dns : int list;
} }
let create ~max_entries = let create ~max_entries =
let tcp_size = 7 * max_entries / 8 in let tcp_size = 7 * max_entries / 8 in
let udp_size = max_entries - tcp_size in let udp_size = max_entries - tcp_size in
let table = Nat.empty ~tcp_size ~udp_size ~icmp_size:100 in let table = Nat.empty ~tcp_size ~udp_size ~icmp_size:100 in
{ table } { table ; udp_dns = [] }
let pick_free_port t proto =
let rec go () =
let p = 1024 + Random.int (0xffff - 1024) in
match proto with
| `Udp when List.mem p t.udp_dns -> go ()
| _ -> p
in
go ()
let free_udp_port t ~src ~dst ~dst_port =
let rec go () =
let src_port = pick_free_port t `Udp in
if Nat.is_port_free t.table `Udp ~src ~dst ~src_port ~dst_port then begin
t.udp_dns <- src_port :: t.udp_dns;
src_port
end else
go ()
in
go ()
let translate t packet = let translate t packet =
match Nat.translate t.table packet with match Nat.translate t.table packet with
@ -47,46 +54,19 @@ let translate t packet =
None None
| Ok packet -> Some packet | Ok packet -> Some packet
let pick_free_port ~nat_ports ~dns_ports = let remove_connections t ip =
Ports.pick_free_port ~consult:dns_ports nat_ports ignore (Nat.remove_connections t.table ip)
(* just clears the nat ports, dns ports stay as is *) let add_nat_rule_and_translate t ~xl_host action packet =
let reset t ports = let proto = match packet with
ports.nat_tcp := Ports.empty; | `IPv4 (_, `TCP _) -> `Tcp
ports.nat_udp := Ports.empty; | `IPv4 (_, `UDP _) -> `Udp
ports.nat_icmp := Ports.empty; | `IPv4 (_, `ICMP _) -> `Icmp
Nat.reset t.table
let remove_connections t ports ip =
let freed_ports = Nat.remove_connections t.table ip in
ports.nat_tcp := Ports.diff !(ports.nat_tcp) (Ports.of_list freed_ports.Mirage_nat.tcp);
ports.nat_udp := Ports.diff !(ports.nat_udp) (Ports.of_list freed_ports.Mirage_nat.udp);
ports.nat_icmp := Ports.diff !(ports.nat_icmp) (Ports.of_list freed_ports.Mirage_nat.icmp)
let add_nat_rule_and_translate t ports ~xl_host action packet =
let rec aux ~retries =
let nat_ports, dns_ports =
match packet with
| `IPv4 (_, `TCP _) -> ports.nat_tcp, ref Ports.empty
| `IPv4 (_, `UDP _) -> ports.nat_udp, ports.dns_udp
| `IPv4 (_, `ICMP _) -> ports.nat_icmp, ref Ports.empty
in
let xl_port = pick_free_port ~nat_ports ~dns_ports in
match Nat.add t.table packet xl_host (fun () -> xl_port) action with
| Error `Overlap when retries < 0 -> Error "Too many retries"
| Error `Overlap ->
if retries = 0 then (
Log.warn (fun f -> f "Failed to find a free port; resetting NAT table");
reset t ports;
aux ~retries:(retries - 1)
) else (
aux ~retries:(retries - 1)
)
| Error `Cannot_NAT ->
Error "Cannot NAT this packet"
| Ok () ->
Log.debug (fun f -> f "Updated NAT table: %a" Nat.pp_summary t.table);
Option.to_result ~none:"No NAT entry, even after adding one!"
(translate t packet)
in in
aux ~retries:100 match Nat.add t.table packet xl_host (fun () -> pick_free_port t proto) action with
| Error `Overlap -> Error "Too many retries"
| Error `Cannot_NAT -> Error "Cannot NAT this packet"
| Ok () ->
Log.debug (fun f -> f "Updated NAT table: %a" Nat.pp_summary t.table);
Option.to_result ~none:"No NAT entry, even after adding one!"
(translate t packet)

View File

@ -3,25 +3,19 @@
(* Abstract over NAT interface (todo: remove this) *) (* Abstract over NAT interface (todo: remove this) *)
type ports = private { type t = {
nat_tcp : Ports.t ref; table : Mirage_nat_lru.t;
nat_udp : Ports.t ref; mutable udp_dns : int list;
nat_icmp : Ports.t ref;
dns_udp : Ports.t ref;
} }
val empty_ports : unit -> ports
type t
type action = [ type action = [
| `NAT | `NAT
| `Redirect of Mirage_nat.endpoint | `Redirect of Mirage_nat.endpoint
] ]
val free_udp_port : t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> dst_port:int -> int
val create : max_entries:int -> t val create : max_entries:int -> t
val reset : t -> ports -> unit val remove_connections : t -> Ipaddr.V4.t -> unit
val remove_connections : t -> ports -> Ipaddr.V4.t -> unit
val translate : t -> Nat_packet.t -> Nat_packet.t option val translate : t -> Nat_packet.t -> Nat_packet.t option
val add_nat_rule_and_translate : t -> ports -> val add_nat_rule_and_translate : t ->
xl_host:Ipaddr.V4.t -> action -> Nat_packet.t -> (Nat_packet.t, string) result xl_host:Ipaddr.V4.t -> action -> Nat_packet.t -> (Nat_packet.t, string) result

View File

@ -1,16 +0,0 @@
module Set = Set.Make(struct
type t = int
let compare a b = compare a b
end)
include Set
let rec pick_free_port ?(retries = 10) ~consult add_to =
let p = 1024 + Random.int (0xffff - 1024) in
if (mem p !consult || mem p !add_to) && retries <> 0
then pick_free_port ~retries:(retries - 1) ~consult add_to
else
begin
add_to := add p !add_to;
p
end

View File

@ -9,13 +9,10 @@ type t = {
client_eth : Client_eth.t; client_eth : Client_eth.t;
nat : My_nat.t; nat : My_nat.t;
uplink : interface; uplink : interface;
(* NOTE: do not try to make this pure, it relies on mvars / side effects *)
ports : My_nat.ports;
} }
let create ~client_eth ~uplink ~nat = let create ~client_eth ~uplink ~nat =
let ports = My_nat.empty_ports () in { client_eth; nat; uplink }
{ client_eth; nat; uplink; ports }
let target t buf = let target t buf =
let dst_ip = buf.Ipv4_packet.dst in let dst_ip = buf.Ipv4_packet.dst in

View File

@ -9,7 +9,6 @@ type t = private {
client_eth : Client_eth.t; client_eth : Client_eth.t;
nat : My_nat.t; nat : My_nat.t;
uplink : interface; uplink : interface;
ports : My_nat.ports;
} }
val create : val create :

View File

@ -44,7 +44,7 @@ end
Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src); Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src);
match ip_packet with match ip_packet with
| `UDP (header, packet) when Ports.mem header.dst_port !(router.Router.ports.My_nat.dns_udp) -> | `UDP (header, packet) when List.mem header.dst_port router.Router.nat.My_nat.udp_dns ->
Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port); Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port);
Lwt_mvar.put dns_responses (header, packet) Lwt_mvar.put dns_responses (header, packet)
| _ -> | _ ->