btgym
btgym copied to clipboard
Migration to Tensorflow2
Automatic and manual changes for migration of Tensorflow 1 to 2.
Because there are no tests I checked some notebooks from examples:
- very basic environment setup- works
- setting up environment basic- works without the last cell (problem of requesting data from data server)
- setting up environment full- works
- guided ac3 - runtime error
I would need some help to finish this PR.
@woj-i , do I understand correctly issue is: guided ac3 - runtime error, or there are some other points?
I think it's more than that. But there are no tests to point that places. Currently I get the error (for guided_a3c notebook):
File "/home/wind/repositories/gitHub/btgym/btgym/algorithms/aac.py", line 408, in __init__
with tf.device(tf.compat.v1.train.replica_device_setter(1, worker_device=self.worker_device)):
File "/home/wind/repositories/gitHub/btgym/venv/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 5273, in device_v2
raise RuntimeError("tf.device does not support functions.")
RuntimeError: tf.device does not support functions.
@woj-i Same error. Have you solved it?
Change tf.device
to tf.v1.compat.device
will solve this.
I found btgym/research/casual_conv/networks.py
have the following code:
alignments = attention_mechanism(
query_state,
attention_mechanism.initial_alignments(tf.shape(inputs)[0], dtype=tf.float32)
)
Because LuongAttention
has a significant change in tensorflow 2.0. So we should update this code.
But I do not understand the logic here. Could anyone help?
and lstm_network
in algorithms/nn/networks.py
also need some updates, because the dropoutwrapper
will cause error in tensorflow2, use dropout
parameter instead, according to https://stackoverflow.com/questions/62989175/layernormlstmcell-object-has-no-attribute-zero-state-in-tf-2-2 and https://github.com/tensorflow/tensorflow/issues/29129
And this
Traceback (most recent call last):
File "/Users/cmal/btgym/btgym/algorithms/aac.py", line 409, in __init__
self.network = pi_global = self._make_policy('global')
File "/Users/cmal/btgym/btgym/algorithms/aac.py", line 842, in _make_policy
network = self.policy_class(**self.policy_kwargs)
File "/Users/cmal/btgym/btgym/algorithms/policy/stacked_lstm.py", line 300, in __init__
**kwargs,
File "/Users/cmal/btgym/btgym/algorithms/nn/networks.py", line 148, in lstm_network
lstm_init_state = lstm.zero_state(1, dtype=tf.float32)
File "/Users/cmal/py3tf2.3/lib/python3.7/site-packages/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py", line 1273, in zero_state
return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
File "/Users/cmal/py3tf2.3/lib/python3.7/site-packages/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py", line 1273, in <genexpr>
return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
File "/Users/cmal/py3tf2.3/lib/python3.7/site-packages/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py", line 203, in zero_state
return self.cell.zero_state(batch_size, dtype)
AttributeError: 'LayerNormLSTMCell' object has no attribute 'zero_state'
[2020-10-27 04:19:46.238595] ERROR: Worker_0: Base class __init__() exception occurred.
Hey @cmal! Thank you for your help! You are very welcome to PR your changes to my tensorflow2 branch https://github.com/woj-i/btgym/tree/tensorflow2
seems install_requires
needs refactoring to resolve conflicts
Is there any progress on this upgrade PR? I wish to use this environment for my bachelor thesis the coming weeks and perhaps i could be of assistance to speed things up a bit because I have only yet worked with TF 2.0.
No progress from my side. @Arno989 you're welcome to contribute to this PR.