Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source
Source file owl_dataset.ml
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157# 1 "src/owl/misc/owl_dataset.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)(** Dataset: easy access to various datasets *)openOwl_typesletremote_data_path()="https://github.com/ryanrhymes/owl_dataset/raw/master/"letlocal_data_path():string=lethome=Sys.getenv"HOME"^"/.owl"inletd=home^"/dataset"inOwl_log.info"create %s if not present"d;(* Note: use of Sys.file_exist is racy *)(tryUnix.mkdirhome0o755with|Unix.Unix_error(EEXIST,_,_)->());(tryUnix.mkdird0o755with|Unix.Unix_error(EEXIST,_,_)->());dletdownload_datafname=letfn0=remote_data_path()^fnameinletfn1=local_data_path()^fnameinletcmd0="wget "^fn0^" -O "^fn1inletcmd1="gunzip "^fn1inignore(Sys.commandcmd0);ignore(Sys.commandcmd1)letdownload_all()=letl=["stopwords.txt.gz";"enron.test.gz";"enron.train.gz";"nips.test.gz";"nips.train.gz";"mnist-test-images.gz";"mnist-test-labels.gz";"mnist-test-lblvec.gz";"mnist-train-images.gz";"mnist-train-labels.gz";"mnist-train-lblvec.gz";"cifar10_test_data.gz";"cifar10_test_labels.gz";"cifar10_test_filenames.gz";"cifar10_test_lblvec.gz";"cifar10_train1_data.gz";"cifar10_train1_labels.gz";"cifar10_train1_filenames.gz";"cifar10_train1_lblvec.gz";"cifar10_train2_data.gz";"cifar10_train2_labels.gz";"cifar10_train2_filenames.gz";"cifar10_train2_lblvec.gz";"cifar10_train3_data.gz";"cifar10_train3_labels.gz";"cifar10_train3_filenames.gz";"cifar10_train3_lblvec.gz";"cifar10_train4_data.gz";"cifar10_train4_labels.gz";"cifar10_train4_filenames.gz";"cifar10_train4_lblvec.gz";"cifar10_train5_data.gz";"cifar10_train5_labels.gz";"cifar10_train5_filenames.gz";"cifar10_train5_lblvec.gz"]inList.iter(funfname->download_datafname)lletdraw_samplesxyn=letx',y',_=Owl_dense_matrix_generic.draw_rows2~replacement:falsexyninx',y'(* load mnist train data, the return is a triplet. The first is a 60000 x 784
matrix where each row represents a 28 x 28 image. The second is label and the
third is the corresponding unravelled row vector of the label. *)letload_mnist_train_data()=letp=local_data_path()in(Owl_dense_matrix.S.load(p^"mnist-train-images"),Owl_dense_matrix.S.load(p^"mnist-train-labels"),Owl_dense_matrix.S.load(p^"mnist-train-lblvec"))letload_mnist_test_data()=letp=local_data_path()in(Owl_dense_matrix.S.load(p^"mnist-test-images"),Owl_dense_matrix.S.load(p^"mnist-test-labels"),Owl_dense_matrix.S.load(p^"mnist-test-lblvec"))letprint_mnist_imagex=Owl_dense_matrix_generic.reshapex[|28;28|]|>Owl_dense_matrix_generic.iter_rows(funv->Owl_dense_matrix_generic.iter(function|0.->Printf.printf" "|_->Printf.printf"■")v;print_endline"")(* similar to load_mnist_train_data but returns [x] as [*,28,28,1] ndarray *)letload_mnist_train_data_arr()=letx,label,y=load_mnist_train_data()inletm=Owl_dense_matrix.S.row_numxinletx=Owl_dense_ndarray.S.reshapex[|m;28;28;1|]inx,label,yletload_mnist_test_data_arr()=letx,label,y=load_mnist_test_data()inletm=Owl_dense_matrix.S.row_numxinletx=Owl_dense_ndarray.S.reshapex[|m;28;28;1|]inx,label,y(* load cifar train data, there are five batches in total. The loaded data is a
10000 * 3072 matrix. Each row represents a 32 x 32 image of three colour
channels, unravelled into a row vector. The labels are also returned. *)letload_cifar_train_databatch=letp=local_data_path()in(Owl_dense_ndarray.S.load(p^"cifar10_train"^string_of_intbatch^"_data"),Owl_dense_matrix.S.load(p^"cifar10_train"^string_of_intbatch^"_labels"),Owl_dense_matrix.S.load(p^"cifar10_train"^string_of_intbatch^"_lblvec"))letload_cifar_test_data()=letp=local_data_path()in(Owl_dense_ndarray.S.load(p^"cifar10_test_data"),Owl_dense_matrix.S.load(p^"cifar10_test_labels"),Owl_dense_matrix.S.load(p^"cifar10_test_lblvec"))letdraw_samples_cifarxyn=letcol_num=(Owl_dense_ndarray_generic.shapex).(0)inleta=Array.initcol_num(funi->i)inleta=Owl_stats.choosean|>Array.to_listin(Owl_dense_ndarray.S.get_fancy[La;R[];R[];R[]]x,Owl_dense_matrix.S.get_fancy[La;R[]]y)(* load text data and stopwords *)letload_stopwords()=letp=local_data_path()inOwl_nlp_utils.load_stopwords(p^"stopwords.txt")letload_nips_train_datastopwords=letp=local_data_path()inOwl_nlp_utils.load_from_file~stopwords(p^"nips.train")