개정판 87c89301
issue #1366: change save and log
Change-Id: Ief20ef12655ae9db9545ed449351dc2b29e72972
DTI_PID/WebServer/symbol_training/train.py | ||
---|---|---|
113 | 113 |
|
114 | 114 |
loss_step = 0 |
115 | 115 |
|
116 |
save_count = 0 |
|
117 |
|
|
116 | 118 |
for epoch in range(opt.num_epoches): |
117 | 119 |
if str(epoch) in learning_rate_schedule.keys(): |
118 | 120 |
for param_group in optimizer.param_groups: |
... | ... | |
133 | 135 |
optimizer.step() |
134 | 136 |
|
135 | 137 |
if iter % opt.test_interval == 0: |
136 |
print("Epoch: {}/{}, Iteration: {}/{}, Lr: {}, Loss:{:.2f} (Coord:{:.2f} Conf:{:.2f} Cls:{:.2f})".format
|
|
138 |
print("Epoch: {}/{}, Iteration: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format
|
|
137 | 139 |
(epoch + 1, opt.num_epoches, iter + 1, num_iter_per_epoch, optimizer.param_groups[0]['lr'], loss, |
138 | 140 |
loss_coord,loss_conf,loss_cls)) |
139 | 141 |
|
... | ... | |
186 | 188 |
te_coord_loss = sum(loss_coord_ls) / test_set.__len__() |
187 | 189 |
te_conf_loss = sum(loss_conf_ls) / test_set.__len__() |
188 | 190 |
te_cls_loss = sum(loss_cls_ls) / test_set.__len__() |
189 |
print("Test>> Epoch: {}/{}, Lr: {}, Loss:{:.2f} (Coord:{:.2f} Conf:{:.2f} Cls:{:.2f})".format(
|
|
191 |
print("Test>> Epoch: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format(
|
|
190 | 192 |
epoch + 1, opt.num_epoches, optimizer.param_groups[0]['lr'], te_loss, te_coord_loss, te_conf_loss, te_cls_loss)) |
191 | 193 |
|
192 | 194 |
model.train() |
193 | 195 |
if te_loss + opt.es_min_delta < best_loss: |
196 |
save_count += 1 |
|
194 | 197 |
best_loss = te_loss |
195 | 198 |
best_epoch = epoch |
196 | 199 |
print("SAVE MODEL") |
197 |
torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params.pth")) |
|
198 |
torch.save(model, os.path.join(opt.saved_path, name + "_whole_model.pth")) |
|
200 |
torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + str(best_loss) + ".pth"))
|
|
201 |
torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + str(best_loss) + ".pth"))
|
|
199 | 202 |
|
200 | 203 |
# Early stopping |
201 | 204 |
if epoch - best_epoch > opt.es_patience > 0: |
내보내기 Unified diff