프로젝트

일반

사용자정보

통계
| 브랜치(Branch): | 개정판:

hytos / DTI_PID / WebServer / symbol_training / train.py @ af6e3ebf

이력 | 보기 | 이력해설 | 다운로드 (13.9 KB)

1
"""
2
Training For Small Object
3
-> parser에서 data_path와 pre_train_model 경로만 지정해준 후 돌리면 됩니다.
4
"""
5
import os
6
import argparse
7
import torch.nn as nn
8
from torch.utils.data import DataLoader
9
from src.doftech_dataset import DoftechDataset, DoftechDatasetTest
10
from src.utils import *
11
from src.loss import YoloLoss
12
from src.yolo_net import Yolo
13
from src.yolo_doftech import YoloD
14
import shutil
15
import visdom
16
import cv2
17
import pickle
18
import numpy as np
19
from src.vis_utils import array_tool as at
20
from src.vis_utils.vis_tool import visdom_bbox
21

    
22
loss_data = {'X': [], 'Y': [], 'legend_U':['total', 'coord', 'conf', 'cls']}
23
#visdom = visdom.Visdom(port='8080')
24

    
25
def get_args():
26
    parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
27
    parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images")
28
    parser.add_argument("--batch_size", type=int, default=10, help="The number of images per batch")
29

    
30
    # Training 기본 Setting
31
    parser.add_argument("--momentum", type=float, default=0.9)
32
    parser.add_argument("--decay", type=float, default=0.0005)
33
    parser.add_argument("--dropout", type=float, default=0.5)
34
    parser.add_argument("--num_epoches", type=int, default=1000)
35
    parser.add_argument("--test_interval", type=int, default=20, help="Number of epoches between testing phases")
36
    parser.add_argument("--object_scale", type=float, default=1.0)
37
    parser.add_argument("--noobject_scale", type=float, default=0.5)
38
    parser.add_argument("--class_scale", type=float, default=1.0)
39
    parser.add_argument("--coord_scale", type=float, default=5.0)
40
    parser.add_argument("--reduction", type=int, default=32)
41
    parser.add_argument("--es_min_delta", type=float, default=0.0,
42
                        help="Early stopping's parameter: minimum change loss to qualify as an improvement")
43
    parser.add_argument("--es_patience", type=int, default=0,
44
                        help="Early stopping's parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.")
45

    
46
    # 확인해야 하는 PATH
47
    parser.add_argument("--data_path", type=str, default="D:/data/DATA_Doftech/small_symbol/training", help="the root folder of dataset") # 학습 데이터 경로 -> image와 xml의 상위 경로 입력
48
    parser.add_argument("--data_path_test", type=str, default="./data/Test", help="the root folder of dataset") # 테스트 데이터 경로 -> test할 이미지만 넣으면 됨
49
    #parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="model")
50
    parser.add_argument("--pre_trained_model_path", type=str, default="trained_models/only_params_trained_yolo_voc") # Pre-training 된 모델 경로
51

    
52
    parser.add_argument("--saved_path", type=str, default="./checkpoint") # training 된 모델 저장 경로
53
    parser.add_argument("--conf_threshold", type=float, default=0.35)
54
    parser.add_argument("--nms_threshold", type=float, default=0.5)
55
    args = parser.parse_args()
56
    return args
57

    
58
# 형상 CLASS
59
DOFTECH_CLASSES= ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief',
60
                  '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure',
61
                  'circle', 'inst_console', 'inst_console_dcs', 'inst_console_sih', 'logic_dcs', 'utility', 'specialty_items', 'logic', 'logic_local_console_dcs',
62
                  'reducer', 'blind_spectacle_open', 'blind_insertion_open', 'blind_spectacle_close', 'blind_insertion_close',
63
                  'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot'
64
                  ,'opc']
65

    
66
print(len(DOFTECH_CLASSES))
67

    
68
def train(name=None, classes=None, root_path=None, pre_trained_model_path=None):
69
    DOFTECH_CLASSES = classes
70

    
71
    parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
72
    parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images")
73
    parser.add_argument("--batch_size", type=int, default=10, help="The number of images per batch")
74

    
75
    # Training 기본 Setting
76
    parser.add_argument("--momentum", type=float, default=0.9)
77
    parser.add_argument("--decay", type=float, default=0.0005)
78
    parser.add_argument("--dropout", type=float, default=0.5)
79
    parser.add_argument("--num_epoches", type=int, default=1000)
80
    parser.add_argument("--test_interval", type=int, default=20, help="Number of epoches between testing phases")
81
    parser.add_argument("--object_scale", type=float, default=1.0)
82
    parser.add_argument("--noobject_scale", type=float, default=0.5)
83
    parser.add_argument("--class_scale", type=float, default=1.0)
84
    parser.add_argument("--coord_scale", type=float, default=5.0)
85
    parser.add_argument("--reduction", type=int, default=32)
86
    parser.add_argument("--es_min_delta", type=float, default=0.0,
87
                        help="Early stopping's parameter: minimum change loss to qualify as an improvement")
88
    parser.add_argument("--es_patience", type=int, default=0,
89
                        help="Early stopping's parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.")
90

    
91
    # 확인해야 하는 PATH
92
    parser.add_argument("--data_path", type=str, default=os.path.join(root_path, 'training'), help="the root folder of dataset") # 학습 데이터 경로 -> image와 xml의 상위 경로 입력
93
    parser.add_argument("--data_path_test", type=str, default=os.path.join(root_path, 'test'), help="the root folder of dataset") # 테스트 데이터 경로 -> test할 이미지만 넣으면 됨
94
    #parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="model")
95
    parser.add_argument("--pre_trained_model_path", type=str, default=pre_trained_model_path) # Pre-training 된 모델 경로
96

    
97
    parser.add_argument("--saved_path", type=str, default=os.path.join(root_path, 'checkpoint')) # training 된 모델 저장 경로
98
    parser.add_argument("--conf_threshold", type=float, default=0.35)
99
    parser.add_argument("--nms_threshold", type=float, default=0.5)
100
    opt = parser.parse_args()
101

    
102

    
103
    if torch.cuda.is_available():
104
        torch.cuda.manual_seed(123)
105
    else:
106
        torch.manual_seed(123)
107
    learning_rate_schedule = {"0": 1e-5, "5": 1e-4,
108
                              "80": 1e-5, "110": 1e-6}
109

    
110
    training_params = {"batch_size": opt.batch_size,
111
                       "shuffle": True,
112
                       "drop_last": True,
113
                       "collate_fn": custom_collate_fn}
114

    
115
    test_params = {"batch_size": opt.batch_size,
116
                   "shuffle": False,
117
                   "drop_last": False,
118
                   "collate_fn": custom_collate_fn}
119

    
120
    training_set = DoftechDataset(opt.data_path, opt.image_size, is_training=True, classes=DOFTECH_CLASSES)
121
    training_generator = DataLoader(training_set, **training_params)
122

    
123
    test_set = DoftechDatasetTest(opt.data_path_test, opt.image_size, is_training=False, classes=DOFTECH_CLASSES)
124
    test_generator = DataLoader(test_set, **test_params)
125

    
126
    pre_model = Yolo(20).cuda()
127
    pre_model.load_state_dict(torch.load(opt.pre_trained_model_path), strict=False)
128

    
129
    model = YoloD(pre_model, training_set.num_classes).cuda()
130

    
131
    nn.init.normal_(list(model.modules())[-1].weight, 0, 0.01)
132

    
133
    criterion = YoloLoss(training_set.num_classes, model.anchors, opt.reduction)
134
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-5, momentum=opt.momentum, weight_decay=opt.decay)
135

    
136
    best_loss = 1e10
137
    best_epoch = 0
138
    model.train()
139
    num_iter_per_epoch = len(training_generator)
140

    
141
    loss_step = 0
142

    
143
    for epoch in range(opt.num_epoches):
144
        if str(epoch) in learning_rate_schedule.keys():
145
            for param_group in optimizer.param_groups:
146
                param_group['lr'] = learning_rate_schedule[str(epoch)]
147

    
148
        for iter, batch in enumerate(training_generator):
149
            image, label = batch
150
            if torch.cuda.is_available():
151
                image = Variable(image.cuda(), requires_grad=True)
152
            else:
153
                image = Variable(image, requires_grad=True)
154

    
155
            optimizer.zero_grad()
156
            logits = model(image)
157
            loss, loss_coord, loss_conf, loss_cls = criterion(logits, label)
158
            loss.backward()
159

    
160
            optimizer.step()
161

    
162
            if iter % opt.test_interval == 0:
163
                print("Epoch: {}/{}, Iteration: {}/{}, Lr: {}, Loss:{:.2f} (Coord:{:.2f} Conf:{:.2f} Cls:{:.2f})".format
164
                    (epoch + 1, opt.num_epoches, iter + 1, num_iter_per_epoch, optimizer.param_groups[0]['lr'], loss,
165
                    loss_coord,loss_conf,loss_cls))
166

    
167
                predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold,
168
                                              opt.nms_threshold)
169

    
170
                gt_image = at.tonumpy(image[0])
171
                gt_image = visdom_bbox(gt_image, label[0])
172
                #visdom.image(gt_image, opts=dict(title='gt_box_image'), win=3)
173

    
174
                if len(predictions) != 0:
175
                    image = at.tonumpy(image[0])
176
                    box_image = visdom_bbox(image, predictions[0])
177
                    #visdom.image(box_image, opts=dict(title='box_image'), win=2)
178

    
179
                elif len(predictions) == 0:
180
                    box_image = tensor2im(image)
181
                    #visdom.image(box_image.transpose([2, 0, 1]), opts=dict(title='box_image'), win=2)
182

    
183
                loss_dict = {
184
                    'total' : loss.item(),
185
                    'coord' : loss_coord.item(),
186
                    'conf' : loss_conf.item(),
187
                    'cls' : loss_cls.item()
188
                }
189

    
190
                #visdom_loss(visdom, loss_step, loss_dict)
191
                loss_step = loss_step + 1
192

    
193
        if epoch % opt.test_interval == 0:
194
            model.eval()
195
            loss_ls = []
196
            loss_coord_ls = []
197
            loss_conf_ls = []
198
            loss_cls_ls = []
199
            for te_iter, te_batch in enumerate(test_generator):
200
                te_image, te_label = te_batch
201
                num_sample = len(te_label)
202
                if torch.cuda.is_available():
203
                    te_image = te_image.cuda()
204
                with torch.no_grad():
205
                    te_logits = model(te_image)
206
                    batch_loss, batch_loss_coord, batch_loss_conf, batch_loss_cls = criterion(te_logits, te_label)
207
                loss_ls.append(batch_loss * num_sample)
208
                loss_coord_ls.append(batch_loss_coord * num_sample)
209
                loss_conf_ls.append(batch_loss_conf * num_sample)
210
                loss_cls_ls.append(batch_loss_cls * num_sample)
211

    
212
            te_loss = sum(loss_ls) / test_set.__len__()
213
            te_coord_loss = sum(loss_coord_ls) / test_set.__len__()
214
            te_conf_loss = sum(loss_conf_ls) / test_set.__len__()
215
            te_cls_loss = sum(loss_cls_ls) / test_set.__len__()
216
            print("Test>> Epoch: {}/{}, Lr: {}, Loss:{:.2f} (Coord:{:.2f} Conf:{:.2f} Cls:{:.2f})".format(
217
                epoch + 1, opt.num_epoches, optimizer.param_groups[0]['lr'], te_loss, te_coord_loss, te_conf_loss, te_cls_loss))
218

    
219
            model.train()
220
            if te_loss + opt.es_min_delta < best_loss:
221
                best_loss = te_loss
222
                best_epoch = epoch
223
                print("SAVE MODEL")
224
                torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params.pth"))
225
                torch.save(model, os.path.join(opt.saved_path, name + "_whole_model.pth"))
226

    
227
            # Early stopping
228
            if epoch - best_epoch > opt.es_patience > 0:
229
                print("Stop training at epoch {}. The lowest loss achieved is {}".format(epoch, te_loss))
230
                break
231

    
232
def visdom_loss(visdom, loss_step, loss_dict):
233
    loss_data['X'].append(loss_step)
234
    loss_data['Y'].append([loss_dict[k] for k in loss_data['legend_U']])
235
    visdom.line(
236
        X=np.stack([np.array(loss_data['X'])] * len(loss_data['legend_U']), 1),
237
        Y=np.array(loss_data['Y']),
238
        win=30,
239
        opts=dict(xlabel='Step',
240
                  ylabel='Loss',
241
                  title='YOLO_V2',
242
                  legend=loss_data['legend_U']),
243
        update='append'
244
    )
245

    
246
def tensor2im(image_tensor, imtype=np.uint8):
247

    
248
    image_numpy = image_tensor[0].detach().cpu().float().numpy()
249

    
250
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
251

    
252
    image_numpy = np.clip(image_numpy, 0, 255)
253

    
254
    return image_numpy.astype(imtype)
255

    
256
def denormalize(tensors):
257
    """ Denormalizes image tensors using mean and std """
258
    mean = np.array([0.5, 0.5, 0.5])
259
    std = np.array([0.5, 0.5, 0.5])
260

    
261
    # mean = np.array([0.47571, 0.50874, 0.56821])
262
    # std = np.array([0.10341, 0.1062, 0.11548])
263

    
264
    denorm = tensors.clone()
265

    
266
    for c in range(tensors.shape[1]):
267
        denorm[:, c] = denorm[:, c].mul_(std[c]).add_(mean[c])
268

    
269
    denorm = torch.clamp(denorm, 0, 255)
270

    
271
    return denorm
272

    
273
if __name__ == "__main__":
274
    datas = ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief',
275
                  '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure',
276
                  'inst', 'func_valve', 'inst_console', 'inst_console_dcs', 'inst_console_sih', 'logic_dcs', 'utility', 'specialty_items', 'logic', 'logic_local_console_dcs',
277
                  'reducer', 'blind_spectacle_open', 'blind_insertion_open', 'blind_spectacle_close', 'blind_insertion_close',
278
                  'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot',
279
                  'opc']
280
    data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\Data\\', 'VnV')
281
    train(name='VnV', classes=datas, root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath(
282
                                                       __file__)) + '\\pre_trained_model\\only_params_trained_yolo_voc')
클립보드 이미지 추가 (최대 크기: 500 MB)