package arrayjit
An array language compiler with multiple backends (CPU, CUDA), staged compilation
Install
Dune Dependency
Authors
Maintainers
Sources
0.5.2.tar.gz
md5=1f62613c37076ccb1c57a78c13a1a388
sha512=bccea3b2ad2cd6a96b1f03aaf8e127c800687a69191e5d09c7adf5e26c3bccd73f993eef91154a1ce2bcf4eeebf5bdb8d5372932018b4307515e8b6f5f4e94ab
doc/src/arrayjit.ppx_arrayjit/ppx_helper.ml.html
Source file ppx_helper.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
open Base open Ppxlib type li = longident let rec collect_list accu = function | [%expr [%e? hd] :: [%e? tl]] -> collect_list (hd :: accu) tl | [%expr []] -> List.rev accu | expr -> List.rev (expr :: accu) let dim_spec_to_string = function | `Input_dims dim -> "input (tuple) of dim " ^ Int.to_string dim | `Output_dims dim -> "output (list) of dim " ^ Int.to_string dim | `Batch_dims dim -> "batch (array) of dim " ^ Int.to_string dim let ndarray_constant expr = let loc = expr.pexp_loc in (* Traverse the backbone of the ndarray to collect the dimensions. *) let rec loop_dims accu = function | { pexp_desc = Pexp_tuple (exp :: _ as exps); _ } -> loop_dims (`Input_dims (List.length exps) :: accu) exp | { pexp_desc = Pexp_array (exp :: _ as exps); _ } -> loop_dims (`Batch_dims (List.length exps) :: accu) exp | { pexp_desc = Pexp_tuple []; _ } -> `Input_dims 0 :: accu | { pexp_desc = Pexp_array []; _ } -> `Batch_dims 0 :: accu | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } as expr -> ( let exps = collect_list [] expr in match exps with | exp :: _ -> loop_dims (`Output_dims (List.length exps) :: accu) exp | [] -> `Output_dims 0 :: accu) | _ -> accu in let dims_spec = Array.of_list_rev @@ loop_dims [] expr in let open Ast_builder.Default in let rec loop_values depth accu expr = if depth >= Array.length dims_spec then match expr with | { pexp_desc = Pexp_constant (Pconst_float _); _ } -> expr :: accu | { pexp_desc = Pexp_constant (Pconst_integer _); _ } -> [%expr Float.of_int [%e expr]] :: accu | { pexp_desc = Pexp_tuple _; pexp_loc = loc; _ } -> (pexp_extension ~loc @@ Location.error_extensionf ~loc "Arrayjit: ndarray literal found input axis (tuple), expected number") :: accu | { pexp_desc = Pexp_array _; pexp_loc = loc; _ } -> (pexp_extension ~loc @@ Location.error_extensionf ~loc "Arrayjit: ndarray literal found batch axis (array), expected number") :: accu | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } -> (pexp_extension ~loc @@ Location.error_extensionf ~loc "Arrayjit: ndarray literal found output axis (list), expected number") :: accu | expr -> expr :: accu (* it either computes a number, or becomes a type error *) else match expr with | { pexp_desc = Pexp_tuple exps; _ } -> ( match dims_spec.(depth) with | `Input_dims dim when dim = List.length exps -> List.fold_left exps ~init:accu ~f:(loop_values @@ (depth + 1)) | dim_spec -> (pexp_extension ~loc @@ Location.error_extensionf ~loc "Arrayjit: ndarray literal axis mismatch, got %s, expected %s" (dim_spec_to_string @@ `Input_dims (List.length exps)) (dim_spec_to_string dim_spec)) :: accu) | { pexp_desc = Pexp_array exps; _ } -> ( match dims_spec.(depth) with | `Batch_dims dim when dim = List.length exps -> List.fold_left exps ~init:accu ~f:(loop_values @@ (depth + 1)) | dim_spec -> (pexp_extension ~loc @@ Location.error_extensionf ~loc "Arrayjit: ndarray literal axis mismatch, got %s, expected %s" (dim_spec_to_string @@ `Batch_dims (List.length exps)) (dim_spec_to_string dim_spec)) :: accu) | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } -> ( let exps = collect_list [] expr in match dims_spec.(depth) with | `Output_dims dim when dim = List.length exps -> List.fold_left exps ~init:accu ~f:(loop_values @@ (depth + 1)) | dim_spec -> (pexp_extension ~loc @@ Location.error_extensionf ~loc "Arrayjit: ndarray literal axis mismatch, got %s, expected %s" (dim_spec_to_string @@ `Output_dims (List.length exps)) (dim_spec_to_string dim_spec)) :: accu) | { pexp_loc = loc; _ } -> (pexp_extension ~loc @@ Location.error_extensionf ~loc "Arrayjit: ndarray literal: expected an axis (tuple, list or array)") :: accu in let result = loop_values 0 [] expr in let values = { expr with pexp_desc = Pexp_array (List.rev result) } in let batch_dims, output_dims, input_dims = Array.fold dims_spec ~init:([], [], []) ~f:(fun (batch_dims, output_dims, input_dims) -> function | `Input_dims dim -> (batch_dims, output_dims, eint ~loc dim :: input_dims) | `Output_dims dim -> (batch_dims, eint ~loc dim :: output_dims, input_dims) | `Batch_dims dim -> (eint ~loc dim :: batch_dims, output_dims, input_dims)) in (values, List.rev batch_dims, List.rev output_dims, List.rev input_dims)
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>