保姆级讲解 Stable Diffusion

汪天宇
2023-12-01

保姆级讲解 Stable Diffusion:

https://mp.weixin.qq.com/s?__biz=Mzk0MzIzODM5MA==&mid=2247486486&idx=1&sn=aff9ed60bba2cbf9efd32aa68557c93b&chksm=c337b18ff4403899d24ac32a60dbfd0402aab7309e8442dabdcb14cd61cfb55ad6cc1f977b3b#rd

整体代码

# 1、prompt编码为token。编码器为FrozenCLIPEmbedde(包括1层的 CLIPTextEmbeddings 和12层的自注意力encoder)
c = self.cond_stage_model.encode(c)    # (c为输入的提示语句,重复2次)  输出:(2,77,768)
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
# self.tokenizer来自 transformers包中的 预训练CLIPTokenizer
tokens = batch_encoding["input_ids"].to(self.device)             # (2,77)一句话编码为77维
outputs = self.transformer(input_ids=tokens).last_hidden_state   # 12层self-atten,结果(2,77,768)

# 2、
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                   conditioning=c,
                                   batch_size=opt.n_samples,
                                   shape=shape,
                                   verbose=False,
                                   unconditional_guidance_scale=opt.scale,
                                   unconditional_conditioning=uc,
                                   eta=opt.ddim_eta,
                                   x_T=start_code)
     # 01、
     self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)    # S=50
     # 这一步是ddim中,预先register超参数,如a的连乘等
     # Data shape for PLMS sampling is (2, 4, 32, 32) 
     # 02、
     samples, intermediates = self.plms_sampling(conditioning, size,
                                                callback=callback,
                                                img_callback=img_callback,
                                                quantize_denoised=quantize_x0,
                                                mask=mask, x0=x0,
                                                ddim_use_original_steps=False,
                                                noise_dropout=noise_dropout,
                                                temperature=temperature,
                                                score_corrector=score_corrector,
                                                corrector_kwargs=corrector_kwargs,
                                                x_T=x_T )
          img = torch.randn(shape, device=device)    # (2,4,32,32)
          for i, step in enumerate(iterator):
                index = total_steps - i - 1                                        # index=50-i-1, step=981
                ts = torch.full((b,), step, device=device, dtype=torch.long)       # [981,981]
                outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                          quantize_denoised=quantize_denoised, temperature=temperature,
                                          noise_dropout=noise_dropout, score_corrector=score_corrector,
                                          corrector_kwargs=corrector_kwargs,
                                          unconditional_guidance_scale=unconditional_guidance_scale,
                                          unconditional_conditioning=unconditional_conditioning,
                                          old_eps=old_eps, t_next=ts_next)
                    c_in = torch.cat([unconditional_conditioning, c])    # 添加一个空字符,与promt拼接
                    e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
                          t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)    # timesteps:[981,981,981,981] -> (4,320)
                          emb = self.time_embed(t_emb)           # 2*linear:(4,320) -> (4,1280)
                          
                          # unet中带入embed与prompt,具体见源码
                          for module in self.input_blocks:
                              h = module(h, emb, context)        # 输入(4,4,32,32) (4,1280) (4,77,768)
                              hs.append(h)
                          h = self.middle_block(h, emb, context) 
                          for module in self.output_blocks:
                              h = th.cat([h, hs.pop()], dim=1)   # (4,1280,4,4) -> (4,2560,4,4)
                              h = module(h, emb, context)

                          return self.out(h)                     # (4,320,32,32)卷积为(4,4,32,32)

# 3、
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)   # 上步中得到的结果拆开:(2,4,32,32
   e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)  # 用7.5乘以二者差距,再加回空语句生成的图
   x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)                  # DDIM计算:e_t(2,4,32,32) index:49  -> (2,4,32,32)

# 4、
x_samples_ddim = model.decode_first_stage(samples_ddim)    # (2,4,32,32)
        h = self.conv_in(z)    # 卷积4->512
        x = torch.nn.functional.interpolate(h, scale_factor=2.0, mode="nearest")  #(2,512,64,64)
        h = self.up[i_level].block[i_block](h)    # 经过几次卷积与上采样
        h = self.norm_out(h)   # (2,128,256,256)
        h = nonlinearity(h)    # x*torch.sigmoid(x)
        h = self.conv_out(h)   # conv(128,3) -》(2,3,256,256)

# 5、
后处理
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(x_sample.astype(np.uint8))
img.save(os.path.join(sample_path, f"{base_count:05}.png"))

unet解析

DDIM中的Unet 包含输入模块、中间模块、输出模块三部分:

self.input_blocks

包含12个不同的 TimestepEmbedSequential结构,下面列举三种:

# 1、self.input_blocks
ModuleList(
  (0): TimestepEmbedSequential(
    (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (1): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=320, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Identity()
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
      (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=320, out_features=320, bias=False)
            (to_k): Linear(in_features=320, out_features=320, bias=False)
            (to_v): Linear(in_features=320, out_features=320, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=320, out_features=320, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=320, out_features=2560, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=1280, out_features=320, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=320, out_features=320, bias=False)
            (to_k): Linear(in_features=768, out_features=320, bias=False)
            (to_v): Linear(in_features=768, out_features=320, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=320, out_features=320, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
    )
  )

  (6): TimestepEmbedSequential(
    (0): Downsample(
      (op): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
  )

前向过程:

为h添加emb和交与propmt的交叉注意力,会执行多次

emb_out = self.emb_layers(emb)      # (4,1280)卷积为(4,320)
h = h + emb_out                     # (4,320,32,32)+(4,320,1,1)

x = self.attn1(self.norm1(x)) + x                     # 自注意力:x(4,1024,320)映射到qkv,均320维
x = self.attn2(self.norm2(x), context=context) + x    # 交叉注意力:context(4,77,768)映射到kv的320维
x = self.ff(self.norm3(x)) + x

噪音图像h(4,4,32,32)在其中变化为:(4,320,32,32)(4,320,16,16)(4,640,16,16)(4,1280,8,8)(4,1280,4,4)

middle_blocks

TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
  (1): SpatialTransformer(
    (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
    (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
    (transformer_blocks): ModuleList(
      (0): BasicTransformerBlock(
        (attn1): CrossAttention(
          (to_q): Linear(in_features=1280, out_features=1280, bias=False)
          (to_k): Linear(in_features=1280, out_features=1280, bias=False)
          (to_v): Linear(in_features=1280, out_features=1280, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=1280, out_features=1280, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (ff): FeedForward(
          (net): Sequential(
            (0): GEGLU(
              (proj): Linear(in_features=1280, out_features=10240, bias=True)
            )
            (1): Dropout(p=0.0, inplace=False)
            (2): Linear(in_features=5120, out_features=1280, bias=True)
          )
        )
        (attn2): CrossAttention(
          (to_q): Linear(in_features=1280, out_features=1280, bias=False)
          (to_k): Linear(in_features=768, out_features=1280, bias=False)
          (to_v): Linear(in_features=768, out_features=1280, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=1280, out_features=1280, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
    )
    (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  )
  (2): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )

self.output_blocks

与输入模块相同,包含12个 TimestepEmbedSequential,顺序相反。

 类似资料: