Source file lambda_lifting.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
open! Stdlib
open Code
let debug = Debug.find "lifting"
let rec compute_depth program pc =
Code.preorder_traverse
{ fold = Code.fold_children }
(fun pc d ->
let block = Code.Addr.Map.find pc program.blocks in
List.fold_left block.body ~init:d ~f:(fun d (i, _) ->
match i with
| Let (_, Closure (_, (pc', _))) ->
let d' = compute_depth program pc' in
max d (d' + 1)
| _ -> d))
pc
program.blocks
0
let collect_free_vars program var_depth depth pc =
let vars = ref Var.Set.empty in
let baseline = Config.Param.lambda_lifting_baseline () in
let rec traverse pc =
Code.preorder_traverse
{ fold = Code.fold_children }
(fun pc () ->
let block = Code.Addr.Map.find pc program.blocks in
Freevars.iter_block_free_vars
(fun x ->
let idx = Var.idx x in
if idx < Array.length var_depth
then (
let d = var_depth.(idx) in
assert (d >= 0);
if d > baseline && d < depth then vars := Var.Set.add x !vars))
block;
List.iter block.body ~f:(fun (i, _) ->
match i with
| Let (_, Closure (_, (pc', _))) -> traverse pc'
| _ -> ()))
pc
program.blocks
()
in
traverse pc;
!vars
let mark_bound_variables var_depth block depth =
Freevars.iter_block_bound_vars (fun x -> var_depth.(Var.idx x) <- depth) block;
List.iter block.body ~f:(fun (i, _) ->
match i with
| Let (_, Closure (params, _)) ->
List.iter params ~f:(fun x -> var_depth.(Var.idx x) <- depth + 1)
| _ -> ())
let rec traverse var_depth (program, functions) pc depth limit =
let baseline = Config.Param.lambda_lifting_baseline () in
Code.preorder_traverse
{ fold = Code.fold_children }
(fun pc (program, functions) ->
let block = Code.Addr.Map.find pc program.blocks in
mark_bound_variables var_depth block depth;
if depth = baseline
then (
assert (List.is_empty functions);
let program, body =
List.fold_right block.body ~init:(program, []) ~f:(fun i (program, rem) ->
match i with
| (Let (_, Closure (_, (pc', _))), _loc) as i ->
let program, functions =
traverse var_depth (program, []) pc' (depth + 1) limit
in
program, List.rev_append functions (i :: rem)
| i -> program, i :: rem)
in
{ program with blocks = Addr.Map.add pc { block with body } program.blocks }, [])
else if depth < limit
then
List.fold_left block.body ~init:(program, functions) ~f:(fun st i ->
match i with
| Let (_, Closure (_, (pc', _))), _ ->
traverse var_depth st pc' (depth + 1) limit
| _ -> st)
else
let does_not_start_with_closure l =
match l with
| (Let (_, Closure _), _) :: _ -> false
| _ -> true
in
let rec rewrite_body first st l =
match l with
| ((Let (f, (Closure (_, (pc', _)) as cl)), loc) as i) :: rem
when first && does_not_start_with_closure rem ->
let threshold = Config.Param.lambda_lifting_threshold () in
let program, functions =
traverse var_depth st pc' (depth + 1) (depth + threshold)
in
if compute_depth program pc' + 1 >= threshold
then (
let free_vars = collect_free_vars program var_depth (depth + 1) pc' in
let s =
Var.Set.fold
(fun x m -> Var.Map.add x (Var.fork x) m)
free_vars
Var.Map.empty
in
let program = Subst.cont (Subst.from_map s) pc' program in
let f' = try Var.Map.find f s with Not_found -> Var.fork f in
let s = Var.Map.bindings (Var.Map.remove f s) in
let f'' = Var.fork f in
if debug ()
then
Format.eprintf
"LIFT %s (depth:%d free_vars:%d inner_depth:%d)@."
(Code.Var.to_string f'')
depth
(Var.Set.cardinal free_vars)
(compute_depth program pc');
let pc'' = program.free_pc in
let bl =
{ params = []
; body = [ Let (f', cl), noloc ]
; branch = Return f', noloc
}
in
let program =
{ program with
free_pc = pc'' + 1
; blocks = Addr.Map.add pc'' bl program.blocks
}
in
let functions =
(Let (f'', Closure (List.map s ~f:snd, (pc'', []))), loc) :: functions
in
let rem', st = rewrite_body false (program, functions) rem in
( (Let (f, Apply { f = f''; args = List.map ~f:fst s; exact = true }), loc)
:: rem'
, st ))
else
let rem', st = rewrite_body false (program, functions) rem in
i :: rem', st
| ((Let (_, Closure (_, (pc', _))), _) as i) :: rem ->
let st = traverse var_depth st pc' (depth + 1) limit in
let rem', st = rewrite_body false st rem in
i :: rem', st
| i :: rem ->
let rem', st = rewrite_body (does_not_start_with_closure l) st rem in
i :: rem', st
| [] -> [], st
in
let body, (program, functions) =
rewrite_body true (program, functions) block.body
in
( { program with blocks = Addr.Map.add pc { block with body } program.blocks }
, functions ))
pc
program.blocks
(program, functions)
let f program =
let t = Timer.make () in
let nv = Var.count () in
let var_depth = Array.make nv (-1) in
let program, functions =
let threshold = Config.Param.lambda_lifting_threshold () in
let baseline = Config.Param.lambda_lifting_baseline () in
traverse var_depth (program, []) program.start 0 (baseline + threshold)
in
assert (List.is_empty functions);
if Debug.find "times" () then Format.eprintf " lambda lifting: %a@." Timer.print t;
program