How to prune a custom tensor? The tensor is a recursive variable and is initialized with tf.zeros.
Prior to filing: check that this should be a bug instead of a feature request. Everything supported, including the compatible versions of TensorFlow, is listed in the overview page of each technique. For example, the overview page of quantization-aware training is here. An issue for anything not supported should be a feature request.
Describe the bug How to prune a custom tensor? The tensor is a custom variable and is initialized with tf.zeros.
System information
TensorFlow version (installed from source or binary):
TensorFlow Model Optimization version (installed from source or binary):
Python version: 3.8
Describe the expected behavior
Describe the current behavior
How to prune "b"
Code to reproduce the issue
class PruningLayer(tf.keras.layers.Layer, tfmot.sparsity.keras.PrunableLayer):
def __init__(self, n, d):
super(PruningLayer, self).__init__()
self.n = n
self.d = d
def build(self, input_shape):
self.weight = self.add_weight("weight", shape=[1, input_shape[1],
self.n,
self.d,
input_shape[2]
],
initializer="random_normal",
trainable=True)
def call(self, x):
u = tf.matmul(self.weight, x)
b = self.Rr(u)
s = tf.multiply(x, b)
return s
def get_prunable_weights(self):
return [self.weight]
def Rr(self, x):
input_shape = tf.shape(x)
# initialize b to zero
b = tf.zeros((input_shape[0], input_shape[1], self.n, 1))
for _ in range(3):
c = tf.nn.softmax(b, axis=2)
b = b + tf.multiply(x, c)
return b
Screenshots If applicable, add screenshots to help explain your problem.
Additional context
Hi @starsky68,
Can you please provide more context ? In particular:
- What is
self.weightfor ? It looks like it is never used in the layer. - What prevents
bto be a class member ? You could then return it inget_prunable_weights(self)
Hi @starsky68,
Can you please provide more context ? In particular:
- What is
self.weightfor ? It looks like it is never used in the layer.- What prevents
bto be a class member ? You could then return it inget_prunable_weights(self)
I modified the above sample code again. Self. Weight can be obtained through get_ prunable_ Weights returns, but I don't know if this 'b' is returned to get_ prunable_ Weights, hope to get help. Where ‘b' is an iteratively updated tensor
Hi @starsky68,
Can you please provide more context ? In particular:
- What is
self.weightfor ? It looks like it is never used in the layer.- What prevents
bto be a class member ? You could then return it inget_prunable_weights(self)
When I use TF1, I can directly use its pruning interface apply_ mask operates on the tensor, but the current interface seems to have changed after TF2. Such operations are no longer supported