package ppx_rapper

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file ppx_rapper.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
open Base
open Ppxlib
module Buildef = Ast_builder.Default

(** Handle 'record_in' etc. in [%rapper "SELECT * FROM USERS" record_in record_out] *)
let parse_args args =
  let allowed_args = [ "record_in"; "record_out"; "syntax_off" ] in
  match
    List.find ~f:(fun a -> not (List.mem ~equal:String.equal allowed_args a)) args
  with
  | Some unknown ->
      Error (Printf.sprintf "Unknown rapper argument '%s'" unknown)
  | None ->
      let record_in = List.mem args "record_in" ~equal:String.equal in
      let record_out = List.mem args "record_out" ~equal:String.equal in
      let syntax_off = List.mem args "syntax_off" ~equal:String.equal in
      Ok (record_in, record_out, syntax_off)

(** Make some subexpressions to be used in generated code *)
let component_expressions ~loc parsed_query =
  let open Query in
  let inputs_caqti_type =
    Codegen.make_caqti_type_tup ~loc parsed_query.in_params
  in
  let outputs_caqti_type =
    Codegen.make_caqti_type_tup ~loc parsed_query.out_params
  in
  let parsed_sql = Buildef.estring ~loc parsed_query.sql in
  (inputs_caqti_type, outputs_caqti_type, parsed_sql)

(** Make a function [expand_get] to produce the expressions for [get_one], [get_opt] and [get_many], and a similar [expand_exec] for [execute] *)
let make_expand_get_and_exec_expression ~loc parsed_query record_in record_out =
  let { Query.sql; in_params; out_params; list_params } = parsed_query in
  match list_params with
  | Some {subsql; string_index; param_index; params} ->
    assert (List.length params = 1);
    let subsql_expr = Buildef.estring ~loc subsql in
    let sql_before = Buildef.estring ~loc @@ String.sub sql ~pos:0 ~len:string_index in
    let sql_after =
      Buildef.estring ~loc
      @@ String.sub sql ~pos:string_index ~len:(String.length sql - string_index)
    in
    let params_before, params_after = List.split_n in_params param_index in
    let expression_contents =
      { Codegen
      . in_params = params_before @ params @ params_after
      ; out_params
      ; record_in
      ; record_out
      }
    in
    let caqti_input_type =
      match List.is_empty params_before, List.is_empty params_after with
      | true, true -> [%expr packed_list_type]
      | true, false ->
        let params_before = Codegen.make_caqti_type_tup ~loc params_before in
        [%expr Caqti_type.(tup2 [%e params_before] packed_list_type)]
      | false, true ->
        let params_before = Codegen.make_caqti_type_tup ~loc params_before in
        [%expr Caqti_type.(tup2 [%e params_before] packed_list_type)]
      | false, false ->
        let params_before = Codegen.make_caqti_type_tup ~loc params_before in
        let params_after = Codegen.make_caqti_type_tup ~loc params_after in
        [%expr Caqti_type.(tup3 [%e params_before] packed_list_type [%e params_after])]
    in
    let outputs_caqti_type = Codegen.make_caqti_type_tup ~loc out_params in
    let list_param = List.hd_exn params in
    let make_generic make_function query_expr =
      let body_fn body =
       [%expr
         match [%e Buildef.pexp_ident ~loc (Codegen.lident_of_param ~loc list_param) ] with
         | [] ->
           Lwt_result.fail Caqti_error.(
             encode_rejected ~uri:Uri.empty ~typ:Caqti_type.unit (Msg "Empty list"))
         | elems ->
             let subsqls = Stdlib.List.map (fun _ -> [%e subsql_expr]) elems in
             let patch = Stdlib.String.concat ", " subsqls in
             let sql = [%e sql_before] ^ patch ^ [%e sql_after] in
             let open Ppx_rapper_runtime in
             let Dynparam.Pack (packed_list_type, [%p Codegen.ppat_of_param ~loc list_param]) =
               Stdlib.List.fold_left
                 (fun pack item ->
                   Dynparam.add Caqti_type.([%e Codegen.make_caqti_type_tup ~loc [list_param]]) item pack)
                 Dynparam.empty
                 elems
             in
             let query = [%e query_expr] in
             [%e body] ]
      in
      [%expr
        let wrapped (module Db : Caqti_lwt.CONNECTION) =
          [%e make_function ~body_fn ~loc expression_contents]
        in
        wrapped]
    in
    let expand_get caqti_request_function_expr make_function =
      try
        Ok (make_generic make_function
             [%expr
               Caqti_request.([%e caqti_request_function_expr])
               ~oneshot:true
               ([%e caqti_input_type] [@ocaml.warning "-33"])
               (Caqti_type.([%e outputs_caqti_type]) [@ocaml.warning "-33"])
               sql ])
      with Codegen.Error s -> Error s
    in

    let expand_exec caqti_request_function_expr make_function =
      try
        Ok (make_generic make_function
             [%expr
               Caqti_request.([%e caqti_request_function_expr])
               [%e caqti_input_type]
               sql ])
      with Codegen.Error s -> Error s
    in
    (expand_get, expand_exec)

  | None ->
    let inputs_caqti_type, outputs_caqti_type, parsed_sql =
      component_expressions ~loc parsed_query
    in
    let expression_contents =
      Codegen.
        {
          in_params = parsed_query.in_params;
          out_params = parsed_query.out_params;
          record_in;
          record_out;
        }
    in
    let make_generic make_function query_expr =
      [%expr
        let query = [%e query_expr] in
        let wrapped (module Db : Caqti_lwt.CONNECTION) =
          [%e make_function ~body_fn:(fun x -> x) ~loc expression_contents]
        in
        wrapped]
    in
    let expand_get caqti_request_function_expr make_function =
      try
        Ok (make_generic make_function
            [%expr
              Caqti_request.([%e caqti_request_function_expr])
                (Caqti_type.([%e inputs_caqti_type]) [@ocaml.warning "-33"])
                (Caqti_type.([%e outputs_caqti_type]) [@ocaml.warning "-33"])
                [%e parsed_sql]])
      with Codegen.Error s -> Error s
    in

    let expand_exec caqti_request_function_expr make_function =
      try
        Ok (make_generic make_function
            [%expr
              Caqti_request.([%e caqti_request_function_expr])
                (Caqti_type.([%e inputs_caqti_type]) [@ocaml.warning "-33"])
                [%e parsed_sql]])
      with Codegen.Error s -> Error s
    in
    (expand_get, expand_exec)

let expand ~loc ~path:_ action query args =
  let expression_result =
    match parse_args args with
    | Error err -> Error err
    | Ok (record_in, record_out, syntax_off) -> (
        match Query.parse query with
        | Error error -> Error (Query.explain_error error)
        | Ok parsed_query -> (
            let syntax_result =
              match syntax_off with
              | false ->
                  let query_sql = match parsed_query.list_params with
                  | Some { subsql; string_index; _ } ->
                    let sql = parsed_query.sql in
                    let sql_before = String.sub sql ~pos:0 ~len:(string_index) in
                    let sql_after =
                      String.sub sql ~pos:string_index ~len:(String.length sql - string_index)
                    in
                    sql_before ^ subsql ^ sql_after
                  | None -> parsed_query.sql
                  in
                  (match Pg_query.parse query_sql with
                  | Ok _ -> Ok ()
                  | Error msg ->
                      Error (Printf.sprintf "Syntax error in SQL: '%s'" msg) )
              | true -> Ok ()
            in
            match syntax_result with
            | Error msg -> Error msg
            | Ok () ->
                Ok
                  (let expand_get, expand_exec =
                     make_expand_get_and_exec_expression ~loc parsed_query
                       record_in record_out
                   in
                   match action with
                   (* execute is special case because there is no output Caqti_type *)
                   | "execute" ->
                       if record_out then
                         Error "record_out is not a valid argument for execute"
                       else expand_exec [%expr exec] Codegen.exec_function
                   | "get_one" -> expand_get [%expr find] Codegen.find_function
                   | "get_opt" ->
                       expand_get [%expr find_opt] Codegen.find_opt_function
                   | "get_many" ->
                       expand_get [%expr collect] Codegen.collect_list_function
                   | _ ->
                       Error
                         "Supported actions are execute, get_one, get_opt and \
                          get_many") ) )
  in
  match expression_result with
  | Ok (Ok expr) -> expr
  | Ok (Error msg) | Error msg ->
      raise
        (Location.Error
           (Location.Error.createf ~loc "Error in ppx_rapper: %s" msg))

(** Captures [[%rapper get_one "SELECT id FROM things WHERE condition"]] *)
let pattern =
  let open Ast_pattern in
  let query_action = pexp_ident (lident __) in
  let query = pair nolabel (estring __) in
  let arg = pair nolabel (pexp_ident (lident __)) in
  (*   let arg_opt = alt_option (arg ^:: nil) nil in *)
  (*   let arg2 = pair nolabel (pexp_ident (lident __)) in *)
  (*   let arg2_opt = alt_option (arg2 ^:: nil) nil in *)
  let arguments = query ^:: many arg in
  pexp_apply query_action arguments

let name = "rapper"

let ext =
  Extension.declare name Extension.Context.expression
    Ast_pattern.(single_expr_payload pattern)
    expand

let () = Driver.register_transformation name ~extensions:[ ext ]
OCaml

Innovation. Community. Security.