電通総研 テックブログ

電通総研が運営する技術ブログ

Stable Diffusion TPU版の使い方

電通国際情報サービス、オープンイノベーションラボの比嘉康雄です。

Stable Diffusion(というよりdiffusers)でTPU(JAX / Flax)を使った並列実行バージョンがリリースされたので、早速試してみました。

オリジナルのNotebookはこちら。

僕が作ったNotebookはこちら。

今回は、TPUを使うので、Google Colabに特化しています。自分で1から試す方は、メニューのEdit -> Notebook settingsでTPUを使うように設定してください。

Stable Diffusionのおすすめコンテンツはこちら。

必要なモジュールのインストール

次のようにして必要なモジュールをインストールします。

!pip install --upgrade jax jaxlib
!pip install flax transformers ftfy
!pip install diffusers==0.5.1

必要なモジュールのインポート

次のようにして必要なモジュールをインポートします。

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')
import jax

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from IPython.display import display
import random

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

huggingfaceにログイン

huggingfaceにログインします。 まだ、huggingfaceのトークンを取得していない場合は、huggingfaceでユーザー登録を行い、トークンを取得してください。

if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

pipelineとparamsの作成

次のようにして、pipelineとparamsを作成します。

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=jnp.bfloat16,
)

show_images()の定義

呪文から、画像を生成するための関数を作成します。現状のColabでは、TPUは、8個まで並列実行できます。

def show_images(prompt):
  seed = random.randrange(1000000)
  rng = jax.random.PRNGKey(seed)
  rng = jax.random.split(rng, jax.device_count())

  p_params = replicate(params)

  prompt = [prompt] * jax.device_count()
  prompt_ids = pipeline.prepare_inputs(prompt)
  prompt_ids = shard(prompt_ids)

  images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
  images = images.reshape((jax.device_count(), ) + images.shape[-3:])
  images = pipeline.numpy_to_pil(images)

  for image in images:
    display(image)

オリジナルのNotebookでは、8つの画像がGridになっていて、個別にダウンロードできないので、個別にダウンロードできるようにしてみました。

show_images()の実行

次の呪文で、show_images()を実行してみました。

prompt = "illustration of beaultiful girl detailed beautiful face detailed perfect pupil of eyes detailed mouth detailed shoulders detailed chest highly detailed artstation deviantart concept art award winning fantasy scene fantasy composition cinematic lighting ray tracing 8k"

show_images(prompt)

今回の呪文(横長、コピー&ペースト用)

illustration of beaultiful girl detailed beautiful face detailed perfect pupil of eyes detailed mouth detailed shoulders detailed chest highly detailed artstation deviantart concept art award winning fantasy scene fantasy composition cinematic lighting ray tracing 8k

閲覧用呪文(改行版)

illustration of
beaultiful girl
detailed beautiful face
detailed perfect pupil of eyes
detailed mouth
detailed shoulders
detailed chest
highly detailed
artstation
deviantart
concept art
award winning
fantasy scene
fantasy composition
cinematic lighting
ray tracing 8k

画像出力結果

全く選別していない、出力されたそのままの画像です。

まとめ

Stable Diffusionが8つの画像を同時生成できるようになったのは、画期的ではないでしょうか。しかも、TPUを使うので、GPUを使い切っていても使うことができるのです。実際僕は、Google Colab ProのGPUを現在使い切ってしまっていて、GPUを使うことはできないのですが、TPUは使えました。

追記: その後、あっという間に、TPUの無料枠分を使い切ってしまいました。

仲間募集

私たちは同じグループで共に働いていただける仲間を募集しています。
現在、以下のような職種を募集しています。

Stable Diffusionの全コンテンツ

執筆:@higaShodoで執筆されました