Skip to content

modules模块

更新时间:2024年4月13日

train.py

定义训练过程的模块,最终以train函数输出。

模块级输入:

类folder、类Custom_criterion1、变量(类)dataloader、模块simple_functions、参数EPOCHS、CUT_EPOCH、BATCH_SIZE、LATENTDIM、LR_MAX、LR_MIN、VAE_INUSE

注意点:

  • VAE平行装载,在两块GPU上并行训练。
  • Learn_rate采用余弦退火法,输入为LR_MAX和LR_MIN。
  • 优化器使用AdamW
  • criterion与loss函数一致,在CUT_EPOCH之前使用纯MSE,在CUT_EPOCH之后使用0.7xMSE+0.3xSSIM

application.py

定义训练过程的模块,最终以application函数输出。

模块级输入:

类folder、6个损失函数。

用于评估的图片序号:

  • 真实数据集的训练集:1500,2000
  • 真实数据集的测试集:5500,6000
  • 模拟数据集的训练集:2000,3000
  • 模拟数据集的训练集:4500,4700