flambe.learn.distillation
¶
Module Contents¶
-
class
flambe.learn.distillation.
DistillationTrainer
(dataset: Dataset, train_sampler: Sampler, dev_sampler: Sampler, teacher_model: Module, student_model: Module, loss_fn: Metric, metric_fn: Metric, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, extra_validation_metrics: Optional[List[Metric]] = None, teacher_columns: Optional[Tuple[int, ...]] = None, student_columns: Optional[Tuple[int, ...]] = None, alpha_kl: float = 0.5, temperature: int = 1, unlabel_dataset: Optional[Dataset] = None, unlabel_sampler: Optional[Sampler] = None)[source]¶ Bases:
flambe.learn.Trainer
Implement a Distillation Trainer.
Perform knowledge distillation between a teacher and a student model. Note that the model outputs are expected to be raw logits. Make sure that you are not applying a softmax after the decoder. You can replace the traditional Decoder with a MLPEncoder.
-
_compute_loss
(self, batch: Tuple[torch.Tensor, ...])[source]¶ Compute the loss for a single batch
Important: the student and teacher output predictions must be the raw logits, so ensure that your decoder object is step with take_log=False.
Parameters: batch (Tuple[torch.Tensor, ..]) – The batch to train on Returns: The computed loss Return type: torch.Tensor
-