【機械学習】Transformerで画像分類をする!

プログラミング

今回はTransformerを画像処理に適用した、Vision Transformer (ViT) を用いて画像分類タスクを行いたいと思います。
コード付きで分かりやすく解説していきます!

Transformerの復習

自然言語処理の分野では、様々な手法が提案されてきました。
昔から知られているRNN→Attention→Transformerという形で発展しています。

Attentionは、入力されたデータの領域ごとに重要度(どこに注目するのか)を判断する仕組みです。
以下の記事で詳しく解説しています。

Attentionの仕組みを、文章のような前後関係のあるデータに適用したのが、Transformerです。
これは自然言語処理の性能を劇的に引き上げました。

今回利用するVision Transformer (ViT)は、その仕組みを画像処理に応用したものです。
やり方は簡単で、入力画像をバラバラにして横一列に並び替え、文章のように入力するものです。

今回は、このVision Transformerを画像分類タスクで簡単に利用する方法について紹介します!

環境準備

今回は動作環境として、Google Colaboratoryを利用します。
通称Colabは、誰でも手軽に機械学習環境が利用でき、さらにGPUも無料で利用することができます。

今回のコードは全てColab上で公開しています。

Google Colaboratory

Colabのノートブックが起動した時点で、利用するライブラリはほとんどインストールされています。

今回は、PyTorchが提供する学習済みモデルを利用するために、timmを追加でインストールします。
最初のセルで以下のコマンドを実行しましょう(ビックリマークも忘れずに)。

!pip install timm

ここまで来れば、環境設定は終了です。

学習済みモデルのダウンロード

早速、timmを利用して学習済みモデルをダウンロードしましょう。
初めてなので、一番基礎的なものを利用します。

ちなみに提供されているモデルは、ImageNetを用いて事前学習されています。

import timm

model = timm.create_model('vit_base_patch16_224', pretrained=True)
print(model.eval())

セルに上のように入力します。
実行すれば、ダウンロードしたモデルのネットワーク構造を出力されます。

画像ダウンロードと読み込み

今回は、分類対象の画像として、チワワの画像を用意しました。
これは、ImageNetは犬の画像が多く含まれており、比較的高精度に分類できるからです。

ネットから画像をダウンロードして、プログラムに読み込みます。
読み込んだ画像は、PyTorchで扱うために、テンソル形式に変換しています。

import urllib
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

config = resolve_data_config({}, model=model)
transform = create_transform(**config)

url, filename = ("https://dendenblog.xyz/wp-content/uploads/2021/08/tamara-bellis-UI7xouE1dpw-unsplash-1.jpg", "dog.jpg")
urllib.request.urlretrieve(url, filename)
img = Image.open(filename).convert('RGB')
tensor = transform(img).unsqueeze(0)

推論

画像が読み込めたら、モデルを用いて推論を行います。
modelにそのまま渡せば、確率分布が出力されます。

import torch

with torch.no_grad():
    out = model(tensor)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
print(probabilities.shape)

推論結果の表示

人間にとって、確率分布のままでは何が何だか分かりません。
そこで、それぞれに対応したクラス名で表示させます。

まずは、ImageNetのクラスラベルの一覧をダウンロードし、読み込ませます。

url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt")
urllib.request.urlretrieve(url, filename) 
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

推論結果で、スコアの高い上位5件のクラスラベルとスコアを表示します。

top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

以下のように結果が出力されれば、成功です!

Chihuahua 0.8976315855979919
toy terrier 0.08969379961490631
miniature pinscher 0.0019271321361884475
Ibizan hound 0.001282992190681398
basenji 0.001141041168011725

一番上が、Chihuahua(チワワ)になっているので、分類成功です!
スコアも89.7%なので、かなり十分な精度で分類できていることがわかります。

まとめ

今回は、事前学習モデルを利用して、Transformerを画像分類に応用してみました。
ダウンロードして、実行しただけですが、利用シーンやその手軽さが認識できたと思います。

今回の記事が好評なら、学習済みモデルをファインチューニングする方法などについて解説したいと思います!

コメント

タイトルとURLをコピーしました