flambe.learn.train
¶
Module Contents¶
-
class
flambe.learn.train.
Trainer
(dataset: Dataset, train_sampler: Sampler, val_sampler: Sampler, model: Module, loss_fn: Metric, metric_fn: Metric, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, max_grad_norm: Optional[float] = None, max_grad_abs_val: Optional[float] = None, extra_validation_metrics: Optional[List[Metric]] = None)[source]¶ Bases:
flambe.compile.Component
Implement a Trainer block.
A Trainer takes as input data, model and optimizer, and executes training incrementally in run.
Note that it is important that a trainer run be long enough to not increase overhead, so at least a few seconds, and ideally multiple minutes.
-
_batch_to_device
(self, batch: Tuple[torch.Tensor, ...])[source]¶ Move the current batch on the correct device.
Can be overriden if a batch doesn’t follow the expected structure. For example if the batch is a dictionary.
Parameters: batch (Tuple[torch.Tensor, ..]) – The batch to train on.
-
_compute_loss
(self, batch: Tuple[torch.Tensor, ...])[source]¶ Compute the loss given a single batch
Parameters: batch (Tuple[torch.Tensor, ..]) – The batch to train on.
-
_aggregate_preds
(self, data_iterator: Iterator)[source]¶ Aggregate the predicitons and targets for the dataset.
Parameters: data_iterator (Iterator) – Batches of data. Returns: The predictions and targets. Return type: Tuple[torch.tensor, torch.tensor]
-
run
(self)[source]¶ Train until the next checkpoint, and evaluate.
Returns: Whether the computable is not yet complete. Return type: bool
-
metric
(self)[source]¶ Override this method to enable scheduling.
Returns: The metric to compare computable variants. Return type: float
-