TypeError: <module '__main__'> is a built-in module
Hi everyone, I'm trying to use jaxtyping to hopefully make my code more debuggable. However, I just ran into this problem which I can't seem to fix. So I was trying to put jaxtyping to check one method in a class. I followed instructions from here: link. However I got this error:
" TypeError: <module 'main'> is a built-in module".
I don't know why it got an error even though I followed the tutorial. Below is the code. Thank you.
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
from jaxtyping import Array, Float, PyTree, jaxtyped
from typeguard import typechecked as typechecker
from dataclasses import dataclass
# Ensure every computation happens on the GPU when available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class PatchEmbeddings(nn.Module):
def __init__(self, img_size=96, patch_size=16, hidden_dim=512):
super().__init__()
# Store the input image size
self.img_size = img_size
# Store the size of each patch
self.patch_size = patch_size
# Calculate the total number of patches
self.num_patches = (img_size // patch_size) ** 2
# Create a convolutional layer to extract patch embeddings
# in_channels=3 assumes the input image has 3 color channels (RGB)
# out_channels=hidden_dim sets the number of output channels to match the hidden dimension
# kernel_size=patch_size and stride=patch_size ensure each patch is separately embedded
self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,
kernel_size=patch_size, stride=patch_size)
@jaxtyped(typechecker=typechecker)
def forward(self, X: Float[torch.Tensor, "B C H H"])-> Float[torch.Tensor, "B C img_size img_size"]:
# X: (B, C, img_size, img_size)
# Extract patch embeddings from the input image
# set_trace()
X = self.conv(X) # (B, hidden_dim, img_size // patch_size, img_size // patch_size)
# set_trace()
# Flatten the spatial dimensions (height and width) of the patch embeddings
# This step flattens the patch dimensions into a single dimension
X = X.flatten(2) # (B, hidden_dim, self.num_patches)
# set_trace()
# Transpose the dimensions to obtain the shape [batch_size, num_patches, hidden_dim]
# This step brings the num_patches dimension to the second position
# X = X.transpose(1, 2) # (B, self.num_patches, hidden_dim)
return X
#testing
img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4
patch_embeddings = PatchEmbeddings(img_size, patch_size, num_hiddens )
X = torch.zeros(batch_size, 3, img_size, img_size)
patch_embeddings(X).shape
Can you minify your example down? Right now this includes e.g. a lot of unnecessary imports, some of them outside the standard library.
Can you minify your example down? Right now this includes e.g. a lot of unnecessary imports, some of them outside the standard library.
Hi Patrick, I have minified the example by removing unnecessary imports.
Thanks! Running your code I get a different error / what looks like a valid error:
jaxtyping.TypeCheckError: Type-check error whilst checking the return value of __main__.PatchEmbeddings.forward.
Actual value: f32[4,512,36](torch)
Expected type: Float[Tensor, 'B C img_size img_size'].
----------------------
Called with parameters: {'self': PatchEmbeddings(...), 'X': f32[4,3,96,96](torch)}
Parameter annotations: (self, X: Float[Tensor, 'B C H H']) -> Any.
The current values for each jaxtyping axis annotation are as follows.
B=4
C=3
H=96
In particular the code you have here does not seem to reproduce the error you describe, I'm afraid.
Thanks! Running your code I get a different error / what looks like a valid error:
jaxtyping.TypeCheckError: Type-check error whilst checking the return value of __main__.PatchEmbeddings.forward. Actual value: f32[4,512,36](torch) Expected type: Float[Tensor, 'B C img_size img_size']. ---------------------- Called with parameters: {'self': PatchEmbeddings(...), 'X': f32[4,3,96,96](torch)} Parameter annotations: (self, X: Float[Tensor, 'B C H H']) -> Any. The current values for each jaxtyping axis annotation are as follows. B=4 C=3 H=96In particular the code you have here does not seem to reproduce the error you describe, I'm afraid.
Hi. I was running on jupyter notebook notebook, python 3.13.5. I set up using: conda create -n aiagent python=3.13. When I tried running pure python, I got this error:
Traceback (most recent call last): File "/media/dangmanhtruong/147E655C7E65379E/TRUONG/Code_tu_hoc/AI_agent_tutorials/pydantic_ai_multi_agent_example/temp2.py", line 14, in
class PatchEmbeddings(nn.Module): ...<38 lines>... return X File "/media/dangmanhtruong/147E655C7E65379E/TRUONG/Code_tu_hoc/AI_agent_tutorials/pydantic_ai_multi_agent_example/temp2.py", line 34, in PatchEmbeddings @jaxtyped(typechecker=typechecker) ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/jaxtyping/_decorator.py", line 428, in jaxtyped full_fn = _apply_typechecker(typechecker, full_fn) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/jaxtyping/_decorator.py", line 71, in _apply_typechecker return typechecker(fn) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_decorators.py", line 230, in typechecked retval = instrument(target) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_decorators.py", line 73, in instrument instrumentor.visit(module_ast) ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 422, in visit return visitor(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 599, in visit_Module self.generic_visit(node) ~~~~~~~~~~~~~~~~~~^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 499, in generic_visit node = super().generic_visit(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 498, in generic_visit value = self.visit(value) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 422, in visit return visitor(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 644, in visit_ClassDef self.generic_visit(node) ~~~~~~~~~~~~~~~~~~^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 499, in generic_visit node = super().generic_visit(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 498, in generic_visit value = self.visit(value) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 422, in visit return visitor(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 673, in visit_FunctionDef with self._use_memo(node): ~~~~~~~~~~~~~~^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/contextlib.py", line 141, in enter return next(self.gen) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 557, in _use_memo new_memo.return_annotation = self._convert_annotation( ~~~~~~~~~~~~~~~~~~~~~~~~^ return_annotation ^^^^^^^^^^^^^^^^^ ) ^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 583, in _convert_annotation new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation)) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 355, in visit new_node = super().visit(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 422, in visit return visitor(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 422, in visit_Subscript [self.visit(item) for item in node.slice.elts], ~~~~~~~~~~^^^^^^ File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 355, in visit new_node = super().visit(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 422, in visit return visitor(node) File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/site-packages/typeguard/_transformer.py", line 475, in visit_Constant expression = ast.parse(node.value, mode="eval") File "/home/dangmanhtruong/anaconda3/envs/aiagent/lib/python3.13/ast.py", line 50, in parse return compile(source, filename, mode, flags, _feature_version=feature_version, optimize=optimize) File " ", line 1 B C img_size img_size ^ SyntaxError: invalid syntax
This looks like a known error coming from the use of typeguard v4. You should typeguard v2.13.3, as the newer versions are known to be buggy.