sd-scripts
sd-scripts copied to clipboard
Why are model names hardcoded in analyze_checkpoint_state?
My checkpoint is named like this, diffusion_pytorch_model-00001-of-00002.safetensors diffusion_pytorch_model-00002-of-00002.safetensors, it won't load because the name indices 00003 is hardcoded in your code,
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
Args:
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
Returns:
Tuple[bool, bool, Tuple[int, int], List[str]]:
- bool: Diffusersかどうかを示すフラグ。
- bool: Schnellかどうかを示すフラグ。
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
- List[str]: チェックポイントに含まれるキーのリスト。
"""
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
if "00001-of-00003" in ckpt_path:
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
else:
ckpt_paths = [ckpt_path]
Why is that?
I'm going to simplify your logic to these,
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
ckpt_paths = []
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
for sft in glob.glob(os.path.join(ckpt_path, "transformer", "diffusion*.safetensors")) + glob.glob(os.path.join(ckpt_path, "diffusion*.safetensors")):
ckpt_paths.append(sft)
else:
ckpt_paths = [ckpt_path]
.....
If you're ok with that, I'll submit a PR for this. @kohya-ss
I was looking to resolve this with https://github.com/kohya-ss/sd-scripts/pull/1913 but there are a lot of edge cases to consider to make it work with all the possible options.