hytos / DTI_PID / WebServer / symbol_training / train.py @ 91e72bfe
이력 | 보기 | 이력해설 | 다운로드 (12.9 KB)
1 | 6374c2c6 | esham21 | """
|
---|---|---|---|
2 | Training For Small Object
|
||
3 | """
|
||
4 | import os |
||
5 | import argparse |
||
6 | import torch.nn as nn |
||
7 | from torch.utils.data import DataLoader |
||
8 | dd6d4de9 | esham21 | from src.doftech_dataset import DoftechDataset |
9 | 6374c2c6 | esham21 | from src.utils import * |
10 | 76986eb0 | esham21 | from src.loss import YoloLoss |
11 | 6374c2c6 | esham21 | from src.yolo_net import Yolo |
12 | from src.yolo_doftech import YoloD |
||
13 | import shutil |
||
14 | import visdom |
||
15 | import cv2 |
||
16 | import pickle |
||
17 | import numpy as np |
||
18 | from src.vis_utils import array_tool as at |
||
19 | from src.vis_utils.vis_tool import visdom_bbox |
||
20 | |||
21 | loss_data = {'X': [], 'Y': [], 'legend_U':['total', 'coord', 'conf', 'cls']} |
||
22 | |||
23 | # 형상 CLASS
|
||
24 | DOFTECH_CLASSES= ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief', |
||
25 | '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure', |
||
26 | 'circle', 'inst_console', 'inst_console_dcs', 'inst_console_sih', 'logic_dcs', 'utility', 'specialty_items', 'logic', 'logic_local_console_dcs', |
||
27 | 'reducer', 'blind_spectacle_open', 'blind_insertion_open', 'blind_spectacle_close', 'blind_insertion_close', |
||
28 | 'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot' |
||
29 | ,'opc']
|
||
30 | |||
31 | 9be41199 | esham21 | use_voc_model = True
|
32 | use_visdom = False
|
||
33 | 150e63fa | esham21 | if use_visdom:
|
34 | visdom = visdom.Visdom(port='8088')
|
||
35 | #visdom = visdom.Visdom()
|
||
36 | 6374c2c6 | esham21 | |
37 | 1e8ea226 | esham21 | def train(name=None, classes=None, bigs=None, root_path=None, pre_trained_model_path=None): |
38 | e473b8aa | esham21 | global DOFTECH_CLASSES
|
39 | 6374c2c6 | esham21 | DOFTECH_CLASSES = classes |
40 | |||
41 | parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
|
||
42 | e473b8aa | esham21 | parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images") |
43 | 6374c2c6 | esham21 | parser.add_argument("--batch_size", type=int, default=10, help="The number of images per batch") |
44 | |||
45 | # Training 기본 Setting
|
||
46 | parser.add_argument("--momentum", type=float, default=0.9) |
||
47 | parser.add_argument("--decay", type=float, default=0.0005) |
||
48 | parser.add_argument("--dropout", type=float, default=0.5) |
||
49 | aba30784 | esham21 | parser.add_argument("--num_epoches", type=int, default=205) |
50 | b64cc3b5 | esham21 | parser.add_argument("--test_interval", type=int, default=20, help="Number of epoches between testing phases") |
51 | 6374c2c6 | esham21 | parser.add_argument("--object_scale", type=float, default=1.0) |
52 | parser.add_argument("--noobject_scale", type=float, default=0.5) |
||
53 | parser.add_argument("--class_scale", type=float, default=1.0) |
||
54 | parser.add_argument("--coord_scale", type=float, default=5.0) |
||
55 | parser.add_argument("--reduction", type=int, default=32) |
||
56 | parser.add_argument("--es_min_delta", type=float, default=0.0, |
||
57 | help="Early stopping's parameter: minimum change loss to qualify as an improvement")
|
||
58 | parser.add_argument("--es_patience", type=int, default=0, |
||
59 | help="Early stopping's parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.")
|
||
60 | |||
61 | # 확인해야 하는 PATH
|
||
62 | parser.add_argument("--data_path", type=str, default=os.path.join(root_path, 'training'), help="the root folder of dataset") # 학습 데이터 경로 -> image와 xml의 상위 경로 입력 |
||
63 | parser.add_argument("--data_path_test", type=str, default=os.path.join(root_path, 'test'), help="the root folder of dataset") # 테스트 데이터 경로 -> test할 이미지만 넣으면 됨 |
||
64 | #parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="model")
|
||
65 | parser.add_argument("--pre_trained_model_path", type=str, default=pre_trained_model_path) # Pre-training 된 모델 경로 |
||
66 | |||
67 | parser.add_argument("--saved_path", type=str, default=os.path.join(root_path, 'checkpoint')) # training 된 모델 저장 경로 |
||
68 | parser.add_argument("--conf_threshold", type=float, default=0.35) |
||
69 | parser.add_argument("--nms_threshold", type=float, default=0.5) |
||
70 | opt = parser.parse_args() |
||
71 | |||
72 | a6b28afb | esham21 | if not os.path.isdir(opt.saved_path): |
73 | 7f74c5bd | esham21 | os.mkdir(opt.saved_path) |
74 | a6b28afb | esham21 | |
75 | 28822594 | esham21 | # 학습할 클래스들을 저장하고 인식 시 불러와 사용합니다.
|
76 | 23662e23 | esham21 | with open(os.path.join(opt.saved_path, name + "_info.info"), 'w') as stream: |
77 | con = str(len(DOFTECH_CLASSES)) |
||
78 | names = '\n'.join(DOFTECH_CLASSES)
|
||
79 | 1e8ea226 | esham21 | bigs = '\n'.join(bigs)
|
80 | con = con + '\n' + names + '\n' + '***bigs***' + '\n' + bigs |
||
81 | 23662e23 | esham21 | stream.write(con) |
82 | 6374c2c6 | esham21 | |
83 | if torch.cuda.is_available():
|
||
84 | torch.cuda.manual_seed(123)
|
||
85 | else:
|
||
86 | torch.manual_seed(123)
|
||
87 | 9be41199 | esham21 | |
88 | 6374c2c6 | esham21 | learning_rate_schedule = {"0": 1e-5, "5": 1e-4, |
89 | "80": 1e-5, "110": 1e-6} |
||
90 | |||
91 | 072a26bd | esham21 | training_params = {"batch_size": 1,#opt.batch_size, |
92 | 6374c2c6 | esham21 | "shuffle": True, |
93 | "drop_last": True, |
||
94 | "collate_fn": custom_collate_fn}
|
||
95 | |||
96 | test_params = {"batch_size": opt.batch_size,
|
||
97 | "shuffle": False, |
||
98 | "drop_last": False, |
||
99 | "collate_fn": custom_collate_fn}
|
||
100 | |||
101 | training_set = DoftechDataset(opt.data_path, opt.image_size, is_training=True, classes=DOFTECH_CLASSES)
|
||
102 | training_generator = DataLoader(training_set, **training_params) |
||
103 | |||
104 | 9be41199 | esham21 | test_set = DoftechDataset(opt.data_path_test, opt.image_size, is_training=False, classes=DOFTECH_CLASSES)
|
105 | 6374c2c6 | esham21 | test_generator = DataLoader(test_set, **test_params) |
106 | |||
107 | 9be41199 | esham21 | # BUILDING MODEL =======================================================================
|
108 | if use_voc_model :
|
||
109 | pre_model = Yolo(20).cuda()
|
||
110 | pre_model.load_state_dict(torch.load(opt.pre_trained_model_path), strict=False)
|
||
111 | model = YoloD(pre_model, training_set.num_classes).cuda() |
||
112 | else :
|
||
113 | model = Yolo(training_set.num_classes).cuda() |
||
114 | 6374c2c6 | esham21 | |
115 | nn.init.normal_(list(model.modules())[-1].weight, 0, 0.01) |
||
116 | |||
117 | criterion = YoloLoss(training_set.num_classes, model.anchors, opt.reduction) |
||
118 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-5, momentum=opt.momentum, weight_decay=opt.decay)
|
||
119 | |||
120 | best_loss = 1e10
|
||
121 | best_epoch = 0
|
||
122 | num_iter_per_epoch = len(training_generator)
|
||
123 | loss_step = 0
|
||
124 | 9be41199 | esham21 | # ======================================================================================
|
125 | 6374c2c6 | esham21 | |
126 | 9be41199 | esham21 | # TRAINING =============================================================================
|
127 | 87c89301 | esham21 | save_count = 0
|
128 | |||
129 | 9be41199 | esham21 | model.train() |
130 | 6374c2c6 | esham21 | for epoch in range(opt.num_epoches): |
131 | if str(epoch) in learning_rate_schedule.keys(): |
||
132 | for param_group in optimizer.param_groups: |
||
133 | param_group['lr'] = learning_rate_schedule[str(epoch)] |
||
134 | |||
135 | for iter, batch in enumerate(training_generator): |
||
136 | 9be41199 | esham21 | image, label, image2 = batch |
137 | image = Variable(image.cuda(), requires_grad=False)
|
||
138 | 6374c2c6 | esham21 | if torch.cuda.is_available():
|
139 | image = Variable(image.cuda(), requires_grad=True)
|
||
140 | 9be41199 | esham21 | origin = Variable(image2.cuda(), requires_grad=False)
|
141 | 6374c2c6 | esham21 | else:
|
142 | image = Variable(image, requires_grad=True)
|
||
143 | |||
144 | optimizer.zero_grad() |
||
145 | logits = model(image) |
||
146 | loss, loss_coord, loss_conf, loss_cls = criterion(logits, label) |
||
147 | loss.backward() |
||
148 | optimizer.step() |
||
149 | |||
150 | dd6d4de9 | esham21 | if iter % (opt.test_interval * 5) == 0: |
151 | 87c89301 | esham21 | print("Epoch: {}/{}, Iteration: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format
|
152 | 6374c2c6 | esham21 | (epoch + 1, opt.num_epoches, iter + 1, num_iter_per_epoch, optimizer.param_groups[0]['lr'], loss, |
153 | loss_coord,loss_conf,loss_cls)) |
||
154 | |||
155 | 9be41199 | esham21 | if use_visdom:
|
156 | predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold, |
||
157 | opt.nms_threshold) |
||
158 | 6374c2c6 | esham21 | |
159 | 9be41199 | esham21 | gt_image = at.tonumpy(image[0])
|
160 | gt_image = visdom_bbox(gt_image, label[0])
|
||
161 | visdom.image(gt_image, opts=dict(title='gt_box_image'), win=3) |
||
162 | #
|
||
163 | origin_image = at.tonumpy(origin[0])
|
||
164 | origin_image = visdom_bbox(origin_image, []) |
||
165 | visdom.image(origin_image, opts=dict(title='origin_box_image'), win=4) |
||
166 | 6374c2c6 | esham21 | |
167 | image = at.tonumpy(image[0])
|
||
168 | |||
169 | 9be41199 | esham21 | if len(predictions) != 0: |
170 | box_image = visdom_bbox(image, predictions[0])
|
||
171 | visdom.image(box_image, opts=dict(title='box_image'), win=2) |
||
172 | |||
173 | elif len(predictions) == 0: |
||
174 | box_image = visdom_bbox(image, []) |
||
175 | visdom.image(box_image, opts=dict(title='box_image'), win=2) |
||
176 | 6374c2c6 | esham21 | |
177 | 9be41199 | esham21 | loss_dict = { |
178 | 'total' : loss.item(),
|
||
179 | 'coord' : loss_coord.item(),
|
||
180 | 'conf' : loss_conf.item(),
|
||
181 | 'cls' : loss_cls.item()
|
||
182 | } |
||
183 | 6374c2c6 | esham21 | |
184 | 9be41199 | esham21 | visdom_loss(visdom, loss_step, loss_dict) |
185 | loss_step = loss_step + 1
|
||
186 | 6374c2c6 | esham21 | |
187 | if epoch % opt.test_interval == 0: |
||
188 | model.eval() |
||
189 | loss_ls = [] |
||
190 | loss_coord_ls = [] |
||
191 | loss_conf_ls = [] |
||
192 | loss_cls_ls = [] |
||
193 | for te_iter, te_batch in enumerate(test_generator): |
||
194 | dd6d4de9 | esham21 | te_image, te_label, _ = te_batch |
195 | 6374c2c6 | esham21 | num_sample = len(te_label)
|
196 | if torch.cuda.is_available():
|
||
197 | te_image = te_image.cuda() |
||
198 | with torch.no_grad():
|
||
199 | te_logits = model(te_image) |
||
200 | batch_loss, batch_loss_coord, batch_loss_conf, batch_loss_cls = criterion(te_logits, te_label) |
||
201 | loss_ls.append(batch_loss * num_sample) |
||
202 | loss_coord_ls.append(batch_loss_coord * num_sample) |
||
203 | loss_conf_ls.append(batch_loss_conf * num_sample) |
||
204 | loss_cls_ls.append(batch_loss_cls * num_sample) |
||
205 | |||
206 | te_loss = sum(loss_ls) / test_set.__len__()
|
||
207 | te_coord_loss = sum(loss_coord_ls) / test_set.__len__()
|
||
208 | te_conf_loss = sum(loss_conf_ls) / test_set.__len__()
|
||
209 | te_cls_loss = sum(loss_cls_ls) / test_set.__len__()
|
||
210 | 87c89301 | esham21 | print("Test>> Epoch: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format(
|
211 | 6374c2c6 | esham21 | epoch + 1, opt.num_epoches, optimizer.param_groups[0]['lr'], te_loss, te_coord_loss, te_conf_loss, te_cls_loss)) |
212 | |||
213 | model.train() |
||
214 | if te_loss + opt.es_min_delta < best_loss:
|
||
215 | 87c89301 | esham21 | save_count += 1
|
216 | 6374c2c6 | esham21 | best_loss = te_loss |
217 | best_epoch = epoch |
||
218 | print("SAVE MODEL")
|
||
219 | 61709b69 | esham21 | # for debug for each loss
|
220 | 072a26bd | esham21 | #torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth"))
|
221 | 150e63fa | esham21 | #torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth"))
|
222 | 61709b69 | esham21 | # save
|
223 | torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params.pth"))
|
||
224 | torch.save(model, os.path.join(opt.saved_path, name + "_whole_model.pth"))
|
||
225 | 77f7b6b3 | esham21 | else:
|
226 | save_count += 1
|
||
227 | # for debug for each loss
|
||
228 | 072a26bd | esham21 | #torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth"))
|
229 | dd6d4de9 | esham21 | #torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth"))
|
230 | 6374c2c6 | esham21 | |
231 | # Early stopping
|
||
232 | if epoch - best_epoch > opt.es_patience > 0: |
||
233 | print("Stop training at epoch {}. The lowest loss achieved is {}".format(epoch, te_loss))
|
||
234 | break
|
||
235 | |||
236 | def visdom_loss(visdom, loss_step, loss_dict): |
||
237 | loss_data['X'].append(loss_step)
|
||
238 | loss_data['Y'].append([loss_dict[k] for k in loss_data['legend_U']]) |
||
239 | visdom.line( |
||
240 | X=np.stack([np.array(loss_data['X'])] * len(loss_data['legend_U']), 1), |
||
241 | Y=np.array(loss_data['Y']),
|
||
242 | win=30,
|
||
243 | opts=dict(xlabel='Step', |
||
244 | ylabel='Loss',
|
||
245 | title='YOLO_V2',
|
||
246 | legend=loss_data['legend_U']),
|
||
247 | update='append'
|
||
248 | ) |
||
249 | |||
250 | if __name__ == "__main__": |
||
251 | b65d47a2 | esham21 | datas = ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief', |
252 | '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure', |
||
253 | 'inst', 'func_valve', 'inst_console', 'inst_console_dcs', 'inst_console_sih', 'logic_dcs', 'utility', 'specialty_items', 'logic', 'logic_local_console_dcs', |
||
254 | 'reducer', 'blind_spectacle_open', 'blind_insertion_open', 'blind_spectacle_close', 'blind_insertion_close', |
||
255 | 'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot', |
||
256 | 'opc']
|
||
257 | data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\Data\\', 'VnV') |
||
258 | train(name='VnV', classes=datas, root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath(
|
||
259 | __file__)) + '\\pre_trained_model\\only_params_trained_yolo_voc') |