brax
brax copied to clipboard
Rendering an image with brax.io.image
I have seen this functionality mentioned in other issues, but I cannot find code that works. Here is my code that attempts to get an image of the environment at every step:
import jax.numpy as jnp
from brax import envs
from brax.io import html
from brax.io import image
from jax import random
import jax
if __name__ == "__main__":
env_name = "ant"
backend = "generalized"
episode_length = 5
env = envs.get_environment(backend=backend, env_name=env_name)
state = jax.jit(env.reset)(random.PRNGKey(0))
cum_reward = 0
states = []
for step in range(1, episode_length + 1):
action = jnp.array([0]*env.action_size)
state = jax.jit(env.step)(state, action)
states.append(state.pipeline_state)
cum_reward += state.reward
# this line gives the error. comment it out to get the html video
step_image = image.render(env.sys, state)
if state.done:
break
render = html.render(env.sys, states)
with open("traj.html", "w") as f:
f.write(render)
print("Episode ended. Total reward: " + str(cum_reward))
This throws an error
File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/brax/io/image.py", line 34, in render_array
renderer = mujoco.Renderer(sys.mj_model, height=height, width=width)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/mujoco/renderer.py", line 83, in __init__
self._mjr_context = _render.MjrContext(
^^^^^^^^^^^^^^^^^^^
mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
Exception ignored in: <function Renderer.__del__ at 0x7fa3777ec860>
Traceback (most recent call last):
File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/mujoco/renderer.py", line 330, in __del__
self.close()
File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/mujoco/renderer.py", line 318, in close
if self._mjr_context:
^^^^^^^^^^^^^^^^^
AttributeError: 'Renderer' object has no attribute '_mjr_context'
If you comment out the line calling image.render
, you will see that html.render
works, so I am not sure why one works and the other does not.
I am using brax 0.10.0 and mujoco 3.1.2.
You need to install dependencies for rendering. See https://pytorch.org/rl/reference/generated/knowledge_base/MUJOCO_INSTALLATION.html ('Prerequisite for rendering').
Thank you! I installed the dependencies with sudo apt-get install libglfw3 libgl1-mesa-glx libosmesa6
and conda install -c conda-forge glew
(because apt get could not find package libglew2.0).
I am not getting that error any more but I am getting this one
step_image = image.render(env.sys, states)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/brax/io/image.py", line 66, in render
frames[0].save(f, format=fmt)
^^^^^^^^^^^^^^
AttributeError: 'numpy.ndarray' object has no attribute 'save'. Did you mean: 'ravel'?
My numpy version is 1.24.3
I also encountered the same save error (using brax 0.10.0 and mujoco 3.1.2).
From looking at image.py in brax.io in the previous version of brax (0.9.4), it seems that each element of frames needs to be wrapped in Image.fromarray(....) before save is called. For example:
from PIL import Image
....
Image.fromarray(frames[0])
This looks like a bug to me.
Thanks for the bug report and finding the issue, we'll push out a fix
This should be fixed in 0825bcb74b53e36a62d50405a85e540fe2c25a95 , please let us know if that doesn't work! Closing for now
There's another issue. The renderings are not in color. Here are some example gifs and the code used to produce them.
import jax.numpy as jnp
from brax import envs
from brax.io import image
from jax import random
import jax
from IPython.display import Image
env_name = "reacher"
backend = "generalized"
episode_length = 5
env = envs.get_environment(backend=backend, env_name=env_name)
state = jax.jit(env.reset)(random.PRNGKey(0))
cum_reward = 0
states = []
rollout = []
for step in range(1, episode_length + 1):
rollout.append(state.pipeline_state)
action = jnp.array([0]*env.action_size)
state = jax.jit(env.step)(state, action)
states.append(state.pipeline_state)
cum_reward += state.reward
if state.done:
break
gif = Image(image.render(env.sys, rollout, fmt = 'gif'))
open('reacher.gif', 'wb').write(gif.data)