diff --git a/my_dns.ml b/my_dns.ml index 8cb169d..80f5ab0 100644 --- a/my_dns.ml +++ b/my_dns.ml @@ -35,12 +35,14 @@ module Transport (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) (Time : Mirage_ let open My_nat in let dst, dst_port = ctx.nameserver in let router, send_udp, answer = ctx.stack in - let src_port = My_nat.free_udp_port router.nat ~src:router.uplink#my_ip ~dst ~dst_port:53 in + let src_port, evict = + My_nat.free_udp_port router.nat ~src:router.uplink#my_ip ~dst ~dst_port:53 + in with_timeout ctx.timeout_ns ((send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg) >>= function | Ok () -> (Lwt_mvar.take answer >|= fun (_, dns_response) -> Ok dns_response) | Error _ as e -> Lwt.return e) >|= fun result -> - router.nat.udp_dns <- List.filter (fun p -> p <> src_port) router.nat.udp_dns; + evict (); result let close _ = Lwt.return_unit diff --git a/my_nat.ml b/my_nat.ml index 2591483..209a562 100644 --- a/my_nat.ml +++ b/my_nat.ml @@ -13,37 +13,58 @@ type action = [ module Nat = Mirage_nat_lru +module S = + Set.Make(struct type t = int let compare (a : int) (b : int) = compare a b end) + type t = { table : Nat.t; - mutable udp_dns : int list; + mutable udp_dns : S.t; + last_resort_port : int } +let pick_port () = + 1024 + Random.int (0xffff - 1024) + let create ~max_entries = let tcp_size = 7 * max_entries / 8 in let udp_size = max_entries - tcp_size in let table = Nat.empty ~tcp_size ~udp_size ~icmp_size:100 in - { table ; udp_dns = [] } + let last_resort_port = pick_port () in + { table ; udp_dns = S.empty ; last_resort_port } let pick_free_port t proto = - let rec go () = - let p = 1024 + Random.int (0xffff - 1024) in - match proto with - | `Udp when List.mem p t.udp_dns -> go () - | _ -> p + let rec go retries = + if retries = 0 then + None + else + let p = 1024 + Random.int (0xffff - 1024) in + match proto with + | `Udp when S.mem p t.udp_dns || p = t.last_resort_port -> + go (retries - 1) + | _ -> Some p in - go () + go 10 let free_udp_port t ~src ~dst ~dst_port = let rec go () = - let src_port = pick_free_port t `Udp in + let src_port = + Option.value ~default:t.last_resort_port (pick_free_port t `Udp) + in if Nat.is_port_free t.table `Udp ~src ~dst ~src_port ~dst_port then begin - t.udp_dns <- src_port :: t.udp_dns; - src_port + let remove = + if src_port <> t.last_resort_port then begin + t.udp_dns <- S.add src_port t.udp_dns; + (fun () -> t.udp_dns <- S.remove src_port t.udp_dns) + end else Fun.id + in + src_port, remove end else go () in go () +let dns_port t port = S.mem port t.udp_dns || port = t.last_resort_port + let translate t packet = match Nat.translate t.table packet with | Error (`Untranslated | `TTL_exceeded as e) -> diff --git a/my_nat.mli b/my_nat.mli index 1a9c1e7..eab1a34 100644 --- a/my_nat.mli +++ b/my_nat.mli @@ -3,17 +3,16 @@ (* Abstract over NAT interface (todo: remove this) *) -type t = { - table : Mirage_nat_lru.t; - mutable udp_dns : int list; -} +type t type action = [ | `NAT | `Redirect of Mirage_nat.endpoint ] -val free_udp_port : t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> dst_port:int -> int +val free_udp_port : t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> dst_port:int -> + int * (unit -> unit) +val dns_port : t -> int -> bool val create : max_entries:int -> t val remove_connections : t -> Ipaddr.V4.t -> unit val translate : t -> Nat_packet.t -> Nat_packet.t option diff --git a/uplink.ml b/uplink.ml index 8ff4c10..b74d1df 100644 --- a/uplink.ml +++ b/uplink.ml @@ -44,7 +44,7 @@ end Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src); match ip_packet with - | `UDP (header, packet) when List.mem header.dst_port router.Router.nat.My_nat.udp_dns -> + | `UDP (header, packet) when My_nat.dns_port router.Router.nat header.dst_port -> Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port); Lwt_mvar.put dns_responses (header, packet) | _ ->