basenji
basenji copied to clipboard
Using the same architecture and training parameters in pytorch, the model fails to converge.
Hello, I have rewritten akita using pytorch in order to facilitate debugging of akita. i am using the same training parameters as you have in tensorflow and found that the model is not converging. However, on the same dataset, I reused your tensorflow version of akita and the model was able to converge.
`---------------------------------------------------------------- Layer (type) Output Shape Param #
Conv1d-1 [-1, 96, 1048576] 4,224
BatchNorm1d-2 [-1, 96, 1048576] 192
ReLU-3 [-1, 96, 1048576] 0
MaxPool1d-4 [-1, 96, 524288] 0
Conv1d-5 [-1, 96, 524288] 46,080
BatchNorm1d-6 [-1, 96, 524288] 192
ReLU-7 [-1, 96, 524288] 0
MaxPool1d-8 [-1, 96, 262144] 0
Conv1d-9 [-1, 96, 262144] 46,080
BatchNorm1d-10 [-1, 96, 262144] 192
ReLU-11 [-1, 96, 262144] 0
MaxPool1d-12 [-1, 96, 131072] 0
Conv1d-13 [-1, 96, 131072] 46,080
BatchNorm1d-14 [-1, 96, 131072] 192
ReLU-15 [-1, 96, 131072] 0
MaxPool1d-16 [-1, 96, 65536] 0
Conv1d-17 [-1, 96, 65536] 46,080
BatchNorm1d-18 [-1, 96, 65536] 192
ReLU-19 [-1, 96, 65536] 0
MaxPool1d-20 [-1, 96, 32768] 0
Conv1d-21 [-1, 96, 32768] 46,080
BatchNorm1d-22 [-1, 96, 32768] 192
ReLU-23 [-1, 96, 32768] 0
MaxPool1d-24 [-1, 96, 16384] 0
Conv1d-25 [-1, 96, 16384] 46,080
BatchNorm1d-26 [-1, 96, 16384] 192
ReLU-27 [-1, 96, 16384] 0
MaxPool1d-28 [-1, 96, 8192] 0
Conv1d-29 [-1, 96, 8192] 46,080
BatchNorm1d-30 [-1, 96, 8192] 192
ReLU-31 [-1, 96, 8192] 0
MaxPool1d-32 [-1, 96, 4096] 0
Conv1d-33 [-1, 96, 4096] 46,080
BatchNorm1d-34 [-1, 96, 4096] 192
ReLU-35 [-1, 96, 4096] 0
MaxPool1d-36 [-1, 96, 2048] 0
Conv1d-37 [-1, 96, 2048] 46,080
BatchNorm1d-38 [-1, 96, 2048] 192
ReLU-39 [-1, 96, 2048] 0
MaxPool1d-40 [-1, 96, 1024] 0
Conv1d-41 [-1, 96, 1024] 46,080
BatchNorm1d-42 [-1, 96, 1024] 192
ReLU-43 [-1, 96, 1024] 0
MaxPool1d-44 [-1, 96, 512] 0
Conv1d-45 [-1, 48, 512] 13,824
BatchNorm1d-46 [-1, 48, 512] 96
ReLU-47 [-1, 48, 512] 0
Conv1d-48 [-1, 96, 512] 4,608
BatchNorm1d-49 [-1, 96, 512] 192
ReLU-50 [-1, 96, 512] 0
Dropout-51 [-1, 96, 512] 0
Residual-52 [-1, 96, 512] 0
Conv1d-53 [-1, 48, 512] 13,824
BatchNorm1d-54 [-1, 48, 512] 96
ReLU-55 [-1, 48, 512] 0
Conv1d-56 [-1, 96, 512] 4,608
BatchNorm1d-57 [-1, 96, 512] 192
ReLU-58 [-1, 96, 512] 0
Dropout-59 [-1, 96, 512] 0
Residual-60 [-1, 96, 512] 0
Conv1d-61 [-1, 48, 512] 13,824
BatchNorm1d-62 [-1, 48, 512] 96
ReLU-63 [-1, 48, 512] 0
Conv1d-64 [-1, 96, 512] 4,608
BatchNorm1d-65 [-1, 96, 512] 192
ReLU-66 [-1, 96, 512] 0
Dropout-67 [-1, 96, 512] 0
Residual-68 [-1, 96, 512] 0
Conv1d-69 [-1, 48, 512] 13,824
BatchNorm1d-70 [-1, 48, 512] 96
ReLU-71 [-1, 48, 512] 0
Conv1d-72 [-1, 96, 512] 4,608
BatchNorm1d-73 [-1, 96, 512] 192
ReLU-74 [-1, 96, 512] 0
Dropout-75 [-1, 96, 512] 0
Residual-76 [-1, 96, 512] 0
Conv1d-77 [-1, 48, 512] 13,824
BatchNorm1d-78 [-1, 48, 512] 96
ReLU-79 [-1, 48, 512] 0
Conv1d-80 [-1, 96, 512] 4,608
BatchNorm1d-81 [-1, 96, 512] 192
ReLU-82 [-1, 96, 512] 0
Dropout-83 [-1, 96, 512] 0
Residual-84 [-1, 96, 512] 0
Conv1d-85 [-1, 48, 512] 13,824
BatchNorm1d-86 [-1, 48, 512] 96
ReLU-87 [-1, 48, 512] 0
Conv1d-88 [-1, 96, 512] 4,608
BatchNorm1d-89 [-1, 96, 512] 192
ReLU-90 [-1, 96, 512] 0
Dropout-91 [-1, 96, 512] 0
Residual-92 [-1, 96, 512] 0
Conv1d-93 [-1, 48, 512] 13,824
BatchNorm1d-94 [-1, 48, 512] 96
ReLU-95 [-1, 48, 512] 0
Conv1d-96 [-1, 96, 512] 4,608
BatchNorm1d-97 [-1, 96, 512] 192
ReLU-98 [-1, 96, 512] 0
Dropout-99 [-1, 96, 512] 0
Residual-100 [-1, 96, 512] 0
Conv1d-101 [-1, 48, 512] 13,824
BatchNorm1d-102 [-1, 48, 512] 96
ReLU-103 [-1, 48, 512] 0
Conv1d-104 [-1, 96, 512] 4,608
BatchNorm1d-105 [-1, 96, 512] 192
ReLU-106 [-1, 96, 512] 0
Dropout-107 [-1, 96, 512] 0
Residual-108 [-1, 96, 512] 0
DilatedResidual1D-109 [-1, 96, 512] 0 Conv1d-110 [-1, 64, 512] 30,720 BatchNorm1d-111 [-1, 64, 512] 128 ReLU-112 [-1, 64, 512] 0 OneToTwo-113 [-1, 64, 512, 512] 0 ConcatDist2D-114 [-1, 65, 512, 512] 0 Conv2d-115 [-1, 48, 512, 512] 28,080 BatchNorm2d-116 [-1, 48, 512, 512] 96 ReLU-117 [-1, 48, 512, 512] 0 Symmetrize2D-118 [-1, 48, 512, 512] 0 Conv2d-119 [-1, 24, 512, 512] 10,368 BatchNorm2d-120 [-1, 24, 512, 512] 48 ReLU-121 [-1, 24, 512, 512] 0 Conv2d-122 [-1, 48, 512, 512] 1,152 BatchNorm2d-123 [-1, 48, 512, 512] 96 ReLU-124 [-1, 48, 512, 512] 0 Dropout-125 [-1, 48, 512, 512] 0 Residual-126 [-1, 48, 512, 512] 0 Symmetrize2D-127 [-1, 48, 512, 512] 0 Conv2d-128 [-1, 24, 512, 512] 10,368 BatchNorm2d-129 [-1, 24, 512, 512] 48 ReLU-130 [-1, 24, 512, 512] 0 Conv2d-131 [-1, 48, 512, 512] 1,152 BatchNorm2d-132 [-1, 48, 512, 512] 96 ReLU-133 [-1, 48, 512, 512] 0 Dropout-134 [-1, 48, 512, 512] 0 Residual-135 [-1, 48, 512, 512] 0 Symmetrize2D-136 [-1, 48, 512, 512] 0 Conv2d-137 [-1, 24, 512, 512] 10,368 BatchNorm2d-138 [-1, 24, 512, 512] 48 ReLU-139 [-1, 24, 512, 512] 0 Conv2d-140 [-1, 48, 512, 512] 1,152 BatchNorm2d-141 [-1, 48, 512, 512] 96 ReLU-142 [-1, 48, 512, 512] 0 Dropout-143 [-1, 48, 512, 512] 0 Residual-144 [-1, 48, 512, 512] 0 Symmetrize2D-145 [-1, 48, 512, 512] 0 Conv2d-146 [-1, 24, 512, 512] 10,368 BatchNorm2d-147 [-1, 24, 512, 512] 48 ReLU-148 [-1, 24, 512, 512] 0 Conv2d-149 [-1, 48, 512, 512] 1,152 BatchNorm2d-150 [-1, 48, 512, 512] 96 ReLU-151 [-1, 48, 512, 512] 0 Dropout-152 [-1, 48, 512, 512] 0 Residual-153 [-1, 48, 512, 512] 0 Symmetrize2D-154 [-1, 48, 512, 512] 0 Conv2d-155 [-1, 24, 512, 512] 10,368 BatchNorm2d-156 [-1, 24, 512, 512] 48 ReLU-157 [-1, 24, 512, 512] 0 Conv2d-158 [-1, 48, 512, 512] 1,152 BatchNorm2d-159 [-1, 48, 512, 512] 96 ReLU-160 [-1, 48, 512, 512] 0 Dropout-161 [-1, 48, 512, 512] 0 Residual-162 [-1, 48, 512, 512] 0 Symmetrize2D-163 [-1, 48, 512, 512] 0 Conv2d-164 [-1, 24, 512, 512] 10,368 BatchNorm2d-165 [-1, 24, 512, 512] 48 ReLU-166 [-1, 24, 512, 512] 0 Conv2d-167 [-1, 48, 512, 512] 1,152 BatchNorm2d-168 [-1, 48, 512, 512] 96 ReLU-169 [-1, 48, 512, 512] 0 Dropout-170 [-1, 48, 512, 512] 0 Residual-171 [-1, 48, 512, 512] 0 DilatedResidual2D-172 [-1, 48, 512, 512] 0 Cropping2D-173 [-1, 48, 448, 448] 0 UpperTri-174 [-1, 48, 99681] 0 Linear-175 [-1, 99681, 5] 245 Final-176 [-1, 99681, 5] 0
Total params: 746,149 Trainable params: 746,149 Non-trainable params: 0
Input size (MB): 16.00 Forward/backward pass size (MB): 10473.61 Params size (MB): 2.85 Estimated Total Size (MB): 10492.46
SeqNN( (feature_extractor_1d): Sequential( (0): Sequential( (0): Conv1d(4, 96, kernel_size=(11,), stride=(1,), padding=(5,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (1): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (2): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (3): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (4): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (5): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (6): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (7): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (8): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (9): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (10): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (11): DilatedResidual1D( (layers): Sequential( (0): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(1,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (1): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (2): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (3): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (4): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(9,), dilation=(9,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (5): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (6): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(29,), dilation=(29,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (7): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(50,), dilation=(50,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) ) ) (12): Sequential( (0): Conv1d(96, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(64, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() ) ) (feature_extractor_2d): Sequential( (0): Conv2d(65, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): DilatedResidual2D( (layers): Sequential( (0): Symmetrize2D() (1): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) (2): Symmetrize2D() (3): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) (4): Symmetrize2D() (5): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) (6): Symmetrize2D() (7): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) (8): Symmetrize2D() (9): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(9, 9), dilation=(9, 9), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) (10): Symmetrize2D() (11): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(16, 16), dilation=(16, 16), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) ) ) ) (oneto_two): OneToTwo() (concat_dist_2d): ConcatDist2D() (crop_2d): Cropping2D() (uppertri): UpperTri() (final): Final( (dense): Linear(in_features=48, out_features=5, bias=True) ) )`
I checked the difference between the model's predicted and true values. It was found that the model had roughly the same mean values for predicted and true values in the beginning phase, but the var gap was large. r2 was consistently elevated during the training process.
I would be grateful for any advice you can give!
Moving across frameworks can be very challenging. You'll basically have to double check that every detail matches. E.g. make sure the initialization is handled the same way. You might want to train simplified Tensorflow models with our code alongside your PyTorch version to focus on specific components.
Moving across frameworks can be very challenging. You'll basically have to double check that every detail matches. E.g. make sure the initialization is handled the same way. You might want to train simplified Tensorflow models with our code alongside your PyTorch version to focus on specific components.
Thank you very much for your reply. I have ensured that the parameters are consistent at each layer and have used the same initialization method as in TensorFlow. The dataset is also the same as the original Akita dataset. Initially, I suspected that the issue might be with the Adam optimizer, but after switching to SGD, the problem still persists.
hi @bioczsun -- did you try loading the trained model into your pytorch re-implementation rather than re-training? If so, I wonder if it gave similar predictions on test set sequences.
hi @bioczsun -- did you try loading the trained model into your pytorch re-implementation rather than re-training? If so, I wonder if it gave similar predictions on test set sequences.
Hi @gfudenberg , do you mean loading the pre-trained parameters from tensorflow into the pytorch model?
hi @bioczsun yes, exactly