개정판 1e8ea226
issue #1366: test big symbol
Change-Id: I42759e4e53529fc18b1fcb71e2a31ba90ea81f6d
DTI_PID/WebServer/app/training/index.py | ||
---|---|---|
84 | 84 |
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\', datas['name']) |
85 | 85 | |
86 | 86 |
if os.path.isdir(data_path): |
87 |
train.train(name=datas['name'], classes=datas['classes'], root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath( |
|
87 |
train.train(name=datas['name'], classes=datas['classes'], bigs=datas['bigs'], root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath(
|
|
88 | 88 |
__file__)) + '\\..\\..\\symbol_training\\pre_trained_model\\only_params_trained_yolo_voc') |
89 | 89 |
|
90 | 90 |
return jsonify({'count': 1}) |
DTI_PID/WebServer/symbol_recognition/test_doftech_all_images.py | ||
---|---|---|
64 | 64 |
with open(os.path.join(saved_path, name + "_info.info"), 'r') as stream: |
65 | 65 |
con = stream.read().split('\n') |
66 | 66 |
classes = int(con[0]) |
67 |
DOFTECH_CLASSES = con[1:] |
|
67 |
DOFTECH_CLASSES = con[1:con.index('***bigs***')] |
|
68 |
bigs = con[con.index('***bigs***') + 1:] |
|
68 | 69 | |
69 | 70 |
if torch.cuda.is_available(): |
70 | 71 |
model1 = Yolo(classes).cuda() |
... | ... | |
147 | 148 |
print("=========> SMALL OBJECT DETECTION") |
148 | 149 |
for idx_small in range(len(small_object_patch_list)): |
149 | 150 |
patchs = small_object_patch_list[idx_small] |
150 |
total_symbole_list.append(detection_object(patchs, model1, opc=False, opt=opt, save_root=save_dir_tile1)) |
|
151 |
total_symbole_list.append(detection_object(patchs, model1, opc=False, opt=opt, save_root=save_dir_tile1, bigs=bigs))
|
|
151 | 152 |
# ---------------------------------------------- |
152 | 153 |
# Large Object Detection (OPC etc.) |
153 | 154 |
# ---------------------------------------------- |
154 | 155 |
print("=========> LARGE OBJECT DETECTION") |
155 | 156 |
for idx_large in range(len(large_object_patch_list)): |
156 | 157 |
patchs = large_object_patch_list[idx_large] |
157 |
total_symbole_list.append(detection_object(patchs, model2, opc=True, opt=opt, save_root=save_dir_tile2)) |
|
158 |
total_symbole_list.append(detection_object(patchs, model2, opc=True, opt=opt, save_root=save_dir_tile2, bigs=bigs))
|
|
158 | 159 | |
159 | 160 |
t_image = img_list[idx_img].copy() |
160 | 161 |
count_result = merge_fn(t_image, total_symbole_list, save_dir) |
... | ... | |
175 | 176 |
return total_symbole_lists |
176 | 177 | |
177 | 178 | |
178 |
def detection_object(patchs, model, opc, opt, save_root=None): |
|
179 |
def detection_object(patchs, model, opc, opt, save_root=None, bigs=None):
|
|
179 | 180 |
global DOFTECH_CLASSES |
180 | 181 |
|
181 | 182 |
symbol_list = [] |
... | ... | |
204 | 205 |
predictions = predictions[0] |
205 | 206 |
output_image = cv2.cvtColor(np.array(pil_image.img), cv2.COLOR_RGB2BGR) |
206 | 207 |
for pred in predictions: |
207 |
if opc == True : |
|
208 |
if opc == True and pred[5] in bigs:
|
|
208 | 209 |
if pred[4] > 0.4:#pred[5] == "opc" and pred[4] > 0.4: # Classification threshold |
209 | 210 |
xmin = int(max(pred[0] / width_ratio, 0)) |
210 | 211 |
ymin = int(max(pred[1] / height_ratio, 0)) |
... | ... | |
224 | 225 |
output_image, pred[5] + ' : %.2f' % pred[4], |
225 | 226 |
(xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, |
226 | 227 |
(255, 255, 255), 1) |
227 |
else :
|
|
228 |
elif opc == False and pred[5] not in bigs:
|
|
228 | 229 |
if pred[4] > 0.1: # Classification threshold |
229 | 230 |
xmin = int(max(pred[0] / width_ratio, 0)) |
230 | 231 |
ymin = int(max(pred[1] / height_ratio, 0)) |
DTI_PID/WebServer/symbol_training/train.py | ||
---|---|---|
31 | 31 | |
32 | 32 |
#print(len(DOFTECH_CLASSES)) |
33 | 33 | |
34 |
def train(name=None, classes=None, root_path=None, pre_trained_model_path=None): |
|
34 |
def train(name=None, classes=None, bigs=None, root_path=None, pre_trained_model_path=None):
|
|
35 | 35 |
global DOFTECH_CLASSES |
36 | 36 |
DOFTECH_CLASSES = classes |
37 | 37 | |
... | ... | |
72 | 72 |
with open(os.path.join(opt.saved_path, name + "_info.info"), 'w') as stream: |
73 | 73 |
con = str(len(DOFTECH_CLASSES)) |
74 | 74 |
names = '\n'.join(DOFTECH_CLASSES) |
75 |
con = con + '\n' + names |
|
75 |
bigs = '\n'.join(bigs) |
|
76 |
con = con + '\n' + names + '\n' + '***bigs***' + '\n' + bigs |
|
76 | 77 |
stream.write(con) |
77 | 78 | |
78 | 79 |
if torch.cuda.is_available(): |
내보내기 Unified diff