hiddenlayer icon indicating copy to clipboard operation
hiddenlayer copied to clipboard

Recursively detect frame work when multiple inheritance of nn.Module is used

Open Damming opened this issue 4 years ago • 1 comments

Hi,

I found that the function detect_framework() in graph.py can only detect the class directly inherited from nn.Module (in a Pytorch style), so I wrote this:

def find_root_class(outer_class, all_classes): if outer_class.__bases__[0] == object: return all_classes + outer_class.__bases__ else: return find_root_class(outer_class.__bases__[0], all_classes + outer_class.__bases__)

Then the original detect_framework() could be:

def detect_framework(value): classes = find_root_class(value.__class__, ()) for c in classes: if c.__module__.startswith("torch"): return "torch" elif c.__module__.startswith("tensorflow"): return "tensorflow"

Hope this could be useful.

Damming avatar Dec 06 '19 03:12 Damming

A more elegant solution has already been implemented in 294f8732b271cbdd6310c55bdf5ce855cbf61c75. However, it has not been merged yet (and unfortunately, probably is not going to get merged any time soon). Nevertheless, you can already use the fixed version, if you install hiddenalyer via git and not via pip.

maxfrei750 avatar Dec 30 '19 20:12 maxfrei750