hytos / DTI_PID / WebServer / symbol_training / train.py @ 46aba3e1
이력 | 보기 | 이력해설 | 다운로드 (12.8 KB)
1 | 6374c2c6 | esham21 | """
|
---|---|---|---|
2 | Training For Small Object
|
||
3 | """
|
||
4 | import os |
||
5 | import argparse |
||
6 | import torch.nn as nn |
||
7 | from torch.utils.data import DataLoader |
||
8 | from src.doftech_dataset import DoftechDataset, DoftechDatasetTest |
||
9 | from src.utils import * |
||
10 | 76986eb0 | esham21 | from src.loss import YoloLoss |
11 | 6374c2c6 | esham21 | from src.yolo_net import Yolo |
12 | from src.yolo_doftech import YoloD |
||
13 | import shutil |
||
14 | import visdom |
||
15 | import cv2 |
||
16 | import pickle |
||
17 | import numpy as np |
||
18 | from src.vis_utils import array_tool as at |
||
19 | from src.vis_utils.vis_tool import visdom_bbox |
||
20 | |||
21 | loss_data = {'X': [], 'Y': [], 'legend_U':['total', 'coord', 'conf', 'cls']} |
||
22 | b64cc3b5 | esham21 | #visdom = visdom.Visdom(port='8088')
|
23 | 6374c2c6 | esham21 | |
24 | # 형상 CLASS
|
||
25 | DOFTECH_CLASSES= ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief', |
||
26 | '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure', |
||
27 | 'circle', 'inst_console', 'inst_console_dcs', 'inst_console_sih', 'logic_dcs', 'utility', 'specialty_items', 'logic', 'logic_local_console_dcs', |
||
28 | 'reducer', 'blind_spectacle_open', 'blind_insertion_open', 'blind_spectacle_close', 'blind_insertion_close', |
||
29 | 'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot' |
||
30 | ,'opc']
|
||
31 | |||
32 | 23662e23 | esham21 | #print(len(DOFTECH_CLASSES))
|
33 | 6374c2c6 | esham21 | |
34 | 1e8ea226 | esham21 | def train(name=None, classes=None, bigs=None, root_path=None, pre_trained_model_path=None): |
35 | e473b8aa | esham21 | global DOFTECH_CLASSES
|
36 | 6374c2c6 | esham21 | DOFTECH_CLASSES = classes |
37 | |||
38 | parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
|
||
39 | e473b8aa | esham21 | parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images") |
40 | 6374c2c6 | esham21 | parser.add_argument("--batch_size", type=int, default=10, help="The number of images per batch") |
41 | |||
42 | # Training 기본 Setting
|
||
43 | parser.add_argument("--momentum", type=float, default=0.9) |
||
44 | parser.add_argument("--decay", type=float, default=0.0005) |
||
45 | parser.add_argument("--dropout", type=float, default=0.5) |
||
46 | aba30784 | esham21 | parser.add_argument("--num_epoches", type=int, default=205) |
47 | b64cc3b5 | esham21 | parser.add_argument("--test_interval", type=int, default=20, help="Number of epoches between testing phases") |
48 | 6374c2c6 | esham21 | parser.add_argument("--object_scale", type=float, default=1.0) |
49 | parser.add_argument("--noobject_scale", type=float, default=0.5) |
||
50 | parser.add_argument("--class_scale", type=float, default=1.0) |
||
51 | parser.add_argument("--coord_scale", type=float, default=5.0) |
||
52 | parser.add_argument("--reduction", type=int, default=32) |
||
53 | parser.add_argument("--es_min_delta", type=float, default=0.0, |
||
54 | help="Early stopping's parameter: minimum change loss to qualify as an improvement")
|
||
55 | parser.add_argument("--es_patience", type=int, default=0, |
||
56 | help="Early stopping's parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.")
|
||
57 | |||
58 | # 확인해야 하는 PATH
|
||
59 | parser.add_argument("--data_path", type=str, default=os.path.join(root_path, 'training'), help="the root folder of dataset") # 학습 데이터 경로 -> image와 xml의 상위 경로 입력 |
||
60 | parser.add_argument("--data_path_test", type=str, default=os.path.join(root_path, 'test'), help="the root folder of dataset") # 테스트 데이터 경로 -> test할 이미지만 넣으면 됨 |
||
61 | #parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="model")
|
||
62 | parser.add_argument("--pre_trained_model_path", type=str, default=pre_trained_model_path) # Pre-training 된 모델 경로 |
||
63 | |||
64 | parser.add_argument("--saved_path", type=str, default=os.path.join(root_path, 'checkpoint')) # training 된 모델 저장 경로 |
||
65 | parser.add_argument("--conf_threshold", type=float, default=0.35) |
||
66 | parser.add_argument("--nms_threshold", type=float, default=0.5) |
||
67 | opt = parser.parse_args() |
||
68 | |||
69 | a6b28afb | esham21 | if not os.path.isdir(opt.saved_path): |
70 | 7f74c5bd | esham21 | os.mkdir(opt.saved_path) |
71 | a6b28afb | esham21 | |
72 | 28822594 | esham21 | # 학습할 클래스들을 저장하고 인식 시 불러와 사용합니다.
|
73 | 23662e23 | esham21 | with open(os.path.join(opt.saved_path, name + "_info.info"), 'w') as stream: |
74 | con = str(len(DOFTECH_CLASSES)) |
||
75 | names = '\n'.join(DOFTECH_CLASSES)
|
||
76 | 1e8ea226 | esham21 | bigs = '\n'.join(bigs)
|
77 | con = con + '\n' + names + '\n' + '***bigs***' + '\n' + bigs |
||
78 | 23662e23 | esham21 | stream.write(con) |
79 | 6374c2c6 | esham21 | |
80 | if torch.cuda.is_available():
|
||
81 | torch.cuda.manual_seed(123)
|
||
82 | else:
|
||
83 | torch.manual_seed(123)
|
||
84 | learning_rate_schedule = {"0": 1e-5, "5": 1e-4, |
||
85 | "80": 1e-5, "110": 1e-6} |
||
86 | |||
87 | training_params = {"batch_size": opt.batch_size,
|
||
88 | "shuffle": True, |
||
89 | "drop_last": True, |
||
90 | "collate_fn": custom_collate_fn}
|
||
91 | |||
92 | test_params = {"batch_size": opt.batch_size,
|
||
93 | "shuffle": False, |
||
94 | "drop_last": False, |
||
95 | "collate_fn": custom_collate_fn}
|
||
96 | |||
97 | training_set = DoftechDataset(opt.data_path, opt.image_size, is_training=True, classes=DOFTECH_CLASSES)
|
||
98 | training_generator = DataLoader(training_set, **training_params) |
||
99 | |||
100 | test_set = DoftechDatasetTest(opt.data_path_test, opt.image_size, is_training=False, classes=DOFTECH_CLASSES)
|
||
101 | test_generator = DataLoader(test_set, **test_params) |
||
102 | |||
103 | pre_model = Yolo(20).cuda()
|
||
104 | pre_model.load_state_dict(torch.load(opt.pre_trained_model_path), strict=False)
|
||
105 | |||
106 | model = YoloD(pre_model, training_set.num_classes).cuda() |
||
107 | |||
108 | nn.init.normal_(list(model.modules())[-1].weight, 0, 0.01) |
||
109 | |||
110 | criterion = YoloLoss(training_set.num_classes, model.anchors, opt.reduction) |
||
111 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-5, momentum=opt.momentum, weight_decay=opt.decay)
|
||
112 | |||
113 | best_loss = 1e10
|
||
114 | best_epoch = 0
|
||
115 | model.train() |
||
116 | num_iter_per_epoch = len(training_generator)
|
||
117 | |||
118 | loss_step = 0
|
||
119 | |||
120 | 87c89301 | esham21 | save_count = 0
|
121 | |||
122 | 6374c2c6 | esham21 | for epoch in range(opt.num_epoches): |
123 | if str(epoch) in learning_rate_schedule.keys(): |
||
124 | for param_group in optimizer.param_groups: |
||
125 | param_group['lr'] = learning_rate_schedule[str(epoch)] |
||
126 | |||
127 | for iter, batch in enumerate(training_generator): |
||
128 | image, label = batch |
||
129 | if torch.cuda.is_available():
|
||
130 | image = Variable(image.cuda(), requires_grad=True)
|
||
131 | else:
|
||
132 | image = Variable(image, requires_grad=True)
|
||
133 | |||
134 | optimizer.zero_grad() |
||
135 | logits = model(image) |
||
136 | loss, loss_coord, loss_conf, loss_cls = criterion(logits, label) |
||
137 | loss.backward() |
||
138 | |||
139 | optimizer.step() |
||
140 | |||
141 | if iter % opt.test_interval == 0: |
||
142 | 87c89301 | esham21 | print("Epoch: {}/{}, Iteration: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format
|
143 | 6374c2c6 | esham21 | (epoch + 1, opt.num_epoches, iter + 1, num_iter_per_epoch, optimizer.param_groups[0]['lr'], loss, |
144 | loss_coord,loss_conf,loss_cls)) |
||
145 | |||
146 | predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold, |
||
147 | opt.nms_threshold) |
||
148 | |||
149 | gt_image = at.tonumpy(image[0])
|
||
150 | gt_image = visdom_bbox(gt_image, label[0])
|
||
151 | bb5f022f | esham21 | #visdom.image(gt_image, opts=dict(title='gt_box_image'), win=3)
|
152 | 6374c2c6 | esham21 | |
153 | if len(predictions) != 0: |
||
154 | image = at.tonumpy(image[0])
|
||
155 | box_image = visdom_bbox(image, predictions[0])
|
||
156 | bb5f022f | esham21 | #visdom.image(box_image, opts=dict(title='box_image'), win=2)
|
157 | 6374c2c6 | esham21 | |
158 | elif len(predictions) == 0: |
||
159 | box_image = tensor2im(image) |
||
160 | bb5f022f | esham21 | #visdom.image(box_image.transpose([2, 0, 1]), opts=dict(title='box_image'), win=2)
|
161 | 6374c2c6 | esham21 | |
162 | loss_dict = { |
||
163 | 'total' : loss.item(),
|
||
164 | 'coord' : loss_coord.item(),
|
||
165 | 'conf' : loss_conf.item(),
|
||
166 | 'cls' : loss_cls.item()
|
||
167 | } |
||
168 | |||
169 | bb5f022f | esham21 | #visdom_loss(visdom, loss_step, loss_dict)
|
170 | 6374c2c6 | esham21 | loss_step = loss_step + 1
|
171 | |||
172 | if epoch % opt.test_interval == 0: |
||
173 | model.eval() |
||
174 | loss_ls = [] |
||
175 | loss_coord_ls = [] |
||
176 | loss_conf_ls = [] |
||
177 | loss_cls_ls = [] |
||
178 | for te_iter, te_batch in enumerate(test_generator): |
||
179 | te_image, te_label = te_batch |
||
180 | num_sample = len(te_label)
|
||
181 | if torch.cuda.is_available():
|
||
182 | te_image = te_image.cuda() |
||
183 | with torch.no_grad():
|
||
184 | te_logits = model(te_image) |
||
185 | batch_loss, batch_loss_coord, batch_loss_conf, batch_loss_cls = criterion(te_logits, te_label) |
||
186 | loss_ls.append(batch_loss * num_sample) |
||
187 | loss_coord_ls.append(batch_loss_coord * num_sample) |
||
188 | loss_conf_ls.append(batch_loss_conf * num_sample) |
||
189 | loss_cls_ls.append(batch_loss_cls * num_sample) |
||
190 | |||
191 | te_loss = sum(loss_ls) / test_set.__len__()
|
||
192 | te_coord_loss = sum(loss_coord_ls) / test_set.__len__()
|
||
193 | te_conf_loss = sum(loss_conf_ls) / test_set.__len__()
|
||
194 | te_cls_loss = sum(loss_cls_ls) / test_set.__len__()
|
||
195 | 87c89301 | esham21 | print("Test>> Epoch: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format(
|
196 | 6374c2c6 | esham21 | epoch + 1, opt.num_epoches, optimizer.param_groups[0]['lr'], te_loss, te_coord_loss, te_conf_loss, te_cls_loss)) |
197 | |||
198 | model.train() |
||
199 | if te_loss + opt.es_min_delta < best_loss:
|
||
200 | 87c89301 | esham21 | save_count += 1
|
201 | 6374c2c6 | esham21 | best_loss = te_loss |
202 | best_epoch = epoch |
||
203 | print("SAVE MODEL")
|
||
204 | 61709b69 | esham21 | # for debug for each loss
|
205 | 755b0a9d | esham21 | torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth")) |
206 | 61709b69 | esham21 | torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth")) |
207 | # save
|
||
208 | torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params.pth"))
|
||
209 | torch.save(model, os.path.join(opt.saved_path, name + "_whole_model.pth"))
|
||
210 | 77f7b6b3 | esham21 | else:
|
211 | save_count += 1
|
||
212 | # for debug for each loss
|
||
213 | torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth")) |
||
214 | torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth")) |
||
215 | 6374c2c6 | esham21 | |
216 | # Early stopping
|
||
217 | if epoch - best_epoch > opt.es_patience > 0: |
||
218 | print("Stop training at epoch {}. The lowest loss achieved is {}".format(epoch, te_loss))
|
||
219 | break
|
||
220 | |||
221 | def visdom_loss(visdom, loss_step, loss_dict): |
||
222 | loss_data['X'].append(loss_step)
|
||
223 | loss_data['Y'].append([loss_dict[k] for k in loss_data['legend_U']]) |
||
224 | visdom.line( |
||
225 | X=np.stack([np.array(loss_data['X'])] * len(loss_data['legend_U']), 1), |
||
226 | Y=np.array(loss_data['Y']),
|
||
227 | win=30,
|
||
228 | opts=dict(xlabel='Step', |
||
229 | ylabel='Loss',
|
||
230 | title='YOLO_V2',
|
||
231 | legend=loss_data['legend_U']),
|
||
232 | update='append'
|
||
233 | ) |
||
234 | |||
235 | def tensor2im(image_tensor, imtype=np.uint8): |
||
236 | |||
237 | image_numpy = image_tensor[0].detach().cpu().float().numpy()
|
||
238 | |||
239 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) |
||
240 | |||
241 | image_numpy = np.clip(image_numpy, 0, 255) |
||
242 | |||
243 | return image_numpy.astype(imtype)
|
||
244 | |||
245 | def denormalize(tensors): |
||
246 | """ Denormalizes image tensors using mean and std """
|
||
247 | mean = np.array([0.5, 0.5, 0.5]) |
||
248 | std = np.array([0.5, 0.5, 0.5]) |
||
249 | |||
250 | # mean = np.array([0.47571, 0.50874, 0.56821])
|
||
251 | # std = np.array([0.10341, 0.1062, 0.11548])
|
||
252 | |||
253 | denorm = tensors.clone() |
||
254 | |||
255 | for c in range(tensors.shape[1]): |
||
256 | denorm[:, c] = denorm[:, c].mul_(std[c]).add_(mean[c]) |
||
257 | |||
258 | denorm = torch.clamp(denorm, 0, 255) |
||
259 | |||
260 | return denorm
|
||
261 | |||
262 | if __name__ == "__main__": |
||
263 | b65d47a2 | esham21 | datas = ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief', |
264 | '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure', |
||
265 | 'inst', 'func_valve', 'inst_console', 'inst_console_dcs', 'inst_console_sih', 'logic_dcs', 'utility', 'specialty_items', 'logic', 'logic_local_console_dcs', |
||
266 | 'reducer', 'blind_spectacle_open', 'blind_insertion_open', 'blind_spectacle_close', 'blind_insertion_close', |
||
267 | 'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot', |
||
268 | 'opc']
|
||
269 | data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\Data\\', 'VnV') |
||
270 | train(name='VnV', classes=datas, root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath(
|
||
271 | __file__)) + '\\pre_trained_model\\only_params_trained_yolo_voc') |