axlearn
axlearn copied to clipboard
Fix module import 2
Pull Request: Fix Incorrect Module Import in layers.py
Now axlearn supports python 3.10 thus the PR Is valid now.
Summary
This pull request addresses an issue with an incorrect module import in the axlearn/common/quantized_dot_general/layers.py
file. The import and usage of the Context
class from the aqt_dot_general
module have been corrected to use the appropriate Context
class from the aqt_utils
module.
Changes
1. Import Correction:
Now axlearn supports python 3.10 thus the PR Is valid now.
- Changed the import from
aqt.jax.v2.aqt_dot_general
toaqt.jax.v2.utils
and aliased it asaqt_utils
.
2. Context Class Usage:
- Updated the usage of the
Context
class to useaqt_utils.Context
instead ofaqt_dot_general.Context
.
Code Changes
1. Importing aqt_utils
from aqt.jax.v2 import utils as aqt_utils
2. Taken Reference from the official README of aqt
context: aqt_dot_general.Context = aqt_dot_general.Context(key=None, train_step=None),
to
context: aqt_utils.Context = aqt_utils.Context(key=None, train_step=None),
Testing
- Verified that the changes resolve the
AttributeError
and the code runs successfully without any issues.
my mail : [email protected]