PyTorchでImageFolderを使わないやり方

前回は、ImageFolderを使用してデータローダーを作成しましたが、あえてImageFolderを使用しないやり方を考えてみます。

なぜこの記事を書くことにしたのかというと、ディープラーニングで試しにモデルを作って動かしてみる時に、ストックしてある大量の画像データから、種類や枚数を選んでデータローダーを作成したいと思うことがあったからです。

単純に画像データを新しいディレクトリにコピーしてImageFolderを使えばいいのですが、学習モデルごとにあちこちにデータをコピーするのは煩雑でスマートじゃないという、ただそれだけの理由です。

データセットの作成

まずは、お決まりのモジュールインポートから。

import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data

from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

次に、取り込む画像データファイルのフルパスのリストを作成します。前回作成したMNISTの手書きデータのPNG画像を使用します。データディレクトリの構造は、以下のようになっています。

./MNIST_PNG
├── 0
│ ├── mnist_0_0.png
│ ├── mnist_0_1.png
│ ├── ・・・
│
├── 2
│ ├── mnist_2_0.png
│ ├── mnist_2_1.png
│ ├── ・・・
・・・

この中から、今回は1と7のイメージ画像を200枚づつ使用してデータローダーを作成していきます。

画像データファイルの名前は規則的に付与されているので、forループでフルパスのリストをリストに追加していけばよさそうです。

train_img_list = list()

for img_idx in range(200):
  img_path = "./MNIST_PNG/1/mnist_1_" + str(img_idx) + '.png'
  train_img_list.append(img_path)

  img_path = "./MNIST_PNG/7/mnist_7_" + str(img_idx) + '.png'
  train_img_list.append(img_path)

次に、データ変換クラスとデータセットを作成するクラスを定義します。

class ImageTransform():
  def __init__(self, mean, std):
    self.data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
    ])

  def __call__(self, img):
    return self.data_transform(img)

class Mnist_Img_Dataset(data.Dataset):
  def __init__(self, file_list, transform):
    self.file_list = file_list
    self.transform = transform

  def __len__(self):
    return len(self.file_list)

  def __getitem__(self, index):
    img_path = self.file_list[index]
    img = Image.open(img_path).convert('RGB')
    img_label = img_path.split("/")[3]
    img_label = img_label.split("_")[1]
    img_transformed = self.transform(img)
    return img_transformed, img_label, img_path

ImageTransformクラスは、画像データの変換を行うクラスで、前回とまったく同じです。今回は、Mnist_Img_Datasetを追加しています。このクラスで自作のデータセットを作っていきます。

Mnist_Img_Datasetは、torch.utils.data.Datasetのスーパークラスを継承させます。また、__len__と__getitem__メソッドも必須となります。

簡単に説明すると__len__は入力された画像ファイルのリストの長さを取得し、__getitem__で画像のファイルパスからイメージ画像を取り込み、ピクセルの情報とファイル名からラベルをデータセットに出力します。

上記の場合、データセットには画像変換されたピクセル情報、ラベル及びデータファイルパスを出力させています。通常であれば画像データとラベルだけあれば問題ありません。

データローダの作成

では、自前で作成したデータセットを使用してデータローダを作成してみます。

#自前データセットを作成する
mean = (0.5,)
std = (0.5,)
org_dataset = Mnist_Img_Dataset(file_list = train_img_list, transform = ImageTransform(mean, std))

自前データセットを作成するMnist_Img_Datasetクラスには、画像データパスのリストと画像変換のクラスを与えます。

#1バッチに含む画像の枚数を指定する
batch_size = 64

#データローダーを作成する
train_dataloader = data.DataLoader(org_dataset, batch_size = batch_size, shuffle = True)

正しくできているのか確認するため、バッチから画像データを取り出して可視化します。

#1バッチ分取り出してsizeを確認する
imgs, labels, img_path = iter(train_dataloader).next()
#images = next(batch_iterator)

#バッチから取り出した画像の大きさを確認する
print("image shape ==>",imgs[0].shape)

#バッチから取り出した画像のイメージとラベルを表示する
pic = transforms.ToPILImage(mode='RGB')(imgs[0])
plt.imshow(pic)
print("Label is ",labels[0])

image shape ==> torch.Size([3, 64, 64])
Label is 7

正しくできているみたいですね。

画像は、見た目モノクロですが、データ自体はカラー(RGB)画像です。モノクロで取り込む場合は、__getitem__の中でPILを使用して画像を開いたところで変更することができます。

#モノクロで取り込む場合
img = Image.open(img_path).convert('L')

#カラーで取り込む場合
img = Image.open(img_path).convert('RGB')

ImageFolderを使わないやり方について解説してきましたが、結構簡単にできた方だとおもいます。ImageFolderと使い分けることで、学習用の画像データを散らかさずに開発作業ができそうですね。ぜひ、試してみてください。