lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Constant modeling for getitem

Open tfogal opened this issue 1 year ago • 10 comments

🚀 Feature

It would be useful if we would model constants and use that to optimize away getitem.

Motivation

The ThunderFX path sometimes ends up giving us some rather silly graphs:

  class DynamoModule(torch.nn.Module):
    def forward(self):
        scale_t = torch.tensor([2])
        getitem = scale_t[0]
        return (getitem, scale_t)

Our trace ends up correspondingly silly:

def computation():
  scale_t = torch.tensor([2], device=torch.device("cpu"), dtype=None)  # scale_t: "cpu i64[1]"
    # scale_t = ltorch.tensor([2], device=torch.device("cpu"), dtype=None, requires_grad=False, pin_memory=False)  # scale_t: "cpu i64[1]"
      # scale_t = prims.tensor_from_sequence([2], dtype=None, device=devices.Device("cpu"))  # scale_t: "cpu i64[1]"
  getitem = Tensor.__getitem__(scale_t, 0)  # getitem: "cpu i64[]"
    # getitem = ltorch.getitem(scale_t, 0)  # getitem: "cpu i64[]"
      # t7 = prims.slice_prim(scale_t, [0], [1], [1])  # t7: "cpu i64[1]"
      # getitem = prims.squeeze(t7, (0,))  # getitem: "cpu i64[]"
  return (getitem, scale_t)

g28.py - self contained test case.

Pitch

I'd love to see us produce this trace:

def computation():
  scale_t = torch.tensor([2], device=torch.device("cpu"), dtype=None)  # scale_t: "cpu i64[1]"
    # scale_t = ltorch.tensor([2], device=torch.device("cpu"), dtype=None, requires_grad=False, pin_memory=False)  # scale_t: "cpu i64[1]"
      # scale_t = prims.tensor_from_sequence([2], dtype=None, device=devices.Device("cpu"))  # scale_t: "cpu i64[1]"
  return (2, scale_t)

i.e. we elide the getitem call and just use a 2 directly. foo = 2 and return (foo, scale_t) is also reasonable for now.

Alternatives

Another idea is to do constant folding and DCE at the FX graph level, and if the graph is empty then we just don't even invoke thunder.jit at all. Then Thunder is only tasked with handling "real" computations.

If we reject this entirely, we could say that the user is responsible for e.g. torch.compiler.allow_in_graphing appropriately such that we never get isolated graphs like the one above.

Additional context

These are happening in a variety of subgraphs from the NeVA case, and I suspect we'll see them often as we look into supporting other networks verbatim.

cc @tfogal

tfogal avatar Oct 02 '24 20:10 tfogal

I for one would love to see a constant folding pass.

t-vi avatar Oct 02 '24 20:10 t-vi

I have an example pass as transform for this specific case.

I assume down the line FusionExecutor (or transform_for_execution) may change the exact computation based on this, so this is a pre_prologue transform which is called before these execution passes is applied.

NOTE - Generated trace is in the snippet.

import torch
import thunder
from thunder.core.trace import from_trace
from thunder.core import utils
from collections.abc import Sequence
from thunder.core.pytree import tree_flatten
from thunder.core.proxies import variableify, Variable
import numbers

class GetItemFromScalarTensorConstantFold(thunder.Transform):
    def transform_traces_pre_prologue(self, prologue_trc, computation_trc, epilogue_trc, **kwargs):
        # Create a new trace
        const_folded_trace = from_trace(computation_trc)
        const_folded_trace.bound_symbols = computation_trc.bound_symbols

        producers = utils.producers(computation_trc)
        swap_map: dict[Variable, numbers.Number] = {}
        for bsym in const_folded_trace.bound_symbols:
            # If symbol `getitem` with zero-dim tensor output.
            if bsym.sym.id == thunder.torch.getitem.id and bsym.output.shape == ():
                getitem_bsym = bsym  # Nicer name for bsym
                getitem_producer = producers[getitem_bsym.args[0]]
                # Check if producer if `torch.tensor`.
                if getitem_producer.sym.id in (thunder.torch.tensor.id,):
                    # TODO: Support folding scalar elem from multi-dim tensor.
                    tensor_value = getitem_producer.args[0]
                    flat_values, _ = tree_flatten(tensor_value)
                    if isinstance(tensor_value, Sequence) and len(flat_values) == 1:
                        tensor_value = tensor_value[0]
                    elif isinstance(tensor_value, numbers.Number):
                        pass
                    else:
                        # tensor_values are multi-elem or 
                        # it could be an ndarray, etc (though currently thunder only supports Number or Sequence).
                        continue  

                    swap_map[variableify(getitem_bsym.output)] = tensor_value

        updated_bsyms = []
        for bsym in const_folded_trace.bound_symbols:
            updated_bsyms.append(bsym.from_bsym_swap_proxies(swap_map))

        const_folded_trace.bound_symbols = updated_bsyms
        const_folded_trace.set_provenance("GetItem from Scalar Tensor Const Folding pass")
        return prologue_trc, const_folded_trace, epilogue_trc


def forward():
    scale_t = torch.tensor([2])
    getitem = scale_t[0]
    return (getitem, scale_t)

jf = thunder.jit(forward, transforms=[GetItemFromScalarTensorConstantFold()])
jf()

print(thunder.last_traces(jf)[-1])

# def computation():
#   scale_t = torch.tensor([2], device=torch.device("cpu"), dtype=None)  # scale_t: "cpu i64[1]"
#     # scale_t = ltorch.tensor([2], device=torch.device("cpu"), dtype=None, requires_grad=False, pin_memory=False)  # scale_t: "cpu i64[1]"
#       # scale_t = prims.tensor_from_sequence([2], dtype=None, device=devices.Device("cpu"))  # scale_t: "cpu i64[1]"
#   return (2, scale_t)


def forward(x):
    scale_t = torch.tensor([2])
    getitem = scale_t[0]  # This will be const folded

    # This won't be folded for now because `torch.tensor` has more than 1 elem.
    # Though, will get it working in future.
    getitem_2 = torch.tensor([2, 2])[0]
    return x + getitem + getitem_2

jf = thunder.jit(forward, transforms=[GetItemFromScalarTensorConstantFold()])
jf(torch.randn(3, 3, requires_grad=True))

print(thunder.last_traces(jf)[-1])

# def computation(x):
#   # x: "cpu f32[3, 3]"
#   t3 = torch.tensor([2, 2], device=torch.device("cpu"), dtype=None)  # t3: "cpu i64[2]"
#     # t3 = ltorch.tensor([2, 2], device=torch.device("cpu"), dtype=None, requires_grad=False, pin_memory=False)  # t3: "cpu i64[2]"
#       # t3 = prims.tensor_from_sequence([2, 2], dtype=None, device=devices.Device("cpu"))  # t3: "cpu i64[2]"
#   t7 = torch.add(x, 2.0)  # t7: "cpu f32[3, 3]"
#     # t7 = ltorch.add(x, 2.0, alpha=None)  # t7: "cpu f32[3, 3]"
#       # t7 = prims.add(x, 2.0)  # t7: "cpu f32[3, 3]"
#   t11 = torch_slice_prim_impl(t3, [0], [1], [1])  # t11: "cpu i64[1]"
#   del t3
#   getitem_2 = torch.squeeze(t11, (0,))  # getitem_2: "cpu i64[]"
#     # getitem_2 = ltorch.squeeze(t11, (0,))  # getitem_2: "cpu i64[]"
#       # getitem_2 = prims.squeeze(t11, (0,))  # getitem_2: "cpu i64[]"
#   del t11
#   t14 = Tensor.to(getitem_2, torch.float32, copy=True)  # t14: "cpu f32[]"
#     # t14 = ltorch.to(getitem_2, torch.float32, None, device=None, dtype=None, copy=True, memory_format=None)  # t14: "cpu f32[]"
#       # t14 = prims.convert_element_type(getitem_2, dtypes.float32)  # t14: "cpu f32[]"
#   del getitem_2
#   t9 = torch.add(t7, t14)  # t9: "cpu f32[3, 3]"
#     # t9 = ltorch.add(t7, t14, alpha=None)  # t9: "cpu f32[3, 3]"
#       # t9 = prims.add(t7, t14)  # t9: "cpu f32[3, 3]"
#   del t7, t14
#   return {'output': t9, 'flat_args': [x], 'flat_output': (t9,)}, ((), ())

Just wanted to double check if this is what you had in mind. Also, should it be present by default in thunder.jit with an opt-out flag? Or should we provide this as a transform and user can opt-in by passing it via transforms argument to thunder.jit.

kshitij12345 avatar Oct 03 '24 13:10 kshitij12345

As per offline chat with @t-vi, here is a more generic approach. Is this roughly what you had in mind?

Thanks @t-vi for the idea and suggestions!!

import torch
import thunder
from thunder.core.trace import from_trace, TraceCtx, tracectx
from thunder.core import utils
from collections.abc import Sequence
from thunder.core.pytree import tree_flatten
from thunder.core.proxies import variableify, Variable, TensorProxy, NumberProxy
from thunder.core.symbol import BoundSymbol
from thunder.core.dtypes import to_dtype
from thunder.core.devices import to_device
import numbers
from thunder.core.proxies import ProxyTag
from thunder.torch import _torch_to_thunder_function_map

_thunder_to_torch_function_map = {v: k for k,v in _torch_to_thunder_function_map.items()}

ProxyTag.register_tag("CONSTANT_VALUE")

def is_constant(proxy):
    if isinstance(proxy, TensorProxy) and ProxyTag.CONSTANT_VALUE in proxy.tags:
        return True
    elif isinstance(proxy, NumberProxy) and proxy.is_static_constrained():
        return True
    return False


def compute_with_materialized_tensor(bsym, const_values):

    def materialize_args(a):
        if isinstance(a, TensorProxy):
            return const_values[variableify(a)]
        elif isinstance(a, NumberProxy):
            return a.value
        return a

    new_args = tuple(map(materialize_args, bsym.args))
    new_kwargs = {k: materialize_args(v) for k, v in bsym.kwargs.items()}
    torch_fn = _thunder_to_torch_function_map.get(bsym.sym, None)
    if torch_fn is None:
        return
    return torch_fn(*new_args, **new_kwargs)


class ConstantFold(thunder.Transform):
    def __init__(self, const_fold_scalar_tensor_from_getitem=True):
        self.const_fold_scalar_tensor_from_getitem = const_fold_scalar_tensor_from_getitem

    def transform_traces_pre_prologue(self, prologue_trc, computation_trc, epilogue_trc, **kwargs):
        # Create a new trace
        const_folded_trace = from_trace(computation_trc)
        const_folded_trace.bound_symbols = computation_trc.bound_symbols

        # Factory functions whose value we know.
        TENSOR_FACTORY = (thunder.torch.tensor.id, thunder.torch.ones.id, thunder.torch.zeros.id,)
        # TODO: Clear them at the last usage in trace.
        const_tensors: dict[Variable, BoundSymbol] = {}  

        # Tag output from factory functions as constant value.
        for bsym in const_folded_trace.bound_symbols:
            if bsym.sym.id in TENSOR_FACTORY:
                bsym.output.tags.add(ProxyTag.CONSTANT_VALUE)
                torch_fn = _thunder_to_torch_function_map[bsym.sym]
                t = torch_fn(*bsym.args, **bsym.kwargs)
                const_tensors[variableify(bsym.output)] = t

        new_bsyms = []
        swap_map_getitem = {}
        for bsym in const_folded_trace.bound_symbols:
            # If bsym has constant inputs, try to compute the output.
            if all(map(is_constant, bsym.flat_proxy_args)) and bsym.sym.id not in TENSOR_FACTORY:
                if bsym.flat_args == []:  # eg, unpack_trivial
                    continue
                new_concrete_output = compute_with_materialized_tensor(bsym, const_tensors)
                if new_concrete_output is not None:  # Might happen for `python_return` as it won't have mapping in `_thunder_to_torch_map`.

                    # Create a new symbol with same output proxy but which will now represent the computed constant value.
                    # eg. t = known_tensor + 1 where known_tensor = torch.tensor(2) --> t = torch.tensor(3)
                    new_bsym = BoundSymbol(thunder.prims.tensor_from_sequence,
                                           (new_concrete_output.tolist(),), {'dtype': to_dtype(new_concrete_output.dtype), 'device': to_device(new_concrete_output.device)},
                                           output=bsym.output)
                    new_bsyms.append(new_bsym)

                    # Update const_tensors (so that usage of the output of this symbol will also be used for further computation.)
                    const_tensors[variableify(bsym.output)] = new_concrete_output
                    bsym.output.tags.add(ProxyTag.CONSTANT_VALUE)
                    if self.const_fold_scalar_tensor_from_getitem and bsym.sym.id == thunder.torch.getitem.id and bsym.output.shape == ():
                        swap_map_getitem[variableify(bsym.output)] = new_concrete_output.tolist()
                    continue

            # BoundSymbol with non-constant inputs, keep it as-is
            new_bsyms.append(bsym)

        del const_tensors

        const_folded_trace.bound_symbols = new_bsyms

        # Replace 0-dim tensors from getitem with Python scalars.
        if self.const_fold_scalar_tensor_from_getitem:
            new_bsyms = []
            for bsym in const_folded_trace.bound_symbols:
                new_bsyms.append(bsym.from_bsym_swap_proxies(swap_map_getitem))

        const_folded_trace.bound_symbols = new_bsyms
        const_folded_trace.set_provenance("Const Folding pass")
        return prologue_trc, const_folded_trace, epilogue_trc


def forward():
    scale_t = torch.tensor([2])
    getitem = scale_t[0]
    return (getitem, scale_t)

jf = thunder.jit(forward, transforms=[ConstantFold()])
jf()

print(thunder.last_traces(jf)[-1])

# def computation():
#   scale_t = torch.tensor([2], device=torch.device("cpu"), dtype=None)  # scale_t: "cpu i64[1]"
#     # scale_t = ltorch.tensor([2], device=torch.device("cpu"), dtype=None, requires_grad=False, pin_memory=False)  # scale_t: "cpu i64[1]"
#       # scale_t = prims.tensor_from_sequence([2], dtype=None, device=devices.Device("cpu"))  # scale_t: "cpu i64[1]"
#   return (2, scale_t)


def forward(x):
    scale_t = torch.tensor([2])
    getitem = scale_t[0]
    getitem_2 = torch.tensor([2, 2])[0]
    return x + getitem + getitem_2

jf = thunder.jit(forward, transforms=[ConstantFold()])
jf(torch.randn(3, 3, requires_grad=True))

# print(thunder.last_traces(jf)[-1])

# def computation(x):
#   # x: "cpu f32[3, 3]"
#   t7 = torch.add(x, 2.0)  # t7: "cpu f32[3, 3]"
#     # t7 = ltorch.add(x, 2.0, alpha=None)  # t7: "cpu f32[3, 3]"
#       # t7 = prims.add(x, 2.0)  # t7: "cpu f32[3, 3]"
#   t9 = torch.add(t7, 2.0)  # t9: "cpu f32[3, 3]"
#     # t9 = ltorch.add(t7, 2.0, alpha=None)  # t9: "cpu f32[3, 3]"
#       # t9 = prims.add(t7, 2.0)  # t9: "cpu f32[3, 3]"
#   del t7
#   return {'output': t9, 'flat_args': [x], 'flat_output': (t9,)}, ((), ())

def forward(x):
    scale_t = torch.tensor([2], dtype=torch.float16)
    ones_t = torch.ones(1, dtype=torch.float32)
    s1 = scale_t * 2
    s2 = scale_t / 1
    s3 = s1 * s2
    ones_mul_10 = ones_t * 10
    return x[0, 0] + s3 + ones_mul_10

jf = thunder.jit(forward, transforms=[ConstantFold()])
forward(torch.randn(3, 3, requires_grad=True))
jf(torch.randn(3, 3, requires_grad=False))

# print(thunder.last_traces(jf)[-1])

# def computation(x):
#   # x: "cpu f32[3, 3]"
#   s3 = torch.tensor([8.0], device=torch.device("cpu"), dtype=torch.float16)  # s3: "cpu f16[1]"
#     # s3 = ltorch.tensor([8.0], device=torch.device("cpu"), dtype=torch.float16, requires_grad=False, pin_memory=False)  # s3: "cpu f16[1]"
#       # s3 = prims.tensor_from_sequence([8.0], dtype=dtypes.float16, device=devices.Device("cpu"))  # s3: "cpu f16[1]"
#   ones_mul_10 = torch.tensor([10.0], device=torch.device("cpu"), dtype=torch.float32)  # ones_mul_10: "cpu f32[1]"
#     # ones_mul_10 = ltorch.tensor([10.0], device=torch.device("cpu"), dtype=torch.float32, requires_grad=False, pin_memory=False)  # ones_mul_10: "cpu f32[1]"
#       # ones_mul_10 = prims.tensor_from_sequence([10.0], dtype=dtypes.float32, device=devices.Device("cpu"))  # ones_mul_10: "cpu f32[1]"
#   t14 = Tensor.__getitem__(x, (0, 0))  # t14: "cpu f32[]"
#     # t14 = ltorch.getitem(x, (0, 0))  # t14: "cpu f32[]"
#       # (_, _) = prims.shape(x)
#       # t27 = prims.slice_prim(x, [0, 0], [1, 1], [1, 1])  # t27: "cpu f32[1, 1]"
#       # t14 = prims.squeeze(t27, (0, 1))  # t14: "cpu f32[]"
#   t16 = torch.add(t14, s3)  # t16: "cpu f32[1]"
#     # t16 = ltorch.add(t14, s3, alpha=None)  # t16: "cpu f32[1]"
#       # t29 = prims.convert_element_type(s3, dtypes.float32)  # t29: "cpu f32[1]"
#       # t16 = prims.add(t14, t29)  # t16: "cpu f32[1]"
#   del t14, s3
#   t17 = torch.add(t16, ones_mul_10)  # t17: "cpu f32[1]"
#     # t17 = ltorch.add(t16, ones_mul_10, alpha=None)  # t17: "cpu f32[1]"
#       # t17 = prims.add(t16, ones_mul_10)  # t17: "cpu f32[1]"
#   del t16, ones_mul_10
#   return t17
# print(thunder.last_traces(jf)[-1])

kshitij12345 avatar Oct 03 '24 20:10 kshitij12345

@kshitij12345 Why make this a pass?

What if the implementation of torch.item (or prims.item?) looked like:

def item(a: TensorLike, /) -> Number:
	if is_constexpr(a):
		return a.constexpr_value

    return prims.item(a)

And what if calling torch.ones and other tensor factories appropriately labeled tensors as constexpr when given constexpr inputs?

The reason why I'd prefer to keep the logic with the operators, if possible, is that a pass (as in the above) has to implement operator-specific logic away from the operators. If a pass changes the behavior of specific operators and doesn't require cross-operator analysis then it might best be implemented at the operator level.

For example, would we prefer a pass that translated prims.mul(a, 1) to a nullop, or would we want the mul operation itself to contain the logic that makes it a nullop? I think putting the logic with the operators implements the transform through abstract interpretation, which is nice.

Maybe @tfogal wants to argue for a transform over operator-logic changes, however.

mruberry avatar Oct 03 '24 20:10 mruberry

@mruberry thanks for the suggestion. However, I am leaning a bit more towards pass.

Example snippet -

def forward():
    scale_t = torch.tensor([2])
    getitem = ((scale_t * 10 ) - 2)[0]
    return (getitem, scale_t)

For the above snippet, with operation based approach as you have mentioned, mul, neg and torch.getitem (and any other deterministic operation) would have to know about the constant propagation, I think it could lead to coverage problem. And I imagine, in case of bugs, it would be easy to reason and debug through the pass than to look at all operators and figure why 1 or 2 operator are doing something funny. So, I feel this is more suitable for the pass which will take care of this logic (using PyTorch eager for computation). Wdyt?

If a pass changes the behavior of specific operators and doesn't require cross-operator analysis then it might best be implemented at the operator level.

I don't think this pass would change the behaviour of operator (except for the getitem, I guess) as the pass should be following same semantics just at compile time.

Curious to hear other opinions as well.

kshitij12345 avatar Oct 03 '24 21:10 kshitij12345

@tfogal regarding the original request to inline the output of getitem, one thing I wanted to highlight was scale_t[0] is a scalar tensor and not a Python scalar. scale_t[0].item() will return a Python scalar, so do/should we expect a Python scalar?

Original Graph -

  class DynamoModule(torch.nn.Module):
    def forward(self):
        scale_t = torch.tensor([2])
        getitem = scale_t[0]
        return (getitem, scale_t)

Expected Computation Graph (from PR description)

def computation():
  scale_t = torch.tensor([2], device=torch.device("cpu"), dtype=None)  # scale_t: "cpu i64[1]"
    # scale_t = ltorch.tensor([2], device=torch.device("cpu"), dtype=None, requires_grad=False, pin_memory=False)  # scale_t: "cpu i64[1]"
      # scale_t = prims.tensor_from_sequence([2], dtype=None, device=devices.Device("cpu"))  # scale_t: "cpu i64[1]"
  return (2, scale_t)

PyTorch Eager behavior -

>>> import torch
>>> torch.randn(3)[0]
tensor(-0.3535)
>>> torch.randn(3)[0].item()
0.9367182850837708

kshitij12345 avatar Oct 04 '24 14:10 kshitij12345

@mruberry thanks for the suggestion. However, I am leaning a bit more towards pass.

Example snippet -

def forward():
    scale_t = torch.tensor([2])
    getitem = ((scale_t * 10 ) - 2)[0]
    return (getitem, scale_t)

For the above snippet, with operation based approach as you have mentioned, mul, neg and torch.getitem (and any other deterministic operation) would have to know about the constant propagation, I think it could lead to coverage problem. And I imagine, in case of bugs, it would be easy to reason and debug through the pass than to look at all operators and figure why 1 or 2 operator are doing something funny. So, I feel this is more suitable for the pass which will take care of this logic (using PyTorch eager for computation). Wdyt?

I think it's easier for always on per-operator logic to be with the operator description. Otherwise if people update an operator, or add a new operator that can reason about constexpr inputs, then both the operator and the pass need to be updated. Essentially the logic for an operation is split into two places. If another pass did the same then a single operator could have its logic in three places, and so on.

The maybe_convert_to_dtype operation:

https://github.com/Lightning-AI/lightning-thunder/blob/b4ca020b6e194ac9cd6ce47a28c04df963110a83/thunder/clang/init.py#L150

Follows this logic. In the case of a no-op it performs a no-op that is removed by DCE. Putting that logic into a pass would just mean looking at multiple places to understand what the operation actually did, and people would have to know which other passes applied to the operation (or didn't), too. I think it would be confusing.

Transforms/passes are best used for optional per-operator transformations (like grad) or transformations that reason about multiple operations (like distributed).

If a pass changes the behavior of specific operators and doesn't require cross-operator analysis then it might best be implemented at the operator level.

I don't think this pass would change the behaviour of operator (except for the getitem, I guess) as the pass should be following same semantics just at compile time.

My point is that this would be an "always on" and "per operator" transform, which I think is the category to implement with the operator logic.

Implementing things at the operator level is also typically nicer for debuggability and maintainability, as the manipulation is often more direct and reproducible, not requiring modifying an entire pass to experiment.

Another way to look at this is that when a per-operator pass starts to target specific operators, we probably want to implement that pass with an operator->operator transformation, like grad implements. But if we're doing the transformation every time, then we should just put it into the operator implementation itself.

mruberry avatar Oct 04 '24 16:10 mruberry

do we expect a Python scalar?

You are correct, of course; I should have looked at what torch does and not what we do :-). That the graph happened to name the tensor getitem despite never calling item() certainly didn't help me.

Anyway, in general I would say the rule is match torch, but in this case it seems we've been returning a scalar instead of a tensor for a while and it seems to have been fine. Maybe torch is just setup to accept a raw Python value and a scalar tensor interchangeably? If we can get away with the raw Python value, that seems cheaper and so is reasonable to me. I'm curious as to other's thoughts, though, about whether strict adherence to what torch would do is desired here.

tfogal avatar Oct 04 '24 17:10 tfogal

do we expect a Python scalar?

You are correct, of course; I should have looked at what torch does and not what we do :-). That the graph happened to name the tensor getitem despite never calling item() certainly didn't help me.

Anyway, in general I would say the rule is match torch, but in this case it seems we've been returning a scalar instead of a tensor for a while and it seems to have been fine. Maybe torch is just setup to accept a raw Python value and a scalar tensor interchangeably? If we can get away with the raw Python value, that seems cheaper and so is reasonable to me. I'm curious as to other's thoughts, though, about whether strict adherence to what torch would do is desired here.

If PyTorch returns a tensor here, we should probably return a tensor. While PyTorch may allow tensors or numbers (almost always interchangeably), not all Python code does.

mruberry avatar Oct 04 '24 17:10 mruberry

Triage reviewed - We will pursue explicit pass.

nvMelissa avatar Oct 07 '24 15:10 nvMelissa

Closing this issue as we have added ConstantFolding as a transform in #1273. Have filed a separate issue on how it could be added as a default transform to thunder.jit in #1299

kshitij12345 avatar Oct 14 '24 19:10 kshitij12345