Module epiclass.core.trainer
Trainer class extensions module
Functions
def define_callbacks(early_stop_limit: int | None, show_summary=True)
-
Returns list of PyTorch trainer callbacks. RichModelSummary, EarlyStopping, ModelCheckpoint
Will only save last epoch model if there is no early stopping.
Classes
class MyTrainer (general_log_dir: str, model, **kwargs)
-
Personalized trainer
Metrics expect probabilities and not logits.
Expand source code
class MyTrainer(pl.Trainer): """Personalized trainer""" def __init__(self, general_log_dir: str, model, **kwargs): """Metrics expect probabilities and not logits.""" super().__init__(**kwargs) self.best_checkpoint_file = Path(general_log_dir) / "best_checkpoint.list" self.my_model = model self.batch_size = None def fit(self, *args, verbose=True, **kwargs): """Base pl.Trainer.fit function, but also prints the batch size.""" self.batch_size = kwargs["train_dataloaders"].batch_size if verbose: print(f"Training batch size : {self.batch_size}") super().fit(*args, **kwargs) def save_model_path(self): """Save best checkpoint path to a file.""" try: model_path = self.checkpoint_callback.best_model_path # type: ignore print(f"Saving model to {model_path}") with open(self.best_checkpoint_file, "a", encoding="utf-8") as ckpt_file: ckpt_file.write(f"{model_path} {datetime.now()}\n") except AttributeError: print("Cannot save model, no checkpoint callback.") def print_hyperparameters(self): """Print training hyperparameters.""" print("--TRAINING HYPERPARAMETERS--") print(f"L2 scale : {self.my_model.l2_scale}") print(f"Dropout rate : {self.my_model.dropout_rate}") print(f"Learning rate : {self.my_model.learning_rate}") try: stop_callback = self.early_stopping_callback print(f"Patience : {stop_callback.patience}") # type: ignore print(f"Monitored value : {stop_callback.monitor}") # type: ignore except AttributeError: print("No early stopping.")
Ancestors
- pytorch_lightning.trainer.trainer.Trainer
Methods
def fit(self, *args, verbose=True, **kwargs)
-
Base pl.Trainer.fit function, but also prints the batch size.
def print_hyperparameters(self)
-
Print training hyperparameters.
def save_model_path(self)
-
Save best checkpoint path to a file.