MedSegDiff
MedSegDiff copied to clipboard
error in create_argparser
defaults.update({k: v for k, v in model_and_diffusion_defaults().items() if k not in defaults}) Hi, i believe this is what you want to have, otherwise the value will be overwriten by those in the predefined values
Instead of passing arguments using CLI, you can do something like this too.
def create_argparser(): # This argparser is from inference/sampling, but for training also you can follow same approach
defaults = dict(
data_name = '', # DATASET NAME
data_dir="", # PATH TO DATASET DIR
clip_denoised=True,
num_samples=1,
batch_size=1,
use_ddim=False,
model_path="",
num_ensemble=5, #number of samples in the ensemble
gpu_dev = "0",
out_dir='./results/',
multi_gpu = None, #"0,1,2"
debug = True
)
defaults.update(model_and_diffusion_defaults())
return defaults
class CFG:
def __init__(self, arg_dict = create_argparser()):
for key, value in arg_dict.items():
setattr(self, key, arg_dict.get(key, value))
Now in the def main():
args = CFG()