Skip to content

[WIP] refactor trainer#9510

Open
Jintao-Huang wants to merge 1 commit into
modelscope:mainfrom
Jintao-Huang:refactor_trainer_0608
Open

[WIP] refactor trainer#9510
Jintao-Huang wants to merge 1 commit into
modelscope:mainfrom
Jintao-Huang:refactor_trainer_0608

Conversation

@Jintao-Huang

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a custom training loop in the Trainer class by overriding _inner_training_loop and adding several helper methods such as _update_auto_batch_size, set_initial_training_values, _init_training_state, _prepare_for_training, _run_epoch, _finalize_training, and _evaluate. The code reviewer provided valuable feedback pointing out multiple potential AttributeError and FileNotFoundError issues. These issues are mainly related to backward compatibility with older versions of transformers, accelerate, and PyTorch, as well as potential None type dereferences for args and self.lr_scheduler.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread swift/trainers/trainer.py
Comment on lines +159 to +160
# needed to calculate tokens/s
self._initial_num_input_tokens_seen = self.state.num_input_tokens_seen

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

self.state.num_input_tokens_seen will raise an AttributeError on older versions of transformers (< 4.43.0) where token tracking is not supported. We should guard it with hasattr to ensure backward compatibility.

Suggested change
# needed to calculate tokens/s
self._initial_num_input_tokens_seen = self.state.num_input_tokens_seen
# needed to calculate tokens/s
if hasattr(self.state, 'num_input_tokens_seen'):
self._initial_num_input_tokens_seen = self.state.num_input_tokens_seen

Comment thread swift/trainers/trainer.py
Comment on lines 100 to 101
if num_items_in_batch is not None and self.model_accepts_loss_kwargs:
loss = loss / self.args.gradient_accumulation_steps

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

self.model_accepts_loss_kwargs will raise an AttributeError on older versions of transformers where this attribute is not defined on Trainer. We should use getattr(self, 'model_accepts_loss_kwargs', False) to ensure backward compatibility.

Suggested change
if num_items_in_batch is not None and self.model_accepts_loss_kwargs:
loss = loss / self.args.gradient_accumulation_steps
if num_items_in_batch is not None and getattr(self, 'model_accepts_loss_kwargs', False):
loss = loss / self.args.gradient_accumulation_steps

Comment thread swift/trainers/trainer.py
Comment on lines +376 to +379
if self.is_fsdp_enabled:
# Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA
if hasattr(self.model, 'generate'):
dist.fsdp.register_fsdp_forward_method(self.model, 'generate')

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

dist.fsdp.register_fsdp_forward_method is only available in PyTorch 2.2+. If the user is running an older version of PyTorch (e.g., 2.0 or 2.1), this will raise an AttributeError. We should guard this call with hasattr to ensure backward compatibility.

Suggested change
if self.is_fsdp_enabled:
# Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA
if hasattr(self.model, 'generate'):
dist.fsdp.register_fsdp_forward_method(self.model, 'generate')
if self.is_fsdp_enabled:
# Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA
if hasattr(self.model, 'generate') and hasattr(getattr(dist, 'fsdp', None), 'register_fsdp_forward_method'):
dist.fsdp.register_fsdp_forward_method(self.model, 'generate')

Comment thread swift/trainers/trainer.py
Comment on lines +383 to +385
pc = getattr(self.accelerator, 'parallelism_config', None)
if pc is not None and pc.sp_backend == 'deepspeed' and pc.sp_enabled:
train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

self.accelerator.deepspeed_ulysses_dl_adapter will raise an AttributeError on older versions of accelerate where DeepSpeed Ulysses is not supported. We should guard it with hasattr to ensure backward compatibility.

Suggested change
pc = getattr(self.accelerator, 'parallelism_config', None)
if pc is not None and pc.sp_backend == 'deepspeed' and pc.sp_enabled:
train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model)
pc = getattr(self.accelerator, 'parallelism_config', None)
if pc is not None and pc.sp_backend == 'deepspeed' and pc.sp_enabled:
if hasattr(self.accelerator, 'deepspeed_ulysses_dl_adapter'):
train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model)

Comment thread swift/trainers/trainer.py
Comment on lines +514 to +520
if not self.accelerator.optimizer_step_was_skipped:
# Delay optimizer scheduling until metrics are generated
lr_scheduler_cls = [torch.optim.lr_scheduler.ReduceLROnPlateau]
if GreedyLR is not None:
lr_scheduler_cls.append(GreedyLR)
if not isinstance(self.lr_scheduler, tuple(lr_scheduler_cls)):
self.lr_scheduler.step()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If self.lr_scheduler is None (e.g., if no learning rate scheduler is configured), isinstance(None, tuple(lr_scheduler_cls)) will evaluate to False, and calling self.lr_scheduler.step() will raise an AttributeError. We should add a self.lr_scheduler is not None check to prevent crashes.

Suggested change
if not self.accelerator.optimizer_step_was_skipped:
# Delay optimizer scheduling until metrics are generated
lr_scheduler_cls = [torch.optim.lr_scheduler.ReduceLROnPlateau]
if GreedyLR is not None:
lr_scheduler_cls.append(GreedyLR)
if not isinstance(self.lr_scheduler, tuple(lr_scheduler_cls)):
self.lr_scheduler.step()
if not self.accelerator.optimizer_step_was_skipped and self.lr_scheduler is not None:
# Delay optimizer scheduling until metrics are generated
lr_scheduler_cls = [torch.optim.lr_scheduler.ReduceLROnPlateau]
if GreedyLR is not None:
lr_scheduler_cls.append(GreedyLR)
if not isinstance(self.lr_scheduler, tuple(lr_scheduler_cls)):
self.lr_scheduler.step()

Comment thread swift/trainers/trainer.py
Comment on lines +171 to +172
if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

args.eval_on_start will raise an AttributeError on older versions of transformers where eval_on_start is not defined in TrainingArguments. We should use getattr(args, 'eval_on_start', False) to ensure backward compatibility.

Suggested change
if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
if getattr(args, 'eval_on_start', False):
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)

Comment thread swift/trainers/trainer.py
Comment on lines +606 to +607
self._finish_current_push()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

self._finish_current_push() will raise an AttributeError on older versions of transformers where this method is not defined. We should guard it with hasattr to ensure backward compatibility.

Suggested change
self._finish_current_push()
# Wait for the checkpoint to be uploaded.
if hasattr(self, '_finish_current_push'):
self._finish_current_push()

Comment thread swift/trainers/trainer.py
Comment on lines +610 to +615
if self.neftune_noise_alpha is not None:
if hasattr(self, '_deactivate_neftune'):
self._deactivate_neftune(self.model)
else:
from transformers.trainer import deactivate_neftune
deactivate_neftune(self.model, self.neftune_hook_handle, self.accelerator)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

self.neftune_noise_alpha will raise an AttributeError on older versions of transformers (< 4.35.0) where NEFTune is not supported. We should use getattr(self, 'neftune_noise_alpha', None) to ensure backward compatibility.

Suggested change
if self.neftune_noise_alpha is not None:
if hasattr(self, '_deactivate_neftune'):
self._deactivate_neftune(self.model)
else:
from transformers.trainer import deactivate_neftune
deactivate_neftune(self.model, self.neftune_hook_handle, self.accelerator)
if getattr(self, 'neftune_noise_alpha', None) is not None:
if hasattr(self, '_deactivate_neftune'):
self._deactivate_neftune(self.model)
else:
from transformers.trainer import deactivate_neftune
deactivate_neftune(self.model, self.neftune_hook_handle, self.accelerator)

Comment thread swift/trainers/trainer.py
Comment on lines +113 to +115
"""Run the actual training loop: forward, backward, optimizer step, logging, and checkpointing."""
# reset everything
self.accelerator.free_memory()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The args parameter is typed as TrainingArguments | None = None in the method signature, but it is accessed directly as args.auto_find_batch_size on line 116. If args is passed as None, this will raise an AttributeError. We should initialize args = args or self.args at the start of the method to prevent potential crashes.

        """Run the actual training loop: forward, backward, optimizer step, logging, and checkpointing."""
        args = args or self.args
        # reset everything
        self.accelerator.free_memory()

Comment thread swift/trainers/trainer.py
Comment on lines +597 to +601
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using os.path.samefile can raise a FileNotFoundError if either of the paths does not exist on the local disk (e.g., if the best checkpoint was already cleaned up or in multi-node environments). Comparing resolved paths using os.path.realpath is much safer and avoids throwing exceptions.

Suggested change
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if os.path.realpath(checkpoint) != os.path.realpath(self.state.best_model_checkpoint):
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant