package owl

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file owl_dataset.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# 1 "src/owl/misc/owl_dataset.ml"
(*
 * OWL - OCaml Scientific and Engineering Computing
 * Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
 *)

(** Dataset: easy access to various datasets *)

open Owl_types

let remote_data_path () = "https://github.com/ryanrhymes/owl_dataset/raw/master/"

let local_data_path () : string =
  let home = Sys.getenv "HOME" ^ "/.owl" in
  let d = home ^ "/dataset" in
  Owl_log.info "create %s if not present" d;
  (* Note: use of Sys.file_exist is racy *)
  (try Unix.mkdir home 0o755 with
  | Unix.Unix_error (EEXIST, _, _) -> ());
  (try Unix.mkdir d 0o755 with
  | Unix.Unix_error (EEXIST, _, _) -> ());
  d


let download_data fname =
  let fn0 = remote_data_path () ^ fname in
  let fn1 = local_data_path () ^ fname in
  let cmd0 = "wget " ^ fn0 ^ " -O " ^ fn1 in
  let cmd1 = "gunzip " ^ fn1 in
  ignore (Sys.command cmd0);
  ignore (Sys.command cmd1)


let download_all () =
  let l =
    [ "stopwords.txt.gz"
    ; "enron.test.gz"
    ; "enron.train.gz"
    ; "nips.test.gz"
    ; "nips.train.gz"
    ; "mnist-test-images.gz"
    ; "mnist-test-labels.gz"
    ; "mnist-test-lblvec.gz"
    ; "mnist-train-images.gz"
    ; "mnist-train-labels.gz"
    ; "mnist-train-lblvec.gz"
    ; "cifar10_test_data.gz"
    ; "cifar10_test_labels.gz"
    ; "cifar10_test_filenames.gz"
    ; "cifar10_test_lblvec.gz"
    ; "cifar10_train1_data.gz"
    ; "cifar10_train1_labels.gz"
    ; "cifar10_train1_filenames.gz"
    ; "cifar10_train1_lblvec.gz"
    ; "cifar10_train2_data.gz"
    ; "cifar10_train2_labels.gz"
    ; "cifar10_train2_filenames.gz"
    ; "cifar10_train2_lblvec.gz"
    ; "cifar10_train3_data.gz"
    ; "cifar10_train3_labels.gz"
    ; "cifar10_train3_filenames.gz"
    ; "cifar10_train3_lblvec.gz"
    ; "cifar10_train4_data.gz"
    ; "cifar10_train4_labels.gz"
    ; "cifar10_train4_filenames.gz"
    ; "cifar10_train4_lblvec.gz"
    ; "cifar10_train5_data.gz"
    ; "cifar10_train5_labels.gz"
    ; "cifar10_train5_filenames.gz"
    ; "cifar10_train5_lblvec.gz"
    ]
  in
  List.iter (fun fname -> download_data fname) l


let draw_samples x y n =
  let x', y', _ = Owl_dense_matrix_generic.draw_rows2 ~replacement:false x y n in
  x', y'


(* load mnist train data, the return is a triplet. The first is a 60000 x 784
  matrix where each row represents a 28 x 28 image. The second is label and the
  third is the corresponding unravelled row vector of the label. *)
let load_mnist_train_data () =
  let p = local_data_path () in
  ( Owl_dense_matrix.S.load (p ^ "mnist-train-images")
  , Owl_dense_matrix.S.load (p ^ "mnist-train-labels")
  , Owl_dense_matrix.S.load (p ^ "mnist-train-lblvec") )


let load_mnist_test_data () =
  let p = local_data_path () in
  ( Owl_dense_matrix.S.load (p ^ "mnist-test-images")
  , Owl_dense_matrix.S.load (p ^ "mnist-test-labels")
  , Owl_dense_matrix.S.load (p ^ "mnist-test-lblvec") )


let print_mnist_image x =
  Owl_dense_matrix_generic.reshape x [| 28; 28 |]
  |> Owl_dense_matrix_generic.iter_rows (fun v ->
         Owl_dense_matrix_generic.iter
           (function
             | 0. -> Printf.printf " "
             | _  -> Printf.printf "■")
           v;
         print_endline "")


(* similar to load_mnist_train_data but returns [x] as [*,28,28,1] ndarray *)
let load_mnist_train_data_arr () =
  let x, label, y = load_mnist_train_data () in
  let m = Owl_dense_matrix.S.row_num x in
  let x = Owl_dense_ndarray.S.reshape x [| m; 28; 28; 1 |] in
  x, label, y


let load_mnist_test_data_arr () =
  let x, label, y = load_mnist_test_data () in
  let m = Owl_dense_matrix.S.row_num x in
  let x = Owl_dense_ndarray.S.reshape x [| m; 28; 28; 1 |] in
  x, label, y


(* load cifar train data, there are five batches in total. The loaded data is a
  10000 * 3072 matrix. Each row represents a 32 x 32 image of three colour
  channels, unravelled into a row vector. The labels are also returned. *)
let load_cifar_train_data batch =
  let p = local_data_path () in
  ( Owl_dense_ndarray.S.load (p ^ "cifar10_train" ^ string_of_int batch ^ "_data")
  , Owl_dense_matrix.S.load (p ^ "cifar10_train" ^ string_of_int batch ^ "_labels")
  , Owl_dense_matrix.S.load (p ^ "cifar10_train" ^ string_of_int batch ^ "_lblvec") )


let load_cifar_test_data () =
  let p = local_data_path () in
  ( Owl_dense_ndarray.S.load (p ^ "cifar10_test_data")
  , Owl_dense_matrix.S.load (p ^ "cifar10_test_labels")
  , Owl_dense_matrix.S.load (p ^ "cifar10_test_lblvec") )


let draw_samples_cifar x y n =
  let col_num = (Owl_dense_ndarray_generic.shape x).(0) in
  let a = Array.init col_num (fun i -> i) in
  let a = Owl_stats.choose a n |> Array.to_list in
  ( Owl_dense_ndarray.S.get_fancy [ L a; R []; R []; R [] ] x
  , Owl_dense_matrix.S.get_fancy [ L a; R [] ] y )


(* load text data and stopwords *)
let load_stopwords () =
  let p = local_data_path () in
  Owl_nlp_utils.load_stopwords (p ^ "stopwords.txt")


let load_nips_train_data stopwords =
  let p = local_data_path () in
  Owl_nlp_utils.load_from_file ~stopwords (p ^ "nips.train")
OCaml

Innovation. Community. Security.