Janus icon indicating copy to clipboard operation
Janus copied to clipboard

JanusFlow示例代码修复建议

Open songafu opened this issue 11 months ago • 3 comments

在使用JanusFlow的过程中,发现示例代码中的文生图部分在当前最新的Transformer版本(>=4.48.0)下无法正常运行且示例代码可能存在代码缺陷,建议修复如下: (1)JanusFlow示例代码(文生图)中数据流的处理,存在变量引用逻辑错误 if step == 0: outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, use_cache=True, attention_mask=attention_mask, past_key_values=None) past_key_values = [] for kv_cache in outputs.past_key_values: #should be outputs.past_key_values k, v = kv_cache[0], kv_cache[1] past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :])) past_key_values = tuple(past_key_values)

(2)在当前最新的Transformer版本(>=4.48.0)下无法正常运行JanusFlow服务,报错如下,建议在Quick Start 提示用户选择使用较低版本的transformer(如4.38.2)或兼容最新的transformer版本修复。 llama/modeling_llama.py", line 551, in forward past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'tuple' object has no attribute 'get_seq_length'

songafu avatar Jan 28 '25 23:01 songafu

Good job!

SimonYS001 avatar Jan 29 '25 20:01 SimonYS001

just created a fix for this https://github.com/deepseek-ai/Janus/pull/137

Replace the code on line 108 - 122

if step == 0:
            outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, 
                                             use_cache=True, 
                                             attention_mask=attention_mask,
                                             past_key_values=None)
            past_key_values = []
            for kv_cache in past_key_values:
                k, v = kv_cache[0], kv_cache[1]
                past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
            past_key_values = tuple(past_key_values)
        else:
            outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, 
                                             use_cache=True, 
                                             attention_mask=attention_mask,
                                             past_key_values=past_key_values)

with this

if step == 0:
            past_key_values = None  # Ensure it starts as None
        else:
            past_key_values = tuple(past_key_values) if past_key_values else None  # Convert only if it's valid

        outputs = vl_gpt.language_model.model(
            inputs_embeds=llm_emb, 
            use_cache=True, 
            attention_mask=attention_mask,
            past_key_values=past_key_values  # Now correctly assigned
        )

Hope it helps :)

scifisatan avatar Feb 04 '25 14:02 scifisatan

Just want to confirm if the kv cache bu used in Janus Flow?

nv-samcheng avatar Feb 15 '25 21:02 nv-samcheng