프로젝트

일반

사용자정보

통계
| 브랜치(Branch): | 개정판:

hytos / DTI_PID / WebServer / symbol_training / train.py @ 9be41199

이력 | 보기 | 이력해설 | 다운로드 (12.9 KB)

1 6374c2c6 esham21
"""
2
Training For Small Object
3
"""
4
import os
5
import argparse
6
import torch.nn as nn
7
from torch.utils.data import DataLoader
8
from src.doftech_dataset import DoftechDataset, DoftechDatasetTest
9
from src.utils import *
10 76986eb0 esham21
from src.loss import YoloLoss
11 6374c2c6 esham21
from src.yolo_net import Yolo
12
from src.yolo_doftech import YoloD
13
import shutil
14
import visdom
15
import cv2
16
import pickle
17
import numpy as np
18
from src.vis_utils import array_tool as at
19
from src.vis_utils.vis_tool import visdom_bbox
20
21
loss_data = {'X': [], 'Y': [], 'legend_U':['total', 'coord', 'conf', 'cls']}
22 b64cc3b5 esham21
#visdom = visdom.Visdom(port='8088')
23 6374c2c6 esham21
24
# 형상 CLASS
25
DOFTECH_CLASSES= ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief',
26
                  '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure',
27
                  'circle', 'inst_console', 'inst_console_dcs', 'inst_console_sih', 'logic_dcs', 'utility', 'specialty_items', 'logic', 'logic_local_console_dcs',
28
                  'reducer', 'blind_spectacle_open', 'blind_insertion_open', 'blind_spectacle_close', 'blind_insertion_close',
29
                  'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot'
30
                  ,'opc']
31
32 9be41199 esham21
use_voc_model = True
33
use_visdom = False
34
if use_visdom
35
    visdom = visdom.Visdom()
36 6374c2c6 esham21
37 1e8ea226 esham21
def train(name=None, classes=None, bigs=None, root_path=None, pre_trained_model_path=None):
38 e473b8aa esham21
    global DOFTECH_CLASSES
39 6374c2c6 esham21
    DOFTECH_CLASSES = classes
40
41
    parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
42 e473b8aa esham21
    parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images")
43 6374c2c6 esham21
    parser.add_argument("--batch_size", type=int, default=10, help="The number of images per batch")
44
45
    # Training 기본 Setting
46
    parser.add_argument("--momentum", type=float, default=0.9)
47
    parser.add_argument("--decay", type=float, default=0.0005)
48
    parser.add_argument("--dropout", type=float, default=0.5)
49 aba30784 esham21
    parser.add_argument("--num_epoches", type=int, default=205)
50 b64cc3b5 esham21
    parser.add_argument("--test_interval", type=int, default=20, help="Number of epoches between testing phases")
51 6374c2c6 esham21
    parser.add_argument("--object_scale", type=float, default=1.0)
52
    parser.add_argument("--noobject_scale", type=float, default=0.5)
53
    parser.add_argument("--class_scale", type=float, default=1.0)
54
    parser.add_argument("--coord_scale", type=float, default=5.0)
55
    parser.add_argument("--reduction", type=int, default=32)
56
    parser.add_argument("--es_min_delta", type=float, default=0.0,
57
                        help="Early stopping's parameter: minimum change loss to qualify as an improvement")
58
    parser.add_argument("--es_patience", type=int, default=0,
59
                        help="Early stopping's parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.")
60
61
    # 확인해야 하는 PATH
62
    parser.add_argument("--data_path", type=str, default=os.path.join(root_path, 'training'), help="the root folder of dataset") # 학습 데이터 경로 -> image와 xml의 상위 경로 입력
63
    parser.add_argument("--data_path_test", type=str, default=os.path.join(root_path, 'test'), help="the root folder of dataset") # 테스트 데이터 경로 -> test할 이미지만 넣으면 됨
64
    #parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="model")
65
    parser.add_argument("--pre_trained_model_path", type=str, default=pre_trained_model_path) # Pre-training 된 모델 경로
66
67
    parser.add_argument("--saved_path", type=str, default=os.path.join(root_path, 'checkpoint')) # training 된 모델 저장 경로
68
    parser.add_argument("--conf_threshold", type=float, default=0.35)
69
    parser.add_argument("--nms_threshold", type=float, default=0.5)
70
    opt = parser.parse_args()
71
72 a6b28afb esham21
    if not os.path.isdir(opt.saved_path):
73 7f74c5bd esham21
        os.mkdir(opt.saved_path)
74 a6b28afb esham21
75 28822594 esham21
    # 학습할 클래스들을 저장하고 인식 시 불러와 사용합니다.
76 23662e23 esham21
    with open(os.path.join(opt.saved_path, name + "_info.info"), 'w') as stream:
77
        con = str(len(DOFTECH_CLASSES))
78
        names = '\n'.join(DOFTECH_CLASSES)
79 1e8ea226 esham21
        bigs = '\n'.join(bigs)
80
        con = con + '\n' + names + '\n' + '***bigs***' + '\n' + bigs
81 23662e23 esham21
        stream.write(con)
82 6374c2c6 esham21
83
    if torch.cuda.is_available():
84
        torch.cuda.manual_seed(123)
85
    else:
86
        torch.manual_seed(123)
87 9be41199 esham21
88 6374c2c6 esham21
    learning_rate_schedule = {"0": 1e-5, "5": 1e-4,
89
                              "80": 1e-5, "110": 1e-6}
90
91 9be41199 esham21
    training_params = {"batch_size": 1,#opt.batch_size,
92 6374c2c6 esham21
                       "shuffle": True,
93
                       "drop_last": True,
94
                       "collate_fn": custom_collate_fn}
95
96
    test_params = {"batch_size": opt.batch_size,
97
                   "shuffle": False,
98
                   "drop_last": False,
99
                   "collate_fn": custom_collate_fn}
100
101
    training_set = DoftechDataset(opt.data_path, opt.image_size, is_training=True, classes=DOFTECH_CLASSES)
102
    training_generator = DataLoader(training_set, **training_params)
103
104 9be41199 esham21
    test_set = DoftechDataset(opt.data_path_test, opt.image_size, is_training=False, classes=DOFTECH_CLASSES)
105 6374c2c6 esham21
    test_generator = DataLoader(test_set, **test_params)
106
107 9be41199 esham21
    # BUILDING MODEL =======================================================================
108
    if use_voc_model :
109
        pre_model = Yolo(20).cuda()
110
        pre_model.load_state_dict(torch.load(opt.pre_trained_model_path), strict=False)
111
        model = YoloD(pre_model, training_set.num_classes).cuda()
112
    else :
113
        model = Yolo(training_set.num_classes).cuda()
114 6374c2c6 esham21
115
    nn.init.normal_(list(model.modules())[-1].weight, 0, 0.01)
116
117
    criterion = YoloLoss(training_set.num_classes, model.anchors, opt.reduction)
118
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-5, momentum=opt.momentum, weight_decay=opt.decay)
119
120
    best_loss = 1e10
121
    best_epoch = 0
122
    num_iter_per_epoch = len(training_generator)
123
    loss_step = 0
124 9be41199 esham21
    # ======================================================================================
125 6374c2c6 esham21
126 9be41199 esham21
    # TRAINING =============================================================================
127 87c89301 esham21
    save_count = 0
128
129 9be41199 esham21
    model.train()
130 6374c2c6 esham21
    for epoch in range(opt.num_epoches):
131
        if str(epoch) in learning_rate_schedule.keys():
132
            for param_group in optimizer.param_groups:
133
                param_group['lr'] = learning_rate_schedule[str(epoch)]
134
135
        for iter, batch in enumerate(training_generator):
136 9be41199 esham21
            image, label, image2 = batch
137
            image = Variable(image.cuda(), requires_grad=False)
138 6374c2c6 esham21
            if torch.cuda.is_available():
139
                image = Variable(image.cuda(), requires_grad=True)
140 9be41199 esham21
                origin = Variable(image2.cuda(), requires_grad=False)
141 6374c2c6 esham21
            else:
142
                image = Variable(image, requires_grad=True)
143
144
            optimizer.zero_grad()
145
            logits = model(image)
146
            loss, loss_coord, loss_conf, loss_cls = criterion(logits, label)
147
            loss.backward()
148
            optimizer.step()
149
150
            if iter % opt.test_interval == 0:
151 87c89301 esham21
                print("Epoch: {}/{}, Iteration: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format
152 6374c2c6 esham21
                    (epoch + 1, opt.num_epoches, iter + 1, num_iter_per_epoch, optimizer.param_groups[0]['lr'], loss,
153
                    loss_coord,loss_conf,loss_cls))
154
155 9be41199 esham21
                if use_visdom:
156
                    predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold,
157
                                                  opt.nms_threshold)
158 6374c2c6 esham21
159 9be41199 esham21
                    gt_image = at.tonumpy(image[0])
160
                    gt_image = visdom_bbox(gt_image, label[0])
161
                    visdom.image(gt_image, opts=dict(title='gt_box_image'), win=3)
162
                    #
163
                    origin_image = at.tonumpy(origin[0])
164
                    origin_image = visdom_bbox(origin_image, [])
165
                    visdom.image(origin_image, opts=dict(title='origin_box_image'), win=4)
166 6374c2c6 esham21
167
                    image = at.tonumpy(image[0])
168
169 9be41199 esham21
                    if len(predictions) != 0:
170
                        box_image = visdom_bbox(image, predictions[0])
171
                        visdom.image(box_image, opts=dict(title='box_image'), win=2)
172
173
                    elif len(predictions) == 0:
174
                        box_image = visdom_bbox(image, [])
175
                        visdom.image(box_image, opts=dict(title='box_image'), win=2)
176 6374c2c6 esham21
177 9be41199 esham21
                    loss_dict = {
178
                        'total' : loss.item(),
179
                        'coord' : loss_coord.item(),
180
                        'conf' : loss_conf.item(),
181
                        'cls' : loss_cls.item()
182
                    }
183 6374c2c6 esham21
184 9be41199 esham21
                    visdom_loss(visdom, loss_step, loss_dict)
185
                    loss_step = loss_step + 1
186 6374c2c6 esham21
187
        if epoch % opt.test_interval == 0:
188
            model.eval()
189
            loss_ls = []
190
            loss_coord_ls = []
191
            loss_conf_ls = []
192
            loss_cls_ls = []
193
            for te_iter, te_batch in enumerate(test_generator):
194
                te_image, te_label = te_batch
195
                num_sample = len(te_label)
196
                if torch.cuda.is_available():
197
                    te_image = te_image.cuda()
198
                with torch.no_grad():
199
                    te_logits = model(te_image)
200
                    batch_loss, batch_loss_coord, batch_loss_conf, batch_loss_cls = criterion(te_logits, te_label)
201
                loss_ls.append(batch_loss * num_sample)
202
                loss_coord_ls.append(batch_loss_coord * num_sample)
203
                loss_conf_ls.append(batch_loss_conf * num_sample)
204
                loss_cls_ls.append(batch_loss_cls * num_sample)
205
206
            te_loss = sum(loss_ls) / test_set.__len__()
207
            te_coord_loss = sum(loss_coord_ls) / test_set.__len__()
208
            te_conf_loss = sum(loss_conf_ls) / test_set.__len__()
209
            te_cls_loss = sum(loss_cls_ls) / test_set.__len__()
210 87c89301 esham21
            print("Test>> Epoch: {}/{}, Lr: {}, Loss:{:.5f} (Coord:{:.5f} Conf:{:.5f} Cls:{:.5f})".format(
211 6374c2c6 esham21
                epoch + 1, opt.num_epoches, optimizer.param_groups[0]['lr'], te_loss, te_coord_loss, te_conf_loss, te_cls_loss))
212
213
            model.train()
214
            if te_loss + opt.es_min_delta < best_loss:
215 87c89301 esham21
                save_count += 1
216 6374c2c6 esham21
                best_loss = te_loss
217
                best_epoch = epoch
218
                print("SAVE MODEL")
219 61709b69 esham21
                # for debug for each loss
220 755b0a9d esham21
                torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth"))
221 61709b69 esham21
                torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth"))
222
                # save
223
                torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params.pth"))
224
                torch.save(model, os.path.join(opt.saved_path, name + "_whole_model.pth"))
225 77f7b6b3 esham21
            else:
226
                save_count += 1
227
                # for debug for each loss
228
                torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth"))
229
                torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth"))
230 6374c2c6 esham21
231
            # Early stopping
232
            if epoch - best_epoch > opt.es_patience > 0:
233
                print("Stop training at epoch {}. The lowest loss achieved is {}".format(epoch, te_loss))
234
                break
235
236
def visdom_loss(visdom, loss_step, loss_dict):
237
    loss_data['X'].append(loss_step)
238
    loss_data['Y'].append([loss_dict[k] for k in loss_data['legend_U']])
239
    visdom.line(
240
        X=np.stack([np.array(loss_data['X'])] * len(loss_data['legend_U']), 1),
241
        Y=np.array(loss_data['Y']),
242
        win=30,
243
        opts=dict(xlabel='Step',
244
                  ylabel='Loss',
245
                  title='YOLO_V2',
246
                  legend=loss_data['legend_U']),
247
        update='append'
248
    )
249
250
if __name__ == "__main__":
251 b65d47a2 esham21
    datas = ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief',
252
                  '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure',
253
                  'inst', 'func_valve', 'inst_console', 'inst_console_dcs', 'inst_console_sih', 'logic_dcs', 'utility', 'specialty_items', 'logic', 'logic_local_console_dcs',
254
                  'reducer', 'blind_spectacle_open', 'blind_insertion_open', 'blind_spectacle_close', 'blind_insertion_close',
255
                  'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot',
256
                  'opc']
257
    data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\Data\\', 'VnV')
258
    train(name='VnV', classes=datas, root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath(
259
                                                       __file__)) + '\\pre_trained_model\\only_params_trained_yolo_voc')
클립보드 이미지 추가 (최대 크기: 500 MB)