flambe.metric.loss.cross_entropy
¶
Module Contents¶
-
class
flambe.metric.loss.cross_entropy.
MultiLabelCrossEntropy
(weight: Optional[torch.Tensor] = None, ignore_index: Optional[int] = None, reduction: str = 'mean')[source]¶ Bases:
flambe.metric.metric.Metric
-
compute
(self, pred: torch.Tensor, target: torch.Tensor)[source]¶ Computes the multilabel cross entropy loss.
Parameters: - pred (torch.Tensor) – input logits of shape (B x N)
- target (torch.LontTensor) – target tensor of shape (B x N)
Returns: loss – Multi label cross-entropy loss, of shape (B)
Return type: torch.Tensor
-