DeviceList/with_device inconsistency
Because TensorFlow.jl with_device() expects 1-based device numberings, we cannot natively use the output of DeviceList() to be fed into with_device(), as DeviceList() gives zero-indexed device names.
My current workaround is to do something like the following:
function get_device(sess, device_type)
# Find first device with the given device type (e.g. `"XLA_GPU"`)
devices = collect(TensorFlow.DeviceList(sess))
device = first(filter(x -> x.device_type == device_type, devices))
# Fixup this device name so that it is 1-indexed, as TensorFlow.jl requires
inc_number(x::Number) = x + 1
inc_number(x) = x
function fixup_device!(d::TensorFlow.Device)
d.parts[:] .= [TensorFlow.DevicePart(p.kind, inc_number(p.index)) for p in d.parts]
return d
end
fixup_device!(d) = d
return fixup_device!(TensorFlow.Device(device.name))
end
Personally, I would prefer that with_device used zero-indexed device names, as on systems with multiple devices (e.g. multiple GPUs), it adds an unnecessary extra mental burden to always remember that /job:job/replica:1/task:1/device:CPU:1 in TensorFlow.jl is not the same thing as /job:job/replica:1/task:1/device:CPU:1 when dealing with anything else in the TensorFlow ecosystem. Regardless, we should be consistent so that the output of one function can be fed to another within TensorFlow.jl.
How about a compromise for now where all functions related to devices take a keyword argument for whether to interpret them as one-base or zero-based? I don't want to break anyone's code which is relying on the current 1-based system for now.
The easy solution for now is to do this fixup_device() kind of thing within DeviceList(); I think anybody who is already doing this kind of thing to make the output of DeviceList() usable won't mind, and everybody else can continue using 1-based indexing. I can't imagine anybody relying upon the incorrect values of DeviceList() being unhappy about no longer needing to use a workaround like this.
Ya you're right. Want to submit a PR?
Rather than fixing it after the fact we could "fix" the heuristic that determines if things are indexes or not:
But I think that might be more confusing. Since it spreads DeviceList specific code into more files.
https://github.com/malmaud/TensorFlow.jl/blob/8a28acb020bef2ff95787687cdfb47bd33c6e821/src/generate_ops.jl#L98