hytos / DTI_PID / WebServer / app / training / index.py @ 5a049576
이력 | 보기 | 이력해설 | 다운로드 (3.23 KB)
1 |
# file name : index.py
|
---|---|
2 |
# pwd : /project_name/app/license/index.py
|
3 |
|
4 |
from flask import Blueprint, request, render_template, flash, redirect, url_for, jsonify |
5 |
from flask_wtf import FlaskForm |
6 |
from flask_wtf.file import FileField, FileAllowed, FileRequired |
7 |
from wtforms import * |
8 |
from werkzeug.utils import secure_filename |
9 |
from wtforms.validators import * |
10 |
|
11 |
import json, base64 |
12 |
import cv2 |
13 |
import sys, os |
14 |
import numpy as np |
15 |
|
16 |
# training
|
17 |
sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training') |
18 |
|
19 |
training_service = Blueprint('training', __name__, url_prefix='/training') |
20 |
|
21 |
import train |
22 |
|
23 |
|
24 |
class TrainingForm(FlaskForm): |
25 |
csrf_token = 'id2 web service'
|
26 |
project_name = StringField('Project Name', required=True) |
27 |
class_file = FileField('Class File', validators=[FileRequired()])
|
28 |
training_file = FileField('Training File', validators=[FileRequired()])
|
29 |
|
30 |
|
31 |
@training_service.route('/add', methods=['GET', 'POST']) |
32 |
def add_training(): |
33 |
form = TrainingForm() |
34 |
if request.method == 'POST': |
35 |
if form.validate_on_submit():
|
36 |
f = request.files[form.training_file.name] |
37 |
f.save(f.filename) |
38 |
else:
|
39 |
file_url = None
|
40 |
|
41 |
return render_template("/training/index.html", form=form) |
42 |
|
43 |
|
44 |
@training_service.route('/upload_training_data', methods=['POST']) |
45 |
def upload_training_data(): |
46 |
if request.method == 'POST': |
47 |
r = request |
48 |
datas = json.loads(r.data) |
49 |
|
50 |
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\', datas['name']) |
51 |
|
52 |
if not os.path.isdir(data_path): |
53 |
os.mkdir(data_path) |
54 |
os.mkdir(os.path.join(data_path, 'training'))
|
55 |
os.mkdir(os.path.join(data_path, 'training', 'xml')) |
56 |
os.mkdir(os.path.join(data_path, 'training', 'img')) |
57 |
os.mkdir(os.path.join(data_path, 'test'))
|
58 |
os.mkdir(os.path.join(data_path, 'test', 'xml')) |
59 |
os.mkdir(os.path.join(data_path, 'test', 'img')) |
60 |
|
61 |
count = 0
|
62 |
|
63 |
for name, str_img in datas['tiles']: |
64 |
str_img = base64.b64decode(str_img) |
65 |
nparr = np.fromstring(str_img, np.uint8) |
66 |
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
67 |
|
68 |
cv2.imwrite(os.path.join(data_path, name), img) |
69 |
count += 1
|
70 |
|
71 |
for name, xml in datas['xmls']: |
72 |
#rect = base64.b64decode(str_img)
|
73 |
with open(os.path.join(data_path, name), 'w') as stream: |
74 |
stream.write(xml) |
75 |
count += 1
|
76 |
|
77 |
return jsonify({'count': count}) |
78 |
|
79 |
@training_service.route('/training_model', methods=['POST']) |
80 |
def training_model(): |
81 |
if request.method == 'POST': |
82 |
r = request |
83 |
datas = json.loads(r.data) |
84 |
|
85 |
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\', datas['name']) |
86 |
|
87 |
if os.path.isdir(data_path):
|
88 |
train.train(name=datas['name'], classes=datas['classes'], bigs=datas['bigs'], root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath( |
89 |
__file__)) + '\\..\\..\\symbol_training\\pre_trained_model\\only_params_trained_yolo_voc')
|
90 |
|
91 |
return jsonify({'count': 1}) |