hydra-zen
hydra-zen copied to clipboard
Calling methods via instantiate API
This is just a prototype. We can use zen-wrappers to let people specify a method call via instantiate:
from hydra_zen import builds, instantiate
def outer(Target, *, mthd_name, mthd_args, mthd_kwargs):
def wrapper(*args, **kwargs):
return getattr(Target(*args, **kwargs), mthd_name)(*mthd_args, **mthd_kwargs)
return wrapper
def call_method(mthd_name, *mthd_args, **mthd_kwargs):
return builds(
outer,
mthd_args=mthd_args,
mthd_name=mthd_name,
mthd_kwargs=mthd_kwargs,
populate_full_signature=True,
zen_partial=True
)
class A:
def __init__(self, x):
self.x = x
def x_plus_y(self, y):
return self.x + y
>>> instantiate(builds(A, x=-11, zen_wrappers=call_method("x_plus_y", y=2)))
-9
I can't say I love this form factor, or the readability of this. Specifying the method via its name (a string) is not very zen. I would want to see if I can make the yaml more legible:
_target_: hydra_zen.funcs.zen_processing
_zen_target: __main__.A
_zen_wrappers:
_target_: hydra_zen.funcs.zen_processing
_zen_target: __main__.call_method
_zen_partial: true
mthd_name: x_plus_y
'y': 2
x: -11
Some things for us to consider:
- Would it be worthwhile for us to provide users with this sort of wrapper?
- Can the wrapper be designed to handle class & instance & static methods?
- Can we avoid requiring the method be specified by name (string)?
- Is there some sort of optional validation API that we can expose for zen-wrappers in general? So that
builds(A, x=-11, zen_wrappers=call_method("x_plus_y", y=2))
can validate that the method call will be valid upon instantiation?
This doesn't work with the code above:
>>> instantiate(builds(torch.Generator, zen_wrappers=call_method("manual_seed", 42)))
...
TypeError: Error instantiating 'hydra_zen.funcs.zen_processing' : outer() got multiple values for argument 'mthd_name'
Ah, whoops. Yeah things get a little complicated with *args
and partial. I updated the implementation in the top post, there should no longer be an issue.
I think, for now, we should provide this functionality as is. It provides a useful functionality that probably won't be used that much.
- Can the wrapper be designed to handle class & instance & static methods?
Doesn't builds
handle these cases already?
- Can we avoid requiring the method be specified by name (string)?
- Is there some sort of optional validation API that we can expose for zen-wrappers in general? So that builds(A, x=-11, zen_wrappers=call_method("x_plus_y", y=2)) can validate that the method call will be valid upon instantiation?
I agree it would be nicer to have a more user friendly interface (and one we can validate), but maybe we can wait until we see more use cases. One idea in left field, is there any possible way to code this API: builds(A, x=-11).x_plus_y(y=2)
?
One idea in left field, is there any possible way to code this API:
builds(A, x=-11).x_plus_y(y=2)
?
The generated dataclass would need a special __getattr__
method.
Edit: Actually I think you might need a metaclass __getattr__
in this case, as __getattr__
defined on the class itself would only do the trick for instances.
I think, for now, we should provide this functionality as is
@jgbos I think you are right. Since this is independent of builds
itself, users can take it or leave it. That being said, I would really like to begin working on an API for permitting wrapper validation. We can start with making the API documented/internal and to use this is a prototype example.
One idea in left field, is there any possible way to code this API: builds(A, x=-11).x_plus_y(y=2)?
I think this would get hairy pretty fast; a purely functional approach would definitely be easier to support.
Actually I think you might need a
metaclass __getattr__
in this case, as__getattr__
defined on the class itself would only do the trick for instances.
Yep, I think you are right. I don't think builds
should ever take on this sort of complexity.
#219 prototypes a potential solution to the above, that is much more powerful than what I had initially proposed. This is inspired by @jgbos ' out of left field idea 😄
Let's see hydra_zen.like
in action:
>>> from hydra_zen import just, instantiate, like
>>> import torch
>>> GenLike = like(torch.Generator) # 'memorizes' sequence of interactions with `torch.Generator`
>>> seeded = GenLike().manual_seed(42)
>>> seeded
Like(<class 'torch._C.Generator'>)().manual_seed(42)
>>> Conf = just(seeded) # creates Hydra-compatible config for performing interactions via instantiation
>>> generator = instantiate(Conf) # calls torch.Generator().manual_seed(42)
<torch._C.Generator at 0x14efb92c770>
what is going on here?
hydra_zen.like(<target>)
returns a _Tracker
instance that records all subsequent interactions with <target>
. Right now, "interactions" are restricted to accessing attributes and calling the target. hydra_zen.just
then knows how to convert a _Tracker
instance into a valid Hydra config.
Let's use this to make a config that actually returns the manual_seed
method of our Generator-instance:
>>> manual_seed_fn = GenLike().manual_seed
>>> instantiate(just(manual_seed_fn))
<function Generator.manual_seed>
Here is a silly example, just to help give more of a feel for this:
>>> listy = like(list)
>>> x = listy([1, 2, 3])
>>> y = x.pop()
# note that `x` *does not* reflect any mutation
>>> instantiate(just(x)), instantiate(just(y))
([1, 2, 3], 3)
Current thoughts on this prototype
The implementation in #219 is super crude, but it seems to cover ~80% (?) of the desired use cases already!
As @jasha10 points out, there is an issue with recording interactions with class instances vs class objects. Fortunately, hydra_zen.like
returns a _Tracker
instance, and thus we can avoid this headache from the get-go. That all being said, validating the input to like
-- ensuring it is a class-object or function -- will be important.
One nice thing about like
is that looks like a simple pass-through to IDEs, so you get auto-completes, static type checks, etc. as you build your "like expression" prior to creating the config.
Some obvious to-dos would include:
- Adding as much validation as possible. E.g.
GenLike().manual_seedling(42)
should raise a runtime error becausemanual_seedling
is not an attribute ofGenerator
. - Making the yaml-representation of the
just(like(<...>).<...>)
expression as clean as possible- doing
just(like(<target>)(*args, **kwargs))
should reduce down tobuilds(<target>, *args, **kwargs)
- doing
- Enable users to directly provide
like
'd objects in their configs without them needing to calljust
first.
I would definitely like to get feedback on this. I only just thought of it, and will have to think carefully about potential pitfalls as well as about possible use cases that I haven't anticipated yet.
This idea reminds me of unittest.mock.MagicMock
, which has an API for figuring method calls that have been made on the object:
>>> from unittest.mock import MagicMock
>>> GenLike = MagicMock(name="GenLike")
>>> seeded = GenLike("init").foo(42)["bar"].baz
>>> GenLike.mock_calls
[call('init'), call().foo(42), call().foo().__getitem__('bar')]
>>> seeded._mock_new_name # _mock_new_name is undocumented
'baz'
>>> str(seeded)
"<MagicMock name='GenLike().foo().__getitem__().baz' id='139896475473856'>"
Oh yeah, that is a great observation! I will see if I might leverage any of that functionality directly, or if some of their design decisions might be worth copying.
Had to throw this at you, the interface made me try it 😵💫
Module = builds(...) # lightning module
LightningTrainer = like(Trainer)(gpus=2, strategy="ddp")
LightningFit = just(LightningTrainer.fit(Module))
# automatically start fitting
instantiate(LightningFit)
This then begs the question, can we do something like this:
LightningTrainer = builds_like(Trainer, gpus=2, strategy="ddp", populate_signature=True)
LightningFit = just(LightningTrainer.fit(Module))
Nice!
This got me thinking...
While I do like hydra_zen.like
-- it definitely makes certain things way easier and more intuitive to configure -- I do get concerned that it might encourage bad practices. Beyond cases akin to torch.Generator().manual_seed(42)
I think users should be encouraged to include other complicated logic in the task function rather than in their configs.
(that all being said, it is pretty crazy how things "just work" like your example above)
This then begs the question, can we do something like this:
Yeah this is definitely something I have been thinking about... What are the arguments to like
? Do we expose a populate_full_signature
? zen_wrappers
? At what point does this just become builds
2.0?
As I mentioned in my previous comment, I think we should take some time to come up with guidelines for what are and aren't "best practices", and to have those guidelines inform how we design like
.
Heh... I just thought of this:
import torch as tr
tr = like(tr)
tr.tensor([1., 2.]) # is equivalent to `like(tr.tensor)([1.0, 2.0])`
That is, this would effectively auto-apply like
to all members of the module as you access them.
(kind of funny how I post this immediately after my "now now, we should take care to only recommend moderate use cases" post)
It would be nice to support like
-expressions without just
; e.g. make_config(a=like(Generator)().manual_seed(42))
, where just
is applied under the hood to create the proper structured configs for these values. We already do this for 1+2j
, and runtime support for like
-expressions is trivial, but there is an issue with annotations.
The annotation for like
is like(obj: T) -> T
, thus like(Generator)().manual_seed(42))
looks identical to Generator().manual_seed(42))
for type-checkers. This means that our annotations for make_config
et al would need to be be broadened to Any
, in order to accommodate arbitrary like(...)<...>
expressions. Any attempt to make type-checkers see that like
returns a Tracker
type would break all the nice tool assistance that one would want when all but trivial expressions with like
(e.g. autocomplete, signature inspections, etc.).
But resorting to Any
would obviously sacrifice all of our ability to statically validate bad vs good config values to make_config
et. al., whereas we currently provide quite reliable annotations. I don't think this is acceptable. There are a two options I can think of:
- Require users to explicitly
like
-expressions withjust
. I.e.make_config(a=like(...))
raise at runtime, whereasmake_config(a=just(like(...)))
is OK - Provide runtime support for
make_config(a=like(...))
but don't revise annotations. Thus the expression is valid at runtime, but type-checkers will mark it as an error.
The downside of 1 is that requiring an extra just
call substantially adds to the boilerplate for writing configs. Without it, users can practically write "native" expressions in their configs, which is quite nice!
The downside of 2 is that it promotes code patterns that will light IDEs/type-checkers up like a Christmas tree. Obviously, it is bad to have code that runs without issue get consistently flagged by static checks. I really don't want hydra-zen to promote this sort of thing.
Is there some nice trick I am missing here? Something with Annotated
that will let us have our cake and eat it too?
Based on the torch.Generator()
example, my understanding is that like
is restricted to patterns where instantiating the config immediately executes a linear sequence of actions. Is that correct?
I have a scenario where I want to expose the default parameter(s) of a class method in Hydra's config, but rather than instantiation calling the method, I want to receive the initialized object with a modified method. For a concrete example, PyTorch Lightning Flash offers finetuning strategies with their Trainer
. In my application code, I want to call trainer.finetune()
with the default parameter strategy
modified by Hydra, while also being able to interact with trainer
before/after
import flash
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
class MyCustomCallback(Callback):
pass
def train(trainer: flash.Trainer, model: pl.LightningModule, datamodule: pl.LightningDataModule):
"""This function would be a Hydra _target_ with the parameters supplied via recursive instantiation"""
# Interacting with ``trainer`` before and after calling the method whose params are to be altered
trainer.callbacks.insert(0, MyCustomCallback(...))
trainer.finetune(model, datamodule=datamodule) # TODO configure default value for ``strategy`` param
trainer.test()
As @jgbos points out, I could combine the interactions with trainer
into a single config node, which when instantiated, executes the logic of train()
. However, as the logic gets more complex, I find it odd to encapsulate it within the config building code instead of the application code. IMO Hydra's role is strictly to configure, and the logic of how objects are used should remain inside the main application code base. This way, the programmer's intent can be understood without requiring deep knowledge of Hydra
Given that (and assuming my interpretation of like
is correct), I think I'd need a different solution. Before reading this thread, I played around with the following. Perhaps it is helpful in brainstorming the best interface for users
import inspect
from hydra.utils import instantiate
from hydra_zen import builds
import flash
from omegaconf import OmegaConf
config_strategy = 'foobar'
TrainerConf = builds(
flash.Trainer,
zen_meta={'finetune': builds(
flash.core.trainer.Trainer.finetune,
strategy=config_strategy,
zen_partial=True)},
max_epochs=10,
gpus=-1)
# FIXME ``builds`` changes ``_target_`` path --> current workaround is to manually specify
TrainerConf.finetune._target_ = 'flash.core.trainer.Trainer.finetune'
print(OmegaConf.to_yaml(TrainerConf))
trainer = instantiate(TrainerConf)
sig = inspect.signature(trainer.finetune)
instantiated_strategy = sig.parameters['strategy'].default
# FIXME ``zen_meta`` attr disappears after instantiation so the ``strategy`` override has no effect
# Can't move the field outside of meta since it doesn't match the signature for Trainer.__init__
if instantiated_strategy != config_strategy:
raise AssertionError(
f'TrainerConf.finetune.strategy ({config_strategy}) != '
f'trainer.finetune.strategy ({instantiated_strategy})')
Seeing the wrappers approach in @rsokl initial comment, I now think that would be the better approach. I could write a wrapper function to replace the finetune
method with an alternate one (likely zen_partial
like above in order to allow Hydra overrides of the default parameters). This would allow me to instantiate objects with modified method signatures. Does that seem like a decent approach?
Sorry for the delayed response. I'm finally going through your example here and see what you are trying to do. I think this is the equivalent using like
:
from hydra_zen import to_yaml, like, just, instantiate
import flash
config_strategy = 'foobar'
TrainerConf = like(flash.Trainer)(max_epochs=10, gpus=-1)
FineTunerConf = TrainerConf.finetune(strategy=config_strategy, zen_partial=True)
#works
print(to_yaml(just(TrainerConf)))
instantiate(just(TrainerConf))
# does not work with partials
print(to_yaml(just(FineTunerConf)))
instantiate(just(FineTunerConf))
It currently doesn't support zen_partial
and populate_full_signature
which I think are two important features to add.
@rsokl I believe you had reasons for not pursuing like
. Also, is there anything in the works with hydra-core? I haven't been paying close attention to their future plans.
I'm wondering now if your use case would work with this approach. There would be a configuration for Trainer
and a FineTuner
now. So would the following work as expected? I'm wondering if updating trainer
(the callback insertion) works as expected (I think it will not).
import flash
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from typing import Protocol
class FineTuner(Protocol):
def __call__(
self,
model,
train_dataloader=None,
val_dataloaders=None,
datamodule=None,
strategy="no_freeze",
train_bn=True,
):
...
class MyCustomCallback(Callback):
pass
def train(
trainer: flash.Trainer,
finetuner: FineTuner,
model: pl.LightningModule,
datamodule: pl.LightningDataModule,
):
trainer.callbacks.insert(0, MyCustomCallback(...))
finetuner(model, datamodule=datamodule)
trainer.test()
We may need a better approach still. Here is a temporary solution:
from hydra_zen import to_yaml, instantiate, builds, make_config
import flash
config_strategy = "foobar"
TrainerConf = builds(flash.Trainer, max_epochs=10, gpus=-1, populate_full_signature=True)
Config = make_config(
finetuner_kwargs = make_config(strategy="foobar"),
trainer=TrainerConf,
model=...,
datamodule=...
)
def train(
trainer: flash.Trainer,
model: pl.LightningModule,
datamodule: pl.LightningDataModule,
finetuner_kwargs: dict
):
trainer.callbacks.insert(0, MyCustomCallback(...))
trainer.finetune(model, datamodule=datamodule, **finetuner_args)
trainer.test()