프로젝트

일반

사용자정보

통계
| 브랜치(Branch): | 개정판:

hytos / DTI_PID / WebServer / app / training / index.py @ b64cc3b5

이력 | 보기 | 이력해설 | 다운로드 (3.23 KB)

1
# file name : index.py
2
# pwd : /project_name/app/license/index.py
3

    
4
from flask import Blueprint, request, render_template, flash, redirect, url_for, jsonify
5
from flask_wtf import FlaskForm
6
from flask_wtf.file import FileField, FileAllowed, FileRequired
7
from wtforms import *
8
from werkzeug.utils import secure_filename
9
from wtforms.validators import *
10

    
11
import json, base64
12
import cv2
13
import sys, os
14
import numpy as np
15

    
16
# training
17
sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training')
18

    
19
training_service = Blueprint('training', __name__, url_prefix='/training')
20

    
21
import train
22

    
23

    
24
class TrainingForm(FlaskForm):
25
    csrf_token = 'id2 web service'
26
    project_name = StringField('Project Name', required=True)
27
    class_file = FileField('Class File', validators=[FileRequired()])
28
    training_file = FileField('Training File', validators=[FileRequired()])
29

    
30

    
31
@training_service.route('/add', methods=['GET', 'POST'])
32
def add_training():
33
    form = TrainingForm()
34
    if request.method == 'POST':
35
        if form.validate_on_submit():
36
            f = request.files[form.training_file.name]
37
            f.save(f.filename)
38
        else:
39
            file_url = None
40

    
41
    return render_template("/training/index.html", form=form)
42

    
43

    
44
@training_service.route('/upload_training_data', methods=['POST'])
45
def upload_training_data():
46
    if request.method == 'POST':
47
        r = request
48
        datas = json.loads(r.data)
49
        
50
        data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\', datas['name'])
51

    
52
        if not os.path.isdir(data_path):
53
            os.mkdir(data_path)
54
            os.mkdir(os.path.join(data_path, 'training'))
55
            os.mkdir(os.path.join(data_path, 'training', 'xml'))
56
            os.mkdir(os.path.join(data_path, 'training', 'img'))
57
            os.mkdir(os.path.join(data_path, 'test'))
58
            os.mkdir(os.path.join(data_path, 'test', 'xml'))
59
            os.mkdir(os.path.join(data_path, 'test', 'img'))
60

    
61
        count = 0
62

    
63
        for name, str_img in datas['tiles']:
64
            str_img = base64.b64decode(str_img)
65
            nparr = np.fromstring(str_img, np.uint8)
66
            img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
67
            
68
            cv2.imwrite(os.path.join(data_path, name), img)
69
            count += 1
70

    
71
        for name, xml in datas['xmls']:
72
            #rect = base64.b64decode(str_img)
73
            with open(os.path.join(data_path, name), 'w') as stream:
74
                stream.write(xml)
75
            count += 1
76

    
77
        return jsonify({'count': count})
78

    
79
@training_service.route('/training_model', methods=['POST'])
80
def training_model():
81
    if request.method == 'POST':
82
        r = request
83
        datas = json.loads(r.data)
84
        
85
        data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\', datas['name'])
86

    
87
        if os.path.isdir(data_path):
88
            train.train(name=datas['name'], classes=datas['classes'], bigs=datas['bigs'], root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath(
89
                                                       __file__)) + '\\..\\..\\symbol_training\\pre_trained_model\\only_params_trained_yolo_voc')
90
            
91
        return jsonify({'count': 1})
클립보드 이미지 추가 (최대 크기: 500 MB)