TensorLayerX icon indicating copy to clipboard operation
TensorLayerX copied to clipboard

Pytorch后端NHWC和NCHW问题

Open Windaway opened this issue 2 years ago • 0 comments

New Issue Checklist

Issue Description

Pytorch后端模型定义NHWC和NCHW数据格式主要是定以数据和模型后传到设备时用.to("cuda:0", memory_format=torch.channels_last)确定。

TLX目前做法是pytorch依据nhwc格式时,全部转NCHW然后处理完转回来,这潜在是让模型用NCHW格式计算。对纯GPU应用时问题不大,但是对于一些NHWC友好的设备部署,比如未来的Mindspore,由于多次nhwc nchw切换,性能有损失。

这里可能需要框架对于pytorch这里nhwc支持改成全局变量,即输入时数据做nchw-nhwc,模型转nhwc然后计算即可。

不过Pytorch本身GPU NHWC支持稀烂,倒不是很急。

Windaway avatar Feb 07 '23 12:02 Windaway