top of page
スクリーンショット 2024-01-25 15.42.00.png
  • 執筆者の写真Ryo Shimizu

Liger-KernelによるGPUメモリ削減を試す/A100 80x7で10000コンテキスト長を学習可能に

2024年8月23日に公開されたLiger-KernelはLLM専用に設計されたTritonカーネル集で、これを適用することによってマルチGPUトレーニングのスループットを20%効率化し、VRAMの使用量を60%も削減できるとのことです。


つまり、弊社社長の継之助にはA100 80GBx8しか搭載されていないため、本来ならば80億(8B)パラメータモデルをファインチューニングすると、最大コンテキスト長が2048までしか対応できないのですが、Liger-Kernelを使用すると最大16384まで拡張できるとのこと。


継之助の可能性を大いに広げる技術のため、実際に試してみました。





ただし、継之助ではまだ一つ実験が継続中なので現在使えるGPUは7つだけです。7つのGPUでどこまでコンテキスト長を伸ばせるか試してみます。


まずは試運転。

Liger-Kernelのインストールは簡単です。


$  pip install liger-kernel 

ソースコードからLiger-Kernelを使うこと自体もかなり簡単になっているのですが、ここではそこは触らず、まずはサンプルを試してみます。


$ git clone https://github.com/linkedin/Liger-Kernel.git
$ cd Liger-Kernel
$ cd examples/huggingface/
$ pip install -r requirements.txt
$ CUDA_VISIBLE_DEVICES=1,2,3,4 sh run.sh

まずは4つのGPUを割り当ててサンプルを起動するとしばらくして学習に成功したことがわかった。学習中のGPU使用量は以下の通りでした。


+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off |   00000000:2A:00.0 Off |                    0 |
| N/A   66C    P0            270W /  300W |   63485MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100 80GB PCIe          Off |   00000000:3D:00.0 Off |                    0 |
| N/A   65C    P0            271W /  300W |   75489MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100 80GB PCIe          Off |   00000000:63:00.0 Off |                    0 |
| N/A   63C    P0            310W /  300W |   72309MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A100 80GB PCIe          Off |   00000000:AB:00.0 Off |                    0 |
| N/A   63C    P0            128W /  300W |   76107MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

このrun.shの中でliger_kernelを使うかどうかのフラグを設定しているので、念の為liger_kernelを使わないとどうなるか確認しました。


$ cat run.sh
torchrun --nnodes=1 --nproc-per-node=4 training.py \
    --bf16 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 64 \
    --per_device_eval_batch_size 64 \
    --eval_strategy "no" \
    --save_strategy "no" \
    --learning_rate 6e-6 \
    --weight_decay 0.05 \
    --warmup_ratio 0.1 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --include_num_input_tokens_seen \
    --report_to none \
    --fsdp "full_shard auto_wrap" \
    --fsdp_config config/fsdp_config.json \
    --seed 42 \
    --output_dir alpaca_finetuning
$ CUDA_VISIBLE_DEVICES=1,2,3,4 sh run.sh
W0826 04:18:10.479000 139757400758080 torch/distributed/run.py:779] 
W0826 04:18:10.479000 139757400758080 torch/distributed/run.py:779] *****************************************
W0826 04:18:10.479000 139757400758080 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0826 04:18:10.479000 139757400758080 torch/distributed/run.py:779] *****************************************
/home/shi3z/.pyenv/versions/anaconda3-2023.09-0/envs/ai_scientist/lib/python3.11/site-packages/transformers/utils/hub.py:127: FutureWarning: Using `PYTORCH_TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
  warnings.warn(
/home/shi3z/.pyenv/versions/anaconda3-2023.09-0/envs/ai_scientist/lib/python3.11/site-packages/transformers/utils/hub.py:127: FutureWarning: Using `PYTORCH_TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
  warnings.warn(
/home/shi3z/.pyenv/versions/anaconda3-2023.09-0/envs/ai_scientist/lib/python3.11/site-packages/transformers/utils/hub.py:127: FutureWarning: Using `PYTORCH_TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
  warnings.warn(
/home/shi3z/.pyenv/versions/anaconda3-2023.09-0/envs/ai_scientist/lib/python3.11/site-packages/transformers/utils/hub.py:127: FutureWarning: Using `PYTORCH_TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
  warnings.warn(
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  8.40it/s]


(中略)


    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/shi3z/.pyenv/versions/anaconda3-2023.09-0/envs/ai_scientist/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
training.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-08-26_04:20:32
  host      : tsuginosuke
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 2344141)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Liger-kernelを使わないとエラーを吐いて止まりました。

次に、コンテキスト長を伸ばしてGPUを7つ使用した実験をしてみます。 「 --max_seq_length」に8192を追加して学習させてみます。


$ cat run.sh
torchrun --nnodes=1 --nproc-per-node=7 training.py \
    --bf16 \
    --num_train_epochs 1 --max_seq_length 8192 \
    --per_device_train_batch_size 64 \
    --per_device_eval_batch_size 64 \
    --eval_strategy "no" \
    --save_strategy "no" \
    --learning_rate 6e-6 \
    --weight_decay 0.05 \
    --warmup_ratio 0.1 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --include_num_input_tokens_seen \
    --report_to none --use_liger True\
    --fsdp "full_shard auto_wrap" \
    --fsdp_config config/fsdp_config.json \
    --output_dir alpaca_finetuning

$ CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 sh run.sh
W0826 04:30:39.760000 140275924506432 torch/distributed/run.py:779] 
W0826 04:30:39.760000 140275924506432 torch/distributed/run.py:779] *****************************************
W0826 04:30:39.760000 140275924506432 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0826 04:30:39.760000 140275924506432 torch/distributed/run.py:779] *****************************************

これもきちんと動いてるようです。

学習中のGPUの使用量はこちら


+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off |   00000000:2A:00.0 Off |                    0 |
| N/A   51C    P0            182W /  300W |   42139MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100 80GB PCIe          Off |   00000000:3D:00.0 Off |                    0 |
| N/A   52C    P0            269W /  300W |   44769MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100 80GB PCIe          Off |   00000000:63:00.0 Off |                    0 |
| N/A   50C    P0            329W /  300W |   43027MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A100 80GB PCIe          Off |   00000000:AB:00.0 Off |                    0 |
| N/A   58C    P0            283W /  300W |   44111MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A100 80GB PCIe          Off |   00000000:BD:00.0 Off |                    0 |
| N/A   54C    P0            116W /  300W |   47037MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A100 80GB PCIe          Off |   00000000:CF:00.0 Off |                    0 |
| N/A   59C    P0            290W /  300W |   48749MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A100 80GB PCIe          Off |   00000000:E1:00.0 Off |                    0 |
| N/A   58C    P0            366W /  300W |   42127MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

こうするとまだまだGPUのメモリに余裕があるので、二倍の16384は無理だとしても、12000くらいはいけるかもしれません。


$ cat run.sh
torchrun --nnodes=1 --nproc-per-node=7 training.py \
    --bf16 \
    --num_train_epochs 1 --max_seq_length 12000 \
    --per_device_train_batch_size 64 \
    --per_device_eval_batch_size 64 \
    --eval_strategy "no" \
    --save_strategy "no" \
    --learning_rate 6e-6 \
    --weight_decay 0.05 \
    --warmup_ratio 0.1 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --include_num_input_tokens_seen \
    --report_to none --use_liger True\
    --fsdp "full_shard auto_wrap" \
    --fsdp_config config/fsdp_config.json \
    --output_dir alpaca_finetuning
$ CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 sh run.sh
W0826 04:51:32.098000 140138435372864 torch/distributed/run.py:779] 
W0826 04:51:32.098000 140138435372864 torch/distributed/run.py:779] *****************************************
W0826 04:51:32.098000 140138435372864 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0826 04:51:32.098000 140138435372864 torch/distributed/run.py:779] *****************************************

(中略)

{'train_runtime': 1192.1715, 'train_samples_per_second': 39.257, 'train_steps_per_second': 0.088, 'train_loss': 1.1508818785349528, 'epoch': 1.0, 'num_input_tokens_seen': 15095808, 'step': 105, 'step_time_sec': 12.19, 'avg_step_time_sec': 11.22, 'time_to_completion_sec': 0.0, 'estimated_total_time_sec': 1178.36, 'step_peak_memory_allocated_MB': 25737.13, 'total_peak_memory_allocated_MB': 40353.82, 'step_peak_memory_reserved_MB': 60034.0, 'total_peak_memory_reserved_MB': 60034.0, 'step_tokens_per_second': 11675.1, 'avg_tokens_per_second': 12806.84}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [19:52<00:00, 11.35s/it]


こちらの学習中のGPU使用量は以下の通り



+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off |   00000000:2A:00.0 Off |                    0 |
| N/A   57C    P0            100W /  300W |   55987MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100 80GB PCIe          Off |   00000000:3D:00.0 Off |                    0 |
| N/A   60C    P0            110W /  300W |   51903MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100 80GB PCIe          Off |   00000000:63:00.0 Off |                    0 |
| N/A   58C    P0            109W /  300W |   52607MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A100 80GB PCIe          Off |   00000000:AB:00.0 Off |                    0 |
| N/A   59C    P0            111W /  300W |   53203MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A100 80GB PCIe          Off |   00000000:BD:00.0 Off |                    0 |
| N/A   59C    P0            113W /  300W |   67507MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A100 80GB PCIe          Off |   00000000:CF:00.0 Off |                    0 |
| N/A   62C    P0            111W /  300W |   56483MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A100 80GB PCIe          Off |   00000000:E1:00.0 Off |                    0 |
| N/A   69C    P0            266W /  300W |   58927MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
 

というわけで、A100 80GBx8の小規模クラスのAIスーパーコンピュータでもロングコンテキストが学習できるのは本当のようです。


継之助は現在、LLMのファインチューニングよりもMoA(Mixture of Agents)に力を入れていますが、今後ファインチューニングがもっと手軽になっていくと、部分的にLLMのファインチューニングを取り入れたほうがいいという選択肢も浮かんでくるかもしれません。

閲覧数:463回

最新記事

すべて表示

Comments


bottom of page