hytos / DTI_PID / WebServer / app / training / index.py @ dd6d4de9
이력 | 보기 | 이력해설 | 다운로드 (3.23 KB)
1 | 5f6cb499 | humkyung | # file name : index.py
|
---|---|---|---|
2 | # pwd : /project_name/app/license/index.py
|
||
3 | |||
4 | 6374c2c6 | esham21 | from flask import Blueprint, request, render_template, flash, redirect, url_for, jsonify |
5 | 5f6cb499 | humkyung | from flask_wtf import FlaskForm |
6 | from flask_wtf.file import FileField, FileAllowed, FileRequired |
||
7 | from wtforms import * |
||
8 | from werkzeug.utils import secure_filename |
||
9 | from wtforms.validators import * |
||
10 | |||
11 | 6374c2c6 | esham21 | import json, base64 |
12 | import cv2 |
||
13 | import sys, os |
||
14 | import numpy as np |
||
15 | |||
16 | # training
|
||
17 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training') |
||
18 | |||
19 | 5f6cb499 | humkyung | training_service = Blueprint('training', __name__, url_prefix='/training') |
20 | |||
21 | 6374c2c6 | esham21 | import train |
22 | |||
23 | 5f6cb499 | humkyung | |
24 | class TrainingForm(FlaskForm): |
||
25 | csrf_token = 'id2 web service'
|
||
26 | cfdb92c4 | humkyung | project_name = StringField('Project Name', required=True) |
27 | class_file = FileField('Class File', validators=[FileRequired()])
|
||
28 | 5f6cb499 | humkyung | training_file = FileField('Training File', validators=[FileRequired()])
|
29 | |||
30 | |||
31 | @training_service.route('/add', methods=['GET', 'POST']) |
||
32 | def add_training(): |
||
33 | form = TrainingForm() |
||
34 | if request.method == 'POST': |
||
35 | if form.validate_on_submit():
|
||
36 | f = request.files[form.training_file.name] |
||
37 | f.save(f.filename) |
||
38 | else:
|
||
39 | file_url = None
|
||
40 | |||
41 | return render_template("/training/index.html", form=form) |
||
42 | |||
43 | 6374c2c6 | esham21 | |
44 | @training_service.route('/upload_training_data', methods=['POST']) |
||
45 | def upload_training_data(): |
||
46 | if request.method == 'POST': |
||
47 | r = request |
||
48 | datas = json.loads(r.data) |
||
49 | |||
50 | data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\', datas['name']) |
||
51 | |||
52 | if not os.path.isdir(data_path): |
||
53 | os.mkdir(data_path) |
||
54 | os.mkdir(os.path.join(data_path, 'training'))
|
||
55 | os.mkdir(os.path.join(data_path, 'training', 'xml')) |
||
56 | os.mkdir(os.path.join(data_path, 'training', 'img')) |
||
57 | os.mkdir(os.path.join(data_path, 'test'))
|
||
58 | os.mkdir(os.path.join(data_path, 'test', 'xml')) |
||
59 | os.mkdir(os.path.join(data_path, 'test', 'img')) |
||
60 | |||
61 | count = 0
|
||
62 | |||
63 | for name, str_img in datas['tiles']: |
||
64 | str_img = base64.b64decode(str_img) |
||
65 | nparr = np.fromstring(str_img, np.uint8) |
||
66 | img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
||
67 | |||
68 | cv2.imwrite(os.path.join(data_path, name), img) |
||
69 | count += 1
|
||
70 | |||
71 | for name, xml in datas['xmls']: |
||
72 | #rect = base64.b64decode(str_img)
|
||
73 | with open(os.path.join(data_path, name), 'w') as stream: |
||
74 | stream.write(xml) |
||
75 | count += 1
|
||
76 | |||
77 | return jsonify({'count': count}) |
||
78 | |||
79 | @training_service.route('/training_model', methods=['POST']) |
||
80 | def training_model(): |
||
81 | if request.method == 'POST': |
||
82 | r = request |
||
83 | datas = json.loads(r.data) |
||
84 | |||
85 | data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\', datas['name']) |
||
86 | |||
87 | 76986eb0 | esham21 | if os.path.isdir(data_path):
|
88 | 1e8ea226 | esham21 | train.train(name=datas['name'], classes=datas['classes'], bigs=datas['bigs'], root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath( |
89 | 6374c2c6 | esham21 | __file__)) + '\\..\\..\\symbol_training\\pre_trained_model\\only_params_trained_yolo_voc')
|
90 | |||
91 | cfdb92c4 | humkyung | return jsonify({'count': 1}) |