Three channel output using covnext model
I want to train the covnext model for the infection classifier problem. The model will use two or more input channels (phase + HSP90 channels + infection score channels, etc.) and should output three channels (background + uninfected + infected channels). Currently, I can perform this training using 2D and 2.5D unets, but not using covnext. It doesn't allow three-channel output due to a hardcoded scaling parameter value. @ziw-liu, can you help with this? Thanks!
@ziw-liu We are trying two kinds of models:
- Semantic segmentation (background vs uninfected vs infected) based on human annotation.
- Regression of infection score based on engineered features.
The regression models can be trained without any code change. What is the right branch at this point?
For segmentation models, we need a task=segmentation option that tweaks the model architecture. Can we specify the loss via config?
@Soorya19Pradeep what's the configuration that doesn't work? I checked that this works (on main):
model = VSUNet(
architecture="2.2D",
model_config={
"in_channels": 1,
"out_channels": 3,
"in_stack_depth": 5,
"backbone": "convnextv2_tiny",
"stem_kernel_size": (5, 4, 4),
"decoder_mode": "pixelshuffle",
"head_expansion_ratio": 4,
},
)
Edit: I see that the '2.2D' class doesn't work for 2D-2D. For that you need to use the 2d-fcmae branch.
For example:
model = FcmaeUNet(
model_config=dict(
in_channels=1,
out_channels=3,
encoder_blocks=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
decoder_conv_blocks=1,
stem_kernel_size=(1, 2, 2),
in_stack_depth=1,
)
)
Can we specify the loss via config?
Yes, it's configured at:
https://github.com/mehta-lab/VisCy/blob/40a7044c49729673b64a60b733614538599eec3a/examples/configs/fit_example.yml#L23
For example:
loss_function:
class_path: viscy.light.engine.MixedLoss
init_args:
l1_alpha: 0.5
l2_alpha: 0.0
ms_dssim_alpha: 0.5
For segmentation models, we need a task=segmentation option that tweaks the model architecture.
The final activation layer is not learnable. I think It's cleaner to remove it from all the U-Nets and only define it in the lightning module. For example we have a lightning module that takes a regression U-Net and performs virtual staining, and another one that adds an optional activation and performs infection phenotyping.
Thanks @ziw-liu.
Re: regression models: @Soorya19Pradeep The models that predict infection scores (phase IQR or skew of fluorescence) should be trainable with the current code. I suggest that you compute infection score channels while stitching the fragmented time series. Once the data is in that format, we can iterate on the models, figures, and movies.
@ziw-liu Does architecture="2.2D" specify a model with 3D (ZYX) input and 3D (ZYX) output, correct? What is the right model when the input is 3D (BxCinx5x512x512) and the output is 2D (BxCoutx512x512)?
Cin = (phase, fluorescence), Cout= infection score OR Cout=logits for classes.
Re: segmentation model
Yes, it makes sense that the segmentation model is implemented as a lighting module.
Does architecture="2.2D" specify a model with 3D (ZYX) input and 3D (ZYX) output, correct?
This is correct.
What is the right model when the input is 3D (BxCinx5x512x512) and the output is 2D (BxCoutx512x512)?
That would be the "2.1D". I just checked and it has bit-rotted since mid-last year as we deprecated it for virtual staining. I can fix it if there's a use case.
That would be the "2.1D". I just checked and it has bit-rotted since mid-last year as we deprecated it for virtual staining. I can fix it if there's a use case.
How about modifying 2.2D UNet so that it has out_stack_depth parameter?
How about modifying 2.2D UNet so that it has out_stack_depth parameter?
This also works.
@ziw-liu , I tried out the covnext after pulling changes on main branch. I am still getting the error. Do you know where this would be coming from?
File /hpc/mydata/soorya.pradeep/viscy_inf/lib/python3.10/site-packages/torch/nn/modules/conv.py:610, in Conv3d.forward(self, input) 609 def forward(self, input: Tensor) -> Tensor: --> 610 return self._conv_forward(input, self.weight, self.bias)
File /hpc/mydata/soorya.pradeep/viscy_inf/lib/python3.10/site-packages/torch/nn/modules/conv.py:605, in Conv3d._conv_forward(self, input, weight, bias) 593 if self.padding_mode != "zeros": 594 return F.conv3d( 595 F.pad( 596 input, self._reversed_padding_repeated_twice, mode=self.padding_mode (...) 603 self.groups, 604 ) --> 605 return F.conv3d( 606 input, weight, bias, self.stride, self.padding, self.dilation, self.groups 607 )
RuntimeError: Given groups=1, weight of size [96, 1, 5, 4, 4], expected input[16, 2, 5, 256, 256] to have 1 channels, but got 2 channels instead
after pulling changes on main branch
@Soorya19Pradeep Which branch are you using exactly? The fixes are in #80.
@ziw-liu , my data is 3D. Do I still use the fix-3d-to-2d branch? I tried the 2.2D model on main and infection_phenotying branches. Both gave me the error.
@ziw-liu , my data is 3D. Do I still use the fix-3d-to-2d branch? I tried the 2.2D model on main and infection_phenotying branches. Both gave me the error.
If I understand correctly, you are trying to predict 2D infection labels with 3D data as input, then you should use #80.
My infection annotation channel has been extended to 3D. Here is a sample dataset: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated.zarr
@Soorya19Pradeep what's your model configuration?
@Soorya19Pradeep The annotations are constant along Z, right? It does make sense to use 6 slices of HSP90 and phase as input, and 1 slice of annotation as an output. The 6-slice -> 6-slice model and 6-slice -> 1 slice model both should be configurable by fix-3d-to-2d branch. I have not yet tried the branch however.
@ziw-liu , I got it working! I hardcoded the '1' in my code. Sorry about the confusion!
@mattersoflight , I have training running with fix-3d-to-2d branch. You can merge it to main.