求教 expr 中_Fill 的正确用法
平台: Android
Github版本: 2.8.2
问题:使用如下代码 #if 1扩起来的代码,但是结果不符合预期,使用 #else 括起来的符合预期,但是VARP 内存在GPU 上时可能不合理。
需求:填充某个 VARP 的数据全为 0 的合理方法,能兼容多种推理后端,谢谢!
template <typename T>
static void __VARP_FillValue(VARP* varp, T value) {
auto mInfo = (*varp)->getInfo();
if (mInfo->dim.empty()) {
auto destPtr = (*varp)->writeMap<void>();
::memcpy(destPtr, &value, sizeof(value));
return;
}
#if 1
// TODO: these code can't fill zero value
auto var_shape = _Input({(int)mInfo->dim.size()}, NCHW, halide_type_of<int>());
auto shapePtr = var_shape->writeMap<int>();
memcpy(shapePtr, &mInfo->dim[0], mInfo->dim.size() * sizeof(int));
VARP tmpVarp = _Fill(_Shape(var_shape, mInfo->order == Dimensionformat::NCHW), _Scalar<T>(value));
tmpVarp = _Convert(tmpVarp, mInfo->order);
auto sourcePtr = tmpVarp->readMap<T>();
auto destPtr = (*varp)->writeMap<T>();
if (mInfo->size && destPtr && sourcePtr) {
::memcpy(destPtr, sourcePtr, mInfo->size * mInfo->type.bytes());
}
#else
// TODO: 这段代码解决了设置 0 值问题,但是在 GPU 上可能不合理
auto destPtr = (*varp)->writeMap<T>();
std::fill_n(destPtr, mInfo->size, value);
#endif
}
调用如下:
auto inVarp = _Input({6,1,100,14}, NCHW, halide_type_of<float>());
__VARP_FillValue(&inVarp, 0.f);
#if 1 括起来的代码测试结果如下(截取了部分):
0 0 0 0 -0.148806 0.0477117 0.173962 -0.0790967 -0.0392254 -0.0149113 0.00637543 -0.146912 -0.0296952 0.0189911 1.68626e-21 1.62551e-43 0 0 0 0 -0.0292555 -0.0300227 -0.0241131 0.235818 0.131943 0.0236826 0.139795 -0.187976 0.0447565 0.0244021 -0.112458 0.212011 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -24049 1.62551e-43 1.686e-21 1.62551e-43
auto var_shape = _Input({(int)mInfo->dim.size()}, NCHW, halide_type_of<int>());
auto shapePtr = var_shape->writeMap<int>();
memcpy(shapePtr, &mInfo->dim[0], mInfo->dim.size() * sizeof(int));
VARP tmpVarp = _Fill(_Shape(var_shape, mInfo->order == Dimensionformat::NCHW), _Scalar<T>(value));
这段逻辑是有问题的,这样 fill 出来的 shape 是 {dim.size()} 而非 dim 数组
改成 VARP tmpVarp = _Fill(_Shape(*varp, mInfo->order == Dimensionformat::NCHW), _Scalar<T>(value));
建议写法是: template <typename T> static void __VARP_FillValue(VARP* varp, T value) { auto mInfo = (*varp)->getInfo(); VARP tmpVarp = _Fill(_Shape(*varp, mInfo->order == Dimensionformat::NCHW), _Scalar<T>(value)); tmpVarp = _Convert(tmpVarp, mInfo->order); *varp = tmpVarp; } 避免多余的内存拷贝
auto var_shape = _Input({(int)mInfo->dim.size()}, NCHW, halide_type_of<int>()); auto shapePtr = var_shape->writeMap<int>(); memcpy(shapePtr, &mInfo->dim[0], mInfo->dim.size() * sizeof(int)); VARP tmpVarp = _Fill(_Shape(var_shape, mInfo->order == Dimensionformat::NCHW), _Scalar<T>(value));这段逻辑是有问题的,这样 fill 出来的 shape 是 {dim.size()} 而非 dim 数组
改成 VARP tmpVarp = _Fill(_Shape(*varp, mInfo->order == Dimensionformat::NCHW), _Scalar(value));
建议写法是: template static void __VARP_FillValue(VARP* varp, T value) { auto mInfo = (*varp)->getInfo(); VARP tmpVarp = _Fill(_Shape(*varp, mInfo->order == Dimensionformat::NCHW), _Scalar(value)); tmpVarp = _Convert(tmpVarp, mInfo->order); *varp = tmpVarp; } 避免多余的内存拷贝
谢谢指导,确实这样做可以设置为 0 了。
但是您写的 *varp = tmpVarp; 这段会引发了一个问题,就是在将 inVarp 作为模型的 input,模型的 outpus.size() 是不正确的。在 __VARP_FillValue 中保留原来的拷贝方式,输出就是ok 的。
示例代码如下,预期out_vars1.size() == 3。
auto inVarp = _Input({6,1,100,14}, NCHW, halide_type_of<float>());
__VARP_FillValue(&inVarp, 0.f);
std::vector<VARP> in_vars1({ invar_chunk_mel_, inVarp});
auto out_vars1 = module_contenter_->onForward(in_vars1);
assert(out_vars1.size() == 3);
可以试试在 onForward 之前加一句 inVarp.fix(VARP::InputType::CONSTANT)
Marking as stale. No activity in 60 days.