package albatross

  1. Overview
  2. Docs

Source file vmm_lwt.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
(* (c) 2017 Hannes Mehnert, all rights reserved *)

open Lwt.Infix

let pp_sockaddr ppf = function
  | Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str
  | Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d"
                                         (Unix.string_of_inet_addr addr) port

let safe_close fd =
  Lwt.catch
    (fun () -> Lwt_unix.close fd)
    (fun _ -> Lwt.return_unit)

let port_socket ip port =
  let open Lwt_unix in
  let pf, addr, sockopt =
    match ip with
    | Ipaddr.V4 v4 ->
      PF_INET, ADDR_INET (Ipaddr_unix.V4.to_inet_addr v4, port), fun _s -> ()
    | Ipaddr.V6 v6 ->
      PF_INET6, ADDR_INET (Ipaddr_unix.V6.to_inet_addr v6, port),
      fun s -> setsockopt s IPV6_ONLY false
  in
  let s = socket pf SOCK_STREAM 0 in
  set_close_on_exec s ;
  setsockopt s SO_REUSEADDR true ;
  sockopt s ;
  bind s addr >>= fun () ->
  listen s 10 ;
  Lwt.return s

let systemd_socket () =
  match Vmm_unix.sd_listen_fds () with
  | Some [ fd ] -> Lwt.return (Lwt_unix.of_unix_file_descr fd)
  | _ -> (* FIXME *) failwith "Systemd socket activation error"

let service_socket sock =
  let permissions = {|, are the permissions of the socket file correct? Is the user to execute this daemon the right one?|} in
  let name = Vmm_core.socket_path sock in
  (Lwt_unix.file_exists name >>= function
    | true ->
      Lwt.catch (fun () -> Lwt_unix.unlink name)
        (function
          | Unix.Unix_error (Unix.EACCES, _, _) ->
            failwith ("Couldn't unlink old socket file (EACCES)" ^ permissions)
          | e -> raise e)
    | false -> Lwt.return_unit)
  >>= fun () ->
  let s = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
  Lwt_unix.set_close_on_exec s;
  let old_umask = Unix.umask 0 in
  let _ = Unix.umask (old_umask land 0o707) in
  Lwt.catch (fun () -> Lwt_unix.(bind s (ADDR_UNIX name)))
    (function
      | Unix.Unix_error (Unix.EACCES, _, _) ->
        failwith ("Couldn't bind socket " ^ name ^ " (EACCES)" ^ permissions)
      | e -> raise e) >|= fun () ->
  Logs.app (fun m -> m "listening on %s" name);
  let _ = Unix.umask old_umask in
  Lwt_unix.listen s 1;
  s

let connect addrtype sockaddr =
  let c = Lwt_unix.(socket addrtype SOCK_STREAM 0) in
  Lwt_unix.set_close_on_exec c ;
  Lwt.catch (fun () ->
      Lwt_unix.(connect c sockaddr) >|= fun () ->
      Some c)
    (fun e ->
       Logs.warn (fun m -> m "error %s connecting to socket %a"
                     (Printexc.to_string e) pp_sockaddr sockaddr);
       safe_close c >|= fun () ->
       None)

let pp_process_status ppf = function
  | Unix.WEXITED c -> Fmt.pf ppf "exited with %d" c
  | Unix.WSIGNALED s -> Fmt.pf ppf "killed by signal %a" Fmt.Dump.signal s
  | Unix.WSTOPPED s -> Fmt.pf ppf "stopped by signal %a" Fmt.Dump.signal s

let ret = function
  | Unix.WEXITED c -> `Exit c
  | Unix.WSIGNALED s -> `Signal s
  | Unix.WSTOPPED s -> `Stop s

let rec waitpid pid =
  Lwt.catch
    (fun () -> Lwt_unix.waitpid [] pid >|= fun r -> Ok r)
    (function
      | Unix.(Unix_error (EINTR, _, _)) ->
        Logs.debug (fun m -> m "EINTR in waitpid(), %d retrying" pid) ;
        waitpid pid
      | e ->
        Logs.err (fun m -> m "error %s in waitpid() %d"
                     (Printexc.to_string e) pid) ;
        Lwt.return (Error ()))

let wait_and_clear pid =
  Logs.debug (fun m -> m "waitpid() for pid %d" pid) ;
  waitpid pid >|= fun r ->
  match r with
  | Error () ->
    Logs.err (fun m -> m "waitpid() for %d returned error" pid) ;
    `Exit 23
  | Ok (_, s) ->
    Logs.debug (fun m -> m "pid %d exited: %a" pid pp_process_status s) ;
    ret s

let read_wire s =
  let buf = Bytes.create 4 in
  let rec r b i l =
    Lwt.catch (fun () ->
        Lwt_unix.read s b i l >>= function
        | 0 ->
          Logs.debug (fun m -> m "end of file while reading") ;
          Lwt.return (Error `Eof)
        | n when n == l -> Lwt.return (Ok ())
        | n when n < l -> r b (i + n) (l - n)
        | _ ->
          Logs.err (fun m -> m "read too much, shouldn't happen)") ;
          Lwt.return (Error `Toomuch))
      (fun e ->
         let err = Printexc.to_string e in
         Logs.err (fun m -> m "exception %s while reading" err) ;
         safe_close s >|= fun () ->
         Error `Exception)
  in
  r buf 0 4 >>= function
  | Error e -> Lwt.return (Error e)
  | Ok () ->
    let len = Bytes.get_int32_be buf 0 in
    if len > 0l then begin
      let b = Bytes.create (Int32.to_int len) in
      r b 0 (Int32.to_int len) >|= function
      | Error e -> Error e
      | Ok () ->
        (*          Logs.debug (fun m -> m "read hdr %a, body %a"
                         (Ohex.pp_hexdump ()) (Bytes.unsafe_to_string buf)
                         (Ohex.pp_hexdump ()) (Bytes.unsafe_to_string b)) ; *)
        match Vmm_asn.wire_of_str (Bytes.unsafe_to_string b) with
        | Error (`Msg msg) ->
          Logs.err (fun m -> m "error %s while parsing data" msg) ;
          Error `Exception
        | (Ok (hdr, _)) as w ->
          if not Vmm_commands.(is_current hdr.version) then
            Logs.warn (fun m -> m "version mismatch, received %a current %a"
                          Vmm_commands.pp_version hdr.Vmm_commands.version
                          Vmm_commands.pp_version Vmm_commands.current);
          w
    end else begin
      Lwt.return (Error `Eof)
    end

let write_raw s buf =
  let rec w off l =
    Lwt.catch (fun () ->
        Lwt_unix.send s buf off l [] >>= fun n ->
        if n = l then
          Lwt.return (Ok ())
        else
          w (off + n) (l - n))
      (fun e ->
         Logs.err (fun m -> m "exception %s while writing" (Printexc.to_string e)) ;
         safe_close s >|= fun () ->
         Error `Exception)
  in
  (*  Logs.debug (fun m -> m "writing %a" Ohex.pp_hexdump (Bytes.unsage_to_string buf)) ; *)
  w 0 (Bytes.length buf)

let write_wire s wire =
  let data = Vmm_asn.wire_to_str wire in
  let dlen = Bytes.create 4 in
  Bytes.set_int32_be dlen 0 (Int32.of_int (String.length data)) ;
  let buf = Bytes.cat dlen (Bytes.unsafe_of_string data) in
  write_raw s buf
OCaml

Innovation. Community. Security.