package octez-libs

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file linear_algebra.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
(*****************************************************************************)
(*                                                                           *)
(* MIT License                                                               *)
(* Copyright (c) 2022 Nomadic Labs <contact@nomadic-labs.com>                *)
(*                                                                           *)
(* Permission is hereby granted, free of charge, to any person obtaining a   *)
(* copy of this software and associated documentation files (the "Software"),*)
(* to deal in the Software without restriction, including without limitation *)
(* the rights to use, copy, modify, merge, publish, distribute, sublicense,  *)
(* and/or sell copies of the Software, and to permit persons to whom the     *)
(* Software is furnished to do so, subject to the following conditions:      *)
(*                                                                           *)
(* The above copyright notice and this permission notice shall be included   *)
(* in all copies or substantial portions of the Software.                    *)
(*                                                                           *)
(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)
(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  *)
(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL   *)
(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)
(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING   *)
(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER       *)
(* DEALINGS IN THE SOFTWARE.                                                 *)
(*                                                                           *)
(*****************************************************************************)

module type Ring_sig = sig
  type t

  val add : t -> t -> t

  val mul : t -> t -> t

  val negate : t -> t

  val zero : t

  val one : t

  val eq : t -> t -> bool
end

module type Field_sig = sig
  include Ring_sig

  val inverse_exn : t -> t
end

(** This refers to the mathematical generalization of vector space called
    "module", where the field of scalars is replaced by a ring *)
module type Module_sig = sig
  type t

  type matrix = t array array

  (** [zeros r c] is a matrix with [r] rows and [c] columns filled with zeros *)
  val zeros : int -> int -> matrix

  (** [identity n] is the identity matrix of dimension [n] *)
  val identity : int -> matrix

  (** matrix equality *)
  val equal : matrix -> matrix -> bool

  (** matrix addition *)
  val add : matrix -> matrix -> matrix

  (** matrix multiplication *)
  val mul : matrix -> matrix -> matrix

  (** matrix transposition *)
  val transpose : matrix -> matrix

  (** [row_add ~coeff i j m] adds to the i-th row, the j-th row times coeff in m *)
  val row_add : ?coeff:t -> int -> int -> matrix -> unit

  (** [row_swap i j m] swaps the i-th and j-th rows of m *)
  val row_swap : int -> int -> matrix -> unit

  (** [row_mul coeff i m] multiplies the i-th row by coeff in m *)
  val row_mul : t -> int -> matrix -> unit

  (** [filter_cols f m] removes the columns of [m] whose index does not satisfy [f] *)
  val filter_cols : (int -> bool) -> matrix -> matrix

  (** splits matrix [m] into the first n columns and the rest, producing two matrices *)
  val split_n : int -> matrix -> matrix * matrix
end

module type VectorSpace_sig = sig
  include Module_sig

  (** reduced row Echelon form of m *)
  val reduced_row_echelon_form : matrix -> matrix

  (** [inverse m] is the inverse matrix of m

      @raise [Invalid_argument] if [m] is not invertible *)
  val inverse : matrix -> matrix
end

module Make_Module (Ring : Ring_sig) : Module_sig with type t = Ring.t = struct
  type t = Ring.t

  type matrix = t array array

  let zeros r c = Array.make_matrix r c Ring.zero

  let identity n =
    Array.(init n (fun i -> init n Ring.(fun j -> if i = j then one else zero)))

  let equal = Array.(for_all2 (for_all2 Ring.eq))

  let add = Array.(map2 (map2 Ring.add))

  let mul m1 m2 =
    let nb_rows = Array.length m1 in
    let nb_cols = Array.length m2.(0) in
    let n = Array.length m1.(0) in
    assert (Array.length m2 = n) ;
    let p = zeros nb_rows nb_cols in
    for i = 0 to nb_rows - 1 do
      for j = 0 to nb_cols - 1 do
        for k = 0 to n - 1 do
          p.(i).(j) <- Ring.(add p.(i).(j) @@ mul m1.(i).(k) m2.(k).(j))
        done
      done
    done ;
    p

  let transpose m =
    let nb_rows = Array.length m in
    let nb_cols = Array.length m.(0) in
    Array.(init nb_cols (fun i -> init nb_rows (fun j -> m.(j).(i))))

  let row_add ?(coeff = Ring.one) i j m =
    m.(i) <- Array.map2 Ring.(fun a b -> add a (mul coeff b)) m.(i) m.(j)

  let row_swap i j m =
    let aux = m.(i) in
    m.(i) <- m.(j) ;
    m.(j) <- aux

  let row_mul coeff i m = m.(i) <- Array.map (Ring.mul coeff) m.(i)

  let filter_cols f =
    Array.map (fun row ->
        List.filteri (fun i _ -> f i) (Array.to_list row) |> Array.of_list)

  let split_n n m =
    (filter_cols (fun i -> i < n) m, filter_cols (fun i -> i >= n) m)
end

module Make_VectorSpace (Field : Field_sig) :
  VectorSpace_sig with type t = Field.t = struct
  include Make_Module (Field)

  let reduced_row_echelon_form m =
    let n = Array.length m in
    (* returns the first non-zero index in the row *)
    let find_pivot row =
      let rec aux cnt = function
        | [] -> None
        | x :: xs -> if Field.(eq zero x) then aux (cnt + 1) xs else Some cnt
      in
      aux 0 (Array.to_list row)
    in
    let move_zeros_to_bottom m =
      let is_non_zero_row = Array.exists (fun a -> not Field.(eq zero a)) in
      let rec aux nonzeros zeros = function
        | [] -> Array.of_list (List.rev nonzeros @ zeros)
        | r :: rs ->
            if is_non_zero_row r then aux (r :: nonzeros) zeros rs
            else aux nonzeros (r :: zeros) rs
      in
      aux [] [] (Array.to_list m)
    in
    let rec aux k =
      if k >= Array.length m then m
      else
        match find_pivot m.(k) with
        | Some j when j < n ->
            row_mul (Field.inverse_exn m.(k).(j)) k m ;
            Array.iteri
              (fun i _ ->
                if i <> k then row_add ~coeff:Field.(negate @@ m.(i).(j)) i k m)
              m ;
            row_swap k j m ;
            aux (k + 1)
        | _ -> aux (k + 1)
    in
    aux 0 |> move_zeros_to_bottom

  let inverse m =
    let n = Array.length m in
    let id_n = identity n in
    let augmented = Array.(map2 append m id_n) in
    let reduced = reduced_row_echelon_form augmented in
    let residue, inv = split_n n reduced in
    let is_zero_row = Array.for_all Field.(eq zero) in
    if Array.exists is_zero_row residue then
      raise @@ Invalid_argument "matrix [m] is not invertible"
    else inv
end
OCaml

Innovation. Community. Security.