Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source
Source file owl_utils_infer_shape.ml
src/base/misc/owl_utils_infer_shape.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_types(* This module is for calculating the shape of an ndarray given inputs. *)(* check if broadcasting is required *)letrequire_broadcastingshape_xshape_y=letshape_a,shape_b=Owl_utils_array.align`Left1shape_xshape_yin(* NOTE: compare content, not physical address *)shape_a<>shape_b(* calculate the output shape of [conv2d] given input and kernel and stride *)letcalc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stride=letinput_cols=float_of_intinput_colsinletinput_rows=float_of_intinput_rowsinletkernel_cols=float_of_intkernel_colsinletkernel_rows=float_of_intkernel_rowsinletrow_stride=float_of_introw_strideinletcol_stride=float_of_intcol_strideinletoutput_cols=matchpaddingwith|SAME->input_cols/.col_stride|>ceil|>int_of_float|VALID->(input_cols-.kernel_cols+.1.)/.col_stride|>ceil|>int_of_floatinletoutput_rows=matchpaddingwith|SAME->input_rows/.row_stride|>ceil|>int_of_float|VALID->(input_rows-.kernel_rows+.1.)/.row_stride|>ceil|>int_of_floatinoutput_cols,output_rows(* calculate the output shape of [transpose_conv2d] given input and kernel and stride *)letcalc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stride=letoutput_cols=matchpaddingwith|SAME->input_cols*col_stride|VALID->(input_cols*col_stride)+max(kernel_cols-col_stride)0inletoutput_rows=matchpaddingwith|SAME->input_rows*row_stride|VALID->(input_rows*row_stride)+max(kernel_rows-row_stride)0inoutput_cols,output_rows(* calculate the padding size along width and height *)letcalc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stride=letpad_along_height=Stdlib.max(((output_rows-1)*row_stride)+kernel_rows-input_rows)0inletpad_along_width=Stdlib.max(((output_cols-1)*col_stride)+kernel_cols-input_cols)0inletpad_top=pad_along_height/2inletpad_bottom=pad_along_height-pad_topinletpad_left=pad_along_width/2inletpad_right=pad_along_width-pad_leftinpad_top,pad_left,pad_bottom,pad_right(* calc_conv1d_output_shape actually calls its 2d version *)letcalc_conv1d_output_shapepaddinginput_colskernel_colscol_stride=letinput_rows=1inletkernel_rows=1inletrow_stride=1incalc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stride|>fst(* calc_transpose_conv1d_output_shape actually calls its 2d version *)letcalc_transpose_conv1d_output_shapepaddinginput_colskernel_colscol_stride=letinput_rows=1inletkernel_rows=1inletrow_stride=1incalc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stride|>fst(* calculate the output shape of [conv3d] given input and kernel and stride *)letcalc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_stride=letinput_cols=float_of_intinput_colsinletinput_rows=float_of_intinput_rowsinletinput_dpts=float_of_intinput_dptsinletkernel_cols=float_of_intkernel_colsinletkernel_rows=float_of_intkernel_rowsinletkernel_dpts=float_of_intkernel_dptsinletrow_stride=float_of_introw_strideinletcol_stride=float_of_intcol_strideinletdpt_stride=float_of_intdpt_strideinletoutput_cols=matchpaddingwith|SAME->input_cols/.col_stride|>ceil|>int_of_float|VALID->(input_cols-.kernel_cols+.1.)/.col_stride|>ceil|>int_of_floatinletoutput_rows=matchpaddingwith|SAME->input_rows/.row_stride|>ceil|>int_of_float|VALID->(input_rows-.kernel_rows+.1.)/.row_stride|>ceil|>int_of_floatinletoutput_dpts=matchpaddingwith|SAME->input_dpts/.dpt_stride|>ceil|>int_of_float|VALID->(input_dpts-.kernel_dpts+.1.)/.dpt_stride|>ceil|>int_of_floatinoutput_cols,output_rows,output_dpts(* calculate the output shape of [transpose_conv3d] given input and kernel and stride *)letcalc_transpose_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_stride=letoutput_cols=matchpaddingwith|SAME->input_cols*col_stride|VALID->(input_cols*col_stride)+max(kernel_cols-col_stride)0inletoutput_rows=matchpaddingwith|SAME->input_rows*row_stride|VALID->(input_rows*row_stride)+max(kernel_rows-row_stride)0inletoutput_dpts=matchpaddingwith|SAME->input_dpts*dpt_stride|VALID->(input_dpts*dpt_stride)+max(kernel_dpts-dpt_stride)0inoutput_cols,output_rows,output_dpts(* calculate the padding size along width, height, and depth. *)letcalc_conv3d_paddinginput_colsinput_rowsinput_depthkernel_colskernel_rowskernel_depthoutput_colsoutput_rowsoutput_depthrow_stridecol_stridedepth_stride=letpad_along_height=Stdlib.max(((output_rows-1)*row_stride)+kernel_rows-input_rows)0inletpad_along_width=Stdlib.max(((output_cols-1)*col_stride)+kernel_cols-input_cols)0inletpad_along_depth=Stdlib.max(((output_depth-1)*depth_stride)+kernel_depth-input_depth)0inletpad_top=pad_along_height/2inletpad_bottom=pad_along_height-pad_topinletpad_left=pad_along_width/2inletpad_right=pad_along_width-pad_leftinletpad_shallow=pad_along_depth/2inletpad_deep=pad_along_depth-pad_shallowinpad_top,pad_left,pad_shallow,pad_bottom,pad_right,pad_deep(* various functions to calculate output shape, used in computation graph. *)letbroadcast1s0s1=letsa,sb=Owl_utils_array.align`Left1s0s1inArray.iter2(funab->Owl_exception.(check(not(a<>1&&b<>1&&a<>b))NOT_BROADCASTABLE))sasb;(* calculate the output shape *)Array.map2maxsasbletbroadcast2s0s1s2=letsa,sb,sc=Owl_utils_array.align3`Left1s0s1s2inletsd=Owl_utils_array.map3(funabc->maxa(maxbc))sasbscinOwl_utils_array.iter4(funabcd->Owl_exception.(check(not(a<>1&&a<>d))NOT_BROADCASTABLE);Owl_exception.(check(not(b<>1&&b<>d))NOT_BROADCASTABLE);Owl_exception.(check(not(c<>1&&c<>d))NOT_BROADCASTABLE))sasbscsd;sd(* no need to align two shapes before passing in. *)letbroadcast1_strides0s1=letsa,sb=Owl_utils_array.align`Left1s0s1inletstride_0=Owl_utils_ndarray.calc_stridesainletstride_1=Owl_utils_ndarray.calc_stridesbinOwl_utils_array.iter2i(funid0d1->ifd0<>d1thenifd0=1thenstride_0.(i)<-0elsestride_1.(i)<-0)sasb;stride_0,stride_1letfoldshapeaxis=letd=Array.lengthshapeinleta=Owl_utils_ndarray.adjust_indexaxisdinassert(a>=0&&a<d);let_shape=Array.copyshapein_shape.(a)<-1;_shapelettileshaperepeats=assert(Array.exists((>)1)repeats=false);lets,r=Owl_utils_array.align`Left1shaperepeatsinOwl_utils_array.map2(funab->a*b)srletrepeatshaperepeats=assert(Array.exists((>)1)repeats=false);assert(Array.lengthshape=Array.lengthrepeats);Owl_utils_array.map2(*)shaperepeatsletconcatenateshapeaxis=letd=Array.lengthshape.(0)inletaxis=Owl_utils_ndarray.adjust_indexaxisdinletshapes=Array.(mapcopyshape)inletshape0=Array.copyshapes.(0)inshape0.(axis)<-0;letacc_dim=ref0inArray.iteri(fun_ishape1->acc_dim:=!acc_dim+shape1.(axis);shape1.(axis)<-0;assert(shape0=shape1))shapes;shape0.(axis)<-!acc_dim;shape0letsplitshapeaxisparts=letd=Array.lengthshapeinleta=Owl_utils_ndarray.adjust_indexaxisdinlete=Array.fold_left(+)0partsinassert(a<d);assert(e=shape.(a));Array.map(funn->lets=Array.copyshapeins.(a)<-n;s)partsletsliceshapeslice_spec=letinfer_lenorig_lenstartstop?step()=letstart=ifstart<0thenorig_len+startelsestartinletstop=ifstop<0thenorig_len+stopelsestopinletstep=matchstepwith|Somex->x|None->ifstart<=stopthen1else-1inassert((start<=stop&&step>0&&stop<orig_len)||(start>stop&&step<0&&start<orig_len));letstep_abs=absstepin(abs(stop-start)+step_abs)/step_absinletshape'=List.mapi(funislicei->letlength=shape.(i)inletinfer_len_i=infer_lenlengthinmatchsliceiwith|[]->length|[index]->infer_len_iindexindex()|[start;stop]->infer_len_istartstop()|[start;stop;step]->infer_len_istartstop~step()|_->failwith"owl_utils_infer_shape: invalid slice specification")slice_specinlets=Array.copyshapeinList.iteri(funilen->s.(i)<-len)shape';sletdrawshapeaxisn=letd=Array.lengthshapeinleta=Owl_utils_ndarray.adjust_indexaxisdinlets=Array.copyshapeinassert(a<d);s.(a)<-n;sletreduceshapeaxis=letd=Array.lengthshapeinleta=Array.map(funi->Owl_utils_ndarray.adjust_indexid)axisinlets=Array.copyshapeinArray.iter(funi->assert(i<d);s.(i)<-1)a;sletconv2dinput_shapepaddingkernel_shapestride_shape=letbatches=input_shape.(0)inletinput_cols=input_shape.(1)inletinput_rows=input_shape.(2)inletin_channel=input_shape.(3)inletkernel_cols=kernel_shape.(0)inletkernel_rows=kernel_shape.(1)inletout_channel=kernel_shape.(3)inassert(in_channel=kernel_shape.(2));letcol_stride=stride_shape.(0)inletrow_stride=stride_shape.(1)inletoutput_cols,output_rows=calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stridein[|batches;output_cols;output_rows;out_channel|]letconv1dinput_shapepaddingkernel_shapestride_shape=letbatches=input_shape.(0)inletinput_cols=input_shape.(1)inletin_channel=input_shape.(2)inletinput_shape=[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel_shape.(0)inletout_channel=kernel_shape.(2)inassert(in_channel=kernel_shape.(1));letkernel_shape=[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride_shape.(0)inletstride_shape=[|1;col_stride|]inletoutput_shape=conv2dinput_shapepaddingkernel_shapestride_shapeinletoutput_cols=output_shape.(2)in[|batches;output_cols;out_channel|]letconv3dinput_shapepaddingkernel_shapestride_shape=letbatches=input_shape.(0)inletinput_cols=input_shape.(1)inletinput_rows=input_shape.(2)inletinput_dpts=input_shape.(3)inletin_channel=input_shape.(4)inletkernel_cols=kernel_shape.(0)inletkernel_rows=kernel_shape.(1)inletkernel_dpts=kernel_shape.(2)inletout_channel=kernel_shape.(4)inassert(in_channel=kernel_shape.(3));letcol_stride=stride_shape.(0)inletrow_stride=stride_shape.(1)inletdpt_stride=stride_shape.(2)inletoutput_cols,output_rows,output_dpts=calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_stridein[|batches;output_cols;output_rows;output_dpts;out_channel|]letdilated_conv2dinput_shapepaddingkernel_shapestride_shaperate_shape=letkernel_cols=kernel_shape.(0)inletkernel_rows=kernel_shape.(1)inletin_channel=kernel_shape.(2)inletout_channel=kernel_shape.(3)inletrate_cols=rate_shape.(0)inletrate_rows=rate_shape.(1)inletcol_up=kernel_cols+((kernel_cols-1)*(rate_cols-1))inletrow_up=kernel_rows+((kernel_rows-1)*(rate_rows-1))inletkernel_shape'=[|col_up;row_up;in_channel;out_channel|]inconv2dinput_shapepaddingkernel_shape'stride_shapeletdilated_conv1dinput_shapepaddingkernel_shapestride_shaperate_shape=letbatches=input_shape.(0)inletinput_cols=input_shape.(1)inletin_channel=input_shape.(2)inletinput_shape=[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel_shape.(0)inletout_channel=kernel_shape.(2)inassert(in_channel=kernel_shape.(1));letkernel_shape=[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride_shape.(0)inletstride_shape=[|1;col_stride|]inletcol_rate=rate_shape.(0)inletrate_shape=[|1;col_rate|]inletoutput_shape=dilated_conv2dinput_shapepaddingkernel_shapestride_shaperate_shapeinletoutput_cols=output_shape.(2)in[|batches;output_cols;out_channel|]letdilated_conv3dinput_shapepaddingkernel_shapestride_shaperate_shape=letkernel_cols=kernel_shape.(0)inletkernel_rows=kernel_shape.(1)inletkernel_dpts=kernel_shape.(2)inletin_channel=kernel_shape.(3)inletout_channel=kernel_shape.(4)inletrate_cols=rate_shape.(0)inletrate_rows=rate_shape.(1)inletrate_dpts=rate_shape.(2)inletcol_up=kernel_cols+((kernel_cols-1)*(rate_cols-1))inletrow_up=kernel_rows+((kernel_rows-1)*(rate_rows-1))inletdpt_up=kernel_dpts+((kernel_dpts-1)*(rate_dpts-1))inletkernel_shape'=[|col_up;row_up;dpt_up;in_channel;out_channel|]inconv3dinput_shapepaddingkernel_shape'stride_shapelettranspose_conv2dinput_shapepaddingkernel_shapestride_shape=letbatches=input_shape.(0)inletinput_cols=input_shape.(1)inletinput_rows=input_shape.(2)inletin_channel=input_shape.(3)inletkernel_cols=kernel_shape.(0)inletkernel_rows=kernel_shape.(1)inletout_channel=kernel_shape.(3)inassert(in_channel=kernel_shape.(2));letcol_stride=stride_shape.(0)inletrow_stride=stride_shape.(1)inletoutput_cols,output_rows=calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stridein[|batches;output_cols;output_rows;out_channel|]lettranspose_conv1dinput_shapepaddingkernel_shapestride_shape=letbatches=input_shape.(0)inletinput_cols=input_shape.(1)inletin_channel=input_shape.(2)inletinput_shape=[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel_shape.(0)inletout_channel=kernel_shape.(2)inassert(in_channel=kernel_shape.(1));letkernel_shape=[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride_shape.(0)inletstride_shape=[|1;col_stride|]inletoutput_shape=transpose_conv2dinput_shapepaddingkernel_shapestride_shapeinletoutput_cols=output_shape.(2)in[|batches;output_cols;out_channel|]lettranspose_conv3dinput_shapepaddingkernel_shapestride_shape=letbatches=input_shape.(0)inletinput_cols=input_shape.(1)inletinput_rows=input_shape.(2)inletinput_dpts=input_shape.(3)inletin_channel=input_shape.(4)inletkernel_cols=kernel_shape.(0)inletkernel_rows=kernel_shape.(1)inletkernel_dpts=kernel_shape.(2)inletout_channel=kernel_shape.(4)inassert(in_channel=kernel_shape.(3));letcol_stride=stride_shape.(0)inletrow_stride=stride_shape.(1)inletdpt_stride=stride_shape.(2)inletoutput_cols,output_rows,output_dpts=calc_transpose_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_stridein[|batches;output_cols;output_rows;output_dpts;out_channel|]letpool2dinput_shapepaddingkernel_shapestride_shape=letbatches=input_shape.(0)inletinput_cols=input_shape.(1)inletinput_rows=input_shape.(2)inletin_channel=input_shape.(3)inletkernel_cols=kernel_shape.(0)inletkernel_rows=kernel_shape.(1)inletcol_stride=stride_shape.(0)inletrow_stride=stride_shape.(1)inletoutput_cols,output_rows=calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stridein[|batches;output_cols;output_rows;in_channel|]letupsampling2dinput_shapesize=letbatches=input_shape.(0)inletinput_cols=input_shape.(1)inletinput_rows=input_shape.(2)inletin_channel=input_shape.(3)inletcol_size=size.(0)inletrow_size=size.(1)in[|batches;input_cols*col_size;input_rows*row_size;in_channel|]lettransposeinput_shapeaxis=Array.map(funj->input_shape.(j))axisletdotx_shapey_shape=[|x_shape.(0);y_shape.(1)|]letonehotinput_shapedepth=Array.appendinput_shape[|depth|](* ends here *)