Pytorch Document

Começar. É Gratuito
ou inscrever-se com seu endereço de e-mail
Pytorch Document por Mind Map: Pytorch Document

1. DATASETS

1.1. Loading a Dataset

1.1.1. the following parameters: 1 root is the path where the train/test data is stored, 2 train specifies training or test dataset, 3 download=True downloads the data from the internet if it’s not available at root. 4 transform and target_transform specify the feature and label transformations

1.1.2. import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() )

1.2. Creating a Custom Dataset for your files

1.2.1. import os import pandas as pd from torchvision.io import read_image class CustomImageDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label

1.3. Datasets & DataLoaders — PyTorch Tutorials 1.9.0+cu102 documentation

2. DATALOADERS

2.1. Preparing your data for training with DataLoaders

2.1.1. from torch.utils.data import DataLoader train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

3. Save and load model

3.1. torch.save

3.1.1. state_dict

3.1.1.1. A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor

3.1.1.1.1. convolutional layers, linear layers, etc. BN

3.1.1.2. model are contained in the model’s parameters (accessed with model.parameters())

3.2. torch.load

3.3. torch.nn.models.load_state_dict

3.4. Saving and Loading Models — PyTorch Tutorials 1.9.1+cu102 documentation

4. torchvision

4.1. transform

4.1.1. torchvision.transforms — Torchvision 0.10.0 documentation

5. model.train and model.eval()

5.1. Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别_初识-CV的博客-CSDN博客