ar-vae icon indicating copy to clipboard operation
ar-vae copied to clipboard

main() on train_image_vae.py

Open IamIDG opened this issue 2 years ago • 1 comments

Hello Dr.Pati,

I am trying to run the train_image_vae.py script on Colab (where I have copied over all the necessary components). I am running into the following error when calling the main() function without arguments:

TypeError: main() missing 13 required positional arguments: 'dataset_type', 'batch_size', 'num_epochs', 'lr', 'beta', 'capacity', 'gamma', 'delta', 'dec_dist', 'train', 'log', 'rand', and 'reg_type'

I tried passing the required parameters (based on the values specified in the class ImageVAETrainer(Trainer) as follows:

main('EchoLV', 2, 10, 1e-04, 4.0, 0.0, 10.0, 1.0, 'bernoulli', 'train', 'log', 'rand', 'volume')

but it seems to be reading in the 'volume' argument letter-by-letter. Here is the error I am getting:

in main(dataset_type, batch_size, num_epochs, lr, beta, capacity, gamma, delta, dec_dist, train, log, rand, reg_type) 35 reg_dim = [] 36 for r in reg_type: ---> 37 reg_dim.append(attr_dict[r]) 38 else: 39 reg_dim = [0] KeyError: 'v'

I tried using ['volume'] but ran into: TypeError: unhashable type: 'list'

What am I missing? Any help would be appreciated. Thank you!

IamIDG avatar Sep 15 '23 21:09 IamIDG

It's been a while since I have looked at this.

There seems to be a bug in the code. You can try changing the lines below (73--86 in main()):

    if len(reg_type) != 0:
        if len(reg_type) == 1:
            if reg_type[0] == 'all':
                reg_dim = []
                for r in attr_dict.keys():
                    if r == 'digit_identity' or r == 'color':
                        continue
                    reg_dim.append(attr_dict[r])
            else:
                reg_dim = [attr_dict[reg_type]]
        else:
            reg_dim = []
            for r in reg_type:
                reg_dim.append(attr_dict[r])

to

    # these should convert your string `volume` to [`volume`]
    if isinstance(reg_type, str):
        reg_type = [reg_type]
    
    # we do the same as before
    if len(reg_type) != 0:
        if len(reg_type) == 1:
            if reg_type[0] == 'all':
                reg_dim = []
                for r in attr_dict.keys():
                    if r == 'digit_identity' or r == 'color':
                        continue
                    reg_dim.append(attr_dict[r])
            else:
                reg_dim = [attr_dict[reg_type]]
        else:
            reg_dim = []
            for r in reg_type:
                reg_dim.append(attr_dict[r])        
    

However, the code was originally written to work with the mnist / dpsrites datasets only. However, it seems like you are using a custom dataset (EchoLV). So, I am not sure it would work even if you fix this. You might have to do some additional refactoring. E.g, you might need to define a custom attr_dict

ashispati avatar Sep 16 '23 01:09 ashispati