Bug: AttributeError when using cached text embeddings with sampling enabled
When training with cache_text_embeddings: true and sampling enabled (which is the default), the training fails with:
AttributeError: 'FakeTextEncoder' object has no attribute 'encoder'
This occurs during the initial baseline sample generation before training starts.
Root Cause
The issue is in the get_te_has_grad() method in chroma_model.py. When cache_text_embeddings is enabled, the T5 text encoder is replaced with a FakeTextEncoder to save memory. However, the get_te_has_grad() method assumes the text encoder always has an encoder attribute, which FakeTextEncoder doesn't have.
This method is called during device state management when generating samples, causing the crash.
Configuration Used
train:
train_text_encoder: false
cache_text_embeddings: true
unload_text_encoder: false
disable_sampling: false # Sampling is enabled
Expected Behavior
The system should be able to generate samples even when using cached text embeddings. The get_te_has_grad() method should handle the case when FakeTextEncoder is used.
quick and dirty patch to check for the fake encoder. got me up and running with chroma+cached embeddings
index 9bf3e51..855f380 100644
--- a/extensions_built_in/diffusion_models/chroma/chroma_model.py
+++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py
@@ -413,8 +413,11 @@ class ChromaModel(BaseModel):
return self.model.final_layer.linear.weight.requires_grad
def get_te_has_grad(self):
- # return from a weight if it has grad
- return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
+ from toolkit.unloader import FakeTextEncoder
+ te = self.text_encoder[1]
+ if isinstance(te, FakeTextEncoder):
+ return False
+ return te.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
def save_model(self, output_path, meta, save_dtype):
if not output_path.endswith(".safetensors"):
This also happened to Wan 2.2 TI2V 5B using UI.
A very similar version of this issue is also happening with Flux. (Yes, some folks still want to generate Flux loras!)
I tried hacking this fix into stable_diffusion_model.py but was not able to get it to work properly. I was able to get past the error after some experimentation, but then ran into a new issue. So I gave up. The code is slightly different between these programs.
I am new to AI Toolkit, but the frustrating part of it is that recent videos specifically encourage the use of Cache Text Embedding, so I carefully built my dataset around that.
Can confirm, this fixes the issue for chroma when seeing the faketextencoder error after enabling caching of text embeddings.
This also happened to Wan 2.2 TI2V 5B using UI.
U need to modify line 659 in the wan21.py: def get_te_has_grad(self): from toolkit.unloader import FakeTextEncoder te = getattr(self, "text_encoder", None) if te is None or isinstance(te, FakeTextEncoder): return False for p in te.parameters(): return p.requires_grad return False