package tiny_httpd

  1. Overview
  2. Docs

Source file tiny_httpd_ws.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
open Common_ws_

type handler = unit Request.t -> IO.Input.t -> IO.Output.t -> unit

module Frame_type = struct
  type t = int

  let continuation : t = 0
  let text : t = 1
  let binary : t = 2
  let close : t = 8
  let ping : t = 9
  let pong : t = 10

  let show = function
    | 0 -> "continuation"
    | 1 -> "text"
    | 2 -> "binary"
    | 8 -> "close"
    | 9 -> "ping"
    | 10 -> "pong"
    | _ty -> spf "unknown frame type %xd" _ty
end

module Header = struct
  type t = {
    mutable fin: bool;
    mutable ty: Frame_type.t;
    mutable payload_len: int;
    mutable mask: bool;
    mask_key: bytes;  (** len = 4 *)
  }

  let create () : t =
    {
      fin = false;
      ty = 0;
      payload_len = 0;
      mask = false;
      mask_key = Bytes.create 4;
    }
end

exception Close_connection
(** Raised to close the connection. *)

module Writer = struct
  type t = {
    header: Header.t;
    header_buf: bytes;
    buf: bytes;  (** bufferize writes *)
    mutable offset: int;  (** number of bytes already in [buf] *)
    oc: IO.Output.t;
    mutable closed: bool;
    mutex: Mutex.t;
  }

  let create ?(buf_size = 16 * 1024) ~oc () : t =
    {
      header = Header.create ();
      header_buf = Bytes.create 16;
      buf = Bytes.create buf_size;
      offset = 0;
      oc;
      closed = false;
      mutex = Mutex.create ();
    }

  let[@inline] with_mutex_ (self : t) f =
    Mutex.lock self.mutex;
    try
      let x = f () in
      Mutex.unlock self.mutex;
      x
    with e ->
      Mutex.unlock self.mutex;
      raise e

  let[@inline] close self = self.closed <- true
  let int_of_bool : bool -> int = Obj.magic

  (** Write the frame header to [self.oc] *)
  let write_header_ (self : t) : unit =
    let header_len = ref 2 in
    let b0 =
      Char.chr ((int_of_bool self.header.fin lsl 7) lor self.header.ty)
    in
    Bytes.unsafe_set self.header_buf 0 b0;

    (* we don't mask *)
    let payload_len = self.header.payload_len in
    let payload_first_byte =
      if payload_len < 126 then
        payload_len
      else if payload_len < 1 lsl 16 then (
        Bytes.set_int16_be self.header_buf 2 payload_len;
        header_len := 4;
        126
      ) else (
        Bytes.set_int64_be self.header_buf 2 (Int64.of_int payload_len);
        header_len := 10;
        127
      )
    in

    let b1 =
      Char.chr @@ ((int_of_bool self.header.mask lsl 7) lor payload_first_byte)
    in
    Bytes.unsafe_set self.header_buf 1 b1;

    if self.header.mask then (
      Bytes.blit self.header_buf !header_len self.header.mask_key 0 4;
      header_len := !header_len + 4
    );

    (*Log.debug (fun k ->
        k "websocket: write header ty=%s (%d B)"
          (Frame_type.show self.header.ty)
          !header_len);*)
    IO.Output.output self.oc self.header_buf 0 !header_len;
    ()

  (** Max fragment size: send 16 kB at a time *)
  let max_fragment_size = 16 * 1024

  let[@inline never] really_output_buf_ (self : t) =
    self.header.fin <- true;
    self.header.ty <- Frame_type.binary;
    self.header.payload_len <- self.offset;
    self.header.mask <- false;
    write_header_ self;

    IO.Output.output self.oc self.buf 0 self.offset;
    IO.Output.flush self.oc;
    self.offset <- 0

  let flush_ (self : t) =
    if self.closed then raise Close_connection;
    if self.offset > 0 then really_output_buf_ self

  let[@inline] flush_if_full (self : t) : unit =
    if self.offset = Bytes.length self.buf then really_output_buf_ self

  let send_pong (self : t) : unit =
    let@ () = with_mutex_ self in
    self.header.fin <- true;
    self.header.ty <- Frame_type.pong;
    self.header.payload_len <- 0;
    self.header.mask <- false;
    (* only write a header, we don't send a payload at all *)
    write_header_ self

  let output_char (self : t) c : unit =
    let@ () = with_mutex_ self in
    let cap = Bytes.length self.buf - self.offset in
    (* make room for [c] *)
    if cap = 0 then really_output_buf_ self;
    Bytes.set self.buf self.offset c;
    self.offset <- self.offset + 1;
    (* if [c] made the buffer full, then flush it *)
    if cap = 1 then really_output_buf_ self

  let output (self : t) buf i len : unit =
    let@ () = with_mutex_ self in
    let i = ref i in
    let len = ref len in
    while !len > 0 do
      flush_if_full self;

      let n = min !len (Bytes.length self.buf - self.offset) in
      assert (n > 0);

      Bytes.blit buf !i self.buf self.offset n;
      self.offset <- self.offset + n;

      i := !i + n;
      len := !len - n
    done;
    flush_if_full self

  let flush self : unit =
    let@ () = with_mutex_ self in
    flush_ self
end

module Reader = struct
  type state =
    | Begin  (** At the beginning of a frame *)
    | Reading_frame of { mutable remaining_bytes: int; mutable num_read: int }
        (** Currently reading the payload of a frame with [remaining_bytes]
        left to read from the underlying [ic] *)
    | Close

  type t = {
    ic: IO.Input.t;
    writer: Writer.t;  (** Writer, to send "pong" *)
    header_buf: bytes;  (** small buffer to read frame headers *)
    small_buf: bytes;  (** Used for control frames *)
    header: Header.t;
    last_ty: Frame_type.t;  (** Last frame's type, used for continuation *)
    mutable state: state;
  }

  let create ~ic ~(writer : Writer.t) () : t =
    {
      ic;
      header_buf = Bytes.create 8;
      small_buf = Bytes.create 128;
      writer;
      state = Begin;
      last_ty = 0;
      header = Header.create ();
    }

  (** limitation: we only accept frames that are 2^30 bytes long or less *)
  let max_fragment_size = 1 lsl 30

  (** Read next frame header into [self.header] *)
  let read_frame_header (self : t) : unit =
    (* read header *)
    IO.Input.really_input self.ic self.header_buf 0 2;

    let b0 = Bytes.unsafe_get self.header_buf 0 |> Char.code in
    let b1 = Bytes.unsafe_get self.header_buf 1 |> Char.code in

    self.header.fin <- b0 land 1 == 1;
    let ext = (b0 lsr 4) land 0b0111 in
    if ext <> 0 then (
      Log.error (fun k -> k "websocket: unknown extension %d, closing" ext);
      raise Close_connection
    );

    self.header.ty <- b0 land 0b0000_1111;
    self.header.mask <- b1 land 0b1000_0000 != 0;

    let payload_len : int =
      let len = b1 land 0b0111_1111 in
      if len = 126 then (
        IO.Input.really_input self.ic self.header_buf 0 2;
        Bytes.get_uint16_be self.header_buf 0
      ) else if len = 127 then (
        IO.Input.really_input self.ic self.header_buf 0 8;
        let len64 = Bytes.get_int64_be self.header_buf 0 in
        if Int64.compare len64 (Int64.of_int max_fragment_size) > 0 then (
          Log.error (fun k ->
              k "websocket: maximum frame fragment exceeded (%Ld > %d)" len64
                max_fragment_size);
          raise Close_connection
        );

        Int64.to_int len64
      ) else
        len
    in
    self.header.payload_len <- payload_len;

    if self.header.mask then
      IO.Input.really_input self.ic self.header.mask_key 0 4;

    (*Log.debug (fun k ->
        k "websocket: read frame header type=%s payload_len=%d mask=%b"
          (Frame_type.show self.header.ty)
          self.header.payload_len self.header.mask);*)
    ()

  external apply_masking_ :
    key:bytes -> key_offset:int -> buf:bytes -> int -> int -> unit
    = "tiny_httpd_ws_apply_masking"
    [@@noalloc]
  (** Apply masking to the parsed data *)

  let[@inline] apply_masking ~mask_key ~mask_offset (buf : bytes) off len : unit
      =
    assert (
      Bytes.length mask_key = 4
      && mask_offset >= 0 && off >= 0
      && off + len <= Bytes.length buf);
    apply_masking_ ~key:mask_key ~key_offset:mask_offset ~buf off len

  let read_body_to_string (self : t) : string =
    let len = self.header.payload_len in
    let buf = Bytes.create len in
    IO.Input.really_input self.ic buf 0 len;
    if self.header.mask then
      apply_masking ~mask_key:self.header.mask_key ~mask_offset:0 buf 0 len;
    Bytes.unsafe_to_string buf

  (** Skip bytes of the body *)
  let skip_body (self : t) : unit =
    let len = ref self.header.payload_len in
    while !len > 0 do
      let n = min !len (Bytes.length self.small_buf) in
      IO.Input.really_input self.ic self.small_buf 0 n;
      len := !len - n
    done

  (** State machine that reads [len] bytes into [buf] *)
  let rec read_rec (self : t) buf i len : int =
    match self.state with
    | Close -> 0
    | Reading_frame r when r.remaining_bytes = 0 ->
      self.state <- Begin;
      read_rec self buf i len
    | Reading_frame r ->
      let len = min len r.remaining_bytes in
      let n = IO.Input.input self.ic buf i len in

      (* apply masking *)
      if self.header.mask then
        apply_masking ~mask_key:self.header.mask_key ~mask_offset:r.num_read buf
          i n
      else (
        Log.error (fun k -> k "websocket: client's frames must be masked");
        raise Close_connection
      );

      (* update state *)
      r.remaining_bytes <- r.remaining_bytes - n;
      r.num_read <- r.num_read + n;
      if r.remaining_bytes = 0 then self.state <- Begin;
      n
    | Begin ->
      read_frame_header self;
      Log.debug (fun k ->
          k "websocket: read frame of type=%s payload_len=%d key=%S"
            (Frame_type.show self.header.ty)
            self.header.payload_len
            (Bytes.unsafe_to_string self.header.mask_key));

      (match self.header.ty with
      | 0 ->
        (* continuation *)
        if self.last_ty = 1 || self.last_ty = 2 then
          self.state <-
            Reading_frame
              { remaining_bytes = self.header.payload_len; num_read = 0 }
        else (
          Log.error (fun k ->
              k "continuation frame coming after frame of type %s"
                (Frame_type.show self.last_ty));
          raise Close_connection
        );
        read_rec self buf i len
      | 1 ->
        (* text *)
        self.state <-
          Reading_frame
            { remaining_bytes = self.header.payload_len; num_read = 0 };
        read_rec self buf i len
      | 2 ->
        (* binary *)
        self.state <-
          Reading_frame
            { remaining_bytes = self.header.payload_len; num_read = 0 };
        read_rec self buf i len
      | 8 ->
        (* close frame *)
        self.state <- Close;
        let body = read_body_to_string self in
        if String.length body >= 2 then (
          let errcode = Bytes.get_int16_be (Bytes.unsafe_of_string body) 0 in
          Log.info (fun k ->
              k "client send 'close' with errcode=%d, message=%S" errcode
                (String.sub body 2 (String.length body - 2)))
        );
        0
      | 9 ->
        (* ping, reply *)
        skip_body self;
        Writer.send_pong self.writer;
        read_rec self buf i len
      | 10 ->
        (* pong, just ignore *)
        skip_body self;
        read_rec self buf i len
      | ty ->
        Log.error (fun k -> k "unknown frame type: %xd" ty);
        raise Close_connection)

  let read self buf i len =
    try read_rec self buf i len
    with Close_connection ->
      self.state <- Close;
      0

  let close self : unit =
    if self.state != Close then (
      Log.debug (fun k -> k "websocket: close connection from server side");
      self.state <- Close
    )
end

let upgrade ic oc : _ * _ =
  let writer = Writer.create ~oc () in
  let reader = Reader.create ~ic ~writer () in
  let ws_ic : IO.Input.t =
    object
      inherit IO.Input.t_from_refill ~bytes:(Bytes.create 4_096) ()

      method private refill (slice : IO.Slice.t) =
        slice.off <- 0;
        slice.len <- Reader.read reader slice.bytes 0 (Bytes.length slice.bytes)

      method close () = Reader.close reader
    end
  in
  let ws_oc : IO.Output.t =
    object
      method close () = Writer.close writer
      method flush () = Writer.flush writer
      method output bs i len = Writer.output writer bs i len
      method output_char c = Writer.output_char writer c
    end
  in
  ws_ic, ws_oc

(** Turn a regular connection handler (provided by the user) into a websocket upgrade handler *)
module Make_upgrade_handler (X : sig
  val accept_ws_protocol : string -> bool
  val handler : handler
end) : Server.UPGRADE_HANDLER with type handshake_state = unit Request.t =
struct
  type handshake_state = unit Request.t

  let name = "websocket"

  open struct
    exception Bad_req of string

    let bad_req msg = raise (Bad_req msg)
    let bad_reqf fmt = Printf.ksprintf bad_req fmt
  end

  let handshake_ (req : unit Request.t) =
    (match Request.get_header req "sec-websocket-protocol" with
    | None -> ()
    | Some proto when not (X.accept_ws_protocol proto) ->
      bad_reqf "handler rejected websocket protocol %S" proto
    | Some _proto -> ());
    let key =
      match Request.get_header req "sec-websocket-key" with
      | None -> bad_req "need sec-websocket-key"
      | Some k -> k
    in

    (* TODO: "origin" header *)

    (* produce the accept key *)
    let accept =
      (* yes, SHA1 is broken. It's also part of the spec for websockets. *)
      Utils_.sha_1 (key ^ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
      |> Utils_.B64.encode ~url:false
    in

    let headers = [ "sec-websocket-accept", accept ] in
    Log.debug (fun k ->
        k "websocket: upgrade successful, accept key is %S" accept);
    headers, req

  let handshake _addr req : _ result =
    try Ok (handshake_ req) with Bad_req s -> Error s

  let handle_connection req ic oc =
    let ws_ic, ws_oc = upgrade ic oc in
    try X.handler req ws_ic ws_oc
    with Close_connection ->
      Log.debug (fun k -> k "websocket: requested to close the connection");
      ()
end

let add_route_handler ?accept ?(accept_ws_protocol = fun _ -> true) ?middlewares
    (server : Server.t) route (f : handler) : unit =
  let module M = Make_upgrade_handler (struct
    let handler = f
    let accept_ws_protocol = accept_ws_protocol
  end) in
  let up : Server.upgrade_handler = (module M) in
  Server.add_upgrade_handler ?accept ?middlewares server route up

module Private_ = struct
  let apply_masking = Reader.apply_masking
end
OCaml

Innovation. Community. Security.