flambe.nlp.fewshot.model
¶
Module Contents¶
-
class
flambe.nlp.fewshot.model.
PrototypicalTextClassifier
(embedder: Embedder, distance: str = 'euclidean', detach_mean: bool = False)[source]¶ Bases:
flambe.nn.Module
Implements a standard classifier.
The classifier is composed of an encoder module, followed by a fully connected output layer, with a dropout layer in between.
-
decoder
¶ the decoder layer
Type: Decoder
-
drop
¶ the dropout layer
Type: nn.Dropout
-
compute_prototypes
(self, support: Tensor, label: Tensor)[source]¶ Set the current prototypes used for classification.
Parameters: - data (torch.Tensor) – Input encodings
- label (torch.Tensor) – Corresponding labels
-
forward
(self, query: Tensor, query_label: Optional[Tensor] = None, support: Optional[Tensor] = None, support_label: Optional[Tensor] = None)[source]¶ Run a forward pass through the network.
Parameters: data (Tensor) – The input data Returns: The output predictions Return type: Union[Tensor, Tuple[Tensor, Tensor]]
-