5 |
5 |
import argparse
|
6 |
6 |
import torch.nn as nn
|
7 |
7 |
from torch.utils.data import DataLoader
|
8 |
|
from src.doftech_dataset import DoftechDataset, DoftechDatasetTest
|
|
8 |
from src.doftech_dataset import DoftechDataset
|
9 |
9 |
from src.utils import *
|
10 |
10 |
from src.loss import YoloLoss
|
11 |
11 |
from src.yolo_net import Yolo
|
... | ... | |
147 |
147 |
loss.backward()
|
148 |
148 |
optimizer.step()
|
149 |
149 |
|
150 |
|
if iter % opt.test_interval == 0:
|
|
150 |
if iter % (opt.test_interval * 5) == 0:
|
151 |
151 |
print("Epoch: {}/{}, Iteration: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format
|
152 |
152 |
(epoch + 1, opt.num_epoches, iter + 1, num_iter_per_epoch, optimizer.param_groups[0]['lr'], loss,
|
153 |
153 |
loss_coord,loss_conf,loss_cls))
|
... | ... | |
191 |
191 |
loss_conf_ls = []
|
192 |
192 |
loss_cls_ls = []
|
193 |
193 |
for te_iter, te_batch in enumerate(test_generator):
|
194 |
|
te_image, te_label = te_batch
|
|
194 |
te_image, te_label, _ = te_batch
|
195 |
195 |
num_sample = len(te_label)
|
196 |
196 |
if torch.cuda.is_available():
|
197 |
197 |
te_image = te_image.cuda()
|
... | ... | |
217 |
217 |
best_epoch = epoch
|
218 |
218 |
print("SAVE MODEL")
|
219 |
219 |
# for debug for each loss
|
220 |
|
#torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth"))
|
|
220 |
torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth"))
|
221 |
221 |
#torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth"))
|
222 |
222 |
# save
|
223 |
223 |
torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params.pth"))
|
... | ... | |
226 |
226 |
save_count += 1
|
227 |
227 |
# for debug for each loss
|
228 |
228 |
torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth"))
|
229 |
|
torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth"))
|
|
229 |
#torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth"))
|
230 |
230 |
|
231 |
231 |
# Early stopping
|
232 |
232 |
if epoch - best_epoch > opt.es_patience > 0:
|