package dream

  1. Overview
  2. Docs

Source file session.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
(* This file is part of Dream, released under the MIT license. See LICENSE.md
   for details, or visit https://github.com/aantron/dream.

   Copyright 2021 Anton Bachin *)



module Dream = Dream_pure
module Cookie = Dream__server.Cookie
module Session = Dream__server.Session



let (|>?) =
  Option.bind

module type DB = Caqti_lwt.CONNECTION

module R = Caqti_request
module T = Caqti_type

let serialize_payload payload =
  payload
  |> List.map (fun (name, value) -> name, `String value)
  |> fun assoc -> `Assoc assoc
  |> Yojson.Basic.to_string

let insert =
  let query =
    let open Caqti_request.Infix in
    (T.(t4 string string float string) ->. T.unit) {|
      INSERT INTO dream_session (id, label, expires_at, payload)
      VALUES ($1, $2, $3, $4)
    |} in

  fun (module Db : DB) (session : Session.session) ->
    let payload = serialize_payload session.payload in
    let%lwt result =
      Db.exec query (session.id, session.label, session.expires_at, payload) in
    Caqti_lwt.or_fail result

let find_opt =
  let query =
    let open Caqti_request.Infix in
    (T.string ->? T.(t3 string float string))
      "SELECT label, expires_at, payload FROM dream_session WHERE id = $1" in

  fun (module Db : DB) id ->
    let%lwt result = Db.find_opt query id in
    match%lwt Caqti_lwt.or_fail result with
    | None -> Lwt.return_none
    | Some (label, expires_at, payload) ->
      (* TODO Mind exceptions! *)
      let payload =
        Yojson.Basic.from_string payload
        |> function
          | `Assoc payload ->
            payload |> List.map (function
              | name, `String value -> name, value
              | _ -> failwith "Bad payload")
          | _ -> failwith "Bad payload"
      in
      Lwt.return_some Session.{
        id;
        label;
        expires_at;
        payload;
      }

let refresh =
  let query =
    let open Caqti_request.Infix in
    (T.(t2 float string) ->. T.unit)
      "UPDATE dream_session SET expires_at = $1 WHERE id = $2" in

  fun (module Db : DB) (session : Session.session) ->
    let%lwt result = Db.exec query (session.expires_at, session.id) in
    Caqti_lwt.or_fail result

let update =
  let query =
    let open Caqti_request.Infix in
    (T.(t2 string string) ->. T.unit)
      "UPDATE dream_session SET payload = $1 WHERE id = $2" in

  fun (module Db : DB) (session : Session.session) ->
    let payload = serialize_payload session.payload in
    let%lwt result = Db.exec query (payload, session.id) in
    Caqti_lwt.or_fail result

let remove =
  let query =
    let open Caqti_request.Infix in
    (T.string ->. T.unit) "DELETE FROM dream_session WHERE id = $1"  in

  fun (module Db : DB) id ->
    let%lwt result = Db.exec query id in
    Caqti_lwt.or_fail result

(* TODO Session sharing is greatly complicated by the backing store; is it ok to
   just work with snapshots? All kinds of race conditions may be possible,
   unless there is a generation value or the like. *)
(* TODO This can be greatly addressed with a cache, which is desirable
   anyway. *)
(* TODO The in-memory sessions manager should actually be re-done in terms of
   the cache, just with no persistent backing store. *)

let rec create db expires_at attempt =
  let session = Session.{
    id = Session.new_id ();
    label = Session.new_label ();
    expires_at;
    payload = [];
  } in
  (* Assume that any exception is a PRIMARY KEY collision (extremely unlikely)
     and try a couple more times. *)
  match%lwt insert db session with
  | exception Caqti_error.Exn _ when attempt <= 3 ->
    create db expires_at (attempt + 1)
  | () ->
    Lwt.return session

let put request (session : Session.session) name value =
  session.payload
  |> List.remove_assoc name
  |> fun dictionary -> (name, value)::dictionary
  |> fun dictionary -> session.payload <- dictionary;
  Sql.sql request (fun db -> update db session)

let drop request (session : Session.session) name =
  session.payload
  |> List.remove_assoc name
  |> fun dictionary -> session.payload <- dictionary;
  Sql.sql request (fun db -> update db session)

let invalidate request lifetime operations (session : Session.session ref) =
  Sql.sql request begin fun db ->
    let%lwt () = remove db !session.id in
    let%lwt new_session = create db (Unix.gettimeofday () +. lifetime) 1 in
    session := new_session;
    operations.Session.dirty <- true;
    Lwt.return_unit
  end

let operations request lifetime (session : Session.session ref) dirty =
  let rec operations = {
    Session.put = (fun name value -> put request !session name value);
    drop = (fun name -> drop request !session name);
    invalidate = (fun () -> invalidate request lifetime operations session);
    dirty;
  } in
  operations

let load lifetime request =
  Sql.sql request begin fun db ->
    let now = Unix.gettimeofday () in

    let%lwt valid_session =
      match Cookie.cookie request ~decrypt:false Session.session_cookie with
      | None -> Lwt.return_none
      | Some id ->
        match Session.read_session_id id with
        | None -> Lwt.return_none
        | Some id ->
          match%lwt find_opt db id with
          | None -> Lwt.return_none
          | Some session ->
            if session.expires_at > now then
              Lwt.return (Some session)
            else begin
              let%lwt () = remove db id in
              Lwt.return_none
            end
    in

    let%lwt dirty, session =
      match valid_session with
      | Some session ->
        if session.expires_at -. now > (lifetime /. 2.) then
          Lwt.return (false, session)
        else begin
          session.expires_at <- now +. lifetime;
          let%lwt () = refresh db session in
          Lwt.return (true, session)
        end
      | None ->
        let%lwt session = create db (now +. lifetime) 1 in
        Lwt.return (true, session)
    in

    let session = ref session in
    Lwt.return (operations request lifetime session dirty, session)
  end

let send (operations, session) request response =
  if operations.Session.dirty then begin
    let id = Session.version_session_id !session.Session.id in
    let max_age = !session.Session.expires_at -. Unix.gettimeofday () in
    Cookie.set_cookie
      response
      request
      Session.session_cookie
      id
      ~encrypt:false
      ~max_age
  end;
  Lwt.return response

let back_end lifetime = {
  Session.load = load lifetime;
  send;
}

let sql_sessions ?(lifetime = Session.two_weeks) =
  Session.middleware (back_end lifetime)
OCaml

Innovation. Community. Security.