package owl

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file owl_nlp_similarity.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# 1 "src/owl/nlp/owl_nlp_similarity.ml"
(*
 * OWL - OCaml Scientific Computing
 * Copyright (c) 2016-2022 Liang Wang <liang@ocaml.xyz>
 *)

type t =
  | Cosine
  | Euclidean
  | KL_D

let to_string = function
  | Cosine    -> "Cosine"
  | Euclidean -> "Euclidean"
  | KL_D      -> "Kullback–Leibler divergence"


let kl_distance _ _ = 0.

let cosine_distance x y =
  let hy = Hashtbl.create (Array.length y) in
  Array.iter (fun (k, v) -> Hashtbl.add hy k v) y;
  let z = ref 0. in
  Array.iter
    (fun (k, v) ->
      match Hashtbl.mem hy k with
      | true  -> z := !z +. (v *. Hashtbl.find hy k)
      | false -> ())
    x;
  (* return the negative since high similarity indicates small distance *)
  -. !z


let inner_product x y =
  let hy = Hashtbl.create (Array.length y) in
  Array.iter (fun (k, v) -> Hashtbl.add hy k v) y;
  let z = ref 0. in
  Array.iter
    (fun (k, v) ->
      match Hashtbl.mem hy k with
      | true  -> z := !z +. (v *. Hashtbl.find hy k)
      | false -> ())
    x;
  !z


(* this function aussmes that the elements' ids have been sorted in increasing
  order, then perform inner product operation of both passed in vectors.
 *)
let inner_product_fast x y =
  (*
  Array.sort (fun a b -> Stdlib.compare (fst a) (fst b)) x;
  Array.sort (fun a b -> Stdlib.compare (fst a) (fst b)) y;
  *)
  let xi = ref 0 in
  let yi = ref 0 in
  let xn = Array.length x in
  let yn = Array.length y in
  let z = ref 0. in
  while !xi < xn && !yi < yn do
    let xk, xv = x.(!xi) in
    let yk, yv = y.(!yi) in
    if xk = yk
    then (
      z := !z +. (xv *. yv);
      xi := !xi + 1;
      yi := !yi + 1)
    else if xk < yk
    then xi := !xi + 1
    else if xk > yk
    then yi := !yi + 1
  done;
  !z


let euclidean_distance x y =
  let h = Hashtbl.create (Array.length x) in
  Array.iter (fun (k, a) -> Hashtbl.add h k a) x;
  Array.iter
    (fun (k, b) ->
      match Hashtbl.mem h k with
      | true  ->
        let a = Hashtbl.find h k in
        Hashtbl.replace h k (a -. b)
      | false -> Hashtbl.add h k b)
    y;
  let z = ref 0. in
  Hashtbl.iter (fun _ v -> z := !z +. (v *. v)) h;
  sqrt !z


let distance = function
  | Cosine    -> cosine_distance
  | Euclidean -> euclidean_distance
  | KL_D      -> kl_distance

(* ends here *)
OCaml

Innovation. Community. Security.