top of page
スクリーンショット 2024-01-25 15.42.00.png
執筆者の写真Ryo Shimizu

CogVideoXをシングルGPUでLoRAファインチューニングする

更新日:9月17日

CogVideoXのLoRAファインチューニングを行います。

使用したデータは、前回作成した動画からPixtral12Bでキャプションを得るプログラムで作ったデータセットです。


もとにしたのは、弊社共同創業者の清水亮のYouTube番組「さすらい魂」のフッテージです。以下のようなイメージの動画です。


これをLoRAファインチューニングしたところ、以下のような動画が生成できました。プロンプトは「footage of shi3z man.a shi3z man. A man is walking at street of Europe.」です。


ヨーロッパの街並み、胸につけたピンマイク、顎髭、人種などがプロンプトなしで再現されているのがわかります。静止画のファインチューニングは20枚程度の画像でも可能ですが、動画のファインチューニングの場合、どこまで画像を用意すべきか、また、プロンプトをどのようにすべきかはまだ研究途上と言えます。



まず、データを以下のように配置します。

data/
├── labels
│   ├── 000.txt
│   ├── 001.txt
│   ├── 002.txt
├── videos
│   ├── 000.mp4
│   ├── 001.mp4
│   ├── 002.mp4

便宜上0-2までしかありませんが、本当は1000近いファイルを用意しています。

これを学習させるわけですが、例によってそのままだとなかなかうまくいかなかったので、ファインチューニング時のパッケージ関連の情報をここに書いておきます。


numpy                     2.0.1
torch                     2.4.0
torchaudio                2.4.1
torchmetrics              1.4.2
torchvision               0.19.0

CUDAは12.1、ビデオカードはA100 80GBx1です。

コンフィグをいじります。と言っても、今回はあまりいじるところがありませんでした。


configs/cogvideox_2b_lora.yaml

  first_stage_config:
    target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
    params:
      cp_size: 1
      ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt"

configs/sft.yaml

rgs:                                                                                                                                                                                
  checkpoint_activations: True ## using gradient checkpointing                                                                                                                       
  model_parallel_size: 1                                                                                                                                                             
  experiment_name: lora-anime                                                                                                                                                       
  mode: finetune                                                                                                                                                                     
  load: "CogVideoX-2b-sat/transformer"                                                                                                                                               
  no_load_rng: True                                                                                                                                                                  
  train_iters: 1000 # Suggest more than 1000 For Lora and SFT For 500 is enough                                                                                                      
  eval_iters: 1                                                                                                                                                                      
  eval_interval: 100                                                                                                                                                                 
  eval_batch_size: 1                                                                                                                                                                 
  save: ckpts_2b_lora                                                                                                                                                                
  save_interval: 500                                                                                                                                                                 
  log_interval: 20                                                                                                                                                                   
  train_data: [ "data" ] # Train data path                                                                                                                                         
  valid_data: [ "data" ] # Validation data path, can be the same as train_data(not recommended)   

さらにfinetune_single_gpu.shを修正します。


#! /bin/bash

echo "RUN on `hostname`, CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"

environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"

run_cmd="$environs python train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"

echo ${run_cmd}
eval ${run_cmd}

これで準備はできました。


$ bash finetune_single_gpu.sh	

で、学習が開始されます。

学習には数時間かかりました。


学習結果はckpts_2b_loraディレクトリに入ります。 ちゃんと学習結果が得られているか確認しましょう。


推論するには、configを修正する必要があります。

inference.yamlをもとにlora_inference.yamlを作ります。


configs/lora_inference.yaml

args:
  latent_channels: 16
  mode: inference
  load: "ckpts_2b_lora/lora-data-09-16-01-26" 

  batch_size: 1
  input_type: txt
  input_file: configs/lora_test.txt
  sampling_num_frames: 13  # Must be 13, 11 or 9
  sampling_fps: 8
  fp16: True # For CogVideoX-2B
#  bf16: True # For CogVideoX-5B
  output_dir: outputs/
  force_inference: True

さらに、inference.shを改造してlora_inference.shを作ります。

#! /bin/bash

echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"

environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"

run_cmd="$environs python sample_video.py --base configs/cogvideox_2b_lora.yaml configs/lora_inference.yaml --seed $RANDOM"

echo ${run_cmd}
eval ${run_cmd}

echo "DONE on `hostname`"

これで推論させることができます。


$ bash lora_inference.sh

感想としては、LoRAファインチューニングでも絵の感じは結構簡単に学習させることかできるなあ、というものでした。


一方で、元々のCogVideoXのモデルが苦手としているテーマは破綻しやすい印象でした。もっとデータセットのバリエーションを増やせばもっと多様なテーマの動画を生成できるようになると思います。


CogVideoXの2Bは推論には20GBくらいのVRAM消費しかしないものの、学習にはLoRAファインチューニングといえど80GBフルで使う場面がちょいちょい出てきて、これくらいのVRAMがないと厳しいようです。


今回、継之助の本体(A100 80GBx8)は24時間AIハッカソンに貸し出していたのでA100 80x1の継之助(小)で試しました。継之助の本体を使えば5Bでも学習できそうです。


最終的に出力された動画はこちらになります。






閲覧数:1,850回
bottom of page