brax icon indicating copy to clipboard operation
brax copied to clipboard

Rendering an image with brax.io.image

Open eleninisioti opened this issue 1 year ago • 4 comments

I have seen this functionality mentioned in other issues, but I cannot find code that works. Here is my code that attempts to get an image of the environment at every step:

import jax.numpy as jnp
from brax import envs
from brax.io import html
from brax.io import image
from jax import random
import jax



if __name__ == "__main__":

    env_name = "ant"
    backend = "generalized"
    episode_length = 5

    env = envs.get_environment(backend=backend, env_name=env_name)

    state = jax.jit(env.reset)(random.PRNGKey(0))

    cum_reward = 0
    states = []
    for step in range(1, episode_length + 1):

        action = jnp.array([0]*env.action_size)
        state = jax.jit(env.step)(state, action)
        states.append(state.pipeline_state)
        cum_reward += state.reward

        # this line gives the error. comment it  out to get the html video
        step_image = image.render(env.sys, state)

        if state.done:
            break

    render = html.render(env.sys, states)

    with open("traj.html", "w") as f:
        f.write(render)

    print("Episode ended. Total reward: " + str(cum_reward))

This throws an error

  File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/brax/io/image.py", line 34, in render_array
    renderer = mujoco.Renderer(sys.mj_model, height=height, width=width)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/mujoco/renderer.py", line 83, in __init__
    self._mjr_context = _render.MjrContext(
                        ^^^^^^^^^^^^^^^^^^^
mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
Exception ignored in: <function Renderer.__del__ at 0x7fa3777ec860>
Traceback (most recent call last):
  File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/mujoco/renderer.py", line 330, in __del__
    self.close()
  File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/mujoco/renderer.py", line 318, in close
    if self._mjr_context:
       ^^^^^^^^^^^^^^^^^
AttributeError: 'Renderer' object has no attribute '_mjr_context'

If you comment out the line calling image.render, you will see that html.render works, so I am not sure why one works and the other does not.

I am using brax 0.10.0 and mujoco 3.1.2.

eleninisioti avatar Feb 22 '24 20:02 eleninisioti

You need to install dependencies for rendering. See https://pytorch.org/rl/reference/generated/knowledge_base/MUJOCO_INSTALLATION.html ('Prerequisite for rendering').

jamesheald avatar Feb 24 '24 16:02 jamesheald

Thank you! I installed the dependencies with sudo apt-get install libglfw3 libgl1-mesa-glx libosmesa6 and conda install -c conda-forge glew (because apt get could not find package libglew2.0).

I am not getting that error any more but I am getting this one

    step_image = image.render(env.sys, states)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eleni/anaconda3/envs/brax_env/lib/python3.11/site-packages/brax/io/image.py", line 66, in render
    frames[0].save(f, format=fmt)
    ^^^^^^^^^^^^^^
    AttributeError: 'numpy.ndarray' object has no attribute 'save'. Did you mean: 'ravel'?

My numpy version is 1.24.3

eleninisioti avatar Feb 24 '24 20:02 eleninisioti

I also encountered the same save error (using brax 0.10.0 and mujoco 3.1.2).

From looking at image.py in brax.io in the previous version of brax (0.9.4), it seems that each element of frames needs to be wrapped in Image.fromarray(....) before save is called. For example:

from PIL import Image

....

Image.fromarray(frames[0])

This looks like a bug to me.

jamesheald avatar Feb 24 '24 22:02 jamesheald

Thanks for the bug report and finding the issue, we'll push out a fix

btaba avatar Feb 27 '24 02:02 btaba

This should be fixed in 0825bcb74b53e36a62d50405a85e540fe2c25a95 , please let us know if that doesn't work! Closing for now

btaba avatar Feb 28 '24 00:02 btaba

There's another issue. The renderings are not in color. Here are some example gifs and the code used to produce them.

reacher ant

import jax.numpy as jnp
from brax import envs
from brax.io import image
from jax import random
import jax
from IPython.display import Image

env_name = "reacher"
backend = "generalized"
episode_length = 5

env = envs.get_environment(backend=backend, env_name=env_name)

state = jax.jit(env.reset)(random.PRNGKey(0))

cum_reward = 0
states = []
rollout = []
for step in range(1, episode_length + 1):

    rollout.append(state.pipeline_state)

    action = jnp.array([0]*env.action_size)
    state = jax.jit(env.step)(state, action)
    states.append(state.pipeline_state)
    cum_reward += state.reward

    if state.done:
        break

gif = Image(image.render(env.sys, rollout, fmt = 'gif'))
open('reacher.gif', 'wb').write(gif.data)

jamesheald avatar Feb 28 '24 13:02 jamesheald