package websocket-lwt

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file websocket_lwt.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
(*
 * Copyright (c) 2012-2016 Vincent Bernardoff <vb@luminar.eu.org>
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 *)

include Websocket

open Astring
open Lwt.Infix

module Lwt_IO = IO(Cohttp_lwt_unix.IO)
open Lwt_IO

module Request = Cohttp.Request.Make(Cohttp_lwt_unix.IO)
module Response = Cohttp.Response.Make(Cohttp_lwt_unix.IO)

let section = Lwt_log.Section.make "websocket_lwt"
exception HTTP_Error of string

module Connected_client = struct
  type t = {
    buffer: Buffer.t;
    flow: Conduit_lwt_unix.flow;
    ic: Request.IO.ic;
    oc: Request.IO.oc;
    http_request: Cohttp.Request.t;
    standard_frame_replies: bool;
    read_frame: unit -> Frame.t Lwt.t;
  }

  let create
      ?read_buf
      ?(write_buf=Buffer.create 128)
      http_request flow ic oc =
    let read_frame = make_read_frame ?buf:read_buf ~mode:Server ic oc in
    {
      buffer = write_buf;
      flow;
      ic;
      oc;
      http_request;
      standard_frame_replies = false;
      read_frame;
    }

  let send { buffer; oc; _ } frame =
    Buffer.clear buffer;
    write_frame_to_buf ~mode:Server buffer frame;
    Lwt_io.write oc @@ Buffer.contents buffer

  let send_multiple { buffer; oc; _ } frames =
    Buffer.clear buffer;
    List.iter (write_frame_to_buf ~mode:Server buffer) frames;
    Lwt_io.write oc @@ Buffer.contents buffer

  let standard_recv t =
    t.read_frame () >>= fun fr ->
    match fr.Frame.opcode with
    | Frame.Opcode.Ping ->
        send t @@ Frame.create
          ~opcode:Frame.Opcode.Pong () >|= fun () -> fr
    | Frame.Opcode.Close ->
        (* Immediately echo and pass this last message to the user *)
        (if String.length fr.Frame.content >= 2 then
           send t @@ Frame.create
             ~opcode:Frame.Opcode.Close
             ~content:(String.(sub ~start:0 ~stop:2 fr.Frame.content |> Sub.to_string)) ()
         else send t @@ Frame.close 1000
        ) >|= fun () -> fr
    | _ -> Lwt.return fr

  let recv t =
    if t.standard_frame_replies then
      standard_recv t
    else
      t.read_frame ()

  let http_request { http_request; _ } = http_request

  type source =
    | TCP of Ipaddr.t * int
    | Domain_socket of string
    | Vchan of Conduit_lwt_unix.vchan_flow

  let source { flow; _ } : source =
    match flow with
    | Conduit_lwt_unix.TCP tcp_flow ->
      TCP (tcp_flow.Conduit_lwt_unix.ip, tcp_flow.Conduit_lwt_unix.port)
    | Conduit_lwt_unix.Domain_socket { path; _ } ->
      Domain_socket path
    | Conduit_lwt_unix.Vchan flow ->
      Vchan flow

  let make_standard t = { t with standard_frame_replies = true }
end

let set_tcp_nodelay flow =
  let open Conduit_lwt_unix in
  match flow with
  | TCP { fd; _ } -> Lwt_unix.setsockopt fd Lwt_unix.TCP_NODELAY true
  | _ -> ()

let check_origin ?(origin_mandatory=false) ~hosts =
  let pred origin_host = List.exists
    (fun h -> String.Ascii.lowercase h = origin_host)
    hosts
  in
  fun request ->
    let headers = request.Cohttp.Request.headers in
    match Cohttp.Header.get headers "origin" with
    | None -> not origin_mandatory
    | Some origin ->
        let origin = Uri.of_string origin in
        match Uri.host origin with
        | None -> false
        | Some host -> (* host is already lowercased by Uri *)
            pred host

let check_origin_with_host request =
  let headers = request.Cohttp.Request.headers in
  let host = Cohttp.Header.get headers "host" in
  match host with
  | None -> failwith "Missing host header" (* mandatory in http/1.1 *)
  | Some host ->
      (* remove port *)
      let hostname = Option.value_map ~default:host ~f:fst (String.cut ~sep:":" host) in
      check_origin ~hosts:[hostname] request

let with_connection
    ?(extra_headers = Cohttp.Header.init ())
    ?(random_string=Rng.init ())
    ?(ctx=Conduit_lwt_unix.default_ctx)
    client uri =
  let connect () =
    let module C = Cohttp in
    let nonce = random_string 16 |> B64.encode ~pad:true in
    let headers = C.Header.add_list extra_headers
        ["Upgrade"               , "websocket";
         "Connection"            , "Upgrade";
         "Sec-WebSocket-Key"     , nonce;
         "Sec-WebSocket-Version" , "13"] in
    let req = C.Request.make ~headers uri in
    Conduit_lwt_unix.connect ~ctx client >>= fun (flow, ic, oc) ->
    set_tcp_nodelay flow;
    let drain_handshake () =
      Request.write (fun _writer -> Lwt.return_unit) req oc >>= fun () ->
      Response.read ic >>= (function
          | `Ok r -> Lwt.return r
          | `Eof -> Lwt.fail End_of_file
          | `Invalid s -> Lwt.fail @@ Failure s) >>= fun response ->
      let status = C.Response.status response in
      let headers = C.Response.headers response in
      if C.Code.(is_error @@ code_of_status status)
      then Lwt.fail @@ HTTP_Error C.Code.(string_of_status status)
      else if not (C.Response.version response = `HTTP_1_1
                   && status = `Switching_protocols
                   && Option.map ~f:String.Ascii.lowercase @@
                   C.Header.get headers "upgrade" = Some "websocket"
                   && upgrade_present headers
                   && C.Header.get headers "sec-websocket-accept" =
                      Some (nonce ^ websocket_uuid |> b64_encoded_sha1sum)
                  )
      then Lwt.fail (Protocol_error "Bad headers")
      else Lwt_log.info_f ~section "Connected to %s" (Uri.to_string uri)
    in
    Lwt.catch drain_handshake begin fun exn ->
      Lwt_io.close ic >>= fun () ->
      Lwt.fail exn
    end >>= fun () ->
    Lwt.return (ic, oc)
  in
  connect () >|= fun (ic, oc) ->
  let read_frame = make_read_frame ~mode:(Client random_string) ic oc in
  let read_frame () = Lwt.catch read_frame (fun exn -> Lwt.fail exn) in
  let buf = Buffer.create 128 in
  let write_frame frame =
    Buffer.clear buf;
    Lwt.wrap2
      (write_frame_to_buf ~mode:(Client random_string)) buf frame >>= fun () ->
    Lwt_io.write oc @@ Buffer.contents buf in
  read_frame, write_frame

let write_failed_response oc =
  let body = "403 Forbidden" in
  let body_len = String.length body |> Int64.of_int in
  let response = Cohttp.Response.make
      ~status:`Forbidden
      ~encoding:(Cohttp.Transfer.Fixed body_len)
      ()
  in
  let open Response in
  write
    ~flush:true
    (fun writer -> write_body writer body)
    response oc

let establish_server
    ?read_buf ?write_buf
    ?timeout ?stop
    ?on_exn
    ?(check_request=check_origin_with_host)
    ?(ctx=Conduit_lwt_unix.default_ctx) ~mode react =
  let module C = Cohttp in
  let server_fun flow ic oc =
    (Request.read ic >>= function
      | `Ok r -> Lwt.return r
      | `Eof ->
        (* Remote endpoint closed connection. No further action necessary here. *)
        Lwt_log.info ~section "Remote endpoint closed connection" >>= fun () ->
        Lwt.fail End_of_file
      | `Invalid reason ->
        Lwt_log.info_f ~section "Invalid input from remote endpoint: %s" reason >>= fun () ->
        Lwt.fail @@ HTTP_Error reason) >>= fun request ->
    let meth    = C.Request.meth request in
    let version = C.Request.version request in
    let headers = C.Request.headers request in
    let key = C.Header.get headers "sec-websocket-key" in
    if not (
        version = `HTTP_1_1
        && meth = `GET
        && Option.map ~f:String.Ascii.lowercase @@
          C.Header.get headers "upgrade" = Some "websocket"
        && key <> None
        && upgrade_present headers
        && check_request request
      )
    then write_failed_response oc >>= fun () -> Lwt.fail (Protocol_error "Bad headers")
    else
    let key = Option.value_exn key in
    let hash = key ^ websocket_uuid |> b64_encoded_sha1sum in
    let response_headers = C.Header.of_list
        ["Upgrade", "websocket";
         "Connection", "Upgrade";
         "Sec-WebSocket-Accept", hash] in
    let response = C.Response.make
        ~status:`Switching_protocols
        ~encoding:C.Transfer.Unknown
        ~headers:response_headers () in
    Response.write (fun _writer -> Lwt.return_unit) response oc >>= fun () ->
    let client =
      Connected_client.create ?read_buf ?write_buf request flow ic oc in
    react client
  in
  Conduit_lwt_unix.serve ?on_exn ?timeout ?stop ~ctx ~mode begin fun flow ic oc ->
    set_tcp_nodelay flow;
    server_fun flow ic oc
  end

let mk_frame_stream recv =
  let f () =
    recv () >>= fun fr ->
    match fr.Frame.opcode with
    | Frame.Opcode.Close -> Lwt.return_none
    | _ -> Lwt.return (Some fr)
  in
  Lwt_stream.from f

let establish_standard_server
    ?read_buf ?write_buf
    ?timeout ?stop
    ?on_exn ?check_request ?(ctx=Conduit_lwt_unix.default_ctx) ~mode react =
  let f client =
    react (Connected_client.make_standard client)
  in
  establish_server ?read_buf ?write_buf ?timeout ?stop
    ?on_exn ?check_request ~ctx ~mode f
OCaml

Innovation. Community. Security.