package octez-libs

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file utils.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
(*****************************************************************************)
(*                                                                           *)
(* MIT License                                                               *)
(* Copyright (c) 2022 Nomadic Labs <contact@nomadic-labs.com>                *)
(*                                                                           *)
(* Permission is hereby granted, free of charge, to any person obtaining a   *)
(* copy of this software and associated documentation files (the "Software"),*)
(* to deal in the Software without restriction, including without limitation *)
(* the rights to use, copy, modify, merge, publish, distribute, sublicense,  *)
(* and/or sell copies of the Software, and to permit persons to whom the     *)
(* Software is furnished to do so, subject to the following conditions:      *)
(*                                                                           *)
(* The above copyright notice and this permission notice shall be included   *)
(* in all copies or substantial portions of the Software.                    *)
(*                                                                           *)
(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)
(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  *)
(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL   *)
(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)
(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING   *)
(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER       *)
(* DEALINGS IN THE SOFTWARE.                                                 *)
(*                                                                           *)
(*****************************************************************************)

module S = Csir.Scalar

(* Difference between the scalar order and the succeeding power of 2 *)
let alpha = Z.(shift_left one (Z.numbits S.order) - S.order)

(*
  Plompiler uses lists of bits where the lower positions are less significant
  e.g. the least significant bit of [bs] is [List.nth bs 0].

  [bitlist] default endianess follows the semantics of [get_uint16] that can be
  tested in utop:

  Bytes.get_uint16_le (Bytes.of_string "\001\000") 0;;
  - : int = 1

  Bytes.get_uint16_be (Bytes.of_string "\001\000") 0;;
  - : int = 256
*)

let bitlist : le:bool -> bytes -> bool list =
 fun ~le b ->
  let l = Bytes.length b in
  (* Depending on endianess we start, stop and step in different directions. *)
  let start = if le then 0 else l - 1 in
  let stop = if le then l else -1 in
  let next = if le then succ else pred in
  let rec loop_byte acc n =
    if n = stop then acc
    else
      let byte = Bytes.get_uint8 b n in
      let rec loop_bit acc m =
        if m = 8 then acc
        else
          (* For each position in a byte, a mask is built where all bits are
             zero except the one at position. The masked byte is compute by
             locagical AND of the mask and the current byte. If the masked
             result is zero the bit at position is zero, otherwise it's one. *)
          let mask = 1 lsl m in
          let bit = byte land mask in
          let bit = if bit = 0 then false else true in
          loop_bit (bit :: acc) (m + 1)
      in
      let acc = loop_bit acc 0 in
      loop_byte acc (next n)
  in
  List.rev @@ loop_byte [] start

(* Takes a list of booleans (typically from the Plompiler Bytes representation)
   and returns OCaml Bytes. Works only if the input length is a multiple of a
   byte. *)
let of_bitlist : le:bool -> bool list -> bytes =
 fun ~le bl ->
  assert (List.length bl mod 8 = 0) ;
  let rec loop_byte acc rest =
    match rest with
    | [] ->
        let res = if le then List.rev acc else acc in
        Bytes.(concat empty res)
    | _ ->
        let rec loop_bit acc pos rest =
          if pos = 8 then (acc, rest)
          else
            match rest with
            | [] -> assert false
            | bit :: rest ->
                (* For each position in a byte, a mask is built where all bits
                   are zero except the one at position. The mask is then summed
                   to the accumulator using a logical OR. *)
                let mask = if bit then 1 lsl pos else 0 in
                let acc = acc lor mask in
                loop_bit acc (succ pos) rest
        in
        (* Each sequence of 8 bits is converted to an integer in the previous
           loop and here it is interpreted as a uint8. *)
        let byte_as_int, rest = loop_bit 0 0 rest in
        let byte = Bytes.create 1 in
        Bytes.set_uint8 byte 0 byte_as_int ;
        loop_byte (byte :: acc) rest
  in
  loop_byte [] bl

let bytes_of_hex hs =
  let h = `Hex hs in
  Hex.to_bytes h

let hex_of_bytes bs = Hex.of_bytes bs |> Hex.show

let bool_list_to_scalar : bool list -> S.t =
 fun b_list ->
  let res, _ =
    List.fold_left
      (fun (acc_res, acc_p) b ->
        let acc_res = if b then S.(acc_res + acc_p) else acc_res in
        let acc_p = S.double acc_p in
        (acc_res, acc_p))
      (S.zero, S.one)
      b_list
  in
  res

(* We use little endian notation (the lsb is on the head) *)
let bool_list_to_z : bool list -> Z.t =
 fun b_list ->
  let res, _ =
    List.fold_left
      (fun (acc_res, acc_p) b ->
        let acc_res = if b then Z.(acc_res + acc_p) else acc_res in
        let acc_p = Z.(acc_p + acc_p) in
        (acc_res, acc_p))
      (Z.zero, Z.one)
      b_list
  in
  res

(* We use little endian notation (the lsb is on the head) *)
let bool_list_of_z : ?nb_bits:int -> Z.t -> bool list =
 fun ?nb_bits z ->
  let two = Z.of_int 2 in
  let rec aux bits z = function
    | 0 -> List.rev bits
    | n ->
        let b = Z.(equal (z mod two) one) in
        aux (b :: bits) (Z.div z two) (n - 1)
  in
  aux [] z @@ Option.value ~default:(Z.numbits z) nb_bits

let split_exactly list size_chunk =
  let len = List.length list in
  let nb_chunks = len / size_chunk in
  assert (len = size_chunk * nb_chunks) ;
  List.init nb_chunks (fun i ->
      let array = Array.of_list list in
      let array = Array.sub array (i * size_chunk) size_chunk in
      Array.to_list array)

let bool_list_change_endianness b =
  assert (List.length b mod 8 = 0) ;
  split_exactly b 8 |> List.rev |> List.concat

let limbs_of_bool_list ~nb_bits bl =
  let bl = split_exactly bl nb_bits in
  let sum x =
    List.fold_left
      (fun acc a -> (acc lsl 1) + if a then 1 else 0)
      0
      (List.rev x)
  in
  List.map sum bl

module Z = struct
  include Z

  let t : t Repr.t =
    Repr.(
      map
        bytes
        (fun bs -> Z.of_bits (Bytes.unsafe_to_string bs))
        (fun s -> Z.to_bits s |> Bytes.of_string))
end

let ( %! ) = Z.rem

(* [next_multiple_of k n] is the first multiple of [k : int] greater than
   or equal to [n : int] *)
let next_multiple_of k n = k * (1 + ((n - 1) / k))

(* [is_power_of_2 n] returns [true] iff [n : Z.t] is a perfect power of 2 *)
let is_power_of_2 n = Z.log2 n = Z.log2up n

(* [min_nb_limbs ~modulus ~base] is the smallest integer k such that
   base^k >= modulus *)
let min_nb_limbs ~modulus ~base =
  assert (Z.(modulus > one)) ;
  assert (Z.(base > one)) ;
  (* we want to compute ceil(log_base(modulus)), but we use this iterative
     method as we only have support for log2 (and not log_base) over Z.t *)
  let rec aux acc k =
    if acc >= modulus then k else aux Z.(acc * base) (k + 1)
  in
  aux base 1

(* [z_to_limbs ~len ~base n] takes an integer (n : Z.t) and returns a Z.t list
   of [len] elements encoding its big-endian representation in base [base].
   It fails if [n < 0 or n >= base^len]. *)
let z_to_limbs ~len ~base n =
  let rec aux output n =
    let q, r = Z.div_rem n base in
    if Z.(q = zero) then r :: output else aux (r :: output) q
  in
  if n < Z.zero then
    raise @@ Failure "z_to_limbs: n must be greater than or equal to zero" ;
  let limbs = aux [] n in
  let nb_limbs = List.length limbs in
  if nb_limbs > len then
    raise @@ Failure "z_to_limbs: n must be strictly lower than base^len"
  else List.init (len - nb_limbs) (Fun.const Z.zero) @ limbs

(* [z_of_limbs ~base ls] returns the Z.t encoded in the given Z.t list [ls],
   its big-endian representation in base [base]. *)
let z_of_limbs ~base limbs =
  List.fold_left (fun acc x -> Z.((base * acc) + x)) Z.zero limbs

(* [mod_add_limbs ~modulus ~base xs ys] returns the result of adding [xs]
   and [ys] modulo [modulus], where the inputs and the output are in big-endian
   form in base [base]. *)
let mod_add_limbs ~modulus ~base xs ys =
  let nb_limbs = List.length xs in
  assert (List.compare_length_with ys nb_limbs = 0) ;
  let x = z_of_limbs ~base xs in
  let y = z_of_limbs ~base ys in
  let z = Z.((x + y) %! modulus) in
  let z = if z < Z.zero then Z.(z + modulus) else z in
  z_to_limbs ~len:nb_limbs ~base z

let mod_sub_limbs ~modulus ~base xs ys =
  mod_add_limbs ~modulus ~base xs (List.map Z.neg ys)

(* [mod_mul_limbs ~modulus ~base xs ys] returns the result of multiplying [xs]
   by [ys] modulo [modulus], where the inputs and the output are in big-endian
   form in base [base]. *)
let mod_mul_limbs ~modulus ~base xs ys =
  let nb_limbs = List.length xs in
  assert (List.compare_length_with ys nb_limbs = 0) ;
  let x = z_of_limbs ~base xs in
  let y = z_of_limbs ~base ys in
  let z = Z.(x * y %! modulus) in
  let z = if z < Z.zero then Z.(z + modulus) else z in
  z_to_limbs ~len:nb_limbs ~base z

(* [mod_div_limbs ~modulus ~base xs ys] returns the result of dividing [xs]
   by [ys] modulo [modulus], where the inputs and the output are in big-endian
   form in base [base]. Dividing [x] by [y] raises a failure if the division cannot be
   performed, i.e. if there does not exist [z] s.t. [x = z * y (mod modulus)].
   Note that division may be performed even if the divisor is not invertible
   with respect to the given modulus, e.g., dividing [10] by [2] will always
   be possible (and result in [5]), even if the modulus is even. *)
let mod_div_limbs ~modulus ~base xs ys =
  let nb_limbs = List.length xs in
  assert (List.compare_length_with ys nb_limbs = 0) ;
  let x = z_of_limbs ~base xs in
  let y = z_of_limbs ~base ys in
  let d, y_inv, _v = Z.gcdext y modulus in
  if Z.(rem x d <> zero) then
    raise
    @@ Failure
         (Format.sprintf
            "mod_div_limbs: %s is not divisible by %s (modulo %s)"
            (Z.to_string x)
            (Z.to_string y)
            (Z.to_string modulus)) ;
  let z = Z.(divexact x d * y_inv %! modulus) in
  let z = if z < Z.zero then Z.(z + modulus) else z in
  z_to_limbs ~len:nb_limbs ~base z

let rec transpose = function
  | [] | [] :: _ -> []
  | rows -> List.(map hd rows :: (transpose @@ map tl rows))

let of_bytes repr bs =
  Stdlib.Result.get_ok
  @@ Repr.(unstage @@ of_bin_string repr) (Bytes.unsafe_to_string bs)

let to_bytes repr e =
  Bytes.unsafe_of_string @@ Repr.(unstage @@ to_bin_string repr) e

let tables_cs_encoding_t : (string list * Csir.CS.t) Repr.t =
  let open Repr in
  pair (list string) Csir.CS.t

let save_cs_to_file path tables cs =
  (*   let outc = open_out path in *)
  (*   let encoder = Jsonm.encoder (`Channel outc) in *)
  (*   Repr.encode_json tables_cs_encoding_t encoder (tables, cs); *)
  (*   close_out outc *)
  let s = Repr.to_json_string tables_cs_encoding_t (tables, cs) in
  let outc = open_out path in
  output_string outc s ;
  close_out outc

let load_cs_from_file path =
  if not (Sys.file_exists path) then
    raise
    @@ Invalid_argument
         (Printf.sprintf "load_cs_from_file: %s does not exist." path) ;
  (*   let inc = open_in path in *)
  (*   let decoder = Jsonm.decoder (`Channel inc) in *)
  (*   let res = *)
  (*     Repr.decode_json tables_cs_encoding_t decoder |> Stdlib.Result.get_ok *)
  (*   in *)
  (*   close_in inc; *)
  (*   res *)
  let inc = open_in path in
  let content = really_input_string inc (in_channel_length inc) in
  let res =
    Repr.of_json_string tables_cs_encoding_t content |> Stdlib.Result.get_ok
  in
  close_in inc ;
  res

let get_circuit_id cs =
  let serialized_bytes = to_bytes Csir.CS.t cs in
  Hacl_star.Hacl.Blake2b_32.hash serialized_bytes 32 |> Hex.of_bytes |> Hex.show

let circuit_dir =
  match Sys.getenv_opt "TMPDIR" with
  | None -> "/tmp/plompiler"
  | Some dir -> dir ^ "/plompiler"

let circuit_path s =
  if not @@ Sys.file_exists circuit_dir then Sys.mkdir circuit_dir 0o755 ;
  circuit_dir ^ "/" ^ s

let dump_label_traces path (cs : Csir.CS.t) =
  let outc = open_out path in
  List.iter
    Csir.CS.(
      Array.iter (fun c ->
          Printf.fprintf outc "%s 1\n" @@ String.concat "; " (List.rev c.label)))
    cs ;
  close_out outc

let dump_label_range_checks_traces path fg =
  let outc = open_out path in
  List.iter
    (fun (label, nb) ->
      Printf.fprintf outc "%s %d\n" (String.concat "; " label) nb)
    fg ;
  close_out outc
OCaml

Innovation. Community. Security.