Source file infer.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
(** Type inference and checking *)
open Common
open Lplib
open Term
open Timed
open Print
(** Logging function for typing. *)
let log = Logger.make 'i' "infr" "type inference/checking"
let log = log.pp
type octxt = ctxt * bctxt
let boxed = snd
let classic = fst
let extend (cctx, bctx) v ?def ty =
((v, ty, def) :: cctx, (v, lift ty, Option.map lift def) :: bctx)
let unbox = Bindlib.unbox
(** Exception that may be raised by type inference. *)
exception NotTypable
(** [unif p c a b] adds the constraint [c |- a==b] in [p] if [a] is not
convertible to [b]. *)
let unif : problem -> octxt -> term -> term -> unit = fun pb c a b ->
if not (Eval.pure_eq_modulo (classic c) a b) then
begin
if Logger.log_enabled () then
log (Color.yel "add constraint %a") constr (classic c, a, b);
pb := {!pb with to_solve = (classic c, a, b) :: !pb.to_solve}
end
(** {1 Handling coercions} *)
(** [reduce_coercions c t] tries to reduce coercions that are in term [t]. The
reduction is attempted bottom up: first simplify leaves then go up to the
root. It returns [None] if some coercions couldn't be simplified, and
[Some t] where [t] is the simplified term otherwise. *)
let rec reduce_coercions : octxt -> term -> term option = fun c t ->
let open Option.Monad in
let is_coercion = function
| Symb s when s == Coercion.coerce -> true
| _ -> false
in
let (hd, args) = get_args t in
if is_coercion hd then
let* args = List.map (reduce_coercions c) args |> List.sequence_opt in
let reduct = Eval.whnf (classic c) (add_args hd args) in
let hd, args = get_args reduct in
if is_coercion hd then None else reduce_coercions c (add_args hd args)
else
let reduce_coercions_binder b =
let x, b = Bindlib.unbind b in
let* b = reduce_coercions c b in
return (Bindlib.(unbox (bind_var x (Term.lift b))))
in
match unfold t with
| Patt _ | Wild | TEnv _ | TRef _ -> assert false
| Plac _
| Kind
| Type | Vari _ | Symb _ | Meta _ -> return t
| Appl (t, u) ->
let* t = reduce_coercions c t in let* u = reduce_coercions c u in
return (mk_Appl (t, u))
| Abst (a, b) ->
let* a = reduce_coercions c a in
let* b = reduce_coercions_binder b in
return (mk_Abst (a, b))
| Prod (a, b) ->
let* a = reduce_coercions c a in
let* b = reduce_coercions_binder b in
return (mk_Prod (a, b))
| LLet (a, e, b) ->
let* a = reduce_coercions c a in
let* e = reduce_coercions c e in
let* b = reduce_coercions_binder b in
return (mk_LLet (a, e, b))
(** [coerce pb c t a b] coerces term [t] from type [a] to type [b] in context
[c] and problem [pb]. *)
let rec coerce : problem -> octxt -> term -> term -> term -> term * bool =
fun pb c t a b ->
if Eval.pure_eq_modulo (classic c) a b then (t, false) else
match Coercion.apply a b t |> reduce_coercions c with
| None -> unif pb c a b; (t, false)
| Some u ->
if Logger.log_enabled () then
log "Coerced [%a : %a <: %a : %a]" term t term a term u term b;
let u, _, _ = infer pb c u in
(u, true)
(** {1 Other rules} *)
(** NOTE: functions {!val:type_enforce}, {!val:force} and {!val:infer}
return a boolean which is true iff the typechecked term has been
modified. It allows to bypass reconstruction of some Bindlib terms (which
call [lift |> bind_var x |> unbox]). It reduces the type checking time of
Holide by 21%. *)
(** [type_enforce pb c a] returns a tuple [(a',s)] where [a'] is refined
term [a] and [s] is a sort (Type or Kind) such that [a'] is of type
[s]. *)
and type_enforce : problem -> octxt -> term -> term * term * bool =
fun pb c a ->
if Logger.log_enabled () then log "Type enforce [%a]" term a;
let a, s, cui = infer pb c a in
let sort =
match unfold s with
| Kind -> mk_Kind
| Type -> mk_Type
| _ -> mk_Type
in
let a, cu = coerce pb c a s sort in
(a, sort, cui || cu)
(** [force pb c t a] returns a term [t'] such that [t'] has type [a],
and [t'] is the refinement of [t]. *)
and force : problem -> octxt -> term -> term -> term * bool =
fun pb c te ty ->
if Logger.log_enabled () then
log "Force [%a] of [%a]" term te term ty;
match unfold te with
| Plac true ->
unif pb c ty mk_Type;
(unbox (LibMeta.bmake pb (boxed c) _Type), true)
| Plac false ->
(unbox (LibMeta.bmake pb (boxed c) (lift ty)), true)
| _ ->
let (t, a, cui) = infer pb c te in
let t, cu = coerce pb c t a ty in
(t, cu || cui)
and infer_aux : problem -> octxt -> term -> term * term * bool =
fun pb c t ->
match unfold t with
| Patt _ -> assert false
| TEnv _ -> assert false
| Kind -> assert false
| Wild -> assert false
| TRef _ -> assert false
| Type -> (mk_Type, mk_Kind, false)
| Vari x ->
let a = try Ctxt.type_of x (classic c) with Not_found -> assert false in
(t, a, false)
| Symb s -> (t, !(s.sym_type), false)
| Plac true ->
let m = LibMeta.bmake pb (boxed c) _Type in
(unbox m, mk_Type, true)
| Plac false ->
let mt = LibMeta.bmake pb (boxed c) _Type in
let m = LibMeta.bmake pb (boxed c) mt in
(unbox m, unbox mt, true)
| (Meta (m, ts)) as t ->
let cu = Stdlib.ref false in
let rec ref_esubst i range =
if i >= Array.length ts then range else
match unfold range with
| Prod(ai, b) ->
let (tsi, cuf) = force pb c ts.(i) ai in
ts.(i) <- tsi;
Stdlib.(cu := !cu || cuf);
ref_esubst (i+1) (Bindlib.subst b tsi)
| LLet(_,d,b) ->
unif pb c ts.(i) d;
ref_esubst (i+1) (Bindlib.subst b d)
| _ ->
assert false
in
let range = ref_esubst 0 !(m.meta_type) in
(t, range, Stdlib.(!cu))
| LLet (t_ty, t, u) as top ->
let t_ty, _, cu_t_ty = type_enforce pb c t_ty in
let t, cu_t = force pb c t t_ty in
let (x, u) = Bindlib.unbind u in
let c = extend c x ~def:t t_ty in
let u, u_ty, cu_u = infer pb c u in
( match unfold u_ty with
| Kind ->
Error.fatal_msg "Let bindings cannot have a body of type Kind.";
Error.fatal_msg "Body of let binding [%a] has type Kind."
term u;
raise NotTypable
| _ -> () );
let u_ty = Bindlib.(u_ty |> lift |> bind_var x |> unbox) in
let top_ty = mk_LLet (t_ty, t, u_ty) in
let cu = cu_t_ty || cu_t || cu_u in
let top =
if cu then
let u = Bindlib.(u |> lift |> bind_var x |> unbox) in
mk_LLet(t_ty, t, u)
else top
in
(top, top_ty, cu)
| Abst (dom, b) as top ->
let dom, cu_dom = force pb c dom mk_Type in
let (x, b) = Bindlib.unbind b in
let c = extend c x dom in
let b, range, cu_b = infer pb c b in
let range = Bindlib.(lift range |> bind_var x |> unbox) in
let top_ty = mk_Prod (dom, range) in
let cu = cu_b || cu_dom in
let top =
if cu then
let b = Bindlib.(lift b |> bind_var x |> unbox) in
mk_Abst (dom, b)
else top
in
(top, top_ty, cu)
| Prod (dom, b) as top ->
let dom, cu_dom = force pb c dom mk_Type in
let (x, b) = Bindlib.unbind b in
let c = extend c x dom in
let b, b_s, cu_b = type_enforce pb c b in
let cu = cu_b || cu_dom in
let top =
if cu then
let b = Bindlib.(lift b |> bind_var x |> unbox) in
mk_Prod (dom, b)
else top
in
(top, b_s, cu)
| Appl (t, u) as top -> (
let t, t_ty, cu_t = infer pb c t in
let return m t u range =
let ty = Bindlib.subst range u and cu = cu_t || m in
if cu then (mk_Appl (t, u), ty, cu) else (top, ty, cu)
in
match Eval.whnf (classic c) t_ty with
| Prod (dom, range) ->
if Logger.log_enabled () then
log "Appl-prod arg [%a]" term u;
let u, cu_u = force pb c u dom in
return cu_u t u range
| Meta (_, _) ->
let u, u_ty, cu_u = infer pb c u in
let range =
unbox (LibMeta.bmake_codomain pb (boxed c) (lift u_ty))
in
unif pb c t_ty (mk_Prod (u_ty, range));
return cu_u t u range
| t_ty ->
let domain = LibMeta.bmake pb (boxed c) _Type in
let range = LibMeta.bmake_codomain pb (boxed c) domain in
let domain = unbox domain
and range = unbox range in
let t, cu_t' = coerce pb c t t_ty (mk_Prod (domain, range)) in
if Logger.log_enabled () then
log "Appl-default arg [%a]" term u;
let u, cu_u = force pb c u domain in
return (cu_t' || cu_u) t u range )
and infer : problem -> octxt -> term -> term * term * bool = fun pb c t ->
if Logger.log_enabled () then log "Infer [%a]" term t;
let t, t_ty, cu = infer_aux pb c t in
if Logger.log_enabled () then log "Inferred [%a:@ %a]" term t term t_ty;
(t, t_ty, cu)
(** {b NOTE} when unbinding a binder [b] (e.g. when inferring the type of an
abstraction [λ x, e]) in context [c], [c] is always extended, even if
binder [b] is constant. This is because during typechecking, the context
must contain all variables traversed to build appropriate meta-variables.
Otherwise, the term [λ a: _, λ b: _, b] will be transformed to [λ _: ?1,
λ b: ?2, b] whereas it should be [λ a: ?1.[], λ b: ?2.[a], b] *)
(** [noexn f cs c args] initialises {!val:constraints} to [cs],
calls [f c args] and returns [Some(r,cs)] where [r] is the value of
the call to [f] and [cs] is the list of constraints gathered by
[f]. Function [f] may raise [NotTypable], in which case [None] is
returned. *)
let noexn :
(problem -> octxt -> 'a -> 'b) -> problem -> ctxt -> 'a -> 'b option =
fun f pb c args ->
try Some (f pb (c, Ctxt.box_context c) args)
with NotTypable -> None
let infer_noexn pb c t : (term * term) option =
if Logger.log_enabled () then log "Top infer %a%a" ctxt c term t;
let infer pb c t = let (t,t_ty,_) = infer pb c t in (t, t_ty) in
noexn infer pb c t
let check_noexn pb c t a : term option =
if Logger.log_enabled () then log "Top check \"%a\"" typing (c, t, a);
let force pb c (t, a) = fst (force pb c t a) in
noexn force pb c (t, a)
let check_sort_noexn pb c t : (term * term) option =
if Logger.log_enabled () then
log "Top check sort %a%a" ctxt c term t;
let type_enforce pb c t = let (t, s, _) = type_enforce pb c t in (t, s) in
noexn type_enforce pb c t