ray
ray copied to clipboard
[Rllib] InvalidArgumentError: cannot compute ConcatV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor
What happened + What you expected to happen
The problem occurred when training the Soft-Actor Critic (SAC) model with TensorFlow2 on the Hopper environment of Gymnasium.
config = (
SACConfig()
.environment(env="Hopper-v4")
.framework("tf2")
)
algo = config.build()
Instead of compiling normally, the script threw the following error:
Traceback (most recent call last):
File "/Users/macbookpro/Documents/autonomous/generative/markov/test-rllib/github_report.py", line 9, in <module>
algo = config.build()
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm_config.py", line 1071, in build
return algo_class(
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/sac/sac.py", line 354, in __init__
super().__init__(*args, **kwargs)
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 475, in __init__
super().__init__(
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 170, in __init__
self.setup(copy.deepcopy(self.config))
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 601, in setup
self.workers = WorkerSet(
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 172, in __init__
self._setup(
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 262, in _setup
self._local_worker = self._make_worker(
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 967, in _make_worker
worker = cls(
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 738, in __init__
self._update_policy_map(policy_dict=self.policy_dict)
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1985, in _update_policy_map
self._build_policy_map(
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 2097, in _build_policy_map
new_policy = create_policy_for_framework(
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/utils/policy.py", line 139, in create_policy_for_framework
return policy_class(observation_space, action_space, merged_config)
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/policy/eager_tf_policy.py", line 470, in __init__
self._initialize_loss_from_dummy_batch(
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/policy/policy.py", line 1487, in _initialize_loss_from_dummy_batch
self._loss(self, self.model, self.dist_class, train_batch)
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/sac/sac_tf_policy.py", line 333, in sac_actor_critic_loss
q_t, _ = model.get_q_values(
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/sac/sac_tf_model.py", line 212, in get_q_values
return self._get_q_value(model_out, actions, self.q_net)
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/ray/rllib/algorithms/sac/sac_tf_model.py", line 245, in _get_q_value
input_dict = {"obs": tf.concat([model_out, actions], axis=-1)}
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Users/macbookpro/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 7262, in raise_from_not_ok_status
raise core._status_to_exception(e) from None # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute ConcatV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:ConcatV2] name: concat
The error happened because TensorFlow tried to concat two tensor of different datatype (input_dict = {"obs": tf.concat([model_out, actions], axis=-1)}
). The model_out
tensor has float64 datatype (after more backtracking, the datatype seems to be inferred from the observation space), while the actions
tensor always has float32 datatype (as it is always casted to float32 by the following code).
q_t, _ = model.get_q_values(
model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)
)
I have tried to specify the observation_space
to use float32 datatype, but the error persisted.
config = (
SACConfig()
.environment(
env="Hopper-v4",
observation_space=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float32))
.framework("tf2")
)
For now, my workaround is:
- Modify the library code to cast the model_out tensor to float32. This solution worked, but it seems rather hacky and I don't know if this fix won't break the model when trained on other environment.
- Use PyTorch instead of TF2. This is not ideal because I need to inferencing the model on web frontend, and there is currently no library to effortlessly run PyTorch model on web. Converting PyTorch to TF.js is rather complicated (PyTorch -> ONNX -> TF python -> TF.js), and most tools for ONNX -> TF python conversion is no longer maintained.
Versions / Dependencies
Ray: 2.5.0 OS: macOS Monterey (12.6) Python: 3.9.16
Reproduction script
from ray.rllib.algorithms.sac import SACConfig
config = (
SACConfig()
.environment(env="Hopper-v4")
.framework("tf2")
)
algo = config.build()
Issue Severity
Low: It annoys or frustrates me.