개정판 acb80620
issue #1366: add file
Change-Id: I54037a35e0f9e4de9485da03b75ec79a9a2dca14
DTI_PID/WebServer/symbol_recognition/test_doftech_all_images.py | ||
---|---|---|
1 |
""" |
|
2 |
Testing |
|
3 |
pid_images에 테스트 할 도면을 넣고 돌리면 결과를 pid_results에서 확인할 수 있습니다. |
|
4 |
pre_train_model path를 입력하고 돌리면 됩니다. |
|
5 |
""" |
|
6 |
import os |
|
7 |
import glob |
|
8 |
import argparse |
|
9 |
import pickle |
|
10 |
import cv2 |
|
11 |
import numpy as np |
|
12 |
from src.utils import * |
|
13 |
from src.yolo_net import Yolo |
|
14 |
from src.yolo_doftech import YoloD |
|
15 |
from PIL import Image |
|
16 |
import time |
|
17 |
DOFTECH_CLASSES = ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief', |
|
18 |
'3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure', |
|
19 |
'inst', 'func_valve', 'inst_console', 'inst_console_dcs', 'inst_console_sih', 'logic_dcs', 'utility', 'specialty_items', 'logic', 'logic_local_console_dcs', |
|
20 |
'reducer', 'blind_spectacle_open', 'blind_insertion_open', 'blind_spectacle_close', 'blind_insertion_close', |
|
21 |
'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot', |
|
22 |
'opc'] |
|
23 |
class Patch: |
|
24 |
def __init__(self, start_h, start_w, img, check_line): |
|
25 |
self.start_h = start_h |
|
26 |
self.start_w = start_w |
|
27 |
self.img = img |
|
28 |
self.check_line = check_line |
|
29 |
self.obj_num = 0 |
|
30 |
|
|
31 |
def setObjectNum(self, obj_num): |
|
32 |
self.obj_num = obj_num |
|
33 |
|
|
34 |
class Symbol: |
|
35 |
def __init__(self, start_h, start_w, end_h, end_w, iou, class_info): |
|
36 |
self.start_h = start_h |
|
37 |
self.start_w = start_w |
|
38 |
self.end_h = end_h |
|
39 |
self.end_w = end_w |
|
40 |
self.iou = iou |
|
41 |
self.class_info = class_info |
|
42 |
|
|
43 |
def to_stream(self): |
|
44 |
# y, x, height, width : member name is incorrect |
|
45 |
#return [self.class_info, self.start_w, self.start_h, self.end_w - self.start_w, self.end_h - self.start_h, self.iou] |
|
46 |
# x, y, width, height |
|
47 |
return [self.class_info, self.start_h, self.start_w, self.end_h - self.start_h, self.end_w - self.start_w, self.iou] |
|
48 |
|
|
49 |
def get_args(): |
|
50 |
parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection") |
|
51 |
parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images") |
|
52 |
parser.add_argument("--conf_threshold", type=float, default=0.5) |
|
53 |
parser.add_argument("--nms_threshold", type=float, default=0.6) |
|
54 |
parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="params") |
|
55 |
parser.add_argument("--pre_trained_model_path", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\MODEL/f_doftech_all_class_only_params.pth") |
|
56 |
parser.add_argument("--pre_trained_model_path2", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\MODEL/doftech_all_class_only_params_opc.pth") |
|
57 |
#parser.add_argument("--pre_trained_model_path3", type=str, default="trained_models/only_params_trained_yolo_voc") |
|
58 |
parser.add_argument("--input", type=str, default="test_images") |
|
59 |
parser.add_argument("--output", type=str, default="result_images") |
|
60 |
parser.add_argument("--data_path_test", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\pid_images/", help="the root folder of dataset") #Doftech VOCdevkit |
|
61 |
args = parser.parse_args() |
|
62 |
return args |
|
63 |
|
|
64 |
def get_symbol(imgs, trained_model1=None, trained_model2=None): |
|
65 |
global colors |
|
66 |
|
|
67 |
parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection") |
|
68 |
parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images") |
|
69 |
parser.add_argument("--conf_threshold", type=float, default=0.5) |
|
70 |
parser.add_argument("--nms_threshold", type=float, default=0.6) |
|
71 |
parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="params") |
|
72 |
parser.add_argument("--pre_trained_model_path", type=str, default=trained_model1) |
|
73 |
parser.add_argument("--pre_trained_model_path2", type=str, default=trained_model2) |
|
74 |
parser.add_argument("--data_path_test", type=str, default="E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\pid_images/", help="the root folder of dataset") #Doftech VOCdevkit |
|
75 |
opt = parser.parse_args() |
|
76 |
|
|
77 |
if torch.cuda.is_available(): |
|
78 |
model1 = Yolo(35).cuda() |
|
79 |
model1.load_state_dict(torch.load(opt.pre_trained_model_path)) |
|
80 |
|
|
81 |
model2 = Yolo(35).cuda() |
|
82 |
model2.load_state_dict(torch.load(opt.pre_trained_model_path2)) |
|
83 |
|
|
84 |
colors = pickle.load(open(os.path.dirname(os.path.realpath(__file__)) + "/src/pallete", "rb")) |
|
85 |
|
|
86 |
img_list = imgs |
|
87 |
|
|
88 |
start = time.time() |
|
89 |
|
|
90 |
total_symbole_lists = [] |
|
91 |
|
|
92 |
# ---------------------------------------------- |
|
93 |
# Get Patch arguments : (img list, patch_size, overlap_size) |
|
94 |
# ---------------------------------------------- |
|
95 |
|
|
96 |
for idx_img in range(len(img_list)): |
|
97 |
print("=========> CROP PATCH") |
|
98 |
small_object_patch_list = [] |
|
99 |
large_object_patch_list = [] |
|
100 |
total_symbole_list = [] |
|
101 |
|
|
102 |
s_image = img_list[idx_img].copy() |
|
103 |
l_image = img_list[idx_img].copy() |
|
104 |
|
|
105 |
small_object_patch_list.append(get_patch(s_image, 500, 250)) |
|
106 |
large_object_patch_list.append(get_patch(l_image, 800, 200)) |
|
107 |
|
|
108 |
img_name = str(idx_img) |
|
109 |
|
|
110 |
print("=========> "+ img_name) |
|
111 |
|
|
112 |
save_dir = 'E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\save/' + img_name + '/' |
|
113 |
if not os.path.isdir(save_dir): |
|
114 |
os.mkdir(save_dir) |
|
115 |
os.mkdir(save_dir+"a/") |
|
116 |
os.mkdir(save_dir+"b/") |
|
117 |
|
|
118 |
text_file = open(save_dir+'test_result.txt', mode='wt', encoding='utf-8') |
|
119 |
text_file.write(img_name+" -> Detection Reesult\n") |
|
120 |
text_file.write("================================\n") |
|
121 |
|
|
122 |
# Text File 만들기 |
|
123 |
|
|
124 |
# ---------------------------------------------- |
|
125 |
# Small Object Detection (Valve, Sensor, etc.) |
|
126 |
# --------------------------------------------- |
|
127 |
print("=========> SMALL OBJECT DETECTION") |
|
128 |
for idx_small in range(len(small_object_patch_list)): |
|
129 |
patchs = small_object_patch_list[idx_small] |
|
130 |
total_symbole_list.append(detection_object(patchs, model1, save_dir+"a/", opc=False, opt=opt)) |
|
131 |
# ---------------------------------------------- |
|
132 |
# Large Object Detection (OPC etc.) |
|
133 |
# ---------------------------------------------- |
|
134 |
print("=========> LARGE OBJECT DETECTION") |
|
135 |
for idx_large in range(len(large_object_patch_list)): |
|
136 |
patchs = large_object_patch_list[idx_large] |
|
137 |
total_symbole_list.append(detection_object(patchs, model2, save_dir+"b/", opc=True, opt=opt)) |
|
138 |
|
|
139 |
t_image = img_list[idx_img].copy() |
|
140 |
count_result = merge_fn(t_image, total_symbole_list, save_dir) |
|
141 |
|
|
142 |
for idx, value in enumerate(count_result): |
|
143 |
text_file.write(DOFTECH_CLASSES[idx]+ " : "+ str(value) + "\n") |
|
144 |
|
|
145 |
text_file.close() |
|
146 |
|
|
147 |
res = [] |
|
148 |
for symbol in total_symbole_list[0] + total_symbole_list[1]: |
|
149 |
res.append(symbol.to_stream()) |
|
150 |
total_symbole_lists.append(res) |
|
151 |
|
|
152 |
print("time :", time.time() - start) # 현재시각 - 시작시간 = 실행 시간 |
|
153 |
|
|
154 |
return total_symbole_lists |
|
155 |
|
|
156 |
def test(opt): |
|
157 |
global colors |
|
158 |
if torch.cuda.is_available(): |
|
159 |
model1 = Yolo(35).cuda() |
|
160 |
model1.load_state_dict(torch.load(opt.pre_trained_model_path)) |
|
161 |
|
|
162 |
model2 = Yolo(35).cuda() |
|
163 |
model2.load_state_dict(torch.load(opt.pre_trained_model_path2)) |
|
164 |
else: |
|
165 |
return None |
|
166 |
|
|
167 |
colors = pickle.load(open("E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition/src/pallete", "rb")) |
|
168 |
|
|
169 |
print(opt.input) |
|
170 |
|
|
171 |
img_list = sorted(glob.glob(os.path.join(opt.data_path_test, '*.png'))) |
|
172 |
|
|
173 |
start = time.time() |
|
174 |
print(img_list) |
|
175 |
|
|
176 |
# ---------------------------------------------- |
|
177 |
# Get Patch arguments : (img list, patch_size, overlap_size) |
|
178 |
# ---------------------------------------------- |
|
179 |
|
|
180 |
for idx_img in range(len(img_list)): |
|
181 |
print("=========> CROP PATCH") |
|
182 |
small_object_patch_list = [] |
|
183 |
large_object_patch_list = [] |
|
184 |
total_symbole_list = [] |
|
185 |
|
|
186 |
small_object_patch_list.append(get_patch(img_list[idx_img], 500, 250)) |
|
187 |
large_object_patch_list.append(get_patch(img_list[idx_img], 800, 200)) |
|
188 |
|
|
189 |
img_name = img_list[idx_img].split('\\')[1].split('.')[0] |
|
190 |
|
|
191 |
print("=========> "+ img_name) |
|
192 |
|
|
193 |
save_dir = 'E:\Projects\DTIPID_GERRIT\DTI_PID\WebServer\symbol_recognition\save/' + img_name + '/' |
|
194 |
if not os.path.isdir(save_dir): |
|
195 |
os.mkdir(save_dir) |
|
196 |
os.mkdir(save_dir+"a/") |
|
197 |
os.mkdir(save_dir+"b/") |
|
198 |
|
|
199 |
text_file = open(save_dir+'test_result.txt', mode='wt', encoding='utf-8') |
|
200 |
text_file.write(img_name+" -> Detection Reesult\n") |
|
201 |
text_file.write("================================\n") |
|
202 |
|
|
203 |
# Text File 만들기 |
|
204 |
|
|
205 |
# ---------------------------------------------- |
|
206 |
# Small Object Detection (Valve, Sensor, etc.) |
|
207 |
# --------------------------------------------- |
|
208 |
print("=========> SMALL OBJECT DETECTION") |
|
209 |
for idx_small in range(len(small_object_patch_list)): |
|
210 |
patchs = small_object_patch_list[idx_small] |
|
211 |
total_symbole_list.append(detection_object(patchs, model1, save_dir+"a/", opc=False, opt=opt)) |
|
212 |
# ---------------------------------------------- |
|
213 |
# Large Object Detection (OPC etc.) |
|
214 |
# ---------------------------------------------- |
|
215 |
print("=========> LARGE OBJECT DETECTION") |
|
216 |
for idx_large in range(len(large_object_patch_list)): |
|
217 |
patchs = large_object_patch_list[idx_large] |
|
218 |
total_symbole_list.append(detection_object(patchs, model2, save_dir+"b/", opc=True, opt=opt)) |
|
219 |
|
|
220 |
count_result = merge_fn(img_list[idx_img], total_symbole_list, save_dir) |
|
221 |
|
|
222 |
for idx, value in enumerate(count_result): |
|
223 |
text_file.write(DOFTECH_CLASSES[idx]+ " : "+ str(value) + "\n") |
|
224 |
|
|
225 |
text_file.close() |
|
226 |
print("time :", time.time() - start) # 현재시각 - 시작시간 = 실행 시간 |
|
227 |
|
|
228 |
def detection_object(patchs, model, save_root, opc, opt): |
|
229 |
symbol_list = [] |
|
230 |
for idx in range(len(patchs)): |
|
231 |
pil_image = patchs[idx] |
|
232 |
# pil_image.img.show() |
|
233 |
np_img = np.asarray(pil_image.img) |
|
234 |
image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) |
|
235 |
height, width = image.shape[:2] |
|
236 |
image = cv2.resize(image, (opt.image_size, opt.image_size)) |
|
237 |
image = np.transpose(np.array(image, dtype=np.float32), (2, 0, 1)) |
|
238 |
image = image[None, :, :, :] |
|
239 |
width_ratio = float(opt.image_size) / width |
|
240 |
height_ratio = float(opt.image_size) / height |
|
241 |
|
|
242 |
data = Variable(torch.FloatTensor(image)) |
|
243 |
|
|
244 |
if torch.cuda.is_available(): |
|
245 |
data = data.cuda() |
|
246 |
|
|
247 |
with torch.no_grad(): |
|
248 |
logits = model(data) |
|
249 |
predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold, |
|
250 |
opt.nms_threshold) |
|
251 |
if len(predictions) != 0: |
|
252 |
predictions = predictions[0] |
|
253 |
output_image = cv2.cvtColor(np.array(pil_image.img), cv2.COLOR_RGB2BGR) |
|
254 |
for pred in predictions: |
|
255 |
if opc == True : |
|
256 |
if pred[5] == "opc" and pred[4] > 0.4: # Classification threshold |
|
257 |
xmin = int(max(pred[0] / width_ratio, 0)) |
|
258 |
ymin = int(max(pred[1] / height_ratio, 0)) |
|
259 |
xmax = int(min((pred[0] + pred[2]) / width_ratio, width)) |
|
260 |
ymax = int(min((pred[1] + pred[3]) / height_ratio, height)) |
|
261 |
color = colors[DOFTECH_CLASSES.index(pred[5])] |
|
262 |
symbol = Symbol(xmin + pil_image.start_w, ymin + pil_image.start_h, xmax + pil_image.start_w, |
|
263 |
ymax + pil_image.start_h, pred[4], pred[5]) # To save symbole information |
|
264 |
check = check_duplication(symbol_list, symbol, True) |
|
265 |
if check is True: |
|
266 |
symbol_list.append(symbol) |
|
267 |
cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2) |
|
268 |
text_size = cv2.getTextSize(pred[5] + ' : %.2f' % pred[4], cv2.FONT_HERSHEY_PLAIN, 1, 1)[0] |
|
269 |
cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4), |
|
270 |
color,-1) |
|
271 |
cv2.putText( |
|
272 |
output_image, pred[5] + ' : %.2f' % pred[4], |
|
273 |
(xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, |
|
274 |
(255, 255, 255), 1) |
|
275 |
else : |
|
276 |
if pred[4] > 0.1: # Classification threshold |
|
277 |
xmin = int(max(pred[0] / width_ratio, 0)) |
|
278 |
ymin = int(max(pred[1] / height_ratio, 0)) |
|
279 |
xmax = int(min((pred[0] + pred[2]) / width_ratio, width)) |
|
280 |
ymax = int(min((pred[1] + pred[3]) / height_ratio, height)) |
|
281 |
color = colors[DOFTECH_CLASSES.index(pred[5])] |
|
282 |
symbol = Symbol(xmin + pil_image.start_w, ymin + pil_image.start_h, xmax + pil_image.start_w, |
|
283 |
ymax + pil_image.start_h, pred[4], pred[5]) # To save symbole information |
|
284 |
check = check_duplication(symbol_list, symbol) |
|
285 |
if check is True: |
|
286 |
symbol_list.append(symbol) |
|
287 |
cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2) |
|
288 |
text_size = cv2.getTextSize(pred[5] + ' : %.2f' % pred[4], cv2.FONT_HERSHEY_PLAIN, 1, 1)[0] |
|
289 |
cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4), color, |
|
290 |
-1) |
|
291 |
cv2.putText( |
|
292 |
output_image, pred[5] + ' : %.2f' % pred[4], |
|
293 |
(xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, |
|
294 |
(255, 255, 255), 1) |
|
295 |
# -----------------------fill king image-------------------------------------- |
|
296 |
cv2.imwrite(save_root+"{}_prediction.png".format(idx), output_image) |
|
297 |
pil_image.setObjectNum(len(predictions)) |
|
298 |
return symbol_list |
|
299 |
|
|
300 |
def merge_fn(total_img, total_symbole_list, save_root): |
|
301 |
""" total_img:path -> img file """ |
|
302 |
print("=========> MERGE RESULT") |
|
303 |
|
|
304 |
if type(total_img) is str: |
|
305 |
total_img = Image.open(total_img) |
|
306 |
|
|
307 |
king_image = total_img.convert('RGB') |
|
308 |
temp_king_img = np.array(king_image)#Image.new("RGB", (king_width, king_height)) |
|
309 |
temp_king_img = temp_king_img[:,:,::-1].copy() |
|
310 |
|
|
311 |
count_result = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
|
312 |
|
|
313 |
for idx_out in range(len(total_symbole_list)): |
|
314 |
for idx in range(len(total_symbole_list[idx_out])): |
|
315 |
color = colors[DOFTECH_CLASSES.index(total_symbole_list[idx_out][idx].class_info)] |
|
316 |
cv2.rectangle(temp_king_img, (total_symbole_list[idx_out][idx].start_h, total_symbole_list[idx_out][idx].start_w), (total_symbole_list[idx_out][idx].end_h, total_symbole_list[idx_out][idx].end_w), color, 2) |
|
317 |
text_size = cv2.getTextSize(total_symbole_list[idx_out][idx].class_info + ' : %.2f' % total_symbole_list[idx_out][idx].iou, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0] |
|
318 |
cv2.rectangle(temp_king_img, (total_symbole_list[idx_out][idx].start_h, total_symbole_list[idx_out][idx].start_w), (total_symbole_list[idx_out][idx].start_h + text_size[0] + 3, total_symbole_list[idx_out][idx].start_w + text_size[1] + 4), color, -1) |
|
319 |
cv2.putText( |
|
320 |
temp_king_img, total_symbole_list[idx_out][idx].class_info + ' : %.2f' % total_symbole_list[idx_out][idx].iou, |
|
321 |
(total_symbole_list[idx_out][idx].start_h, total_symbole_list[idx_out][idx].start_w+ text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, |
|
322 |
(255, 255, 255), 1) |
|
323 |
# COUNTING |
|
324 |
for index, name in enumerate(DOFTECH_CLASSES): |
|
325 |
if name == total_symbole_list[idx_out][idx].class_info : |
|
326 |
count_result[index] = count_result[index]+1 |
|
327 |
cv2.imwrite(save_root+"entire.png", temp_king_img) |
|
328 |
|
|
329 |
return count_result |
|
330 |
|
|
331 |
def get_patch(image, patch_size, overlap_size): |
|
332 |
""" img_path -> img """ |
|
333 |
crop_list = [] |
|
334 |
if type(image) is str: |
|
335 |
image = Image.open(image) |
|
336 |
(img_h, img_w) = image.size |
|
337 |
|
|
338 |
# crop 할 사이즈 : grid_w, grid_h |
|
339 |
grid_w = patch_size # crop width |
|
340 |
grid_h = patch_size # crop |
|
341 |
|
|
342 |
range_w = (int)(img_w / grid_w) |
|
343 |
range_h = (int)(img_h / grid_h) |
|
344 |
# print(range_w, range_h) |
|
345 |
|
|
346 |
repeat_num_w = (int)(grid_w / overlap_size) |
|
347 |
repeat_num_h = (int)(grid_h / overlap_size) |
|
348 |
# print(repeat_num_w, repeat_num_h) |
|
349 |
|
|
350 |
for w in range(range_w * repeat_num_w): |
|
351 |
for h in range(range_h * repeat_num_h): |
|
352 |
if h == (range_h * repeat_num_h) - 1 and w == (range_w * repeat_num_w) - 1: |
|
353 |
start_h = img_h - grid_h |
|
354 |
start_w = img_w - grid_w |
|
355 |
end_h = img_h |
|
356 |
end_w = img_w |
|
357 |
bbox = (start_h, start_w, end_h, end_w) |
|
358 |
|
|
359 |
crop_img = image.crop(bbox) |
|
360 |
patch = Patch(start_w, start_h, crop_img, True) |
|
361 |
crop_list.append(patch) |
|
362 |
|
|
363 |
elif (h * overlap_size) + (grid_h) >= img_h: |
|
364 |
start_h = img_h - grid_h # OK |
|
365 |
start_w = img_w - grid_w |
|
366 |
end_h = img_h |
|
367 |
end_w = start_w + grid_w |
|
368 |
bbox = (start_h, start_w, end_h, end_w) |
|
369 |
|
|
370 |
crop_img = image.crop(bbox) |
|
371 |
patch = Patch(start_w, start_h, crop_img, True) |
|
372 |
crop_list.append(patch) |
|
373 |
|
|
374 |
|
|
375 |
elif (w * overlap_size) + (grid_w) >= img_w: |
|
376 |
start_h = img_h - grid_h |
|
377 |
start_w = img_w - grid_w |
|
378 |
end_h = start_h + grid_h |
|
379 |
end_w = img_w |
|
380 |
bbox = (start_h, start_w, end_h, end_w) |
|
381 |
|
|
382 |
crop_img = image.crop(bbox) |
|
383 |
patch = Patch(start_w, start_h, crop_img, True) |
|
384 |
crop_list.append(patch) |
|
385 |
|
|
386 |
else: |
|
387 |
bbox = (h * overlap_size, w * overlap_size, (h * overlap_size) + (grid_h), |
|
388 |
(w * overlap_size) + (grid_w)) # 좌, 상, 우, 하 |
|
389 |
crop_img = image.crop(bbox) |
|
390 |
patch = Patch(w * overlap_size, h * overlap_size, crop_img, False) |
|
391 |
crop_list.append(patch) |
|
392 |
# print(bbox) |
|
393 |
# patch.img.save('./pid_results/_{}.png'.format(len(crop_list))) |
|
394 |
return crop_list |
|
395 |
|
|
396 |
def check_duplication(symbol_list, check_symbol, opc=False): |
|
397 |
|
|
398 |
if check_symbol.class_info == "strainer_basket": |
|
399 |
threshold = 30 |
|
400 |
elif check_symbol.class_info == "3way_solenoid": |
|
401 |
threshold = 30 |
|
402 |
else : |
|
403 |
threshold = 20 |
|
404 |
|
|
405 |
for idx in range(len(symbol_list)): |
|
406 |
if (abs(symbol_list[idx].start_h - check_symbol.start_h) <= threshold): |
|
407 |
if (abs(symbol_list[idx].start_w - check_symbol.start_w) <= threshold): |
|
408 |
if (symbol_list[idx].iou < check_symbol.iou): |
|
409 |
symbol_list[idx] = check_symbol |
|
410 |
return False |
|
411 |
else: |
|
412 |
return False |
|
413 |
return True |
|
414 |
|
|
415 |
if __name__ == "__main__": |
|
416 |
opt = get_args() |
|
417 |
test(opt) |
|
418 |
|
내보내기 Unified diff