package moonpool

  1. Overview
  2. Docs

Source file moonpool_forkjoin.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
module A = Moonpool.Atomic
module Domain_ = Moonpool_private.Domain_

module State_ = struct
  type error = exn * Printexc.raw_backtrace
  type 'a or_error = ('a, error) result

  type ('a, 'b) t =
    | Init
    | Left_solved of 'a or_error
    | Right_solved of 'b or_error * Trigger.t
    | Both_solved of 'a or_error * 'b or_error

  let get_exn_ (self : _ t A.t) =
    match A.get self with
    | Both_solved (Ok a, Ok b) -> a, b
    | Both_solved (Error (exn, bt), _) | Both_solved (_, Error (exn, bt)) ->
      Printexc.raise_with_backtrace exn bt
    | _ -> assert false

  let rec set_left_ (self : _ t A.t) (left : _ or_error) =
    let old_st = A.get self in
    match old_st with
    | Init ->
      let new_st = Left_solved left in
      if not (A.compare_and_set self old_st new_st) then (
        Domain_.relax ();
        set_left_ self left
      )
    | Right_solved (right, tr) ->
      let new_st = Both_solved (left, right) in
      if not (A.compare_and_set self old_st new_st) then (
        Domain_.relax ();
        set_left_ self left
      ) else
        Trigger.signal tr
    | Left_solved _ | Both_solved _ -> assert false

  let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit =
    let old_st = A.get self in
    match old_st with
    | Left_solved left ->
      let new_st = Both_solved (left, right) in
      if not (A.compare_and_set self old_st new_st) then set_right_ self right
    | Init ->
      (* we are first arrived, we suspend until the left computation is done *)
      let trigger = Trigger.create () in
      let must_await = ref true in

      while
        let old_st = A.get self in
        match old_st with
        | Init ->
          (* setup trigger so that left computation will wake us up *)
          not (A.compare_and_set self old_st (Right_solved (right, trigger)))
        | Left_solved left ->
          (* other thread is done, no risk of race condition *)
          A.set self (Both_solved (left, right));
          must_await := false;
          false
        | Right_solved _ | Both_solved _ -> assert false
      do
        ()
      done;

      (* wait for the other computation to be done *)
      if !must_await then Trigger.await trigger |> Option.iter Exn_bt.raise
    | Right_solved _ | Both_solved _ -> assert false
end

let both f g : _ * _ =
  let module ST = State_ in
  let st = A.make ST.Init in

  let runner =
    match Runner.get_current_runner () with
    | None -> invalid_arg "Fork_join.both must be run from within a runner"
    | Some r -> r
  in

  (* start computing [f] in the background *)
  Runner.run_async runner (fun () ->
      try
        let res = f () in
        ST.set_left_ st (Ok res)
      with exn ->
        let bt = Printexc.get_raw_backtrace () in
        ST.set_left_ st (Error (exn, bt)));

  let res_right =
    try Ok (g ())
    with exn ->
      let bt = Printexc.get_raw_backtrace () in
      Error (exn, bt)
  in

  ST.set_right_ st res_right;
  ST.get_exn_ st

let both_ignore f g = ignore (both f g : _ * _)

let for_ ?chunk_size n (f : int -> int -> unit) : unit =
  if n > 0 then (
    let runner =
      match Runner.get_current_runner () with
      | None -> failwith "forkjoin.for_: must be run inside a moonpool runner."
      | Some r -> r
    in
    let failure = A.make None in
    let missing = A.make n in

    let chunk_size =
      match chunk_size with
      | Some cs -> max 1 (min n cs)
      | None ->
        (* guess: try to have roughly one task per core *)
        max 1 (1 + (n / Moonpool.Private.num_domains ()))
    in

    let trigger = Trigger.create () in

    let task_for ~offset ~len_range =
      match f offset (offset + len_range - 1) with
      | () ->
        if A.fetch_and_add missing (-len_range) = len_range then
          (* all tasks done successfully *)
          Trigger.signal trigger
      | exception exn ->
        let bt = Printexc.get_raw_backtrace () in
        if Option.is_none (A.exchange failure (Some (Exn_bt.make exn bt))) then
          (* first one to fail, and [missing] must be >= 2
             because we're not decreasing it. *)
          Trigger.signal trigger
    in

    let i = ref 0 in
    while !i < n do
      let offset = !i in

      let len_range = min chunk_size (n - offset) in
      assert (offset + len_range <= n);

      Runner.run_async runner (fun () -> task_for ~offset ~len_range);
      i := !i + len_range
    done;

    Trigger.await trigger |> Option.iter Exn_bt.raise;
    Option.iter Exn_bt.raise @@ A.get failure;
    ()
  )

let all_array ?chunk_size (fs : _ array) : _ array =
  let len = Array.length fs in
  let arr = Array.make len None in

  (* parallel for *)
  for_ ?chunk_size len (fun low high ->
      for i = low to high do
        let x = fs.(i) () in
        arr.(i) <- Some x
      done);

  (* get all results *)
  Array.map
    (function
      | None -> assert false
      | Some x -> x)
    arr

let all_list ?chunk_size fs : _ list =
  Array.to_list @@ all_array ?chunk_size @@ Array.of_list fs

let all_init ?chunk_size n f : _ list =
  let arr = Array.make n None in

  for_ ?chunk_size n (fun low high ->
      for i = low to high do
        let x = f i in
        arr.(i) <- Some x
      done);

  (* get all results *)
  List.init n (fun i ->
      match arr.(i) with
      | None -> assert false
      | Some x -> x)

let map_array ?chunk_size f arr : _ array =
  let n = Array.length arr in
  let res = Array.make n None in

  for_ ?chunk_size n (fun low high ->
      for i = low to high do
        res.(i) <- Some (f arr.(i))
      done);

  (* get all results *)
  Array.map
    (function
      | None -> assert false
      | Some x -> x)
    res

let map_list ?chunk_size f (l : _ list) : _ list =
  let arr = Array.of_list l in
  let n = Array.length arr in
  let res = Array.make n None in

  for_ ?chunk_size n (fun low high ->
      for i = low to high do
        res.(i) <- Some (f arr.(i))
      done);

  (* get all results *)
  List.init n (fun i ->
      match res.(i) with
      | None -> assert false
      | Some x -> x)
OCaml

Innovation. Community. Security.