Allow multiple datasets in fit_to_data and add option to return opt_state
Some loss functions can require several arrays instead of only one.
This extends fit_to_data so that it passes batches of those to the loss functions.
It can also be quite useful to reuse the optimizer state across different runs, so I added opt_state and return_opt_state arguments to fit_to_data to facilitate that.
I can see this could be useful, e.g. it would give a simple way to support weighted samples (https://github.com/danielward27/flowjax/pull/211). I appreciate the effort to maintain backwards compatibility, but it does lead to a bit of an ugly interface. If rewriting from scratch, you could consider replacing x with data: Array | tuple[Array, ...] , and remove the condition argument (just pass it in the data tuple). Then we could just document that train test splitting/batching etc will happen on all of data, and that the batched arrays should match the positional arguments of the loss function. Right now, the implementation you have given forces condition to be the last positional argument, which whilst fine, isn't clear until you read the code.
To me there isn't really a compelling reason to me why condition should be passed as a separate argument once we have chosen to allow x (or data) to be a tuple of arrays. If we go down this road, it might be best to just immediately deprecate condition, and x, with a warning encouraging people to use the new data positional argument (I guess before x, so that no change is necessary if x was passed as a positional argument).
Ability to pass and return the optimizer state is a good idea, but if we include it in fit_to_data, it should also be included in fit_to_key_based_loss for consistency. I would hope (I haven't checked) that it may also allow avoiding recompilation when calling the fitting function twice? Again, if willing to make a breaking change, you could choose to always return the optimizer state, and users could ignore it if they didn't need it. If writing from scratch this is better to me - one consistent return type, and one less argument. But obviously it introduces a breaking change, which I think is better to avoid, at least for now (i.e. as you have done).
Although I feel a little unsure about it, I think I am happy to merge this with the addition of the following changes:
- Add a
datapositional argument beforexwhich can be an array or a tuple of arrays (and a defaultx=None). - We add deprecation messages
xandconditionsaying they will be removed in the following version in favour of a newdatapositional argument, in a way that is non-breaking for now (obviously it will be breaking when deprecated, but at least there has been some warning). - Document the "alignment" between the tuple of arrays in
data, and the positional arguments in the loss functions (e.g.(target, condition)forMaximumLikelihoodLoss).
I'm definitely open to feedback on this though if you or anyone else has any thoughts: there is inevitably a trade off between maintaining backwards compatibility and simplifying the code and improving the API. Users are probably better posed to discuss this than I am - in my applications I don't really mind quick to fix breaking changes (in both FlowJAX and other dependencies), but I understand that may be different for others.
An issue I have just thought about, is it may be required to support None in data if assuming they map to positional arguments in the loss. Imagine a loss taking x, condition=None, weights=None. For unconditional weighted density estimation, data would have to be (x, None, weights) and then presumably we'd need code to ensure None doesn't get passed to the train val split and batching code (or we need to handle it in those functions). That is a little unsatisfying... Another possibility would be to force use of key words i.e. data could be a dict of arrays, matching key word arguments of the loss, rather than positional arguments - but that also feels a little unsatisfying. Maybe worth considering a few options before committing to one. Maintaining the current implementation is also somewhat reasonable - keeping the training loop as simple as possible so it's easier for users to copy and modify it as needed.
Yeah, I also wasn't too happy about introducing the return_opt_state parameter, but I thought I'd rather avoid a breaking change. I also added this to the other fit function.
Maybe the updated version is a nice solution for the multiple arrays and the condition argument?
I changed the x argument to a *data argument, so you can just pass multiple arguments as positional arguments.
Strictly speaking this is also a breaking change, because a user might have called it with a keyword argument for x, but this would at least be obvious to fix, and I would hope isn't common.
For the condition argument, I now also raise a deprecation warning, as that can be passed as the last data argument.
Hmm, I'll have to have a think. I wouldn't merge this as is because of the breaking change. Maybe it's possible to add x as a key word only argument and handle it appropriately to give a deprecation warning without any breaking changes. Also I think the aforementioned issue, of a loss requiring the following data x, condition=None, weights=None, would persist - you would have to wrap the loss if you wanted to only provide x and weights.
How about this version? This avoids the breaking change by having the old x argument and the extra positional arguments.
I'm not sure what to do about the condition/weights problem. Maybe just not deprecating the condition argument and keeping that explicitly in the list of arguments to fit_to_data?
Didn't mean to close...
It is a bit fiddly. Not deprecating the condition would be too confusing/messy of an interface I think if we were to go down this route. I gave it a bit of a go here: https://github.com/danielward27/flowjax/tree/multiple_arrays_fit (just focusing on multiple arrays, not adding opt_state, just to do one thing at a time).
It seems to work without breaking changes. Maybe it's important to support None in data, which the changes in train_utils start to support. In my implementation if we pass data as *(arr1, None) for example, then it works until we get to for batch in zip(*get_batches(val_data, batch_size), strict=True): i.e. I've supported train val splitting with None, batching with None, but the loop over the batches will error, and haven't had time to think about the best solution to that.
Alternatively you could argue users should be forced to e.g. wrap the function to remove any non-array arguments (in which case some of my changes are unnecessary). However, this is arguably a little confusing, especially if we have a loss needing arrays (x, condition=None, weights=None), and the user needs to wrap the loss to provide arguments 1 and 3: shouldn't the loss function be compatible with the package training script without modification?
I am still a bit uncertain about whether to support this. The fit_to_data function is <100 lines of code (a bit more with the docstring), so is relatively easy for users to copy and modify to their use case.
By the way, you have a raise warn in your code, which would error instead of warning!
I like your version :-)
fit_to_data now takes an argument data to provide an array or tuple of arrays:
https://github.com/danielward27/flowjax/pull/222 https://github.com/danielward27/flowjax/pull/220
Sorry for the delay, and thanks for the suggestion. I'd be happy to accept a separate pull request for a flag argument to return the optimizer state.