package prbnmcn-dagger

  1. Overview
  2. Docs

Source file resampling.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
module type Particles = sig
  type p

  type o

  type r

  val iter : (p -> r -> unit) -> unit

  val fold : ('acc -> p -> r -> 'acc) -> 'acc -> 'acc

  val get_output : p -> o option

  val get_score : p -> r

  val append : p -> r -> unit

  val total : unit -> r

  val size : unit -> int

  val ess : unit -> r
end

type ('o, 'r) particles = (module Particles with type o = 'o and type r = 'r)

type ('o, 'r, 'resampling_state) strategy =
  target_size:int ->
  ('o, 'r) particles ->
  'resampling_state ->
  RNG.t ->
  'resampling_state

module Make_predefined
    (F : Intf.Field) (Sampler : sig
      val uniform : F.t -> RNG.t -> F.t
    end) =
struct
  let resampling_generic_iterative (type o) f ((module P) : (o, F.t) particles)
      (rng_state : RNG.t) =
    let cumulative = ref F.zero in
    let partition_index = ref 1 in
    let last = ref (f !partition_index rng_state) in
    let tot = P.total () in
    let res_w = F.div tot (F.of_int (P.size ())) in
    P.iter (fun particle w ->
        cumulative := F.add !cumulative w ;
        while F.(!last < !cumulative) do
          P.append particle res_w ;
          last := f !partition_index rng_state ;
          incr partition_index
        done)

  let in_unit_interval r = F.(zero <= r && r <= one)

  let stratified_resampling (type o) ~ess_threshold ~target_size
      ((module P) as particles : (o, F.t) particles) resampling_state rng_state
      =
    if P.size () < 2 then invalid_arg "stratified_resampling (P.size < 2)" ;
    if target_size < 2 then invalid_arg "stratified_resampling (card < 2)" ;
    if not (in_unit_interval ess_threshold) then
      invalid_arg "stratified_resampling (ess_threshold not in [0;1])" ;
    let ess = F.div (P.ess ()) (F.of_int (P.size ())) in
    if F.(ess < ess_threshold) then
      let tot = P.total () in
      let inv = F.div tot (F.of_int target_size) in
      resampling_generic_iterative
        (fun i rng_state ->
          let rand = Sampler.uniform inv rng_state in
          F.add F.(div (mul tot (of_int i)) (of_int target_size)) rand)
        particles
        rng_state
    else () ;
    resampling_state

  let systematic_resampling (type o) ~ess_threshold ~target_size
      ((module P) as particles : (o, F.t) particles) resampling_state rng_state
      =
    if P.size () < 2 then invalid_arg "systematic_resampling (P.size < 2)" ;
    if target_size < 2 then
      invalid_arg "systematic_resampling (target_size < 2)" ;
    if not (in_unit_interval ess_threshold) then
      invalid_arg "systematic_resampling (ess_threshold not in [0;1])" ;
    let ess = F.div (P.ess ()) (F.of_int (P.size ())) in
    if F.(ess < ess_threshold) then
      let tot = P.total () in
      let inv = F.div tot (F.of_int target_size) in
      let rand = Sampler.uniform inv rng_state in
      resampling_generic_iterative
        (fun i _rng_state ->
          F.add F.(div (mul tot (of_int i)) (of_int target_size)) rand)
        particles
        rng_state
    else () ;
    resampling_state
end
[@@inline]
OCaml

Innovation. Community. Security.