[BUG] TFT + categorical features seems not to be compatible with DDP in some situations.
Describe the bug
Ive recently been updating package dependencies in my project (python, pytorch, lighting). Without changing anything else in my code or hardware, aside from the lightning import convention, I now get RunTimeErrors when trainign a TFT model, in with ddp.
Each rank immediatly returns errors similar to this, but with different shapes.
[rank5]: RuntimeError: [5]: params[0] in this process with sizes [53, 15] appears not to match sizes of the same param in process 0.
I believe this is because ddp strategy is to restart processes of the code which then instantiates separate versions of the model on subsets of the data. However the pytorch-forecasting implementation of TFT encodes categorical features internally. If different subsets of the data have different categorical values, the shapes wont match.
Expected behavior
TFT with categorical variables should support ddp training strategy.
Additional context
I'm training on a single EC2 node with 8 GPUs.
Trainer( accelerator="gpu", strategy="ddp", devices=1, ...
works but is slow:
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
This doesn't work:
Trainer( accelerator="gpu", strategy="ddp", devices=8, ...
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 8 processes
----------------------------------------------------------------------------------------------------
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Versions doesn't work:
python = "~=3.11.0"
pytorch-forecasting = "~=1.2.0"
pytorch-lightning = "==2.0.0"
torch = [
{ version = "==2.5.1+cu118", source = "pytorch-cuda", markers = "sys_platform =='linux' and platform_machine== 'x86_64'" },
{ version = "==2.5.1", source = "picnic", markers = "sys_platform== 'darwin'" },
]
works:
[tool.poetry.dependencies]
python = "~=3.10.0"
pytorch-forecasting = "~=0.10.2"
pytorch-lightning = "~=1.8.0"
torch = [
{ version = "==1.13.1+cu117", source = "pytorch-cuda", markers = "sys_platform=='linux' and platform_machine == 'x86_64'" },
{ version = "==1.13.1", source = "picnic", markers = "sys_platform == 'darwin'" },
]
Manually setting the embedding_sizes when initialising the model with .from_dataset solved the size mismatch issue, showing that this is the cause of the bug.
ie.
# embedding_sizes = {"category_column": (num_categories, embedding_dim)}
embedding_sizes = {'store_id': (100, 50),
'weekday_name': (7, 50),
'month_name': (12, 50) }
tft = TemporalFusionTransformer.from_dataset(
dataset,
embedding_sizes=embedding_sizes,
**hyperparameters,
...
)
However, (I think) this will mean that the categorical values will have different labels and vectors in the embedding space, so when the different models communicate weights they won't refer to the same thing rendering all categorical variables useless. embedding_labels should be pre-computed globally as well.
Hi, I'm facing the same issue when training a TFT model with categorical variables using the DDP strategy. The model works fine with devices=1, but fails with devices=4, throwing similar RuntimeError messages about mismatched parameter sizes across ranks.
Does anyone know how to overcome this?
Thanks in advance!
FYI @fnhirwa, @phoeenniixx, @PranavBhatP - any idea?
@Marcrb2 I check the embedding size on each rank and then broadcast the max embedding size required via a CPU process. This is quite a hacky solution but it works for my dataset.
An important note about this solution: Different subsets of the data could potentially have different subsets sets of categorical values, this solution is guaranteed not to throw any tensor dimension error because it broadcasts the max set size. However, I'm not sure how the label mapping behaves on different subsets. I tested this on my data and the label-value mapping always matched between subsets, but in highly skewed data it could be different. I would be interested to hear from anyone more knowledgeable about how the categorical encoders work under the hood if this can be an issue.
def get_embed_sizes_distributed(
df: pd.DataFrame, columns: list[str], rank: int = 0
) -> dict[str, tuple[int, int]]:
"""Return embedding sizes for categorical variables.
While ensuring consistency across distributed processes.
Args:
df: The input DataFrame (subset on each rank)
columns: List of column names to analyze
rank: The current process rank
Returns:
dict[str, tuple[int, int]]: Embedding sizes synchronized across all processes
"""
# 1. Each rank counts locally
local_counts = {}
for col in columns:
if col in df.columns:
# number of unique elements +1 for nan
local_counts[col] = df[col].nunique() + 1
else:
local_counts[col] = 0
# 2. Synchronize counts across processes
if dist.is_initialized():
# Create a CPU process group
cpu_pg = gloo_manager.get_group()
# Convert counts to tensor
cols = [col for col in columns if col in df.columns]
counts_tensor = torch.tensor(
[local_counts[col] for col in cols], dtype=torch.long, device="cpu"
)
# Use max operation to get the highest count for each column
dist.all_reduce(counts_tensor, op=dist.ReduceOp.MAX, group=cpu_pg)
# Update local counts with global maximums
for i, col in enumerate(cols):
local_counts[col] = counts_tensor[i].item()
LOGGER.info("Rank %s synchronized embedding counts: %s", rank, local_counts)
# 3. Return embedding sizes with synchronized counts
# use pytorch default embedding sizes determined by number of unique values
return {
col: (local_counts.get(col, 1), get_embedding_size(n=local_counts.get(col, 1)))
for col in columns
if col in df.columns
}
embedding_sizes = get_embed_sizes_distributed(
train_data,
[
"categorical_column_1",
"categorical_column_2",
"categorical_column_3",
...
],
rank,
)
tft = TemporalFusionTransformer.from_dataset(
ts_dataset,
embedding_sizes=embedding_sizes,
**hyperparameters,
)
@fkiraly Do you know why the behavior of parallel process creation distribution and destruction would have changed so much with older pyTorch-forecasting, pyTorch, Lightning?
@mkuiack, thanks for sharing the code—much appreciated!
I have a quick question about the following line in the get_embed_sizes_distribution function:
cpu_pg = gloo_manager.get_group()
I'm trying to understand where gloo_manager is coming from. Is it part of a custom utility, or is it defined elsewhere in the code?
Any pointers on how it should be initialized or imported would be greatly appreciated.
Thanks again!
@Marcrb2
import torch.distributed as dist
class GlooGroupManager:
"""Manager for the distributed gloo process group."""
def __init__(self):
"""Initialize a new GlooGroupManager instance with no process group."""
self._process_group = None
def initialize(self):
"""Initialize the gloo process group once during startup."""
if self._process_group is None and dist.is_initialized():
self._process_group = dist.new_group(
backend="gloo", timeout=datetime.timedelta(days=10)
)
return self._process_group
def get_group(self):
"""Get the gloo process group, initializing if necessary."""
if self._process_group is None:
LOGGER.info("Initializing global gloo group")
return self.initialize()
return self._process_group
gloo_manager = GlooGroupManager()
The categorical encoders enbeddings are unfortunately closely coupled with the dataset, see design discussion here about uncoupling them in v2: https://github.com/sktime/pytorch-forecasting/issues/1736
@fkiraly Do you know why the old version of pytorch-forecasting worked without issue, then only after upgrading I started seeing these errors and had to develop the work-around above?