mirror of
https://github.com/mirage/qubes-mirage-firewall.git
synced 2024-10-01 01:05:39 -04:00
Adapt to mirage-nat changes:
allow pick_free_port to fail reserve a special udp port for dns (as last resort)
This commit is contained in:
parent
f2d3faf1da
commit
93b92c041b
@ -35,12 +35,14 @@ module Transport (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) (Time : Mirage_
|
|||||||
let open My_nat in
|
let open My_nat in
|
||||||
let dst, dst_port = ctx.nameserver in
|
let dst, dst_port = ctx.nameserver in
|
||||||
let router, send_udp, answer = ctx.stack in
|
let router, send_udp, answer = ctx.stack in
|
||||||
let src_port = My_nat.free_udp_port router.nat ~src:router.uplink#my_ip ~dst ~dst_port:53 in
|
let src_port, evict =
|
||||||
|
My_nat.free_udp_port router.nat ~src:router.uplink#my_ip ~dst ~dst_port:53
|
||||||
|
in
|
||||||
with_timeout ctx.timeout_ns
|
with_timeout ctx.timeout_ns
|
||||||
((send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg) >>= function
|
((send_udp ~src_port ~dst ~dst_port buf >|= Rresult.R.open_error_msg) >>= function
|
||||||
| Ok () -> (Lwt_mvar.take answer >|= fun (_, dns_response) -> Ok dns_response)
|
| Ok () -> (Lwt_mvar.take answer >|= fun (_, dns_response) -> Ok dns_response)
|
||||||
| Error _ as e -> Lwt.return e) >|= fun result ->
|
| Error _ as e -> Lwt.return e) >|= fun result ->
|
||||||
router.nat.udp_dns <- List.filter (fun p -> p <> src_port) router.nat.udp_dns;
|
evict ();
|
||||||
result
|
result
|
||||||
|
|
||||||
let close _ = Lwt.return_unit
|
let close _ = Lwt.return_unit
|
||||||
|
43
my_nat.ml
43
my_nat.ml
@ -13,37 +13,58 @@ type action = [
|
|||||||
|
|
||||||
module Nat = Mirage_nat_lru
|
module Nat = Mirage_nat_lru
|
||||||
|
|
||||||
|
module S =
|
||||||
|
Set.Make(struct type t = int let compare (a : int) (b : int) = compare a b end)
|
||||||
|
|
||||||
type t = {
|
type t = {
|
||||||
table : Nat.t;
|
table : Nat.t;
|
||||||
mutable udp_dns : int list;
|
mutable udp_dns : S.t;
|
||||||
|
last_resort_port : int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let pick_port () =
|
||||||
|
1024 + Random.int (0xffff - 1024)
|
||||||
|
|
||||||
let create ~max_entries =
|
let create ~max_entries =
|
||||||
let tcp_size = 7 * max_entries / 8 in
|
let tcp_size = 7 * max_entries / 8 in
|
||||||
let udp_size = max_entries - tcp_size in
|
let udp_size = max_entries - tcp_size in
|
||||||
let table = Nat.empty ~tcp_size ~udp_size ~icmp_size:100 in
|
let table = Nat.empty ~tcp_size ~udp_size ~icmp_size:100 in
|
||||||
{ table ; udp_dns = [] }
|
let last_resort_port = pick_port () in
|
||||||
|
{ table ; udp_dns = S.empty ; last_resort_port }
|
||||||
|
|
||||||
let pick_free_port t proto =
|
let pick_free_port t proto =
|
||||||
let rec go () =
|
let rec go retries =
|
||||||
let p = 1024 + Random.int (0xffff - 1024) in
|
if retries = 0 then
|
||||||
match proto with
|
None
|
||||||
| `Udp when List.mem p t.udp_dns -> go ()
|
else
|
||||||
| _ -> p
|
let p = 1024 + Random.int (0xffff - 1024) in
|
||||||
|
match proto with
|
||||||
|
| `Udp when S.mem p t.udp_dns || p = t.last_resort_port ->
|
||||||
|
go (retries - 1)
|
||||||
|
| _ -> Some p
|
||||||
in
|
in
|
||||||
go ()
|
go 10
|
||||||
|
|
||||||
let free_udp_port t ~src ~dst ~dst_port =
|
let free_udp_port t ~src ~dst ~dst_port =
|
||||||
let rec go () =
|
let rec go () =
|
||||||
let src_port = pick_free_port t `Udp in
|
let src_port =
|
||||||
|
Option.value ~default:t.last_resort_port (pick_free_port t `Udp)
|
||||||
|
in
|
||||||
if Nat.is_port_free t.table `Udp ~src ~dst ~src_port ~dst_port then begin
|
if Nat.is_port_free t.table `Udp ~src ~dst ~src_port ~dst_port then begin
|
||||||
t.udp_dns <- src_port :: t.udp_dns;
|
let remove =
|
||||||
src_port
|
if src_port <> t.last_resort_port then begin
|
||||||
|
t.udp_dns <- S.add src_port t.udp_dns;
|
||||||
|
(fun () -> t.udp_dns <- S.remove src_port t.udp_dns)
|
||||||
|
end else Fun.id
|
||||||
|
in
|
||||||
|
src_port, remove
|
||||||
end else
|
end else
|
||||||
go ()
|
go ()
|
||||||
in
|
in
|
||||||
go ()
|
go ()
|
||||||
|
|
||||||
|
let dns_port t port = S.mem port t.udp_dns || port = t.last_resort_port
|
||||||
|
|
||||||
let translate t packet =
|
let translate t packet =
|
||||||
match Nat.translate t.table packet with
|
match Nat.translate t.table packet with
|
||||||
| Error (`Untranslated | `TTL_exceeded as e) ->
|
| Error (`Untranslated | `TTL_exceeded as e) ->
|
||||||
|
@ -3,17 +3,16 @@
|
|||||||
|
|
||||||
(* Abstract over NAT interface (todo: remove this) *)
|
(* Abstract over NAT interface (todo: remove this) *)
|
||||||
|
|
||||||
type t = {
|
type t
|
||||||
table : Mirage_nat_lru.t;
|
|
||||||
mutable udp_dns : int list;
|
|
||||||
}
|
|
||||||
|
|
||||||
type action = [
|
type action = [
|
||||||
| `NAT
|
| `NAT
|
||||||
| `Redirect of Mirage_nat.endpoint
|
| `Redirect of Mirage_nat.endpoint
|
||||||
]
|
]
|
||||||
|
|
||||||
val free_udp_port : t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> dst_port:int -> int
|
val free_udp_port : t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> dst_port:int ->
|
||||||
|
int * (unit -> unit)
|
||||||
|
val dns_port : t -> int -> bool
|
||||||
val create : max_entries:int -> t
|
val create : max_entries:int -> t
|
||||||
val remove_connections : t -> Ipaddr.V4.t -> unit
|
val remove_connections : t -> Ipaddr.V4.t -> unit
|
||||||
val translate : t -> Nat_packet.t -> Nat_packet.t option
|
val translate : t -> Nat_packet.t -> Nat_packet.t option
|
||||||
|
@ -44,7 +44,7 @@ end
|
|||||||
|
|
||||||
Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src);
|
Log.debug (fun f -> f "received ipv4 packet from %a on uplink" Ipaddr.V4.pp ip_header.Ipv4_packet.src);
|
||||||
match ip_packet with
|
match ip_packet with
|
||||||
| `UDP (header, packet) when List.mem header.dst_port router.Router.nat.My_nat.udp_dns ->
|
| `UDP (header, packet) when My_nat.dns_port router.Router.nat header.dst_port ->
|
||||||
Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port);
|
Log.debug (fun f -> f "found a DNS packet whose dst_port (%d) was in the list of dns_client ports" header.dst_port);
|
||||||
Lwt_mvar.put dns_responses (header, packet)
|
Lwt_mvar.put dns_responses (header, packet)
|
||||||
| _ ->
|
| _ ->
|
||||||
|
Loading…
Reference in New Issue
Block a user