Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source
Source file owl_computation_graph.ml
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267# 1 "src/base/compute/owl_computation_graph.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_graphmoduleMake(Optimiser:Owl_computation_optimiser_sig.Sig)=structmoduleOptimiser=OptimiseropenOptimiseropenOptimiser.Operator.SymbolopenOptimiser.Operator.Symbol.Shape.TypeopenOptimiser.Operator.Symbol.Shape.Type.Devicetypegraph={mutablename:string;(* name of the graph *)mutableinput:attrnodearray;(* input nodes of the graph *)mutableoutput:attrnodearray;(* output nodes of the graph *)mutableiopair:(attrnode*attrnode)array;(* input and output loopback pairs *)mutableiosafe:boolarray;(* whether it is safe to use unsafe_assign_arr *)mutablerandom:attrnodearray;(* rvs automatically invalidate themselves *)mutablehtbl:(string,attrnode)Hashtbl.t;(* node name to node mapping *)mutabledevice:device(* device-dependent field *)}(* utility functions *)(* print shape for ndarrays, whilst value for scalars *)letshape_or_valuex=letshape=(attrx).shapeinifis_assignedx=truethen(matchshape.(0)with|Somes->ifArray.lengths=0thenPrintf.sprintf"v:%g"(node_to_eltx|>elt_to_float)elsePrintf.sprintf"s:%s"(shape_to_strshape)|None->Printf.sprintf"s:%s"(shape_to_strshape))elsePrintf.sprintf"s:%s"(shape_to_strshape)let_block_colourb_id=(* lazy attempt to generate distinguishable colours *)leth=float(b_id*283mod360)/.360.inlets=0.4inletv=1.inPrintf.sprintf"%.3f %.1f %.0f"hsvletgraph_to_dotgraph=letb=Buffer.create512inBuffer.add_stringb"digraph CG {\nnode [shape=record];\n";iter_in_edges(funuv->Buffer.add_stringb(Printf.sprintf"%i -> %i;\n"(idu)(idv)))graph.output;iter_ancestors(funn->letsvs=shape_or_valueninletb_id=get_block_idninBuffer.add_stringb(Printf.sprintf"%i [ label=\"{{#%i | { %s | %s }} | r:%i; %s; b:%i }\""(idn)(idn)(namen)(op_to_str(attrn).op)(refnumn)svsb_id);ifget_reusen&&b_id<>-1then(letcol=_block_colourb_idinBuffer.add_stringb(Printf.sprintf"style=filled fillcolor=\"%s\""col));Buffer.add_stringb"];\n")graph.output;Buffer.add_charb'}';Buffer.contentsbletgraph_to_tracegraph=letu_nodes=Owl_utils_stack.make()inletv_nodes=Owl_utils_stack.make()initer_in_edges(funuv->Owl_utils_stack.pushu_nodes(node_to_stru);Owl_utils_stack.pushv_nodes(node_to_strv))graph.output;letu_strings=Owl_utils_stack.to_arrayu_nodesinletv_strings=Owl_utils_stack.to_arrayv_nodesinletu_longest=Owl_utils.longest_stringu_stringsinletu_strings=Owl_utils.pad_strings`Rightu_longestu_stringsinOwl_utils_array.fold2(funaccuv->Printf.sprintf"%s%s -> %s\n"accuv)""u_stringsv_stringsletsave_graphgraphfname=letdata=graph,A.numberinOwl_io.marshal_to_filedatafnameletload_graphfname=letgraph,num_typ=Owl_io.marshal_from_filefnameinifnum_typ<>A.numberthenfailwith"load_graph: inconsistent type."elsegraph,num_typletcollect_rvsoutput=letstack=Owl_utils_stack.make()inOwl_graph.iter_ancestors(funv->letop_typ=get_operatorvinifis_random_variableop_typthenOwl_utils_stack.pushstackv)output;Owl_utils_stack.to_arraystackletinvalidate_rvsgraph=Array.iterinvalidate_graphgraph.random(* core graph functions *)letmake_graph~input~outputname=(* check all the inputs must be variables *)letall_vars=Array.for_allis_varinputinletexn=Owl_exception.INVALID_ARGUMENT"inputs must be variables"inOwl_exception.(checkall_varsexn);(* set outputs' memory as not reusable *)Array.iter(funx->set_reusexfalse)output;(* create hash table to store input/output names *)letinput_output=Array.appendinputoutputinlethtbl_size=Array.lengthinput_outputinlethtbl=Hashtbl.createhtbl_sizein(* add nodes' name into the hash table *)Array.iter(funx->letn_name=Owl_graph.namexinletx_name=ifn_name=""thenPrintf.sprintf"n#%i"(idx)elsen_namein(* nodes name must be unique in inputs and outputs *)ifHashtbl.memhtblx_name=truethen(Owl_log.warn"nodes are both input and output: %s"(node_to_strx);letsaved_node=Hashtbl.findhtblx_nameinleterror()=lets=Printf.sprintf"node name %s is not unique"x_nameinOwl_exception.INVALID_ARGUMENTsinOwl_exception.(verify(saved_node==x)error))elseHashtbl.addhtblx_namex)input_output;(* freeze the graph to avoid memory leak *)freeze_ancestorsoutput;(* empty io pairing by default *)letiopair=[||]inletiosafe=[||]in(* collect all the random variables *)letrandom=collect_rvsoutputin(* create a device dependent field *)letdevice=make_device()in(* return the graph record *){name;input;output;iopair;iosafe;random;htbl;device}letget_inputsx=x.inputletget_outputsx=x.output(* manipulate input and output pairs *)letget_node_arr_valx=letvalue=get_valuexinletexn=Owl_exception.INVALID_ARGUMENT"input values do not exist."inOwl_exception.check(Array.lengthvalue>0)exn;value_to_arrvalue.(0)letget_node_elt_valx=letvalue=get_valuexinletexn=Owl_exception.INVALID_ARGUMENT"input values do not exist."inOwl_exception.check(Array.lengthvalue>0)exn;value_to_eltvalue.(0)letset_node_arr_valxv=set_valuex[|v|]letset_node_elt_valxv=set_valuex[|v|]letis_iopair_safeio=letsafe_pair=reftrueinletpass_by_o=reffalseinletbranching=ref0inlet_=tryOwl_graph.iter_descendants(funv->branching:=Stdlib.max!branching(refnumv);ifv==othenpass_by_o:=true;assert(not(!branching>1&&!pass_by_o)))[|i|]with|_exn->safe_pair:=falsein!safe_pairletmake_iopairgraphinputoutput=letinput_len=Array.lengthinputinletoutput_len=Array.lengthoutputinletis_equal=input_len=output_leninleterror()=lets=Printf.sprintf"input (%i) and output (%i) must have equal length."input_lenoutput_leninOwl_exception.INVALID_ARGUMENTsinOwl_exception.verifyis_equalerror;letiopair=Array.map2(funio->i,o)inputoutputinletiosafe=Array.map2(funio->is_iopair_safeio)inputoutputingraph.iopair<-iopair;graph.iosafe<-iosafeletupdate_iopairgraph=Array.iteri(funidx(i,o)->ifis_node_arri=truethen(leto_val=get_node_arr_valoinleti_arr=node_to_arriin(* make sure the original data will never be modified. *)ifgraph.iosafe.(idx)=truethenunsafe_assign_arri_arro_valelseassign_arri_arro_val)else(leti_elt=node_to_eltiinleto_val=get_node_elt_valoinassign_elti_elto_val))graph.iopairletremove_unused_iopairinputoutput=letnew_i,new_o=Owl_utils_array.filter2_split(funi_o->degreei<>0)inputoutputinnew_i,new_oletinit_inputsfgraph=Array.iter(funv->set_valuev[|fv|])graph.inputletoptimisegraph=optimise_nodesgraph.outputend(* Make functor ends *)