torchview icon indicating copy to clipboard operation
torchview copied to clipboard

Additional tensor info beside its shape

Open fiqas opened this issue 1 year ago • 5 comments

Is your feature request related to a problem? Please describe. A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

I'm looking for a transformer visualiser that prints additional information about input and output tensors beside just shape.

Describe the solution you'd like In the visualisation where shape of a tensor is listed, I'd like additional statistic to be printed such as:

  • tensor datatype
  • % of zero elements (sparsity level) If you could provide this functionality or point me towards the parts of the code that would need to be modified, I'd appreciate it.

It would be especially helpful for researching big LLM models and their optimisation.

fiqas avatar Apr 24 '25 15:04 fiqas

The commit edbe1fae97d94bb413c8fecace4f28b7ef9a59b0 seems relevant. But it does not keep the record of attributes of tensor on TensorNode.

If we include attributes field to the TensorNode object during traversal, we can keep the record of tensor attributes

@ppetrovicTT What do you think?

mert-kurttutan avatar Apr 25 '25 04:04 mert-kurttutan

Hello! Thanks for including me @mert-kurttutan! However, I'm not sure if it's related.

When TensorNode is created in compute_node.py#L13, it uses tensor as the source for tensor_id and tensor_shape. Perhaps tensor_dtype and tensor_sparsity can be set there as well and then used in computation_graph.py#L389

I can see a new flag being added in a similar way as collect_attributes to control if these parameters are displayed or not. But if the number of these parameters grows, it's worth considering some sort of Config structure.

ppetrovicTT avatar Apr 25 '25 09:04 ppetrovicTT

@ppetrovicTT I was thinking more about collecting the attributes of tensor as python object, e.g dtype attribute can be used to get the dtype of torch tensor. That seems like an low-hanging-fruit

But, you are right in that we need different logic to check sparsity and other info that is not obtainable by just looking attributes.

mert-kurttutan avatar Apr 25 '25 10:04 mert-kurttutan

Introducing an api just for the specific example of sparsity checking seems a bad design. More general and principle approach is to introduce an input (function pointer) that generates what to calculate from each tensor during traversal.

e.g.

def tensor_sparsity_compute(x: Tensor) -> float:
  # your implementation
  #...
draw_graph_input = (tensor_sparsity_compute,)

draw_graph(
  ...
  tensor_info_collector = draw_graph_input,
)

Then, the function inside can be used to compute various values and attached to TensorNode (e.g. to attribute dictionary property of TensorNode)

Regarding how to implement, in addition what @ppetrovicTT said, you need to propagate these transformations and compute for each TensorNode during its initialization

mert-kurttutan avatar Apr 25 '25 10:04 mert-kurttutan

I definitely agree with some functionality that would allow a user to extend the visualisation, like supplying your own sparsity measurement or even different coloring depending on needs. Json would be a good choice to select that.

At minimum, showing statistics such as dtype of Tensor would be really useful with mixed precision becoming state-of-the-art in the LLMs.

fiqas avatar Apr 25 '25 12:04 fiqas