ToDoList
参数解释
train.py基本参数
- “–backbone”, type=str, choices=BackboneRegistry.get_all_names(), default=”ncsnpp” ——使用主干网络,默认为ncsnpp,可以选择dvunet、ncsnpp_48k
- “–sde”, type=str, choices=SDERegistry.get_all_names(), default=”ouve” ——使用SDE,默认为ouve,还可以选ouvp
- “–logdir”, type=str, default=”logs” ——日志存储处
- “–nolog”, action=’store_true’, ——不存储日志
- “–wandb_name”, type=str, default=None, ——在wandb生成的run的名字,留空则为随机名
- “–ckpt”, type=str, default=None, ——从已存的ckpt继续训练
- “–devices”, default=”auto” ——使用gpu
- “–accumulate_grad_batches”, type=int, default=1, ——梯度累计量
- “–max_epochs”, type=int, default=300
- “–wandb_project”, type=str, default=”sgmse”,
数据集参数
- “–base_dir”, type=str, required=True, ——数据集路径
- “–format”, type=str, choices=(“default”, “reverb”), default=”default”, ——数据集存放格式,详情看data_module.Specs
- “–batch_size”, type=int, default=8, ——批量大小
- “–num_workers”, type=int, default=4, ——DataLoader_worker数量
- “–dummy”, action=”store_true”, ——将数据集规模调整为1/50,用于快速测试和调试
- “–n_fft”, type=int, default=510, ——fft点数
- –window”, type=str, choices=(“sqrthann”, “hann”), default=”hann”, ——STFT所使用窗函数
- “–hop_length”, type=int, default=128, ——窗函数长度
- “–num_frames”, type=int, default=256, ——截取帧数
- “–spec_factor”, type=float, default=0.15, ——STFT系数缩放比例,Factor to multiply complex STFT coefficients by. 0.15 by default.
- “–spec_abs_exponent”, type=float, default=0.5, ——转换公式 abs(z)**e * exp(1j*angle(z)) 的底数e
- “–transform_type”, type=str, choices=(“exponent”, “log”, “none”), default=”exponent”, ——频谱图转换类型
Exponent 转换:abs(z)**e * exp(1j*angle(z)) ——对频谱图的幅度部分应用一个指数变换。指数 e 通常介于 0 和 1 之间,用于压缩频谱图的动态范围。当 e < 1 时,可以减少高幅度值的影响,使得低幅度值的特征更加明显。
Log转换:log(1 + abs(z)) ——对频谱图的幅度部分应用对数变换。压缩动态范围,使得幅度的变化更平滑。可以减少极大值对模型训练的影响,强调相对较小的变化。
None:不进行任何转换,直接使用原始的复数频谱图 - “–normalize”, type=str, choices=(“clean”, “noisy”, “not”), default=”noisy”, ——选择对输入波形进行归一化的方式,可以选择根据干净信号、噪音信号或不进行归一化。
noisy:根据噪声信号y
的最大绝对值对干净信号x
和噪声信号y
进行归一化。适用于噪声幅度变化较大的情况。
clean:根据干净信号x
的最大绝对值对干净信号x
和噪声信号y
进行归一化。适用于干净信号幅度变化较大的情况。
not:不进行归一化,保持原始信号幅度。适用于信号幅度一致或对幅度不敏感的情况。
模型参数
- “–lr”, type=float, default=1e-4, ——学习率,默认0.0001
- “–ema_decay”, type=float, default=0.999, ——EMA指数平滑参数
- “–t_eps”, type=float, default=0.03, ——最小运行步数(避免时间步为零)
- “–num_eval_files”, type=int, default=20, ——在训练任务中使用验证集的文件数量
- “–loss_type”, type=str, default=”mse”, choices=(“mse”, “mae”), ——损失函数种类
- “–seed”, type=int, default=None, ——设置固定种子,来还原实验结果
SDE参数
- “–sde-n”, type=int, default=1000, ——SDE推理步数,似乎是无效参数,实际步数SGMSE由model.enhance默认参数(N=30)控制,或者是由util.inference.py中的setting控制;diff-sep由config.model.default控制,参数在train[model=model_obj(config)]->DiffSepModel.init中初始化sde时传入
- “–theta, –sigma-min, –sigma-max”, type=float, default=1.5, 0.05, 0.5, ——VE-SDE参数,具体看论文
- “–beta-min, –beta-max, –stiffness”, type=float, required=True, ——VP-SDE参数,具体看论文
enhancement.py参数
- “–test_dir”, type=str, required=True, ——待增强数据路径
- “–enhanced_dir”, type=str, required=True, ——输出路径
- “–ckpt”, type=str, ——使用模型参数
- “–corrector”, type=str, choices=(“ald”, “langevin”, “none”), default=”ald”, ——使用矫正器
- “–corrector_steps”, type=int, default=1, ——使用矫正器步数
- “–snr”, type=float, default=0.5, ——朗之万退火采样器的SNR
- “–N”, type=int, default=30, ——反向过程步数
- “–device”, type=str, default=”cuda”,
calc_metrics.py参数
- “–clean_dir”, type=str, required=True, help=’Directory containing the clean data’
- “–noisy_dir”, type=str, required=True, help=’Directory containing the noisy data’
- “–enhanced_dir”, type=str, required=True, help=’Directory containing the enhanced data’
推荐启动参数
python train.py --base_dir="./data" --num_workers=8 --batch_size=8 --backbone dcunet --n_fft 512 --wandb_name SGMSE --seed 74
python train.py --base_dir="./data" --num_workers=16 --batch_size=8 --wandb_project test --dummy --seed 74
python train.py --base_dir="./data" --num_workers=16 --batch_size=4 --wandb_project test --dummy --seed 2304815 --sde vpsde --backbone SIT --N 15 --wandb_name sit7_SIT-B/2_N15
VPIDM对SGMSE+的改动
- train.py中,增加参数T限制最大采样步数;nolog参数改为了no_wandb参数,若使用则不启用wandb而启用tensorboard
- 在model.py中,t_eps参数默认值从0.03改到了3e-2
- 添加了time_emb和p_name参数,time_emb为bool变量,time_emb为true则使用时间步为条件,否则用sigma
- backbone不再由model类成员变量存储,而是直接加载到self.dnn中
- 优化器只更新self.dnn的参数,而不是model类的所有参数
- 在evaluate_model中,如果sde的名字为VPSDE,则不使用corrector(校正器步数为0)
- 给sampling.get_pc_sampler添加了输入参数time_emb
- sde.py中改写了discretize抽象方法,不知道有何用途
- 给ouve(SGMSE+的sde)添加了–eta_mode和–eta参数,将初始N限定到了正常的30步
- 改写了alpha函数,这个也没看懂
- 新添了vpsde和bbed作为sde备选项
代码分析
SDE.py
使用注册表Registry管理类,每个包含复用类的子文件先声明SDERegistry = Registry(“SDE”),写复用类时先注册@SDERegistry.register(“ouvp”)在写类的内容,这样使用choices=SDERegistry.get_all_names()就可以得到所有复用类名字,sde_class = SDERegistry.get_by_name(temp_args.sde)就可以直接得到注册类
- sde.sde:核心方法,计算并返回正向扩散的漂移项drift和扩散项diffusion
- sde.marginal_prob:在时间 t 下,根据给定的初始状态 x0 和目标状态 y,计算系统状态 x(t) 的边际概率分布,返回的是该分布的均值和标准差
- sde.prior_sampling:用于在进行推理(反向扩散)时,通过y生成最初扩散样本x_T
- sde.discretize:通用抽象方法,用于生成离散化的正向过程的漂移项和扩散项(其实就是乘上dt)
- sde.rverse:通用抽象方法,生成逆sde类,即rsde
- rsde.rsde_parts:生成一个字典,包括正向sde的漂移项和扩散项、分数模型等,最终计算逆sde的总漂移项和总扩散项
- rsde.sde:调用rsde_parts生成的字典,返回总漂移项和总扩散项
- rsde.discretize:用于生成离散化的反向过程的漂移项和扩散项,其是通过正向sde.discretize推导得到的,没用到以上两个方法
model.py
- Initialization and Setup:
__init__
: 初始化模型,包括定义网络结构、损失函数、优化器等。setup
: 在训练、验证或测试的某个阶段前调用,用于设置数据或其他环境准备工作。通常在第一次调用fit
、validate
、test
、或predict
时触发。
- Data Preparation:
train_dataloader
: 返回训练数据的DataLoader
。Trainer 使用这个数据加载器来获取训练数据。val_dataloader
: 返回验证数据的DataLoader
,用于在每个 epoch 结束时进行模型的验证。test_dataloader
: 如果执行trainer.test()
,将使用这个数据加载器获取测试数据。
- Optimizer and Scheduler Setup:
configure_optimizers
: 配置优化器和学习率调度器。Trainer 使用这些优化器来更新模型参数。
- Training Loop:
training_step
: 每个训练批次调用一次。定义前向传播、损失计算、反向传播等。通常返回损失。_step
: training_step中显示调用的自定义函数,这个函数体现了SDE的核心计算逻辑,要改SDE主要从这里改forward()
:在_step
中被显式调用。负责定义分数匹配模型的前向传播逻辑,在这里调用了self.dnn进行模型推理,模型接收输入张量,输出预测score值。_loss
: 用于计算损失的辅助方法,在_step
中被调用。使用去噪分数匹配
- Validation Loop:
validation_step
: 每个验证批次调用一次。用于在验证数据上评估模型的性能,其部分内容和training_step一致,但由于需要额外计算pesq等指标,需要调用一次inference.evaluate_model
来进行一次反向过程(training_step只进行正向扩散过程),最终返回损失。inference.evaluate_model
: 进行反向过程推理,此处包括了使用采样器(调用get_pc_sampler
)进行采样以及矫正器,最终返回sisdr、pesq等指标结果validation_epoch_end
: (可选) 在一个验证 epoch 结束时调用,用于聚合或处理验证结果。
- Checkpointing:
on_save_checkpoint
: 在保存检查点时调用,可以存储额外的状态信息。on_load_checkpoint
: 在加载检查点时调用,用于恢复模型状态。
- Inference and Other Operations:
enhance
: 专用于语音增强的功能方法,在推理过程(调用enhancement.py时)调用。其内容基本和inference.evaluate_model
一致,但返回值不是指标结果,而是降噪后的语音
- get_pc_sampler
- 一个嵌入版本的get_pc_sampler,本质上和sampling.get_pc_sampler没有太大区别
可选方法:
- Test Loop (如果执行
trainer.test()
):test_step
: 每个测试批次调用一次。与validation_step
类似,用于在测试数据上评估模型。test_epoch_end
: (可选) 在一个测试 epoch 结束时调用,用于处理测试结果。
- 验证阶段:
- 调用
on_validation_epoch_start()
和on_validation_epoch_end()
。
- 调用
- 优化器步进:
- 计算完梯度后,
Trainer
会自动调用optimizer.step()
和optimizer.zero_grad()
来更新模型权重。
- 计算完梯度后,
- Epoch 开始:
- 进入训练循环,每个 epoch 开始时,框架会调用
on_epoch_start()
和on_train_epoch_start()
(如果这些方法被覆盖)。
- 进入训练循环,每个 epoch 开始时,框架会调用
- 结束处理:
- 完成所有 epoch 后,调用
on_fit_end()
方法。你可以在这里进行一些清理工作或保存最终的模型状态。
- 完成所有 epoch 后,调用
sgmse.sampling
- get_pc_sampler
- 根据预测器、纠正器名字(已经在correctors.py和predictors.py中注册过,如果无法引用则在init中import一下),返回一个函数pc_sampler和校正器步数。一般来说,预测器处理采样的确定性部分;校正器处理采样的随机性部分
- pc_sampler:先使用sde.prior_sampling从sde的先验分布中采样起始点xt,然后创建一个时间序列步长,从sde.T一直到ade.N。对于每一个时间步,先计算档期时间步的stepsize,然后应用校正器corrector.update_fn细化当前版本xt,再应用预测器predictor.update_fn更新xt。最终输出去噪后结果xt_mean
- 每一次验证论,都会调用inference.evaluate_model,进而调用model.get_pc_sampler获取(pc_)sampler,然后在调用sampler得到预测结果sample
- predictor
- 一般在get_pc_sampler的predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)进行初始化,获取成员变量sde、rsde=sde.reverse(score_fn)、score_fn等
- 具体predictor一般使用ReverseDiffusionPredictor(即p_name=’reverse_diffusion’)
- predictor.update_fn:先调用rsde.discretize方法得到反向sde首项f和尾项g,然后由x – f得到x均值,再加上g * z得到预测结果x