MLUtils.jl icon indicating copy to clipboard operation
MLUtils.jl copied to clipboard

mapobs problems with composite datasets and vector indexes

Open CarloLucibello opened this issue 2 years ago • 0 comments

With tuples:

julia> t = (zeros(3), ones(3))
([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])

julia> m = mapobs(x -> (x[1], 2*x[2]), t)
mapobs(#15, Tuple{Vector{Float64}, Vector{Float64}})

julia> m[1]
(0.0, 2.0)

julia> m[2]
(0.0, 2.0)

julia> m[1:2]
((0.0, 0.0), (1.0, 2.0))  # expected ([0.0, 0.0], [2.0, 2.0]) 

With named tuples:

julia> t = (a=zeros(3),b=ones(3))
(a = [0.0, 0.0, 0.0], b = [1.0, 1.0, 1.0])

julia> m = mapobs(x -> (x[1], 2*x[2]), t)
mapobs(#17, NamedTuple{(:a, :b), Tuple{Vector{Float64}, Vector{Float64}}})

julia> m[1]
(0.0, 2.0)

julia> m[2]
(0.0, 2.0)

julia> m[1:2]
ERROR: ArgumentError: broadcasting over dictionaries and `NamedTuple`s is reserved
Stacktrace:
 [1] broadcastable(#unused#::NamedTuple{(:a, :b), Tuple{Vector{Float64}, Vector{Float64}}})
   @ Base.Broadcast ./broadcast.jl:705
 [2] broadcasted
   @ ./broadcast.jl:1295 [inlined]
 [3] getindex(data::MLUtils.MappedData{var"#17#18", NamedTuple{(:a, :b), Tuple{Vector{Float64}, Vector{Float64}}}}, idxs::UnitRange{Int64})
   @ MLUtils ~/.julia/packages/MLUtils/8OXl7/src/obstransform.jl:14
 [4] top-level scope
   @ REPL[35]:1

The fix could be turning https://github.com/JuliaML/MLUtils.jl/blob/1da3c53c1b6a5c2e4ce51ab74df358f02c17d1bf/src/obstransform.jl#L14

into

Base.getindex(data::MappedData, idxs::AbstractVector) = Flux.batch([data.f(getobs(data.data, i)) for i in idxs]) 

CarloLucibello avatar Jun 30 '22 03:06 CarloLucibello