trak
trak copied to clipboard
TRAK on 5D shapes
Hi, I have a classification model trained on the shape batch_size x N (=46) x channel x height x width. How can I adapt TRAK to use for that ? I get an error in the featurize function itself. Do I have to modify the in_dims?
/opt/conda/lib/python3.8/site-packages/trak/gradient_computers.py in compute_per_sample_grad(self, batch) 148 149 # map over batch dimensions (hence 0 for each batch dimension, and None for model params) --> 150 grads = torch.func.vmap( 151 grads_loss, 152 in_dims=(None, None, None, *([0] * len(batch))),
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in wrapped(*args, **kwargs) 432 433 # If chunk_size is not specified. --> 434 return _flat_vmap( 435 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs 436 )
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in fn(*args, **kwargs) 37 def fn(*args, **kwargs): 38 with torch.autograd.graph.disable_saved_tensors_hooks(message): ---> 39 return f(*args, **kwargs) 40 return fn 41
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs) 617 try: 618 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec) --> 619 batched_outputs = func(*batched_inputs, **kwargs) 620 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func) 621 finally:
/opt/conda/lib/python3.8/site-packages/torch/functorch/eager_transforms.py in wrapper(*args, **kwargs) 1378 @wraps(func) 1379 def wrapper(*args, **kwargs): -> 1380 results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs) 1381 if has_aux: 1382 grad, (, aux) = results
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in fn(*args, **kwargs) 37 def fn(*args, **kwargs): 38 with torch.autograd.graph.disable_saved_tensors_hooks(message): ---> 39 return f(*args, **kwargs) 40 return fn 41
/opt/conda/lib/python3.8/site-packages/torch/functorch/eager_transforms.py in wrapper(*args, **kwargs) 1243 tree_map(partial(_create_differentiable, level=level), diff_args) 1244 -> 1245 output = func(*args, **kwargs) 1246 if has_aux: 1247 if not (isinstance(output, tuple) and len(output) == 2):
/opt/conda/lib/python3.8/site-packages/trak/modelout_functions.py in get_output(model, weights, buffers, image, label) 138 """ 139 logits = ch.func.functional_call(model, (weights, buffers), image.unsqueeze(0)) --> 140 bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False) 141 logits_correct = logits[bindex, label.unsqueeze(0)] 142
AttributeError: 'tuple' object has no attribute 'shape'