そぬばこ

備忘録とか、多分そんな感じ。

gokart + PyTorch Lightning でいい感じに深層学習モデルを動かす

この記事は Sansan Advent Calendar 2021 20日の記事です。

前日は、id:kur0cky さんの

kur0cky.hatenablog.com

でした。

私は過去陸上部だった時代があるのですが、個人的にはフォームも気にしたほうがいいと思います(感想)。

この記事は何か

私が所属している研究開発部には、パイプラインに則ってコードを書こうという文化が浸透してきました。 これらのパイプラインのパッケージは、弊社の研究員が弊社ブログにて既に様々書いているので、ぜひこちらを御覧ください。

buildersbox.corp-sansan.com

buildersbox.corp-sansan.com

さらに、PyTorch のラッパである PyTorch Lightning はいいぞとの布教をとある研究員の方から受け、良さそうだなとなりました。 PyTorch Lightning については、とある研究員の方がいつだかに書いた記事もぜひご覧ください。

buildersbox.corp-sansan.com

こういったパイプラインやラッパは処理の切り分けが明確になることで、コードの可読性を向上させたり等様々なメリットがあります。 せっかくなので私もこのびっぐうぇーぶに乗って、パイプラインである gokart と PyTorch Lightning を組み合わせてサクッとコードを書いてみようと思います。

準備

今回は gokart を使うので、雑に cookiecutter からテンプレートを持ってきて使います。 エムスリーさんが用意しているものがあるので、ありがたく使っていきます。

github.com

次に、 PyTorch を poetry で入れていきましょう。

ところで、PyTorch と poetry はボチボチ相性が悪いです。 PyTorch はアーキテクチャ等の環境に沿ったものを入れないといけませんが、例えば、特定の source から参照する方法で入れることを試みることができます。

github.com

一方、↑の Issue にのように source を指定すると、他のライブラリまで指定した source を参照しようとしていまい、うまく動きません。

[tool.poetry.dependencies]
python = ">=3.9,<3.10"
gokart = "*"
torch = { version = "=1.9.0+cpu", source = "pytorch" }
torchvision = { version = "=0.10.0+cpu", source = "pytorch" }

[[tool.poetry.source]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cpu/"
secondary = true

上記は、そのうまく動かない例です。 gokart をこの PyTorch で指定している url から取ってこようとして 403 が返ってきます。

これはそもそも PyTorch のインストールが PEP 503 に対応したことで、 pipenv の設定がシンプルになるというものがあり、これ poetry でも出来るやんけと思ったら出来なかったという不具合です。

github.com

Downgrading to Poetry 1.0.10 might be a workaround (ontop of my previous comment) as per: python-poetry/poetry#4704 (comment)

Haven't tested because it's too much of a pain, switching to pip!

github.com

しんどいですね。

今回は、最悪ですが whl の url を直接見に行く*1ことで一旦の回避策とします。 こちらですが、 今現在最新の poetry 1.1.12 では、 torchvision が torch の "x.x.x+cpu" に依存しているのにも関わらず 、 '+' 以降がうまく解釈できずに "x.x.x" に依存してるからうまくいかないよと怒られます。 この記事では、さらに苦肉の策として対応されている poetry 1.2.0a2 のプレビュー版を使っています。 誰か私を楽にしてください。

[tool.poetry.dependencies]
python = ">=3.9,<3.10"
gokart = "*"
torch = { url = "https://download.pytorch.org/whl/cpu/torch-1.10.1%2Bcpu-cp39-cp39-linux_x86_64.whl" }
torchvision = { url = "https://download.pytorch.org/whl/cpu/torchvision-0.11.2%2Bcpu-cp39-cp39-linux_x86_64.whl" }

Python のバージョンに依存するので ">=3.9,<3.10" にしてます。

ここまで、書いたら poetry install です。やっと PyTorch が入りましたね。 PyTorch Lightning は poetry が torch と torchvision のバージョンに合わせて依存解決出来るので問題ありません。

書いた

準備が出来たのでサクッとコードを書きました。 特に PyTorch Lightning は初めて書いたので、もっといい感じの書き方があればぜひ教えて下さい。 今回は ResNet で CIFAR-10 データで学習評価まで行うものにしました*2

gokart のタスクは以下の4つに分けてみました。 それぞれサラッと見ていきましょう。

データ前処理

gokart

import gokart
import luigi


class GokartTask(gokart.TaskOnKart):
    task_namespace = 'sansan_adcal_2021'


class PreprocessDataModuleTask(GokartTask):
    _v: int = luigi.IntParameter(default=0)

    def run(self):
        data_module = DataModule()
        data_module.prepare_data()

        self.dump(data_module)

PyTorch Lightning の LightningDataModule

from pathlib import Path

import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class DataModule(pl.LightningDataModule):
    def __init__(self, dataset_root_path: Optional[Path] = None, train_size: float = 0.8, seed: int = 1111):
        super().__init__()
        self.dataset_root_path = dataset_root_path
        self.train_size = train_size
        self.seed = seed
        if self.dataset_root_path is None:
            # gokart の中間ファイルの出力に合わせて resources 以下に入れることにする
            self.dataset_root_path = Path(__file__).resolve().parents[2].joinpath('resources', 'dataset')
        self.data_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
        ])

    def prepare_data(self) -> None:
        datasets.CIFAR10(
            root=self.dataset_root_path,
            download=True
        )

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == "fit":
            all_train_dataset = datasets.CIFAR10(
                root=self.dataset_root_path,
                train=True,
                transform=self.data_transforms
            )
            len_train_dataset = int(self.train_size * len(all_train_dataset))
            len_val_dataset = len(all_train_dataset) - len_train_dataset
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(
                all_train_dataset,
                [len_train_dataset, len_val_dataset],
                generator=torch.Generator().manual_seed(self.seed)
            )
        elif stage == "test":
            self.test_dataset = datasets.CIFAR10(
                root=self.dataset_root_path,
                train=False,
                transform=self.data_transforms
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size=256, num_workers=8)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.val_dataset, batch_size=256, num_workers=8)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.test_dataset, batch_size=256, num_workers=8)

定義した DataModule に基本的にデータセットとしての裁量を託しています。 一度だけ呼ぶ prepare_data() メソッドは、このタスク内で呼ぶことにしました。 今回特に定義をしていませんが、前処理に関するパラメータや train, val のデータセットの分割のシード値等を、gokart 側で受け取って DataModule に渡せると良さそうです。

モデル準備

gokart

class PrepareModelTask(GokartTask):
    _v: int = luigi.IntParameter(default=0)

    def run(self):
        model_module = ModelModule()

        self.dump(model_module)

PyTorch Lightning の LightningModule

import torchmetrics
from torchvision.models import resnet34


class ModelModule(pl.LightningModule):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()

        self.model = resnet34(pretrained=True)
        self.model.fc = torch.nn.Linear(512, 10)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.val_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()

    def configure_optimizers(self):
        optimzier = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimzier

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def training_step(self, batch, *args, **kwargs) -> torch.Tensor:
        x, y = batch
        pred_y = self.forward(x)

        loss = self.criterion(pred_y, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, *args, **kwargs) -> torch.Tensor:
        x, y = batch
        pred_y = self.forward(x)

        loss = self.criterion(pred_y, y)
        self.val_acc(pred_y, y)
        self.log("val_loss", loss)
        self.log("val_acc", self.val_acc, on_step=True, on_epoch=True)
        return loss

    def test_step(self, batch, *args, **kwargs) -> torch.Tensor:
        x, y = batch
        pred_y = self.forward(x)

        loss = self.criterion(pred_y, y)
        self.test_acc(pred_y, y)
        self.log("test_loss", loss)
        self.log("test_acc", self.test_acc, on_step=True, on_epoch=True)

モデルの構築・準備も同様に gokart はタスクの分割と、パラメータの受け渡し口としての使い方が良さそうに感じています。 (データセット同様、今回は何もパラメータを渡してないですが)

モデル訓練

class TrainTask(GokartTask):
    _v: int = luigi.IntParameter(default=1)

    def requires(self):
        return {
            "dataset": PreprocessDataModuleTask(),
            "model": PrepareModelTask()
        }

    def run(self):
        data_module: pl.LightningDataModule = self.load("dataset")
        model_module: pl.LightningModule = self.load("model")

        trainer = pl.Trainer(
            max_epochs=10,
            min_epochs=1
        )

        data_module.setup('fit')
        trainer.fit(model_module, datamodule=data_module)

        self.dump(trainer)

データセットのタスクとモデルのタスクを依存させるようにして、それぞれの LightningModule を渡しています。 各 epoch の checkpoints は PyTorch Lightning 側で持つので、ここを gokart に持たせる必要はないと判断しました。

モデル評価

class EvaluateTask(GokartTask):
    _v: int = luigi.IntParameter(default=0)

    def requires(self):
        return {
            "dataset": PreprocessDataModuleTask(),
            "model": PrepareModelTask(),
            "trainer": TrainTask(),
        }

    def run(self):
        data_module: pl.LightningDataModule = self.load("dataset")
        model_module: pl.LightningModule = self.load("model")
        trainer: pl.Trainer = self.load("trainer")

        data_module.setup("test")
        result = trainer.test(model_module, datamodule=data_module)

        self.dump(result)

Trainer を持ってこさせるようにしています。 評価の処理そのものは LightningModule 側で持っているので、結果を dump() させておいて、後で参照しやすいようにだけしておきました。

まとめと所感

今回は gokart と PyTorch Lightning を組み合わせて、深層学習モデルを動かすサンプルを書いてみました。 深層学習モデル部分の裁量を PyTorch Lightning に持たせ、処理を gokart のタスクで切ることで全体的に処理フローが見えやすい形になったかと思います。 学習のコードがややわかりづらくなりがちな PyTorch のコードは、 PyTorch Lightning で書くことで嫌でも処理がわかりやすくなって良いですね。 また、アドカレとしての締切の時間の都合上、直書きしてしまったパラメータが多いのですが、こういうパラメータはタスクごとに gokart で受け渡せるようにしておくともっと良いかと思います。

コードの全体像は後日まとめて GitHub 上に公開しようと思います。 文章を書いた人間としては、なんだか Poetry 上での torch + torchvision インストールバトルが本題になってしまったようで複雑な気持ちもありますが、許してください。

*1:ここ (https://download.pytorch.org/whl/torch_stable.html) にあります

*2:時間がかかるので、今回は epoch 数等適当に少なめで動作確認だけしました。