개정판 9be41199
issue #1366: cu merge test
Change-Id: Ic21cb5de9fa1cff6045d257eb4f746a0357c26ca
DTI_PID/WebServer/symbol_training/train.py | ||
---|---|---|
29 | 29 |
'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot' |
30 | 30 |
,'opc'] |
31 | 31 |
|
32 |
#print(len(DOFTECH_CLASSES)) |
|
32 |
use_voc_model = True |
|
33 |
use_visdom = False |
|
34 |
if use_visdom |
|
35 |
visdom = visdom.Visdom() |
|
33 | 36 |
|
34 | 37 |
def train(name=None, classes=None, bigs=None, root_path=None, pre_trained_model_path=None): |
35 | 38 |
global DOFTECH_CLASSES |
... | ... | |
81 | 84 |
torch.cuda.manual_seed(123) |
82 | 85 |
else: |
83 | 86 |
torch.manual_seed(123) |
87 |
|
|
84 | 88 |
learning_rate_schedule = {"0": 1e-5, "5": 1e-4, |
85 | 89 |
"80": 1e-5, "110": 1e-6} |
86 | 90 |
|
87 |
training_params = {"batch_size": opt.batch_size, |
|
91 |
training_params = {"batch_size": 1,#opt.batch_size,
|
|
88 | 92 |
"shuffle": True, |
89 | 93 |
"drop_last": True, |
90 | 94 |
"collate_fn": custom_collate_fn} |
... | ... | |
97 | 101 |
training_set = DoftechDataset(opt.data_path, opt.image_size, is_training=True, classes=DOFTECH_CLASSES) |
98 | 102 |
training_generator = DataLoader(training_set, **training_params) |
99 | 103 |
|
100 |
test_set = DoftechDatasetTest(opt.data_path_test, opt.image_size, is_training=False, classes=DOFTECH_CLASSES)
|
|
104 |
test_set = DoftechDataset(opt.data_path_test, opt.image_size, is_training=False, classes=DOFTECH_CLASSES) |
|
101 | 105 |
test_generator = DataLoader(test_set, **test_params) |
102 | 106 |
|
103 |
pre_model = Yolo(20).cuda() |
|
104 |
pre_model.load_state_dict(torch.load(opt.pre_trained_model_path), strict=False) |
|
105 |
|
|
106 |
model = YoloD(pre_model, training_set.num_classes).cuda() |
|
107 |
# BUILDING MODEL ======================================================================= |
|
108 |
if use_voc_model : |
|
109 |
pre_model = Yolo(20).cuda() |
|
110 |
pre_model.load_state_dict(torch.load(opt.pre_trained_model_path), strict=False) |
|
111 |
model = YoloD(pre_model, training_set.num_classes).cuda() |
|
112 |
else : |
|
113 |
model = Yolo(training_set.num_classes).cuda() |
|
107 | 114 |
|
108 | 115 |
nn.init.normal_(list(model.modules())[-1].weight, 0, 0.01) |
109 | 116 |
|
... | ... | |
112 | 119 |
|
113 | 120 |
best_loss = 1e10 |
114 | 121 |
best_epoch = 0 |
115 |
model.train() |
|
116 | 122 |
num_iter_per_epoch = len(training_generator) |
117 |
|
|
118 | 123 |
loss_step = 0 |
124 |
# ====================================================================================== |
|
119 | 125 |
|
126 |
# TRAINING ============================================================================= |
|
120 | 127 |
save_count = 0 |
121 | 128 |
|
129 |
model.train() |
|
122 | 130 |
for epoch in range(opt.num_epoches): |
123 | 131 |
if str(epoch) in learning_rate_schedule.keys(): |
124 | 132 |
for param_group in optimizer.param_groups: |
125 | 133 |
param_group['lr'] = learning_rate_schedule[str(epoch)] |
126 | 134 |
|
127 | 135 |
for iter, batch in enumerate(training_generator): |
128 |
image, label = batch |
|
136 |
image, label, image2 = batch |
|
137 |
image = Variable(image.cuda(), requires_grad=False) |
|
129 | 138 |
if torch.cuda.is_available(): |
130 | 139 |
image = Variable(image.cuda(), requires_grad=True) |
140 |
origin = Variable(image2.cuda(), requires_grad=False) |
|
131 | 141 |
else: |
132 | 142 |
image = Variable(image, requires_grad=True) |
133 | 143 |
|
... | ... | |
135 | 145 |
logits = model(image) |
136 | 146 |
loss, loss_coord, loss_conf, loss_cls = criterion(logits, label) |
137 | 147 |
loss.backward() |
138 |
|
|
139 | 148 |
optimizer.step() |
140 | 149 |
|
141 | 150 |
if iter % opt.test_interval == 0: |
... | ... | |
143 | 152 |
(epoch + 1, opt.num_epoches, iter + 1, num_iter_per_epoch, optimizer.param_groups[0]['lr'], loss, |
144 | 153 |
loss_coord,loss_conf,loss_cls)) |
145 | 154 |
|
146 |
predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold, |
|
147 |
opt.nms_threshold) |
|
155 |
if use_visdom: |
|
156 |
predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold, |
|
157 |
opt.nms_threshold) |
|
148 | 158 |
|
149 |
gt_image = at.tonumpy(image[0]) |
|
150 |
gt_image = visdom_bbox(gt_image, label[0]) |
|
151 |
#visdom.image(gt_image, opts=dict(title='gt_box_image'), win=3) |
|
159 |
gt_image = at.tonumpy(image[0]) |
|
160 |
gt_image = visdom_bbox(gt_image, label[0]) |
|
161 |
visdom.image(gt_image, opts=dict(title='gt_box_image'), win=3) |
|
162 |
# |
|
163 |
origin_image = at.tonumpy(origin[0]) |
|
164 |
origin_image = visdom_bbox(origin_image, []) |
|
165 |
visdom.image(origin_image, opts=dict(title='origin_box_image'), win=4) |
|
152 | 166 |
|
153 |
if len(predictions) != 0: |
|
154 | 167 |
image = at.tonumpy(image[0]) |
155 |
box_image = visdom_bbox(image, predictions[0]) |
|
156 |
#visdom.image(box_image, opts=dict(title='box_image'), win=2) |
|
157 | 168 |
|
158 |
elif len(predictions) == 0: |
|
159 |
box_image = tensor2im(image) |
|
160 |
#visdom.image(box_image.transpose([2, 0, 1]), opts=dict(title='box_image'), win=2) |
|
169 |
if len(predictions) != 0: |
|
170 |
box_image = visdom_bbox(image, predictions[0]) |
|
171 |
visdom.image(box_image, opts=dict(title='box_image'), win=2) |
|
172 |
|
|
173 |
elif len(predictions) == 0: |
|
174 |
box_image = visdom_bbox(image, []) |
|
175 |
visdom.image(box_image, opts=dict(title='box_image'), win=2) |
|
161 | 176 |
|
162 |
loss_dict = { |
|
163 |
'total' : loss.item(), |
|
164 |
'coord' : loss_coord.item(), |
|
165 |
'conf' : loss_conf.item(), |
|
166 |
'cls' : loss_cls.item() |
|
167 |
} |
|
177 |
loss_dict = {
|
|
178 |
'total' : loss.item(),
|
|
179 |
'coord' : loss_coord.item(),
|
|
180 |
'conf' : loss_conf.item(),
|
|
181 |
'cls' : loss_cls.item()
|
|
182 |
}
|
|
168 | 183 |
|
169 |
#visdom_loss(visdom, loss_step, loss_dict)
|
|
170 |
loss_step = loss_step + 1 |
|
184 |
visdom_loss(visdom, loss_step, loss_dict)
|
|
185 |
loss_step = loss_step + 1
|
|
171 | 186 |
|
172 | 187 |
if epoch % opt.test_interval == 0: |
173 | 188 |
model.eval() |
... | ... | |
232 | 247 |
update='append' |
233 | 248 |
) |
234 | 249 |
|
235 |
def tensor2im(image_tensor, imtype=np.uint8): |
|
236 |
|
|
237 |
image_numpy = image_tensor[0].detach().cpu().float().numpy() |
|
238 |
|
|
239 |
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) |
|
240 |
|
|
241 |
image_numpy = np.clip(image_numpy, 0, 255) |
|
242 |
|
|
243 |
return image_numpy.astype(imtype) |
|
244 |
|
|
245 |
def denormalize(tensors): |
|
246 |
""" Denormalizes image tensors using mean and std """ |
|
247 |
mean = np.array([0.5, 0.5, 0.5]) |
|
248 |
std = np.array([0.5, 0.5, 0.5]) |
|
249 |
|
|
250 |
# mean = np.array([0.47571, 0.50874, 0.56821]) |
|
251 |
# std = np.array([0.10341, 0.1062, 0.11548]) |
|
252 |
|
|
253 |
denorm = tensors.clone() |
|
254 |
|
|
255 |
for c in range(tensors.shape[1]): |
|
256 |
denorm[:, c] = denorm[:, c].mul_(std[c]).add_(mean[c]) |
|
257 |
|
|
258 |
denorm = torch.clamp(denorm, 0, 255) |
|
259 |
|
|
260 |
return denorm |
|
261 |
|
|
262 | 250 |
if __name__ == "__main__": |
263 | 251 |
datas = ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief', |
264 | 252 |
'3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure', |
내보내기 Unified diff