jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

TypeError: <module '__main__'> is a built-in module

Open dangmanhtruong1995 opened this issue 5 months ago • 5 comments

Hi everyone, I'm trying to use jaxtyping to hopefully make my code more debuggable. However, I just ran into this problem which I can't seem to fix. So I was trying to put jaxtyping to check one method in a class. I followed instructions from here: link. However I got this error:

" TypeError: <module 'main'> is a built-in module".

I don't know why it got an error even though I followed the tutorial. Below is the code. Thank you.

import torchvision.transforms as transforms
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init

from jaxtyping import Array, Float, PyTree, jaxtyped
from typeguard import typechecked as typechecker
from dataclasses import dataclass

# Ensure every computation happens on the GPU when available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class PatchEmbeddings(nn.Module):
    def __init__(self, img_size=96, patch_size=16, hidden_dim=512):
        super().__init__()

        # Store the input image size
        self.img_size = img_size

        # Store the size of each patch
        self.patch_size = patch_size

        # Calculate the total number of patches
        self.num_patches = (img_size // patch_size) ** 2

        # Create a convolutional layer to extract patch embeddings
        # in_channels=3 assumes the input image has 3 color channels (RGB)
        # out_channels=hidden_dim sets the number of output channels to match the hidden dimension
        # kernel_size=patch_size and stride=patch_size ensure each patch is separately embedded
        self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,
                              kernel_size=patch_size, stride=patch_size)

    @jaxtyped(typechecker=typechecker)
    def forward(self, X: Float[torch.Tensor, "B C H H"])-> Float[torch.Tensor, "B C img_size img_size"]:
        # X: (B, C, img_size, img_size)
        
        # Extract patch embeddings from the input image
        # set_trace()
        X = self.conv(X) # (B, hidden_dim, img_size // patch_size, img_size // patch_size)
        
        # set_trace()

        # Flatten the spatial dimensions (height and width) of the patch embeddings
        # This step flattens the patch dimensions into a single dimension
        X = X.flatten(2) # (B, hidden_dim, self.num_patches)
        # set_trace()

        # Transpose the dimensions to obtain the shape [batch_size, num_patches, hidden_dim]
        # This step brings the num_patches dimension to the second position
        # X = X.transpose(1, 2) # (B, self.num_patches, hidden_dim)

        return X

#testing
img_size, patch_size,  num_hiddens, batch_size = 96, 16, 512, 4
patch_embeddings = PatchEmbeddings(img_size, patch_size, num_hiddens )
X = torch.zeros(batch_size, 3, img_size, img_size)
patch_embeddings(X).shape

dangmanhtruong1995 avatar Aug 15 '25 10:08 dangmanhtruong1995

Can you minify your example down? Right now this includes e.g. a lot of unnecessary imports, some of them outside the standard library.

patrick-kidger avatar Aug 15 '25 11:08 patrick-kidger

Can you minify your example down? Right now this includes e.g. a lot of unnecessary imports, some of them outside the standard library.

Hi Patrick, I have minified the example by removing unnecessary imports.

dangmanhtruong1995 avatar Aug 15 '25 11:08 dangmanhtruong1995

Thanks! Running your code I get a different error / what looks like a valid error:

jaxtyping.TypeCheckError: Type-check error whilst checking the return value of __main__.PatchEmbeddings.forward.
Actual value: f32[4,512,36](torch)
Expected type: Float[Tensor, 'B C img_size img_size'].
----------------------
Called with parameters: {'self': PatchEmbeddings(...), 'X': f32[4,3,96,96](torch)}
Parameter annotations: (self, X: Float[Tensor, 'B C H H']) -> Any.
The current values for each jaxtyping axis annotation are as follows.
B=4
C=3
H=96

In particular the code you have here does not seem to reproduce the error you describe, I'm afraid.

patrick-kidger avatar Aug 15 '25 17:08 patrick-kidger

Thanks! Running your code I get a different error / what looks like a valid error:

jaxtyping.TypeCheckError: Type-check error whilst checking the return value of __main__.PatchEmbeddings.forward.
Actual value: f32[4,512,36](torch)
Expected type: Float[Tensor, 'B C img_size img_size'].
----------------------
Called with parameters: {'self': PatchEmbeddings(...), 'X': f32[4,3,96,96](torch)}
Parameter annotations: (self, X: Float[Tensor, 'B C H H']) -> Any.
The current values for each jaxtyping axis annotation are as follows.
B=4
C=3
H=96

In particular the code you have here does not seem to reproduce the error you describe, I'm afraid.

Hi. I was running on jupyter notebook notebook, python 3.13.5. I set up using: conda create -n aiagent python=3.13. When I tried running pure python, I got this error:

Traceback (most recent call last): File "/media/dangmanhtruong/147E655C7E65379E/TRUONG/Code_tu_hoc/AI_agent_tutorials/pydantic_ai_multi_agent_example/temp2.py", line 14, in class PatchEmbeddings(nn.Module): ...<38 lines>... return X File "/media/dangmanhtruong/147E655C7E65379E/TRUONG/Code_tu_hoc/AI_agent_tutorials/pydantic_ai_multi_agent_example/temp2.py", line 34, in PatchEmbeddings @jaxtyped(typechecker=typechecker) ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/jaxtyping/_decorator.py", line 428, in jaxtyped full_fn = _apply_typechecker(typechecker, full_fn) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/jaxtyping/_decorator.py", line 71, in _apply_typechecker return typechecker(fn) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_decorators.py", line 230, in typechecked retval = instrument(target) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_decorators.py", line 73, in instrument instrumentor.visit(module_ast) ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 422, in visit return visitor(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 599, in visit_Module self.generic_visit(node) ~~~~~~~~~~~~~~~~~~^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 499, in generic_visit node = super().generic_visit(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 498, in generic_visit value = self.visit(value) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 422, in visit return visitor(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 644, in visit_ClassDef self.generic_visit(node) ~~~~~~~~~~~~~~~~~~^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 499, in generic_visit node = super().generic_visit(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 498, in generic_visit value = self.visit(value) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 422, in visit return visitor(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 673, in visit_FunctionDef with self._use_memo(node): ~~~~~~~~~~~~~~^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/contextlib.py", line 141, in enter return next(self.gen) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 557, in _use_memo new_memo.return_annotation = self._convert_annotation( ~~~~~~~~~~~~~~~~~~~~~~~~^ return_annotation ^^^^^^^^^^^^^^^^^ ) ^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 583, in _convert_annotation new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation)) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 355, in visit new_node = super().visit(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 422, in visit return visitor(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 422, in visit_Subscript [self.visit(item) for item in node.slice.elts], ~~~~~~~~~~^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 355, in visit new_node = super().visit(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 422, in visit return visitor(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 475, in visit_Constant expression = ast.parse(node.value, mode="eval") File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 50, in parse return compile(source, filename, mode, flags, _feature_version=feature_version, optimize=optimize) File "", line 1 B C img_size img_size ^ SyntaxError: invalid syntax

dangmanhtruong1995 avatar Aug 16 '25 09:08 dangmanhtruong1995

This looks like a known error coming from the use of typeguard v4. You should typeguard v2.13.3, as the newer versions are known to be buggy.

patrick-kidger avatar Aug 18 '25 15:08 patrick-kidger