hiddenlayer
hiddenlayer copied to clipboard
Recursively detect frame work when multiple inheritance of nn.Module is used
Hi,
I found that the function detect_framework() in graph.py can only detect the class directly inherited from nn.Module (in a Pytorch style), so I wrote this:
def find_root_class(outer_class, all_classes):
if outer_class.__bases__[0] == object:
return all_classes + outer_class.__bases__
else:
return find_root_class(outer_class.__bases__[0], all_classes + outer_class.__bases__)
Then the original detect_framework() could be:
def detect_framework(value):
classes = find_root_class(value.__class__, ())
for c in classes:
if c.__module__.startswith("torch"):
return "torch"
elif c.__module__.startswith("tensorflow"):
return "tensorflow"
Hope this could be useful.
A more elegant solution has already been implemented in 294f8732b271cbdd6310c55bdf5ce855cbf61c75. However, it has not been merged yet (and unfortunately, probably is not going to get merged any time soon). Nevertheless, you can already use the fixed version, if you install hiddenalyer via git and not via pip.