TensorFlow.jl icon indicating copy to clipboard operation
TensorFlow.jl copied to clipboard

DeviceList/with_device inconsistency

Open staticfloat opened this issue 7 years ago • 4 comments

Because TensorFlow.jl with_device() expects 1-based device numberings, we cannot natively use the output of DeviceList() to be fed into with_device(), as DeviceList() gives zero-indexed device names.

My current workaround is to do something like the following:

function get_device(sess, device_type)
    # Find first device with the given device type (e.g. `"XLA_GPU"`)
    devices = collect(TensorFlow.DeviceList(sess))
    device = first(filter(x -> x.device_type == device_type, devices))

    # Fixup this device name so that it is 1-indexed, as TensorFlow.jl requires
    inc_number(x::Number) = x + 1
    inc_number(x) = x
    function fixup_device!(d::TensorFlow.Device)
        d.parts[:] .= [TensorFlow.DevicePart(p.kind, inc_number(p.index)) for p in d.parts]
        return d
    end
    fixup_device!(d) = d
    return fixup_device!(TensorFlow.Device(device.name))
end

Personally, I would prefer that with_device used zero-indexed device names, as on systems with multiple devices (e.g. multiple GPUs), it adds an unnecessary extra mental burden to always remember that /job:job/replica:1/task:1/device:CPU:1 in TensorFlow.jl is not the same thing as /job:job/replica:1/task:1/device:CPU:1 when dealing with anything else in the TensorFlow ecosystem. Regardless, we should be consistent so that the output of one function can be fed to another within TensorFlow.jl.

staticfloat avatar Oct 28 '18 15:10 staticfloat

How about a compromise for now where all functions related to devices take a keyword argument for whether to interpret them as one-base or zero-based? I don't want to break anyone's code which is relying on the current 1-based system for now.

malmaud avatar Oct 29 '18 15:10 malmaud

The easy solution for now is to do this fixup_device() kind of thing within DeviceList(); I think anybody who is already doing this kind of thing to make the output of DeviceList() usable won't mind, and everybody else can continue using 1-based indexing. I can't imagine anybody relying upon the incorrect values of DeviceList() being unhappy about no longer needing to use a workaround like this.

staticfloat avatar Oct 29 '18 17:10 staticfloat

Ya you're right. Want to submit a PR?

malmaud avatar Oct 30 '18 15:10 malmaud

Rather than fixing it after the fact we could "fix" the heuristic that determines if things are indexes or not:

But I think that might be more confusing. Since it spreads DeviceList specific code into more files. https://github.com/malmaud/TensorFlow.jl/blob/8a28acb020bef2ff95787687cdfb47bd33c6e821/src/generate_ops.jl#L98

oxinabox avatar Oct 31 '18 11:10 oxinabox