MLUtils.jl
MLUtils.jl copied to clipboard
mapobs problems with composite datasets and vector indexes
With tuples:
julia> t = (zeros(3), ones(3))
([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])
julia> m = mapobs(x -> (x[1], 2*x[2]), t)
mapobs(#15, Tuple{Vector{Float64}, Vector{Float64}})
julia> m[1]
(0.0, 2.0)
julia> m[2]
(0.0, 2.0)
julia> m[1:2]
((0.0, 0.0), (1.0, 2.0)) # expected ([0.0, 0.0], [2.0, 2.0])
With named tuples:
julia> t = (a=zeros(3),b=ones(3))
(a = [0.0, 0.0, 0.0], b = [1.0, 1.0, 1.0])
julia> m = mapobs(x -> (x[1], 2*x[2]), t)
mapobs(#17, NamedTuple{(:a, :b), Tuple{Vector{Float64}, Vector{Float64}}})
julia> m[1]
(0.0, 2.0)
julia> m[2]
(0.0, 2.0)
julia> m[1:2]
ERROR: ArgumentError: broadcasting over dictionaries and `NamedTuple`s is reserved
Stacktrace:
[1] broadcastable(#unused#::NamedTuple{(:a, :b), Tuple{Vector{Float64}, Vector{Float64}}})
@ Base.Broadcast ./broadcast.jl:705
[2] broadcasted
@ ./broadcast.jl:1295 [inlined]
[3] getindex(data::MLUtils.MappedData{var"#17#18", NamedTuple{(:a, :b), Tuple{Vector{Float64}, Vector{Float64}}}}, idxs::UnitRange{Int64})
@ MLUtils ~/.julia/packages/MLUtils/8OXl7/src/obstransform.jl:14
[4] top-level scope
@ REPL[35]:1
The fix could be turning https://github.com/JuliaML/MLUtils.jl/blob/1da3c53c1b6a5c2e4ce51ab74df358f02c17d1bf/src/obstransform.jl#L14
into
Base.getindex(data::MappedData, idxs::AbstractVector) = Flux.batch([data.f(getobs(data.data, i)) for i in idxs])