개정판 cfdb92c4
issue #1366: upload training restful api
Change-Id: I2e51bca1194d1ad4d19c04008eb054656df54a9b
DTI_PID/WebServer/app/__init__.py | ||
---|---|---|
7 | 7 |
from app.training.index import training_service |
8 | 8 |
from app.api.controller.license_controller import api as license_ns |
9 | 9 |
from app.api.controller.license_controller import license_service |
10 |
from app.api.controller.TrainingController import api as training_ns |
|
11 |
from app.api.controller.TrainingController import TrainingModel |
|
10 | 12 |
from app.api.models import LicenseDTO |
11 | 13 | |
12 | 14 | |
... | ... | |
39 | 41 |
description='flask restx web service for ID2' |
40 | 42 |
) |
41 | 43 |
api.add_namespace(license_ns, path='/license') |
44 |
api.add_namespace(training_ns, path='/training') |
DTI_PID/WebServer/app/api/controller/TrainingController.py | ||
---|---|---|
1 |
import os |
|
2 |
import zipfile |
|
3 |
from flask import request, Blueprint, redirect, url_for, render_template, make_response, jsonify |
|
4 |
from flask_restx import Namespace, Resource, reqparse |
|
5 |
from werkzeug.datastructures import FileStorage |
|
6 | ||
7 |
api = Namespace('training', description='file related operations') |
|
8 |
upload_parser = api.parser() |
|
9 |
upload_parser.add_argument('project_name', type=str, required=True) |
|
10 |
upload_parser.add_argument('class_file', location='files', type=FileStorage, required=True) |
|
11 |
upload_parser.add_argument('training_file', location='files', type=FileStorage, required=True) |
|
12 | ||
13 |
training_parser = api.parser() |
|
14 |
training_parser.add_argument('project_name', type=str, required=True) |
|
15 |
training_parser.add_argument('class_name_file', location='files', type=FileStorage, required=True) |
|
16 | ||
17 | ||
18 |
def allowed_file(filename): |
|
19 |
return os.path.splitext(filename)[1].upper() == '.ZIP' |
|
20 | ||
21 | ||
22 |
@api.route('/upload_files') |
|
23 |
@api.expect(upload_parser) |
|
24 |
class FileStorage(Resource): |
|
25 |
def post(self): |
|
26 |
args = upload_parser.parse_args() |
|
27 |
project_name = args['project_name'] |
|
28 |
class_file = request.files['class_file'] |
|
29 |
training_file = request.files['training_file'] |
|
30 | ||
31 |
project_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\', |
|
32 |
project_name) |
|
33 | ||
34 |
try: |
|
35 |
if not os.path.isdir(project_path): |
|
36 |
os.mkdir(project_path) |
|
37 |
os.mkdir(os.path.join(project_path, 'training')) |
|
38 |
os.mkdir(os.path.join(project_path, 'training', 'xml')) |
|
39 |
os.mkdir(os.path.join(project_path, 'training', 'img')) |
|
40 |
os.mkdir(os.path.join(project_path, 'test')) |
|
41 |
os.mkdir(os.path.join(project_path, 'test', 'xml')) |
|
42 |
os.mkdir(os.path.join(project_path, 'test', 'img')) |
|
43 |
except FileNotFoundError: |
|
44 |
return make_response('File Not Found', 404) |
|
45 | ||
46 |
if class_file and allowed_file(class_file.filename): |
|
47 |
class_file.save(class_file.filename) |
|
48 |
with zipfile.ZipFile(class_file.filename) as zip_file: |
|
49 |
zip_file.extractall(path=project_path) |
|
50 |
else: |
|
51 |
return False |
|
52 | ||
53 |
if training_file and allowed_file(training_file.filename): |
|
54 |
training_file.save(training_file.filename) |
|
55 |
with zipfile.ZipFile(training_file.filename) as zip_file: |
|
56 |
zip_file.extractall(path=project_path) |
|
57 |
else: |
|
58 |
return False |
|
59 | ||
60 |
return True |
|
61 | ||
62 | ||
63 |
@api.route('/training') |
|
64 |
@api.expect(training_parser) |
|
65 |
class TrainingModel(Resource): |
|
66 |
def post(self): |
|
67 |
args = training_parser.parse_args() |
|
68 |
project_name = args['project_name'] |
|
69 |
class_name_file = request.files['class_name_file'] |
|
70 | ||
71 |
project_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\', |
|
72 |
project_name) |
|
73 | ||
74 |
if class_name_file: |
|
75 |
classes = class_name_file.read().decode('ascii').splitlines() |
|
76 |
else: |
|
77 |
return False |
|
78 | ||
79 |
return jsonify({'count': 1}) |
DTI_PID/WebServer/app/templates/training/index.html | ||
---|---|---|
8 | 8 |
{{ form.csrf_token }} |
9 | 9 | |
10 | 10 |
<table border="0"> |
11 |
<tr><td>{{ form.training_description.label }}</td><td>{{ form.training_description() }}</td></tr>
|
|
11 |
<tr><td>{{ form.class_file.label }}</td><td>{{ form.class_file() }}</td></tr>
|
|
12 | 12 |
<tr><td>{{ form.training_file.label }}</td><td>{{ form.training_file() }}</td></tr> |
13 |
<tr><td><button class="btn btn-sm btn-success" type="submit">Add Recipe</button></td></tr>
|
|
13 |
<tr><td><button class="btn btn-sm btn-success" type="submit">Upload training data</button></td></tr>
|
|
14 | 14 |
</table> |
15 | 15 |
</form> |
16 | 16 |
{% endblock content %} |
DTI_PID/WebServer/app/training/index.py | ||
---|---|---|
23 | 23 | |
24 | 24 |
class TrainingForm(FlaskForm): |
25 | 25 |
csrf_token = 'id2 web service' |
26 |
training_description = StringField('Description') |
|
26 |
project_name = StringField('Project Name', required=True) |
|
27 |
class_file = FileField('Class File', validators=[FileRequired()]) |
|
27 | 28 |
training_file = FileField('Training File', validators=[FileRequired()]) |
28 | 29 | |
29 | 30 | |
... | ... | |
87 | 88 |
train.train(name=datas['name'], classes=datas['classes'], root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath( |
88 | 89 |
__file__)) + '\\..\\..\\symbol_training\\pre_trained_model\\only_params_trained_yolo_voc') |
89 | 90 |
|
90 |
return jsonify({'count': 1}) |
|
91 |
return jsonify({'count': 1}) |
내보내기 Unified diff