Flux.jl
Flux.jl copied to clipboard
Make `loss(f,x,y) == loss(f(x), y)`
If train! stops accepting implicit parameters, as in #2082, then its loss function needs to accept the model as an argument, rather than close over it.
This makes all the built-in ones do so, to avoid defining loss(m,x,y) = mse(m(x), y) etc. yourself every time.
(Defining loss(x,y) = mse(model(x), y) every time used to be the idiom for closing over the model, and IMO this is pretty confusing. It means "loss function" means two things. Cleaner to delete this entirely than to update it to a 3-arg version.)
PR Checklist
- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable
A NEWS entry for this feature would be good too
Sorry, I have to say that I'm really not a fan of this signature because it excludes a bunch of models while adding one more thing to know for loss function authors. For example, what does mse(m, x, y) even mean if you're doing self-supervised learning and m is some siamese network?
Given that the existing train! API requires users to define their own zero-arg "loss" function already, could we not keep that constraint (bring your own function) and pass in (m, x, y)? This would be strictly less confusing than the status quo and we could rename the callback to "forward pass" or some such.
Yes I agree it's specialised to some uses.
It just seems slightly weird to force people to define a function which is the just adjusting the signature to work, not doing any work or making any choices. They are forced to do so now because, in addition, this function closes over the model. So it must be re-defined if you change the model.
I suppose it seems especially odd if the "official" documented way is that you must name this trivial function. And perhaps writing always something like this would be less odd:
train!(model, [(x1,y1), (x2,y2), ...], opt) do m,x,y
mse(m(x), y)
end
However, that's still quite a bit of boilerplate to say "use mse". And I know some people find the do super-confusing at first.
If it were just a matter of clarifying how the do syntax works, we could address this with a docs issue. But to the brevity point, ideally we'd be able to extract out some loss(f,x,y) = loss(f(x), y) helper so that individual loss functions don't have to be responsible for being model-aware? It would be one more verb/noun to learn, but it would save us confused users who ported over a loss(x, y) function from some other library and don't understand the resulting MethodError (I'm assuming that if they don't understand do, they'd have a hard time with this too). If this wrapper were a named type, there's even a chance to toss in optimization state and thus simplify https://github.com/FluxML/Flux.jl/pull/2082, but I haven't thought too hard about that yet.
confused users who ported over a
loss(x, y)function from some other library and don't understand the resulting MethodError
Right now this is worse, loss(x, y) = norm(x - y) will result in zero gradients but no error.
For implicit-Flux, having methods like mse(m) = (x,y) -> mse(m(x), y) would allow train!(mse(model), params(model), data, opt) which is less obscure than what we have now. Or it could be spelled train!(applyloss(mse, model), params(model), data, opt) with one more verb.
For explicit-Flux, we could have train!(applyloss(mse), model, data, opt). Not a big fan of a verb exclusively to translate built-in loss to what built-in train! wants. though.
We could also just make train!(loss, model, [(x1,y1), (x2,y2)], opt) call loss(model(x1), y1). The rule is then loss(model(data[1][1]), data[1][2:end]...) instead of loss(model, data[1]...) in #2082? Not sure either.
Right now this is worse,
loss(x, y) = norm(x - y)will result in zero gradients but no error.
Yeah, that's a good argument for having the model(x) part more explicit. Is it too bad to ask users to write a 2-liner?
data = [(x1,y1), (x2,y2), ...]
train!((m, x, y) -> mse(m(x), y), model, data, opt)
Most users can directly copy-paste this, and those who have more complex forward passes can either define a separate function or ease into learning the do syntax. But if one just wants to add a regularization term?
train!((m, x, y) -> mse(m(x), y) + Optimisers.total(norm, m), model, data, opt)
OK, https://fluxml.ai/Flux.jl/previews/PR2114/training/training/ takes this view that we should just always make an anon. function. It emphasises gradient + update over train!, and for gradient you are always going to want that. And it explains the do block several times.