vllm icon indicating copy to clipboard operation
vllm copied to clipboard

Adding Share Expert Fusion for DeepSeek

Open DiegoD94 opened this issue 1 year ago • 18 comments

This PR is to submit a feature enable fuse shared expert into MoE module for DeepSeek V2/3 models. Design doc: https://docs.google.com/document/d/1iXgzR6Mt6s0DpT7w2Pz93ExlUJ-nnSnU_o9Sqd8TE34/edit?tab=t.0

TL;DR:

We have up to 34% ITL lossless improvement and 12% TTLT lossless improvement by enabling this feature. With FlashAttention3, without MLA and without V1 Engine.

vllm: 0.8.3,
LLMPerf: https://github.com/ray-project/llmperf, 
LM-Eval-Harness: https://github.com/EleutherAI/lm-evaluation-harness 0.4.0

Command to enable this feature:

export VLLM_USE_V1=0; export VLLM_MLA_DISABLE=1; export VLLM_ENABLE_SHARE_EXPERT_FUSION=1;python3 -m vllm.entrypoints.openai.api_server --model /home/ubuntu/workspace/ckpt/DeepSeek-R1 --tensor-parallel-size 8 --trust-remote-code --max-model-len 16384

Details:

Implementation:

  1. Currently we used a ENV Variable (VLLM_ENABLE_SHARE_EXPERT_FUSION) as the toggle to enable/disable the feature. Later we can edit it as a engine args.
  2. When this feature enabled, we internally set the expert number =264 and topk = 9 for FusedMoE module, and clone the shared expert weight into experts No.256 to No.263 during weight loading stage. Here the CUDA memory usage increase a bit (about 1-2 GB per GPU, assuming TP=8)
  3. Top 9th routed expert id and expert weights are manually assigned so that expert are balanced and accuracy are guaranteed.

Result:

Kernel level profile in the doc above

Latency:

The 117% TTLT improvement might be due to a stuck prefill of baseline. I excluded it for conclusion.

| BatchSize | ContextLen | Baseline(0.8.3)                         | Ours (This PR)                         | Improvement          |
|-----------|------------|-----------------------------------------|----------------------------------------|----------------------|
|           |            | ITL_p50 | ITL_p90 | TTFT_p50 | TTFT_p90 | ITL_p50 | ITL_p90 | TTFT_p50 | TTFT_p90 | ITL    | TTFT   |
|-----------|------------|---------|---------|----------|----------|---------|---------|----------|----------|--------|--------|
| 1         | 512        | 0.019   | 0.020   | 0.188    | 0.892    | 0.017   | 0.018   | 0.173    | 0.290    | 9.63%  | 9.01%  |
| 1         | 2048       | 0.020   | 0.020   | 0.197    | 0.621    | 0.018   | 0.018   | 0.178    | 0.213    | 12.92% | 10.38% |
| 1         | 8192       | 0.021   | 0.021   | 0.406    | 0.589    | 0.018   | 0.018   | 0.396    | 0.757    | 12.55% | 2.64%  |
| 2         | 512        | 0.022   | 0.022   | 0.201    | 0.214    | 0.018   | 0.019   | 0.178    | 0.288    | 22.97% | 12.65% |
| 2         | 2048       | 0.023   | 0.023   | 0.367    | 0.378    | 0.019   | 0.019   | 0.342    | 0.350    | 20.31% | 7.07%  |
| 2         | 8192       | 0.025   | 0.027   | 0.732    | 0.765    | 0.021   | 0.024   | 0.704    | 0.739    | 18.83% | 4.03%  |
| 4         | 512        | 0.021   | 0.022   | 0.359    | 0.362    | 0.020   | 0.020   | 0.329    | 0.344    | 6.47%  | 8.93%  |
| 4         | 2048       | 0.022   | 0.023   | 0.445    | 0.468    | 0.021   | 0.023   | 0.432    | 0.449    | 3.21%  | 2.87%  |
| 4         | 8192       | 0.027   | 0.042   | 1.425    | 1.526    | 0.027   | 0.039   | 1.316    | 1.369    | 1.87%  | 8.28%  |
| 8         | 512        | 0.023   | 0.024   | 2.364    | 2.520    | 0.022   | 0.023   | 1.088    | 1.844    | 5.31%  | 117.30%|
| 8         | 2048       | 0.025   | 0.027   | 0.753    | 0.785    | 0.025   | 0.026   | 0.715    | 0.738    | 1.24%  | 5.26%  |
| 8         | 8192       | 0.048   | 0.080   | 2.331    | 2.846    | 0.036   | 0.067   | 2.232    | 2.439    | 34.61% | 4.41%  |

Accuracy: GSM8K looks good

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9591|±  |0.0055|
|     |       |strict-match    |     5|exact_match|↑  |0.9583|±  |0.0055|

DiegoD94 avatar Mar 25 '25 21:03 DiegoD94

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Mar 25 '25 21:03 github-actions[bot]

This is cool, good find! Do you have any idea how much cloning + load-balancing is helping vs. not cloning the expert and instead just running an unbalanced fused-moe?

LucasWilkinson avatar Mar 25 '25 21:03 LucasWilkinson

This is cool, good find! Do you have any idea how much cloning + load-balancing is helping vs. not cloning the expert and instead just running an unbalanced fused-moe?

Actually cloning won't give obvious inference latency benefit nor decreasing the speed, I current use 264 instead of other number(257, 258, etc.) as we might need do EP later, especially with dissgragated prefiling later. I think I can also make the number of copied share expert configurable, by setting VLLM_ENABLE_FUSION_SHARE = 0/1/2/3/4/5/6/7/8 (0 = disable and 1=1copy, 8=8copy, etc.)

DiegoD94 avatar Mar 26 '25 00:03 DiegoD94

This is cool, good find! Do you have any idea how much cloning + load-balancing is helping vs. not cloning the expert and instead just running an unbalanced fused-moe?

Actually cloning won't give obvious inference latency benefit nor decreasing the speed, I current use 264 instead of other number(257, 258, etc.) as we might need do EP later, especially with dissgragated prefiling later. I think I can also make the number of copied share expert configurable, by setting VLLM_ENABLE_FUSION_SHARE = 0/1/2/3/4/5/6/7/8 (0 = disable and 1=1copy, 8=8copy, etc.)

oh if thats the case then we should default to 1 to avoid wasting memory that could be KV-cache in the non-EP case (more KV cache directly benefits the offline throughput case, i.e. inf QPS)

LucasWilkinson avatar Mar 26 '25 04:03 LucasWilkinson

This is cool, good find! Do you have any idea how much cloning + load-balancing is helping vs. not cloning the expert and instead just running an unbalanced fused-moe?

Actually cloning won't give obvious inference latency benefit nor decreasing the speed, I current use 264 instead of other number(257, 258, etc.) as we might need do EP later, especially with dissgragated prefiling later. I think I can also make the number of copied share expert configurable, by setting VLLM_ENABLE_FUSION_SHARE = 0/1/2/3/4/5/6/7/8 (0 = disable and 1=1copy, 8=8copy, etc.)

oh if thats the case then we should default to 1 to avoid wasting memory that could be KV-cache in the non-EP case (more KV cache directly benefits the offline throughput case, i.e. inf QPS)

Thanks for the comments! In next revision I'll keep the enable/disable option and also setting this by default to 1 and make the number of replica configurable.

DiegoD94 avatar Mar 26 '25 07:03 DiegoD94

@LucasWilkinson @tlrmchlsmth @robertgshaw2-redhat Hi all I have pushed a second revision make the replica configurable, one can now setting the VLLM_ENABLE_SHARED_FUSION to 0 to disable the feature and 1,2,3,4,5,6,7,8, or higher number to config the number of replica they like. I also added warning if both this feature and use_ep is enabled, and if number of replica is not same as ep-size

DiegoD94 avatar Mar 27 '25 08:03 DiegoD94

Ported in a new commit to fix current comments, let me know if there is any potential blocker for merging this in, thanks! @LucasWilkinson @tlrmchlsmth @maobaolong @robertgshaw2-redhat

DiegoD94 avatar Mar 31 '25 17:03 DiegoD94

curious what's the performance like without kernel config tuning?

sarckk avatar Mar 31 '25 23:03 sarckk

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 01 '25 16:04 mergify[bot]

curious what's the performance like without kernel config tuning?

Screenshot 2025-04-01 at 11 06 47 AM This is an internal experiment that I benchmarked based on vllm 0.7.2, with and without tuned config, I think we can include all tuned configs and report the number as next step

DiegoD94 avatar Apr 01 '25 18:04 DiegoD94

rebased to catch up mainline

DiegoD94 avatar Apr 01 '25 19:04 DiegoD94

I think we should prioritize getting this in for deepseek and llama4, it seems to be a pretty clear win

mgoin avatar Apr 07 '25 00:04 mgoin

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 07 '25 13:04 mergify[bot]

Running into an issue trying it on DeepSeekV2. @DiegoD94 could you take a look?

VLLM_USE_V1=0 VLLM_SHARED_EXPERT_FUSION_REPLICAS=1 vllm serve deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --tensor-parallel-size 2 --trust-remote-code --max-model-len 16384
Cloning 1 replicas of shared expert into MoE:   4%|███▎                                                                                      | 1/27 [00:00<00:00, 150.74it/s]
ERROR 04-07 13:35:48 [engine.py:448] 'model.layers.1.mlp.shared_experts.down_proj.weight_scale_inv'
ERROR 04-07 13:35:48 [engine.py:448] Traceback (most recent call last):
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/multiprocessing/engine.py", line 436, in run_mp_engine
ERROR 04-07 13:35:48 [engine.py:448]     engine = MQLLMEngine.from_vllm_config(
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/multiprocessing/engine.py", line 128, in from_vllm_config
ERROR 04-07 13:35:48 [engine.py:448]     return cls(
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/multiprocessing/engine.py", line 82, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     self.engine = LLMEngine(*args, **kwargs)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/llm_engine.py", line 281, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     self.model_executor = executor_class(vllm_config=vllm_config, )
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/executor_base.py", line 271, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     super().__init__(*args, **kwargs)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/executor_base.py", line 52, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     self._init_executor()
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/mp_distributed_executor.py", line 125, in _init_executor
ERROR 04-07 13:35:48 [engine.py:448]     self._run_workers("load_model",
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/mp_distributed_executor.py", line 185, in _run_workers
ERROR 04-07 13:35:48 [engine.py:448]     driver_worker_output = run_method(self.driver_worker, sent_method,
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/utils.py", line 2347, in run_method
ERROR 04-07 13:35:48 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/worker/worker.py", line 183, in load_model
ERROR 04-07 13:35:48 [engine.py:448]     self.model_runner.load_model()
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/worker/model_runner.py", line 1113, in load_model
ERROR 04-07 13:35:48 [engine.py:448]     self.model = get_model(vllm_config=self.vllm_config)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/model_executor/model_loader/__init__.py", line 14, in get_model
ERROR 04-07 13:35:48 [engine.py:448]     return loader.load_model(vllm_config=vllm_config)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/model_executor/model_loader/loader.py", line 444, in load_model
ERROR 04-07 13:35:48 [engine.py:448]     loaded_weights = model.load_weights(
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/model_executor/models/deepseek_v2.py", line 761, in load_weights
ERROR 04-07 13:35:48 [engine.py:448]     f".{suffix}", weights_dict[
ERROR 04-07 13:35:48 [engine.py:448] KeyError: 'model.layers.1.mlp.shared_experts.down_proj.weight_scale_inv'

tlrmchlsmth avatar Apr 07 '25 13:04 tlrmchlsmth

I think we should prioritize getting this in for deepseek and llama4, it seems to be a pretty clear win

Running into an issue trying it on DeepSeekV2. @DiegoD94 could you take a look?

VLLM_USE_V1=0 VLLM_SHARED_EXPERT_FUSION_REPLICAS=1 vllm serve deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --tensor-parallel-size 2 --trust-remote-code --max-model-len 16384
Cloning 1 replicas of shared expert into MoE:   4%|███▎                                                                                      | 1/27 [00:00<00:00, 150.74it/s]
ERROR 04-07 13:35:48 [engine.py:448] 'model.layers.1.mlp.shared_experts.down_proj.weight_scale_inv'
ERROR 04-07 13:35:48 [engine.py:448] Traceback (most recent call last):
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/multiprocessing/engine.py", line 436, in run_mp_engine
ERROR 04-07 13:35:48 [engine.py:448]     engine = MQLLMEngine.from_vllm_config(
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/multiprocessing/engine.py", line 128, in from_vllm_config
ERROR 04-07 13:35:48 [engine.py:448]     return cls(
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/multiprocessing/engine.py", line 82, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     self.engine = LLMEngine(*args, **kwargs)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/llm_engine.py", line 281, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     self.model_executor = executor_class(vllm_config=vllm_config, )
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/executor_base.py", line 271, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     super().__init__(*args, **kwargs)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/executor_base.py", line 52, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     self._init_executor()
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/mp_distributed_executor.py", line 125, in _init_executor
ERROR 04-07 13:35:48 [engine.py:448]     self._run_workers("load_model",
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/mp_distributed_executor.py", line 185, in _run_workers
ERROR 04-07 13:35:48 [engine.py:448]     driver_worker_output = run_method(self.driver_worker, sent_method,
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/utils.py", line 2347, in run_method
ERROR 04-07 13:35:48 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/worker/worker.py", line 183, in load_model
ERROR 04-07 13:35:48 [engine.py:448]     self.model_runner.load_model()
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/worker/model_runner.py", line 1113, in load_model
ERROR 04-07 13:35:48 [engine.py:448]     self.model = get_model(vllm_config=self.vllm_config)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/model_executor/model_loader/__init__.py", line 14, in get_model
ERROR 04-07 13:35:48 [engine.py:448]     return loader.load_model(vllm_config=vllm_config)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/model_executor/model_loader/loader.py", line 444, in load_model
ERROR 04-07 13:35:48 [engine.py:448]     loaded_weights = model.load_weights(
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/model_executor/models/deepseek_v2.py", line 761, in load_weights
ERROR 04-07 13:35:48 [engine.py:448]     f".{suffix}", weights_dict[
ERROR 04-07 13:35:48 [engine.py:448] KeyError: 'model.layers.1.mlp.shared_experts.down_proj.weight_scale_inv'

Yeah this should be due to different name convention in V2 v.s. V3/R1 ckpt, I'll take a look, but essentially this PR is targeting V3/R1 only

DiegoD94 avatar Apr 07 '25 19:04 DiegoD94

Running into an issue trying it on DeepSeekV2. @DiegoD94 could you take a look?

VLLM_USE_V1=0 VLLM_SHARED_EXPERT_FUSION_REPLICAS=1 vllm serve deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --tensor-parallel-size 2 --trust-remote-code --max-model-len 16384
Cloning 1 replicas of shared expert into MoE:   4%|███▎                                                                                      | 1/27 [00:00<00:00, 150.74it/s]
ERROR 04-07 13:35:48 [engine.py:448] 'model.layers.1.mlp.shared_experts.down_proj.weight_scale_inv'
ERROR 04-07 13:35:48 [engine.py:448] Traceback (most recent call last):
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/multiprocessing/engine.py", line 436, in run_mp_engine
ERROR 04-07 13:35:48 [engine.py:448]     engine = MQLLMEngine.from_vllm_config(
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/multiprocessing/engine.py", line 128, in from_vllm_config
ERROR 04-07 13:35:48 [engine.py:448]     return cls(
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/multiprocessing/engine.py", line 82, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     self.engine = LLMEngine(*args, **kwargs)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/engine/llm_engine.py", line 281, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     self.model_executor = executor_class(vllm_config=vllm_config, )
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/executor_base.py", line 271, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     super().__init__(*args, **kwargs)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/executor_base.py", line 52, in __init__
ERROR 04-07 13:35:48 [engine.py:448]     self._init_executor()
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/mp_distributed_executor.py", line 125, in _init_executor
ERROR 04-07 13:35:48 [engine.py:448]     self._run_workers("load_model",
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/executor/mp_distributed_executor.py", line 185, in _run_workers
ERROR 04-07 13:35:48 [engine.py:448]     driver_worker_output = run_method(self.driver_worker, sent_method,
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/utils.py", line 2347, in run_method
ERROR 04-07 13:35:48 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/worker/worker.py", line 183, in load_model
ERROR 04-07 13:35:48 [engine.py:448]     self.model_runner.load_model()
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/worker/model_runner.py", line 1113, in load_model
ERROR 04-07 13:35:48 [engine.py:448]     self.model = get_model(vllm_config=self.vllm_config)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/model_executor/model_loader/__init__.py", line 14, in get_model
ERROR 04-07 13:35:48 [engine.py:448]     return loader.load_model(vllm_config=vllm_config)
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/model_executor/model_loader/loader.py", line 444, in load_model
ERROR 04-07 13:35:48 [engine.py:448]     loaded_weights = model.load_weights(
ERROR 04-07 13:35:48 [engine.py:448]   File "/home/tms/vllm/vllm/model_executor/models/deepseek_v2.py", line 761, in load_weights
ERROR 04-07 13:35:48 [engine.py:448]     f".{suffix}", weights_dict[
ERROR 04-07 13:35:48 [engine.py:448] KeyError: 'model.layers.1.mlp.shared_experts.down_proj.weight_scale_inv'

I've fixed this in new commit, previous one assuming weight name, expert number and routing scaling to be DSV3/R1 only

DiegoD94 avatar Apr 08 '25 00:04 DiegoD94

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 09 '25 06:04 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 10 '25 14:04 mergify[bot]

Can you merge from main? It should fix a bunch of failed tests

DarkLight1337 avatar Apr 15 '25 06:04 DarkLight1337

Can you merge from main? It should fix a bunch of failed tests

Hi Thanks, I still got some failed test, but majority of them are timeout error(all 4 kernel test) you have any idea why?

DiegoD94 avatar Apr 16 '25 00:04 DiegoD94

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 17 '25 02:04 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 23 '25 14:04 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar May 13 '25 11:05 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar May 14 '25 02:05 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar May 16 '25 14:05 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Jun 05 '25 01:06 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Jun 05 '25 18:06 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @DiegoD94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Jun 07 '25 03:06 mergify[bot]

@DiegoD94 Are you still active maintaining the PR? Will you need help to make this merge into main?

tjtanaa avatar Jul 20 '25 03:07 tjtanaa

Closing as stale (there are also major conflicts which indicate that a new PR might be a better approach)

hmellor avatar Nov 12 '25 11:11 hmellor