pyjulia icon indicating copy to clipboard operation
pyjulia copied to clipboard

Segmentation Fault When Using PyJulia Inside of PyTorch Custom Autograd Function

Open THargreaves opened this issue 3 years ago • 8 comments

I have a (vector-to-scalar) function and corresponding derivative function written in Julia that I am unable to translate to Python. I would like to use these within PyTorch by defining a custom autograd function. As a simple, reproducible example, let's say the function is sum():

import numpy as np
import torch
from julia import Main

class JuliaSum(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        
        x = input.cpu().detach().numpy()

        return torch.FloatTensor([Main.sum(x)]).to('cuda')
        # return torch.FloatTensor([np.sum(x)]).to('cuda')

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        x = input.cpu().detach().numpy()

        y = torch.FloatTensor(Main.ones(len(x))).to('cuda')
        # y = torch.FloatTensor(np.ones(len(x))).to('cuda')

        return grad_output * y

input = torch.FloatTensor([0.1, 0.2, 0.3]).to('cuda').requires_grad_()

# Works — outputs `tensor([0.6000], device='cuda:0', grad_fn=<JuliaSumBackward>)`
y = JuliaSum.apply(input)
print(y)

# Works — outputs `tensor([1., 1., 1.], device='cuda:0') `
x = input.cpu().detach().numpy().astype(np.float64)
y_test = torch.FloatTensor(Main.ones(len(x))).to('cuda')
print(torch.ones(1).to('cuda') * y_test)

# Doesn't work — segmentation fault
y.backward(torch.ones(1).to('cuda'))
print(input.grad)

Calling the forward method works fine, as does running the code contained in the backward method from the global scope. However, when I call the backward method, I receive:

signal (11): Segmentation fault
in expression starting at none:0
Allocations: 3652709 (Pool: 3650429; Big: 2280); GC: 5
Segmentation fault (core dumped)         

The exact line command causing the issue is Main.ones(len(x)). Replacing this with Main.ones(3) still causes a segmentation fault, so it appears to be an issue with PyJulia accessing memory that has been deallocated.

Also note that when I replace the two calls to Julia with the corresponding NumPy commands (left commented-out), the backward method works fine. The code also works when all tensors are on the CPU but my application requires GPU-acceleration.

What is causing this segmentation fault, and how can alter my code to avoid it whilst keeping PyTorch tensors on the GPU?


I've included a Dockerfile that matches my environment to make reproducing this issue as simple as possible. For reference, I am using an RTX 3060.

FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 

ARG PYTHON_VERSION=3.10.1
ARG JULIA_VERSION=1.7.1

ENV container docker
ENV DEBIAN_FRONTEND noninteractive
ENV LANG en_US.utf8
ENV MAKEFLAGS -j4

RUN mkdir /app
WORKDIR /app

# DEPENDENCIES
#===========================================
RUN apt-get update -y && \
    apt-get install -y gcc make wget libffi-dev \
        build-essential libssl-dev zlib1g-dev \
        libbz2-dev libreadline-dev libsqlite3-dev \
        libncurses5-dev libncursesw5-dev xz-utils \
        git

# INSTALL PYTHON
#===========================================
RUN wget https://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz && \
    tar -zxf Python-$PYTHON_VERSION.tgz && \
    cd Python-$PYTHON_VERSION && \
    ./configure --with-ensurepip=install --enable-shared && make && make install && \
    ldconfig && \
    ln -sf python3 /usr/local/bin/python
RUN python -m pip install --upgrade pip setuptools wheel && \
    python -m pip install julia numpy torch

# INSTALL JULIA
#====================================
RUN wget https://raw.githubusercontent.com/abelsiqueira/jill/main/jill.sh && \
    bash /app/jill.sh -y -v $JULIA_VERSION && \
    export PYTHON="python" && \
    julia -e 'using Pkg; ENV["PYTHON"] = "/usr/local/bin/python"' && \
    python -c 'import julia; julia.install()'

# CLEAN UP
#===========================================
RUN rm -rf /app/jill.sh \
    /opt/julias/*.tar.gz \
    /app/Python-$PYTHON_VERSION.tgz
RUN apt-get purge -y gcc make wget zlib1g-dev libffi-dev libssl-dev \
        libbz2-dev libreadline-dev \
        libncurses5-dev libncursesw5-dev xz-utils && \
    apt-get autoremove -y

CMD ["/bin/bash"]

THargreaves avatar Dec 14 '22 09:12 THargreaves

  1. What version of pyjulia are you using?
  2. Did this work with another version of pyjulia?
  3. Have you tried to use juliacall instead of pyjulia?

https://pypi.org/project/juliacall/

Why are you using Julia 1.7.1?

mkitti avatar Dec 14 '22 13:12 mkitti

  1. What version of pyjulia are you using?

0.6.0

  1. Did this work with another version of pyjulia?

This is the only version I have tried. Is there a particular version you would suggest?

  1. Have you tried to use juliacall instead of pyjulia?

https://pypi.org/project/juliacall/

I tried after this suggestion and also obtained a segmentation fault.

Why are you using Julia 1.7.1?

My actual function was written for Julia 1.7.1 so for ease of reproducibility I was sticking with that. I have tried the sum example using 1.8.3 and 1.6.7 LTS and the error persists.

THargreaves avatar Dec 15 '22 11:12 THargreaves

Does the issue occur with pyjulia v0.5.7?

My guess is yes, but I just wanted to double check since there were a few major changes with 0.6.0.

mkitti avatar Dec 16 '22 03:12 mkitti

I'm not sure if this is related, but there was also reports of a segmentation fault with PySR: https://github.com/MilesCranmer/PySR/issues/238

That appears to mainly occur on Windows, however, and is difficult to reproduce.

I don't suppose you are able to run a debugging in your environment, are you?

mkitti avatar Dec 16 '22 03:12 mkitti

Is this the same segfault as observed in these?

  • https://github.com/JuliaPy/pyjulia/issues/499
  • https://github.com/JuliaPy/pyjulia/issues/125
  • https://github.com/pytorch/pytorch/issues/78829

MilesCranmer avatar Dec 21 '22 00:12 MilesCranmer

Does the issue occur with pyjulia v0.5.7?

It does.

I don't suppose you are able to run a debugging in your environment, are you?

I can get something rudimentary using pdb. It's a bit hard to interpret given that the segmentation fault doesn't occur in the Python code but I can find the last line of Python called before the segmentation fault.

I think this might be a big hint as to what is going wrong but I'm not sure how to interpret it. Here is the debugging output.

> /usr/local/lib/python3.10/site-packages/julia/libjulia.py(114)__getattr__()
-> return getattr(self.libjulia, name)
(Pdb) s
--Return--
> /usr/local/lib/python3.10/site-packages/julia/libjulia.py(114)__getattr__()-><_FuncPtr obj...x7f8310991540>
-> return getattr(self.libjulia, name)
(Pdb) s

signal (11): Segmentation fault
in expression starting at none:0
Allocations: 6799632 (Pool: 6795414; Big: 4218); GC: 8
Segmentation fault (core dumped)

The first call to libjulia.py(114) runs but then this is immediately followed by some sort of pointer which sets off the segmentation fault.

THargreaves avatar Jan 02 '23 14:01 THargreaves

I don't think the linked issues are related as I am running on Linux and have no trouble using PyJulia outside of the backward method.

THargreaves avatar Jan 02 '23 14:01 THargreaves

I am also facing this issue when using julia in backward() call on gpu (and it only happens when torch uses gpu). Something that I have noticed:

  • When there are no print and @threads calls juliacall works most of the time. PyJulia segfaults even on a call to eval("1")
  • Both PyJulia and juliacall have issues on backward() call, but PyJulia segfaults and julicall hangs.

Also juliacall seems to mostly work even on nontrivial calls when there are no uses of Julia runtime. @THargreaves I believe your code can work with juliacall if the vectors were pre-allocated on python side.

Here's my MWE, it

from julia.api import Julia
jl = Julia(compiled_modules=False)

class MyLoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, dat):
        return torch.full((1,1), 1., device=dat.device)
    
    @staticmethod
    def backward(ctx, grad_output):
        jl.eval('1') # without this line no deadlock happens
        # jl.seval('println(1)') # in juliacall seval('1') runs without problem, but println hangs the program (no segfault)
        return None

device = "cuda:0" # on cpu works fine
# device = "cpu"

dat = torch.full((5,1), 1.).to(device)
model = torch.nn.Linear(1,1).to(device).train()
output = model(dat)
loss = MyLoss.apply(output)

loss.backward() # segfaults here

I'm using Ubuntu, torch 2.0.0.post200, Julia 1.9.0, pyjulia '0.6.1'

artemsolod avatar Jun 07 '23 19:06 artemsolod