L
参数模块化
为了对整个Project的参数进行统一的管理,我们通常会使用一个特定的文件比如config.py文件,将整个工程的参数(包括模型的参数)都保存在该文件里面,然后在其它文件里面通过导入该文件即可使用特定的参数,这样方便实现参数的模块化。
在config.py文件内,我们通过定义一个Config类来保存参数,主要参数包含以下种:
路径(包括训练集、测试集路径,模型保存路径,source路径等等)
数据参数(包括输入图片的大小,通道,类别数等等)
训练通用参数(包括使用GPU,使用已训练模型,学习率,batch_size等等)
模型特定参数
Config类的方法主要有:
初始化路径文件夹
大致的config.py文件内容如下:
# -*- coding: utf-8 -*-
import os
import torch
class Config():
def __init__(self):
# general param
self.RETRAIN = True
self.USE_CUDA = torch.cuda.is_available()
# define the data paths
self.RAW_TRAIN_DATA = "./data/train_data/"
self.RAW_TEST_DATA = "./data/eval_data/"
# define the source path
self.SOURCE_DIR_PATH = {
"MODEL_DIR" : "./source/models/",
"SUMMARY_DIR" : "./source/summary/"
}
# define the file path
self.LABEL_TO_NAME_PATH = "./source/dict/label_to_name_dict.pkl"
self.NAME_TO_LABEL_PATH = "./source/dict/name_to_label_dict.pkl"
# check the path
self.check_dir()
# define the param of the training
self.WIDTH = 488
self.HEIGHT = 488
self.CHANNEL = 3
self.NUM_CLASS = 250
self.BATCH_SIZE = 30
self.NUM_EPOCHS = 500
self.LEARNING_RATE = 0.001
self.VALPERBATCH = 2
def check_dir(self):
'''
This function is used to check the dirs.if data path
does not exists, raise error.if source path does not
exits, make new dirs.
:return: None
'''
# check the data path
if not os.path.exists(self.RAW_TEST_DATA):
raise Exception("==> Error: Data path %s does not exist." % self.RAW_TEST_DATA)
if not os.path.exists(self.RAW_TRAIN_DATA):
raise Exception("==> Error: Data path %s does not exist." % self.RAW_TRAIN_DATA)
# check the source path
for name, path in self.SOURCE_DIR_PATH.items():
if not os.path.exists(path):
print("==> Creating %s : %s" % (name, path))
os.makedirs(path)