长生栈 长生栈
首页
  • 编程语言

    • C语言
    • C++
    • Java
    • Python
  • 数据结构和算法

    • 全排列算法实现
    • 动态规划算法
  • CMake
  • gitlab 安装和配置
  • docker快速搭建wordpress
  • electron+react开发和部署
  • Electron-创建你的应用程序
  • ImgUI编译环境
  • 搭建图集网站
  • 使用PlantUml画时序图
  • 友情链接
关于
收藏
  • 分类
  • 标签
  • 归档
GitHub (opens new window)

Living Team

编程技术分享
首页
  • 编程语言

    • C语言
    • C++
    • Java
    • Python
  • 数据结构和算法

    • 全排列算法实现
    • 动态规划算法
  • CMake
  • gitlab 安装和配置
  • docker快速搭建wordpress
  • electron+react开发和部署
  • Electron-创建你的应用程序
  • ImgUI编译环境
  • 搭建图集网站
  • 使用PlantUml画时序图
  • 友情链接
关于
收藏
  • 分类
  • 标签
  • 归档
GitHub (opens new window)
  • 初识Python
  • 变量和运算符
  • python之正则表达式
  • 机器学习pytorch虚拟环境搭建
  • AI

    • pytorch-quickstart
    • pytorch-DATASETS & DATALOADERS
      • Loading a Dataset
      • Iterating and Visualizing the Dataset
      • Creating a Custom Dataset for your files
        • __init__
        • __len__
        • __getitem__
      • Preparing your data for training with DataLoaders
      • Iterate through the DataLoader
    • pytorch-TENSORS
    • pytorch-BUILD THE NEURAL NETWORK
    • pytorch-OPTIMIZING MODEL PARAMETERS
    • pytorch-SAVE AND LOAD THE MODEL
    • YOLO - You only look once
    • 知识蒸馏
  • Python
  • AI
DC Wang
2022-03-25
目录

pytorch-DATASETS & DATALOADERS

# DATASETS & DATALOADERS

PyTorch 提供了两个数据原语:torch.utils.data.DataLoader 和torch.utils.data.Dataset,允许您使用预加载的数据集以及您自己的数据。Dataset存储样本及其对应的标签,DataLoader在 Dataset 周围包装了一个可迭代对象,以便轻松访问样本。

PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集是 torch.utils.data.Dataset 的子类,并实现了特定于特定数据的功能。 它们可用于对您的模型进行原型设计和基准测试。 您可以在此处找到它们:Image Datasets (opens new window), Text Datasets (opens new window), and Audio Datasets (opens new window)

# Loading a Dataset

以下是如何从 TorchVision 加载 Fashion-MNIST (opens new window) 数据集的示例。 Fashion-MNIST 是 Zalando 文章图像的数据集,由 60,000 个训练示例和 10,000 个测试示例组成。 每个示例都包含 28×28 灰度图像和来自 10 个类别之一的相关标签。

我们使用以下参数加载 FashionMNIST 数据集 (opens new window):

  • root 是存储 train/test data 的路径。
  • train 指定是训练数据集还是测试数据集,true是训练数据集,false是测试数据集。
  • download=True 如果在 root 中不可用,则从 Internet 下载数据。
  • transform and target_transform specify the feature and label transformations,如转换图片为tensor,转换target为softmax target。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

Out:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

# Iterating and Visualizing the Dataset

We can index Datasets manually like a list: training_data[index]. We use matplotlib to visualize some samples in our training data.

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

../../_images/sphx_glr_data_tutorial_001.png

# Creating a Custom Dataset for your files

A custom Dataset class must implement three functions: __init__,__len__, and__getitem__. Take a look at this implementation; the FashionMNIST images are stored in a directory img_dir, and their labels are stored separately in a CSV file annotations_file.

In the next sections, we’ll break down what’s happening in each of these functions.

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

# __init__

The __init__ function is run once when instantiating the Dataset object. We initialize the directory containing the images, the annotations file, and both transforms (covered in more detail in the next section).

The labels.csv file looks like:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
1
2
3
4
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform
1
2
3
4
5

# __len__

The __len__ function returns the number of samples in our dataset.

Example:

def __len__(self):
    return len(self.img_labels)
1
2

# __getitem__

The __getitem__ function loads and returns a sample from the dataset at the given index idx. Based on the index, it identifies the image’s location on disk, converts that to a tensor using read_image, retrieves the corresponding label from the csv data in self.img_labels, calls the transform functions on them (if applicable), and returns the tensor image and corresponding label in a tuple.

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label
1
2
3
4
5
6
7
8
9

# Preparing your data for training with DataLoaders

Dataset 检索我们数据集的特征并一次标记一个样本。 在训练模型时,我们通常希望以“小批量”的形式传递样本,在每个 epoch 重新洗牌(shuffle)以减少模型过拟合,并使用 Python 的“多处理”来加速数据检索。

DataLoader 是一个迭代器,它通过一个简单的 API 为我们抽象了这种复杂性。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
1
2
3
4

# Iterate through the DataLoader

We have loaded that dataset into the DataLoader and can iterate through the dataset as needed. Each iteration below returns a batch of train_features and train_labels (containing batch_size=64 features and labels respectively). Because we specified shuffle=True, after we iterate over all batches the data is shuffled (for finer-grained control over the data loading order, take a look at Samplers (opens new window)).

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
1
2
3
4
5
6
7
8
9

../../_images/sphx_glr_data_tutorial_002.png

Out:

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 4
1
2
3
编辑 (opens new window)
#Python#AI#机器学习
上次更新: 2022/10/03, 09:24:26
pytorch-quickstart
pytorch-TENSORS

← pytorch-quickstart pytorch-TENSORS→

最近更新
01
ESP32-网络摄像头方案
06-14
02
ESP32-PWM驱动SG90舵机
06-14
03
ESP32-实时操作系统freertos
06-14
更多文章>
Theme by Vdoing | Copyright © 2019-2025 DC Wang All right reserved | 辽公网安备 21021102001125号 | 吉ICP备20001966号-2
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式