프로젝트

일반

사용자정보

개정판 9be41199

ID9be411999eb9dcd875bf277f51364abc0c35f511
상위 0d19831e
하위 9c91a597, 150e63fa

함의성이(가) 4년 이상 전에 추가함

issue #1366: cu merge test

Change-Id: Ic21cb5de9fa1cff6045d257eb4f746a0357c26ca

차이점 보기:

DTI_PID/WebServer/symbol_training/train.py
29 29
                  'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot'
30 30
                  ,'opc']
31 31

  
32
#print(len(DOFTECH_CLASSES))
32
use_voc_model = True
33
use_visdom = False
34
if use_visdom
35
    visdom = visdom.Visdom()
33 36

  
34 37
def train(name=None, classes=None, bigs=None, root_path=None, pre_trained_model_path=None):
35 38
    global DOFTECH_CLASSES
......
81 84
        torch.cuda.manual_seed(123)
82 85
    else:
83 86
        torch.manual_seed(123)
87

  
84 88
    learning_rate_schedule = {"0": 1e-5, "5": 1e-4,
85 89
                              "80": 1e-5, "110": 1e-6}
86 90

  
87
    training_params = {"batch_size": opt.batch_size,
91
    training_params = {"batch_size": 1,#opt.batch_size,
88 92
                       "shuffle": True,
89 93
                       "drop_last": True,
90 94
                       "collate_fn": custom_collate_fn}
......
97 101
    training_set = DoftechDataset(opt.data_path, opt.image_size, is_training=True, classes=DOFTECH_CLASSES)
98 102
    training_generator = DataLoader(training_set, **training_params)
99 103

  
100
    test_set = DoftechDatasetTest(opt.data_path_test, opt.image_size, is_training=False, classes=DOFTECH_CLASSES)
104
    test_set = DoftechDataset(opt.data_path_test, opt.image_size, is_training=False, classes=DOFTECH_CLASSES)
101 105
    test_generator = DataLoader(test_set, **test_params)
102 106

  
103
    pre_model = Yolo(20).cuda()
104
    pre_model.load_state_dict(torch.load(opt.pre_trained_model_path), strict=False)
105

  
106
    model = YoloD(pre_model, training_set.num_classes).cuda()
107
    # BUILDING MODEL =======================================================================
108
    if use_voc_model :
109
        pre_model = Yolo(20).cuda()
110
        pre_model.load_state_dict(torch.load(opt.pre_trained_model_path), strict=False)
111
        model = YoloD(pre_model, training_set.num_classes).cuda()
112
    else :
113
        model = Yolo(training_set.num_classes).cuda()
107 114

  
108 115
    nn.init.normal_(list(model.modules())[-1].weight, 0, 0.01)
109 116

  
......
112 119

  
113 120
    best_loss = 1e10
114 121
    best_epoch = 0
115
    model.train()
116 122
    num_iter_per_epoch = len(training_generator)
117

  
118 123
    loss_step = 0
124
    # ======================================================================================
119 125

  
126
    # TRAINING =============================================================================
120 127
    save_count = 0
121 128

  
129
    model.train()
122 130
    for epoch in range(opt.num_epoches):
123 131
        if str(epoch) in learning_rate_schedule.keys():
124 132
            for param_group in optimizer.param_groups:
125 133
                param_group['lr'] = learning_rate_schedule[str(epoch)]
126 134

  
127 135
        for iter, batch in enumerate(training_generator):
128
            image, label = batch
136
            image, label, image2 = batch
137
            image = Variable(image.cuda(), requires_grad=False)
129 138
            if torch.cuda.is_available():
130 139
                image = Variable(image.cuda(), requires_grad=True)
140
                origin = Variable(image2.cuda(), requires_grad=False)
131 141
            else:
132 142
                image = Variable(image, requires_grad=True)
133 143

  
......
135 145
            logits = model(image)
136 146
            loss, loss_coord, loss_conf, loss_cls = criterion(logits, label)
137 147
            loss.backward()
138

  
139 148
            optimizer.step()
140 149

  
141 150
            if iter % opt.test_interval == 0:
......
143 152
                    (epoch + 1, opt.num_epoches, iter + 1, num_iter_per_epoch, optimizer.param_groups[0]['lr'], loss,
144 153
                    loss_coord,loss_conf,loss_cls))
145 154

  
146
                predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold,
147
                                              opt.nms_threshold)
155
                if use_visdom:
156
                    predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold,
157
                                                  opt.nms_threshold)
148 158

  
149
                gt_image = at.tonumpy(image[0])
150
                gt_image = visdom_bbox(gt_image, label[0])
151
                #visdom.image(gt_image, opts=dict(title='gt_box_image'), win=3)
159
                    gt_image = at.tonumpy(image[0])
160
                    gt_image = visdom_bbox(gt_image, label[0])
161
                    visdom.image(gt_image, opts=dict(title='gt_box_image'), win=3)
162
                    #
163
                    origin_image = at.tonumpy(origin[0])
164
                    origin_image = visdom_bbox(origin_image, [])
165
                    visdom.image(origin_image, opts=dict(title='origin_box_image'), win=4)
152 166

  
153
                if len(predictions) != 0:
154 167
                    image = at.tonumpy(image[0])
155
                    box_image = visdom_bbox(image, predictions[0])
156
                    #visdom.image(box_image, opts=dict(title='box_image'), win=2)
157 168

  
158
                elif len(predictions) == 0:
159
                    box_image = tensor2im(image)
160
                    #visdom.image(box_image.transpose([2, 0, 1]), opts=dict(title='box_image'), win=2)
169
                    if len(predictions) != 0:
170
                        box_image = visdom_bbox(image, predictions[0])
171
                        visdom.image(box_image, opts=dict(title='box_image'), win=2)
172

  
173
                    elif len(predictions) == 0:
174
                        box_image = visdom_bbox(image, [])
175
                        visdom.image(box_image, opts=dict(title='box_image'), win=2)
161 176

  
162
                loss_dict = {
163
                    'total' : loss.item(),
164
                    'coord' : loss_coord.item(),
165
                    'conf' : loss_conf.item(),
166
                    'cls' : loss_cls.item()
167
                }
177
                    loss_dict = {
178
                        'total' : loss.item(),
179
                        'coord' : loss_coord.item(),
180
                        'conf' : loss_conf.item(),
181
                        'cls' : loss_cls.item()
182
                    }
168 183

  
169
                #visdom_loss(visdom, loss_step, loss_dict)
170
                loss_step = loss_step + 1
184
                    visdom_loss(visdom, loss_step, loss_dict)
185
                    loss_step = loss_step + 1
171 186

  
172 187
        if epoch % opt.test_interval == 0:
173 188
            model.eval()
......
232 247
        update='append'
233 248
    )
234 249

  
235
def tensor2im(image_tensor, imtype=np.uint8):
236

  
237
    image_numpy = image_tensor[0].detach().cpu().float().numpy()
238

  
239
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
240

  
241
    image_numpy = np.clip(image_numpy, 0, 255)
242

  
243
    return image_numpy.astype(imtype)
244

  
245
def denormalize(tensors):
246
    """ Denormalizes image tensors using mean and std """
247
    mean = np.array([0.5, 0.5, 0.5])
248
    std = np.array([0.5, 0.5, 0.5])
249

  
250
    # mean = np.array([0.47571, 0.50874, 0.56821])
251
    # std = np.array([0.10341, 0.1062, 0.11548])
252

  
253
    denorm = tensors.clone()
254

  
255
    for c in range(tensors.shape[1]):
256
        denorm[:, c] = denorm[:, c].mul_(std[c]).add_(mean[c])
257

  
258
    denorm = torch.clamp(denorm, 0, 255)
259

  
260
    return denorm
261

  
262 250
if __name__ == "__main__":
263 251
    datas = ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief',
264 252
                  '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure',

내보내기 Unified diff

클립보드 이미지 추가 (최대 크기: 500 MB)