프로젝트

일반

사용자정보

개정판 9bdbbda1

ID9bdbbda19e602e568a30397dd5313739f2b4361c
상위 4d2aa82f
하위 d915dcd8

함의성이(가) 약 5년 전에 추가함

issue #1366: tile gpu testing stream server not commited yet

Change-Id: I6458aa749ab755dea0b1baa777a38d72abcaed12

차이점 보기:

DTI_PID/DTI_PID/AppWebService.py
3 3

  
4 4
#import urllib.request
5 5
#import urllib.parse
6
#import json
6
import json
7 7
import requests
8 8
import cv2
9 9
import sys
10
import base64
10 11
import numpy as np
11 12

  
12 13
class AppWebService:
......
26 27
            return False
27 28

  
28 29
    def request_text_box(self, img, img_path, score_path):
30
        # send uncroped image
29 31
        try:
30 32
            if not self.text_connection():
31 33
                return []
......
45 47
                                                          sys.exc_info()[-1].tb_lineno)
46 48
            App.mainWnd().addMessage.emit(MessageType.Error, message)
47 49

  
50
    def request_text_box_tile(self, img_infos, img_path, score_path):
51
        try:
52
            if not self.text_connection():
53
                return []
54

  
55
            text_box = '/stream_text_box'
56

  
57
            imgs = [info[4] for info in img_infos]
58
            str_imgs = []
59
            for img in imgs:
60
                _, bts = cv2.imencode('.png', img)
61
                #bts = bts.tostring()
62
                bts = base64.b64encode(bts).decode()
63
                str_imgs.append(bts)
64

  
65
            response = requests.post(self._url + text_box, data=json.dumps(str_imgs))
66
            box_list = response.json()['text_box_list']
67

  
68
            # for debug
69
            if len(img_infos) != len(box_list):
70
                print('check return values')
71
                return
72

  
73
            for index in range(len(img_infos)):
74
                img_infos[index].append(box_list[index])
75

  
76
            return img_infos
77
        except Exception as ex:
78
            from App import App
79
            from AppDocData import MessageType
80
            message = 'error occurred({}) in {}:{}'.format(ex, sys.exc_info()[-1].tb_frame.f_code.co_filename,
81
                                                          sys.exc_info()[-1].tb_lineno)
82
            App.mainWnd().addMessage.emit(MessageType.Error, message)
83

  
48 84
    '''
49 85
    def request(self, url):
50 86
    response = urllib.request.urlopen(url)
DTI_PID/DTI_PID/TextDetector.py
212 212

  
213 213
        return res_list, ocr_image
214 214

  
215
    def get_text_image_tile(self, img, size=[1300, 1300], overlap=100):
216
        """ return image tile """
217
        width, height = img.shape[1], img.shape[0]
218
        width_count, height_count = width // size[0] + 1, height // size[1] + 1
219
        b_width, b_height = width_count * size[0], height_count * size[1]
220
        b_img = np.zeros((b_height, b_width), np.uint8) + 255
221
        b_img[:height, :width] = img[:, :]
222

  
223
        tile_info_list = []
224
        for row in range(height_count):
225
            for col in range(width_count):
226
                t_width = size[0] if height_count == 1 else (\
227
                    size[0] + overlap * 2 if col != 0 and col != height_count - 1 else size[0] + overlap)
228
                t_height = size[1] if width_count == 1 else (\
229
                    size[1] + overlap * 2 if row != 0 and row != width_count - 1 else size[1] + overlap)
230

  
231
                t_y = 0 if row == 0 else row * size[1] - overlap
232
                t_x = 0 if col == 0 else col * size[0] - overlap
233
                t_img = b_img[t_y:t_y + t_height, t_x:t_x + t_width]
234

  
235
                tile_info_list.append([row, col, t_x, t_y, t_img.copy()])
236

  
237
                #Image.fromarray(tile_info_list[-1][4]).show()
238

  
239
        return tile_info_list
240

  
215 241
    def getTextBox_craft(self, ocr_image, maxTextSize, minSize, offset_x, offset_y, web=False):
216 242
        """ get text box by using craft """
217 243

  
......
231 257
            import text_craft
232 258

  
233 259
            boxes = text_craft.get_text_box(ocr_image, img_path, score_path, os.path.dirname(os.path.realpath('./')) + '\\WebServer\\CRAFT_pytorch_master\\weights\\craft_mlt_25k.pth')
234
        else:
260
        elif False:
235 261
            app_web_service = AppWebService()
236 262
            boxes = app_web_service.request_text_box(ocr_image, img_path, score_path)
263
        else:
264
            app_web_service = AppWebService()
265
            #boxes = app_web_service.request_text_box(ocr_image, img_path, score_path)
266

  
267
            tile_image_infos = self.get_text_image_tile(ocr_image)
268
            img_infos = app_web_service.request_text_box_tile(tile_image_infos, img_path, score_path)
269

  
270
            boxes = []
271
            for info in img_infos:
272
                for box in info[5]:
273
                    box[0] = box[0] + info[0]
274
                    box[1] = box[1] + info[1]
275
                    box[4] = box[4] + info[0]
276
                    box[5] = box[5] + info[1]
277

  
278
                boxes.extend(info[5])
237 279

  
238 280
        rects = []
239 281

  
240 282
        for box in boxes:
241 283
            rects.append(QRect(box[0], box[1], box[4] - box[0], box[5] - box[1]))
242 284

  
285
        # merge tile text box
286
        overlap_merges = []
287
        for rect1 in rects:
288
            for rect2 in rects:
289
                if rect1 is rect2:
290
                    continue
291
                l1, l2 = rect1.left(), rect2.left()
292
                r1, r2 = rect1.right(), rect2.right()
293
                l_x, s_x = [l1, r1], [l2, r2]
294
                t1, t2 = rect1.top(), rect2.top()
295
                b1, b2 = rect1.bottom(), rect2.bottom()
296
                l_y, s_y = [t1, b1], [t2, b2]
297
                if not (max(l_x) < min(s_x) or max(s_x) < min(l_x)) and \
298
                    (max(l_y) < min(s_y) or max(s_y) < min(l_y)):
299
                    inserted = False
300
                    for merge in overlap_merges:
301
                        if rect1 in merge and rect2 in merge:
302
                            inserted = True
303
                            break
304
                        elif rect1 in merge and rect2 not in merge:
305
                            merge.append(rect2)
306
                            inserted = True
307
                            break
308
                        elif rect2 in merge and rect1 not in merge:
309
                            merge.append(rect1)
310
                            inserted = True
311
                            break
312
                    if not inserted:
313
                        overlap_merges.append([rect1, rect2])
314

  
315
        for merge in overlap_merges:
316
            for rect in merge:
317
                if rect in rects:
318
                    rects.remove(rect)
319
                else:
320
                    print(str(rect))
321

  
322
        for merge in overlap_merges:
323
            max_x, max_y, min_x, min_y = 0, 0, sys.maxsize, sys.maxsize
324
            for rect in merge:
325
                if rect.left() < min_x:
326
                    min_x = rect.left()
327
                if rect.right() > max_x:
328
                    max_x = rect.right()
329
                if rect.top() < min_y:
330
                    min_y = rect.top()
331
                if rect.bottom() > max_y:
332
                    max_y = rect.bottom()
333

  
334
            rect = QRect(min_x, min_y, max_x - min_x, max_y - min_y)
335
            rects.append(rect)
336
        # up to here
337

  
338
        # merge adjacent text box
243 339
        configs = app_doc_data.getConfigs('Text Recognition', 'Merge Size')
244 340
        mergeSize = int(configs[0].value) if 1 == len(configs) else 10
245 341
        #gap_size = mergeSize / 2
......
322 418
                else:
323 419
                    print(str(rect))
324 420

  
325
        for merge in v_merges:
326
            max_x, max_y, min_x, min_y = 0, 0, sys.maxsize, sys.maxsize
327
            for rect in merge:
328
                if rect.left() < min_x:
329
                    min_x = rect.left()
330
                if rect.right() > max_x:
331
                    max_x = rect.right()
332
                if rect.top() < min_y:
333
                    min_y = rect.top()
334
                if rect.bottom() > max_y:
335
                    max_y = rect.bottom()
336

  
337
            rect = QRect(min_x, min_y, max_x - min_x, max_y - min_y)
338
            rect._vertical = True
339
            rects.append(rect)
340
        
341
        for merge in h_merges:
421
        for merge in v_merges + h_merges:
342 422
            max_x, max_y, min_x, min_y = 0, 0, sys.maxsize, sys.maxsize
343 423
            for rect in merge:
344 424
                if rect.left() < min_x:
......
351 431
                    max_y = rect.bottom()
352 432

  
353 433
            rect = QRect(min_x, min_y, max_x - min_x, max_y - min_y)
354
            rect._vertical = False
434
            if merge in v_merges:
435
                rect._vertical = True
436
            else:
437
                rect._vertical = False
355 438
            rects.append(rect)
439
        # up to here
356 440

  
357 441
        res_rects = []
358 442
        for rect in rects:
DTI_PID/WebServer/CRAFT_pytorch_master/text_craft.py
97 97

  
98 98
    return boxes, polys, ret_score_text
99 99

  
100
def get_text_box(img, img_path, score_path, trained_model=None):
100
def get_text_box_batch(infos):
101
    boxes_list = []
102
    for info in infos:
103
        boxes_list.append(get_text_box(info[0], info[1], info[2], info[3]))
104

  
105
    return boxes_list
106

  
107
def get_text_box(img, img_path=None, score_path=None, trained_model=None):
101 108
    if img.shape[0] == 2: img = img[0]
102 109
    if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
103 110
    if img.shape[2] == 4:   img = img[:,:,:3]
......
109 116
    parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
110 117
    parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
111 118
    parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
112
    parser.add_argument('--cuda', default=False, type=str2bool, help='Use cuda for inference')
113
    parser.add_argument('--canvas_size', default=4000, type=int, help='image size for inference')
119
    parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
120
    parser.add_argument('--canvas_size', default=1000, type=int, help='image size for inference')
114 121
    parser.add_argument('--mag_ratio', default=1, type=float, help='image magnification ratio')
115 122
    parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
116 123
    parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
DTI_PID/WebServer/app.py
2 2
import cv2
3 3
import numpy as np
4 4
import sys, os
5
import json, base64
5 6

  
7
# craft
6 8
sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '\\CRAFT_pytorch_master')
9
# service streamer
10
sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '\\service_streamer_master')
11
# deep ocr
7 12
#sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '\\deep_text_recognition_benchmark_master')
8 13

  
9 14
app = Flask(__name__)
10 15

  
16
try:
17
    #from model import get_prediction, batch_prediction
18
    import text_craft
19
    from service_streamer import ThreadedStreamer
20

  
21
    streamer = ThreadedStreamer(text_craft.get_text_box_batch, batch_size=64)
22
except ImportError as ex:
23
    ex
24
    pass
25

  
11 26
@app.route('/')
12 27
def index():
13 28
    return 'Hello Flask'
14 29
    
15 30
@app.route('/text_box', methods=['POST'])
16 31
def text_box():
17
    import text_craft
18

  
19
    r = request
20
    nparr = np.fromstring(r.data, np.uint8)
21

  
22
    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
23
    #img = img.reshape(1, -1)
24

  
25
    boxes = text_craft.get_text_box(img, img_path=None, score_path=None, trained_model=os.path.dirname(os.path.realpath(__file__)) + '\\CRAFT_pytorch_master\\weights\\craft_mlt_25k.pth')
26

  
27
    return jsonify({'text_box': boxes})
28
    
32
    if request.method == 'POST':
33
        r = request
34
        nparr = np.fromstring(r.data, np.uint8)
35

  
36
        img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
37
        #img = img.reshape(1, -1)
38

  
39
        boxes = text_craft.get_text_box(img, img_path=None, score_path=None, trained_model=os.path.dirname(os.path.realpath(__file__)) + '\\CRAFT_pytorch_master\\weights\\craft_mlt_25k.pth')
40

  
41
        return jsonify({'text_box': boxes})
42

  
43
@app.route('/stream_text_box', methods=['POST'])
44
def stream_text_box():
45
    if request.method == 'POST':
46
        r = request
47
        str_imgs = json.loads(r.data)
48
        imgs = []
49
        for str_img in str_imgs:
50
            str_img = base64.b64decode(str_img)
51
            nparr = np.fromstring(str_img, np.uint8)
52
            img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
53
            imgs.append(img)
54

  
55
        boxes_list = []
56
        for img in imgs:
57
            boxes = streamer.predict([[img, None, None, os.path.dirname(os.path.realpath(__file__)) + '\\CRAFT_pytorch_master\\weights\\craft_mlt_25k.pth']])
58
            boxes_list.append(boxes[0])
59
        return jsonify({'text_box_list': boxes_list})
29 60

  
30 61
if __name__ == '__main__':
31
    app.run(debug=True)
62
    app.run(debug=True)

내보내기 Unified diff

클립보드 이미지 추가 (최대 크기: 500 MB)