개정판 150e63fa
issue #1366:
DTI_PID/WebServer/symbol_training/train.py | ||
---|---|---|
19 | 19 |
from src.vis_utils.vis_tool import visdom_bbox |
20 | 20 |
|
21 | 21 |
loss_data = {'X': [], 'Y': [], 'legend_U':['total', 'coord', 'conf', 'cls']} |
22 |
#visdom = visdom.Visdom(port='8088') |
|
23 | 22 |
|
24 | 23 |
# 형상 CLASS |
25 | 24 |
DOFTECH_CLASSES= ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief', |
... | ... | |
31 | 30 |
|
32 | 31 |
use_voc_model = True |
33 | 32 |
use_visdom = False |
34 |
if use_visdom |
|
35 |
visdom = visdom.Visdom() |
|
33 |
if use_visdom: |
|
34 |
visdom = visdom.Visdom(port='8088') |
|
35 |
#visdom = visdom.Visdom() |
|
36 | 36 |
|
37 | 37 |
def train(name=None, classes=None, bigs=None, root_path=None, pre_trained_model_path=None): |
38 | 38 |
global DOFTECH_CLASSES |
... | ... | |
217 | 217 |
best_epoch = epoch |
218 | 218 |
print("SAVE MODEL") |
219 | 219 |
# for debug for each loss |
220 |
torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth")) |
|
221 |
torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth")) |
|
220 |
#torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth"))
|
|
221 |
#torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth"))
|
|
222 | 222 |
# save |
223 | 223 |
torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params.pth")) |
224 | 224 |
torch.save(model, os.path.join(opt.saved_path, name + "_whole_model.pth")) |
내보내기 Unified diff