Tracker.jl
Tracker.jl copied to clipboard
preserve size for param(::AbstractArray)
According to https://fluxml.ai/Flux.jl/stable/performance/#Don't-use-more-precision-than-you-need.-1 we should use as small types as possible. Implementing
param(xs::AbstractArray) = TrackedArray(float.(xs))
seems somewhat counterintuitive as it tries to convert to Float64. Thus, if I restricted my own DataTypes to Float32 already this will just widen it again. Specialcasing for smaller float types might be a good idea here?
This doesn't convert Float32 to Float64:
julia> float(1.0f0)
1.0f0
I could see how you might expect param(1) * param(1f0) to be Float32 though. I'm not sure what we can do for this; we could assume Float32 is enough precision for an integer, but then we'd have the same issue with Float16 and so on.
This doesn't convert Float32 to Float64:
julia> float(1.0f0) 1.0f0
I feel the light sensation of need to burry mself in ground. 🤦♂
I could see how you might expect
param(1) * param(1f0)to be Float32 though. I'm not sure what we can do for this; we could assumeFloat32is enough precision for an integer, but then we'd have the same issue withFloat16and so on.
How about providing some setprecision!(::Model, ::Type) (or even for the whole session/Tracker)?
And then convert everything to that given precision. Maybe consider to add an option that throws if not representable (to help finding bugs). That would easen it a lot to just say, ah, I guess Float32 is enough for this problem. -> Everything that is needed in the tracking process automatically becomes Float32.