主に感想。

を書いています。

Stratified train/test-split in PyTorch

日本語版はこちらです。The Japanese version can be found here. qiita.com

What is Stratified Splitting?

When you do machine learning, you often separate the data set into two parts: training data and validation data. Especially in the case of classification tasks, it is preferable to split the data so that the ratio of each class in the split data is the same as that of the original one. This way of splitting data is called Stratified Splitting.

Example in PyTorch

In scikit-learn, you can do a Stratified Split by passing the stratify option to the function sklearn.model_selection.train_test_split().

PyTorch, on the other hand, does not have such a function. So, we will use a function in scikit-learn to achieve a Stratified Split in PyTorch.

An example code would look like the following.

import torch
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

transformer = transforms.Compose([
    transforms.ToTensor(),
])

# load images
dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)

# split the data set into train and validation
train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

# create DataLoader
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)

Let me explain the details.

First, create a Dataset by importing images with ImageFolder.

transformer = transforms.Compose([
    transforms.ToTensor(),
])

dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)

Next, we split the data with train_test_split(), but since we can't pass the Dataset to it directly, we generate an index array of Datasets [0,1,2,3,...num_data] with list(range(len(dataset.targets))) and pass it instead. Then, by passing the class labels dataset.targets as a stratify option, we can split the index array into training and validation while keeping the class label ratio of the original data.

train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)

Since it is just the index array that we have split just now, we will split the dataset itself based on the index array. Subset, as the name implies, is a class for creating subsets of data. By passing the original Dataset and the index array to it, we can get the Dataset corresponding to the index.

train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

After that, just pass the Dataset to the DataLoader as usual.

train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)