今回はTransformerを画像処理に適用した、Vision Transformer (ViT) を用いて画像分類タスクを行いたいと思います。
コード付きで分かりやすく解説していきます!
Transformerの復習
自然言語処理の分野では、様々な手法が提案されてきました。
昔から知られているRNN→Attention→Transformerという形で発展しています。
Attentionは、入力されたデータの領域ごとに重要度(どこに注目するのか)を判断する仕組みです。
以下の記事で詳しく解説しています。
Attentionの仕組みを、文章のような前後関係のあるデータに適用したのが、Transformerです。
これは自然言語処理の性能を劇的に引き上げました。
今回利用するVision Transformer (ViT)は、その仕組みを画像処理に応用したものです。
やり方は簡単で、入力画像をバラバラにして横一列に並び替え、文章のように入力するものです。

今回は、このVision Transformerを画像分類タスクで簡単に利用する方法について紹介します!
環境準備
今回は動作環境として、Google Colaboratoryを利用します。
通称Colabは、誰でも手軽に機械学習環境が利用でき、さらにGPUも無料で利用することができます。
今回のコードは全てColab上で公開しています。
参考 Vision Transformer.ipynbcolab.research.google.comColabのノートブックが起動した時点で、利用するライブラリはほとんどインストールされています。
今回は、PyTorchが提供する学習済みモデルを利用するために、timm
を追加でインストールします。
最初のセルで以下のコマンドを実行しましょう(ビックリマークも忘れずに)。
!pip install timm
ここまで来れば、環境設定は終了です。
学習済みモデルのダウンロード
早速、timmを利用して学習済みモデルをダウンロードしましょう。
初めてなので、一番基礎的なものを利用します。
ちなみに提供されているモデルは、ImageNetを用いて事前学習されています。
import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True)
print(model.eval())
セルに上のように入力します。
実行すれば、ダウンロードしたモデルのネットワーク構造を出力されます。
画像ダウンロードと読み込み
今回は、分類対象の画像として、チワワの画像を用意しました。
これは、ImageNetは犬の画像が多く含まれており、比較的高精度に分類できるからです。

ネットから画像をダウンロードして、プログラムに読み込みます。
読み込んだ画像は、PyTorchで扱うために、テンソル形式に変換しています。
import urllib
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
url, filename = ("https://dendenblog.xyz/wp-content/uploads/2021/08/tamara-bellis-UI7xouE1dpw-unsplash-1.jpg", "dog.jpg")
urllib.request.urlretrieve(url, filename)
img = Image.open(filename).convert('RGB')
tensor = transform(img).unsqueeze(0)
推論
画像が読み込めたら、モデルを用いて推論を行います。model
にそのまま渡せば、確率分布が出力されます。
import torch
with torch.no_grad():
out = model(tensor)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
print(probabilities.shape)
推論結果の表示
人間にとって、確率分布のままでは何が何だか分かりません。
そこで、それぞれに対応したクラス名で表示させます。
まずは、ImageNetのクラスラベルの一覧をダウンロードし、読み込ませます。
url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt")
urllib.request.urlretrieve(url, filename)
with open("imagenet_classes.txt", "r") as f:
categories = [s.strip() for s in f.readlines()]
推論結果で、スコアの高い上位5件のクラスラベルとスコアを表示します。
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
print(categories[top5_catid[i]], top5_prob[i].item())
以下のように結果が出力されれば、成功です!
Chihuahua 0.8976315855979919
toy terrier 0.08969379961490631
miniature pinscher 0.0019271321361884475
Ibizan hound 0.001282992190681398
basenji 0.001141041168011725
一番上が、Chihuahua(チワワ)になっているので、分類成功です!
スコアも89.7%なので、かなり十分な精度で分類できていることがわかります。
まとめ
今回は、事前学習モデルを利用して、Transformerを画像分類に応用してみました。
ダウンロードして、実行しただけですが、利用シーンやその手軽さが認識できたと思います。
今回の記事が好評なら、学習済みモデルをファインチューニングする方法などについて解説したいと思います!