pytorch_tabular
pytorch_tabular copied to clipboard
Fix devices_list type validation error in TrainerConfig
Problem
When specifying devices_list in TrainerConfig, users encountered an OmegaConf validation error:
from pytorch_tabular.config import TrainerConfig
from omegaconf import OmegaConf
# This fails with ValidationError
trainer_config = TrainerConfig(devices_list=[1, 2, 3, 4])
config = OmegaConf.structured(trainer_config)
Error:
omegaconf.errors.ValidationError: Value '[1, 2, 3, 4]' of type 'list' could not be converted to Integer
full_key: devices
object_type=TrainerConfig
Root Cause
The devices field in TrainerConfig was typed as Optional[int], but the __post_init__ method assigns the devices_list (a list) to devices when provided. When OmegaConf wraps the config with OmegaConf.structured(), it validates types and rejects the list assignment.
Solution
Changed the type annotation of the devices field from Optional[int] to Any. This approach:
- Allows both integer and list values, matching PyTorch Lightning's Trainer API which accepts both
- Works correctly with OmegaConf (note:
Union[int, List[int]]is not supported by OmegaConf for container types) - Maintains full backward compatibility with existing code
Changes
src/pytorch_tabular/config/config.py:
- Changed
devicesfield type fromOptional[int]toAny - Updated docstring to indicate
devicesacceptsUnion[int, List[int]]
tests/test_config.py (new file):
- Added comprehensive test coverage for devices/devices_list functionality
- Tests cover the documented use case (
devices_list=[0, 1]) - Tests verify backward compatibility with integer
devicesvalues - Tests validate config merging scenarios as used in TabularModel
Testing
All scenarios now work correctly:
# Multiple GPUs (from the issue)
TrainerConfig(devices_list=[1, 2, 3, 4]) # ✅ Works
# Documented example
TrainerConfig(devices_list=[0, 1]) # ✅ Works
# Backward compatibility
TrainerConfig(devices=2) # ✅ Still works
TrainerConfig() # ✅ Default devices=-1 still works
Backward Compatibility
✅ Fully backward compatible - all existing code using devices=<int> continues to work unchanged.
Fixes #issue_number
[!WARNING]
Firewall rules blocked me from connecting to one or more addresses (expand for details)
I tried to connect to the following addresses, but was blocked by firewall rules:
archive.ics.uci.edu
- Triggering command:
python -m pytest tests/test_config.py -v(dns block)If you need me to access, download, or install something from one of these locations, you can either:
- Configure Actions setup steps to set up my environment, which run before the firewall is enabled
- Add the appropriate URLs or hosts to the custom allowlist in this repository's Copilot coding agent settings (admins only)
Original prompt
This section details on the original issue you should resolve
<issue_title>devices and devices_list type issues</issue_title> <issue_description>https://github.com/manujosephv/pytorch_tabular/blob/023db2776f96a0f2854e837eef62840be1a12a5e/src/pytorch_tabular/config/config.py#L565C9-L566C45
When specifying a devices_list, this line of code causes
omegaconf.errors.ValidationError: Value '[1, 2, 3, 4]' of type 'list' could not be converted to Integer full_key: devices object_type=TrainerConfig ```</issue_description> ## Comments on the Issue (you are @copilot in this section) <comments> </comments>
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.