basenji icon indicating copy to clipboard operation
basenji copied to clipboard

Using the same architecture and training parameters in pytorch, the model fails to converge.

Open bioczsun opened this issue 1 year ago • 5 comments

Hello, I have rewritten akita using pytorch in order to facilitate debugging of akita. i am using the same training parameters as you have in tensorflow and found that the model is not converging. However, on the same dataset, I reused your tensorflow version of akita and the model was able to converge.

`---------------------------------------------------------------- Layer (type) Output Shape Param #

        Conv1d-1          [-1, 96, 1048576]           4,224
   BatchNorm1d-2          [-1, 96, 1048576]             192
          ReLU-3          [-1, 96, 1048576]               0
     MaxPool1d-4           [-1, 96, 524288]               0
        Conv1d-5           [-1, 96, 524288]          46,080
   BatchNorm1d-6           [-1, 96, 524288]             192
          ReLU-7           [-1, 96, 524288]               0
     MaxPool1d-8           [-1, 96, 262144]               0
        Conv1d-9           [-1, 96, 262144]          46,080
  BatchNorm1d-10           [-1, 96, 262144]             192
         ReLU-11           [-1, 96, 262144]               0
    MaxPool1d-12           [-1, 96, 131072]               0
       Conv1d-13           [-1, 96, 131072]          46,080
  BatchNorm1d-14           [-1, 96, 131072]             192
         ReLU-15           [-1, 96, 131072]               0
    MaxPool1d-16            [-1, 96, 65536]               0
       Conv1d-17            [-1, 96, 65536]          46,080
  BatchNorm1d-18            [-1, 96, 65536]             192
         ReLU-19            [-1, 96, 65536]               0
    MaxPool1d-20            [-1, 96, 32768]               0
       Conv1d-21            [-1, 96, 32768]          46,080
  BatchNorm1d-22            [-1, 96, 32768]             192
         ReLU-23            [-1, 96, 32768]               0
    MaxPool1d-24            [-1, 96, 16384]               0
       Conv1d-25            [-1, 96, 16384]          46,080
  BatchNorm1d-26            [-1, 96, 16384]             192
         ReLU-27            [-1, 96, 16384]               0
    MaxPool1d-28             [-1, 96, 8192]               0
       Conv1d-29             [-1, 96, 8192]          46,080
  BatchNorm1d-30             [-1, 96, 8192]             192
         ReLU-31             [-1, 96, 8192]               0
    MaxPool1d-32             [-1, 96, 4096]               0
       Conv1d-33             [-1, 96, 4096]          46,080
  BatchNorm1d-34             [-1, 96, 4096]             192
         ReLU-35             [-1, 96, 4096]               0
    MaxPool1d-36             [-1, 96, 2048]               0
       Conv1d-37             [-1, 96, 2048]          46,080
  BatchNorm1d-38             [-1, 96, 2048]             192
         ReLU-39             [-1, 96, 2048]               0
    MaxPool1d-40             [-1, 96, 1024]               0
       Conv1d-41             [-1, 96, 1024]          46,080
  BatchNorm1d-42             [-1, 96, 1024]             192
         ReLU-43             [-1, 96, 1024]               0
    MaxPool1d-44              [-1, 96, 512]               0
       Conv1d-45              [-1, 48, 512]          13,824
  BatchNorm1d-46              [-1, 48, 512]              96
         ReLU-47              [-1, 48, 512]               0
       Conv1d-48              [-1, 96, 512]           4,608
  BatchNorm1d-49              [-1, 96, 512]             192
         ReLU-50              [-1, 96, 512]               0
      Dropout-51              [-1, 96, 512]               0
     Residual-52              [-1, 96, 512]               0
       Conv1d-53              [-1, 48, 512]          13,824
  BatchNorm1d-54              [-1, 48, 512]              96
         ReLU-55              [-1, 48, 512]               0
       Conv1d-56              [-1, 96, 512]           4,608
  BatchNorm1d-57              [-1, 96, 512]             192
         ReLU-58              [-1, 96, 512]               0
      Dropout-59              [-1, 96, 512]               0
     Residual-60              [-1, 96, 512]               0
       Conv1d-61              [-1, 48, 512]          13,824
  BatchNorm1d-62              [-1, 48, 512]              96
         ReLU-63              [-1, 48, 512]               0
       Conv1d-64              [-1, 96, 512]           4,608
  BatchNorm1d-65              [-1, 96, 512]             192
         ReLU-66              [-1, 96, 512]               0
      Dropout-67              [-1, 96, 512]               0
     Residual-68              [-1, 96, 512]               0
       Conv1d-69              [-1, 48, 512]          13,824
  BatchNorm1d-70              [-1, 48, 512]              96
         ReLU-71              [-1, 48, 512]               0
       Conv1d-72              [-1, 96, 512]           4,608
  BatchNorm1d-73              [-1, 96, 512]             192
         ReLU-74              [-1, 96, 512]               0
      Dropout-75              [-1, 96, 512]               0
     Residual-76              [-1, 96, 512]               0
       Conv1d-77              [-1, 48, 512]          13,824
  BatchNorm1d-78              [-1, 48, 512]              96
         ReLU-79              [-1, 48, 512]               0
       Conv1d-80              [-1, 96, 512]           4,608
  BatchNorm1d-81              [-1, 96, 512]             192
         ReLU-82              [-1, 96, 512]               0
      Dropout-83              [-1, 96, 512]               0
     Residual-84              [-1, 96, 512]               0
       Conv1d-85              [-1, 48, 512]          13,824
  BatchNorm1d-86              [-1, 48, 512]              96
         ReLU-87              [-1, 48, 512]               0
       Conv1d-88              [-1, 96, 512]           4,608
  BatchNorm1d-89              [-1, 96, 512]             192
         ReLU-90              [-1, 96, 512]               0
      Dropout-91              [-1, 96, 512]               0
     Residual-92              [-1, 96, 512]               0
       Conv1d-93              [-1, 48, 512]          13,824
  BatchNorm1d-94              [-1, 48, 512]              96
         ReLU-95              [-1, 48, 512]               0
       Conv1d-96              [-1, 96, 512]           4,608
  BatchNorm1d-97              [-1, 96, 512]             192
         ReLU-98              [-1, 96, 512]               0
      Dropout-99              [-1, 96, 512]               0
    Residual-100              [-1, 96, 512]               0
      Conv1d-101              [-1, 48, 512]          13,824
 BatchNorm1d-102              [-1, 48, 512]              96
        ReLU-103              [-1, 48, 512]               0
      Conv1d-104              [-1, 96, 512]           4,608
 BatchNorm1d-105              [-1, 96, 512]             192
        ReLU-106              [-1, 96, 512]               0
     Dropout-107              [-1, 96, 512]               0
    Residual-108              [-1, 96, 512]               0

DilatedResidual1D-109 [-1, 96, 512] 0 Conv1d-110 [-1, 64, 512] 30,720 BatchNorm1d-111 [-1, 64, 512] 128 ReLU-112 [-1, 64, 512] 0 OneToTwo-113 [-1, 64, 512, 512] 0 ConcatDist2D-114 [-1, 65, 512, 512] 0 Conv2d-115 [-1, 48, 512, 512] 28,080 BatchNorm2d-116 [-1, 48, 512, 512] 96 ReLU-117 [-1, 48, 512, 512] 0 Symmetrize2D-118 [-1, 48, 512, 512] 0 Conv2d-119 [-1, 24, 512, 512] 10,368 BatchNorm2d-120 [-1, 24, 512, 512] 48 ReLU-121 [-1, 24, 512, 512] 0 Conv2d-122 [-1, 48, 512, 512] 1,152 BatchNorm2d-123 [-1, 48, 512, 512] 96 ReLU-124 [-1, 48, 512, 512] 0 Dropout-125 [-1, 48, 512, 512] 0 Residual-126 [-1, 48, 512, 512] 0 Symmetrize2D-127 [-1, 48, 512, 512] 0 Conv2d-128 [-1, 24, 512, 512] 10,368 BatchNorm2d-129 [-1, 24, 512, 512] 48 ReLU-130 [-1, 24, 512, 512] 0 Conv2d-131 [-1, 48, 512, 512] 1,152 BatchNorm2d-132 [-1, 48, 512, 512] 96 ReLU-133 [-1, 48, 512, 512] 0 Dropout-134 [-1, 48, 512, 512] 0 Residual-135 [-1, 48, 512, 512] 0 Symmetrize2D-136 [-1, 48, 512, 512] 0 Conv2d-137 [-1, 24, 512, 512] 10,368 BatchNorm2d-138 [-1, 24, 512, 512] 48 ReLU-139 [-1, 24, 512, 512] 0 Conv2d-140 [-1, 48, 512, 512] 1,152 BatchNorm2d-141 [-1, 48, 512, 512] 96 ReLU-142 [-1, 48, 512, 512] 0 Dropout-143 [-1, 48, 512, 512] 0 Residual-144 [-1, 48, 512, 512] 0 Symmetrize2D-145 [-1, 48, 512, 512] 0 Conv2d-146 [-1, 24, 512, 512] 10,368 BatchNorm2d-147 [-1, 24, 512, 512] 48 ReLU-148 [-1, 24, 512, 512] 0 Conv2d-149 [-1, 48, 512, 512] 1,152 BatchNorm2d-150 [-1, 48, 512, 512] 96 ReLU-151 [-1, 48, 512, 512] 0 Dropout-152 [-1, 48, 512, 512] 0 Residual-153 [-1, 48, 512, 512] 0 Symmetrize2D-154 [-1, 48, 512, 512] 0 Conv2d-155 [-1, 24, 512, 512] 10,368 BatchNorm2d-156 [-1, 24, 512, 512] 48 ReLU-157 [-1, 24, 512, 512] 0 Conv2d-158 [-1, 48, 512, 512] 1,152 BatchNorm2d-159 [-1, 48, 512, 512] 96 ReLU-160 [-1, 48, 512, 512] 0 Dropout-161 [-1, 48, 512, 512] 0 Residual-162 [-1, 48, 512, 512] 0 Symmetrize2D-163 [-1, 48, 512, 512] 0 Conv2d-164 [-1, 24, 512, 512] 10,368 BatchNorm2d-165 [-1, 24, 512, 512] 48 ReLU-166 [-1, 24, 512, 512] 0 Conv2d-167 [-1, 48, 512, 512] 1,152 BatchNorm2d-168 [-1, 48, 512, 512] 96 ReLU-169 [-1, 48, 512, 512] 0 Dropout-170 [-1, 48, 512, 512] 0 Residual-171 [-1, 48, 512, 512] 0 DilatedResidual2D-172 [-1, 48, 512, 512] 0 Cropping2D-173 [-1, 48, 448, 448] 0 UpperTri-174 [-1, 48, 99681] 0 Linear-175 [-1, 99681, 5] 245 Final-176 [-1, 99681, 5] 0

Total params: 746,149 Trainable params: 746,149 Non-trainable params: 0

Input size (MB): 16.00 Forward/backward pass size (MB): 10473.61 Params size (MB): 2.85 Estimated Total Size (MB): 10492.46

SeqNN( (feature_extractor_1d): Sequential( (0): Sequential( (0): Conv1d(4, 96, kernel_size=(11,), stride=(1,), padding=(5,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (1): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (2): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (3): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (4): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (5): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (6): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (7): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (8): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (9): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (10): Sequential( (0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (11): DilatedResidual1D( (layers): Sequential( (0): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(1,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (1): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (2): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (3): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (4): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(9,), dilation=(9,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (5): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (6): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(29,), dilation=(29,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) (7): Residual( (fn): Sequential( (0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(50,), dilation=(50,), bias=False) (1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False) (4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.4, inplace=False) ) ) ) ) (12): Sequential( (0): Conv1d(96, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False) (1): BatchNorm1d(64, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() ) ) (feature_extractor_2d): Sequential( (0): Conv2d(65, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): DilatedResidual2D( (layers): Sequential( (0): Symmetrize2D() (1): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) (2): Symmetrize2D() (3): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) (4): Symmetrize2D() (5): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) (6): Symmetrize2D() (7): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) (8): Symmetrize2D() (9): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(9, 9), dilation=(9, 9), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) (10): Symmetrize2D() (11): Residual( (fn): Sequential( (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(16, 16), dilation=(16, 16), bias=False) (1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True) (5): ReLU() (6): Dropout(p=0.1, inplace=False) ) ) ) ) ) (oneto_two): OneToTwo() (concat_dist_2d): ConcatDist2D() (crop_2d): Cropping2D() (uppertri): UpperTri() (final): Final( (dense): Linear(in_features=48, out_features=5, bias=True) ) )`

I checked the difference between the model's predicted and true values. It was found that the model had roughly the same mean values for predicted and true values in the beginning phase, but the var gap was large. r2 was consistently elevated during the training process.

image

I would be grateful for any advice you can give!

bioczsun avatar Oct 31 '24 01:10 bioczsun

Moving across frameworks can be very challenging. You'll basically have to double check that every detail matches. E.g. make sure the initialization is handled the same way. You might want to train simplified Tensorflow models with our code alongside your PyTorch version to focus on specific components.

davek44 avatar Nov 01 '24 00:11 davek44

Moving across frameworks can be very challenging. You'll basically have to double check that every detail matches. E.g. make sure the initialization is handled the same way. You might want to train simplified Tensorflow models with our code alongside your PyTorch version to focus on specific components.

Thank you very much for your reply. I have ensured that the parameters are consistent at each layer and have used the same initialization method as in TensorFlow. The dataset is also the same as the original Akita dataset. Initially, I suspected that the issue might be with the Adam optimizer, but after switching to SGD, the problem still persists.

bioczsun avatar Nov 01 '24 08:11 bioczsun

hi @bioczsun -- did you try loading the trained model into your pytorch re-implementation rather than re-training? If so, I wonder if it gave similar predictions on test set sequences.

gfudenberg avatar Nov 01 '24 22:11 gfudenberg

hi @bioczsun -- did you try loading the trained model into your pytorch re-implementation rather than re-training? If so, I wonder if it gave similar predictions on test set sequences.

Hi @gfudenberg , do you mean loading the pre-trained parameters from tensorflow into the pytorch model?

bioczsun avatar Nov 02 '24 02:11 bioczsun

hi @bioczsun yes, exactly

gfudenberg avatar Nov 04 '24 23:11 gfudenberg