目录
前言
在网上没找到对stable diffusion代码的解读。这里记录一下自己读代码的过程和进度作为一个备忘。
官方代码
related contents:
stable Diffusion 代码(二)
Stable Diffusion 代码 (三)
Stable Diffusion 代码(四)
模型的初始化和载入
从prompt生成图片时的命令为
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms就从入口txt2img.py开始阅读。跳过传入参数的parser部分
# 设定随机seedseed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
其中 opt.config= configs/stable-diffusion/v1-inference.yaml,指向一个预定义好的配置文件ckpt是预先下载好的模型
然后看load_model_from_config函数,这一函数就定义在同一个文件(txt2img.py文件)中,但是它调用了ldm.util中的两个方法。这里一起写出来
def instantiate_from_config(config):return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
return getattr(importlib.import_module(module, package=None), cls)
def load_model_from_config(config, ckpt):
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model
实际上等效于
from ldm.models.diffusion.ddpm import LatentDiffusionmodel = LatentDiffusion(**config.model.get("params", dict()))
model.load_state_dict(torch.load(ckpt, map_location="cpu")["state_dict"], strict=False)
原code使用importlib.import_module,来读取字典中的模块名称进行灵活的import。从方便理解代码运行和算法原理的视角来看,在实际使用LatentDiffusion时,上下两种写法是完全等效的。
这里多说一句,Config字典类似于
Config = { target: path1.path2.module_1_name,params: { para_1 : value_a,
para_2 : value_b,
module_2:{ target: path1.path2.module_2_name,
params: { para_3 : value_c,
module_3:{ target: path1.path2.module_3_name,
params : {para_4: value_d }
}}}}}
get_obj_from_str接收config字典中target对应的值来导入对应的模块,
在 instantiate_from_config 返回对应的类的实例,返回的实例是以params对应的值初始化的params对应的值是同等格式的字典。
也就是说,config中可以像上面的例子一样,设置好嵌套的各个模块,并且在模块实例化时读取传入的config,在模块的__init__中继续调用instantiate_from_config就可以实现各个模块嵌套式的实例化。具体的例子可以看第三篇。
# 初始化模型的全部逻辑: from ldm.models.diffusion.ddpm import LatentDiffusion
import torch
from omegaconf import OmegaConf
# 读取config
config = OmegaConf.load(f"{opt.config}")
# 初始化模型并传入config中的参数
model = LatentDiffusion(**config.model.get("params", dict()))
model.load_state_dict(torch.load(ckpt, map_location="cpu")["state_dict"], strict=False)
device = torch.device("cuda")
model = model.to(device)
图像生成的准备和图像的生成
有了model之后是sampler的初始化 (基于命令行传入的 --plms,执行判断语句的第一条)
sampler = PLMSSampler(model)紧接着,原代码提供了两种输入prompt的方法,分别是命令行输入和从文件读取,不关键。总之最后prompt进入了data这个变量
data = [batch_size * [prompt]]到这里,我们有了
model-[LatentDiffusion]sampler-[PLMSSampler] prompt
这样就可以开始生成图片了。
这里有两个重要的部分,一个是PLMSSampler的定义,一个是LatentDiffusion的定义。我们先将这两个模块视作黑箱,假定它们能完美的完成各自的任务,之后再详细看它们的代码。
这里先简单回忆一下classifier-free guidance的方法: ϵ(x,t)=ϵ(x,t|ϕ)+α⋅(ϵ(x,t|c)−ϵ(x,t|ϕ))\epsilon(x, t)= \epsilon(x,t ~| ~\phi) + \alpha\cdot (\epsilon(x,t~|~ c) -\epsilon(x,t~ |~ \phi))
因此除了prompt,也就是上式中c所对应的条件,还需要unconditional的ϕ\phi 。
c = model.get_learned_conditioning(prompts)uc = model.get_learned_conditioning(batch_size * [""])
这里可以看到model中的一个方法 get_learned_conditioning() : 输入text, 输出text的embedding 。
之后就是图像的生成了。图像的生成调用sampler实例的sample方法。这里为了直观的理解省略了几个参数,完整的参数和具体的各个参数的作用在后面sampler的代码解读部分再说。
samples_ddim, _ = sampler.sample(S=50,conditioning=c,
batch_size=1,
shape=[4,64,64],
unconditional_guidance_scale=7.5,
unconditional_conditioning=uc,
eta=opt.ddim_eta)
x_samples_ddim = model.decode_first_stage(samples_ddim)
到这里为止,diffusion的任务已经结束了,x_samples_ddim 再经过基本的图像处理就是最终的结果。
以上就是txt2img.py文件的全部内容。这一部分绝大多数代码都是数据的读写和准备工作,核心逻辑部分比较少,还是比较好理解的。
接下来进入plms文件去看sampler的代码实现。