SGMSE代码开发

ToDoList

参数解释

train.py基本参数

  1. “–backbone”, type=str, choices=BackboneRegistry.get_all_names(), default=”ncsnpp” ——使用主干网络,默认为ncsnpp,可以选择dvunet、ncsnpp_48k
  2. “–sde”, type=str, choices=SDERegistry.get_all_names(), default=”ouve” ——使用SDE,默认为ouve,还可以选ouvp
  3. “–logdir”, type=str, default=”logs” ——日志存储处
  4. “–nolog”, action=’store_true’, ——不存储日志
  5. “–wandb_name”, type=str, default=None, ——在wandb生成的run的名字,留空则为随机名
  6. “–ckpt”, type=str, default=None, ——从已存的ckpt继续训练
  7. “–devices”, default=”auto” ——使用gpu
  8. “–accumulate_grad_batches”, type=int, default=1, ——梯度累计量
  9. “–max_epochs”, type=int, default=300
  10. “–wandb_project”, type=str, default=”sgmse”,

数据集参数

  1. “–base_dir”, type=str, required=True, ——数据集路径
  2. “–format”, type=str, choices=(“default”, “reverb”), default=”default”, ——数据集存放格式,详情看data_module.Specs
  3. “–batch_size”, type=int, default=8, ——批量大小
  4. “–num_workers”, type=int, default=4, ——DataLoader_worker数量
  5. “–dummy”, action=”store_true”, ——将数据集规模调整为1/50,用于快速测试和调试
  6. “–n_fft”, type=int, default=510, ——fft点数
  7. –window”, type=str, choices=(“sqrthann”, “hann”), default=”hann”, ——STFT所使用窗函数
  8. “–hop_length”, type=int, default=128, ——窗函数长度
  9. “–num_frames”, type=int, default=256, ——截取帧数
  10. “–spec_factor”, type=float, default=0.15, ——STFT系数缩放比例,Factor to multiply complex STFT coefficients by. 0.15 by default.
  11. “–spec_abs_exponent”, type=float, default=0.5, ——转换公式 abs(z)**e * exp(1j*angle(z)) 的底数e
  12. “–transform_type”, type=str, choices=(“exponent”, “log”, “none”), default=”exponent”, ——频谱图转换类型
    Exponent 转换:abs(z)**e * exp(1j*angle(z)) ——对频谱图的幅度部分应用一个指数变换。指数 e 通常介于 0 和 1 之间,用于压缩频谱图的动态范围。当 e < 1 时,可以减少高幅度值的影响,使得低幅度值的特征更加明显。
    Log转换:log(1 + abs(z)) ——对频谱图的幅度部分应用对数变换。压缩动态范围,使得幅度的变化更平滑。可以减少极大值对模型训练的影响,强调相对较小的变化。
    None:不进行任何转换,直接使用原始的复数频谱图
  13. “–normalize”, type=str, choices=(“clean”, “noisy”, “not”), default=”noisy”, ——选择对输入波形进行归一化的方式,可以选择根据干净信号、噪音信号或不进行归一化。
    noisy:根据噪声信号 y 的最大绝对值对干净信号 x 和噪声信号 y 进行归一化。适用于噪声幅度变化较大的情况。
    clean:根据干净信号 x 的最大绝对值对干净信号 x 和噪声信号 y 进行归一化。适用于干净信号幅度变化较大的情况。
    not:不进行归一化,保持原始信号幅度。适用于信号幅度一致或对幅度不敏感的情况。

模型参数

  1. “–lr”, type=float, default=1e-4, ——学习率,默认0.0001
  2. “–ema_decay”, type=float, default=0.999, ——EMA指数平滑参数
  3. “–t_eps”, type=float, default=0.03, ——最小运行步数(避免时间步为零)
  4. “–num_eval_files”, type=int, default=20, ——在训练任务中使用验证集的文件数量
  5. “–loss_type”, type=str, default=”mse”, choices=(“mse”, “mae”), ——损失函数种类
  6. “–seed”, type=int, default=None, ——设置固定种子,来还原实验结果

SDE参数

  1. “–sde-n”, type=int, default=1000, ——SDE推理步数,似乎是无效参数,实际步数SGMSE由model.enhance默认参数(N=30)控制,或者是由util.inference.py中的setting控制;diff-sep由config.model.default控制,参数在train[model=model_obj(config)]->DiffSepModel.init中初始化sde时传入
  2. “–theta, –sigma-min, –sigma-max”, type=float, default=1.5, 0.05, 0.5, ——VE-SDE参数,具体看论文
  3. –beta-min, –beta-max, –stiffness”, type=float, required=True, ——VP-SDE参数,具体看论文

enhancement.py参数

  1. “–test_dir”, type=str, required=True, ——待增强数据路径
  2. “–enhanced_dir”, type=str, required=True, ——输出路径
  3. “–ckpt”, type=str, ——使用模型参数
  4. “–corrector”, type=str, choices=(“ald”, “langevin”, “none”), default=”ald”, ——使用矫正器
  5. “–corrector_steps”, type=int, default=1, ——使用矫正器步数
  6. “–snr”, type=float, default=0.5, ——朗之万退火采样器的SNR
  7. “–N”, type=int, default=30, ——反向过程步数
  8. “–device”, type=str, default=”cuda”,

calc_metrics.py参数

  1. “–clean_dir”, type=str, required=True, help=’Directory containing the clean data’
  2. “–noisy_dir”, type=str, required=True, help=’Directory containing the noisy data’
  3. “–enhanced_dir”, type=str, required=True, help=’Directory containing the enhanced data’

推荐启动参数

python train.py --base_dir="./data" --num_workers=8 --batch_size=8 --backbone dcunet --n_fft 512 --wandb_name SGMSE --seed 74
python train.py --base_dir="./data" --num_workers=16 --batch_size=8 --wandb_project test --dummy --seed 74
python train.py --base_dir="./data" --num_workers=16 --batch_size=4 --wandb_project test --dummy --seed 2304815 --sde vpsde --backbone SIT --N 15 --wandb_name sit7_SIT-B/2_N15

VPIDM对SGMSE+的改动

  1. train.py中,增加参数T限制最大采样步数;nolog参数改为了no_wandb参数,若使用则不启用wandb而启用tensorboard
  2. 在model.py中,t_eps参数默认值从0.03改到了3e-2
  3. 添加了time_emb和p_name参数,time_emb为bool变量,time_emb为true则使用时间步为条件,否则用sigma
  4. backbone不再由model类成员变量存储,而是直接加载到self.dnn中
  5. 优化器只更新self.dnn的参数,而不是model类的所有参数
  6. 在evaluate_model中,如果sde的名字为VPSDE,则不使用corrector(校正器步数为0)
  7. 给sampling.get_pc_sampler添加了输入参数time_emb
  8. sde.py中改写了discretize抽象方法,不知道有何用途
  9. 给ouve(SGMSE+的sde)添加了–eta_mode和–eta参数,将初始N限定到了正常的30步
  10. 改写了alpha函数,这个也没看懂
  11. 新添了vpsde和bbed作为sde备选项

代码分析

SDE.py

使用注册表Registry管理类,每个包含复用类的子文件先声明SDERegistry = Registry(“SDE”),写复用类时先注册@SDERegistry.register(“ouvp”)在写类的内容,这样使用choices=SDERegistry.get_all_names()就可以得到所有复用类名字,sde_class = SDERegistry.get_by_name(temp_args.sde)就可以直接得到注册类

  1. sde.sde:核心方法,计算并返回正向扩散的漂移项drift和扩散项diffusion
  2. sde.marginal_prob:在时间 t 下,根据给定的初始状态 x0 和目标状态 y,计算系统状态 x(t) 的边际概率分布,返回的是该分布的均值和标准差
  3. sde.prior_sampling:用于在进行推理(反向扩散)时,通过y生成最初扩散样本x_T
  4. sde.discretize:通用抽象方法,用于生成离散化的正向过程的漂移项和扩散项(其实就是乘上dt)
  5. sde.rverse:通用抽象方法,生成逆sde类,即rsde
    • rsde.rsde_parts:生成一个字典,包括正向sde的漂移项和扩散项、分数模型等,最终计算逆sde的总漂移项和总扩散项
    • rsde.sde:调用rsde_parts生成的字典,返回总漂移项和总扩散项
    • rsde.discretize:用于生成离散化的反向过程的漂移项和扩散项,其是通过正向sde.discretize推导得到的,没用到以上两个方法

model.py

  1. Initialization and Setup:
    • __init__: 初始化模型,包括定义网络结构、损失函数、优化器等。
    • setup: 在训练、验证或测试的某个阶段前调用,用于设置数据或其他环境准备工作。通常在第一次调用 fitvalidatetest、或 predict 时触发。
  2. Data Preparation:
    • train_dataloader: 返回训练数据的 DataLoader。Trainer 使用这个数据加载器来获取训练数据。
    • val_dataloader: 返回验证数据的 DataLoader,用于在每个 epoch 结束时进行模型的验证。
    • test_dataloader: 如果执行 trainer.test(),将使用这个数据加载器获取测试数据。
  3. Optimizer and Scheduler Setup:
    • configure_optimizers: 配置优化器和学习率调度器。Trainer 使用这些优化器来更新模型参数。
  4. Training Loop:
    • training_step: 每个训练批次调用一次。定义前向传播、损失计算、反向传播等。通常返回损失。
    • _step: training_step中显示调用的自定义函数,这个函数体现了SDE的核心计算逻辑,要改SDE主要从这里改
    • forward():在_step中被显式调用。负责定义分数匹配模型的前向传播逻辑,在这里调用了self.dnn进行模型推理,模型接收输入张量,输出预测score值。
    • _loss: 用于计算损失的辅助方法,在 _step 中被调用。使用去噪分数匹配
  5. Validation Loop:
    • validation_step: 每个验证批次调用一次。用于在验证数据上评估模型的性能,其部分内容和training_step一致,但由于需要额外计算pesq等指标,需要调用一次inference.evaluate_model来进行一次反向过程(training_step只进行正向扩散过程),最终返回损失。
    • inference.evaluate_model: 进行反向过程推理,此处包括了使用采样器(调用get_pc_sampler)进行采样以及矫正器,最终返回sisdr、pesq等指标结果
    • validation_epoch_end: (可选) 在一个验证 epoch 结束时调用,用于聚合或处理验证结果。
  6. Checkpointing:
    • on_save_checkpoint: 在保存检查点时调用,可以存储额外的状态信息。
    • on_load_checkpoint: 在加载检查点时调用,用于恢复模型状态。
  7. Inference and Other Operations:
    • enhance: 专用于语音增强的功能方法,在推理过程(调用enhancement.py时)调用。其内容基本和inference.evaluate_model一致,但返回值不是指标结果,而是降噪后的语音
  8. get_pc_sampler
    • 一个嵌入版本的get_pc_sampler,本质上和sampling.get_pc_sampler没有太大区别

可选方法:

  1. Test Loop (如果执行 trainer.test()):
    • test_step: 每个测试批次调用一次。与 validation_step 类似,用于在测试数据上评估模型。
    • test_epoch_end: (可选) 在一个测试 epoch 结束时调用,用于处理测试结果。
  2. 验证阶段
    • 调用 on_validation_epoch_start()on_validation_epoch_end()
  3. 优化器步进:
    • 计算完梯度后,Trainer 会自动调用 optimizer.step()optimizer.zero_grad() 来更新模型权重。
  4. Epoch 开始:
    • 进入训练循环,每个 epoch 开始时,框架会调用 on_epoch_start()on_train_epoch_start()(如果这些方法被覆盖)。
  5. 结束处理:
    • 完成所有 epoch 后,调用 on_fit_end() 方法。你可以在这里进行一些清理工作或保存最终的模型状态。

sgmse.sampling

  1. get_pc_sampler
    • 根据预测器、纠正器名字(已经在correctors.py和predictors.py中注册过,如果无法引用则在init中import一下),返回一个函数pc_sampler和校正器步数。一般来说,预测器处理采样的确定性部分;校正器处理采样的随机性部分
    • pc_sampler:先使用sde.prior_sampling从sde的先验分布中采样起始点xt,然后创建一个时间序列步长,从sde.T一直到ade.N。对于每一个时间步,先计算档期时间步的stepsize,然后应用校正器corrector.update_fn细化当前版本xt,再应用预测器predictor.update_fn更新xt。最终输出去噪后结果xt_mean
    • 每一次验证论,都会调用inference.evaluate_model,进而调用model.get_pc_sampler获取(pc_)sampler,然后在调用sampler得到预测结果sample
  2. predictor
    • 一般在get_pc_sampler的predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)进行初始化,获取成员变量sde、rsde=sde.reverse(score_fn)、score_fn等
    • 具体predictor一般使用ReverseDiffusionPredictor(即p_name=’reverse_diffusion’)
    • predictor.update_fn:先调用rsde.discretize方法得到反向sde首项f和尾项g,然后由x – f得到x均值,再加上g * z得到预测结果x
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇