头图

JoJoGAN: One Shot Face Stylization. Using only one image of a face, it can learn its style and then transfer to other images. The training time is only 1-2 minutes.

Effect:

Main process:

This article shares the entire process of personal practice of JoJoGAN in a local environment (non-colab). You can also follow this article to start training your favorite style.

Prepare the environment

Install:

conda create -n torch python=3.9 -y
conda activate torch

conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -y

an examination:

$ python - <<EOF
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())
EOF
1.10.1 True

prepare code

git clone https://github.com/mchong6/JoJoGAN.git
cd JoJoGAN

pip install tqdm gdown matplotlib scipy opencv-python dlib lpips wandb

# Ninja is required to load C++ extensions
wget https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-linux.zip
sudo unzip ninja-linux.zip -d /usr/local/bin/
sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

Then, put several *.py provided in this article into the JoJoGAN directory and get it from here: https://github.com/ikuokuo/start-deep-learning/tree/master/practice/JoJoGAN .

  • download_models.py : get model
  • generate_faces.py : generate face
  • stylize.py : Stylized
  • train.py : training

After that, in the training process section, the work flow of JoJoGAN will be described in combination with the code. Others *.py only mention the usage, not the implementation.

get model

python download_models.py to get the model, as follows:

models/
├── arcane_caitlyn_preserve_color.pt
├── arcane_caitlyn.pt
├── arcane_jinx_preserve_color.pt
├── arcane_jinx.pt
├── arcane_multi_preserve_color.pt
├── arcane_multi.pt
├── art.pt
├── disney_preserve_color.pt
├── disney.pt
├── dlibshape_predictor_68_face_landmarks.dat
├── e4e_ffhq_encode.pt
├── jojo_preserve_color.pt
├── jojo.pt
├── jojo_yasuho_preserve_color.pt
├── jojo_yasuho.pt
├── restyle_psp_ffhq_encode.pt
├── stylegan2-ffhq-config-f.pt
├── supergirl_preserve_color.pt
└── supergirl.pt

generate face

Randomly generate faces with the StyleGAN2 pretrained model for testing:

python generate_faces.py -n 5 -s 2000 -o input

Use a pretrained style

JoJoGAN gave 8 pre-training models, which can be experienced together, the same as the renderings at the beginning of the article:

# 预览 JoJoGAN 所有预训练模型 风格化某图片(test_input/iu.jpeg)的效果
python stylize.py -i test_input/iu.jpeg -s all --save-all --show-all

# 使用 JoJoGAN 所有预训练模型 风格化所有生成的测试人脸(input/*)
find ./input -type f -print0 | xargs -0 -i python stylize.py -i {} -s all --save-all

train your own style

First, prepare a style map:

After that, start training:

python train.py -n yinshi -i style_images/yinshi.jpeg --alpha 1.0 --num_iter 500 --latent_dim 512 --use_wandb --log_interval 50

--use_wandb , you can view the training log:

Finally, test the effect:

python stylize.py -i input/girl.jpeg --save-all --show-all --test_style yinshi --test_ckpt output/yinshi.pt --test_ref output/yinshi/style_images_aligned/yinshi.png

training workflow

Prepare style images and turn them into training data

Align the face crop in the style image:

# dlib 预测人脸特征点,再裁减对齐
from util import align_face
style_aligned = align_face(img_path)

Inversely map the style image GAN Inversion back to the Latent Space of the pretrained model:

name, _ = os.path.splitext(os.path.basename(img_path))
style_code_path = os.path.join(latent_dir, f'{name}.pt')

# e4e FFHQ encoder (pSp) > GAN inversion,得到 latent
from e4e_projection import projection
latent = projection(style_aligned, style_code_path, device)

Load StyleGAN2 model, train fine-tuning

Load the pretrained model:

latent_dim = 512

# 加载预训练模型
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
ckpt = torch.load("models/stylegan2-ffhq-config-f.pt", map_location=lambda storage, loc: storage)
original_generator.load_state_dict(ckpt["g_ema"], strict=False)

# 准备微调的模型
generator = deepcopy(original_generator)

Training tunable parameters:

# 控制风格强度 [0, 1]
alpha = 1.0
alpha = 1-alpha

# 是否保留原图像色彩
preserve_color = True

# 训练迭代次数(最好 500,Adam 学习率是基于 500 次迭代调优的)
num_iter = 500

# 风格图片 targets 及 latents
targets = ..
latents = ..

Perform training and fit the latent space. Finally save:

# 准备 LPIPS 计算 loss
lpips_fn = lpips.LPIPS(net='vgg').to(device)

# 准备优化器
g_optim = torch.optim.Adam(generator.parameters(), lr=2e-3, betas=(0, 0.99))

# 哪些层用于交换,用于生成风格化图片
if preserve_color:
    id_swap = [7,9,11,15,16,17]
else:
    id_swap = list(range(7, generator.n_latent))

# 训练迭代
for idx in tqdm(range(num_iter)):
    # 交换层混合风格,并加噪声
    mean_w = generator.get_latent(torch.randn([latents.size(0), latent_dim])
        .to(device)).unsqueeze(1).repeat(1, generator.n_latent, 1)
    in_latent = latents.clone()
    in_latent[:, id_swap] = alpha*latents[:, id_swap] + (1-alpha)*mean_w[:, id_swap]

    # 以 latent 风格化图片,与目标风格对比
    img = generator(in_latent, input_is_latent=True)
    loss = lpips_fn(F.interpolate(img, size=(256,256), mode='area'),
        F.interpolate(targets, size=(256,256), mode='area')).mean()

    # 优化
    g_optim.zero_grad()
    loss.backward()
    g_optim.step()

# 保存权重,完成
torch.save({"g": generator.state_dict()}, save_path)

Epilogue

JoJoGAN works well in practice. Using the code given in this article, it is easier to start training your favorite style, and it is worth trying.


GoCoding personal practice experience sharing, you can pay attention to the public number!

GoCoding
88 声望5 粉丝

Go coding in my way :)