Module nujo.utils

nujo utils

Expand source code
''' nujo utils '''

from nujo.utils.computation_graph_plotter import ComputationGraphPlotter

__all__ = [
    'ComputationGraphPlotter',
]

Sub-modules

nujo.utils.computation_graph_plotter

Classes

class ComputationGraphPlotter (**kwargs)

Computation Graph Plotter

Uses graphviz.

Expand source code
class ComputationGraphPlotter:
    ''' Computation Graph Plotter

        Uses graphviz.

    '''
    def __init__(self, **kwargs):
        self.computation_graph = Digraph(**kwargs)

    @staticmethod
    def get_color(node: _Node) -> str:
        if isinstance(node, Tensor):
            if len(node.children) > 0:
                return 'lightblue'
            return 'indianred1'
        else:
            return 'gold2'

    @staticmethod
    def get_shape(node: _Node) -> str:
        if isinstance(node, Tensor):
            return 'box'
        else:
            return 'oval'

    def create(self,
               root: _Node,
               display_values=False) -> 'ComputationGraphPlotter':

        if len(root.children) == 0:
            return

        root_name = str(root) if display_values else repr(root)
        for child in root.children:
            child_name = str(child) if display_values else repr(child)

            self.computation_graph.node(child_name,
                                        color=self.get_color(child),
                                        shape=self.get_shape(child),
                                        style='filled')

            self.computation_graph.node(root_name,
                                        color=self.get_color(root),
                                        shape=self.get_shape(root),
                                        style='filled')

            self.computation_graph.edge(child_name, root_name)
            self.create(child)

        return self

    def view(self) -> None:
        self.computation_graph.view()

Static methods

def get_color(node: nujo.autodiff._node._Node) -> str
Expand source code
@staticmethod
def get_color(node: _Node) -> str:
    if isinstance(node, Tensor):
        if len(node.children) > 0:
            return 'lightblue'
        return 'indianred1'
    else:
        return 'gold2'
def get_shape(node: nujo.autodiff._node._Node) -> str
Expand source code
@staticmethod
def get_shape(node: _Node) -> str:
    if isinstance(node, Tensor):
        return 'box'
    else:
        return 'oval'

Methods

def create(self, root: nujo.autodiff._node._Node, display_values=False) -> ComputationGraphPlotter
Expand source code
def create(self,
           root: _Node,
           display_values=False) -> 'ComputationGraphPlotter':

    if len(root.children) == 0:
        return

    root_name = str(root) if display_values else repr(root)
    for child in root.children:
        child_name = str(child) if display_values else repr(child)

        self.computation_graph.node(child_name,
                                    color=self.get_color(child),
                                    shape=self.get_shape(child),
                                    style='filled')

        self.computation_graph.node(root_name,
                                    color=self.get_color(root),
                                    shape=self.get_shape(root),
                                    style='filled')

        self.computation_graph.edge(child_name, root_name)
        self.create(child)

    return self
def view(self) -> NoneType
Expand source code
def view(self) -> None:
    self.computation_graph.view()