LLMを学習させるためのツールはたくさん存在しますが、本格的な学習には分散トレーニングが必要になります。今回は、継之助を使った分散トレーニングのために便利なAccelerateの簡単な使い方を紹介したいと思います。
TRLは、元々はその名の通り強化学習(RL;Reinforcement Learning)のためのツールでしたが、今は教師あり学習(SFT; Supervised Fine Tuning)やDPO(Direct Perfomance Optimization)といった手法にも対応している万能ツールです。
TRLはAccelerateと組み合わせやすいので例として紹介します。
まずはインストールしましょう
$ pip install -U transformers trl accelerate
TRLもAccelerateもtransformersのバージョンに依存するので基本的に全て最新のバージョンを使うのがいいと思います。
例えば、TRL単体でFacebook(Meta)の350m(3億5000万パラメータ)モデルを日本語版Wikipediaで学習させるコードは以下のようになります。
from datasets import load_dataset
from trl import SFTTrainer
#データセットを読み込む
dataset = load_dataset("izumi-lab/wikipedia-ja-20230720", split="train")
# SFTTrainerを設定する
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
))
#学習を実行する
trainer.train()
実行するには通常のPythonコマンドを使います
$ python train.py
これを実行すると確かに動作するのですが、これだと単一のGPUで動作するのみで、8つのA100 80GB GPUを搭載した継之助にとっては宝の持ち腐れになってしまいます。
そこで、Accelerateの出番です。
Accelerateで以下のように書き直します。
from accelerate import Accelerator #追加
accelerator = Accelerator() #追加
from datasets import load_dataset
from trl import SFTTrainer
dataset = load_dataset("izumi-lab/wikipedia-ja-20230720", split="train")
trainer = accelerator.prepare(SFTTrainer( #変更
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
))
trainer.train()
Accelerateの便利なところは、SFTTrainerをaccelerator.prepareで包んであげれば、自動的に各種設定をAccelerate用に変換してくれるところです。
これを実行するにはpythonコマンドではなくaccelerate コマンドを使います。
$ accelerate launch train.py
これだけでちゃんと分散トレーニングが始まります。
$ nvidia-smi
Fri Apr 19 07:12:01 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100 80GB PCIe On | 00000000:17:00.0 Off | 0 |
| N/A 50C P0 291W / 300W | 20585MiB / 81920MiB | 100% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA A100 80GB PCIe On | 00000000:2A:00.0 Off | 0 |
| N/A 55C P0 303W / 300W | 20595MiB / 81920MiB | 100% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA A100 80GB PCIe On | 00000000:3D:00.0 Off | 0 |
| N/A 56C P0 295W / 300W | 20585MiB / 81920MiB | 100% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA A100 80GB PCIe On | 00000000:63:00.0 Off | 0 |
| N/A 52C P0 303W / 300W | 20597MiB / 81920MiB | 100% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 4 NVIDIA A100 80GB PCIe On | 00000000:AB:00.0 Off | 0 |
| N/A 61C P0 156W / 300W | 20585MiB / 81920MiB | 100% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 5 NVIDIA A100 80GB PCIe On | 00000000:BD:00.0 Off | 0 |
| N/A 59C P0 307W / 300W | 20585MiB / 81920MiB | 100% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 6 NVIDIA A100 80GB PCIe On | 00000000:CF:00.0 Off | 0 |
| N/A 59C P0 295W / 300W | 20585MiB / 81920MiB | 100% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 7 NVIDIA A100 80GB PCIe On | 00000000:E1:00.0 Off | 0 |
| N/A 61C P0 305W / 300W | 20585MiB / 81920MiB | 100% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 4089972 C ...nda3-2023.09-0/envs/c310/bin/python 20572MiB |
| 1 N/A N/A 4089973 C ...nda3-2023.09-0/envs/c310/bin/python 20582MiB |
| 2 N/A N/A 4089974 C ...nda3-2023.09-0/envs/c310/bin/python 20572MiB |
| 3 N/A N/A 4089975 C ...nda3-2023.09-0/envs/c310/bin/python 20584MiB |
| 4 N/A N/A 4089976 C ...nda3-2023.09-0/envs/c310/bin/python 20572MiB |
| 5 N/A N/A 4089977 C ...nda3-2023.09-0/envs/c310/bin/python 20572MiB |
| 6 N/A N/A 4089978 C ...nda3-2023.09-0/envs/c310/bin/python 20572MiB |
| 7 N/A N/A 4089979 C ...nda3-2023.09-0/envs/c310/bin/python 20572MiB |
+---------------------------------------------------------------------------------------+
学習が終わるとwandbにログインしていればwandbに結果が表示されます。
{'loss': 3.3256, 'grad_norm': 1.9639703035354614, 'learning_rate': 2.8687127024722932e-05, '
epoch': 1.28}
{'loss': 3.0874, 'grad_norm': 1.8486438989639282, 'learning_rate': 7.374254049445865e-06, 'e
poch': 2.56}
{'train_runtime': 827.3497, 'train_samples_per_second': 90.651, 'train_steps_per_second': 1.
418, 'train_loss': 3.1781245212928924, 'epoch': 3.0}
100%|████████████████████████████████████████████████████████████████████████████████████| 1
173/1173 [13:41<00:00, 1.43it/s]
wandb: | 0.070 MB of 0.070 MB uploaded
wandb: Run history:
wandb: train/epoch ▁▆█
wandb: train/global_step ▁▆█
wandb: train/grad_norm █▁
wandb: train/learning_rate █▁
wandb: train/loss █▁
wandb:
wandb: Run summary:
wandb: total_flos 6.523706571279565e+16
wandb: train/epoch 3.0
wandb: train/global_step 1173
wandb: train/grad_norm 1.84864
wandb: train/learning_rate 1e-05
wandb: train/loss 3.0874
wandb: train_loss 3.17812
wandb: train_runtime 827.3497
wandb: train_samples_per_second 90.651
wandb: train_steps_per_second 1.418
wandb:
wandbにログインしておくと、学習経過をブラウザからグラフで確認できて便利です。
これは出先のiPhoneやiPadからも見ることができるので、いつも「ちゃんと学習上手くいってるかな」とドキドキしながら確認するのが結構楽しい作業になります。
8つのGPUの消費メモリや温度などもモニタリングできます。
wandbは便利なのですが、デバッグ中は鬱陶しいので、SFTTrainerのオプションに「report_to="none"」を指定すると呼び出されなくなります。
from transformers import TrainingArguments
args = TrainingArguments(
output_dir='./output',
num_train_epochs=4,
max_grad_norm=1.0,
learning_rate=1e-4,
gradient_accumulation_steps=64,
per_device_train_batch_size=2,
save_strategy="steps",
logging_steps=20,
lr_scheduler_type="constant",
save_total_limit=1,
fp16=False,
bf16=True,
report_to="none", #これを "wandb"にするとwandbに報告するようになる
)
trainer = accelerator.prepare(SFTTrainer( #変更
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
))
trainer.train()
また、まだTRLでサポートされない新しいアーキテクチャの動作がよくわからない時は、こういうツールを使わずに自分で分散したい時もあると思います。
そういう時もAccelerateを使うとこんなふうに書くことができます。
from torch.utils.data import DataLoader,Dataset
import copy
class CompletionDataset(Dataset): #コンプリーション用のデータセット
def init(self, data):
self.data = data
def len(self):
return len(self.data)
def getitem(self, idx):
data=copy.deepcopy(self.data[idx])
label=torch.LongTensor(data["input_ids"][1:]).to(device)
data["input_ids"]=torch.LongTensor(data["input_ids"][:-1]).to(device)
data["attention_mask"]=torch.LongTensor(data["attention_mask"][:-1]).to(device)
return data,label
dataset = CompletionDataset(tokenized_dataset)
import numpy as np
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99, eps=1e-08)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
train_dataloader = DataLoader(dataset, batch_size=16)
criterion = torch.nn.CrossEntropyLoss()
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
num_epochs=4
print(model)
for epoch in range(num_epochs):
model.train()
for batch,labels in train_dataloader:
optimizer.zero_grad()
outputs = model(**batch,labels=labels)
loss = outputs.loss
print(loss)
loss.backward()
optimizer.step()
途中経過を確認しながらどこに問題があるのか切り分けできるので重宝します。
デバッグはあえて分散させずに行い、良きところで分散して学習できるように切り替えやすいのもAccelerateのメリットと言えるでしょう。