flambe.metric
¶
Subpackages¶
Submodules¶
Package Contents¶
-
class
flambe.metric.
Metric
[source]¶ Bases:
flambe.compile.Component
Base Metric interface.
Objects implementing this interface should take in a sequence of examples and provide as output a processd list of the same size.
-
compute
(self, pred: torch.Tensor, target: torch.Tensor)¶ Computes the metric over the given prediction and target.
Parameters: - pred (torch.Tensor) – The model predictions
- target (torch.Tensor) – The ground truth targets
Returns: The computed metric
Return type: torch.Tensor
-
__call__
(self, *args, **kwargs)¶ Makes Featurizer a callable.
-
__str__
(self)¶ Return the name of the Metric (for use in logging).
-
-
class
flambe.metric.
MultiLabelCrossEntropy
(weight: Optional[torch.Tensor] = None, ignore_index: Optional[int] = None, reduction: str = 'mean')[source]¶ Bases:
flambe.metric.metric.Metric
-
compute
(self, pred: torch.Tensor, target: torch.Tensor)¶ Computes the multilabel cross entropy loss.
Parameters: - pred (torch.Tensor) – input logits of shape (B x N)
- target (torch.LontTensor) – target tensor of shape (B x N)
Returns: loss – Multi label cross-entropy loss, of shape (B)
Return type: torch.Tensor
-
-
class
flambe.metric.
MultiLabelNLLLoss
(weight: Optional[torch.Tensor] = None, ignore_index: Optional[int] = None, reduction: str = 'mean')[source]¶ Bases:
flambe.metric.metric.Metric
-
compute
(self, pred: torch.Tensor, target: torch.Tensor)¶ Computes the Negative log likelihood loss for multilabel.
Parameters: - pred (torch.Tensor) – input logits of shape (B x N)
- target (torch.LontTensor) – target tensor of shape (B x N)
Returns: loss – Multi label negative log likelihood loss, of shape (B)
Return type: torch.float
-
-
class
flambe.metric.
Accuracy
[source]¶ Bases:
flambe.metric.metric.Metric
-
compute
(self, pred: torch.Tensor, target: torch.Tensor)¶ Computes the loss.
Parameters: - pred (Tensor) – input logits of shape (B x N)
- target (LontTensor) – target tensor of shape (B) or (B x N)
Returns: accuracy – single label accuracy, of shape (B)
Return type: torch.Tensor
-
-
class
flambe.metric.
Perplexity
[source]¶ Bases:
flambe.metric.Metric
-
compute
(self, pred: torch.Tensor, target: torch.Tensor)¶ Compute the preplexity given the input and target.
Parameters: - pred (torch.Tensor) – input logits of shape (B x N)
- target (torch.LontTensor) – target tensor of shape (B x N)
Returns: Output perplexity
Return type: torch.float
-
-
class
flambe.metric.
AUC
(max_fpr=1.0)[source]¶ Bases:
flambe.metric.metric.Metric
-
compute
(self, pred: torch.Tensor, target: torch.Tensor)¶ Compute AUC at the given max false positive rate.
Parameters: - pred (torch.Tensor) – The model predictions
- target (torch.Tensor) – The binary targets
Returns: The computed AUC
Return type: torch.Tensor
-
-
class
flambe.metric.
BinaryPrecision
[source]¶ Bases:
flambe.metric.dev.binary.BinaryMetric
Compute Binary Precision.
An example is considered negative when its score is below the specified threshold. Binary precition is computed as follows:
` |True positives| / |True Positives| + |False Positives| `
-
compute_binary
(self, pred: torch.Tensor, target: torch.Tensor)¶ Compute binary precision.
Parameters: - pred (torch.Tensor) – Predictions made by the model. It should be a probability 0 <= p <= 1 for each sample, 1 being the positive class.
- target (torch.Tensor) – Ground truth. Each label should be either 0 or 1.
Returns: The computed binary metric
Return type: torch.float
-
-
class
flambe.metric.
BinaryRecall
[source]¶ Bases:
flambe.metric.dev.binary.BinaryMetric
Compute binary recall.
An example is considered negative when its score is below the specified threshold. Binary precition is computed as follows:
` |True positives| / |True Positives| + |False Negatives| `
-
compute_binary
(self, pred: torch.Tensor, target: torch.Tensor)¶ Compute binary recall.
Parameters: - pred (torch.Tensor) – Predictions made by the model. It should be a probability 0 <= p <= 1 for each sample, 1 being the positive class.
- target (torch.Tensor) – Ground truth. Each label should be either 0 or 1.
Returns: The computed binary metric
Return type: torch.float
-