package GT

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file compare.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
(*
 * GT: compare plugin
 * Copyright (C) 2016-2022
 *   Dmitrii Kosarev a.k.a. Kakadu
 * St.Petersburg University, JetBrains Research
 *)

(** {i Compare} plugin: receive another value as inherited attribute and compare.

    For type declaration [type ('a,'b,...) typ = ...] it will create a transformation
    function with type

    [('a -> 'a -> GT.comparison) ->
     ('b -> 'b -> GT.comparison) -> ... -> ('a,'b,...) typ -> GT.comparison ]

    Inherited attribute' is the same as argument, synthetized attribute is {!GT.comparison}.
*)

open Ppxlib
open Stdppx
open Printf
open GTCommon
open HelpersBase

let trait_name = "compare"

(* Compare plugin where we pass another value of the same type as 'inh
 * and return GT.comparison as 'syn
 *)
module Make (AstHelpers : GTHELPERS_sig.S) = struct
  module P = Plugin.Make (AstHelpers)
  open AstHelpers

  let trait_name = trait_name
  let access_GT s = Ldot (Lident "GT", s)

  class g initial_args tdecls =
    object (self : 'self)
      inherit P.with_inherited_attr initial_args tdecls
      method trait_name = trait_name

      method inh_of_main ~loc tdecl =
        let ans = Typ.use_tdecl tdecl in
        if is_polyvariant_tdecl tdecl
        then
          Typ.alias ~loc (Typ.variant_of_t ~loc ans)
          @@ Naming.make_extra_param tdecl.ptype_name.txt
        else ans

      method syn_of_param ~loc _s =
        Typ.of_longident ~loc (Ldot (Lident "GT", "comparison"))

      method syn_of_main ~loc ?in_class tdecl = self#syn_of_param ~loc "dummy"
      method inh_of_param ~loc _tdecl name = Typ.var ~loc name

      method plugin_class_params ~loc (typs : Ppxlib.core_type list) ~typname =
        (* the same as in 'show' plugin *)
        List.map typs ~f:Typ.from_caml
        @ [ Typ.var ~loc @@ Naming.make_extra_param typname ]

      method! make_typ_of_class_argument
        : 'a.
          loc:loc
          -> type_declaration
          -> (Typ.t -> 'a -> 'a)
          -> string
          -> (('a -> 'a) -> 'a -> 'a)
          -> 'a
          -> 'a =
        fun ~loc tdecl chain name k ->
          let subj_t = Typ.var ~loc name in
          let syn_t = self#syn_of_param ~loc name in
          let inh_t = subj_t in
          k @@ chain (Typ.arrow ~loc inh_t @@ Typ.arrow ~loc subj_t syn_t)

      method chain_exprs ~loc e1 e2 =
        Exp.app_list
          ~loc
          (Exp.of_longident ~loc (access_GT "chain_compare"))
          [ e1; Exp.fun_ ~loc (Pat.unit ~loc) e2 ]
      (* [%expr GT.chain_compare [%e e1] (fun () -> [%e e2]) ] *)

      method chain_init ~loc = Exp.construct ~loc (access_GT "EQ") []

      method on_different_constructors ~loc is_poly other_name cname arg_typs =
        assert (not @@ String.is_empty cname);
        (* Format.printf "%s %s %d\n" cname __FILE__ __LINE__; *)
        let (_
              : [ `Record of (string * string * core_type) list
                | `Tuples of (string * core_type) list
                ])
          =
          arg_typs
        in
        Exp.app_list
          ~loc
          (Exp.of_longident ~loc (access_GT "compare_vari"))
          [ Exp.ident ~loc other_name
          ; ((if is_poly
             then Exp.variant ~loc cname
             else Exp.construct ~loc (lident cname))
            @@
            match arg_typs with
            | `Tuples ts -> List.map ts ~f:(fun _ -> Exp.objmagic_unit ~loc)
            | `Record rs ->
              [ Exp.record ~loc
                @@ List.map rs ~f:(fun (_, l, _) -> lident l, Exp.objmagic_unit ~loc)
              ])
            (* List.map arg_typs ~f:(fun _ -> Exp.objmagic_unit ~loc) *)
            (* It's annoying to use magic here but need to do this first:
           https://caml.inria.fr/mantis/print_bug_page.php?bug_id=4751
        *)
          ]

      method on_tuple_constr ~loc ~is_self_rec ~mutual_decls ~inhe tdecl constr_info args
          =
        let is_poly, cname =
          match constr_info with
          | Some (`Normal s) -> false, s
          | Some (`Poly s) -> true, s
          | None -> false, ""
        in
        let main_case =
          let pat_names = List.map args ~f:(fun _ -> gen_symbol ()) in
          let lhs =
            let arg_pats =
              match pat_names with
              | [] -> []
              | [ s ] -> [ Pat.var ~loc s ]
              | __ -> List.map pat_names ~f:(Pat.var ~loc)
            in
            match constr_info with
            | Some (`Normal s) -> Pat.constr ~loc s arg_pats
            | Some (`Poly s) -> Pat.variant ~loc s arg_pats
            | None -> Pat.tuple ~loc arg_pats
          in
          let rhs =
            (* TODO: rewrite with fold2_exn *)
            List.fold_left
              ~init:(self#chain_init ~loc)
              (List.map2_exn pat_names args ~f:(fun a (b, c) -> a, b, c))
              ~f:(fun acc (pname, name, typ) ->
                self#chain_exprs
                  ~loc
                  acc
                  (self#app_transformation_expr
                     ~loc
                     (self#do_typ_gen ~loc ~is_self_rec ~mutual_decls tdecl typ)
                     (Exp.ident ~loc pname)
                     (Exp.ident ~loc name)))
          in
          case ~lhs ~rhs
        in
        let all_cases =
          if has_many_constructors_tdecl tdecl
          then (
            let other_cases =
              let other_name = gen_symbol ~prefix:"other" () in
              let lhs = Pat.var ~loc other_name in
              match constr_info with
              | Some (`Normal s) ->
                assert (not (String.equal "" s));
                let rhs =
                  self#on_different_constructors
                    ~loc
                    false
                    other_name
                    cname
                    (`Tuples args)
                in
                [ case ~lhs ~rhs ]
              | Some (`Poly s) ->
                assert (not (String.equal "" s));
                let rhs =
                  self#on_different_constructors ~loc true other_name cname (`Tuples args)
                in
                [ case ~lhs ~rhs ]
              | None -> []
            in
            main_case :: other_cases)
          else [ main_case ]
        in
        Exp.match_ ~loc inhe all_cases

      method app_transformation_expr ~loc trf inh subj =
        Exp.app_list ~loc trf [ inh; subj ]

      method abstract_trf ~loc k =
        Exp.fun_list ~loc [ Pat.sprintf ~loc "inh"; Pat.sprintf ~loc "subj" ]
        @@ k (Exp.sprintf ~loc "inh") (Exp.sprintf ~loc "subj")

      method on_record_declaration ~loc ~is_self_rec ~mutual_decls tdecl labs =
        assert (Int.(List.length labs > 0));
        let pat =
          Pat.record ~loc
          @@ List.map labs ~f:(fun { pld_name } ->
               Lident pld_name.txt, Pat.var ~loc pld_name.txt)
        in
        let methname = sprintf "do_%s" tdecl.ptype_name.txt in
        [ (Cf.method_concrete ~loc methname
          (* TODO: maybe use abstract_transformation_expr here *)
          @@ Exp.fun_list ~loc [ Pat.sprintf ~loc "inh"; pat ]
          @@
          let wrap lab =
            self#app_transformation_expr
              ~loc
              (self#do_typ_gen ~loc ~is_self_rec ~mutual_decls tdecl lab.pld_type)
              (Exp.field ~loc (Exp.ident ~loc "inh") (Lident lab.pld_name.txt))
              (Exp.ident ~loc lab.pld_name.txt)
          in
          let init = self#chain_init ~loc in
          List.fold_left ~init labs ~f:(fun acc lab ->
            self#chain_exprs ~loc acc (wrap lab)))
        ]

      method! on_record_constr
        ~loc
        ~is_self_rec
        ~mutual_decls
        ~inhe
        tdecl
        info
        bindings
        labs =
        assert (Int.(List.length labs > 0));
        let is_poly, cname =
          match info with
          | `Normal s -> false, s
          | `Poly s -> true, s
        in
        let main_case =
          let pat_names = List.map labs ~f:(fun _ -> gen_symbol ()) in
          let lhs =
            Pat.constr
              ~loc
              cname
              [ Pat.record ~loc
                @@ List.map2_exn labs pat_names ~f:(fun l name ->
                     lident l.pld_name.txt, Pat.var ~loc name)
              ]
          in
          let rhs =
            List.fold_left
              ~init:(self#chain_init ~loc)
              (List.map2_exn bindings pat_names ~f:(fun (name1, _, typ) iname ->
                 name1, iname, typ))
              ~f:(fun acc (sname, iname, typ) ->
                self#chain_exprs
                  ~loc
                  acc
                  (self#app_transformation_expr
                     ~loc
                     (self#do_typ_gen ~loc ~is_self_rec ~mutual_decls tdecl typ)
                     (Exp.ident ~loc iname)
                     (Exp.ident ~loc sname)))
          in
          case ~lhs ~rhs
        in
        let all_cases =
          if has_many_constructors_tdecl tdecl
          then (
            let other_cases =
              let other_name = "other" in
              let lhs = Pat.var ~loc other_name in
              let rhs =
                self#on_different_constructors ~loc is_poly other_name cname
                @@ `Record bindings
              in
              case ~lhs ~rhs
            in
            [ main_case; other_cases ])
          else [ main_case ]
        in
        Exp.match_ ~loc inhe all_cases
    end

  let create = (new g :> P.plugin_constructor)
end

let register () = Expander.register_plugin trait_name (module Make : Plugin_intf.MAKE)
let () = register ()
OCaml

Innovation. Community. Security.