Metric3D
Metric3D copied to clipboard
Supporting old GPUs?
Hello, thanks again for the great work! Your model uses torch.bfloat16
which is only supported by the newer GPUs.
https://github.com/YvanYin/Metric3D/blob/7b5440dcbc17ef5e09805169a5f0b2d6bfe59161/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py#L218-L229
May I ask you to kindly support older ones by adding an option to use torch.float32
instead? It could be as simple as
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
, and use dtype
in autocast