프로젝트

일반

사용자정보

개정판 cfdb92c4

IDcfdb92c4ad71d286a2f4887738c4320cf3ac2494
상위 180aec46
하위 bfd45a14

백흠경이(가) 약 5년 전에 추가함

issue #1366: upload training restful api

Change-Id: I2e51bca1194d1ad4d19c04008eb054656df54a9b

차이점 보기:

DTI_PID/WebServer/app/__init__.py
7 7
from app.training.index import training_service
8 8
from app.api.controller.license_controller import api as license_ns
9 9
from app.api.controller.license_controller import license_service
10
from app.api.controller.TrainingController import api as training_ns
11
from app.api.controller.TrainingController import TrainingModel
10 12
from app.api.models import LicenseDTO
11 13

  
12 14

  
......
39 41
          description='flask restx web service for ID2'
40 42
          )
41 43
api.add_namespace(license_ns, path='/license')
44
api.add_namespace(training_ns, path='/training')
DTI_PID/WebServer/app/api/controller/TrainingController.py
1
import os
2
import zipfile
3
from flask import request, Blueprint, redirect, url_for, render_template, make_response, jsonify
4
from flask_restx import Namespace, Resource, reqparse
5
from werkzeug.datastructures import FileStorage
6

  
7
api = Namespace('training', description='file related operations')
8
upload_parser = api.parser()
9
upload_parser.add_argument('project_name', type=str, required=True)
10
upload_parser.add_argument('class_file', location='files', type=FileStorage, required=True)
11
upload_parser.add_argument('training_file', location='files', type=FileStorage, required=True)
12

  
13
training_parser = api.parser()
14
training_parser.add_argument('project_name', type=str, required=True)
15
training_parser.add_argument('class_name_file', location='files', type=FileStorage, required=True)
16

  
17

  
18
def allowed_file(filename):
19
    return os.path.splitext(filename)[1].upper() == '.ZIP'
20

  
21

  
22
@api.route('/upload_files')
23
@api.expect(upload_parser)
24
class FileStorage(Resource):
25
    def post(self):
26
        args = upload_parser.parse_args()
27
        project_name = args['project_name']
28
        class_file = request.files['class_file']
29
        training_file = request.files['training_file']
30

  
31
        project_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\',
32
                                 project_name)
33

  
34
        try:
35
            if not os.path.isdir(project_path):
36
                os.mkdir(project_path)
37
                os.mkdir(os.path.join(project_path, 'training'))
38
                os.mkdir(os.path.join(project_path, 'training', 'xml'))
39
                os.mkdir(os.path.join(project_path, 'training', 'img'))
40
                os.mkdir(os.path.join(project_path, 'test'))
41
                os.mkdir(os.path.join(project_path, 'test', 'xml'))
42
                os.mkdir(os.path.join(project_path, 'test', 'img'))
43
        except FileNotFoundError:
44
            return make_response('File Not Found', 404)
45

  
46
        if class_file and allowed_file(class_file.filename):
47
            class_file.save(class_file.filename)
48
            with zipfile.ZipFile(class_file.filename) as zip_file:
49
                zip_file.extractall(path=project_path)
50
        else:
51
            return False
52

  
53
        if training_file and allowed_file(training_file.filename):
54
            training_file.save(training_file.filename)
55
            with zipfile.ZipFile(training_file.filename) as zip_file:
56
                zip_file.extractall(path=project_path)
57
        else:
58
            return False
59

  
60
        return True
61

  
62

  
63
@api.route('/training')
64
@api.expect(training_parser)
65
class TrainingModel(Resource):
66
    def post(self):
67
        args = training_parser.parse_args()
68
        project_name = args['project_name']
69
        class_name_file = request.files['class_name_file']
70

  
71
        project_path = os.path.join(os.path.dirname(os.path.realpath(__file__)) + '\\..\\..\\symbol_training\\Data\\',
72
                                    project_name)
73

  
74
        if class_name_file:
75
            classes = class_name_file.read().decode('ascii').splitlines()
76
        else:
77
            return False
78

  
79
        return jsonify({'count': 1})
DTI_PID/WebServer/app/templates/training/index.html
8 8
        {{ form.csrf_token }}
9 9

  
10 10
        <table border="0">
11
            <tr><td>{{ form.training_description.label }}</td><td>{{ form.training_description() }}</td></tr>
11
            <tr><td>{{ form.class_file.label }}</td><td>{{ form.class_file() }}</td></tr>
12 12
            <tr><td>{{ form.training_file.label }}</td><td>{{ form.training_file() }}</td></tr>
13
            <tr><td><button class="btn btn-sm btn-success" type="submit">Add Recipe</button></td></tr>
13
            <tr><td><button class="btn btn-sm btn-success" type="submit">Upload training data</button></td></tr>
14 14
        </table>
15 15
    </form>
16 16
{% endblock content %}
DTI_PID/WebServer/app/training/index.py
23 23

  
24 24
class TrainingForm(FlaskForm):
25 25
    csrf_token = 'id2 web service'
26
    training_description = StringField('Description')
26
    project_name = StringField('Project Name', required=True)
27
    class_file = FileField('Class File', validators=[FileRequired()])
27 28
    training_file = FileField('Training File', validators=[FileRequired()])
28 29

  
29 30

  
......
87 88
            train.train(name=datas['name'], classes=datas['classes'], root_path=data_path, pre_trained_model_path=os.path.dirname(os.path.realpath(
88 89
                                                       __file__)) + '\\..\\..\\symbol_training\\pre_trained_model\\only_params_trained_yolo_voc')
89 90
            
90
        return jsonify({'count': 1})
91
        return jsonify({'count': 1})

내보내기 Unified diff