MedSegDiff icon indicating copy to clipboard operation
MedSegDiff copied to clipboard

about the model

Open Devil-Ideal opened this issue 1 year ago • 3 comments

Hi! I'm interesting in your work and I have read both of the paper. However, after reading the code, I'm a little confused. It seems that this code is still about the first paper, MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model, since I didn't find the SS-Former and the Anchor Condition. I'm guessing it's an improvement over the previous paper (the first paper), because there are also somethings different from the paper. Still, I noticed somthing strange, the ISICDataset returns (img, mask, name) and batch gets the first element, cond gets the secoend element [code: batch, cond, name = next(data_iter) ], and then they were concated by batch=th.cat((batch, cond), dim=1). I think the variable denoted cond is the mask, and it's at the last channel. In the model, UNetModel_newpreview, the function named highway_forward is used to encode the image but it seems that it is misused to encode the mask, since the variable denoted c is represented the mask in fact. the code: ++++++++++++function of the gaussion_diffusion def training_losses_segmentation(self, model, classifier, x_start, t, model_kwargs=None, noise=None): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ if model_kwargs is None: model_kwargs = {} if noise is None: noise = th.randn_like(x_start[:, -1:, ...])

    mask = x_start[:, -1:, ...]
    res = torch.where(mask > 0, 1, 0)   #merge all tumor classes into one to get a binary segmentation mask

    res_t = self.q_sample(res, t, noise=noise)     #add noise to the segmentation channel
    x_t=x_start.float()
    x_t[:, -1:, ...]=res_t.float()
    terms = {}


    if self.loss_type == LossType.MSE or self.loss_type == LossType.BCE_DICE or self.loss_type == LossType.RESCALED_MSE:

        model_output, cal = model(x_t, self._scale_timesteps(t), **model_kwargs)

++++++++forward of the model def forward(self, x, timesteps, y=None): """ Apply the model to an input batch.

    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: an [N x C x ...] Tensor of outputs.
    """
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"

    hs = []
    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)

    h = x.type(self.dtype)
    c = h[:,:-1,...]
    anch, cal = self.highway_forward(c)

Devil-Ideal avatar Jun 22 '23 11:06 Devil-Ideal

Hi! I'm interesting in your work and I have read both of the paper. However, after reading the code, I'm a little confused. It seems that this code is still about the first paper, MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model, since I didn't find the SS-Former and the Anchor Condition. I'm guessing it's an improvement over the previous paper (the first paper), because there are also somethings different from the paper. Still, I noticed somthing strange, the ISICDataset returns (img, mask, name) and batch gets the first element, cond gets the secoend element [code: batch, cond, name = next(data_iter) ], and then they were concated by batch=th.cat((batch, cond), dim=1). I think the variable denoted cond is the mask, and it's at the last channel. In the model, UNetModel_newpreview, the function named highway_forward is used to encode the image but it seems that it is misused to encode the mask, since the variable denoted c is represented the mask in fact. the code: ++++++++++++function of the gaussion_diffusion def training_losses_segmentation(self, model, classifier, x_start, t, model_kwargs=None, noise=None): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ if model_kwargs is None: model_kwargs = {} if noise is None: noise = th.randn_like(x_start[:, -1:, ...])

    mask = x_start[:, -1:, ...]
    res = torch.where(mask > 0, 1, 0)   #merge all tumor classes into one to get a binary segmentation mask

    res_t = self.q_sample(res, t, noise=noise)     #add noise to the segmentation channel
    x_t=x_start.float()
    x_t[:, -1:, ...]=res_t.float()
    terms = {}


    if self.loss_type == LossType.MSE or self.loss_type == LossType.BCE_DICE or self.loss_type == LossType.RESCALED_MSE:

        model_output, cal = model(x_t, self._scale_timesteps(t), **model_kwargs)

++++++++forward of the model def forward(self, x, timesteps, y=None): """ Apply the model to an input batch.

    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: an [N x C x ...] Tensor of outputs.
    """
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"

    hs = []
    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)

    h = x.type(self.dtype)
    c = h[:,:-1,...]
    anch, cal = self.highway_forward(c)

Sorry, I misunderstand something, the variable c is actually represent condition. But there is still one thing confuses me. The image and mask were both input into the encoder of diffusion model which mismatch with the paper.

Devil-Ideal avatar Jun 27 '23 07:06 Devil-Ideal

I also did not find the code related to SSFormer, may I ask if your problem has been solved? Besides, it seems that the code only adds the conditional feature at the first layer?

takimailto avatar Jul 04 '23 07:07 takimailto

I also did not find the code related to SSFormer, may I ask if your problem has been solved? Besides, it seems that the code only adds the conditional feature at the first layer?

Sorry, I haven't solve it and I'm trying to reproduce the experiments using old version of the model(--version old, not by default)

Devil-Ideal avatar Jul 04 '23 07:07 Devil-Ideal