nuplan-devkit
nuplan-devkit copied to clipboard
Problem running run_simulation.py with ml_planner
I am trying to simulate using raster_model as MLPlanner. But I'm having problems loading the model, it looks like there are missing parameters in the LightningModuleWrapper class creation. I use the following command to perform the simulation task.
python run_simulation.py +simulation=open_loop_boxes model=raster_model planner=ml_planner planner.ml_planner.model_config=\${model} scenario_builder=nuplan_mini scenario_filter=all_scenarios scenario_filter.scenario_types="[near_multiple_vehicles, on_pickup_dropoff, starting_unprotected_cross_turn, high_magnitude_jerk]" scenario_filter.num_scenarios_per_type=10
The error message indicates that three parameters are missing.
2024-08-05 10:19:43,437 INFO {/home/fengqi/nuplan-devkit/nuplan/planning/script/builders/model_builder.py:18} Building TorchModuleWrapper...
2024-08-05 10:19:43,796 INFO {/home/fengqi/anaconda3/envs/nuplan/lib/python3.9/site-packages/timm/models/helpers.py:244} Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth)
2024-08-05 10:19:43,831 INFO {/home/fengqi/anaconda3/envs/nuplan/lib/python3.9/site-packages/timm/models/helpers.py:269} Converted input conv conv1 pretrained weights from 3 to 4 channel(s)
2024-08-05 10:19:43,846 INFO {/home/fengqi/nuplan-devkit/nuplan/planning/script/builders/model_builder.py:21} Building TorchModuleWrapper...DONE!
Error executing job with overrides: ['+simulation=open_loop_boxes', 'planner=ml_planner', 'scenario_builder=nuplan_mini', 'scenario_filter=all_scenarios', 'scenario_filter.scenario_types=[near_multiple_vehicles, on_pickup_dropoff, starting_unprotected_cross_turn, high_magnitude_jerk]', 'scenario_filter.num_scenarios_per_type=10']
Traceback (most recent call last):
File "/home/fengqi/nuplan-devkit/nuplan/planning/script/run_simulation.py", line 110, in main
run_simulation(cfg=cfg)
File "/home/fengqi/nuplan-devkit/nuplan/planning/script/run_simulation.py", line 66, in run_simulation
runners = build_simulations(
File "/home/fengqi/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py", line 90, in build_simulations
planners = build_planners(cfg.planner, scenario)
File "/home/fengqi/nuplan-devkit/nuplan/planning/script/builders/planner_builder.py", line 58, in build_planners
return [_build_planner(planner, scenario) for planner in planner_cfg.values()]
File "/home/fengqi/nuplan-devkit/nuplan/planning/script/builders/planner_builder.py", line 58, in <listcomp>
return [_build_planner(planner, scenario) for planner in planner_cfg.values()]
File "/home/fengqi/nuplan-devkit/nuplan/planning/script/builders/planner_builder.py", line 26, in _build_planner
model = LightningModuleWrapper.load_from_checkpoint(
File "/home/fengqi/anaconda3/envs/nuplan/lib/python3.9/site-packages/pytorch_lightning/core/saving.py", line 157, in load_from_checkpoint
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
File "/home/fengqi/anaconda3/envs/nuplan/lib/python3.9/site-packages/pytorch_lightning/core/saving.py", line 199, in _load_model_state
model = cls(**_cls_kwargs)
TypeError: __init__() missing 3 required positional arguments: 'objectives', 'metrics', and 'batch_size'
Below is the code that locates the location where the error is reported.
model = LightningModuleWrapper.load_from_checkpoint(
planner_cfg.checkpoint_path, model=torch_module_wrapper
).model
Here's the init function for LightningModuleWrapper.
class LightningModuleWrapper(pl.LightningModule):
"""
Lightning module that wraps the training/validation/testing procedure and handles the objective/metric computation.
"""
def __init__(
self,
model: TorchModuleWrapper,
objectives: List[AbstractObjective],
metrics: List[AbstractTrainingMetric],
batch_size: int,
optimizer: Optional[DictConfig] = None,
lr_scheduler: Optional[DictConfig] = None,
warm_up_lr_scheduler: Optional[DictConfig] = None,
objective_aggregate_mode: str = 'mean',
) -> None:
Can you tell me how to solve this problem, is there an error in the command I am using?