RL
RL copied to clipboard
How do I pass a NEMO model for training? Or does the framework only supports HF.
This is how my nemo model looks like:
/path/to/runs/tts/model
│
├── lightning_logs/
├── model/
├── model--reduced_train_loss=4.0176-epoch=1-consumed_samples=91406336.0-last/
│ ├── context/
│ │ ├── <uuid_1>/
│ │ ├── <uuid_2>/
│ │ ├── <uuid_3>/
│ │ ├── <uuid_4>/
│ │ ├── io.json
│ │ └── model.yaml
│ │
│ └── weights/
│ ├── __0_0.distcp
│ ├── __0_1.distcp
│ ├── __1_0.distcp
│ ├── __2_0.distcp
│ ├── ... (many distcp shard files)
│ ├── __127_1.distcp
│ ├── common.pt
│ └── metadata.json
│
└── wandb/
I policy can I pass this somehow?
jHi @aayush-sarvam , do you mean you want to continue from checkpoint? In general, we accept HF model format, if HF is able to load it, NeMo RL dtensor should as well. Second, we encourage to use dtensor policy v2 (by setting this attribute https://github.com/NVIDIA-NeMo/RL/blob/a0755ebf48eaa479c50fbf44dd865e68f1d2d4f2/examples/configs/grpo_math_1B.yaml#L84 ) for dtensor backend.