PUMIT icon indicating copy to clipboard operation
PUMIT copied to clipboard

Need guidance on using pretrained tokenizer

Open Masaaki-75 opened this issue 10 months ago • 3 comments

Hi! I am trying to use the pretrained tokenizer to obtain latent code for my input CT images.

However, I didn't see the identity-mapping-like reconstruction as demonstrated in Figure 3 of your paper. I guess there's something wrong with the way I handle input.

Here's the process:

"""Step 1: Define the network"""
quantize = VectorQuantizer(num_embeddings=1024, embedding_dim=512, mode='soft')
tokenizer = SimpleVQTokenizer(quantize=quantize, in_channels=3, start_stride=4)
tokenizer.load_state_dict(ckpt['model'], strict=True)
tokenizer.eval()


"""Step 2: Prepare the input"""
def get_rescaled_ct(npy_path, new_range=(-1, 1)):
    x = np.load(npy_path).clip(-1024, 3071)  # typical CT range
    x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
    # linearly transforms (x_min, x_max) to (y_min, y_max)
    x = rescale_tensor(x, y_min=new_range[0], y_max=new_range[1], x_min=-1024, x_max=3071)
    return x

def prepare_input(x: torch.Tensor):
    if x.ndim == 4:  # 2D -> 3D
        x = x.unsqueeze(2)
    if x.shape[1] == 1:  # 1-channel -> 3-channel
        x = x.repeat(1, 3, 1, 1, 1)
        
    if not isinstance(x, SpatialTensor):
        # force aniso_d=6 for 2D input
        x = SpatialTensor(x, aniso_d=6)  
    return x


"""Step 3: Test the tokenizer"""
img_path = ".../some_ct_slice.npy"
x0 = get_rescaled_ct(img_path)  # [1, 1, 512, 512], ranging within [-1, 1]
x = prepare_input(x0)  # [1, 3, 1, 512, 512], ranging within [-1, 1]

with torch.no_grad():
    z = tokenizer.encode(x)
    y = tokenizer.decode(z)

I was expecting that y looks similar as x, but the visualization shows: image

Any advice on that? Thanks!

BTW, here's the info about x0, x, z and y, if needed:

x0: Shape: (1, 1, 512, 512), Range: (-1., 0.4971).
x: Shape: (1, 3, 1, 512, 512), Range: (-1., 0.4971).
z: Shape: (1, 512, 1, 32, 32), Range: (-0.2395, 17.5686).
y: Shape: (1, 3, 1, 512, 512), Range: (-1.6264, 5.3129).

Masaaki-75 avatar Apr 09 '24 11:04 Masaaki-75