package owl-base

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file owl_algodiff_graph_convert.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# 1 "src/base/algodiff/owl_algodiff_graph_convert.ml"
module Make (Core : Owl_algodiff_core_sig.Sig) = struct
  open Core

  (* _traverse_trace and its related functions are used to convert the computation graph
     generated in backward mode into human-readable format. You can make your own convert
     function to generate needed format. *)
  let _traverse_trace x =
    (* init variables for tracking nodes and indices *)
    let nodes = Hashtbl.create 512 in
    let index = ref 0 in
    (* local function to traverse the nodes *)
    let rec push tlist =
      match tlist with
      | []       -> ()
      | hd :: tl ->
        if Hashtbl.mem nodes hd = false
        then (
          let op, prev =
            match hd with
            | DR (_ap, _aa, (_, _, label), _af, _ai, _) -> label
            | F _a -> Printf.sprintf "Const", []
            | Arr _a -> Printf.sprintf "Const", []
            | DF (_, _, _) -> Printf.sprintf "DF", []
          in
          (* check if the node has been visited before *)
          Hashtbl.add nodes hd (!index, op, prev);
          index := !index + 1;
          push (prev @ tl))
        else push tl
    in
    (* iterate the graph then return the hash table *)
    push x;
    nodes


  (* convert graph to terminal output *)
  let _convert_terminal_output nodes =
    Hashtbl.fold
      (fun v (v_id, v_op, v_prev) s0 ->
        let v_ts = type_info v in
        s0
        ^ List.fold_left
            (fun s1 u ->
              let u_id, u_op, _ = Hashtbl.find nodes u in
              let u_ts = type_info u in
              s1
              ^ Printf.sprintf
                  "{ i:%i o:%s t:%s } -> { i:%i o:%s t:%s }\n"
                  u_id
                  u_op
                  u_ts
                  v_id
                  v_op
                  v_ts)
            ""
            v_prev)
      nodes
      ""


  (* convert graph to dot file output *)
  let _convert_dot_output nodes =
    let network =
      Hashtbl.fold
        (fun _v (v_id, _v_op, v_prev) s0 ->
          s0
          ^ List.fold_left
              (fun s1 u ->
                let u_id, _u_op, _ = Hashtbl.find nodes u in
                s1 ^ Printf.sprintf "\t%i -> %i;\n" u_id v_id)
              ""
              v_prev)
        nodes
        ""
    in
    let attrs =
      Hashtbl.fold
        (fun v (v_id, v_op, _v_prev) s0 ->
          if v_op = "Const"
          then
            s0
            ^ Printf.sprintf
                "%i [ label=\"#%i | { %s | %s }\" fillcolor=gray, style=filled ];\n"
                v_id
                v_id
                v_op
                (deep_info v)
          else
            s0
            ^ Printf.sprintf
                "%i [ label=\"#%i | { %s | %s }\" ];\n"
                v_id
                v_id
                v_op
                (deep_info v))
        nodes
        ""
    in
    network ^ attrs


  let to_trace nodes = _traverse_trace nodes |> _convert_terminal_output

  let to_dot nodes =
    _traverse_trace nodes
    |> _convert_dot_output
    |> Printf.sprintf "digraph CG {\nnode [shape=record];\n%s}"


  let pp_num formatter x = Format.fprintf formatter "%s" (type_info x)
end
OCaml

Innovation. Community. Security.