pytorchviz
pytorchviz copied to clipboard
Add variable names to nodes
This is neat and by solving an NP-hard problem in my head I can tell which node corresponds to a given other tensor in my code. Could you add the names of each tensor to its node in the graph? Perhaps using https://pypi.org/project/varname/ or by modifying Tensor.__init__
to extract variable names from declarations in the traceback, then saving them in grad_fn
.
Hi, I am not familiar with the varname package. But it doesn't seem to be doing what you want here? Or at least I can't manage to get it to return the name from other scopes that are not direct parents. Would you have a code sample showing how this would work?
Completely untested:
def nameTensors(): # All tensors defined from here on out will carry a name around in their grad_fn.
oldinit = torch.Tensor.__init__
def newinit(self, *args, **kwargs):
self.oldinit(*args, **kwargs)
self._grad_fn.name = varname.varname(ignore=torch)
torch.Tensor.__init__ = newinit
The problem is that most of the operations happen in c++ and the python Tensor.__init__
is not actually called :/
Hmm. How about something like this, then? Even more untested, if that's even possible.
def nameTensors(module): # Wraps module (presumably torch) to have every function name every returned unnamed tensor.
def wrap(func):
def wrapped(*args, **kwargs):
result = func(*args, **kwargs)
if isinstance(result, Tensor):
if not hasattr(result._grad_fn, 'name'):
result._grad_fn.name = varname.varname(ignore=torch)
return result
return wrapped
for name, func in module.__dict__.iteritems():
if callable(func):
module.__dict__[name] = wrap(func)
return module
Usage:
import nameTensors(torch)
In support of this, adding variable names to nodes would be helpful to annotate exactly which tensors in the code are being saved for backwards