Training seg faults after a few iterations
Hello, I'm trying to train a simple network (mobilenet classifier) which seems fine but I'm getting a segfault after a few batches. Hoping maybe someone can point out what I'm doing wrong or some pointers to debug the seg fault since it just errors out with no decent traceback. Thanks!
Macbook Pro M2 Max 32GB
import itertools
import mlx.core as mx
import numpy as np
import mlx.nn as nn
import mlx.optimizers as optim
from datasets import load_dataset
import cv2
from tqdm import tqdm
class DSConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
super().__init__()
self.depth = [
nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
for _ in range(in_channels)
]
self.bn1 = nn.LayerNorm(in_channels)
self.point = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True)
self.bn2 = nn.LayerNorm(out_channels)
def __call__(self, x):
x = x.split(x.shape[-1], axis=-1) # Split across channels
depth = mx.concatenate([l(_x) for l, _x in zip(self.depth, x)], axis=-1)
point = self.point(nn.relu(self.bn1(depth)))
return nn.relu(self.bn2(point))
class MobileNet(nn.Module):
def __init__(self, input_channels, num_classes, slim: bool = False):
super().__init__()
self.input_conv = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1)
self.bn1 = nn.LayerNorm(32)
layers = [
DSConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
DSConv(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
DSConv(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
DSConv(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
DSConv(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
DSConv(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
]
if not slim:
for _ in range(5):
layers += [
DSConv(512, 512, kernel_size=3, stride=1, padding=1),
]
layers += [
DSConv(512, 1024, kernel_size=3, stride=2, padding=1),
DSConv(1024, 1024, kernel_size=3, stride=2, padding=4),
]
self.layers = nn.Sequential(*layers)
self.linear = nn.Linear(1024, num_classes)
def __call__(self, x):
x = nn.relu(self.bn1(self.input_conv(x)))
x = self.layers(x)
x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1)
x = self.linear(x)
return x
def grouper(iterable, n, *, incomplete="fill", fillvalue=None):
"Collect data into non-overlapping fixed-length chunks or blocks"
args = [iter(iterable)] * n
match incomplete:
case "fill":
return itertools.zip_longest(*args, fillvalue=fillvalue)
case "strict":
return zip(*args, strict=True)
case "ignore":
return zip(*args)
case _:
raise ValueError("Expected fill, strict, or ignore")
def collate(rows):
# resize with openCV
_images = [np.array(item["image"]) for item in rows]
_images = [cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) if len(im.shape) == 2 else im for im in _images]
_images = [cv2.resize(im, (224, 224), interpolation=cv2.INTER_CUBIC) for im in _images]
images = np.array(_images, dtype=np.float32)
images /= 255.0
labels = np.array([item["label"] for item in rows], dtype=np.uint32)
return mx.array(images), mx.array(labels) # ( (b,h,w,c), (b,c) )
def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y))
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
def main():
batch_size = 2
model = MobileNet(input_channels=3, num_classes=1000, slim=True)
mx.eval(model.parameters())
loss_helper = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=0.1)
datasets = load_dataset("imagenet-1k", trust_remote_code=True)
for rows in tqdm(grouper(datasets["train"], batch_size, fillvalue="ignore")):
images, labels = collate(rows)
_loss, grads = loss_helper(model, images, labels)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
if __name__ == "__main__":
main()
21it [00:11, 1.81it/s]zsh: segmentation fault python3 mobilenet/main.py /Users/bento/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d '
Hmm. One thing I'm wondering is if you can try just looping over the data without using MLX. Just to make sure this is an MLX issue and not something to do with the datasets package you are using.
Also good to monitor your memory as you do so see if there is a leak or if you are using way too much. (Use activity monitor or asitop).
Hmm. One thing I'm wondering is if you can try just looping over the data without using MLX. Just to make sure this is an MLX issue and not something to do with the datasets package you are using.
Also good to monitor your memory as you do so see if there is a leak or if you are using way too much. (Use activity monitor or asitop).
When I uncomment these lines, I'm able to loop through the entire dataset just fine.
_loss, grads = loss_helper(model, images, labels)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
Looking at the memory usage, I suspect it's due to out of memory
It doesn't look to be out of memory. And it definitely shouldn't segfault. Does it segfault reliably for you? How far into the training?
I'm running your script on an M1 Max with 32 GB. So far no segfault 🤷♂️ , I'm at iteration 600. Did it segfault before that?
Also what's your OS? What version of MLX are you using? (Commit hash if from source?)
Sonoma 14.2.1 M2 Max 32 GB Python 3.11.7
Yeah, I've had it segfault right away before, it's very sporadic. Sometimes it just hangs and I have to go and kill the process manually.
mlx ❯ python3 mobilenet/train.py
9it [00:04, 2.30it/s]zsh: segmentation fault python3 mobilenet/train.py
/Users/bento/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
~/Repos/mlx-stuff/mlx-playground main* 8s
mlx ❯ python3 mobilenet/train.py
69it [01:18, 2.02it/s]zsh: segmentation fault python3 mobilenet/train.py
/Users/bento/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
~/Repos/mlx-stuff/mlx-playground main* 1m 22s
mlx ❯ python3 mobilenet/train.py
12it [00:05, 2.29it/s]
It doesn't look to be out of memory. And it definitely shouldn't segfault. Does it segfault reliably for you? How far into the training?
you're right, I thought the little widget on the right was tracking memory
What about your MLX version (or commit hash if building from source)?
0.0.6
Not sure if it helps but earlier I saw a bus error instead of a seg fault.
I can't reproduce it either. I left it running for about an hour on my M2 air. My initial thought was that it had to do with the implementation of separable convolution which ends up having 1000 layers and concatenating 1000 arrays but it doesn't seem to cause a problem at all.