프로젝트

일반

사용자정보

개정판 acb80620

IDacb80620fe3d72473eaf5437eb1dab73094e5bd1
상위 2969340e
하위 13759969

함의성이(가) 약 5년 전에 추가함

issue #1366: add file

Change-Id: I54037a35e0f9e4de9485da03b75ec79a9a2dca14

차이점 보기:

DTI_PID/WebServer/symbol_recognition/test_doftech_all_images.py
1
"""
2
Testing
3
pid_images에 테스트 할 도면을 넣고 돌리면 결과를 pid_results에서 확인할 수 있습니다.
4
pre_train_model path를 입력하고 돌리면 됩니다.
5
"""
6
import os
7
import glob
8
import argparse
9
import pickle
10
import cv2
11
import numpy as np
12
from src.utils import *
13
from src.yolo_net import Yolo
14
from src.yolo_doftech import YoloD
15
from PIL import Image
16
import time
17
DOFTECH_CLASSES = ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief',
18
                  '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure',
19
                  'inst', 'func_valve', 'inst_console', 'inst_console_dcs', 'inst_console_sih', 'logic_dcs', 'utility', 'specialty_items', 'logic', 'logic_local_console_dcs',
20
                  'reducer', 'blind_spectacle_open', 'blind_insertion_open', 'blind_spectacle_close', 'blind_insertion_close',
21
                  'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot',
22
                  'opc']
23
class Patch:
24
    def __init__(self, start_h, start_w, img, check_line):
25
        self.start_h = start_h
26
        self.start_w = start_w
27
        self.img = img
28
        self.check_line = check_line
29
        self.obj_num = 0
30

  
31
    def setObjectNum(self, obj_num):
32
        self.obj_num = obj_num
33

  
34
class Symbol:
35
    def __init__(self, start_h, start_w, end_h, end_w, iou, class_info):
36
        self.start_h = start_h
37
        self.start_w = start_w
38
        self.end_h = end_h
39
        self.end_w = end_w
40
        self.iou = iou
41
        self.class_info = class_info
42

  
43
    def to_stream(self):
44
        # y, x, height, width : member name is incorrect
45
        #return [self.class_info, self.start_w, self.start_h, self.end_w - self.start_w, self.end_h - self.start_h, self.iou]
46
        # x, y, width, height
47
        return [self.class_info, self.start_h, self.start_w, self.end_h - self.start_h, self.end_w - self.start_w, self.iou]
48

  
49
def get_args():
50
    parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
51
    parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images")
52
    parser.add_argument("--conf_threshold", type=float, default=0.5)
53
    parser.add_argument("--nms_threshold", type=float, default=0.6)
54
    parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="params")
55
    parser.add_argument("--pre_trained_model_path", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\MODEL/f_doftech_all_class_only_params.pth")
56
    parser.add_argument("--pre_trained_model_path2", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\MODEL/doftech_all_class_only_params_opc.pth")
57
    #parser.add_argument("--pre_trained_model_path3", type=str, default="trained_models/only_params_trained_yolo_voc")
58
    parser.add_argument("--input", type=str, default="test_images")
59
    parser.add_argument("--output", type=str, default="result_images")
60
    parser.add_argument("--data_path_test", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\pid_images/", help="the root folder of dataset") #Doftech VOCdevkit
61
    args = parser.parse_args()
62
    return args
63

  
64
def get_symbol(imgs, trained_model1=None, trained_model2=None):
65
    global colors
66

  
67
    parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
68
    parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images")
69
    parser.add_argument("--conf_threshold", type=float, default=0.5)
70
    parser.add_argument("--nms_threshold", type=float, default=0.6)
71
    parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="params")
72
    parser.add_argument("--pre_trained_model_path", type=str, default=trained_model1)
73
    parser.add_argument("--pre_trained_model_path2", type=str, default=trained_model2)
74
    parser.add_argument("--data_path_test", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\pid_images/", help="the root folder of dataset") #Doftech VOCdevkit
75
    opt = parser.parse_args()
76

  
77
    if torch.cuda.is_available():
78
        model1 = Yolo(35).cuda()
79
        model1.load_state_dict(torch.load(opt.pre_trained_model_path))
80

  
81
        model2 = Yolo(35).cuda()
82
        model2.load_state_dict(torch.load(opt.pre_trained_model_path2))
83

  
84
    colors = pickle.load(open(os.path.dirname(os.path.realpath(__file__)) + "/src/pallete", "rb"))
85

  
86
    img_list = imgs
87

  
88
    start = time.time()
89

  
90
    total_symbole_lists = []
91

  
92
    # ----------------------------------------------
93
    # Get Patch arguments : (img list, patch_size, overlap_size)
94
    # ----------------------------------------------
95

  
96
    for idx_img in range(len(img_list)):
97
        print("=========> CROP PATCH")
98
        small_object_patch_list = []
99
        large_object_patch_list = []
100
        total_symbole_list = []
101

  
102
        s_image = img_list[idx_img].copy()
103
        l_image = img_list[idx_img].copy()
104

  
105
        small_object_patch_list.append(get_patch(s_image, 500, 250))
106
        large_object_patch_list.append(get_patch(l_image, 800, 200))
107

  
108
        img_name = str(idx_img)
109

  
110
        print("=========> "+ img_name)
111

  
112
        save_dir = 'E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\save/' + img_name + '/'
113
        if not os.path.isdir(save_dir):
114
            os.mkdir(save_dir)
115
            os.mkdir(save_dir+"a/")
116
            os.mkdir(save_dir+"b/")
117

  
118
        text_file = open(save_dir+'test_result.txt', mode='wt', encoding='utf-8')
119
        text_file.write(img_name+" -> Detection Reesult\n")
120
        text_file.write("================================\n")
121

  
122
            # Text File 만들기
123

  
124
        # ----------------------------------------------
125
        # Small Object Detection (Valve, Sensor, etc.)
126
        # ---------------------------------------------
127
        print("=========> SMALL OBJECT DETECTION")
128
        for idx_small in range(len(small_object_patch_list)):
129
            patchs = small_object_patch_list[idx_small]
130
            total_symbole_list.append(detection_object(patchs, model1, save_dir+"a/", opc=False, opt=opt))
131
        # ----------------------------------------------
132
        # Large Object Detection (OPC etc.)
133
        # ----------------------------------------------
134
        print("=========> LARGE OBJECT DETECTION")
135
        for idx_large in range(len(large_object_patch_list)):
136
            patchs = large_object_patch_list[idx_large]
137
            total_symbole_list.append(detection_object(patchs, model2, save_dir+"b/", opc=True, opt=opt))
138

  
139
        t_image = img_list[idx_img].copy()
140
        count_result = merge_fn(t_image, total_symbole_list, save_dir)
141

  
142
        for idx, value in enumerate(count_result):
143
            text_file.write(DOFTECH_CLASSES[idx]+ " : "+ str(value) + "\n")
144

  
145
        text_file.close()
146

  
147
        res = []
148
        for symbol in total_symbole_list[0] + total_symbole_list[1]:
149
            res.append(symbol.to_stream())
150
        total_symbole_lists.append(res)
151

  
152
    print("time :", time.time() - start)  # 현재시각 - 시작시간 = 실행 시간
153

  
154
    return total_symbole_lists
155

  
156
def test(opt):
157
    global colors
158
    if torch.cuda.is_available():
159
        model1 = Yolo(35).cuda()
160
        model1.load_state_dict(torch.load(opt.pre_trained_model_path))
161

  
162
        model2 = Yolo(35).cuda()
163
        model2.load_state_dict(torch.load(opt.pre_trained_model_path2))
164
    else:
165
        return None
166

  
167
    colors = pickle.load(open("E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition/src/pallete", "rb"))
168

  
169
    print(opt.input)
170

  
171
    img_list = sorted(glob.glob(os.path.join(opt.data_path_test, '*.png')))
172

  
173
    start = time.time()
174
    print(img_list)
175

  
176
    # ----------------------------------------------
177
    # Get Patch arguments : (img list, patch_size, overlap_size)
178
    # ----------------------------------------------
179

  
180
    for idx_img in range(len(img_list)):
181
        print("=========> CROP PATCH")
182
        small_object_patch_list = []
183
        large_object_patch_list = []
184
        total_symbole_list = []
185

  
186
        small_object_patch_list.append(get_patch(img_list[idx_img], 500, 250))
187
        large_object_patch_list.append(get_patch(img_list[idx_img], 800, 200))
188

  
189
        img_name = img_list[idx_img].split('\\')[1].split('.')[0]
190

  
191
        print("=========> "+ img_name)
192

  
193
        save_dir = 'E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\save/' + img_name + '/'
194
        if not os.path.isdir(save_dir):
195
            os.mkdir(save_dir)
196
            os.mkdir(save_dir+"a/")
197
            os.mkdir(save_dir+"b/")
198

  
199
        text_file = open(save_dir+'test_result.txt', mode='wt', encoding='utf-8')
200
        text_file.write(img_name+" -> Detection Reesult\n")
201
        text_file.write("================================\n")
202

  
203
            # Text File 만들기
204

  
205
        # ----------------------------------------------
206
        # Small Object Detection (Valve, Sensor, etc.)
207
        # ---------------------------------------------
208
        print("=========> SMALL OBJECT DETECTION")
209
        for idx_small in range(len(small_object_patch_list)):
210
            patchs = small_object_patch_list[idx_small]
211
            total_symbole_list.append(detection_object(patchs, model1, save_dir+"a/", opc=False, opt=opt))
212
        # ----------------------------------------------
213
        # Large Object Detection (OPC etc.)
214
        # ----------------------------------------------
215
        print("=========> LARGE OBJECT DETECTION")
216
        for idx_large in range(len(large_object_patch_list)):
217
            patchs = large_object_patch_list[idx_large]
218
            total_symbole_list.append(detection_object(patchs, model2, save_dir+"b/", opc=True, opt=opt))
219

  
220
        count_result = merge_fn(img_list[idx_img], total_symbole_list, save_dir)
221

  
222
        for idx, value in enumerate(count_result):
223
            text_file.write(DOFTECH_CLASSES[idx]+ " : "+ str(value) + "\n")
224

  
225
        text_file.close()
226
    print("time :", time.time() - start)  # 현재시각 - 시작시간 = 실행 시간
227

  
228
def detection_object(patchs, model, save_root, opc, opt):
229
    symbol_list = []
230
    for idx in range(len(patchs)):
231
        pil_image = patchs[idx]
232
        # pil_image.img.show()
233
        np_img = np.asarray(pil_image.img)
234
        image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
235
        height, width = image.shape[:2]
236
        image = cv2.resize(image, (opt.image_size, opt.image_size))
237
        image = np.transpose(np.array(image, dtype=np.float32), (2, 0, 1))
238
        image = image[None, :, :, :]
239
        width_ratio = float(opt.image_size) / width
240
        height_ratio = float(opt.image_size) / height
241

  
242
        data = Variable(torch.FloatTensor(image))
243

  
244
        if torch.cuda.is_available():
245
            data = data.cuda()
246

  
247
        with torch.no_grad():
248
            logits = model(data)
249
            predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold,
250
                                          opt.nms_threshold)
251
        if len(predictions) != 0:
252
            predictions = predictions[0]
253
            output_image = cv2.cvtColor(np.array(pil_image.img), cv2.COLOR_RGB2BGR)
254
            for pred in predictions:
255
                if opc == True :
256
                    if pred[5] == "opc" and pred[4] > 0.4:  # Classification threshold
257
                        xmin = int(max(pred[0] / width_ratio, 0))
258
                        ymin = int(max(pred[1] / height_ratio, 0))
259
                        xmax = int(min((pred[0] + pred[2]) / width_ratio, width))
260
                        ymax = int(min((pred[1] + pred[3]) / height_ratio, height))
261
                        color = colors[DOFTECH_CLASSES.index(pred[5])]
262
                        symbol = Symbol(xmin + pil_image.start_w, ymin + pil_image.start_h, xmax + pil_image.start_w,
263
                                        ymax + pil_image.start_h, pred[4], pred[5])  # To save symbole information
264
                        check = check_duplication(symbol_list, symbol, True)
265
                        if check is True:
266
                            symbol_list.append(symbol)
267
                        cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2)
268
                        text_size = cv2.getTextSize(pred[5] + ' : %.2f' % pred[4], cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
269
                        cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4),
270
                                      color,-1)
271
                        cv2.putText(
272
                            output_image, pred[5] + ' : %.2f' % pred[4],
273
                            (xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1,
274
                            (255, 255, 255), 1)
275
                else :
276
                    if pred[4] > 0.1:  # Classification threshold
277
                        xmin = int(max(pred[0] / width_ratio, 0))
278
                        ymin = int(max(pred[1] / height_ratio, 0))
279
                        xmax = int(min((pred[0] + pred[2]) / width_ratio, width))
280
                        ymax = int(min((pred[1] + pred[3]) / height_ratio, height))
281
                        color = colors[DOFTECH_CLASSES.index(pred[5])]
282
                        symbol = Symbol(xmin + pil_image.start_w, ymin + pil_image.start_h, xmax + pil_image.start_w,
283
                                        ymax + pil_image.start_h, pred[4], pred[5])  # To save symbole information
284
                        check = check_duplication(symbol_list, symbol)
285
                        if check is True:
286
                            symbol_list.append(symbol)
287
                        cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2)
288
                        text_size = cv2.getTextSize(pred[5] + ' : %.2f' % pred[4], cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
289
                        cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4), color,
290
                                      -1)
291
                        cv2.putText(
292
                            output_image, pred[5] + ' : %.2f' % pred[4],
293
                            (xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1,
294
                            (255, 255, 255), 1)
295
            # -----------------------fill king image--------------------------------------
296
            cv2.imwrite(save_root+"{}_prediction.png".format(idx), output_image)
297
            pil_image.setObjectNum(len(predictions))
298
    return symbol_list
299

  
300
def merge_fn(total_img, total_symbole_list, save_root):
301
    """ total_img:path -> img file """
302
    print("=========> MERGE RESULT")
303

  
304
    if type(total_img) is str:
305
        total_img = Image.open(total_img)
306

  
307
    king_image = total_img.convert('RGB')
308
    temp_king_img = np.array(king_image)#Image.new("RGB", (king_width, king_height))
309
    temp_king_img = temp_king_img[:,:,::-1].copy()
310

  
311
    count_result = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
312

  
313
    for idx_out in range(len(total_symbole_list)):
314
        for idx in range(len(total_symbole_list[idx_out])):
315
            color = colors[DOFTECH_CLASSES.index(total_symbole_list[idx_out][idx].class_info)]
316
            cv2.rectangle(temp_king_img, (total_symbole_list[idx_out][idx].start_h, total_symbole_list[idx_out][idx].start_w), (total_symbole_list[idx_out][idx].end_h, total_symbole_list[idx_out][idx].end_w), color, 2)
317
            text_size = cv2.getTextSize(total_symbole_list[idx_out][idx].class_info + ' : %.2f' % total_symbole_list[idx_out][idx].iou, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
318
            cv2.rectangle(temp_king_img, (total_symbole_list[idx_out][idx].start_h, total_symbole_list[idx_out][idx].start_w), (total_symbole_list[idx_out][idx].start_h + text_size[0] + 3,                                                                                                total_symbole_list[idx_out][idx].start_w + text_size[1] + 4), color, -1)
319
            cv2.putText(
320
                temp_king_img, total_symbole_list[idx_out][idx].class_info + ' : %.2f' % total_symbole_list[idx_out][idx].iou,
321
                (total_symbole_list[idx_out][idx].start_h, total_symbole_list[idx_out][idx].start_w+ text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1,
322
                (255, 255, 255), 1)
323
            # COUNTING
324
            for index, name in enumerate(DOFTECH_CLASSES):
325
                if name == total_symbole_list[idx_out][idx].class_info :
326
                    count_result[index] = count_result[index]+1
327
    cv2.imwrite(save_root+"entire.png", temp_king_img)
328

  
329
    return count_result
330

  
331
def get_patch(image, patch_size, overlap_size):
332
    """ img_path -> img """
333
    crop_list = []
334
    if type(image) is str:
335
        image = Image.open(image)
336
    (img_h, img_w) = image.size
337

  
338
    # crop 할 사이즈 : grid_w, grid_h
339
    grid_w = patch_size  # crop width
340
    grid_h = patch_size  # crop
341

  
342
    range_w = (int)(img_w / grid_w)
343
    range_h = (int)(img_h / grid_h)
344
    # print(range_w, range_h)
345

  
346
    repeat_num_w = (int)(grid_w / overlap_size)
347
    repeat_num_h = (int)(grid_h / overlap_size)
348
    # print(repeat_num_w, repeat_num_h)
349

  
350
    for w in range(range_w * repeat_num_w):
351
        for h in range(range_h * repeat_num_h):
352
            if h == (range_h * repeat_num_h) - 1 and w == (range_w * repeat_num_w) - 1:
353
                start_h = img_h - grid_h
354
                start_w = img_w - grid_w
355
                end_h = img_h
356
                end_w = img_w
357
                bbox = (start_h, start_w, end_h, end_w)
358

  
359
                crop_img = image.crop(bbox)
360
                patch = Patch(start_w, start_h, crop_img, True)
361
                crop_list.append(patch)
362

  
363
            elif (h * overlap_size) + (grid_h) >= img_h:
364
                start_h = img_h - grid_h  # OK
365
                start_w = img_w - grid_w
366
                end_h = img_h
367
                end_w = start_w + grid_w
368
                bbox = (start_h, start_w, end_h, end_w)
369

  
370
                crop_img = image.crop(bbox)
371
                patch = Patch(start_w, start_h, crop_img, True)
372
                crop_list.append(patch)
373

  
374

  
375
            elif (w * overlap_size) + (grid_w) >= img_w:
376
                start_h = img_h - grid_h
377
                start_w = img_w - grid_w
378
                end_h = start_h + grid_h
379
                end_w = img_w
380
                bbox = (start_h, start_w, end_h, end_w)
381

  
382
                crop_img = image.crop(bbox)
383
                patch = Patch(start_w, start_h, crop_img, True)
384
                crop_list.append(patch)
385

  
386
            else:
387
                bbox = (h * overlap_size, w * overlap_size, (h * overlap_size) + (grid_h),
388
                        (w * overlap_size) + (grid_w))  # 좌, 상, 우, 하
389
                crop_img = image.crop(bbox)
390
                patch = Patch(w * overlap_size, h * overlap_size, crop_img, False)
391
                crop_list.append(patch)
392
            # print(bbox)
393
            # patch.img.save('./pid_results/_{}.png'.format(len(crop_list)))
394
    return crop_list
395

  
396
def check_duplication(symbol_list, check_symbol, opc=False):
397

  
398
    if check_symbol.class_info == "strainer_basket":
399
        threshold = 30
400
    elif check_symbol.class_info == "3way_solenoid":
401
        threshold = 30
402
    else :
403
        threshold = 20
404

  
405
    for idx in range(len(symbol_list)):
406
        if (abs(symbol_list[idx].start_h - check_symbol.start_h) <= threshold):
407
            if (abs(symbol_list[idx].start_w - check_symbol.start_w) <= threshold):
408
                if (symbol_list[idx].iou < check_symbol.iou):
409
                    symbol_list[idx] = check_symbol
410
                    return False
411
                else:
412
                    return False
413
    return True
414

  
415
if __name__ == "__main__":
416
    opt = get_args()
417
    test(opt)
418

  

내보내기 Unified diff

클립보드 이미지 추가 (최대 크기: 500 MB)