TensorLayerX icon indicating copy to clipboard operation
TensorLayerX copied to clipboard

tensorlayerx.ops.Pad不支持“channels_first”的data_format,后续会补充“channels_first”的格式吗?

Open zhxiucui opened this issue 2 years ago • 0 comments

New Issue Checklist

Issue Description

[INSERT DESCRIPTION OF THE PROBLEM]

Reproducible Code

  • Which OS are you using ?
  • Please provide a reproducible code of your issue. Without any reproducible code, you will probably not receive any help.

[INSERT CODE HERE]

# ======================================================== #
###### tensorlayerx.ops.Pad源码######
# ======================================================== #

class Pad(object):

    def __init__(self, paddings, mode="REFLECT", constant_values=0):
        if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']:
            raise Exception("Unsupported mode: {}".format(mode))
        if mode == 'SYMMETRIC':
            raise NotImplementedError
        self.paddings = paddings
        self.mode = mode.lower()
        self.constant_values = constant_values

    def __call__(self, x):
        if len(x.shape) == 3:
            data_format = 'NLC'
            self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format)
        elif len(x.shape) == 4:
            data_format = 'NHWC'
            self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format)
        elif len(x.shape) == 5:
            data_format = 'NDHWC'
            self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format)
        else:
            raise NotImplementedError('Please check the input shape.')
        return pd.nn.functional.pad(x, self.paddings, self.mode, value=self.constant_values, data_format=data_format)

    def correct_paddings(self, in_shape, paddings, data_format):
        if in_shape == 3 and data_format == 'NLC':
            correct_output = [paddings[1][0], paddings[1][1]]
        elif in_shape == 4 and data_format == 'NHWC':
            correct_output = [paddings[2][0], paddings[2][1], paddings[1][0], paddings[1][1]]
        elif in_shape == 5 and data_format == 'NDHWC':
            correct_output = [
                paddings[3][0], paddings[3][1], paddings[2][0], paddings[2][1], paddings[1][0], paddings[1][1]
            ]
        else:
            raise NotImplementedError('Does not support channels first')
        return correct_output

zhxiucui avatar Dec 05 '22 03:12 zhxiucui