flambe.learn.train
¶
Module Contents¶
-
class
flambe.learn.train.
Trainer
(dataset: Dataset, train_sampler: Sampler, val_sampler: Sampler, model: Module, loss_fn: Metric, metric_fn: Metric, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, iter_scheduler: Optional[_LRScheduler] = None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, max_grad_norm: Optional[float] = None, max_grad_abs_val: Optional[float] = None, extra_validation_metrics: Optional[Iterable[Metric]] = None, extra_training_metrics: Optional[Iterable[Metric]] = None, extra_training_metrics_log_interval: Optional[int] = None)[source]¶ Bases:
flambe.compile.Component
Implement a Trainer block.
A Trainer takes as input data, model and optimizer, and executes training incrementally in run.
Note that it is important that a trainer run be long enough to not increase overhead, so at least a few seconds, and ideally multiple minutes.
-
_batch_to_device
(self, batch: Tuple[torch.Tensor, ...])[source]¶ Move the current batch on the correct device.
Can be overriden if a batch doesn’t follow the expected structure. For example if the batch is a dictionary.
Parameters: batch (Tuple[torch.Tensor, ..]) – The batch to train on.
-
_compute_loss
(self, batch: Tuple[torch.Tensor, ...])[source]¶ Compute the loss given a single batch DEPRECATED, only exists for legacy compatibility with custom trainers
Parameters: batch (Tuple[torch.Tensor, ..]) – The batch to train on.
-
_compute_batch
(self, batch: Tuple[torch.Tensor, ...], metrics: List[Tuple] = [])[source]¶ Computes a batch.
Does a model forward pass over a batch, and returns prediction, target and loss.
Parameters: batch (Tuple[torch.Tensor, ..]) – The batch to train on.
-
static
_log_metrics
(log_prefix: str, metrics_with_states: List[Tuple], global_step: int)[source]¶ Logs all provided metrics
Iterates through the provided list of metrics with states, finalizes the metric, and logs it.
Parameters: - log_prefix (str) – A string, such as a tensorboard prefix
- metrics_with_states (List[Tuple[Metric, Dict]]) – a list of metric-state tuples
- global_step (int) – the global step for loggin
-
_aggregate_preds
(self, data_iterator: Iterator)[source]¶ DEPRECATED Aggregate the predicitons, targets and mean loss for the dataset.
Parameters: data_iterator (Iterator) – Batches of data. Returns: - Tuple[torch.tensor, torch.tensor, float] – The predictions, targets and mean loss.
- DEPRECATED; only existed to aggregate for the metric functions.
- The metric functions do this in-place now.
-
run
(self)[source]¶ Evaluate and then train until the next checkpoint
Returns: Whether the component should continue running. Return type: bool
-
metric
(self)[source]¶ Override this method to enable scheduling.
Returns: The metric to compare computable variants. Return type: float
-