flambe.learn
¶
Submodules¶
Package Contents¶
-
class
flambe.learn.
Trainer
(dataset: Dataset, train_sampler: Sampler, val_sampler: Sampler, model: Module, loss_fn: Metric, metric_fn: Metric, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, iter_scheduler: Optional[_LRScheduler] = None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, max_grad_norm: Optional[float] = None, max_grad_abs_val: Optional[float] = None, extra_validation_metrics: Optional[Iterable[Metric]] = None, extra_training_metrics: Optional[Iterable[Metric]] = None, extra_training_metrics_log_interval: Optional[int] = None)[source]¶ Bases:
flambe.compile.Component
Implement a Trainer block.
A Trainer takes as input data, model and optimizer, and executes training incrementally in run.
Note that it is important that a trainer run be long enough to not increase overhead, so at least a few seconds, and ideally multiple minutes.
-
validation_metrics
¶ Adding property for backwards compatibility
-
_create_train_iterator
(self)¶
-
_batch_to_device
(self, batch: Tuple[torch.Tensor, ...])¶ Move the current batch on the correct device.
Can be overriden if a batch doesn’t follow the expected structure. For example if the batch is a dictionary.
Parameters: batch (Tuple[torch.Tensor, ..]) – The batch to train on.
-
_compute_loss
(self, batch: Tuple[torch.Tensor, ...])¶ Compute the loss given a single batch DEPRECATED, only exists for legacy compatibility with custom trainers
Parameters: batch (Tuple[torch.Tensor, ..]) – The batch to train on.
-
_compute_batch
(self, batch: Tuple[torch.Tensor, ...], metrics: List[Tuple] = [])¶ Computes a batch.
Does a model forward pass over a batch, and returns prediction, target and loss.
Parameters: batch (Tuple[torch.Tensor, ..]) – The batch to train on.
-
static
_log_metrics
(log_prefix: str, metrics_with_states: List[Tuple], global_step: int)¶ Logs all provided metrics
Iterates through the provided list of metrics with states, finalizes the metric, and logs it.
Parameters: - log_prefix (str) – A string, such as a tensorboard prefix
- metrics_with_states (List[Tuple[Metric, Dict]]) – a list of metric-state tuples
- global_step (int) – the global step for loggin
-
_train_step
(self)¶ Run a training step over the training data.
-
_aggregate_preds
(self, data_iterator: Iterator)¶ DEPRECATED Aggregate the predicitons, targets and mean loss for the dataset.
Parameters: data_iterator (Iterator) – Batches of data. Returns: - Tuple[torch.tensor, torch.tensor, float] – The predictions, targets and mean loss.
- DEPRECATED; only existed to aggregate for the metric functions.
- The metric functions do this in-place now.
-
_eval_step
(self)¶ Run an evaluation step over the validation data.
-
run
(self)¶ Evaluate and then train until the next checkpoint
Returns: Whether the component should continue running. Return type: bool
-
metric
(self)¶ Override this method to enable scheduling.
Returns: The metric to compare computable variants. Return type: float
-
_state
(self, state_dict: State, prefix: str, local_metadata: Dict[str, Any])¶
-
_load_state
(self, state_dict: State, prefix: str, local_metadata: Dict[str, Any], strict: bool, missing_keys: List[Any], unexpected_keys: List[Any], error_msgs: List[Any])¶
-
classmethod
precompile
(cls, **kwargs)¶ Override initialization.
Ensure that the model is compiled and pushed to the right device before its parameters are passed to the optimizer.
-
-
class
flambe.learn.
Evaluator
(dataset: Dataset, model: Module, metric_fn: Metric, eval_sampler: Optional[Sampler] = None, eval_data: str = 'test', device: Optional[str] = None)[source]¶ Bases:
flambe.compile.Component
Implement an Evaluator block.
An Evaluator takes as input data, and a model and executes the evaluation. This is a single step Component object.
-
run
(self, block_name: str = None)¶ Run the evaluation.
Returns: Whether the component should continue running. Return type: bool
-
metric
(self)¶ Override this method to enable scheduling.
Returns: The metric to compare computable varients Return type: float
-
-
class
flambe.learn.
Script
(script: str, args: List[Any], kwargs: Optional[Dict[str, Any]] = None, output_dir_arg: Optional[str] = None)[source]¶ Bases:
flambe.compile.Component
Implement a Script computable.
The obejct can be used to turn any script into a Flambé computable. This is useful when you want to rapidly integrate code. Note however that this computable does not enable checkpointing or linking to internal components as it does not have any attributes.
To use this object, your script needs to be in a pip installable, containing all dependencies. The script is run with the following command:
python -m script.py --arg1 value1 --arg2 value2
-
run
(self)¶ Run the evaluation.
Returns: Report dictionary to use for logging Return type: Dict[str, float]
-