top of page
スクリーンショット 2024-01-25 15.42.00.png
執筆者の写真Ryo Shimizu

簡単にできるAccelerateによるLLMの分散トレーニング

LLMを学習させるためのツールはたくさん存在しますが、本格的な学習には分散トレーニングが必要になります。今回は、継之助を使った分散トレーニングのために便利なAccelerateの簡単な使い方を紹介したいと思います。


TRLは、元々はその名の通り強化学習(RL;Reinforcement Learning)のためのツールでしたが、今は教師あり学習(SFT; Supervised Fine Tuning)やDPO(Direct Perfomance Optimization)といった手法にも対応している万能ツールです。


TRLはAccelerateと組み合わせやすいので例として紹介します。


まずはインストールしましょう


$ pip install -U transformers trl accelerate

TRLもAccelerateもtransformersのバージョンに依存するので基本的に全て最新のバージョンを使うのがいいと思います。


例えば、TRL単体でFacebook(Meta)の350m(3億5000万パラメータ)モデルを日本語版Wikipediaで学習させるコードは以下のようになります。


from datasets import load_dataset
from trl import SFTTrainer

#データセットを読み込む
dataset = load_dataset("izumi-lab/wikipedia-ja-20230720", split="train")

# SFTTrainerを設定する
trainer =  SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
))

#学習を実行する
trainer.train()

実行するには通常のPythonコマンドを使います


$ python train.py

これを実行すると確かに動作するのですが、これだと単一のGPUで動作するのみで、8つのA100 80GB GPUを搭載した継之助にとっては宝の持ち腐れになってしまいます。


そこで、Accelerateの出番です。

Accelerateで以下のように書き直します。


from accelerate import Accelerator #追加
accelerator = Accelerator() #追加

from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("izumi-lab/wikipedia-ja-20230720", split="train")

trainer =  accelerator.prepare(SFTTrainer( #変更
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
))
trainer.train()


Accelerateの便利なところは、SFTTrainerをaccelerator.prepareで包んであげれば、自動的に各種設定をAccelerate用に変換してくれるところです。


これを実行するにはpythonコマンドではなくaccelerate コマンドを使います。


$ accelerate launch train.py

これだけでちゃんと分散トレーニングが始まります。


$ nvidia-smi
Fri Apr 19 07:12:01 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100 80GB PCIe          On  | 00000000:17:00.0 Off |                    0 |
| N/A   50C    P0             291W / 300W |  20585MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          On  | 00000000:2A:00.0 Off |                    0 |
| N/A   55C    P0             303W / 300W |  20595MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A100 80GB PCIe          On  | 00000000:3D:00.0 Off |                    0 |
| N/A   56C    P0             295W / 300W |  20585MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A100 80GB PCIe          On  | 00000000:63:00.0 Off |                    0 |
| N/A   52C    P0             303W / 300W |  20597MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA A100 80GB PCIe          On  | 00000000:AB:00.0 Off |                    0 |
| N/A   61C    P0             156W / 300W |  20585MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA A100 80GB PCIe          On  | 00000000:BD:00.0 Off |                    0 |
| N/A   59C    P0             307W / 300W |  20585MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA A100 80GB PCIe          On  | 00000000:CF:00.0 Off |                    0 |
| N/A   59C    P0             295W / 300W |  20585MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA A100 80GB PCIe          On  | 00000000:E1:00.0 Off |                    0 |
| N/A   61C    P0             305W / 300W |  20585MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   4089972      C   ...nda3-2023.09-0/envs/c310/bin/python    20572MiB |
|    1   N/A  N/A   4089973      C   ...nda3-2023.09-0/envs/c310/bin/python    20582MiB |
|    2   N/A  N/A   4089974      C   ...nda3-2023.09-0/envs/c310/bin/python    20572MiB |
|    3   N/A  N/A   4089975      C   ...nda3-2023.09-0/envs/c310/bin/python    20584MiB |
|    4   N/A  N/A   4089976      C   ...nda3-2023.09-0/envs/c310/bin/python    20572MiB |
|    5   N/A  N/A   4089977      C   ...nda3-2023.09-0/envs/c310/bin/python    20572MiB |
|    6   N/A  N/A   4089978      C   ...nda3-2023.09-0/envs/c310/bin/python    20572MiB |
|    7   N/A  N/A   4089979      C   ...nda3-2023.09-0/envs/c310/bin/python    20572MiB |
+---------------------------------------------------------------------------------------+

学習が終わるとwandbにログインしていればwandbに結果が表示されます。


{'loss': 3.3256, 'grad_norm': 1.9639703035354614, 'learning_rate': 2.8687127024722932e-05, '
epoch': 1.28}                                                                               
{'loss': 3.0874, 'grad_norm': 1.8486438989639282, 'learning_rate': 7.374254049445865e-06, 'e
poch': 2.56}                                                                                
{'train_runtime': 827.3497, 'train_samples_per_second': 90.651, 'train_steps_per_second': 1.
418, 'train_loss': 3.1781245212928924, 'epoch': 3.0}                                        
100%|████████████████████████████████████████████████████████████████████████████████████| 1
173/1173 [13:41<00:00,  1.43it/s]                                                           
wandb: | 0.070 MB of 0.070 MB uploaded                                                      
wandb: Run history:                                                                         
wandb:         train/epoch ▁▆█                                                              
wandb:   train/global_step ▁▆█                                                              
wandb:     train/grad_norm █▁                                                               
wandb: train/learning_rate █▁                                                               
wandb:          train/loss █▁                                                               
wandb:                                                                                      
wandb: Run summary:                                                                         
wandb:               total_flos 6.523706571279565e+16                                       
wandb:              train/epoch 3.0                                                         
wandb:        train/global_step 1173                                                        
wandb:          train/grad_norm 1.84864                                                     
wandb:      train/learning_rate 1e-05                                                       
wandb:               train/loss 3.0874                                                      
wandb:               train_loss 3.17812                                                     
wandb:            train_runtime 827.3497                                                    
wandb: train_samples_per_second 90.651                                                      
wandb:   train_steps_per_second 1.418                                                       
wandb:                                      

wandbにログインしておくと、学習経過をブラウザからグラフで確認できて便利です。

これは出先のiPhoneやiPadからも見ることができるので、いつも「ちゃんと学習上手くいってるかな」とドキドキしながら確認するのが結構楽しい作業になります。


8つのGPUの消費メモリや温度などもモニタリングできます。




wandbは便利なのですが、デバッグ中は鬱陶しいので、SFTTrainerのオプションに「report_to="none"」を指定すると呼び出されなくなります。


from transformers import TrainingArguments

args = TrainingArguments(
    output_dir='./output',
    num_train_epochs=4,
    max_grad_norm=1.0,
    learning_rate=1e-4, 
    gradient_accumulation_steps=64,
    per_device_train_batch_size=2,
    save_strategy="steps",
    logging_steps=20,
    lr_scheduler_type="constant",
    save_total_limit=1,
    fp16=False,
    bf16=True,
    report_to="none", #これを "wandb"にするとwandbに報告するようになる
)

trainer =  accelerator.prepare(SFTTrainer( #変更
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
))
trainer.train()


また、まだTRLでサポートされない新しいアーキテクチャの動作がよくわからない時は、こういうツールを使わずに自分で分散したい時もあると思います。


そういう時もAccelerateを使うとこんなふうに書くことができます。

from torch.utils.data import DataLoader,Dataset

import copy
class CompletionDataset(Dataset): #コンプリーション用のデータセット
    def init(self, data):
        self.data = data

    def len(self):
        return len(self.data)

    def getitem(self, idx):
        data=copy.deepcopy(self.data[idx])
        label=torch.LongTensor(data["input_ids"][1:]).to(device)
        data["input_ids"]=torch.LongTensor(data["input_ids"][:-1]).to(device)
        data["attention_mask"]=torch.LongTensor(data["attention_mask"][:-1]).to(device)
        return data,label

dataset = CompletionDataset(tokenized_dataset)   

import numpy as np
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99, eps=1e-08)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
train_dataloader = DataLoader(dataset, batch_size=16)
criterion = torch.nn.CrossEntropyLoss()

model, optimizer, train_dataloader  = accelerator.prepare(model, optimizer, train_dataloader)

num_epochs=4
print(model)
for epoch in range(num_epochs):
    model.train()
    for batch,labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(**batch,labels=labels)
        loss = outputs.loss
        print(loss)
        loss.backward()
        optimizer.step()

途中経過を確認しながらどこに問題があるのか切り分けできるので重宝します。

デバッグはあえて分散させずに行い、良きところで分散して学習できるように切り替えやすいのもAccelerateのメリットと言えるでしょう。




閲覧数:429回
bottom of page