torchdrug
torchdrug copied to clipboard
Use `class-resolver` for readouts and enable choosing `MaxReadout`
This pull request does the following:
- Uses the
class-resolver
to reduce the if/else logic used to pick readouts - Enables using the "max" readout
Hi! That's a useful update. Do you want to add more commits or shall I mark this ready for review?
I have a few more ideas for this if you can give me another day to put them in!
Specifically, I'd like to us the class-resolver
package to deduplicate this logic as described in https://cthoyt.com/2022/02/06/model-abstraction.html
@KiddoZhu is there any automated testing in this package?
I read the post and that's a pretty cool idea.
Actually we came up with a very similar idea in core.Configurable
and core.Registry
. They not only serve as call-by-name dispatcher but also record the arguments passed to the __init__
function of the class. You may refer to the document here.
Currently, we only resolve call-by-name for models, tasks, and optimizers. Ideally, we want everything has a unique name in the registry, so that core.Registry.search
can work smoothly. Let me think if it is necessary to register all layers.
@cthoyt No, there are some tests already but I am not very familiar with the CI system. Do you know any reference for designing a CI for such a library, especially when some of the test script relies on GPU hardware?
Ideally code works the same whether you’re on cpu or gpu - the hardware details are abstracted by pytorch. I’ve never had to explore testing on GPU.
I’m not really sure how to interpret your feedback. Are you saying the registry system accomplishes the same thing as the class resolver? I think after reading the linked documentation I am only more confused.
I also see you’ve reverted all of my changes - is there a reason, like you don’t want to add an external dependency? At least I think it would make sense to make a helper function to reduce all of this code duplication
I revert the code to a naive branching implementation. The class-resolver is awesome. However, if we want to adopt that, we need to make it consistent everywhere in the codebase (e.g. covering the activation functions and some others). At this moment, we are not ready for this.
Feel free to open another issue or pull request to discuss this feature.
@cthoyt Sorry I put the wrong link -- it's this one, core.Registry
. The registry serves as a call-by-name resolver and supports partial name match. For example,
from torchdrug.core import Registry as R
@R.register("layers.readout.max")
class MaxReadout(nn.Module):
...
readout = "max"
if isinstance(readout, str):
readout = R.search(readout)()
We are also considering to augment the registry to support both lowercase and canonical names, very similar to yours. We've already applied this for reconstructing model from nested dict of hyperparameters, resolving molecule features, and probably use it for metrics in the near future.
I like your class-resolver
design. If we adopt class-resolver
, it's better we apply it everywhere it should be, for better maintainence and aesthetics. Also external dependency is a concern, we might prefer directly integrating a minimal version of this kind of tools if the authors permit. For the registry, it has some unique advantages in resolving class from a global namespace (like this hyperparameter configuration), potentially useful for building a meta training code without knowing what task / model to train. So for now we may prefer resolving through the registry. We will implement the registry for readout and other stuff later.