Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source
Source file owl_sparse_ndarray_generic.ml
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585# 1 "src/owl/sparse/owl_sparse_ndarray_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2019 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openBigarrayopenOwl_ndarrayopenOwl_base_dense_commontype('a,'b)kind=('a,'b)Bigarray.kindtype('a,'b)t={mutables:intarray;mutableh:(intarray,int)Hashtbl.t;mutabled:('a,'b,c_layout)Array1.t;}letnnzx=let_stats=Hashtbl.statsx.hinHashtbl.(_stats.num_bindings)let_make_elt_arraykn=letx=Array1.createkc_layoutninArray1.fillx(Owl_const.zerok);xlet_allocate_more_spacex=letc=nnzxinifc<Array1.dimx.dthen()else(Owl_log.debug"allocate space %i"c;x.d<-Owl_utils.array1_extendx.dc;)let_remove_ith_itemxi=Owl_log.debug"_remove_ith_item";forj=ito(nnzx)-2dox.d.{j}<-x.d.{j+1}done;Hashtbl.filter_map_inplace(fun_kv->ifv=ithenNoneelseifv>ithenSome(v-1)elseSomev)x.h(* check whether index x is in slice s *)let_in_slicesx=letr=reftruein(tryArray.iteri(funiv->matchvwith|Somev->(ifv<>x.(i)then(r:=false;failwith"not in the slice";))|None->())swith_exn->());!rletzerosks=letn=Array.fold_right(funca->c*a)s1inletc=max(n/1000)1024in{s=Array.copys;h=Hashtbl.createc;d=_make_elt_arraykc;}letshapex=Array.copyx.sletnum_dimsx=Array.lengthx.sletnth_dimxi=x.s.(i)letnumelx=Array.fold_right(funca->c*a)x.s1letdensityx=leta=float_of_int(nnzx)inletb=float_of_int(numelx)ina/.bletkindx=Array1.kind(x.d)letsame_shapexy=if(num_dimsx)<>(num_dimsy)thenfalseelse(lets0=shapexinlets1=shapeyinletb=reftrueinArray.iteri(funi_d->ifs0.(i)<>s1.(i)thenb:=false)s0;!b)let_check_same_shapexy=ifsame_shapexy=falsethenfailwith"Owl_sparse_ndarray: _check_same_shape fails."letcopyx={s=Array.copyx.s;h=Hashtbl.copyx.h;d=Owl_utils.array1_copyx.d;}letgetxi=tryletj=Hashtbl.findx.hiinArray1.unsafe_getx.djwith_exn->Owl_const.zero(kindx)letsetxia=_allocate_more_spacex;let_a0=Owl_const.zero(kindx)inifa=_a0then(tryletj=Hashtbl.findx.hiinArray1.unsafe_setx.dj_a0;_remove_ith_itemxj;with_exn->())else(tryletj=Hashtbl.findx.hiinArray1.unsafe_setx.dja;with_exn->(letj=nnzxinHashtbl.addx.h(Array.copyi)j;Array1.unsafe_setx.dja))letflattenx=lets=Owl_utils.calc_stride(shapex)inlety=copyxinHashtbl.iter(funij->leti'=Owl_utils.index_nd_1disinHashtbl.removey.hi;Hashtbl.addy.h[|i'|]j)x.h;y.s<-[|numelx|];yletreshapexs=lety=copyxinlets0=Owl_utils.calc_stride(shapex)inlets1=Owl_utils.calc_stridesinleti1=Array.copysinHashtbl.iter(funij->letk=Owl_utils.index_nd_1dis0inOwl_utils.index_1d_ndki1s1;Hashtbl.removey.hi;Hashtbl.addy.h(Array.copyi1)j;)x.h;y.s<-s;yletrec__iteri_fix_axisdjilhfx=ifj=d-1then(fork=l.(j)toh.(j)doi.(j)<-k;fi(getxi);done)else(fork=l.(j)toh.(j)doi.(j)<-k;__iteri_fix_axisd(j+1)ilhfxdone)let_iteri_fix_axisaxisfx=letd=num_dimsxinleti=Array.maked0inletl=Array.maked0inleth=shapexinArray.iteri(funja->matchawith|Someb->(l.(j)<-b;h.(j)<-b)|None->(h.(j)<-h.(j)-1))axis;__iteri_fix_axisd0ilhfxletiteri?axisfx=matchaxiswith|Somea->_iteri_fix_axisafx|None->_iteri_fix_axis(Array.make(num_dimsx)None)fxletiter?axisfx=iteri?axis(fun_y->fy)xletmapi?axisfx=lety=copyxiniteri?axis(funiz->setyi(fiz))y;yletmap?axisfx=lety=copyxiniteri?axis(funiz->setyi(fz))y;ylet_iteri_all_axis_nzfx=Hashtbl.iter(funij->fi(x.d.{j}))x.hlet_iteri_fix_axis_nzaxisfx=Hashtbl.iter(funij->if_in_sliceaxisi=truethenfi(x.d.{j}))x.hlet_iter_all_axis_nzfx=fori=0to(nnzx)-1dof(Array1.unsafe_getx.di)doneletiteri_nz?axisfx=matchaxiswith|Somea->_iteri_fix_axis_nzafx|None->_iteri_all_axis_nzfxletiter_nz?axisfx=matchaxiswith|Somea->_iteri_fix_axis_nza(fun_y->fy)x|None->_iter_all_axis_nzfxletmapi_nz?axisfx=lety=copyxin(matchaxiswith|Somea->Hashtbl.iter(funij->if_in_sliceai=truetheny.d.{j}<-fi(x.d.{j}))y.h|None->Hashtbl.iter(funij->y.d.{j}<-fi(x.d.{j}))y.h);yletmap_nz?axisfx=matchaxiswith|Somea->mapi_nz~axis:a(fun_z->fz)x|None->(lety=copyxinfori=0to(nnzy)doleta=f(Array1.unsafe_gety.di)inArray1.unsafe_sety.diadone;y)let_check_transpose_axisaxisd=letinfo="check_transpose_axiss fails"inifArray.lengthaxis<>dthenfailwithinfo;leth=Hashtbl.create16inArray.iter(funx->ifx<0||x>=dthenfailwithinfo;ifHashtbl.memhx=truethenfailwithinfo;Hashtbl.addhx0)axislettranspose?axisx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->Array.initd(funi->d-i-1)in(* check if axis is a correct permutation *)_check_transpose_axisad;lets0=shapexinlets1=Array.map(funj->s0.(j))ainleti'=Array.maked0inlety=zeros(kindx)s1initeri(funiz->Array.iteri(funkj->i'.(k)<-i.(j))a;setyi'z)x;yletswapa0a1x=letd=num_dimsxinleta=Array.initd(funi->i)inlett=a.(a0)ina.(a0)<-a.(a1);a.(a1)<-t;transpose~axis:axletfilteri?axisfx=lets=Owl_utils.Stack.make()initeri?axis(funiy->iffiy=truethenletj=Array.copyiinOwl_utils.Stack.pushsj)x;Owl_utils.Stack.to_arraysletfilter?axisfx=filteri?axis(fun_y->fy)xletfilteri_nz?axisfx=lets=Owl_utils.Stack.make()initeri_nz?axis(funiy->iffiy=truethenletj=Array.copyiinOwl_utils.Stack.pushsj)x;Owl_utils.Stack.to_arraysletfilter_nz?axisfx=filteri_nz?axis(fun_y->fy)xlet_fold_basic?axisiter_funfax=letr=refainiter_fun?axis(funy->r:=f!ry)x;!rletfold?axisfax=_fold_basic?axisiterfaxletfold_nz?axisfax=_fold_basic?axisiter_nzfaxletfoldi?axisfax=letc=refainiteri?axis(funiy->c:=(fi!cy))x;!cletfoldi_nz?axisfax=letc=refainiteri_nz?axis(funiy->c:=(fi!cy))x;!cletsliceaxisx=(* make the index mapping *)lets=Owl_utils.Stack.make()infori=0toArray.lengthaxis-1domatchaxis.(i)with|Some_->()|None->Owl_utils.Stack.pushsidone;letm=Owl_utils.Stack.to_arraysin(* create a new sparse ndarray for the slice *)lets0=shapexinlets1=Array.map(funi->s0.(i))minlety=zeros(kindx)s1in(* only iterate non-zero elements *)iteri_nz(funiv->if_in_sliceaxisi=truethen(leti'=Array.map(funj->i.(j))minsetyi'v))x;ylet_exists_basiciter_funfx=tryiter_fun(funy->if(fy)=truethenfailwith"found")x;falsewith_exn->trueletexistsfx=_exists_basiciterfxletnot_existsfx=not(existsfx)letfor_allfx=letgy=not(fy)innot_existsgxletexists_nzfx=_exists_basiciter_nzfxletnot_exists_nzfx=not(exists_nzfx)letfor_all_nzfx=letgy=not(fy)innot_exists_nzgxletis_zerox=(nnzx)=0letis_positivex=let_a0=Owl_const.zero(kindx)inif(nnzx)<(numelx)thenfalseelsefor_all((<)_a0)xletis_negativex=let_a0=Owl_const.zero(kindx)inif(nnzx)<(numelx)thenfalseelsefor_all((>)_a0)xletis_nonpositivex=let_a0=Owl_const.zero(kindx)infor_all_nz((>=)_a0)xletis_nonnegativex=let_a0=Owl_const.zero(kindx)infor_all_nz((<=)_a0)xletadd_scalarxa=let_op=_add_elt(kindx)inmap_nz(funz->_opza)xletsub_scalarxa=add_scalarx(_neg_elt(kindx)a)letmul_scalarxa=let_op=_mul_elt(kindx)inmap_nz(funz->_opza)xletdiv_scalarxa=mul_scalarx((_inv_elt(kindx))a)letscalar_addax=let_op=_add_elt(kindx)inmap_nz(funz->_opaz)xletscalar_subax=let_op=_sub_elt(kindx)inmap_nz(funz->_opaz)xletscalar_mulax=let_op=_mul_elt(kindx)inmap_nz(funz->_opaz)xletscalar_divax=let_op=_div_elt(kindx)inmap_nz(funz->_opaz)xletaddx1x2=letk=kindx1inlet_a0=Owl_const.zerokinlet__add_elt=_add_eltkinlety=zerosk(shapex1)inlet_=iteri_nz(funia->letb=getx2iinifb=_a0thensetyia)x1inlet_=iteri_nz(funia->letb=getx1iinsetyi(__add_eltab))x2inyletnegx=map_nz(_neg_elt(kindx))xletsubx1x2=addx1(negx2)letmulx1x2=letk=kindx1inlet_a0=Owl_const.zerokinlet__mul_elt=_mul_eltkinlety=zeros(kindx1)(shapex1)inlet_=iteri_nz(funia->letb=getx2iinifb<>_a0thensetyi(__mul_eltab))x1inyletdivx1x2=letk=kindx1inlet_a0=Owl_const.zerokinlet__div_elt=_div_eltkinlet__inv_elt=_inv_eltkinlety=zeros(kindx1)(shapex1)inlet_=iteri_nz(funia->letb=getx2iinifb<>_a0thensetyi(__div_elta(__inv_eltb)))x1inyletabsx=let_op=_abs_elt(kindx)inmap_nz_opxletsumx=letk=kindxinfold_nz(_add_eltk)(Owl_const.zerok)xletmeanx=(_mean_elt(kindx))(sumx)(numelx)letequalx1x2=_check_same_shapex1x2;if(nnzx1)<>(nnzx2)thenfalseelse(subx1x2|>is_zero)letnot_equalx1x2=not(equalx1x2)letgreaterx1x2=_check_same_shapex1x2;is_positive(subx1x2)letlessx1x2=greaterx2x1letgreater_equalx1x2=_check_same_shapex1x2;is_nonnegative(subx1x2)letless_equalx1x2=greater_equalx2x1letminmaxx=letk=kindxinlet_a0=Owl_const.zerokinletxmin=ref(Owl_const.pos_infk)inletxmax=ref(Owl_const.neg_infk)initer_nz(funy->ify<!xminthenxmin:=y;ify>!xmaxthenxmax:=y;)x;match(nnzx)<(numelx)with|true->(min!xmin_a0),(max!xmax_a0)|false->!xmin,!xmaxletminx=fst(minmaxx)letmaxx=snd(minmaxx)(* input/output functions *)letprint_indexi=Printf.printf"[ ";Array.iter(funx->Printf.printf"%i "x)i;Printf.printf"] "letprint_elementkv=lets=(Owl_utils.elt_to_strk)vinPrintf.printf"%s"s[@@warning"-32"]letprintx=let_op=Owl_utils.elt_to_str(kindx)initeri(funiy->print_indexi;Printf.printf"%s\n"(_opy))xletpp_spndax=let_op=Owl_utils.elt_to_str(kindx)inletk=shapexinlets=Owl_utils.calc_stridekinlet_pp=funij->(fori'=itojdoOwl_utils.index_1d_ndi'ks;print_indexk;Printf.printf"%s\n"(_op(getxk))done)inletn=numelxinifn<=40then(_pp0(n-1))else(_pp019;print_endline"......";_pp(n-20)(n-1))letsavexf=Owl_io.marshal_to_filexfletload_kf=Owl_io.marshal_from_fileflet_random_basicakfd=letx=zeroskdinletn=numelxinletc=int_of_float((float_of_intn)*.a)inleti=Array.copydinlets=Owl_utils.calc_stridedinfor_k=0toc-1doletj=Owl_stats.uniform_int_rvs~a:0~b:(n-1)inOwl_utils.index_1d_ndjis;setxi(f())done;xletbinary?(density=0.1)ks=let_a1=Owl_const.onekin_random_basicdensityk(fun()->_a1)sletuniform?(scale=1.)?(density=0.1)ks=let_op=_owl_uniform_funkin_random_basicdensityk(fun()->_opscale)sletto_arrayx=lety=Array.make(nnzx)([||],Owl_const.zero(kindx))inletj=ref0initeri_nz(funiv->y.(!j)<-(Array.copyi,v);j:=!j+1;)x;yletof_arrayksx=lety=zerosksinArray.iter(fun(i,v)->setyiv)x;y(* ends here *)