프로젝트

일반

사용자정보

개정판 23662e23

ID23662e2326874ee5a8fa2ef0efc04311928c68b3
상위 0d95699a
하위 1034b01c

함의성이(가) 약 5년 전에 추가함

issue #1366: save classes

Change-Id: I26ca00352a04a8af7d80873c985a76c2b1c43a80

차이점 보기:

DTI_PID/DTI_PID/AppWebService.py
48 48
            mb.exec_()
49 49
            return False
50 50

  
51
    def request_symbol_box(self, name, img, classes):
51
    def request_symbol_box(self, name, img):
52 52
        # send uncroped image
53 53
        try:
54 54
            if not self.test_connection():
55 55
                return []
56 56

  
57
            data = { 'classes':classes, 'name':name, 'img':None }
57
            data = { 'name':name, 'img':None }
58 58
            symbol_box = '/recognition/symbol_box'
59 59

  
60 60
            _, bts = cv2.imencode('.png', img)
DTI_PID/DTI_PID/RecognitionDialog.py
2171 2171
        app_doc_data = AppDocData.instance()
2172 2172
        project = app_doc_data.getCurrentProject()
2173 2173
        area = app_doc_data.getArea('Drawing')
2174

  
2175
        classes = 0
2176
        symbolTypeList = app_doc_data.getSymbolTypeList()
2177
        for symbolType in symbolTypeList:
2178
            if not symbolType[1]: continue  # skip if category is empty
2179
            symbolList = app_doc_data.getSymbolListByType('UID', symbolType[0])
2180
            for symbol in symbolList:
2181
                classes += 1
2182 2174
        
2183 2175
        app_web_service = AppWebService()
2184
        symbols = app_web_service.request_symbol_box(project.name, area.img, classes)
2176
        symbols = app_web_service.request_symbol_box(project.name, area.img)
2185 2177

  
2186 2178
        for targetSymbol in targetSymbols[2]:
2187 2179
            symbolName = targetSymbol.getName()
DTI_PID/WebServer/app/recognition/index.py
59 59

  
60 60
        data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\', data['name'])
61 61

  
62
        boxes = test_doftech_all_images.get_symbol(imgs, data['classes'], trained_model1=os.path.join(data_path, 'checkpoint', data['name'] + "_only_params.pth"), \
62
        boxes = test_doftech_all_images.get_symbol(imgs, root_path=data_path, trained_model1=os.path.join(data_path, 'checkpoint', data['name'] + "_only_params.pth"), \
63 63
                                                   trained_model2=os.path.dirname(os.path.realpath(
64 64
                                                       __file__)) + '\\..\\..\\symbol_recognition\\MODEL\\doftech_all_class_only_params_opc.pth')
65 65

  
DTI_PID/WebServer/symbol_recognition/test_doftech_all_images.py
47 47
        # x, y, width, height
48 48
        return [self.class_info, self.start_h, self.start_w, self.end_h - self.start_h, self.end_w - self.start_w, self.iou]
49 49

  
50
def get_args():
51
    parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
52
    parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images")
53
    parser.add_argument("--conf_threshold", type=float, default=0.5)
54
    parser.add_argument("--nms_threshold", type=float, default=0.6)
55
    parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="params")
56
    parser.add_argument("--pre_trained_model_path", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\MODEL/f_doftech_all_class_only_params.pth")
57
    parser.add_argument("--pre_trained_model_path2", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\MODEL/doftech_all_class_only_params_opc.pth")
58
    #parser.add_argument("--pre_trained_model_path3", type=str, default="trained_models/only_params_trained_yolo_voc")
59
    parser.add_argument("--input", type=str, default="test_images")
60
    parser.add_argument("--output", type=str, default="result_images")
61
    parser.add_argument("--data_path_test", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\pid_images/", help="the root folder of dataset") #Doftech VOCdevkit
62
    args = parser.parse_args()
63
    return args
64

  
65
def get_symbol(imgs, classes, trained_model1=None, trained_model2=None):
50
def get_symbol(imgs, root_path=None, trained_model1=None, trained_model2=None):
66 51
    global colors
67 52

  
68 53
    parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
......
75 60
    parser.add_argument("--data_path_test", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\pid_images/", help="the root folder of dataset") #Doftech VOCdevkit
76 61
    opt = parser.parse_args()
77 62

  
63
    saved_path = os.path.join(root_path, 'checkpoint')
64
    with open(os.path.join(saved_path, name + "_info.info"), 'r') as stream:
65
        con = stream.read(con).split('\n')
66
        classes = int(con[0])
67
        DOFTECH_CLASSES = con[1:]
68

  
78 69
    if torch.cuda.is_available():
79 70
        model1 = Yolo(classes).cuda()
80 71
        model1.load_state_dict(torch.load(opt.pre_trained_model_path))
......
163 154

  
164 155
    return total_symbole_lists
165 156

  
166
def test(opt):
167
    global colors
168
    if torch.cuda.is_available():
169
        model1 = Yolo(35).cuda()
170
        model1.load_state_dict(torch.load(opt.pre_trained_model_path))
171

  
172
        model2 = Yolo(35).cuda()
173
        model2.load_state_dict(torch.load(opt.pre_trained_model_path2))
174
    else:
175
        return None
176

  
177
    colors = pickle.load(open("E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition/src/pallete", "rb"))
178

  
179
    print(opt.input)
180

  
181
    img_list = sorted(glob.glob(os.path.join(opt.data_path_test, '*.png')))
182

  
183
    start = time.time()
184
    print(img_list)
185

  
186
    # ----------------------------------------------
187
    # Get Patch arguments : (img list, patch_size, overlap_size)
188
    # ----------------------------------------------
189

  
190
    for idx_img in range(len(img_list)):
191
        print("=========> CROP PATCH")
192
        small_object_patch_list = []
193
        large_object_patch_list = []
194
        total_symbole_list = []
195

  
196
        small_object_patch_list.append(get_patch(img_list[idx_img], 500, 250))
197
        large_object_patch_list.append(get_patch(img_list[idx_img], 800, 200))
198

  
199
        img_name = img_list[idx_img].split('\\')[1].split('.')[0]
200

  
201
        print("=========> "+ img_name)
202

  
203
        save_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'save', img_name)
204
        if not os.path.isdir(save_dir):
205
            os.mkdir(save_dir)
206
            os.mkdir(save_dir+"a/")
207
            os.mkdir(save_dir+"b/")
208

  
209
        text_file = open(save_dir+'test_result.txt', mode='wt', encoding='utf-8')
210
        text_file.write(img_name+" -> Detection Reesult\n")
211
        text_file.write("================================\n")
212

  
213
            # Text File 만들기
214

  
215
        # ----------------------------------------------
216
        # Small Object Detection (Valve, Sensor, etc.)
217
        # ---------------------------------------------
218
        print("=========> SMALL OBJECT DETECTION")
219
        for idx_small in range(len(small_object_patch_list)):
220
            patchs = small_object_patch_list[idx_small]
221
            total_symbole_list.append(detection_object(patchs, model1, save_dir+"a/", opc=False, opt=opt))
222
        # ----------------------------------------------
223
        # Large Object Detection (OPC etc.)
224
        # ----------------------------------------------
225
        print("=========> LARGE OBJECT DETECTION")
226
        for idx_large in range(len(large_object_patch_list)):
227
            patchs = large_object_patch_list[idx_large]
228
            total_symbole_list.append(detection_object(patchs, model2, save_dir+"b/", opc=True, opt=opt))
229

  
230
        count_result = merge_fn(img_list[idx_img], total_symbole_list, save_dir)
231

  
232
        for idx, value in enumerate(count_result):
233
            text_file.write(DOFTECH_CLASSES[idx]+ " : "+ str(value) + "\n")
234

  
235
        text_file.close()
236
    print("time :", time.time() - start)  # 현재시각 - 시작시간 = 실행 시간
237 157

  
238 158
def detection_object(patchs, model, save_root, opc, opt):
239 159
    symbol_list = []
DTI_PID/WebServer/symbol_training/train.py
22 22
loss_data = {'X': [], 'Y': [], 'legend_U':['total', 'coord', 'conf', 'cls']}
23 23
#visdom = visdom.Visdom(port='8080')
24 24

  
25
def get_args():
26
    parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
27
    parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images")
28
    parser.add_argument("--batch_size", type=int, default=10, help="The number of images per batch")
29

  
30
    # Training 기본 Setting
31
    parser.add_argument("--momentum", type=float, default=0.9)
32
    parser.add_argument("--decay", type=float, default=0.0005)
33
    parser.add_argument("--dropout", type=float, default=0.5)
34
    parser.add_argument("--num_epoches", type=int, default=1000)
35
    parser.add_argument("--test_interval", type=int, default=20, help="Number of epoches between testing phases")
36
    parser.add_argument("--object_scale", type=float, default=1.0)
37
    parser.add_argument("--noobject_scale", type=float, default=0.5)
38
    parser.add_argument("--class_scale", type=float, default=1.0)
39
    parser.add_argument("--coord_scale", type=float, default=5.0)
40
    parser.add_argument("--reduction", type=int, default=32)
41
    parser.add_argument("--es_min_delta", type=float, default=0.0,
42
                        help="Early stopping's parameter: minimum change loss to qualify as an improvement")
43
    parser.add_argument("--es_patience", type=int, default=0,
44
                        help="Early stopping's parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.")
45

  
46
    # 확인해야 하는 PATH
47
    parser.add_argument("--data_path", type=str, default="D:/data/DATA_Doftech/small_symbol/training", help="the root folder of dataset") # 학습 데이터 경로 -> image와 xml의 상위 경로 입력
48
    parser.add_argument("--data_path_test", type=str, default="./data/Test", help="the root folder of dataset") # 테스트 데이터 경로 -> test할 이미지만 넣으면 됨
49
    #parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="model")
50
    parser.add_argument("--pre_trained_model_path", type=str, default="trained_models/only_params_trained_yolo_voc") # Pre-training 된 모델 경로
51

  
52
    parser.add_argument("--saved_path", type=str, default="./checkpoint") # training 된 모델 저장 경로
53
    parser.add_argument("--conf_threshold", type=float, default=0.35)
54
    parser.add_argument("--nms_threshold", type=float, default=0.5)
55
    args = parser.parse_args()
56
    return args
57

  
58 25
# 형상 CLASS
59 26
DOFTECH_CLASSES= ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief',
60 27
                  '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure',
......
63 30
                  'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot'
64 31
                  ,'opc']
65 32

  
66
print(len(DOFTECH_CLASSES))
33
#print(len(DOFTECH_CLASSES))
67 34

  
68 35
def train(name=None, classes=None, root_path=None, pre_trained_model_path=None):
69 36
    DOFTECH_CLASSES = classes
......
99 66
    parser.add_argument("--nms_threshold", type=float, default=0.5)
100 67
    opt = parser.parse_args()
101 68

  
69
    with open(os.path.join(opt.saved_path, name + "_info.info"), 'w') as stream:
70
        con = str(len(DOFTECH_CLASSES))
71
        names = '\n'.join(DOFTECH_CLASSES)
72
        con = con + '\n' + names
73
        stream.write(con)
102 74

  
103 75
    if torch.cuda.is_available():
104 76
        torch.cuda.manual_seed(123)

내보내기 Unified diff