TensorLayerX
TensorLayerX copied to clipboard
tensorlayerx.ops.Pad不支持“channels_first”的data_format,后续会补充“channels_first”的格式吗?
New Issue Checklist
- [ ] I have read the Contribution Guidelines
- [ ] I searched for existing GitHub issues
Issue Description
[INSERT DESCRIPTION OF THE PROBLEM]
Reproducible Code
- Which OS are you using ?
- Please provide a reproducible code of your issue. Without any reproducible code, you will probably not receive any help.
[INSERT CODE HERE]
# ======================================================== #
###### tensorlayerx.ops.Pad源码######
# ======================================================== #
class Pad(object):
def __init__(self, paddings, mode="REFLECT", constant_values=0):
if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']:
raise Exception("Unsupported mode: {}".format(mode))
if mode == 'SYMMETRIC':
raise NotImplementedError
self.paddings = paddings
self.mode = mode.lower()
self.constant_values = constant_values
def __call__(self, x):
if len(x.shape) == 3:
data_format = 'NLC'
self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format)
elif len(x.shape) == 4:
data_format = 'NHWC'
self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format)
elif len(x.shape) == 5:
data_format = 'NDHWC'
self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format)
else:
raise NotImplementedError('Please check the input shape.')
return pd.nn.functional.pad(x, self.paddings, self.mode, value=self.constant_values, data_format=data_format)
def correct_paddings(self, in_shape, paddings, data_format):
if in_shape == 3 and data_format == 'NLC':
correct_output = [paddings[1][0], paddings[1][1]]
elif in_shape == 4 and data_format == 'NHWC':
correct_output = [paddings[2][0], paddings[2][1], paddings[1][0], paddings[1][1]]
elif in_shape == 5 and data_format == 'NDHWC':
correct_output = [
paddings[3][0], paddings[3][1], paddings[2][0], paddings[2][1], paddings[1][0], paddings[1][1]
]
else:
raise NotImplementedError('Does not support channels first')
return correct_output