package octez-libs

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file poseidon_core.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
module type PARAMETERS = sig
  val width : int

  val full_rounds : int

  val partial_rounds : int

  val round_constants : string array

  val mds_matrix : string array array

  val partial_round_idx_to_permute : int
end

module type STRATEGY = sig
  type scalar

  type state

  val init : ?input_length:int -> scalar array -> state

  val apply_perm : state -> unit

  val get : state -> scalar array

  val input_length : state -> int option
end

module type HASH = sig
  type scalar

  type ctxt

  val init : ?input_length:int -> unit -> ctxt

  val digest : ctxt -> scalar array -> ctxt

  val get : ctxt -> scalar
end

module Make (C : PARAMETERS) (Scalar : Bls12_381.Ff_sig.PRIME) = struct
  open C

  (* Verify the constants are consistent *)
  let () =
    assert (Array.length mds_matrix = width) ;
    assert (Array.for_all (fun line -> Array.length line = width) mds_matrix)

  let mds_matrix = Array.map (Array.map Scalar.of_string) mds_matrix

  let round_constants = Array.map Scalar.of_string round_constants

  (* Initialize only once an array for the MDS matrix multiplication *)
  let res = Array.make width Scalar.zero

  module Strategy = struct
    type scalar = Scalar.t

    type state = {
      mutable i_round_key : int;
      state : Scalar.t array;
      input_length : int option;
    }

    let init ?input_length state =
      {i_round_key = 0; state = Array.copy state; input_length}

    let get_next_round_key s =
      let v = round_constants.(s.i_round_key) in
      s.i_round_key <- s.i_round_key + 1 ;
      v

    let s_box x = Scalar.(square (square x) * x)

    (* Functions prefixed with apply_ are modifying the state given in
       parameters
    *)
    let apply_round_key s =
      let state = s.state in
      for i = 0 to Array.length state - 1 do
        state.(i) <- Scalar.(get_next_round_key s + state.(i))
      done

    let apply_s_box_last_elem s =
      let s = s.state in
      s.(partial_round_idx_to_permute) <- s_box s.(partial_round_idx_to_permute)

    let apply_s_box s =
      let s = s.state in
      for i = 0 to Array.length s - 1 do
        s.(i) <- s_box s.(i)
      done

    let apply_eval_matrix m v =
      let v = v.state in
      for j = 0 to width - 1 do
        for k = 0 to width - 1 do
          res.(k) <- Scalar.(res.(k) + (m.(k).(j) * v.(j)))
        done
      done ;
      for j = 0 to width - 1 do
        v.(j) <- res.(j) ;
        res.(j) <- Scalar.zero
      done

    let apply_partial_round s =
      apply_round_key s ;
      apply_s_box_last_elem s ;
      apply_eval_matrix mds_matrix s

    let apply_full_round s =
      apply_round_key s ;
      apply_s_box s ;
      apply_eval_matrix mds_matrix s

    let apply_perm s =
      s.i_round_key <- 0 ;
      for _i = 0 to (full_rounds / 2) - 1 do
        apply_full_round s
      done ;
      for _i = 0 to partial_rounds - 1 do
        apply_partial_round s
      done ;
      for _i = 0 to (full_rounds / 2) - 1 do
        apply_full_round s
      done

    let get s = Array.copy s.state

    let add_cst s idx v =
      assert (idx <= width) ;
      s.state.(idx) <- Scalar.(s.state.(idx) + v)

    let input_length s = s.input_length
  end

  module Hash = struct
    type scalar = Scalar.t

    type ctxt = Strategy.state

    let init ?input_length () =
      let state = Array.make width Scalar.zero in
      match input_length with
      | None -> Strategy.init state
      | Some input_length -> Strategy.init ~input_length state

    let digest state data =
      let l = Array.length data in
      let assert_length expected =
        let error_msg =
          Format.sprintf "digest expects data of length %d, %d given" expected l
        in
        if l <> expected then raise @@ Invalid_argument error_msg
      in
      let input_length_opt = Strategy.input_length state in
      Option.iter assert_length input_length_opt ;
      let with_padding = Option.is_none input_length_opt in

      let chunk_size = width - 1 in
      let nb_full_chunk = (l - if with_padding then 0 else 1) / chunk_size in
      let r = l mod chunk_size in
      (* we process first all the full chunks *)
      for i = 0 to nb_full_chunk - 1 do
        let ofs = i * chunk_size in
        for j = 0 to chunk_size - 1 do
          Strategy.add_cst state (1 + j) data.(ofs + j)
        done ;
        Strategy.apply_perm state
      done ;
      (* we add the last partial chunk, add pad with one *)
      let r = if with_padding then r else l - (nb_full_chunk * (width - 1)) in
      for j = 0 to r - 1 do
        let idx = 1 + j in
        Strategy.add_cst state idx data.((nb_full_chunk * chunk_size) + j)
      done ;
      if with_padding then Strategy.add_cst state (r + 1) Scalar.one ;
      Strategy.apply_perm state ;
      state

    let get (ctxt : ctxt) = ctxt.state.(1)
  end
end
OCaml

Innovation. Community. Security.