VisCy icon indicating copy to clipboard operation
VisCy copied to clipboard

Three channel output using covnext model

Open Soorya19Pradeep opened this issue 1 year ago • 9 comments

I want to train the covnext model for the infection classifier problem. The model will use two or more input channels (phase + HSP90 channels + infection score channels, etc.) and should output three channels (background + uninfected + infected channels). Currently, I can perform this training using 2D and 2.5D unets, but not using covnext. It doesn't allow three-channel output due to a hardcoded scaling parameter value. @ziw-liu, can you help with this? Thanks!

Soorya19Pradeep avatar May 23 '24 18:05 Soorya19Pradeep

@ziw-liu We are trying two kinds of models:

  • Semantic segmentation (background vs uninfected vs infected) based on human annotation.
  • Regression of infection score based on engineered features.

The regression models can be trained without any code change. What is the right branch at this point?

For segmentation models, we need a task=segmentation option that tweaks the model architecture. Can we specify the loss via config?

mattersoflight avatar May 23 '24 19:05 mattersoflight

@Soorya19Pradeep what's the configuration that doesn't work? I checked that this works (on main):

model = VSUNet(
    architecture="2.2D",
    model_config={
        "in_channels": 1,
        "out_channels": 3,
        "in_stack_depth": 5,
        "backbone": "convnextv2_tiny",
        "stem_kernel_size": (5, 4, 4),
        "decoder_mode": "pixelshuffle",
        "head_expansion_ratio": 4,
    },
)

Edit: I see that the '2.2D' class doesn't work for 2D-2D. For that you need to use the 2d-fcmae branch.

ziw-liu avatar May 23 '24 23:05 ziw-liu

For example:

model = FcmaeUNet(
    model_config=dict(
        in_channels=1,
        out_channels=3,
        encoder_blocks=[3, 3, 9, 3],
        dims=[96, 192, 384, 768],
        decoder_conv_blocks=1,
        stem_kernel_size=(1, 2, 2),
        in_stack_depth=1,
    )
)

ziw-liu avatar May 23 '24 23:05 ziw-liu

Can we specify the loss via config?

Yes, it's configured at:

https://github.com/mehta-lab/VisCy/blob/40a7044c49729673b64a60b733614538599eec3a/examples/configs/fit_example.yml#L23

For example:

  loss_function:
    class_path: viscy.light.engine.MixedLoss
    init_args:
      l1_alpha: 0.5
      l2_alpha: 0.0
      ms_dssim_alpha: 0.5

ziw-liu avatar May 23 '24 23:05 ziw-liu

For segmentation models, we need a task=segmentation option that tweaks the model architecture.

The final activation layer is not learnable. I think It's cleaner to remove it from all the U-Nets and only define it in the lightning module. For example we have a lightning module that takes a regression U-Net and performs virtual staining, and another one that adds an optional activation and performs infection phenotyping.

ziw-liu avatar May 23 '24 23:05 ziw-liu

Thanks @ziw-liu.

Re: regression models: @Soorya19Pradeep The models that predict infection scores (phase IQR or skew of fluorescence) should be trainable with the current code. I suggest that you compute infection score channels while stitching the fragmented time series. Once the data is in that format, we can iterate on the models, figures, and movies.

@ziw-liu Does architecture="2.2D" specify a model with 3D (ZYX) input and 3D (ZYX) output, correct? What is the right model when the input is 3D (BxCinx5x512x512) and the output is 2D (BxCoutx512x512)?

Cin = (phase, fluorescence), Cout= infection score OR Cout=logits for classes.

Re: segmentation model

Yes, it makes sense that the segmentation model is implemented as a lighting module.

mattersoflight avatar May 23 '24 23:05 mattersoflight

Does architecture="2.2D" specify a model with 3D (ZYX) input and 3D (ZYX) output, correct?

This is correct.

What is the right model when the input is 3D (BxCinx5x512x512) and the output is 2D (BxCoutx512x512)?

That would be the "2.1D". I just checked and it has bit-rotted since mid-last year as we deprecated it for virtual staining. I can fix it if there's a use case.

ziw-liu avatar May 23 '24 23:05 ziw-liu

That would be the "2.1D". I just checked and it has bit-rotted since mid-last year as we deprecated it for virtual staining. I can fix it if there's a use case.

How about modifying 2.2D UNet so that it has out_stack_depth parameter?

mattersoflight avatar May 23 '24 23:05 mattersoflight

How about modifying 2.2D UNet so that it has out_stack_depth parameter?

This also works.

ziw-liu avatar May 24 '24 00:05 ziw-liu

@ziw-liu , I tried out the covnext after pulling changes on main branch. I am still getting the error. Do you know where this would be coming from?

File /hpc/mydata/soorya.pradeep/viscy_inf/lib/python3.10/site-packages/torch/nn/modules/conv.py:610, in Conv3d.forward(self, input) 609 def forward(self, input: Tensor) -> Tensor: --> 610 return self._conv_forward(input, self.weight, self.bias)

File /hpc/mydata/soorya.pradeep/viscy_inf/lib/python3.10/site-packages/torch/nn/modules/conv.py:605, in Conv3d._conv_forward(self, input, weight, bias) 593 if self.padding_mode != "zeros": 594 return F.conv3d( 595 F.pad( 596 input, self._reversed_padding_repeated_twice, mode=self.padding_mode (...) 603 self.groups, 604 ) --> 605 return F.conv3d( 606 input, weight, bias, self.stride, self.padding, self.dilation, self.groups 607 )

RuntimeError: Given groups=1, weight of size [96, 1, 5, 4, 4], expected input[16, 2, 5, 256, 256] to have 1 channels, but got 2 channels instead

Soorya19Pradeep avatar May 27 '24 16:05 Soorya19Pradeep

after pulling changes on main branch

@Soorya19Pradeep Which branch are you using exactly? The fixes are in #80.

ziw-liu avatar May 27 '24 23:05 ziw-liu

@ziw-liu , my data is 3D. Do I still use the fix-3d-to-2d branch? I tried the 2.2D model on main and infection_phenotying branches. Both gave me the error.

Soorya19Pradeep avatar May 28 '24 01:05 Soorya19Pradeep

@ziw-liu , my data is 3D. Do I still use the fix-3d-to-2d branch? I tried the 2.2D model on main and infection_phenotying branches. Both gave me the error.

If I understand correctly, you are trying to predict 2D infection labels with 3D data as input, then you should use #80.

ziw-liu avatar May 28 '24 02:05 ziw-liu

My infection annotation channel has been extended to 3D. Here is a sample dataset: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated.zarr

Soorya19Pradeep avatar May 28 '24 03:05 Soorya19Pradeep

@Soorya19Pradeep what's your model configuration?

ziw-liu avatar May 28 '24 03:05 ziw-liu

@Soorya19Pradeep The annotations are constant along Z, right? It does make sense to use 6 slices of HSP90 and phase as input, and 1 slice of annotation as an output. The 6-slice -> 6-slice model and 6-slice -> 1 slice model both should be configurable by fix-3d-to-2d branch. I have not yet tried the branch however.

mattersoflight avatar May 28 '24 03:05 mattersoflight

@ziw-liu , I got it working! I hardcoded the '1' in my code. Sorry about the confusion!

Soorya19Pradeep avatar May 28 '24 20:05 Soorya19Pradeep

@mattersoflight , I have training running with fix-3d-to-2d branch. You can merge it to main.

Soorya19Pradeep avatar May 28 '24 20:05 Soorya19Pradeep