flambe.vision.classification
¶
Package Contents¶
-
class
flambe.vision.classification.
MNISTDataset
(train_images: np.ndarray = None, train_labels: np.ndarray = None, test_images: np.ndarray = None, test_labels: np.ndarray = None, val_ratio: Optional[float] = 0.2, seed: Optional[int] = None)[source]¶ Bases:
flambe.dataset.Dataset
The official MNIST dataset.
-
data_type
¶
-
URL
= http://yann.lecun.com/exdb/mnist/¶
-
train
:List[Tuple[torch.Tensor, torch.Tensor]]¶ Returns the training data
-
val
:List[Tuple[torch.Tensor, torch.Tensor]]¶ Returns the validation data
-
test
:List[Tuple[torch.Tensor, torch.Tensor]]¶ Returns the test data
-
classmethod
from_path
(cls, train_images_path: str, train_labels_path: str, test_images_path: str, test_labels_path: str, val_ratio: Optional[float] = 0.2, seed: Optional[int] = None)¶ Initialize the MNISTDataset from local files.
Parameters: - train_images_path (str) – path to the train images file in the idx format
- train_labels_path (str) – path to the train labels file in the idx format
- test_images_path (str) – path to the test images file in the idx format
- test_labels_path (str) – path to the test labels file in the idx format
- val_ratio (Optional[float]) – validation set ratio. Default 0.2
- seed (Optional[int]) – random seed for the validation set split
-
classmethod
_parse_local_gzipped_idx
(cls, path: str)¶ Parse a local gzipped idx file
-
classmethod
_parse_downloaded_idx
(cls, url: str)¶ Parse a downloaded idx file
-
classmethod
_parse_idx
(cls, data: bytes)¶ Parse an idx filie
-
-
class
flambe.vision.classification.
ImageClassifier
(encoder: Module, output_layer: Module)[source]¶ Bases:
flambe.nn.Module
Implements a simple image classifier.
This classifier consists of an encocder module, followed by a fully connected output layer that outputs a probability distribution.
-
encoder
¶ The encoder layer
Type: Moodule
-
forward
(self, data: Tensor, target: Optional[Tensor] = None)¶ Run a forward pass through the network.
Parameters: - data (Tensor) – The input data
- target (Tensor, optional) – The input targets, optional
Returns: The output predictions, and optionally the targets
Return type: Union[Tensor, Tuple[Tensor, Tensor]
-