lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

A reference cycle is detected related to the ThunderModule

Open kiya00 opened this issue 1 year ago • 5 comments

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

To Reproduce

import torch
import thunder
import weakref
import gc

mod = torch.nn.ReLU()
ref = weakref.ref(mod, lambda _: print("mod deleted!"))
opt_mod = thunder.jit(mod)
# opt_mod = torch.compile(mod)
ref_opt_mod = weakref.ref(opt_mod, lambda _: print("opt_mod deleted!"))
x = torch.randn(10, 10)
refx = weakref.ref(x, lambda _: print("x deleted!"))
opt_mod(x)
del x
del mod
del opt_mod
# gc.collect()
print("done!")  # done!

with the line of torch.comile(), it outputs:

x deleted!
opt_mod deleted!
mod deleted!
done!

with thunder it outputs:

x deleted!
done!
mod deleted!
opt_mod deleted!

@kshitij12345 detected there's a reference cycle:

import torch
import thunder
import weakref
import gc

mod = torch.nn.ReLU()
ref = weakref.ref(mod, lambda _: print("mod deleted!"))
opt_mod = thunder.jit(mod)
# opt_mod = torch.compile(mod)
ref_opt_mod = weakref.ref(opt_mod, lambda _: print("opt_mod deleted!"))
x = torch.randn(10, 10)
refx = weakref.ref(x, lambda _: print("x deleted!"))
opt_mod(x)
del x
del mod
del opt_mod
# gc.collect()
print("done!")  # done!

if ref_opt_mod() is not None:
    import refcycle
    graph = refcycle.snapshot()

    try:
        cycle = graph.shortest_cycle(ref_opt_mod())
        print("CYCLE FOUND FROM MOD")
    except ValueError:
        print("NO CYCLE FROM MOD")
        pass
# Save the latest cycle.
    cycle.export_json("cycle.json")
    cycle.export_image("cycle.png")

image

cc @apaz-cli

kiya00 avatar Aug 23 '24 12:08 kiya00

Second ref cycle (one of the object here holds onto the user module) - image

Repro

def foo():
    import torch
    import thunder
    import weakref
    import gc

    mod = torch.nn.ReLU()
    ref_mod = weakref.ref(mod, lambda _: print("mod deleted!"))
    opt_mod = thunder.jit(mod)
    ref_opt_mod = weakref.ref(opt_mod, lambda _: print("opt_mod deleted!"))
    x = torch.randn(10, 10)
    refx = weakref.ref(x, lambda _: print("x deleted!"))
    opt_mod(x)
    del x
    del mod
    del opt_mod
    # gc.collect()
    print("done!")  # done!

    if ref_mod() is not None:
        import refcycle
        graph = refcycle.snapshot()

        try:
            cycle = graph.shortest_cycle(ref_mod())
            print("CYCLE FOUND FROM MOD")
        except ValueError:
            print("NO CYCLE FROM MOD")
            pass

        # More cycles are found here
        for anc in graph.ancestors(ref_mod()):
            try:
                cycle = graph.shortest_cycle(anc)
                print("CYCLE FOUND FROM ANCESTOR")
                print(anc)

                # Check the cycle from above
                # print(anc["prologue"].__wrapped__.__wrapped__.__wrapped__.__globals__["prologue"] is anc["prologue"])  # True
                # print(anc["prologue"].__wrapped__.__wrapped__.__wrapped__.__globals__["__function_obj"])
                break
            except ValueError:
                pass

        # for obj in cycle:
        #     print(obj)

        # Save the latest cycle.
        cycle.export_json("cycle.json")
        cycle.export_image("cycle.png")

foo()

kshitij12345 avatar Aug 23 '24 13:08 kshitij12345

So regarding the priority, as discussed in slack: From what I can see, this cycle keeps modules going out of scope from being collected. Not nice, but for the most part, I don't think we will be compiling short-lived modules, so it might not be a game-breaker right now.

I looked at this a bit. In general:

  • I tried sprinkling a few del fn_.__wrapped__ but apparently did not find all of interest, the ref cycle changed but did not go away. I have not been able to select the
  • In general, I do not think we can avoid the ThunderModule or compiled functions to point back to the Module / original, but we might make any reverse reference a weakref. (we already try in compile_data._thunder_module_map.)
  • The ThunderModule / compiled function will need to hold on to traces / functions built from traces.
  • So likely, to avoid cycles between the ThunderModule we should avoid having traces etc. ThunderModule / compile module without weakrefs.

WDYT?

t-vi avatar Sep 07 '24 16:09 t-vi

Not nice, but for the most part, I don't think we will be compiling short-lived modules, so it might not be a game-breaker right now.

It's a game-breaker because it blocks the usage of the Thunder-optimized dropout layer in a larger module as

self.dropout = thunder.jit(nn.Dropout(p=0.5))

IvanYashchuk avatar Sep 26 '24 16:09 IvanYashchuk

It's a game-breaker because it blocks the usage of the Thunder-optimized dropout layer in a larger module as

I would like to understand this more. Is it a game-breaker because you disagree that it is not as relevant for long-lived modules or because you expect the modules to be short-lived?

t-vi avatar Sep 26 '24 17:09 t-vi

I'm sorry I confused this issue with https://github.com/Lightning-AI/lightning-thunder/issues/1074. I don't have an important use case for fixing this bug.

IvanYashchuk avatar Sep 27 '24 08:09 IvanYashchuk