Adapt to mirage-nat changes:

allow pick_free_port to fail
reserve a special udp port for dns (as last resort)
This commit is contained in:
Hannes Mehnert 2022-10-08 10:50:29 +02:00
parent f2d3faf1da
commit 93b92c041b
4 changed files with 41 additions and 19 deletions

View File

@ -35,12 +35,14 @@ module Transport (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) (Time : Mirage_
let open My_nat in let open My_nat in
let dst, dst_port = ctx.nameserver in let dst, dst_port = ctx.nameserver in
let router, send_udp, answer = ctx.stack in let router, send_udp, answer = ctx.stack in
let src_port = My_nat.free_udp_port router.nat ~src:router.uplink#my_ip ~dst ~dst_port:53 in let src_port, evict =
My_nat.free_udp_port router.nat ~src:router.uplink#my_ip ~dst ~dst_port:53
in
with_timeout ctx.timeout_ns with_timeout ctx.timeout_ns
((send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg) >>= function ((send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg) >>= function
| Ok () -> (Lwt_mvar.take answer >|= fun (_, dns_response) -> Ok dns_response) | Ok () -> (Lwt_mvar.take answer >|= fun (_, dns_response) -> Ok dns_response)
| Error _ as e -> Lwt.return e) >|= fun result -> | Error _ as e -> Lwt.return e) >|= fun result ->
router.nat.udp_dns <- List.filter (fun p -> p <> src_port) router.nat.udp_dns; evict ();
result result
let close _ = Lwt.return_unit let close _ = Lwt.return_unit

View File

@ -13,37 +13,58 @@ type action = [
module Nat = Mirage_nat_lru module Nat = Mirage_nat_lru
module S =
Set.Make(struct type t = int let compare (a : int) (b : int) = compare a b end)
type t = { type t = {
table : Nat.t; table : Nat.t;
mutable udp_dns : int list; mutable udp_dns : S.t;
last_resort_port : int
} }
let pick_port () =
1024 + Random.int (0xffff - 1024)
let create ~max_entries = let create ~max_entries =
let tcp_size = 7 * max_entries / 8 in let tcp_size = 7 * max_entries / 8 in
let udp_size = max_entries - tcp_size in let udp_size = max_entries - tcp_size in
let table = Nat.empty ~tcp_size ~udp_size ~icmp_size:100 in let table = Nat.empty ~tcp_size ~udp_size ~icmp_size:100 in
{ table ; udp_dns = [] } let last_resort_port = pick_port () in
{ table ; udp_dns = S.empty ; last_resort_port }
let pick_free_port t proto = let pick_free_port t proto =
let rec go () = let rec go retries =
let p = 1024 + Random.int (0xffff - 1024) in if retries = 0 then
match proto with None
| `Udp when List.mem p t.udp_dns -> go () else
| _ -> p let p = 1024 + Random.int (0xffff - 1024) in
match proto with
| `Udp when S.mem p t.udp_dns || p = t.last_resort_port ->
go (retries - 1)
| _ -> Some p
in in
go () go 10
let free_udp_port t ~src ~dst ~dst_port = let free_udp_port t ~src ~dst ~dst_port =
let rec go () = let rec go () =
let src_port = pick_free_port t `Udp in let src_port =
Option.value ~default:t.last_resort_port (pick_free_port t `Udp)
in
if Nat.is_port_free t.table `Udp ~src ~dst ~src_port ~dst_port then begin if Nat.is_port_free t.table `Udp ~src ~dst ~src_port ~dst_port then begin
t.udp_dns <- src_port :: t.udp_dns; let remove =
src_port if src_port <> t.last_resort_port then begin
t.udp_dns <- S.add src_port t.udp_dns;
(fun () -> t.udp_dns <- S.remove src_port t.udp_dns)
end else Fun.id
in
src_port, remove
end else end else
go () go ()
in in
go () go ()
let dns_port t port = S.mem port t.udp_dns || port = t.last_resort_port
let translate t packet = let translate t packet =
match Nat.translate t.table packet with match Nat.translate t.table packet with
| Error (`Untranslated | `TTL_exceeded as e) -> | Error (`Untranslated | `TTL_exceeded as e) ->

View File

@ -3,17 +3,16 @@
(* Abstract over NAT interface (todo: remove this) *) (* Abstract over NAT interface (todo: remove this) *)
type t = { type t
table : Mirage_nat_lru.t;
mutable udp_dns : int list;
}
type action = [ type action = [
| `NAT | `NAT
| `Redirect of Mirage_nat.endpoint | `Redirect of Mirage_nat.endpoint
] ]
val free_udp_port : t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> dst_port:int -> int val free_udp_port : t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> dst_port:int ->
int * (unit -> unit)
val dns_port : t -> int -> bool
val create : max_entries:int -> t val create : max_entries:int -> t
val remove_connections : t -> Ipaddr.V4.t -> unit val remove_connections : t -> Ipaddr.V4.t -> unit
val translate : t -> Nat_packet.t -> Nat_packet.t option val translate : t -> Nat_packet.t -> Nat_packet.t option

View File

@ -44,7 +44,7 @@ end
Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src); Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src);
match ip_packet with match ip_packet with
| `UDP (header, packet) when List.mem header.dst_port router.Router.nat.My_nat.udp_dns -> | `UDP (header, packet) when My_nat.dns_port router.Router.nat header.dst_port ->
Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port); Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port);
Lwt_mvar.put dns_responses (header, packet) Lwt_mvar.put dns_responses (header, packet)
| _ -> | _ ->