Additional tensor info beside its shape
Is your feature request related to a problem? Please describe. A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
I'm looking for a transformer visualiser that prints additional information about input and output tensors beside just shape.
Describe the solution you'd like In the visualisation where shape of a tensor is listed, I'd like additional statistic to be printed such as:
- tensor datatype
- % of zero elements (sparsity level) If you could provide this functionality or point me towards the parts of the code that would need to be modified, I'd appreciate it.
It would be especially helpful for researching big LLM models and their optimisation.
The commit edbe1fae97d94bb413c8fecace4f28b7ef9a59b0 seems relevant. But it does not keep the record of attributes of tensor on TensorNode.
If we include attributes field to the TensorNode object during traversal, we can keep the record of tensor attributes
@ppetrovicTT What do you think?
Hello! Thanks for including me @mert-kurttutan! However, I'm not sure if it's related.
When TensorNode is created in compute_node.py#L13, it uses tensor as the source for tensor_id and tensor_shape. Perhaps tensor_dtype and tensor_sparsity can be set there as well and then used in computation_graph.py#L389
I can see a new flag being added in a similar way as collect_attributes to control if these parameters are displayed or not. But if the number of these parameters grows, it's worth considering some sort of Config structure.
@ppetrovicTT I was thinking more about collecting the attributes of tensor as python object, e.g dtype attribute can be used to get the dtype of torch tensor. That seems like an low-hanging-fruit
But, you are right in that we need different logic to check sparsity and other info that is not obtainable by just looking attributes.
Introducing an api just for the specific example of sparsity checking seems a bad design. More general and principle approach is to introduce an input (function pointer) that generates what to calculate from each tensor during traversal.
e.g.
def tensor_sparsity_compute(x: Tensor) -> float:
# your implementation
#...
draw_graph_input = (tensor_sparsity_compute,)
draw_graph(
...
tensor_info_collector = draw_graph_input,
)
Then, the function inside can be used to compute various values and attached to TensorNode (e.g. to attribute dictionary property of TensorNode)
Regarding how to implement, in addition what @ppetrovicTT said, you need to propagate these transformations and compute for each TensorNode during its initialization
I definitely agree with some functionality that would allow a user to extend the visualisation, like supplying your own sparsity measurement or even different coloring depending on needs. Json would be a good choice to select that.
At minimum, showing statistics such as dtype of Tensor would be really useful with mixed precision becoming state-of-the-art in the LLMs.