mirror of
https://github.com/mirage/qubes-mirage-firewall.git
synced 2024-10-01 01:05:39 -04:00
Adapt to mirage-nat changes:
allow pick_free_port to fail reserve a special udp port for dns (as last resort)
This commit is contained in:
parent
f2d3faf1da
commit
93b92c041b
@ -35,12 +35,14 @@ module Transport (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) (Time : Mirage_
|
||||
let open My_nat in
|
||||
let dst, dst_port = ctx.nameserver in
|
||||
let router, send_udp, answer = ctx.stack in
|
||||
let src_port = My_nat.free_udp_port router.nat ~src:router.uplink#my_ip ~dst ~dst_port:53 in
|
||||
let src_port, evict =
|
||||
My_nat.free_udp_port router.nat ~src:router.uplink#my_ip ~dst ~dst_port:53
|
||||
in
|
||||
with_timeout ctx.timeout_ns
|
||||
((send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg) >>= function
|
||||
| Ok () -> (Lwt_mvar.take answer >|= fun (_, dns_response) -> Ok dns_response)
|
||||
| Error _ as e -> Lwt.return e) >|= fun result ->
|
||||
router.nat.udp_dns <- List.filter (fun p -> p <> src_port) router.nat.udp_dns;
|
||||
evict ();
|
||||
result
|
||||
|
||||
let close _ = Lwt.return_unit
|
||||
|
39
my_nat.ml
39
my_nat.ml
@ -13,37 +13,58 @@ type action = [
|
||||
|
||||
module Nat = Mirage_nat_lru
|
||||
|
||||
module S =
|
||||
Set.Make(struct type t = int let compare (a : int) (b : int) = compare a b end)
|
||||
|
||||
type t = {
|
||||
table : Nat.t;
|
||||
mutable udp_dns : int list;
|
||||
mutable udp_dns : S.t;
|
||||
last_resort_port : int
|
||||
}
|
||||
|
||||
let pick_port () =
|
||||
1024 + Random.int (0xffff - 1024)
|
||||
|
||||
let create ~max_entries =
|
||||
let tcp_size = 7 * max_entries / 8 in
|
||||
let udp_size = max_entries - tcp_size in
|
||||
let table = Nat.empty ~tcp_size ~udp_size ~icmp_size:100 in
|
||||
{ table ; udp_dns = [] }
|
||||
let last_resort_port = pick_port () in
|
||||
{ table ; udp_dns = S.empty ; last_resort_port }
|
||||
|
||||
let pick_free_port t proto =
|
||||
let rec go () =
|
||||
let rec go retries =
|
||||
if retries = 0 then
|
||||
None
|
||||
else
|
||||
let p = 1024 + Random.int (0xffff - 1024) in
|
||||
match proto with
|
||||
| `Udp when List.mem p t.udp_dns -> go ()
|
||||
| _ -> p
|
||||
| `Udp when S.mem p t.udp_dns || p = t.last_resort_port ->
|
||||
go (retries - 1)
|
||||
| _ -> Some p
|
||||
in
|
||||
go ()
|
||||
go 10
|
||||
|
||||
let free_udp_port t ~src ~dst ~dst_port =
|
||||
let rec go () =
|
||||
let src_port = pick_free_port t `Udp in
|
||||
let src_port =
|
||||
Option.value ~default:t.last_resort_port (pick_free_port t `Udp)
|
||||
in
|
||||
if Nat.is_port_free t.table `Udp ~src ~dst ~src_port ~dst_port then begin
|
||||
t.udp_dns <- src_port :: t.udp_dns;
|
||||
src_port
|
||||
let remove =
|
||||
if src_port <> t.last_resort_port then begin
|
||||
t.udp_dns <- S.add src_port t.udp_dns;
|
||||
(fun () -> t.udp_dns <- S.remove src_port t.udp_dns)
|
||||
end else Fun.id
|
||||
in
|
||||
src_port, remove
|
||||
end else
|
||||
go ()
|
||||
in
|
||||
go ()
|
||||
|
||||
let dns_port t port = S.mem port t.udp_dns || port = t.last_resort_port
|
||||
|
||||
let translate t packet =
|
||||
match Nat.translate t.table packet with
|
||||
| Error (`Untranslated | `TTL_exceeded as e) ->
|
||||
|
@ -3,17 +3,16 @@
|
||||
|
||||
(* Abstract over NAT interface (todo: remove this) *)
|
||||
|
||||
type t = {
|
||||
table : Mirage_nat_lru.t;
|
||||
mutable udp_dns : int list;
|
||||
}
|
||||
type t
|
||||
|
||||
type action = [
|
||||
| `NAT
|
||||
| `Redirect of Mirage_nat.endpoint
|
||||
]
|
||||
|
||||
val free_udp_port : t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> dst_port:int -> int
|
||||
val free_udp_port : t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> dst_port:int ->
|
||||
int * (unit -> unit)
|
||||
val dns_port : t -> int -> bool
|
||||
val create : max_entries:int -> t
|
||||
val remove_connections : t -> Ipaddr.V4.t -> unit
|
||||
val translate : t -> Nat_packet.t -> Nat_packet.t option
|
||||
|
@ -44,7 +44,7 @@ end
|
||||
|
||||
Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src);
|
||||
match ip_packet with
|
||||
| `UDP (header, packet) when List.mem header.dst_port router.Router.nat.My_nat.udp_dns ->
|
||||
| `UDP (header, packet) when My_nat.dns_port router.Router.nat header.dst_port ->
|
||||
Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port);
|
||||
Lwt_mvar.put dns_responses (header, packet)
|
||||
| _ ->
|
||||
|
Loading…
Reference in New Issue
Block a user