日本語版はこちらです。The Japanese version can be found here. qiita.com
What is Stratified Splitting?
When you do machine learning, you often separate the data set into two parts: training data and validation data. Especially in the case of classification tasks, it is preferable to split the data so that the ratio of each class in the split data is the same as that of the original one. This way of splitting data is called Stratified Splitting.
Example in PyTorch
In scikit-learn, you can do a Stratified Split by passing the stratify
option to the function sklearn.model_selection.train_test_split()
.
PyTorch, on the other hand, does not have such a function. So, we will use a function in scikit-learn to achieve a Stratified Split in PyTorch.
An example code would look like the following.
import torch import torchvision.transforms as transforms from sklearn.model_selection import train_test_split transformer = transforms.Compose([ transforms.ToTensor(), ]) # load images dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer) # split the data set into train and validation train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets) train_dataset = torch.utils.data.Subset(dataset, train_indices) val_dataset = torch.utils.data.Subset(dataset, val_indices) # create DataLoader train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4) val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)
Let me explain the details.
First, create a Dataset
by importing images with ImageFolder
.
transformer = transforms.Compose([
transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)
Next, we split the data with train_test_split()
, but since we can't pass the Dataset
to it directly, we generate an index array of Datasets [0,1,2,3,...num_data]
with list(range(len(dataset.targets)))
and pass it instead. Then, by passing the class labels dataset.targets
as a stratify
option, we can split the index array into training and validation while keeping the class label ratio of the original data.
train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)
Since it is just the index array that we have split just now, we will split the dataset itself based on the index array.
Subset
, as the name implies, is a class for creating subsets of data. By passing the original Dataset
and the index array to it, we can get the Dataset
corresponding to the index.
train_dataset = torch.utils.data.Subset(dataset, train_indices) val_dataset = torch.utils.data.Subset(dataset, val_indices)
After that, just pass the Dataset
to the DataLoader
as usual.
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4) val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)