PUMIT
PUMIT copied to clipboard
Need guidance on using pretrained tokenizer
Hi! I am trying to use the pretrained tokenizer to obtain latent code for my input CT images.
However, I didn't see the identity-mapping-like reconstruction as demonstrated in Figure 3 of your paper. I guess there's something wrong with the way I handle input.
Here's the process:
"""Step 1: Define the network"""
quantize = VectorQuantizer(num_embeddings=1024, embedding_dim=512, mode='soft')
tokenizer = SimpleVQTokenizer(quantize=quantize, in_channels=3, start_stride=4)
tokenizer.load_state_dict(ckpt['model'], strict=True)
tokenizer.eval()
"""Step 2: Prepare the input"""
def get_rescaled_ct(npy_path, new_range=(-1, 1)):
x = np.load(npy_path).clip(-1024, 3071) # typical CT range
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
# linearly transforms (x_min, x_max) to (y_min, y_max)
x = rescale_tensor(x, y_min=new_range[0], y_max=new_range[1], x_min=-1024, x_max=3071)
return x
def prepare_input(x: torch.Tensor):
if x.ndim == 4: # 2D -> 3D
x = x.unsqueeze(2)
if x.shape[1] == 1: # 1-channel -> 3-channel
x = x.repeat(1, 3, 1, 1, 1)
if not isinstance(x, SpatialTensor):
# force aniso_d=6 for 2D input
x = SpatialTensor(x, aniso_d=6)
return x
"""Step 3: Test the tokenizer"""
img_path = ".../some_ct_slice.npy"
x0 = get_rescaled_ct(img_path) # [1, 1, 512, 512], ranging within [-1, 1]
x = prepare_input(x0) # [1, 3, 1, 512, 512], ranging within [-1, 1]
with torch.no_grad():
z = tokenizer.encode(x)
y = tokenizer.decode(z)
I was expecting that y looks similar as x, but the visualization shows:
Any advice on that? Thanks!
BTW, here's the info about x0, x, z and y, if needed:
x0: Shape: (1, 1, 512, 512), Range: (-1., 0.4971).
x: Shape: (1, 3, 1, 512, 512), Range: (-1., 0.4971).
z: Shape: (1, 512, 1, 32, 32), Range: (-0.2395, 17.5686).
y: Shape: (1, 3, 1, 512, 512), Range: (-1.6264, 5.3129).
Hello, sorry for making you wait for so long, since we are working on other stuffs. Did you solve this issue? I guess this may be caused by the code version mismatch. Which version of code are you using?
I am not sure about the exact version. I guess it would be from submit branch in January but seems like it is gone now. Here's what I can confirm:
- the
SimpleVQTokenizerarchitecture is the same as in https://github.com/function2-llx/PUMIT/blob/67218a2aebf145b0b6f5cd3ae292adfe39f22561/pumit/tokenizer/simple.py. The detail arguments are set asin_channels = 3, start_stride = 4, downsample_layer_channels = [128, 256, 512], upsample_layer_channels = [128, 256, 512], encoder_act = nn.GELU). - the
VectorQuantizerarchitecture is the same as in https://github.com/function2-llx/PUMIT/blob/67218a2aebf145b0b6f5cd3ae292adfe39f22561/pumit/tokenizer/quantize.py - the
SpatialTensorclass is from https://github.com/function2-llx/PUMIT/blob/67218a2aebf145b0b6f5cd3ae292adfe39f22561/pumit/sac.py
Also, the detailed architecture of SimpleVQTokenizer is as follows, if this will help:
SimpleVQTokenizer(
(quantize): VectorQuantizer(
(proj): Linear(in_features=512, out_features=1024, bias=True)
(embedding): Embedding(1024, 512)
)
(encoder): Sequential(
(0): InflatableConv3d(3, 128, kernel_size=(4, 4, 4), stride=(4, 4, 4))
(1): LayerNormNd(
(0): ChannelLast('n c ... -> n ... c')
(1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(2): ChannelFirst('n ... c -> n c ...')
(3): Contiguous()
)
(2): GELU(approximate='none')
(3): InflatableConv3d(128, 256, kernel_size=(2, 2, 2), stride=(2, 2, 2))
(4): LayerNormNd(
(0): ChannelLast('n c ... -> n ... c')
(1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(2): ChannelFirst('n ... c -> n c ...')
(3): Contiguous()
)
(5): GELU(approximate='none')
(6): InflatableConv3d(256, 512, kernel_size=(2, 2, 2), stride=(2, 2, 2))
(7): LayerNormNd(
(0): ChannelLast('n c ... -> n ... c')
(1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(2): ChannelFirst('n ... c -> n c ...')
(3): Contiguous()
)
(8): GELU(approximate='none')
(9): InflatableConv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(10): GroupNorm(8, 512, eps=1e-05, affine=True)
(11): LeakyReLU(negative_slope=0.01, inplace=True)
(12): InflatableConv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(13): GroupNorm(8, 512, eps=1e-05, affine=True)
(14): LeakyReLU(negative_slope=0.01, inplace=True)
)
(decoder): Sequential(
(0): InflatableConv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(1): GroupNorm(8, 512, eps=1e-05, affine=True)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
(3): InflatableConv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(4): GroupNorm(8, 512, eps=1e-05, affine=True)
(5): LeakyReLU(negative_slope=0.01, inplace=True)
(6): AdaptiveTransposedConvUpsample(
(transposed_conv): InflatableTransposedConv3d(512, 256, kernel_size=(2, 2, 2), stride=(2, 2, 2))
(conv): Sequential(
(0): InflatableConv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(1): GroupNorm(8, 256, eps=1e-05, affine=True)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(7): AdaptiveTransposedConvUpsample(
(transposed_conv): InflatableTransposedConv3d(256, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2))
(conv): Sequential(
(0): InflatableConv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(1): GroupNorm(8, 128, eps=1e-05, affine=True)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(8): InflatableTransposedConv3d(128, 3, kernel_size=(4, 4, 4), stride=(4, 4, 4))
)
)
@Masaaki-75 My dear friend, you forgot to perform the vector quantization. You should call tokenizer.quantize(z) before decoding.
Sorry again for the late reply.