Skip to content

Model Graph

How to construct a graph of the model

This a page describing in detail how to construct nice-looking graphs of your model automatically.

Example

There is a model_graph example in the plotting_examples.ipynb notebook (you can get all examples by running sapsan get_examples). That being said, a brief overview of how it works is below:

from sapsan.lib.estimator.cnn.cnn3d_estimator import CNN3d, CNN3dConfig
from sapsan.utils.plot import model_graph
from sapsan.lib.data import get_loader_shape

# load your data into torch loaders
estimator = CNN3d(config = CNN3dConfig(),
                  loaders = loaders)

shape_x, shape_y = get_loader_shape(loaders)

model_graph(model = estimator.model, shape = shape_x)
Considering that shape_x = (8,1,8,8,8), the following graph will be produced:

cnn_model_graph

cnn_model_graph

Details

shape of the input data is in the format [N, Cin, Db, Hb, Wb]. You can either grab it from the loader as shown above or provide your own, as long as the number of channels Cin matches the data your model was initialized with.

transforms allow you to adjust the graph to your liking. For example, they can allow you to combine layers to be displayed in a single box, instead of separate. Please refer to the API of model_graph to see what options are available for transformations.

Info

Order of transforms in the list matters!

Limitations

  • model input param must be a PyTorch, TensorFlow, or Keras-with-TensorFlow-backend model.

API for model_graph

sapsan.utils.plot.model_graph(model, shape: np.array, transforms)

Creates a graph of the ML model (needs graphviz to be installed). The method is based on hiddenlayer originally written by Waleed Abdulla.

Parameters

model (object) - initialized pytorch or tensorflow model

shape (np.array) - shape of the input array in the form [N, Cin, Db, Hb, Wb], where Cin=1

transforms (list[methods]) - a list of hiddenlayer transforms to be applied (Fold, FoldId, Prune, PruneBranch, FoldDuplicates, Rename), defined in transforms.py. Default:

> import sapsan.utils.hiddenlayer as hl
> transforms = [
                hl.transforms.Fold("Conv > MaxPool > Relu", "ConvPoolRelu"),
                hl.transforms.Fold("Conv > MaxPool", "ConvPool"),    
                hl.transforms.Prune("Shape"),
                hl.transforms.Prune("Constant"),
                hl.transforms.Prune("Gather"),
                hl.transforms.Prune("Unsqueeze"),
                hl.transforms.Prune("Concat"),
                hl.transforms.Rename("Cast", to="Input"),
                hl.transforms.FoldDuplicates()
               ]

Return

SVG graph of a model

Return type

graphviz.Digraph object