Module nujo.utils
nujo utils
Expand source code
''' nujo utils '''
from nujo.utils.computation_graph_plotter import ComputationGraphPlotter
__all__ = [
'ComputationGraphPlotter',
]
Sub-modules
nujo.utils.computation_graph_plotter
Classes
class ComputationGraphPlotter (**kwargs)
-
Computation Graph Plotter
Uses graphviz.
Expand source code
class ComputationGraphPlotter: ''' Computation Graph Plotter Uses graphviz. ''' def __init__(self, **kwargs): self.computation_graph = Digraph(**kwargs) @staticmethod def get_color(node: _Node) -> str: if isinstance(node, Tensor): if len(node.children) > 0: return 'lightblue' return 'indianred1' else: return 'gold2' @staticmethod def get_shape(node: _Node) -> str: if isinstance(node, Tensor): return 'box' else: return 'oval' def create(self, root: _Node, display_values=False) -> 'ComputationGraphPlotter': if len(root.children) == 0: return root_name = str(root) if display_values else repr(root) for child in root.children: child_name = str(child) if display_values else repr(child) self.computation_graph.node(child_name, color=self.get_color(child), shape=self.get_shape(child), style='filled') self.computation_graph.node(root_name, color=self.get_color(root), shape=self.get_shape(root), style='filled') self.computation_graph.edge(child_name, root_name) self.create(child) return self def view(self) -> None: self.computation_graph.view()
Static methods
def get_color(node: nujo.autodiff._node._Node) -> str
-
Expand source code
@staticmethod def get_color(node: _Node) -> str: if isinstance(node, Tensor): if len(node.children) > 0: return 'lightblue' return 'indianred1' else: return 'gold2'
def get_shape(node: nujo.autodiff._node._Node) -> str
-
Expand source code
@staticmethod def get_shape(node: _Node) -> str: if isinstance(node, Tensor): return 'box' else: return 'oval'
Methods
def create(self, root: nujo.autodiff._node._Node, display_values=False) -> ComputationGraphPlotter
-
Expand source code
def create(self, root: _Node, display_values=False) -> 'ComputationGraphPlotter': if len(root.children) == 0: return root_name = str(root) if display_values else repr(root) for child in root.children: child_name = str(child) if display_values else repr(child) self.computation_graph.node(child_name, color=self.get_color(child), shape=self.get_shape(child), style='filled') self.computation_graph.node(root_name, color=self.get_color(root), shape=self.get_shape(root), style='filled') self.computation_graph.edge(child_name, root_name) self.create(child) return self
def view(self) -> NoneType
-
Expand source code
def view(self) -> None: self.computation_graph.view()