seed_rl icon indicating copy to clipboard operation
seed_rl copied to clipboard

About sac_main.y

Open BlackDeal opened this issue 4 years ago • 1 comments

I tried to add sac_ main.py Files in Atari

 from absl import app
from absl import flags
# from seed_rl.agents.r2d2 import learner
from seed_rl.agents.sac import learner
from seed_rl.agents.sac import networks
from seed_rl.atari import env
# from seed_rl.atari import networks
from seed_rl.common import actor
from seed_rl.common import common_flags  
import tensorflow as tf

FLAGS = flags.FLAGS

# Optimizer settings.
flags.DEFINE_float('learning_rate', 0.00048, 'Learning rate.')
flags.DEFINE_float('adam_epsilon', 1e-3, 'Adam epsilon.')
flags.DEFINE_integer('stack_size', 4, 'Number of frames to stack.')


def create_agent(env_action_s, env_obs_s, parametric_action_distribution):
  return networks.ActorCriticMLP(parametric_action_distribution, 1,[32,32])


def create_optimizer(unused_final_iteration):
  learning_rate_fn = lambda iteration: FLAGS.learning_rate
  optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate,
                                       epsilon=FLAGS.adam_epsilon)
  return optimizer, learning_rate_fn


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  if FLAGS.run_mode == 'actor':
    actor.actor_loop(env.create_environment)
  elif FLAGS.run_mode == 'learner':
    learner.learner_loop(env.create_environment,
                         create_agent,
                         create_optimizer)
  else:
    raise ValueError('Unsupported run mode {}'.format(FLAGS.run_mode))


if __name__ == '__main__':
  # FLAGS.run_mode = 'learner'

  app.run(main)

But something went wrong

`run_main(main, args) File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 250, in _run_main sys.exit(main(argv)) File "../atari/sac_main.py", line 61, in main create_optimizer) File "/seed_rl/agents/sac/learner.py", line 402, in learner_loop initialize_agent_variables(agent) File "/seed_rl/agents/sac/learner.py", line 400, in initialize_agent_variables create_variables() File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 580, in call result = self._call(*args, **kwds) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 627, in _call self._initialize(args, kwds, add_initializers_to=initializers) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 506, in _initialize *args, **kwds)) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected graph_function, _, _ = self._maybe_define_function(args, kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2667, in _create_graph_function capture_by_value=self._capture_by_value), File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn return weak_wrapped_fn().wrapped(*args, **kwds) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper raise e.ag_error_metadata.to_exception(e) TypeError: in user code:

/seed_rl/agents/sac/learner.py:399 create_variables  *
    agent.get_Q(*decode(input_), action=decode(input_[0]))]
/seed_rl/agents/sac/networks.py:110 get_action  *
    return self.__call__(*args, **kwargs)
/seed_rl/agents/sac/networks.py:126 __call__  *
    action_params = self.get_action_params(prev_action, env_output, state)
/seed_rl/agents/sac/networks.py:101 get_action_params  *
    return self._actor_mlp(self._concat_obs(env_output.observation))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:927 __call__  **
    outputs = call_fn(cast_inputs, *args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/sequential.py:291 call
    outputs = layer(inputs, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:927 __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/core.py:1183 call
    outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:4346 tensordot
    ab_matmul = matmul(a_reshape, b_reshape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:180 wrapper
    return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:2984 matmul
    a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py:5587 mat_mul
    name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:578 _apply_op_helper
    param_name=input_name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:61 _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))

TypeError: Value passed to parameter 'a' has DataType uint8 not in list of allowed values: bfloat16, float16, float32, float64, int32, int64, complex64, complex128

` I suspect it's the input as a picture, but it doesn't appear in R2D2. Please help to solve it. Thank you very much!

BlackDeal avatar Jul 24 '20 02:07 BlackDeal

You use this network:

return networks.ActorCriticMLP(parametric_action_distribution, 1,[32,32])

which isn't really suited for atari. However, if you just do tf.cast([input], tf.float32) on the input in the network I'm guessing it will work.

Ideally you need a mix between the Q-learning network for Atari and the SAC example one.

lespeholt avatar Aug 05 '20 20:08 lespeholt