package ppx_rapper

  1. Overview
  2. Docs

Source file ppx_rapper.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
open Base
open Ppxlib
module Buildef = Ast_builder.Default

(** Handle 'record_in' etc. in [%rapper "SELECT * FROM USERS" record_in record_out] *)
let parse_args args =
  let allowed_args =
    [ "record_in"; "record_out"; "function_out"; "syntax_off" ]
  in
  match
    List.find
      ~f:(fun a -> not (List.mem ~equal:String.equal allowed_args a))
      args
  with
  | Some unknown ->
      Error (Printf.sprintf "Unknown rapper argument '%s'" unknown)
  | None ->
      let record_in = List.mem args "record_in" ~equal:String.equal in
      let record_out = List.mem args "record_out" ~equal:String.equal in
      let function_out = List.mem args "function_out" ~equal:String.equal in
      let input_kind = if record_in then `Record else `Labelled_args in
      let output_kind =
        match (record_out, function_out) with
        | false, false -> `Tuple
        | true, false -> `Record
        | false, true -> `Function
        | true, true -> assert false
      in
      let syntax_off = List.mem args "syntax_off" ~equal:String.equal in
      assert (not (function_out && record_out));
      Ok (input_kind, output_kind, syntax_off)

(** Make some subexpressions to be used in generated code *)
let component_expressions ~loc parsed_query =
  let open Query in
  let inputs_caqti_type =
    Codegen.make_caqti_type_tup ~loc parsed_query.in_params
  in
  let outputs_caqti_type =
    Codegen.make_caqti_type_tup ~loc parsed_query.out_params
  in
  let parsed_sql = Buildef.estring ~loc parsed_query.sql in
  (inputs_caqti_type, outputs_caqti_type, parsed_sql)

(** Make a function [expand_get] to produce the expressions for [get_one], [get_opt] and [get_many], and a similar [expand_exec] for [execute] *)
let make_expand_get_and_exec_expression ~loc parsed_query input_kind output_kind
    =
  let { Query.sql; in_params; out_params; list_params } = parsed_query in
  match list_params with
  | Some { subsql; string_index; param_index; params } ->
      assert (List.length params = 1);
      let subsql_expr = Buildef.estring ~loc subsql in
      let sql_before =
        Buildef.estring ~loc @@ String.sub sql ~pos:0 ~len:string_index
      in
      let sql_after =
        Buildef.estring ~loc
        @@ String.sub sql ~pos:string_index
             ~len:(String.length sql - string_index)
      in
      let params_before, params_after = List.split_n in_params param_index in
      let expression_contents =
        {
          Codegen.in_params = params_before @ params @ params_after;
          out_params;
          input_kind;
          output_kind;
        }
      in
      let caqti_input_type =
        let exprs_before =
          List.map ~f:(Codegen.caqti_type_of_param ~loc) params_before
        in
        let exprs_after =
          List.map ~f:(Codegen.caqti_type_of_param ~loc) params_after
        in
        match (List.is_empty params_before, List.is_empty params_after) with
        | true, true -> [%expr packed_list_type]
        | true, false ->
            let expression =
              Codegen.caqti_type_tup_of_expressions ~loc
                ([%expr packed_list_type] :: exprs_after)
            in
            [%expr Caqti_type.([%e expression])]
        | false, true ->
            let expression =
              Codegen.caqti_type_tup_of_expressions ~loc
                (exprs_before @ [ [%expr packed_list_type] ])
            in
            [%expr Caqti_type.([%e expression])]
        | false, false ->
            let expression =
              Codegen.caqti_type_tup_of_expressions ~loc
                (exprs_before @ [ [%expr packed_list_type] ] @ exprs_after)
            in
            [%expr Caqti_type.([%e expression])]
      in
      let outputs_caqti_type = Codegen.make_caqti_type_tup ~loc out_params in
      let list_param = List.hd_exn params in
      let make_generic make_function query_expr =
        let body_fn body =
          let base =
            [%expr
              match
                [%e
                  Buildef.pexp_ident ~loc
                    (Codegen.lident_of_param ~loc list_param)]
              with
              | [] ->
                  Rapper_helper.fail
                    Caqti_error.(
                      encode_rejected ~uri:Uri.empty ~typ:Caqti_type.unit
                        (Msg "Empty list"))
              | elems ->
                  let subsqls =
                    Stdlib.List.map (fun _ -> [%e subsql_expr]) elems
                  in
                  let patch = Stdlib.String.concat ", " subsqls in
                  let sql = [%e sql_before] ^ patch ^ [%e sql_after] in
                  let open Rapper.Internal in
                  let (Dynparam.Pack
                        ( packed_list_type,
                          [%p Codegen.ppat_of_param ~loc list_param] )) =
                    Stdlib.List.fold_left
                      (fun pack item ->
                        Dynparam.add
                          (Caqti_type.(
                             [%e
                               Codegen.make_caqti_type_tup ~loc [ list_param ]])
                          [@ocaml.warning "-33"])
                          item pack)
                      Dynparam.empty elems
                  in
                  let query = [%e query_expr] in
                  [%e body]]
          in
          match output_kind with
          | `Function -> [%expr fun loaders -> [%e base]]
          | _ -> base
        in
        match output_kind with
        | `Function ->
            [%expr
              let wrapped =
                [%e make_function ~body_fn ~loc expression_contents]
              in
              wrapped loaders]
        | _ ->
            [%expr
              let wrapped =
                [%e make_function ~body_fn ~loc expression_contents]
              in
              wrapped]
      in
      let expand_get caqti_request_function_expr make_function =
        try
          Ok
            (make_generic make_function
               [%expr
                 Caqti_request.([%e caqti_request_function_expr])
                   ~oneshot:true ([%e caqti_input_type] [@ocaml.warning "-33"])
                   (Caqti_type.([%e outputs_caqti_type]) [@ocaml.warning "-33"])
                   sql])
        with Codegen.Error s -> Error s
      in

      let expand_exec caqti_request_function_expr make_function =
        try
          Ok
            (make_generic make_function
               [%expr
                 Caqti_request.([%e caqti_request_function_expr])
                   [%e caqti_input_type] sql])
        with Codegen.Error s -> Error s
      in
      (expand_get, expand_exec)
  | None ->
      let inputs_caqti_type, outputs_caqti_type, parsed_sql =
        component_expressions ~loc parsed_query
      in
      let expression_contents =
        Codegen.
          {
            in_params = parsed_query.in_params;
            out_params = parsed_query.out_params;
            input_kind;
            output_kind;
          }
      in
      let make_generic make_function query_expr =
        match output_kind with
        | `Function ->
            [%expr
              fun loaders ->
                let query = [%e query_expr] in
                let wrapped =
                  [%e
                    make_function ~body_fn:(fun x -> x) ~loc expression_contents]
                in
                wrapped loaders]
        | _ ->
            [%expr
              let query = [%e query_expr] in
              let wrapped =
                [%e
                  make_function ~body_fn:(fun x -> x) ~loc expression_contents]
              in
              wrapped]
      in
      let expand_get caqti_request_function_expr make_function =
        try
          Ok
            (make_generic make_function
               [%expr
                 Caqti_request.([%e caqti_request_function_expr])
                   (Caqti_type.([%e inputs_caqti_type]) [@ocaml.warning "-33"])
                   (Caqti_type.([%e outputs_caqti_type]) [@ocaml.warning "-33"])
                   [%e parsed_sql]])
        with Codegen.Error s -> Error s
      in

      let expand_exec caqti_request_function_expr make_function =
        try
          Ok
            (make_generic make_function
               [%expr
                 Caqti_request.([%e caqti_request_function_expr])
                   (Caqti_type.([%e inputs_caqti_type]) [@ocaml.warning "-33"])
                   [%e parsed_sql]])
        with Codegen.Error s -> Error s
      in
      (expand_get, expand_exec)

let expand ~loc ~path:_ action query args =
  let expression_result =
    match parse_args args with
    | Error err -> Error err
    | Ok (input_kind, output_kind, syntax_off) -> (
        match Query.parse query with
        | Error error -> Error (Query.explain_error error)
        | Ok parsed_query -> (
            let syntax_result =
              match syntax_off with
              | false -> (
                  let query_sql =
                    match parsed_query.list_params with
                    | Some { subsql; string_index; _ } ->
                        let sql = parsed_query.sql in
                        let sql_before =
                          String.sub sql ~pos:0 ~len:string_index
                        in
                        let sql_after =
                          String.sub sql ~pos:string_index
                            ~len:(String.length sql - string_index)
                        in
                        sql_before ^ subsql ^ sql_after
                    | None -> parsed_query.sql
                  in
                  match Pg_query.parse query_sql with
                  | Ok _ -> Ok ()
                  | Error msg ->
                      Error (Printf.sprintf "Syntax error in SQL: '%s'" msg) )
              | true -> Ok ()
            in
            match syntax_result with
            | Error msg -> Error msg
            | Ok () ->
                Ok
                  (let expand_get, expand_exec =
                     make_expand_get_and_exec_expression ~loc parsed_query
                       input_kind output_kind
                   in
                   match action with
                   (* execute is special case because there is no output Caqti_type *)
                   | "execute" -> (
                       match output_kind with
                       | `Record ->
                           Error
                             "record_out is not a valid argument for execute"
                       (* TODO - could implement this *)
                       | `Function ->
                           Error
                             "function_out is not a valid argument for execute"
                       | `Tuple ->
                           expand_exec [%expr exec] Codegen.exec_function )
                   | "get_one" -> expand_get [%expr find] Codegen.find_function
                   | "get_opt" ->
                       expand_get [%expr find_opt] Codegen.find_opt_function
                   | "get_many" ->
                       expand_get [%expr collect] Codegen.collect_list_function
                   | _ ->
                       Error
                         "Supported actions are execute, get_one, get_opt and \
                          get_many") ) )
  in
  match expression_result with
  | Ok (Ok expr) -> expr
  | Ok (Error msg) | Error msg ->
      raise
        (Location.Error
           (Location.Error.createf ~loc "Error in ppx_rapper: %s" msg))

(** Captures [[%rapper get_one "SELECT id FROM things WHERE condition"]] *)
let pattern =
  let open Ast_pattern in
  let query_action = pexp_ident (lident __) in
  let query = pair nolabel (estring __) in
  let arg = pair nolabel (pexp_ident (lident __)) in
  (*   let arg_opt = alt_option (arg ^:: nil) nil in *)
  (*   let arg2 = pair nolabel (pexp_ident (lident __)) in *)
  (*   let arg2_opt = alt_option (arg2 ^:: nil) nil in *)
  let arguments = query ^:: many arg in
  pexp_apply query_action arguments

let name = "rapper"

let ext =
  Extension.declare name Extension.Context.expression
    Ast_pattern.(single_expr_payload pattern)
    expand

let () = Driver.register_transformation name ~extensions:[ ext ]
OCaml

Innovation. Community. Security.