[WIP] refactor trainer#9510
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a custom training loop in the Trainer class by overriding _inner_training_loop and adding several helper methods such as _update_auto_batch_size, set_initial_training_values, _init_training_state, _prepare_for_training, _run_epoch, _finalize_training, and _evaluate. The code reviewer provided valuable feedback pointing out multiple potential AttributeError and FileNotFoundError issues. These issues are mainly related to backward compatibility with older versions of transformers, accelerate, and PyTorch, as well as potential None type dereferences for args and self.lr_scheduler.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| # needed to calculate tokens/s | ||
| self._initial_num_input_tokens_seen = self.state.num_input_tokens_seen |
There was a problem hiding this comment.
self.state.num_input_tokens_seen will raise an AttributeError on older versions of transformers (< 4.43.0) where token tracking is not supported. We should guard it with hasattr to ensure backward compatibility.
| # needed to calculate tokens/s | |
| self._initial_num_input_tokens_seen = self.state.num_input_tokens_seen | |
| # needed to calculate tokens/s | |
| if hasattr(self.state, 'num_input_tokens_seen'): | |
| self._initial_num_input_tokens_seen = self.state.num_input_tokens_seen |
| if num_items_in_batch is not None and self.model_accepts_loss_kwargs: | ||
| loss = loss / self.args.gradient_accumulation_steps |
There was a problem hiding this comment.
self.model_accepts_loss_kwargs will raise an AttributeError on older versions of transformers where this attribute is not defined on Trainer. We should use getattr(self, 'model_accepts_loss_kwargs', False) to ensure backward compatibility.
| if num_items_in_batch is not None and self.model_accepts_loss_kwargs: | |
| loss = loss / self.args.gradient_accumulation_steps | |
| if num_items_in_batch is not None and getattr(self, 'model_accepts_loss_kwargs', False): | |
| loss = loss / self.args.gradient_accumulation_steps |
| if self.is_fsdp_enabled: | ||
| # Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA | ||
| if hasattr(self.model, 'generate'): | ||
| dist.fsdp.register_fsdp_forward_method(self.model, 'generate') |
There was a problem hiding this comment.
dist.fsdp.register_fsdp_forward_method is only available in PyTorch 2.2+. If the user is running an older version of PyTorch (e.g., 2.0 or 2.1), this will raise an AttributeError. We should guard this call with hasattr to ensure backward compatibility.
| if self.is_fsdp_enabled: | |
| # Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA | |
| if hasattr(self.model, 'generate'): | |
| dist.fsdp.register_fsdp_forward_method(self.model, 'generate') | |
| if self.is_fsdp_enabled: | |
| # Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA | |
| if hasattr(self.model, 'generate') and hasattr(getattr(dist, 'fsdp', None), 'register_fsdp_forward_method'): | |
| dist.fsdp.register_fsdp_forward_method(self.model, 'generate') |
| pc = getattr(self.accelerator, 'parallelism_config', None) | ||
| if pc is not None and pc.sp_backend == 'deepspeed' and pc.sp_enabled: | ||
| train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model) |
There was a problem hiding this comment.
self.accelerator.deepspeed_ulysses_dl_adapter will raise an AttributeError on older versions of accelerate where DeepSpeed Ulysses is not supported. We should guard it with hasattr to ensure backward compatibility.
| pc = getattr(self.accelerator, 'parallelism_config', None) | |
| if pc is not None and pc.sp_backend == 'deepspeed' and pc.sp_enabled: | |
| train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model) | |
| pc = getattr(self.accelerator, 'parallelism_config', None) | |
| if pc is not None and pc.sp_backend == 'deepspeed' and pc.sp_enabled: | |
| if hasattr(self.accelerator, 'deepspeed_ulysses_dl_adapter'): | |
| train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model) |
| if not self.accelerator.optimizer_step_was_skipped: | ||
| # Delay optimizer scheduling until metrics are generated | ||
| lr_scheduler_cls = [torch.optim.lr_scheduler.ReduceLROnPlateau] | ||
| if GreedyLR is not None: | ||
| lr_scheduler_cls.append(GreedyLR) | ||
| if not isinstance(self.lr_scheduler, tuple(lr_scheduler_cls)): | ||
| self.lr_scheduler.step() |
There was a problem hiding this comment.
If self.lr_scheduler is None (e.g., if no learning rate scheduler is configured), isinstance(None, tuple(lr_scheduler_cls)) will evaluate to False, and calling self.lr_scheduler.step() will raise an AttributeError. We should add a self.lr_scheduler is not None check to prevent crashes.
| if not self.accelerator.optimizer_step_was_skipped: | |
| # Delay optimizer scheduling until metrics are generated | |
| lr_scheduler_cls = [torch.optim.lr_scheduler.ReduceLROnPlateau] | |
| if GreedyLR is not None: | |
| lr_scheduler_cls.append(GreedyLR) | |
| if not isinstance(self.lr_scheduler, tuple(lr_scheduler_cls)): | |
| self.lr_scheduler.step() | |
| if not self.accelerator.optimizer_step_was_skipped and self.lr_scheduler is not None: | |
| # Delay optimizer scheduling until metrics are generated | |
| lr_scheduler_cls = [torch.optim.lr_scheduler.ReduceLROnPlateau] | |
| if GreedyLR is not None: | |
| lr_scheduler_cls.append(GreedyLR) | |
| if not isinstance(self.lr_scheduler, tuple(lr_scheduler_cls)): | |
| self.lr_scheduler.step() |
| if args.eval_on_start: | ||
| self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) |
There was a problem hiding this comment.
args.eval_on_start will raise an AttributeError on older versions of transformers where eval_on_start is not defined in TrainingArguments. We should use getattr(args, 'eval_on_start', False) to ensure backward compatibility.
| if args.eval_on_start: | |
| self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) | |
| if getattr(args, 'eval_on_start', False): | |
| self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) |
| self._finish_current_push() | ||
|
|
There was a problem hiding this comment.
self._finish_current_push() will raise an AttributeError on older versions of transformers where this method is not defined. We should guard it with hasattr to ensure backward compatibility.
| self._finish_current_push() | |
| # Wait for the checkpoint to be uploaded. | |
| if hasattr(self, '_finish_current_push'): | |
| self._finish_current_push() |
| if self.neftune_noise_alpha is not None: | ||
| if hasattr(self, '_deactivate_neftune'): | ||
| self._deactivate_neftune(self.model) | ||
| else: | ||
| from transformers.trainer import deactivate_neftune | ||
| deactivate_neftune(self.model, self.neftune_hook_handle, self.accelerator) |
There was a problem hiding this comment.
self.neftune_noise_alpha will raise an AttributeError on older versions of transformers (< 4.35.0) where NEFTune is not supported. We should use getattr(self, 'neftune_noise_alpha', None) to ensure backward compatibility.
| if self.neftune_noise_alpha is not None: | |
| if hasattr(self, '_deactivate_neftune'): | |
| self._deactivate_neftune(self.model) | |
| else: | |
| from transformers.trainer import deactivate_neftune | |
| deactivate_neftune(self.model, self.neftune_hook_handle, self.accelerator) | |
| if getattr(self, 'neftune_noise_alpha', None) is not None: | |
| if hasattr(self, '_deactivate_neftune'): | |
| self._deactivate_neftune(self.model) | |
| else: | |
| from transformers.trainer import deactivate_neftune | |
| deactivate_neftune(self.model, self.neftune_hook_handle, self.accelerator) |
| """Run the actual training loop: forward, backward, optimizer step, logging, and checkpointing.""" | ||
| # reset everything | ||
| self.accelerator.free_memory() |
There was a problem hiding this comment.
The args parameter is typed as TrainingArguments | None = None in the method signature, but it is accessed directly as args.auto_find_batch_size on line 116. If args is passed as None, this will raise an AttributeError. We should initialize args = args or self.args at the start of the method to prevent potential crashes.
"""Run the actual training loop: forward, backward, optimizer step, logging, and checkpointing."""
args = args or self.args
# reset everything
self.accelerator.free_memory()| if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: | ||
| for checkpoint in checkpoints_sorted: | ||
| if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): | ||
| logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") | ||
| shutil.rmtree(checkpoint, ignore_errors=True) |
There was a problem hiding this comment.
Using os.path.samefile can raise a FileNotFoundError if either of the paths does not exist on the local disk (e.g., if the best checkpoint was already cleaned up or in multi-node environments). Comparing resolved paths using os.path.realpath is much safer and avoids throwing exceptions.
| if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: | |
| for checkpoint in checkpoints_sorted: | |
| if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): | |
| logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") | |
| shutil.rmtree(checkpoint, ignore_errors=True) | |
| if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: | |
| for checkpoint in checkpoints_sorted: | |
| if os.path.realpath(checkpoint) != os.path.realpath(self.state.best_model_checkpoint): | |
| logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") | |
| shutil.rmtree(checkpoint, ignore_errors=True) |
No description provided.