Shellbye.github.io
Shellbye.github.io copied to clipboard
C++遍历读取tensorflow::Tensor
在将Python
代码转为C++
的过程中,需要一步一步的查看以确保转化过程的正确性,对于常见的数据类型,基本打印出来就可以进行对比查验,但是对于模型输出的Tensor
,本以为它就是个多维向量而已,循环打印之就好,但是缺也是费了很大的周折,最终还是通过看源码找到了解决方案。
因为我也刚开始接触TenserFlow
,所以很多东西还没有搞特别明白,比如C++
模型的输出,在我们的项目中,输出是一个std::vector<Tensor>
,我想做的事儿就是遍历这个Tensor
,查看里面的数据是否与我们的Python
版一致。
在尝试了各种方法之后,最后通过通读文档,凭感觉觉得这个vec
可能就是我要找的东西,然后发现了Eigen
这么一个概念,继续顺藤摸瓜,又找到了这篇博客,最终确定我需要的应该是tensor
这个方法,
代码如下:
Status run_status = session_->Run(input, { "BiasAdd:0", }, {}, &outputs);
if (run_status.ok())
{
auto f = outputs[0];
auto t0 = f.tensor<float, 3>(); // 3来自f.shape()
std::cout << "shape " << f.shape() << std::endl; // [1,255,43]
for (int i = 0; i < 1; i++) {
for (int j = 0; j < 255; j++) {
for (int k = 0; k < 43; k++) {
std::cout << " " << t0(i, j, k);
}
std::cout << std::endl;
}
}
}