Rosetta icon indicating copy to clipboard operation
Rosetta copied to clipboard

请问下经Rosetta处理后,返回的tensor都是维度不可知的,后续有些api需要确定的维度怎么办呢

Open xiaoshui240 opened this issue 2 years ago • 11 comments

我尝试使用Rosetta在transformer模型上,使用rtt.PrivateTextLineDataset的方式读取数据集,在普通代码下测试没有问题,加上Rosetta出现类型和维度的问题。 1.有些计算依赖tensor是int或者float,而Rosetta返回的都是string类型,我进行了强制转换; 2.还出现了维度的问题,返回的rtttensor好像是掩去了tensor的shape,显示为unknown类型,但有些tensor是在训练过程中才计算shape,但后续有些计算要求tensor的rank或者shape必须是已知的,我中途使用reshape方式强制添加shape,但后续也报错。 报错:ValueError: Input 0 of layer rtt_dense_1 is incompatible with the layer: its rank is undefined, but the layer requires a defined rank. 3.还发现一个问题,我代码里面用了tf.train.get_or_create_global_step(),引入Rosetta后直接报错 ValueError: Tensor conversion requested dtype int64 for Tensor with dtype string: 'Tensor("AsString_76:0", shape=(), dtype=string)' @EM~T 30 JN $TD_TH}TP 想问下是我使用Rosetta方式不对【比如不能强制转换类型和维度】还是Rosetta对部分运算的确还不支持呢 谢谢

xiaoshui240 avatar Jun 08 '22 15:06 xiaoshui240

我尝试使用Rosetta在transformer模型上,使用rtt.PrivateTextLineDataset的方式读取数据集,在普通代码下测试没有问题,加上Rosetta出现类型和维度的问题。 1.有些计算依赖tensor是int或者float,而Rosetta返回的都是string类型,我进行了强制转换; 2.还出现了维度的问题,返回的rtttensor好像是掩去了tensor的shape,显示为unknown类型,但有些tensor是在训练过程中才计算shape,但后续有些计算要求tensor的rank或者shape必须是已知的,我中途使用reshape方式强制添加shape,但后续也报错。 报错:ValueError: Input 0 of layer rtt_dense_1 is incompatible with the layer: its rank is undefined, but the layer requires a defined rank. 3.还发现一个问题,我代码里面用了tf.train.get_or_create_global_step(),引入Rosetta后直接报错 ValueError: Tensor conversion requested dtype int64 for Tensor with dtype string: 'Tensor("AsString_76:0", shape=(), dtype=string)' @EM~T 30 JN $TD_TH}TP 想问下是我使用Rosetta方式不对【比如不能强制转换类型和维度】还是Rosetta对部分运算的确还不支持呢 谢谢

shape 显示 unkown 一般是 shape inference 没有打开导致,rosetta 默认根据 python 版本选择是否打开 shape inference 功能,如果是 python 版本是 3.6.9 及以下版本就 disable shape inference ,python 版本是 3.7 以上的就 enable shape inference。 至于为什么我们这样选择:原因是我们发现在 python 3.6.x 上打开 shape inference 会导致 TF 和 Rosetta 存在不兼容的情况。 (确认下你的 python 版本是不是 3.6.x)

yuucyf avatar Jun 09 '22 01:06 yuucyf

对的,我的刚好是3.6.9,必须要3.6.9以上或者是3.7版本才行吗

xiaoshui240 avatar Jun 09 '22 01:06 xiaoshui240

对的,我的刚好是3.6.9,必须要3.6.9以上或者是3.7版本才行吗

image

你可以尝试看一下 rosetta.sh 脚本大概就知道逻辑了。

yuucyf avatar Jun 09 '22 01:06 yuucyf

好的,谢谢

对的,我的刚好是3.6.9,必须要3.6.9以上或者是3.7版本才行吗

image

你可以尝试看一下 rosetta.sh 脚本大概就知道逻辑了。

xiaoshui240 avatar Jun 09 '22 01:06 xiaoshui240

对了,如果我强制转换rtttensor的类型后还能完成隐私计算吗,像下面一条语句,再怎么转换也不行 outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs)

xiaoshui240 avatar Jun 09 '22 02:06 xiaoshui240

tf.where 算子目前的 rosetta 版本还没有支持,你可以在算子文档中查看一下 rosetta 到底支持哪些算子。 (PS:还有一些已经支持的算子还未进行正式开源)

yuucyf avatar Jun 09 '22 02:06 yuucyf

好的,谢谢解答

xiaoshui240 avatar Jun 09 '22 02:06 xiaoshui240

tf.where 算子目前的 rosetta 版本还没有支持,你可以在算子文档中查看一下 rosetta 到底支持哪些算子。 (PS:还有一些已经支持的算子还未进行正式开源)

你好,想再请教下,我现在已经重装了Rosetta,python也重装成3.7了,shape的问题解决了,但在两个三维tensor相乘却报错了,ValueError: Shape must be rank 2 but is rank 3 for 'encoder/num_blocks_0/multihead_attention/scaled_dot_product_attention/RttMatmul_3' (op: 'RttMatmul') with input shapes: [?,?,32], [?,32,?].代码是outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])),这个在普通tensorlfow代码下测试没有问题,这个报错不是说两者rank不同吗,但[?,?,32], [?,32,?]的两个tensor的rank都是3,怎么会报错呢

xiaoshui240 avatar Jun 12 '22 07:06 xiaoshui240

tf.where 算子目前的 rosetta 版本还没有支持,你可以在算子文档中查看一下 rosetta 到底支持哪些算子。 (PS:还有一些已经支持的算子还未进行正式开源)

你好,想再请教下,我现在已经重装了Rosetta,python也重装成3.7了,shape的问题解决了,但在两个三维tensor相乘却报错了,ValueError: Shape must be rank 2 but is rank 3 for 'encoder/num_blocks_0/multihead_attention/scaled_dot_product_attention/RttMatmul_3' (op: 'RttMatmul') with input shapes: [?,?,32], [?,32,?].代码是outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])),这个在普通tensorlfow代码下测试没有问题,这个报错不是说两者rank不同吗,但[?,?,32], [?,32,?]的两个tensor的rank都是3,怎么会报错呢

可以贴上 tensorflow 下测试没有问题的简单片段代码我们验证一下,谢谢。

yuucyf avatar Jun 13 '22 01:06 yuucyf

tf.where 算子目前的 rosetta 版本还没有支持,你可以在算子文档中查看一下 rosetta 到底支持哪些算子。 (PS:还有一些已经支持的算子还未进行正式开源)

你好,想再请教下,我现在已经重装了Rosetta,python也重装成3.7了,shape的问题解决了,但在两个三维tensor相乘却报错了,ValueError: Shape must be rank 2 but is rank 3 for 'encoder/num_blocks_0/multihead_attention/scaled_dot_product_attention/RttMatmul_3' (op: 'RttMatmul') with input shapes: [?,?,32], [?,32,?].代码是outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])),这个在普通tensorlfow代码下测试没有问题,这个报错不是说两者rank不同吗,但[?,?,32], [?,32,?]的两个tensor的rank都是3,怎么会报错呢

可以贴上 tensorflow 下测试没有问题的简单片段代码我们验证一下,谢谢。

不好意思,上个月事情太多没有来得及回复。 我使用tensorflow官方提供的案例进行了验证,好像二维tensor相乘没有问题,三维tensor相乘就出现我上面提到的Shape must be rank 2 but is rank 3 for 'RttMatmul' (op: 'RttMatmul') with input shapes: [2,2,3], [2,3,2].这个问题了。 测试代码如下:

import tensorflow as tf
import latticex.rosetta as rtt
import numpy as np

# 二维tensor相乘测试
# a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
# b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2])

# 三维tensor相乘测试
a = tf.constant(np.arange(1, 13, dtype=np.int32), shape=[2, 2, 3])
b = tf.constant(np.arange(13, 25, dtype=np.int32), shape=[2, 3, 2])

c = tf.matmul(a, b)

xiaoshui240 avatar Jul 05 '22 15:07 xiaoshui240

tf.where 算子目前的 rosetta 版本还没有支持,你可以在算子文档中查看一下 rosetta 到底支持哪些算子。 (PS:还有一些已经支持的算子还未进行正式开源)

你好,想再请教下,我现在已经重装了Rosetta,python也重装成3.7了,shape的问题解决了,但在两个三维tensor相乘却报错了,ValueError: Shape must be rank 2 but is rank 3 for 'encoder/num_blocks_0/multihead_attention/scaled_dot_product_attention/RttMatmul_3' (op: 'RttMatmul') with input shapes: [?,?,32], [?,32,?].代码是outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])),这个在普通tensorlfow代码下测试没有问题,这个报错不是说两者rank不同吗,但[?,?,32], [?,32,?]的两个tensor的rank都是3,怎么会报错呢

可以贴上 tensorflow 下测试没有问题的简单片段代码我们验证一下,谢谢。

不好意思,上个月事情太多没有来得及回复。 我使用tensorflow官方提供的案例进行了验证,好像二维tensor相乘没有问题,三维tensor相乘就出现我上面提到的Shape must be rank 2 but is rank 3 for 'RttMatmul' (op: 'RttMatmul') with input shapes: [2,2,3], [2,3,2].这个问题了。 测试代码如下:

import tensorflow as tf
import latticex.rosetta as rtt
import numpy as np

# 二维tensor相乘测试
# a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
# b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2])

# 三维tensor相乘测试
a = tf.constant(np.arange(1, 13, dtype=np.int32), shape=[2, 2, 3])
b = tf.constant(np.arange(13, 25, dtype=np.int32), shape=[2, 3, 2])

c = tf.matmul(a, b)

请参考 (https://github.com/LatticeX-Foundation/Rosetta/issues/119)

yuucyf avatar Jul 25 '22 10:07 yuucyf