Skip to content

Fix DeepSpeed model preparation logic in Trainer class#43780

Merged
SunMarc merged 4 commits intomainfrom
fix-deepspeed
Feb 6, 2026
Merged

Fix DeepSpeed model preparation logic in Trainer class#43780
SunMarc merged 4 commits intomainfrom
fix-deepspeed

Conversation

@qgallouedec
Copy link
Member

The changes in #43711 caused the model to be never prepared when using DeepSpeed. When training you hit for example:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/fsx/qgallouedec/trl/trl/scripts/grpo.py", line 193, in <module>
[rank0]:     main(script_args, training_args, model_args, dataset_args)
[rank0]:     ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/trl/trl/scripts/grpo.py", line 162, in main
[rank0]:     trainer.train()
[rank0]:     ~~~~~~~~~~~~~^^
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/trainer.py", line 2170, in train
[rank0]:     return inner_training_loop(
[rank0]:         args=args,
[rank0]:     ...<2 lines>...
[rank0]:         ignore_keys_for_eval=ignore_keys_for_eval,
[rank0]:     )
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/trainer.py", line 2537, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:   File "/fsx/qgallouedec/trl/trl/trainer/grpo_trainer.py", line 1024, in training_step
[rank0]:     output = super().training_step(model, inputs, num_items_in_batch)
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/trainer.py", line 3804, in training_step
[rank0]:     inputs = self._prepare_inputs(inputs)
[rank0]:   File "/fsx/qgallouedec/trl/trl/extras/profiling.py", line 202, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:   File "/fsx/qgallouedec/trl/trl/trainer/grpo_trainer.py", line 1053, in _prepare_inputs
[rank0]:     generation_batch = self._generate_and_score_completions(generation_batch)
[rank0]:   File "/fsx/qgallouedec/trl/trl/trainer/grpo_trainer.py", line 1551, in _generate_and_score_completions
[rank0]:     ) = self._generate(prompts)
[rank0]:         ~~~~~~~~~~~~~~^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/trl/trl/trainer/grpo_trainer.py", line 1431, in _generate
[rank0]:     prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts)
[rank0]:                                                          ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/trl/trl/trainer/grpo_trainer.py", line 1237, in _generate_single_turn
[rank0]:     prompt_completion_ids = unwrapped_model.generate(
[rank0]:         **generate_inputs, generation_config=self.generation_config, disable_compile=True
[rank0]:     )
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/generation/utils.py", line 2638, in generate
[rank0]:     result = decoding_method(
[rank0]:         self,
[rank0]:     ...<5 lines>...
[rank0]:         **model_kwargs,
[rank0]:     )
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/generation/utils.py", line 2833, in _sample
[rank0]:     outputs = self._prefill(input_ids, generation_config, model_kwargs)
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/generation/utils.py", line 3822, in _prefill
[rank0]:     return self(**model_inputs, return_dict=True)
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/utils/generic.py", line 834, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/models/qwen2/modeling_qwen2.py", line 475, in forward
[rank0]:     outputs: BaseModelOutputWithPast = self.model(
[rank0]:                                        ~~~~~~~~~~^
[rank0]:         input_ids=input_ids,
[rank0]:         ^^^^^^^^^^^^^^^^^^^^
[rank0]:     ...<6 lines>...
[rank0]:         **kwargs,
[rank0]:         ^^^^^^^^^
[rank0]:     )
[rank0]:     ^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/utils/generic.py", line 1001, in wrapper
[rank0]:     outputs = func(self, *args, **kwargs)
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/models/qwen2/modeling_qwen2.py", line 373, in forward
[rank0]:     inputs_embeds = self.embed_tokens(input_ids)
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
[rank0]:     return inner()
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1830, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/torch/nn/modules/sparse.py", line 191, in forward
[rank0]:     return F.embedding(
[rank0]:            ~~~~~~~~~~~^
[rank0]:         input,
[rank0]:         ^^^^^^
[rank0]:     ...<5 lines>...
[rank0]:         self.sparse,
[rank0]:         ^^^^^^^^^^^^
[rank0]:     )
[rank0]:     ^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.13/site-packages/torch/nn/functional.py", line 2567, in embedding
[rank0]:     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank0]:            ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Expected all tensors to be on the same device, but got index is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA__index_select)

this PR fixes it

@qgallouedec
Copy link
Member Author

@ArthurZucker @Cyrilvallez ca we have this in the next patch release please (if any)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix. Just a question below.

if self.is_deepspeed_enabled:
from accelerate.utils import DummyScheduler

if isinstance(self.lr_scheduler, DummyScheduler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we need this condition first.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a special case when the lr_scheduler is created by deepspeed and not by us / users

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @SunMarc do you know if our tests just did not catch or they are slow?

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the regression ! Changes make sense

@SunMarc
Copy link
Member

SunMarc commented Feb 6, 2026

cc @SunMarc do you know if our tests just did not catch or they are slow?

Slow tests but we definitely need to work on improving our tests. Coming soon !

@SunMarc SunMarc enabled auto-merge (squash) February 6, 2026 19:08
@SunMarc SunMarc merged commit f70ec1e into main Feb 6, 2026
26 checks passed
@SunMarc SunMarc deleted the fix-deepspeed branch February 6, 2026 19:20
jiosephlee pushed a commit to jiosephlee/transformers_latest that referenced this pull request Feb 11, 2026
…3780)

Fix deepspeed model preparation logic in Trainer class

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants