参数模块化 为了对整个Project的参数进行统一的管理,我们通常会使用一个特定的文件比如config.py文件,将整个工程的参数(包括模型的参数)都保存在该文件里面,然后在其它文件里面通过导入该文件即可使用特定的参数,这样方便实现参数的模块化。 在config.py文件内,我们通过定义一个Config类来保存参数,主要参数包含以下种: 路径(包括训练集、测试集路径,模型保存路径,source路径等等) 数据参数(包括输入图片的大小,通道,类别数等等) 训练通用参数(包括使用GPU,使用已训练模型,学习率,batch_size等等) 模型特定参数 Config类的方法主要有: 初始化路径文件夹 大致的config.py文件内容如下: # -*- coding: utf-8 -*- import os import torch class Config(): def __init__(self): # general param self.RETRAIN = True self.USE_CUDA = torch.cuda.is_available() # define the data paths self.RAW_TRAIN_DATA = "./data/train_data/" self.RAW_TEST_DATA = "./data/eval_data/" # define the source path self.SOURCE_DIR_PATH = { "MODEL_DIR" : "./source/models/", "SUMMARY_DIR" : "./source/summary/" } # define the file path self.LABEL_TO_NAME_PATH = "./source/dict/label_to_name_dict.pkl" self.NAME_TO_LABEL_PATH = "./source/dict/name_to_label_dict.pkl" # check the path self.check_dir() # define the param of the training self.WIDTH = 488 self.HEIGHT = 488 self.CHANNEL = 3 self.NUM_CLASS = 250 self.BATCH_SIZE = 30 self.NUM_EPOCHS = 500 self.LEARNING_RATE = 0.001 self.VALPERBATCH = 2 def check_dir(self): ''' This function is used to check the dirs.if data path does not exists, raise error.if source path does not exits, make new dirs. :return: None ''' # check the data path if not os.path.exists(self.RAW_TEST_DATA): raise Exception("==> Error: Data path %s does not exist." % self.RAW_TEST_DATA) if not os.path.exists(self.RAW_TRAIN_DATA): raise Exception("==> Error: Data path %s does not exist." % self.RAW_TRAIN_DATA) # check the source path for name, path in self.SOURCE_DIR_PATH.items(): if not os.path.exists(path): print("==> Creating %s : %s" % (name, path)) os.makedirs(path)