transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Dtype error in model.fit()

Open els2155 opened this issue 3 years ago • 0 comments

I am having an issue when I run through the code provided in project_build_tf_sentiment_model folder. When I try and run this code:

history = model.fit(
   train_ds,
   validation_data=val_ds,
   epochs=2
)

in the 02_build_and_train_lstm_example.ipynb file, I get the following error:

`---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-7-9a34c6a1ecff> in <module>
      2     train_ds,
      3     validation_data=val_ds,
----> 4     epochs=2
      5 )

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1181                 _r=1):
   1182               callbacks.on_train_batch_begin(step)
-> 1183               tmp_logs = self.train_function(iterator)
   1184               if data_handler.should_sync:
   1185                 context.async_wait()

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    887 
    888       with OptionalXlaContext(self._jit_compile):
--> 889         result = self._call(*args, **kwds)
    890 
    891       new_tracing_count = self.experimental_get_tracing_count()

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    948         # Lifting succeeded, so variables are initialized and we can run the
    949         # stateless function.
--> 950         return self._stateless_fn(*args, **kwds)
    951     else:
    952       _, _, _, filtered_flat_args = \

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   3022        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   3023     return graph_function._call_flat(
-> 3024         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   3025 
   3026   @property

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1959       # No tape is watching; skip to running the function.
   1960       return self._build_call_outputs(self._inference_function.call(
-> 1961           ctx, args, cancellation_manager=cancellation_manager))
   1962     forward_backward = self._select_forward_and_backward_functions(
   1963         args,

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    594               inputs=args,
    595               attrs=attrs,
--> 596               ctx=ctx)
    597         else:
    598           outputs = execute.execute_with_cancellation(

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  Data type mismatch at component 0: expected double but got int32.
	 [[node IteratorGetNext (defined at <ipython-input-7-9a34c6a1ecff>:4) ]]
  (1) Invalid argument:  Data type mismatch at component 0: expected double but got int32.
	 [[node IteratorGetNext (defined at <ipython-input-7-9a34c6a1ecff>:4) ]]
	 [[GroupCrossDeviceControlEdges_0/Identity_2/_41]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_26737]

Function call stack:
train_function -> train_function`

I created a new environmeny from the *yml file and requirements.txt files provided and got the same error.

els2155 avatar Jul 16 '21 01:07 els2155