Tracker.jl icon indicating copy to clipboard operation
Tracker.jl copied to clipboard

preserve size for param(::AbstractArray)

Open rapus95 opened this issue 6 years ago • 2 comments
trafficstars

According to https://fluxml.ai/Flux.jl/stable/performance/#Don't-use-more-precision-than-you-need.-1 we should use as small types as possible. Implementing

param(xs::AbstractArray) = TrackedArray(float.(xs))

seems somewhat counterintuitive as it tries to convert to Float64. Thus, if I restricted my own DataTypes to Float32 already this will just widen it again. Specialcasing for smaller float types might be a good idea here?

rapus95 avatar Nov 07 '19 16:11 rapus95

This doesn't convert Float32 to Float64:

julia> float(1.0f0)
1.0f0

I could see how you might expect param(1) * param(1f0) to be Float32 though. I'm not sure what we can do for this; we could assume Float32 is enough precision for an integer, but then we'd have the same issue with Float16 and so on.

MikeInnes avatar Nov 07 '19 16:11 MikeInnes

This doesn't convert Float32 to Float64:

julia> float(1.0f0)
1.0f0

I feel the light sensation of need to burry mself in ground. 🤦‍♂

I could see how you might expect param(1) * param(1f0) to be Float32 though. I'm not sure what we can do for this; we could assume Float32 is enough precision for an integer, but then we'd have the same issue with Float16 and so on.

How about providing some setprecision!(::Model, ::Type) (or even for the whole session/Tracker)? And then convert everything to that given precision. Maybe consider to add an option that throws if not representable (to help finding bugs). That would easen it a lot to just say, ah, I guess Float32 is enough for this problem. -> Everything that is needed in the tracking process automatically becomes Float32.

rapus95 avatar Nov 07 '19 16:11 rapus95