Setfield.jl
Setfield.jl copied to clipboard
Lens and derivatives
@tkf's amazing gist shows how we can use lenses to compute partial derivatives.
Often I am in the situation, where I want (partial) gradients, hessians etc. instead. Optimizing a struct is a typical example.
Basic case
We have
struct Widget{T <: Number}
parameter1::T
parameter2::T
end
cost(w::Widget) = ...
and want to optimize cost. However many algorithms won't accept the cost function as is,
but need cost(::AbstractVector). Also they often need gradients+hessians in the same form as well.
So to use such algorithms, a way is needed to convert our struct to a vector and back.
Variations:
- Instead of a flat struct, one often has nested structs.
- Some parameters of the struct might be fixed, while others can be optimized.
Currently I write a lot of ad hoc functions to pack unpack the nested fields I need into vector form and back.
I think there should be a better way. I would like to just provide for each of my structs an isomorphism (e.g. a special type of lens) with some type of vectors and the lens machinery should be able to do the following for me:
- It should be easy to define the isomorphism for a big new struct, in terms of the isomorphisms of its fields.
- If should be easy to construct lenses that focus on a selection of nested fields.
This is interesting direction! How about providing AbstractVector (or even AbstractArray) interface/wrapper given an original object (e.g., struct) and a vector/array of lenses? Or is equivalent to what you mean by isomorphism? I'm thinking something like this:
using Setfield
struct VectorInterface{T, O, LS <: Tuple} <: AbstractVector{T}
obj::O
lenses::LS
end
function VectorInterface(obj::O, lenses) where O
orig = promote((get(l, obj) for l in lenses)...)
lenses = tuple(lenses...)
return VectorInterface{typeof(orig[1]),
O,
typeof(lenses)}(obj, lenses)
end
Base.length(vi::VectorInterface) = length(vi.lenses)
Base.getindex(vi::VectorInterface, i) = get(vi.lenses[i], vi.obj)
Base.setindex(vi::VectorInterface, x, i) =
set(Setfield.compose(vi.lenses[i], (@lens _.obj)), vi, x)
Setfield.hassetindex!(::VectorInterface) = false
struct Widget
x1
x2
x3
y1
y2
end
# Accessing only x* as a vector:
vi = VectorInterface(Widget(1:5...),
((@lens _.x1),
(@lens _.x2),
(@lens _.x3),))
@assert vi[1] == 1
@assert vi[2] == 2
@assert vi[3] == 3
let vi2 = @set vi[1] = 111
@assert vi2[1] == 111
end
Yeah I am not exactly sure, what I mean by isomorphism :) An isomorphism between two types T,S would be a pair of functions f::T->S and g::S->T which are inverse e.g. g(f(t)) == t for all t::T and f(g(s)) == s for all s::S.
Such an isomorphism f::T->S uniquely defines a lens by the rules
get(lens, t) = f(t)
set(lens, t, s) = g(s)
A lens of this form is an isomorphism. But these definitions might be too narrow. For instance
struct Point2d{T};x::T,y::T end
Then
f(pt::Point2d) = @SVector [pt.x, pt.y]
is an isomorphism, while
f(pt::Point2d) = [pt.x, pt.y]
is not with the above definition. Anyway enough distracting abstract nonsense.
VectorInterfaceis very interesting. For flat structs, it does the job!
However I think it is not easy to obtain a VectorInterface for a nested
struct from VectorInterface of its components.
Here is something, that I think works better for deep structs.
Unfortunately this is hacky and has horrible performance. I think it will be some work to clean it up.
using Setfield
using Setfield: compose
struct VectorLens{LS <: Tuple} <: Lens
scalar_lenses::LS
end
function Setfield.get(lens::VectorLens, obj)
[get(l, obj) for l in lens.scalar_lenses]
end
function Setfield.set(l::VectorLens, obj, v)
for (li, vi) in zip(l.scalar_lenses,v)
obj = set(li, obj, vi)
end
obj
end
function veclens(lenses::Lens...)
scalar_lenses = []
for l in lenses
if l isa VectorLens
append!(scalar_lenses, l.scalar_lenses)
else
push!(scalar_lenses, l)
end
end
VectorLens(tuple(scalar_lenses...))
end
function Setfield.compose(vl::VectorLens, lens::Lens)
VectorLens(map(vl.scalar_lenses) do l
compose(l, lens)
end)
end
Base.:(∘)(l1::Lens, l2::Lens) = Setfield.compose(l1,l2)
struct Point{T}
x::T
y::T
end
struct Line{T}
p1::Point{T}
p2::Point{T}
end
point_lens = veclens((@lens _.x), (@lens _.y))
line_lens = veclens(
point_lens ∘ (@lens _.p1),
point_lens ∘ (@lens _.p2))
line = Line(Point(1,2), Point(3,4))
@show get(line_lens, line)
set(line_lens, line, [10, 20, 30, 40])
Cool. I like the compositionality of VectorLens. What is the source of the slowness? Is it because you go back and forth from vector to struct? Maybe emitting a special code for the case when the vector is a static array help (though you can't do the same trick for get [1])? Or another way would be to make VectorInterface on top of VectorLens, so that you don't need to "materialize" a vector.
What you trying to achieve with Line and line_lens looks like related to MultiScaleArrays.jl: https://github.com/JuliaDiffEq/MultiScaleArrays.jl
Did you look at it? Though I guess what you want to achieve here (beyond the Line example) is more general purpose.
BTW, I usually like some small amount of abstract nonsense :)
[1] Maybe we can add interface Setfield.get(::Type, ::Lens, obj) for vectorized lens to decide the container type? Or maybe the vectorized lens has to have the "destination" type?
There are multiple issues:
- The
getoperation is allocating. I think it shouldn't be hard to return anSVectorinstead. OTOHVectoris better in the case of large structs. Also I am not so keen on depending onStaticArrays? Decisions, decisions :) - The
veclensoperation is not typestable. Again not that hard to make typestable. Its more of an annoyance then a performance problem, since this function is probably not called in a hot loop. - The
setoperation updates each field individually. This is the part, were the real problem lies. We need a way to apply multiple lenses in a single transaction. I have a vague idea how to do this. We need a lens that can be used something like this with a single constructor call:
struct T;a;b;c end
l = @lens _.(a,b) # Creates MultiFieldLens{(:a,:b)}()
obj = T(1,2,3)
set(l, obj, (a=10, b=20)) # T(10, 20, 3)
get(l, obj) # (a=1, b=2)
I think such a lens that allows translating between objects and named tuples would be useful in lots of other situations aswell. For example it could replace the reconstruct function with:
@set obj.(a=1, b=2)
That snipped is a good starting point, but I am not yet sure how to nest this. For instance it would be nice to be able to do something like this:
obj = T(1,2,T(3,4,5))
l = @lens _.(a, c=(a,b))
set(obj, l, (a=10, c=(a=30, b=40))) # T(10, 2, T(30,40,5))
Syntax and implementation both need more thought in the nested case.
I was not aware of MultiScaleArrays.jl, it looks interesting. thanks. I need to study it. The idea to allow to decide on the container type of a vector lens sounds good, I think we need to do that. VectorInterface can indeed a possible output type here.
Thanks for describing the issues. MultiFieldLens would be awesome. I guess supporting destination container types like (named) tuple and static arrays would be equivalent to having VectorLens.
It would be great to get something solid that did this. There is also https://github.com/rdeits/Flatten.jl and my awful hack https://github.com/rafaqz/CompositeFieldVectors.jl.
I've been playing with flattening nested structs in various ways but mostly without great performance unless I do bad things with eval, and definitely in less theoretically sound ways than you are describing! But I'm keen to get something working using @generated functions similar to https://github.com/rdeits/Flatten.jl but more flexible.
My main use-case is rebuilding nested structs with Dual numbers for AutoDiff and updating parameters for parameter estimation routines, while using composed nested parameter structures. Often this means extracting floats, adding a Dual number wrapper type, and then rewrapping with the correct units for each field.
I also wrote this package which turned out to be very useful for this kind of thing. For ignoring fields where they need to be on the struct but not in the vector, and separating out units and other wrapper metadata to be applied during rebuild: https://github.com/rafaqz/MetaFields.jl
Derivatives of nested structs is also one of my use cases. I think in order to specify what pieces of your struct should be derived named tuples are a must. Currently I am still on julia 0.6 however, where named tuple support is limited. Once I am on 0.7 I will probably toy around with the ideas in this thread again.
Do you currently have a way to derive pieces of nested structs that is nice but slow? Can you provide an example?
I currently use this branch: https://github.com/rafaqz/CompositeFieldVectors.jl/tree/manual
I stopped using the getfield/setfield methods I was originally using with structs of type Any and now I just do a complete rebuild of nested structs/tuples to aim for type stability. Importantly for me it takes the value of units without the unit wrapper because ForwardDiff doesn't handle units.
I prefer not to use named tuples etc, as I really want it to be automatic so I can change my models without specifying anything and they just work, and there are hundreds of possible parameter combinations to write tuples for.
To do that the switch needs to be attached to fields directly, and https://github.com/rafaqz/MetaFields.jl that I built for other things happens to be great at that. It generates the @flattenable macro and methods. There are a few fields where I need to do this:
@flattenable mutable struct Test
a::Int | false
b::Float64 | true
end
Edit: to clarify, flattenable(Test) would return (false, true)
They are ignored and the current value used in the reconstruct() function. Basically I'm storing the metadata in function methods. I think it's actually type-stable too, as is the flatten() function. But I'm not sure how to make the reconstruct() type stable, it probably needs to be@generated.
The whole thing has a hacky feel to it and works in limited conditions: it wont handle arrays or other data types, just Tuples and Structs with fields in Number.
The interesting things are here: https://github.com/rafaqz/CompositeFieldVectors.jl/blob/manual/src/constructors.jl
I'm currently playing with using @pure for type stability but it's fickle and I don't think it's a good strategy
Cool I like the @flattenable! I also think its not too ugly to get performance using @generated functions! I usually prefer @generated to @pure, since @pure is so fragile.
In my usecase I need quite some flexibility. For instance one optimization routine may tune some of the fileds and the next may tune some other fields etc. So I don't want to be forced to specify the diff fields at type definition.
If you change annotations with MetaFields.jl, do you need to restart the julia session?
That makes sense, I don't really have any fields like that. But you could just define another layer of metafield:
@metafield flattenable2 true
@flattenable2 @flattenable mutable struct Test
a::Int | false | true
b::Float64 | true | false
end
You can also redefine the methods after type definition, with @reflattenable. Just write the struct again, you can ignore types its just needs the field names and the metafield, but thats really only meant to happen once to overwrite a packages defaults.
I think @generated is the way to go too, its my current code thats ugly! I just need to skill up on writing @generated...
Edit: missed that question: no metaparam macros can be updated at any time, although I might have some @pure function in there right now that might mess with that - but my tests pass with @re- macros updating things
Updating: my fork of Flatten.jl now incorporates the @flattenable metafield and is type stable and probably as fast as this can get for flatten/reconstruct of nested structs.
But there are world age issues with @flattenable and @generated, so using Flatten has to be last, which breaks your need for switching fields, you actually do have to restart your session if @generated is used.
So I'm tending to agree that passing in tuples is the safest most generic way of doing this. But it might be slower and harder to get type stable than Flatten.jl as is, although you could use tuples of Val to make it all happen at compile time. For me type stability is more important currently as sensitivity analysis needs thousands of reconstruct and it ends up taking far longer than the actual algorithm.
Edit: Revise actually solves this.
revise(Flattenable)
revise(Flatten)
And updates work. Its admittedly still kind of ugly.
Nice! Why are there world age issues? Is this because generated function body is not pure? Do you know which part of your generated functions are problematic? I would guess that in this could be solved in principle. Since you have to annotate type definitions anyway, you could try to do the problematic work in the macro, instead generated functions. It might make the code more ugly though.
I'm not totally sure! But no its not necessarily pure. Firstly it seems @generated functions are very conservative with world age, using the point they are declared rather than when the code is actually generated.
This is the scenario:
The main funciton in Flatten,@generated flatten(....) calls the methods of flattenable(T) for every nested type, at compile time. flattenable(T) is supplied by the Flattenable module (necessary for world age currently) and is declared first in a @metafield flattenable on load of the Flattenable module. Then methods are added to it when actual types are declared as:
@flattenable struct SomeType
field1::Int | false
field2::Float64 | true
end
In the above the struct SomeType is declared, along with flattenable(SomeType, Val{:field1}) = false and flattenabl(SomeType, Val{:field2}) = true.
If the Flatten module was already loaded at that point, those methods will be have a world age that is too new when the @generated flatten function is run. Even though it can't have already compiled for SomeType as SomeType didn't exist before those methods were declared.
But we probably want to add more flattenable() methods later anyway, so we will always have methods that are too new - flatten() is already generated for the type. So revise(Flatten) would always be necessary at some point.
This is probably the fastest the code can be because flatten() ends up being just a list of direct calls to each nested field that is flattenable - filling a vector or tuple. But it might be worth sacrificing some of that for not needing weird revise() workarounds. I'm going to test it out a month or two before I contemplate implementing that.
Yes I think calling flattenable in the generated function is evil. Here is a gist of what I was thinking about:
@generated function concatenate_tuples(t1::NTuple{N1,Any}, t2::NTuple{N2,Any}) where {N1,N2}
args1 = [:(t1[$i]) for i in 1:N1]
args2 = [:(t2[$i]) for i in 1:N2]
Expr(:tuple, args1..., args2...)
end
flatten(primitive::Number) = (primitive,)
@generated function flatten(x)
args = [:(flatten(x.$field)) for field in fieldnames(x)]
:(concatenate_tuples($(args...)))
end
macro flattenable(code)
esc(quote
struct S{A,B,C}
a::A
b::B
c::C
end
function flatten(s::S)
concatenate_tuples(flatten(s.a), flatten(s.c))
end
end)
end
@flattenable struct S{A,B,C}
a::A | true
b::B | false
c::C | true
end
s = S(1,2,(3,4))
@code_warntype flatten(s)
Evil! But yeah I admit its pretty bad... It mostly arose from already having that method in a non-@generated context and just mashing them together while still learning how @generated actually works.
But your method looks good! Redefining flatten inside the @flattenable macro and doing nested concatenation seems like a much er healthier solution, and shouldn't be too much slower. We would need to do something similar for reconstruct() methods.
Interestingly @rdeits thought the whole concept of Flatten was questionable...
Yes I also think this will be as fast as the @generated functions in most cases. Reconstruct is harder, both for the programmer and the compiler. If I come up with something, I will tell you.
Sure, I'll keep thinking about it too. This technique is going to be central to all my work for the next year.
But I'll keep using the evil version until we work something out ;)
Yes sure you can use it, some excellent julia packages do evil things. For example in StaticArrays they are mutating tuples to implement setindex! on MArray.
Let me just clarify why this is evil. In a generated function, you are only allowed to call pure functions. Calling something impure is undefined behavior. Now I am not exactly sure what a pure function is. Intuitively a function that has no side effects is pure.
But there are very subtle forms of side effects in julia. For example a function that may throw an error is impure. Or a function whose method table may change in future is impure. So any function that might get methods added in future is impure and potentially dangerous to use in a @generated function.
Yeah I think only a few core devs really know exactly what pure means in Julia...
In this case I'm pretty obviously breaking purity by overriding methods that are used in a @generated function. However, if Flatten simply loads after methods are declared and it doesn't happen again, its actually pure. The method just returns a constant!
Adding new methods to flattenable after Flatten is loaded is where the evil kicks in, before that I would typify it more as "dodgy" than "evil", because of the required ordering of events. It's definitely not something to put in METADATA. But as far as I know its the only existing type-stable option for running sensitivity analysis and parameterization on nested parameters with excluded fields.
Interestingly @rdeits thought the whole concept of Flatten was questionable...
That's entirely possible, but I don't actually recall. Can I help out here?
I thought you might be vaguely interested, we have a few use cases for Flatten.jl, flattening nested structs for derivatives/optimisation etc. Its really the most important problem for my modelling work at the moment.
I thought your original Flatten.jl was pretty sound conceptually, seeing what a lot of other packages do with @generated these days. But the readme is very tentative about wether it should actually be used!
My fork makes it far more dodgy than it was by adding a flattenable method that is called inside the @generated function to see which fields should be included: https://github.com/rafaqz/Flatten.jl
The issue is that in common use cases you often need to exclude a few fields on the struct, and sometimes change which fields are excluded during a session. Doing that in a type stable way is difficult without dodgy practises like breaking purity and calling revise(Flatten). @jw3126 has a better approach (see above) but it needs more thought to get the reconstruction side to be type stable.
Aha, I see. Wow, that was one of the first packages I ever wrote in Julia, so I wouldn't be surprised if it's not entirely sound :slightly_smiling_face: . It was definitely written long before I understood the hyperpure restriction on @generated function bodies.
@jw3126's suggestion does look good, although I can see how that would make it hard to update the set of included fields interactively.
One thing I've learned since I wrote Flatten.jl is to trust that the compiler can actually do a lot for you, even outside of the generated function. For example, Interpolations.jl had this bug: https://github.com/JuliaMath/Interpolations.jl/issues/207 in which a call to promote_type happened inside the generator of an @generated function. That didn't work, because promote_type gets new methods all the time. The fix was just to turn something like:
@generated function foo(x::T1, y::T2) where {T1, T2}
R = promote_type(T1, T2)
quote
Vector{$R}()
end
end
into something like this:
@generated function foo(x::T1, y::T2) where {T1, T2}
quote
R = promote_type($T1, $T2)
Vector{R}()
end
end
which looks type-unstable, but is actually fine, since the result type of promote_type is inferrable from just its argument types. Even though it looks like we've made the resulting function worse, all the run-time complexity is compiled away, so there's no performance loss. And, by moving promote_type into the quoted function, adding new promote_type methods becomes totally fine. Even though the generator is not run again, the resulting function it produced will always be recompiled when necessary.
So maybe you can do the same thing here? For example, could you have an @generated function flatten() that, in the generator, just loops over the fields of the given type and emits something like:
combine_tuples(
if flattenable(MyStruct, Val{:A}()) isa Flattenable
mystruct.A
else
tuple()
end,
if flattenable(MyStruct, Val{:B}()) isa Flattenable
mystruct.B
else
tuple()
end
)
for each field? Then you can implement flattenable(::Type{<:MyStruct}, ::Val{:A}) = Flattenable() or flattenable(::Type{<:MyStruct}, ::Val{:B}) = NotFlattenable().
That generated code looks very type-unstable, but, again, the compiler is pretty good at this kind of thing. Branches of the form if f(x) isa T are generally eliminated by the compiler when it specializes on the type of x, as long as the function f(x) is type-stable.
The benefit of this would be that if you change your mind and redefine flattenable(..., ::Val{:B}) = Flattenable(), even in the same Julia session, your function will be re-compiled and will just work.
Thanks, that's amazing advice. You are totally correct I had been avoiding long lists of if blocks because I often have over 50 fields and was concerned about the overhead!
But if they are optimised out that looks like a really good approach. Doing that for reconstruction will be the hard part... Maybe something like this?
reconstruct(original, data)
n = 1
MyStruct(
if flattenable(MyStruct, Val{:A}()) isa Flattenable
ChildStruct(
if flattenable(AStruct, Val{:X}()) isa Flattenable
data[n]
n += 1
else
original.X
end,
....
)
else
original.A
end,
...
)
end
Yeah, Julia's pretty good at this stuff (but the only way to be sure is to try it and benchmark for yourself). But I think the rule of thumb is that the generator part of your @generated function is good for unrolling loops or enumerating fields, while the actual code you generate can often handle a lot of the hard work of type inference.
Also, this is all going to get even better in v0.7 with inter-procedure constant propagation. I think you'll be able to do something like:
if flattenable(MyStruct, Val(:A))
x
else
y
end
and then just define:
flattenable(::Type{<:MyStruct}, ::Val{:A}) = true
(rather than using the Flattenable and NotFlattenable types), because Julia will be able to propagate that constant true value through your function in order to evaluate the if statement at compile time (the same way it currently can propagate type information)
I got into this last night and have a fully working (non-evil!) version using a combination of the methods we discussed here.
https://github.com/rafaqz/Flatten.jl
You can update fields at any point by overwriting flattenable methods or using @flattenable. It uses Flat() and NotFlat() now for flattenable field values, but I will transition back to true/false for 0.7 if constant propagation works.
Thanks for all the help getting this working! I'll give due credit in the readme.
but I think it is harder for the compiler
@jw3126 Continuing discussion from https://github.com/jw3126/Setfield.jl/pull/34#discussion_r210092772 isn't it the same effort for compiler, if we generate code in @generate?