PyTorchでImageFolderを使わないやり方
前回は、ImageFolderを使用してデータローダーを作成しましたが、あえてImageFolderを使用しないやり方を考えてみます。
なぜこの記事を書くことにしたのかというと、ディープラーニングで試しにモデルを作って動かしてみる時に、ストックしてある大量の画像データから、種類や枚数を選んでデータローダーを作成したいと思うことがあったからです。
単純に画像データを新しいディレクトリにコピーしてImageFolderを使えばいいのですが、学習モデルごとにあちこちにデータをコピーするのは煩雑でスマートじゃないという、ただそれだけの理由です。
データセットの作成
まずは、お決まりのモジュールインポートから。
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
次に、取り込む画像データファイルのフルパスのリストを作成します。前回作成したMNISTの手書きデータのPNG画像を使用します。データディレクトリの構造は、以下のようになっています。
./MNIST_PNG
├── 0
│ ├── mnist_0_0.png
│ ├── mnist_0_1.png
│ ├── ・・・
│
├── 2
│ ├── mnist_2_0.png
│ ├── mnist_2_1.png
│ ├── ・・・
・・・
この中から、今回は1と7のイメージ画像を200枚づつ使用してデータローダーを作成していきます。
画像データファイルの名前は規則的に付与されているので、forループでフルパスのリストをリストに追加していけばよさそうです。
train_img_list = list()
for img_idx in range(200):
img_path = "./MNIST_PNG/1/mnist_1_" + str(img_idx) + '.png'
train_img_list.append(img_path)
img_path = "./MNIST_PNG/7/mnist_7_" + str(img_idx) + '.png'
train_img_list.append(img_path)
次に、データ変換クラスとデータセットを作成するクラスを定義します。
class ImageTransform():
def __init__(self, mean, std):
self.data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
def __call__(self, img):
return self.data_transform(img)
class Mnist_Img_Dataset(data.Dataset):
def __init__(self, file_list, transform):
self.file_list = file_list
self.transform = transform
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
img_path = self.file_list[index]
img = Image.open(img_path).convert('RGB')
img_label = img_path.split("/")[3]
img_label = img_label.split("_")[1]
img_transformed = self.transform(img)
return img_transformed, img_label, img_path
ImageTransformクラスは、画像データの変換を行うクラスで、前回とまったく同じです。今回は、Mnist_Img_Datasetを追加しています。このクラスで自作のデータセットを作っていきます。
Mnist_Img_Datasetは、torch.utils.data.Datasetのスーパークラスを継承させます。また、__len__と__getitem__メソッドも必須となります。
簡単に説明すると__len__は入力された画像ファイルのリストの長さを取得し、__getitem__で画像のファイルパスからイメージ画像を取り込み、ピクセルの情報とファイル名からラベルをデータセットに出力します。
上記の場合、データセットには画像変換されたピクセル情報、ラベル及びデータファイルパスを出力させています。通常であれば画像データとラベルだけあれば問題ありません。
データローダの作成
では、自前で作成したデータセットを使用してデータローダを作成してみます。
#自前データセットを作成する
mean = (0.5,)
std = (0.5,)
org_dataset = Mnist_Img_Dataset(file_list = train_img_list, transform = ImageTransform(mean, std))
自前データセットを作成するMnist_Img_Datasetクラスには、画像データパスのリストと画像変換のクラスを与えます。
#1バッチに含む画像の枚数を指定する
batch_size = 64
#データローダーを作成する
train_dataloader = data.DataLoader(org_dataset, batch_size = batch_size, shuffle = True)
正しくできているのか確認するため、バッチから画像データを取り出して可視化します。
#1バッチ分取り出してsizeを確認する
imgs, labels, img_path = iter(train_dataloader).next()
#images = next(batch_iterator)
#バッチから取り出した画像の大きさを確認する
print("image shape ==>",imgs[0].shape)
#バッチから取り出した画像のイメージとラベルを表示する
pic = transforms.ToPILImage(mode='RGB')(imgs[0])
plt.imshow(pic)
print("Label is ",labels[0])
image shape ==> torch.Size([3, 64, 64])
Label is 7
正しくできているみたいですね。
画像は、見た目モノクロですが、データ自体はカラー(RGB)画像です。モノクロで取り込む場合は、__getitem__の中でPILを使用して画像を開いたところで変更することができます。
#モノクロで取り込む場合
img = Image.open(img_path).convert('L')
#カラーで取り込む場合
img = Image.open(img_path).convert('RGB')
ImageFolderを使わないやり方について解説してきましたが、結構簡単にできた方だとおもいます。ImageFolderと使い分けることで、学習用の画像データを散らかさずに開発作業ができそうですね。ぜひ、試してみてください。