stable-baselines icon indicating copy to clipboard operation
stable-baselines copied to clipboard

[question]GAIL and discretized observation

Open IwamotoTaro opened this issue 5 years ago • 3 comments

Thank you for the wonderful tool.

(1) I was able to complete training using Cartpole expert data and GAIL. (2) Next, I added a wrapper to Cartpole to discretize observations and was able to complete training at TRPO. (3) Finally, I tried GAIL training with discretized observations, and an error (line 121 of gail / adversary.py) occurred.

Is GAIL training with discretized observations currently available at a stable-baselines?

IwamotoTaro avatar Dec 06 '19 08:12 IwamotoTaro

Hello, Please fill the issue template completely.

araffin avatar Dec 06 '19 08:12 araffin

Code example The file has been uploaded. GAIL and discretized observation.zip ・ TRPO_cartpoleD.py: TRPO training at Cartpole for discretized observation ・ Gen_cartpoleD.py: Creates expert data for 5 episodes (operates the cart with the left and right arrow keys) ・ Cartpole_trajD.npz: Sample of expert data ・ GAIL_cartpoleD.py: GAIL training with discretized observation

Error messages and stack traces

C:\anaconda3\lib\site-packages\gym\envs\registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated.  Call .resolve and .require separately.
  result = entry_point.load(False)
actions (315, 1)
obs (315, 1)
rewards (315,)
episode_returns (5,)
episode_starts (315,)
Total trajectories: -1
Total transitions: 315
Average returns: 63.0
Std for returns: 30.469657037781044
WARNING:tensorflow:From C:\anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
Traceback (most recent call last):
  File "C:\anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 511, in _apply_op_helper
    preferred_dtype=default_dtype)
  File "C:\anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1175, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "C:\anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 977, in _TensorTensorConversionFunction
    (dtype.name, t.dtype.name, str(t)))
ValueError: Tensor conversion requested dtype int64 for Tensor with dtype float32: 'Tensor("adversary/obfilter/Cast:0", shape=(), dtype=float32)'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "GAIL_cartpoleD.py", line 49, in <module>
    model = GAIL(MlpPolicy, env, dataset, verbose=1)
  File "C:\anaconda3\lib\site-packages\stable_baselines\gail\model.py", line 49, in __init__
    self.setup_model()
  File "C:\anaconda3\lib\site-packages\stable_baselines\trpo_mpi\trpo_mpi.py", line 129, in setup_model
    entcoeff=self.adversary_entcoeff)#
  File "C:\anaconda3\lib\site-packages\stable_baselines\gail\adversary.py", line 77, in __init__
    generator_logits = self.build_graph(self.generator_obs_ph, self.generator_acs_ph, reuse=False)
  File "C:\anaconda3\lib\site-packages\stable_baselines\gail\adversary.py", line 121, in build_graph
    obs = (obs_ph - self.obs_rms.mean) / self.obs_rms.std
  File "C:\anaconda3\lib\site-packages\tensorflow\python\ops\math_ops.py", line 812, in binary_op_wrapper
    return func(x, y, name=name)
  File "C:\anaconda3\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 10130, in sub
    "Sub", x=x, y=y, name=name)
  File "C:\anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 547, in _apply_op_helper
    inferred_from[input_arg.type_attr]))
TypeError: Input 'y' of 'Sub' Op has type float32 that does not match type int64 of argument 'x'.

System Info

  • no GPU
  • Python version:Python 3.6.5 :: Anaconda, Inc.
  • Tensorflow version: 1.13.1
  • stable-baselines version: 2.8.0

IwamotoTaro avatar Dec 06 '19 09:12 IwamotoTaro

obs = (obs_ph - self.obs_rms.mean) / self.obs_rms.std

It looks like this normalization should be deactivated when using discrete observation, otherwise cast error occurs. I'm also not sure if the correct preprocessing is applied for discrete observations (I only had a quick look).

The error comes from this line: https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/gail/adversary.py#L118 (and it seems that the docstring is wrong)

araffin avatar Dec 06 '19 10:12 araffin