diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Add SkyReels V2: Infinite-Length Film Generative Model

Open tolgacangoz opened this issue 6 months ago • 19 comments

Thanks for the opportunity to fix #11374!

Original Work

Original repo: https://github.com/SkyworkAI/SkyReels-V2 Paper: https://huggingface.co/papers/2504.13074

SkyReels V2's main contributions are summarized as follow: • Comprehensive video captioner that understand the shot language while capturing the general description of the video, which dramatically improve the prompt adherence. • Motion-specific preference optimization enhances motion dynamics with a semi-automatic data collection pipeline. • Effective Diffusion-forcing adaptation enables the generation of ultra-long videos and story generation capabilities, providing a robust framework for extending temporal coherence and narrative depth. • SkyCaptioner-V1 and SkyReels-V2 series models including diffusion-forcing, text2video, image2video, camera director and elements2video models with various sizes (1.3B, 5B, 14B) are open-sourced.

main_pipeline

TODOs: :white_check_mark: FlowMatchUniPCMultistepScheduler: just copy-pasted from the original repo :white_check_mark: SkyReelsV2Transformer3DModel: 90% WanTransformer3DModel :white_check_mark: SkyReelsV2DiffusionForcingPipeline :white_check_mark: SkyReelsV2DiffusionForcingImageToVideoPipeline: Includes FLF2V. :white_check_mark: SkyReelsV2DiffusionForcingVideoToVideoPipeline: Extends a given video. :white_check_mark: SkyReelsV2Pipeline :white_check_mark: SkyReelsV2ImageToVideoPipeline: Includes FLF2V. :white_check_mark: scripts/convert_skyreelsv2_to_diffusers.py

⏳ Did you make sure to update the documentation with your changes? Did you write any new necessary tests?: We will construct these during review.

T2V with Diffusion Forcing (OLD)

Skywork/SkyReels-V2-DF-1.3B-540P
seed 0 and num_frames 97
Original repo diffusers integration
seed 37 and num_frames 97
Original repo diffusers integration
seed 0 and num_frames 257
Original repo diffusers integration
seed 37 and num_frames 257
Original repo diffusers integration
!pip install git+https://github.com/tolgacangoz/diffusers.git@skyreels-v2 ftfy -q
import torch
from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingPipeline
from diffusers.utils import export_to_video

vae = AutoencoderKLWan.from_pretrained(
			"tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers",
			subfolder="vae",
			torch_dtype=torch.float32)
pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
			"tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers",
			vae=vae,
			torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
pipe.transformer.set_ar_attention(causal_block_size=5)

prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."

output = pipe(
    prompt=prompt,
    num_inference_steps=30,
    height=544,
    width=960,
    num_frames=97,
    ar_step=5,  # Controls asynchronous inference (0 for synchronous mode)
    generator=torch.Generator(device="cpu").manual_seed(0),
    overlap_history=None,  # Number of frames to overlap for smooth transitions in long videos; 17 for long
    addnoise_condition=20,  # Improves consistency in long video generation
).frames[0]
export_to_video(output, "T2V.mp4", fps=24, quality=8)

"""
You can set `ar_step=5` to enable asynchronous inference. When asynchronous inference,
`causal_block_size=5` is recommended while it is not supposed to be set for
synchronous generation. Asynchronous inference will take more steps to diffuse the
whole sequence which means it will be SLOWER than synchronous mode. In our
experiments, asynchronous inference may improve the instruction following and visual consistent performance.
"""

I2V with Diffusion Forcing (OLD)

prompt="A penguin dances." diffusers integration
#!pip uninstall diffusers -yq
#!pip install git+https://github.com/tolgacangoz/diffusers.git@skyreels-v2 ftfy -q
import torch
from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPipeline
from diffusers.utils import export_to_video, load_image

vae = AutoencoderKLWan.from_pretrained(
			"tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers",
			subfolder="vae",
			torch_dtype=torch.float32)
pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
			"tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers",
			vae=vae,
			torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
#pipe.transformer.set_ar_attention(causal_block_size=5)

image = load_image("Penguin from https://huggingface.co/tasks/image-to-video")
prompt = "A penguin dances."

output = pipe(
    image=image,
    prompt=prompt,
    num_inference_steps=50,
    height=544,
    width=960,
    num_frames=97,
    #ar_step=5,  # Controls asynchronous inference (0 for synchronous mode)
    generator=torch.Generator(device="cpu").manual_seed(0),
    overlap_history=None,  # Number of frames to overlap for smooth transitions in long videos; 17 for long
    addnoise_condition=20,  # Improves consistency in long video generation
).frames[0]
export_to_video(output, "I2V.mp4", fps=24, quality=8)

"""
When I set `ar_step=5` and `causal_block_size=5`, then the results seem really bad.
"""

FLF2V with Diffusion Forcing (OLD)

Now, Houston, we have a problem. I have been unable to produce good results with this task. I tried many hyperparameter combinations with the original code. The first frame's latent (torch.Size([1, 16, 1, 68, 120])) is overwritten onto the first of 25 frame latents of latents (torch.Size([1, 16, 25, 68, 120])). Then, the last frame's latent is concatenated, thus latents is torch.Size([1, 16, 26, 68, 120]). After the denoising process, the length of the last frame latent is discarded at the end and then decoded by the VAE. I tried not concatenating the last frame but overwriting onto the latest frame of latents and not discarding the latest frame latent at the end, but still got bad results. Here are some results:

First Frame Last Frame
#!pip uninstall diffusers -yq
#!pip install git+https://github.com/tolgacangoz/diffusers.git@skyreels-v2 ftfy -q
import torch
from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPipeline
from diffusers.utils import export_to_video, load_image

vae = AutoencoderKLWan.from_pretrained(
			"tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers",
			subfolder="vae",
			torch_dtype=torch.float32)
pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
			"tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers",
			vae=vae,
			torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
#pipe.transformer.set_ar_attention(causal_block_size=5)

prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")

output = pipe(
    image=first_frame,
    last_image=last_frame,
    prompt=prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=50,
    height=544,
    width=960,
    num_frames=97,
    #ar_step=5,  # Controls asynchronous inference (0 for synchronous mode)
    generator=torch.Generator(device="cpu").manual_seed(0),
    overlap_history=None,  # Number of frames to overlap for smooth transitions in long videos; 17 for long
    addnoise_condition=20,  # Improves consistency in long video generation
).frames[0]
export_to_video(output, "FLF2V.mp4", fps=24, quality=8)

V2V with Diffusion Forcing (OLD)

This pipeline extends a given video.

Input Video diffusers integration
#!pip uninstall diffusers -yq
#!pip install git+https://github.com/tolgacangoz/diffusers.git@skyreels-v2 ftfy -q
import torch
from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPipeline
from diffusers.utils import export_to_video, load_video

vae = AutoencoderKLWan.from_pretrained(
			"tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers",
			subfolder="vae",
			torch_dtype=torch.float32)
pipe = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
			"tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers",
			vae=vae,
			torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
#pipe.transformer.set_ar_attention(causal_block_size=5)

prompt = "CG animation style, a small blue bird flaps its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its continuing flight and the vastness of the sky from a close-up, low-angle perspective."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
video = load_video("Input video.mp4")

output = pipe(
    video=video,
    prompt=prompt,
    num_inference_steps=50,
    height=544,
    width=960,
    num_frames=120,
    base_num_frames=97,
    ar_step=0,  # Controls asynchronous inference (0 for synchronous mode)
    generator=torch.Generator(device="cpu").manual_seed(0),
    overlap_history=17,  # Number of frames to overlap for smooth transitions in long videos
    addnoise_condition=20,  # Improves consistency in long video generation
).frames[0]
export_to_video(output, "V2V.mp4", fps=24, quality=8)

Firstly, I want to congratulate you on this great work, and thanks for open-sourcing it, SkyReels Team! This PR proposes an integration of your model.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@yiyixuxu @a-r-r-o-w @linoytsaban @yjp999 @Howe2018 @RoseRollZhu @pftq @Langdx @guibinchen @qiudi0127 @nitinmukesh @tin2tin @ukaprch @okaris

tolgacangoz avatar May 07 '25 18:05 tolgacangoz

It's about time. Thanks.

ukaprch avatar May 08 '25 15:05 ukaprch

Hi @yiyixuxu @a-r-r-o-w

Mid-PR questions:

  1. The issue was labelled as "contributions-welcome" but not as "community-examples". Also, the number of stars in this model surpassed that of SkyReels-V1. Thus, I located these pipelines in src/diffusers/pipelines/skyreels_v2/. Should I move these pipelines to examples/? Should I also split this PR for each pipeline (group)?

  2. Just like SkyReels-V1 was based on HunyuanVideo, SkyReels-V2 is based on Wan, but some differences exist. I thought of moving the differences to the parent abstraction, i.e., pipeline code, so that we can use WanTransformer3DModel for both, but it didn't seem appropriate enough to me at first. But then, if we introduce Diffusion Forcing and AutoRegressive properties into WanTransformer3DModel (as native as possible, not with the exact diff below), it seems possible to me. You can examine the current diff between transformer_wan.py and transformer_skyreels_v2.py: https://www.diffchecker.com/U72HJ6ox/ WDYT?

  1. Since SkyReels-V2 is not a completely new architecture, should I move its pipelines into src/diffusers/pipelines/wan/ similar to HunyuanSkyreelsImageToVideoPipeline, if SkyReels-V2 is seen as an official pipeline?

  2. I am removing TeaCache-related code because it is planned for a modular extension, right? If this PR is required to move to examples/, then no need to remove, I think.

  3. I came across this: https://github.com/huggingface/diffusers/blob/01abfc873659e29a8d002f20782fa5b5e6d03f9c/src/diffusers/models/embeddings.py#L1153 At first, [: (dim // 2)] confused me :S Isn't it redundant? dim was already confirmed even with assert dim % 2 == 0. Can I remove [: (dim // 2)] in a separate PR?

tolgacangoz avatar May 14 '25 15:05 tolgacangoz

@tolgacangoz Thanks for working on this, really cool work so far!

  1. I think we should add SkyReels models in core diffusers, so src/ is fine.

2 and 3. I think in this case, we should have separate implementation of SkyReelsV2 and Wan due to the autoregressive nature of the former. Adding any extra code in Wan might complicate it for readers. Will let @yiyixuxu comment on this though

  1. Yeah let's remove the cache code. We'll try to write a model agnostic implementation in future once more of the cache related code is stablized after adding some of the easier methods that are not too model intrusive (such as first block cache).

  2. You're right, it's redundant. Let's remove in a separate PR

a-r-r-o-w avatar May 15 '25 10:05 a-r-r-o-w

FWIW, I have been successful in using the same T5 encoder for WAN 2.1 for this model just by fiddling with their pipeline:

        print('Quantize text_encoder qint8')
        class  QuantizedT5EncoderModelForCausalLM (QuantizedTransformersModel):
            auto_class = UMT5EncoderModel
            auto_class.from_config = auto_class._from_config
        text_encoder = QuantizedT5EncoderModelForCausalLM.from_pretrained(
            "./wan quantro T2V-14B-720P Diffusers/basemodel/t5encodermodel_qint8"
        ).to(dtype=dtype)
        
        pipe = Text2VideoPipeline(
            model_path=model_path,
            transformer=transformer,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            weight_dtype=dtype)

Then this: I incorporate my bitsandbytes nf4 transformer, their tokenizer and the WAN based T5 encoder:

def __init__(
    self, model_path, transformer, text_encoder, tokenizer, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False,
):
    self.transformer = transformer          #get_transformer(model_path, 'cpu', weight_dtype)
    vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
    self.vae = get_vae(vae_model_path, 'cpu', weight_dtype=torch.float32)
    if text_encoder is not None:
        self.text_encoder = text_encoder        #get_text_encoder(model_path, 'cuda', weight_dtype)
    if tokenizer is not None: 
        self.tokenizer = tokenizer

I need to add this function to the pipeline for the T5 encoder to work:

def encode(self, texts):
    ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
    ids = ids.to(self.device)
    mask = mask.to(self.device)
    context = self.text_encoder(ids, mask)
    #seq_lens = mask.gt(0).sum(dim=1).long()
    context = context.last_hidden_state * mask.unsqueeze(-1)
    return context

ukaprch avatar May 16 '25 15:05 ukaprch

It seems appropriate to me. Only Diffusion Forcing pipelines are different for large models. How are the results with your setting?

tolgacangoz avatar May 19 '25 08:05 tolgacangoz

Hi @yiyixuxu @a-r-r-o-w and SkyReels Team @yjp999 @pftq @Langdx @guibinchen ...

This PR will be ready for review for SkyReelsV2Transformer3DModel and SkyReelsV2DiffusionForcingPipeline soon. Other pipelines will follow quickly after initial feedback...

tolgacangoz avatar May 23 '25 11:05 tolgacangoz

@tolgacangoz Awesome work so far, just checking in on the progress. Still trying to fully wrap my head around diffusion forcing and trying to visually verify that diffusers version code matches original. As a sanity check, do we know why the output of T2V from the original code vs diffusers is different? Typically, we try to ensure that given the same starting embeddings, seed and other starting conditions, the output from different implementations matches numerically with threshold < 1e-3. I will try to help with debugging and testing :hugs:

a-r-r-o-w avatar Jun 02 '25 05:06 a-r-r-o-w

The original code had a default negative prompt (which I wasn't aware of at the time); this might have been the reason; or something about timestep processing. I will try to make it as deterministically matchable as possible.

tolgacangoz avatar Jun 02 '25 06:06 tolgacangoz

Today I mostly followed by reading the code. But tomorrow I will be more systematic: I will go through the inputs and outputs of each module step by step.

tolgacangoz avatar Jun 02 '25 16:06 tolgacangoz

I have finally discovered the reasons for the discrepancy and will share them tomorrow :partying_face:!

tolgacangoz avatar Jun 04 '25 17:06 tolgacangoz

Hi @a-r-r-o-w. Initially, I didn't consider that one should have made the integration with an error rate smaller than 1e-3 because of numerical errors, etc. But when you mentioned the ideal condition for integration with a proper quantitative metric, I examined the whole PR from scratch. I found out two things:

  1. self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) should have been self.timesteps_proj = ..embeddings.get_1d_sincos_pos_embed_from_grid. My mistake :sweat_smile:.

  2. This one is interesting. I have debugged until I spotted the first point where torch.equal(diffusers_output_at_that_point, skyreelsv2_output_at_that_point) gave False one abstraction level below at a time, which was norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) in diffusers and out = mul_add_add_compile(self.norm1(x), e[1], e[0]) in the original code. As far as I knew/imagined/expected, a compiled function is just a tritonized/faster equivalent. Therefore, I didn't remove the compiled code from the original repository while comparing. Compiled one produces different results:

import torch

def mul_add_add(x, y, z):
    return x.float() * (1 + y) + z

mul_add_add_compile = torch.compile(mul_add_add, dynamic=True)

x = torch.randn([1, 51000, 1536], device='cuda')
y = torch.randn([1, 51000, 1536], device='cuda')
z = torch.randn([1, 51000, 1536], device='cuda')
torch.allclose(mul_add_add(x, y, z), mul_add_add_compile(x, y, z))
# T4 and L40S give `False`

Normally, this is supposed to be a bug, right? But in this case, if the model was also trained with the compiled one, this wouldn't be a bug but a feature, right :open_mouth:! The model would had seen compiled one's outputs. So, are we supposed to use the compiled ones in diffusers? If not, the results are visibly different.

I am now able to produce similar videos:

Skywork/SkyReels-V2-DF-1.3B-540P
Original repo diffusers integration

I modified the original code to make the two equalized: https://github.com/tolgacangoz/SkyReels-V2/tree/comparable-skyreelsv2 Does my understanding of comparison proper? torch.allclose still gives False for the latest latents. I used float32 for everything but attention inputs. The original code's text_encoder was bfloat16 in its repo; thus, converting it into float32 might have created a discrepancy. Moreover, maybe, there might be another discrepancy-producer after the first attention's norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)...

Since my GPU resources have depleted now, I won't be able to produce example videos and compare the two implementations with their latest latents. After taking a look at other TODOs, I will open this PR for review.

I am merging scripts/convert_skyreelsv2_to_diffusers.py with scripts/convert_wan_to_diffusers.py, right?

tolgacangoz avatar Jun 05 '25 13:06 tolgacangoz

@tolgacangoz Wow, that's awesome progress!

To address the points about precision, it is completely normal to have a difference upto 1e-2 or sometimes higher when using 16-bit precision. It is actually very input dependent. Even just changing order of operations in the modeling code very slightly may yield wildly different absmax differences. This is because floating point operations are not associative:

image

When we do integrations, we try to run a smaller dummy version of the model in float32 precision. Whenever there is a mismatch, we try to match the internal layers in the same way you did - one abstraction level at a time, so really great work there. As long as the latent space difference is < 1e-3 in float32, it can be consider that the model implementations are the same. Small differences are amplified in lower precisions, so it's completely understandable even if we didn't get a full match in bf16/fp16.

Regarding torch.compile producing different results - it is expected. The programs in eager mode vs tritonized may have a different order of operations, certain options may have been enabled to get higher speedup in compiled code at the loss of some precision (an example you could think of is using bf16 accumulation in matmul's instead fp32, which is a lot faster), or different elementwise operations may be fused together (so completely different order, which is not associative), etc. The differences are usually on the order of 1e-2 to 1e-5 or lower. My personal method of doing integrations sometimes includes rewriting parts of the original codebases to remove any kind of "clever" optimizations which are not relevant to the integration itself.

In this case, after all the recent updates, I think we have a perfect match between the implementations, and the subtle differences can be attributed to using bf16.

I am merging scripts/convert_skyreelsv2_to_diffusers.py with scripts/convert_wan_to_diffusers.py, right?

It's completely alright to have a separate one (would personally prefer it to be separate). Conversion scripts are more or less just reference code that is not actively maintained, so any number of files that help with separating responsibilities should be okay.

a-r-r-o-w avatar Jun 05 '25 15:06 a-r-r-o-w

Alright, thanks for informing! I was aware of non-associativity, but couldn't get the scope of compiling enough :+1:.

tolgacangoz avatar Jun 06 '25 10:06 tolgacangoz

Thank you @tolgacangoz @a-r-r-o-w Could you take a look please

DN6 avatar Jun 09 '25 03:06 DN6

Hi @nitinmukesh @tin2tin. You can make tests, reviews for this PR just as you have done in other PRs, if you want.

tolgacangoz avatar Jun 10 '25 07:06 tolgacangoz

Thank you @tolgacangoz for making the feature available in diffusers.

I will test it now.

nitinmukesh avatar Jun 10 '25 08:06 nitinmukesh

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Thanks so much for the invaluable reviews @a-r-r-o-w, @yiyixuxu! I am continuing to work on your comments... Btw, could you examine FLF2V with Diffusion Forcing (OLD) part in the first comment of the PR?

tolgacangoz avatar Jun 18 '25 07:06 tolgacangoz

Thanks so much for the invaluable reviews @a-r-r-o-w, @yiyixuxu! I am continuing to work on your comments... Btw, could you examine FLF2V with Diffusion Forcing (OLD) part in the first comment of the PR?

Similarly, FLF2V with SkyReelsV2ImageToVideoPipeline yields very poor results, which I have recently added support (?) for.

Details

tolgacangoz avatar Jun 22 '25 08:06 tolgacangoz

cc @yiyixuxu could you take a final look + scheduler?

a-r-r-o-w avatar Jun 30 '25 09:06 a-r-r-o-w

Btw, the current usage of the scheduler, removing shift from the pipeline's call, produces slightly different timesteps, thus slightly different videos. I was investigating this, but didn't conclude; I will take a look at it again.

Also, before merging, should these two messages be investigated?

tolgacangoz avatar Jul 02 '25 10:07 tolgacangoz

@tolgacangoz I think it'll be good to decouple the FLF2V into a separate PR if the results are not good. I'm afraid I don't have the time to help in investigating the cause here right now, and this PR has been open for a really long time already and anticipated to be in master by many. Let's try to merge the ones that work for now :)

a-r-r-o-w avatar Jul 02 '25 11:07 a-r-r-o-w

Or, I think they can stay as a meaning of placeholder or potential feature, because the original code was the one that I cannot produce good results with 1.3B models for FLF2V. Or, it was I who couldn't run this task properly, idk :S. Maybe it is OK with larger models. I think this PR is well-suited for its job for integration.

Edit: I opened an issue at the original repo about this. I forgot to open earlier, sry :smiling_face_with_tear:.

tolgacangoz avatar Jul 02 '25 12:07 tolgacangoz

@tolgacangoz are you able to refactor current FlowMatchUniPCMultistepSchedule instead of adding a new one?

yiyixuxu avatar Jul 03 '25 01:07 yiyixuxu

This will be my 3. pipeline contribution, yay :partying_face:!

tolgacangoz avatar Jul 04 '25 05:07 tolgacangoz

Right.

tolgacangoz avatar Jul 06 '25 06:07 tolgacangoz

hi @tolgacangoz

can you send PR into the official repo for the weights, I think they have created place holder for all the checkpoints, e.g.

https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers

yiyixuxu avatar Jul 08 '25 00:07 yiyixuxu

I thought they were supposed to do this by examining/verifying the conversion script, etc., since we talk about the official repository. Sorry for the misunderstanding, working on it.

tolgacangoz avatar Jul 08 '25 06:07 tolgacangoz

Or, I think they can stay as a meaning of placeholder or potential feature, because the original code was the one that I cannot produce good results with 1.3B models for FLF2V. Or, it was I who couldn't run this task properly, idk :S. Maybe it is OK with larger models. I think this PR is well-suited for its job for integration.

Edit: I opened an issue at the original repo about this. I forgot to open earlier, sry 🥲.

They say try with 14B models for FLF2V, thus this issue (?) is irrelevant from this PR, IMO.

tolgacangoz avatar Jul 10 '25 13:07 tolgacangoz

@tolgacangoz can you point me to where this conversation is so I can get some context? 😛

They say try with 14B models for FLF2V, thus this issue (?) is irrelevant from this PR, IMO.

yiyixuxu avatar Jul 14 '25 17:07 yiyixuxu