프로젝트

일반

사용자정보

개정판 87c89301

ID87c893012768c8a7a98d86ca6243551ad830275a
상위 e473b8aa
하위 61709b69

함의성이(가) 약 5년 전에 추가함

issue #1366: change save and log

Change-Id: Ief20ef12655ae9db9545ed449351dc2b29e72972

차이점 보기:

DTI_PID/WebServer/symbol_training/train.py
113 113

  
114 114
    loss_step = 0
115 115

  
116
    save_count = 0
117

  
116 118
    for epoch in range(opt.num_epoches):
117 119
        if str(epoch) in learning_rate_schedule.keys():
118 120
            for param_group in optimizer.param_groups:
......
133 135
            optimizer.step()
134 136

  
135 137
            if iter % opt.test_interval == 0:
136
                print("Epoch: {}/{}, Iteration: {}/{}, Lr: {}, Loss:{:.2f} (Coord:{:.2f} Conf:{:.2f} Cls:{:.2f})".format
138
                print("Epoch: {}/{}, Iteration: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format
137 139
                    (epoch + 1, opt.num_epoches, iter + 1, num_iter_per_epoch, optimizer.param_groups[0]['lr'], loss,
138 140
                    loss_coord,loss_conf,loss_cls))
139 141

  
......
186 188
            te_coord_loss = sum(loss_coord_ls) / test_set.__len__()
187 189
            te_conf_loss = sum(loss_conf_ls) / test_set.__len__()
188 190
            te_cls_loss = sum(loss_cls_ls) / test_set.__len__()
189
            print("Test>> Epoch: {}/{}, Lr: {}, Loss:{:.2f} (Coord:{:.2f} Conf:{:.2f} Cls:{:.2f})".format(
191
            print("Test>> Epoch: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format(
190 192
                epoch + 1, opt.num_epoches, optimizer.param_groups[0]['lr'], te_loss, te_coord_loss, te_conf_loss, te_cls_loss))
191 193

  
192 194
            model.train()
193 195
            if te_loss + opt.es_min_delta < best_loss:
196
                save_count += 1
194 197
                best_loss = te_loss
195 198
                best_epoch = epoch
196 199
                print("SAVE MODEL")
197
                torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params.pth"))
198
                torch.save(model, os.path.join(opt.saved_path, name + "_whole_model.pth"))
200
                torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + str(best_loss) + ".pth"))
201
                torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + str(best_loss) + ".pth"))
199 202

  
200 203
            # Early stopping
201 204
            if epoch - best_epoch > opt.es_patience > 0:

내보내기 Unified diff

클립보드 이미지 추가 (최대 크기: 500 MB)