segmentation_models icon indicating copy to clipboard operation
segmentation_models copied to clipboard

"InvalidArgumentError: ConcatOp : Dimensions of inputs should match" for inference with larger input size (TF 2.7)

Open ben0it8 opened this issue 3 years ago • 0 comments

After training a UNet on 224x224 images I store it as a SavedModel binary using TF 2.7. Later, I load this model in the same environment in order to do inference on larger inputs (where side length is a multiple of 32, e.g. 1280 x 1280), but I get the following error upon prediction:

/srv/venv/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)                                    
   1705       TypeError: If the arguments do not match the function's signature.                                                                
   1706     """                                                                                                                                 
-> 1707     return self._call_impl(args, kwargs)                                                                                                
   1708                                                                                                                                         
   1709   def _call_impl(self, args, kwargs, cancellation_manager=None):                                                                        
                                                                                                                                                
/srv/venv/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_impl(self, args, kwargs, cancellation_manager)               
   1719           try:                                                                                                                          
   1720             return self._call_with_flat_signature(args, kwargs,                                                                         
-> 1721                                                   cancellation_manager)                                                                 
   1722           except TypeError:                                                                                                             
   1723             raise structured_err                                                                                                        
                                                                                                                                                
/srv/venv/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_with_flat_signature(self, args, kwargs, cancellation_manager)
   1772                         f"#{i}(zero-based) to be a Tensor; "                                                                            
   1773                         f"got {type(arg).__name__} ({arg}).")                                                                           
-> 1774     return self._call_flat(args, self.captured_inputs, cancellation_manager)                                                            
   1775                                                                                                                                         
   1776   def _call_with_structured_signature(self, args, kwargs, cancellation_manager):                                                        
                                                                                                                                                
/srv/venv/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in _call_flat(self, args, captured_inputs, cancellation_manager)    
    121       captured_inputs = list(map(get_unused_handle, captured_inputs))                                                                   
    122     return super(_WrapperFunction, self)._call_flat(args, captured_inputs,                                                              
--> 123                                                     cancellation_manager)                                                               
    124                                                                                                                                         
    125                                                                                                                                         
                                                                                                                                                
/srv/venv/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)      
   1958       # No tape is watching; skip to running the function.                                                                              
   1959       return self._build_call_outputs(self._inference_function.call(                                                                    
-> 1960           ctx, args, cancellation_manager=cancellation_manager))                                                                        
   1961     forward_backward = self._select_forward_and_backward_functions(                                                                     
   1962         args,                                                                                                                           
                                                                                                                                                
/srv/venv/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)                        
    601               inputs=args,                                                                                                              
    602               attrs=attrs,                                                                                                              
--> 603               ctx=ctx)                                                                                                                  
    604         else:                                                                                                                           
    605           outputs = execute.execute_with_cancellation(                                                                                  
                                                                                                                                                
/srv/venv/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)       
     57     ctx.ensure_initialized()                                                                                                            
     58     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,                                                              
---> 59                                         inputs, attrs, num_outputs)                                                                     
     60   except core._NotOkStatusException as e:                                                                                               
     61     if name is not None:                                                                                                                
                                                                                                                                                
InvalidArgumentError:  ConcatOp : Dimensions of inputs should match: shape[0] = [2,14,14,512] vs. shape[1] = [2,28,28,256]                      
         [[node network/model_2/decoder_stage0_concat/concat                                                                                    
 (defined at /srv/repo/backend/serving/inference/engine.py:56)                                                                                  
]] [Op:__inference_signature_wrapper_28656]                                                                                                     
                                                                                                                                                
Errors may have originated from an input operation.                                                                                             
Input Source operations connected to node network/model_2/decoder_stage0_concat/concat:                                                         
In[0] network/model_2/decoder_stage0_upsampling/resize/ResizeNearestNeighbor:                                                                   
In[1] network/model_2/stage4_unit1_relu1/Relu:                                                                                                  
In[2] network/model_2/decoder_stage0_concat/concat/axis:                                                                                        

This used to work with TF 2.4, as long as the input side length was divisible by 32. Any clue what causes this?

ben0it8 avatar Nov 30 '21 18:11 ben0it8