프로젝트

일반

사용자정보

개정판 9be41199

ID9be411999eb9dcd875bf277f51364abc0c35f511
상위 0d19831e
하위 9c91a597, 150e63fa

함의성이(가) 4년 이상 전에 추가함

issue #1366: cu merge test

Change-Id: Ic21cb5de9fa1cff6045d257eb4f746a0357c26ca

차이점 보기:

DTI_PID/WebServer/symbol_recognition/src/utils.py
10 10
    items = list(zip(*batch))
11 11
    items[0] = default_collate(items[0])
12 12
    items[1] = list(items[1])
13
    items[2] = default_collate(items[2])
14

  
13 15
    return items
14 16

  
15 17

  
DTI_PID/WebServer/symbol_recognition/src/yolo_doftech.py
9 9
    def __init__(self, pre_model, num_classes,
10 10
                 anchors=[(1.3221, 1.73145), (3.19275, 4.00944), (5.05587, 8.09892), (9.47112, 4.84053),
11 11
                          (11.2364, 10.0071)]):
12
                 #anchors=[(0.7, 3.0), (3.0, 0.7), (1.3, 8.0), (8.0, 1.3), (3.0, 6.4), (6.4, 3.0), (6, 6), (1.5, 3.0), (3.0, 1.5)]):
13 12
        super(YoloD, self).__init__()
14 13

  
15

  
16 14
        self.num_classes = num_classes
17 15
        self.anchors = anchors
18 16

  
......
38 36
        self.stage2_a_conv5 = pre_model.stage2_a_conv5
39 37
        self.stage2_a_conv6 = pre_model.stage2_a_conv6
40 38
        self.stage2_a_conv7 = pre_model.stage2_a_conv7
41

  
42 39
        self.stage2_b_conv = pre_model.stage2_b_conv
43 40

  
44 41
        self.stage3_conv1 = pre_model.stage3_conv1
......
80 77
        output = torch.cat((output_1, output_2), 1)
81 78
        output = self.stage3_conv1(output)
82 79
        output = self.stage3_conv2(output)
80

  
83 81
        return output
84 82

  
85 83
if __name__ == '__main__':
DTI_PID/WebServer/symbol_recognition/src/yolo_net.py
1
"""
2
@author: Viet Nguyen <nhviet1009@gmail.com>
3
"""
4 1
import torch.nn as nn
5 2
import torch
6 3

  
7

  
8 4
class Yolo(nn.Module):
9 5
    def __init__(self, num_classes,
10 6
                 anchors=[(1.3221, 1.73145), (3.19275, 4.00944), (5.05587, 8.09892), (9.47112, 4.84053),
11 7
                          (11.2364, 10.0071)]):
12
                 #anchors=[(0.7, 3.0), (3.0, 0.7), (1.3, 8.0), (8.0, 1.3), (3.0, 6.4), (6.4, 3.0), (6, 6), (1.5, 3.0), (3.0, 1.5)]):
13 8
        super(Yolo, self).__init__()
14 9
        self.num_classes = num_classes
15 10
        self.anchors = anchors
......
42 37
                                           nn.LeakyReLU(0.1, inplace=True))
43 38

  
44 39
        self.stage2_a_maxpl = nn.MaxPool2d(2, 2)
40

  
45 41
        self.stage2_a_conv1 = nn.Sequential(nn.Conv2d(512, 1024, 3, 1, 1, bias=False),
46 42
                                            nn.BatchNorm2d(1024), nn.LeakyReLU(0.1, inplace=True))
47 43
        self.stage2_a_conv2 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1, 0, bias=False), nn.BatchNorm2d(512),
DTI_PID/WebServer/symbol_training/src/data_augmentation.py
5 5
from random import uniform
6 6
import cv2
7 7
from src.bbox_utils import *
8
from src.helpers import *
8 9

  
9 10
class Compose(object):
10 11

  
......
126 127
        return image, new_label
127 128

  
128 129
class Resize(object):
129

  
130 130
    def __init__(self, image_size):
131 131
        super().__init__()
132 132
        self.image_size = image_size
133

  
134 133
    def __call__(self, data):
135 134
        image, label = data
136 135
        height, width = image.shape[:2]
......
151 150
            resize_width = resized_xmax - resized_xmin
152 151
            resize_height = resized_ymax - resized_ymin
153 152
            new_label.append([resized_xmin, resized_ymin, resize_width, resize_height, lb[4]])
154

  
155 153
        return image, new_label
156 154

  
157
class Divide():
158 155

  
156
class ResizeImage(object):
159 157

  
160
    def __init__(self):
158
    def __init__(self, image_size):
161 159
        super().__init__()
162

  
160
        self.image_size = image_size
163 161

  
164 162
    def __call__(self, data):
165
        image, label = data
163
        image = data
164
        image = cv2.resize(image, (self.image_size, self.image_size))
165

  
166
        return image
166 167

  
167
        h, w = image.shape[:2]
168
        cx, cy = w // 2, h // 2
169 168

  
170
        #Resize(x2)
171
        image = cv2.resize(image, (w*2, h*2))
172
        resized_label = []
169
class Rotate(object):
170

  
171
    def __init__(self, angle, image_size):
172
        super().__init__()
173
        self.angle = angle
174
        self.image_size = image_size
175

  
176
    def rotateYolobbox(self, new_image, images, objects, rot_matrix):
177
        new_height, new_width = new_image.shape[:2]
178
        objects = objects
179
        new_bbox = []
180
        H, W = images.shape[:2]
181

  
182
        for x in objects:
183
            bbox = x[:4]
184
            if len(bbox) > 1:
185
                center_x = float((bbox[2]-bbox[0])/2 + bbox[0])/W
186
                center_y = float((bbox[3]-bbox[1])/2 + bbox[1])/H
187
                bbox_width = float((bbox[2]-bbox[0]))/W
188
                bbox_height = float((bbox[3]-bbox[1]))/H
189

  
190
                (center_x, center_y, bbox_width, bbox_height) = yoloFormattocv(center_x, center_y, bbox_width, bbox_height, H, W)
191

  
192
                upper_left_corner_shift = (center_x - W / 2, -H / 2 + center_y)
193
                upper_right_corner_shift = (bbox_width - W / 2, -H / 2 + center_y)
194
                lower_left_corner_shift = (center_x - W / 2, -H / 2 + bbox_height)
195
                lower_right_corner_shift = (bbox_width - W / 2, -H / 2 + bbox_height)
196

  
197
                new_lower_right_corner = [-1, -1]
198
                new_upper_left_corner = []
199

  
200
                for i in (upper_left_corner_shift, upper_right_corner_shift, lower_left_corner_shift,
201
                          lower_right_corner_shift):
202
                    new_coords = np.matmul(rot_matrix, np.array((i[0], -i[1])))
203
                    x_prime, y_prime = new_width / 2 + new_coords[0], new_height / 2 - new_coords[1]
204
                    if new_lower_right_corner[0] < x_prime:
205
                        new_lower_right_corner[0] = x_prime
206
                    if new_lower_right_corner[1] < y_prime:
207
                        new_lower_right_corner[1] = y_prime
208

  
209
                    if len(new_upper_left_corner) > 0:
210
                        if new_upper_left_corner[0] > x_prime:
211
                            new_upper_left_corner[0] = x_prime
212
                        if new_upper_left_corner[1] > y_prime:
213
                            new_upper_left_corner[1] = y_prime
214
                    else:
215
                        new_upper_left_corner.append(x_prime)
216
                        new_upper_left_corner.append(y_prime)
217
                #
218
                new_bbox.append([new_upper_left_corner[0], new_upper_left_corner[1],
219
                                 new_lower_right_corner[0], new_lower_right_corner[1], x[4]])
220
                # new_bbox.append([new_upper_left_corner[0], new_upper_left_corner[1],
221
                #                  new_lower_right_corner[0]-new_upper_left_corner[0], new_lower_right_corner[1]-new_upper_left_corner[1], x[4]])
222

  
223
        return new_bbox
224

  
225
    def rotate_image(self, images, angle):
226
        """
227
        Rotates an image (angle in degrees) and expands image to avoid cropping
228
        """
229
        height, width = images.shape[:2]  # image shape has 3 dimensions
230
        image_center = (width / 2,
231
                        height / 2)  # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape
232

  
233
        rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.)
234

  
235
        # rotation calculates the cos and sin, taking absolutes of those.
236
        abs_cos = abs(rotation_mat[0, 0])
237
        abs_sin = abs(rotation_mat[0, 1])
238

  
239
        # find the new width and height bounds
240
        bound_w = int(height * abs_sin + width * abs_cos)
241
        bound_h = int(height * abs_cos + width * abs_sin)
242

  
243
        # subtract old image center (bringing image back to origin) and adding the new image center coordinates
244
        rotation_mat[0, 2] += bound_w / 2 - image_center[0]
245
        rotation_mat[1, 2] += bound_h / 2 - image_center[1]
246

  
247
        # rotate image with the new bounds and translated rotation matrix
248
        rotated_mat = cv2.warpAffine(images, rotation_mat, (bound_w, bound_h))
249

  
250
        return rotated_mat
251

  
252
    def resize(self, image, label):
253
        height, width = image.shape[:2]
254
        new_image = cv2.resize(image, (self.image_size, self.image_size))
255
        width_ratio = float(self.image_size) / width
256
        height_ratio = float(self.image_size) / height
257
        new_label = []
173 258
        for lb in label:
174
            resized_xmin = lb[0] * 2
175
            resized_ymin = lb[1] * 2
176
            resized_xmax = lb[2] * 2
177
            resized_ymax = lb[3] * 2
259
            if int(lb[0]) < 0: lb[0] = 0.0
260
            if int(lb[1]) < 0: lb[1] = 0.0
261
            if int(lb[2]) < 0: lb[2] = 0.0
262
            if int(lb[3]) < 0: lb[3] = 0.0
263

  
264
            resized_xmin = lb[0] * width_ratio #+ width_ratio
265
            resized_ymin = lb[1] * height_ratio #+ height_ratio
266
            resized_xmax = lb[2] * width_ratio #+ width_ratio
267
            resized_ymax = lb[3] * height_ratio #+height_ratio
178 268
            resize_width = resized_xmax - resized_xmin
179 269
            resize_height = resized_ymax - resized_ymin
180
            resized_label.append([resized_xmin, resized_ymin, resize_width, resize_height, lb[4]])
181

  
182
        #Divide(/4)
183
        new_image = []
270
            new_label.append([resized_xmin, resized_ymin, resize_width, resize_height, lb[4]])
184 271

  
185
        for num in 4:
186
            copy_image = image.copy()
272
        return new_image, new_label
187 273

  
188
            if num==1:
189
                copy_image = image[0:cx, 0:cy]
190
            elif num==2:
191
                copy_image = image[cx:w, 0:cy]
192
            elif num==3:
193
                copy_image = image[0:cx, cy:h]
194
            else:
195
                copy_image = image[cx:w, cy:h]
274
    def __call__(self, data):
275
        self.images, self.objects = data
276
        rotation_angle = self.angle * np.pi / 180
277
        self.rot_matrix = np.array(
278
            [[np.cos(rotation_angle), -np.sin(rotation_angle)], [np.sin(rotation_angle), np.cos(rotation_angle)]])
196 279

  
197
            new_image.append(copy_image)
280
        new_image = self.rotate_image(self.images, self.angle)
198 281

  
199
        new_label = []
200
        #label 나누기!
282
        new_label = self.rotateYolobbox(new_image, self.images, self.objects, self.rot_matrix)
283
        #
284
        new_image, new_label = self.resize(new_image, new_label)
201 285

  
202 286
        return new_image, new_label
287

  
DTI_PID/WebServer/symbol_training/src/doftech_dataset.py
4 4
from src.data_augmentation import *
5 5
import glob
6 6
import random
7
import torch
8

  
7 9

  
8 10
class DoftechDataset(Dataset):
9 11
    def __init__(self, root_path="./data/", image_size=448, is_training=True, classes=None):
10
        self.data_path = root_path
11
        self.img_list = sorted(glob.glob(os.path.join(self.data_path, 'img/*.png')))
12
        self.anno_list = sorted(glob.glob(os.path.join(self.data_path, 'xml/*.xml')))
13 12

  
14 13
        self.classes = classes
15 14

  
16 15
        self.image_size = image_size
17 16
        self.num_classes = len(self.classes)
18
        self.num_images = len(self.img_list)
19 17
        self.is_training = is_training
18

  
20 19
        self.use_rescale = True
21
        self.use_rotation = False
20
        self.use_rotation = True
21

  
22
        self.data_path = root_path
23
        self.img_path = root_path+'/images/'
24
        self.xml_path = root_path + '/xml/'
25
        self.img_path_list = sorted([item for sublist in [glob.glob(self.img_path + ext) for ext in ["*.jpg", "*.png"]] for item in sublist])#sorted(glob.glob(os.path.join(self.img_path, '*.jpg')))
26

  
27
        self.img_list = []
28
        for img_path in self.img_path_list:
29
            self.img_list.append(cv2.imread(img_path))
30

  
31
        self.img_list_new = []
32
        self.anno_list_new = []
33

  
34
        self.anno_path_list = sorted([item for sublist in [glob.glob(self.xml_path + ext) for ext in ["*.xml"]] for item in sublist])#sorted(glob.glob(os.path.join(self.xml_path, '*.xml')))
35
        self.anno_list = []
36

  
37
        for anno_path in self.anno_path_list:
38
            annot = ET.parse(anno_path)
39
            objects = []
40
            for obj in annot.findall('object'):
41
                xmin, xmax, ymin, ymax = [int(obj.find('bndbox').find(tag).text) - 1 for tag in
42
                                          ["xmin", "xmax", "ymin", "ymax"]]
43
                if obj.find('name').text in self.classes:                     
44
                    #label = self.classes.index(obj.find('name').text.lower().strip())
45
                    label = self.classes.index(obj.find('name').text)
46
                    objects.append([xmin, ymin, xmax, ymax, label])
47
            self.anno_list.append(objects)
48

  
49
        if self.is_training==True:
50
            if self.use_rescale == True:
51
                new_image_list, new_label_list = self.rescale_func(self.img_list, self.anno_list)
52

  
53
                self.img_list.extend(new_image_list)
54
                self.anno_list.extend(new_label_list)
55

  
56
    def rescale_func(self, img_list, anno_list):
57
        scale_array = [0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
58

  
59
        new_image_list = []
60
        new_label_list = []
61

  
62
        for idx, img in enumerate(img_list):
63
            objects = anno_list[idx]
64
            height, width = img.shape[:2]
65
            cut_range = 0
66
            for obj in objects:
67
                new_label = []
68
                if obj[0] - cut_range > 0 and obj[1] - cut_range > 0 and obj[2] + cut_range > 0 and obj[
69
                    3] + cut_range > 0:
70
                    random_int = random.randint(0, 6)
71

  
72
                    cut_xmin = obj[0] - cut_range
73
                    cut_ymin = obj[1] - cut_range
74
                    cut_xmax = obj[2] + cut_range
75
                    cut_ymax = obj[3] + cut_range
76

  
77
                    box_width = obj[2] - obj[0]
78
                    box_height = obj[3] - obj[1]
79

  
80
                    cut_height = cut_ymax - cut_ymin
81
                    cut_width = cut_xmax - cut_xmin
82
                    cut_image = img[cut_ymin:cut_ymin + cut_height, cut_xmin:cut_xmin + cut_width]
83

  
84
                    # Rescale
85
                    rescale_size_height = int((cut_height - cut_range) * scale_array[random_int])
86
                    rescale_size_width = int((cut_width - cut_range) * scale_array[random_int])
87
                    scale_image = cv2.resize(cut_image, (rescale_size_width, rescale_size_height))
88
                    scale_height, scale_width = scale_image.shape[:2]
89

  
90
                    width_ratio = float(rescale_size_width) / (cut_width)
91
                    height_ratio = float(rescale_size_height) / (cut_height)
92

  
93
                    resized_xmin = cut_range * width_ratio
94
                    resized_ymin = cut_range * height_ratio
95
                    resized_xmax = (cut_range + box_width) * width_ratio
96
                    resized_ymax = (cut_range + box_height) * height_ratio
97
                    resize_width = resized_xmax - resized_xmin
98
                    resize_height = resized_ymax - resized_ymin
99

  
100
                    # Calculate new bounding box
101
                    image_center_x, image_center_y = int(width / 2), int(height / 2)
102

  
103
                    bbox_xmin = image_center_x - (resize_width / 2)
104
                    bbox_ymin = image_center_y - (resize_height / 2)
105
                    bbox_xmax = bbox_xmin + resize_width
106
                    bbox_ymax = bbox_ymin + resize_height
107

  
108
                    # Image copy to new image
109
                    copy_range_xmin = int((width / 2) - (scale_width / 2))
110
                    copy_range_ymin = int((height / 2) - (scale_height / 2))
111

  
112
                    scale_new_image = np.zeros((height, width, 3), np.uint8)
113
                    cv2.rectangle(scale_new_image, (0, 0), (width, height), (255, 255, 255), -1)
114

  
115
                    scale_new_image[copy_range_ymin:copy_range_ymin + scale_height,
116
                    copy_range_xmin:copy_range_xmin + scale_width] = scale_image
117

  
118
                    new_label.append(([int(bbox_xmin), int(bbox_ymin), int(bbox_xmax), int(bbox_ymax), obj[4]]))
119
                    new_image_list.append(scale_new_image)
120
                    new_label_list.append(new_label)
121

  
122
        return new_image_list, new_label_list
22 123

  
23 124
    def __len__(self):
24
        return self.num_images
125
        return len(self.img_list)
25 126

  
26 127
    def __getitem__(self, item):
27
        image_path = os.path.join(self.img_list[item])
28
        image = cv2.imread(image_path)
29
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
30
        image_xml_path = os.path.join(self.anno_list[item])
31
        annot = ET.parse(image_xml_path)
32
        #0.7, 0.8, 0.9, 1.0, 1,1 1,2, 1..3
33
        objects = []
34
        for obj in annot.findall('object'):
35
            xmin, xmax, ymin, ymax = [int(obj.find('bndbox').find(tag).text) - 1 for tag in
36
                                      ["xmin", "xmax", "ymin", "ymax"]]
37
            #if obj.find('name').text.lower().strip() in self.classes:
38
            if obj.find('name').text in self.classes:
39
                #label = self.classes.index(obj.find('name').text.lower().strip())
40
                label = self.classes.index(obj.find('name').text)
41
                objects.append([xmin, ymin, xmax, ymax, label])
42
            continue
128
        origin_image = self.img_list[item]
129
        objects = self.anno_list[item]
130
        origin_image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB)
43 131
        if self.is_training:
44
            if self.use_rescale:
45
                scale_array = [0.7,0.8,0.9,1.0,1.1,1.2,1.3]
46
                random_int = random.randint(0,6)
47
                height, width = image.shape[:2]
48
                rescale_size = int(height * scale_array[random_int])
49
                transformations = Compose([Rescale(rescale_size), Resize(self.image_size)])
50 132
            if self.use_rotation:
51 133
                angle_array = [0,30,45,60,90,120,135,150,180,210,240,255,270,300,330,345,360]
52 134
                random_int = random.randint(0,16)
135
                transformations = Compose([Rotate(angle_array[random_int], self.image_size)])
53 136
            else :
54 137
                transformations = Compose([Resize(self.image_size)])
55 138

  
56 139
        else:
57 140
            transformations = Compose([Resize(self.image_size)])
58 141

  
59
        image, objects = transformations((image, objects))
60
        return np.transpose(np.array(image, dtype=np.float32), (2, 0, 1)), np.array(objects, dtype=np.float32)
61

  
62
class DoftechDatasetTest(Dataset):
63
    def __init__(self, root_path="./data/", image_size=448, is_training=True, classes=None):
64
        self.data_path = root_path
65
        self.img_list = sorted(glob.glob(os.path.join(self.data_path, 'img/*.png')))
66
        self.anno_list = sorted(glob.glob(os.path.join(self.data_path, 'xml/*.xml')))
67

  
68
        self.classes = classes
142
        image, objects = transformations((origin_image, objects))
69 143

  
70
        self.image_size = image_size
71
        self.num_classes = len(self.classes)
72
        self.num_images = len(self.img_list)
73
        self.is_training = is_training
144
        resize_transform = ResizeImage(self.image_size)
74 145

  
75
    def __len__(self):
76
        return self.num_images
146
        image = np.transpose(np.array(image, dtype=np.float32), (2, 0, 1))
147
        objects = np.array(objects, dtype=np.float32)
77 148

  
78
    def __getitem__(self, item):
79
        image_path = os.path.join(self.img_list[item])
80
        image = cv2.imread(image_path)
81
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
82
        image_xml_path = os.path.join(self.anno_list[item])
83
        annot = ET.parse(image_xml_path)
84

  
85
        objects = []
86
        for obj in annot.findall('object'):
87
            xmin, xmax, ymin, ymax = [int(obj.find('bndbox').find(tag).text) - 1 for tag in
88
                                      ["xmin", "xmax", "ymin", "ymax"]]
89
            #if obj.find('name').text.lower().strip() in self.classes:
90
            if obj.find('name').text in self.classes:
91
                #label = self.classes.index(obj.find('name').text.lower().strip())
92
                label = self.classes.index(obj.find('name').text)
93
                objects.append([xmin, ymin, xmax, ymax, label])
94
            continue
95
        if self.is_training:
96
            transformations = Compose([Resize(self.image_size)])
97
        else:
98
            transformations = Compose([Resize(self.image_size)])
99
        image, objects = transformations((image, objects))
100
        return np.transpose(np.array(image, dtype=np.float32), (2, 0, 1)), np.array(objects, dtype=np.float32)
149
        return image, objects, np.transpose(np.array(resize_transform(origin_image), dtype=np.float32), (2, 0, 1))
DTI_PID/WebServer/symbol_training/src/utils.py
10 10
    items = list(zip(*batch))
11 11
    items[0] = default_collate(items[0])
12 12
    items[1] = list(items[1])
13
    items[2] = default_collate(items[2])
14

  
13 15
    return items
14 16

  
15 17

  
DTI_PID/WebServer/symbol_training/src/yolo_doftech.py
9 9
    def __init__(self, pre_model, num_classes,
10 10
                 anchors=[(1.3221, 1.73145), (3.19275, 4.00944), (5.05587, 8.09892), (9.47112, 4.84053),
11 11
                          (11.2364, 10.0071)]):
12
                 #anchors=[(0.7, 3.0), (3.0, 0.7), (1.3, 8.0), (8.0, 1.3), (3.0, 6.4), (6.4, 3.0), (4.5, 4.5), (1.5, 3.0), (3.0, 1.5)]):
13 12
        super(YoloD, self).__init__()
14 13

  
15 14
        self.num_classes = num_classes
DTI_PID/WebServer/symbol_training/src/yolo_net.py
1
"""
2
@author: Viet Nguyen <nhviet1009@gmail.com>
3
"""
4 1
import torch.nn as nn
5 2
import torch
6 3

  
7

  
8 4
class Yolo(nn.Module):
9 5
    def __init__(self, num_classes,
10 6
                 anchors=[(1.3221, 1.73145), (3.19275, 4.00944), (5.05587, 8.09892), (9.47112, 4.84053),
......
41 37
                                           nn.LeakyReLU(0.1, inplace=True))
42 38

  
43 39
        self.stage2_a_maxpl = nn.MaxPool2d(2, 2)
40

  
44 41
        self.stage2_a_conv1 = nn.Sequential(nn.Conv2d(512, 1024, 3, 1, 1, bias=False),
45 42
                                            nn.BatchNorm2d(1024), nn.LeakyReLU(0.1, inplace=True))
46 43
        self.stage2_a_conv2 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1, 0, bias=False), nn.BatchNorm2d(512),
DTI_PID/WebServer/symbol_training/train.py
29 29
                  'strainer_basket', 'strainer_conical', 'fitting_capillary_tubing', 'meter_ultrasonic', 'strainer_y', 'tube_pitot'
30 30
                  ,'opc']
31 31

  
32
#print(len(DOFTECH_CLASSES))
32
use_voc_model = True
33
use_visdom = False
34
if use_visdom
35
    visdom = visdom.Visdom()
33 36

  
34 37
def train(name=None, classes=None, bigs=None, root_path=None, pre_trained_model_path=None):
35 38
    global DOFTECH_CLASSES
......
81 84
        torch.cuda.manual_seed(123)
82 85
    else:
83 86
        torch.manual_seed(123)
87

  
84 88
    learning_rate_schedule = {"0": 1e-5, "5": 1e-4,
85 89
                              "80": 1e-5, "110": 1e-6}
86 90

  
87
    training_params = {"batch_size": opt.batch_size,
91
    training_params = {"batch_size": 1,#opt.batch_size,
88 92
                       "shuffle": True,
89 93
                       "drop_last": True,
90 94
                       "collate_fn": custom_collate_fn}
......
97 101
    training_set = DoftechDataset(opt.data_path, opt.image_size, is_training=True, classes=DOFTECH_CLASSES)
98 102
    training_generator = DataLoader(training_set, **training_params)
99 103

  
100
    test_set = DoftechDatasetTest(opt.data_path_test, opt.image_size, is_training=False, classes=DOFTECH_CLASSES)
104
    test_set = DoftechDataset(opt.data_path_test, opt.image_size, is_training=False, classes=DOFTECH_CLASSES)
101 105
    test_generator = DataLoader(test_set, **test_params)
102 106

  
103
    pre_model = Yolo(20).cuda()
104
    pre_model.load_state_dict(torch.load(opt.pre_trained_model_path), strict=False)
105

  
106
    model = YoloD(pre_model, training_set.num_classes).cuda()
107
    # BUILDING MODEL =======================================================================
108
    if use_voc_model :
109
        pre_model = Yolo(20).cuda()
110
        pre_model.load_state_dict(torch.load(opt.pre_trained_model_path), strict=False)
111
        model = YoloD(pre_model, training_set.num_classes).cuda()
112
    else :
113
        model = Yolo(training_set.num_classes).cuda()
107 114

  
108 115
    nn.init.normal_(list(model.modules())[-1].weight, 0, 0.01)
109 116

  
......
112 119

  
113 120
    best_loss = 1e10
114 121
    best_epoch = 0
115
    model.train()
116 122
    num_iter_per_epoch = len(training_generator)
117

  
118 123
    loss_step = 0
124
    # ======================================================================================
119 125

  
126
    # TRAINING =============================================================================
120 127
    save_count = 0
121 128

  
129
    model.train()
122 130
    for epoch in range(opt.num_epoches):
123 131
        if str(epoch) in learning_rate_schedule.keys():
124 132
            for param_group in optimizer.param_groups:
125 133
                param_group['lr'] = learning_rate_schedule[str(epoch)]
126 134

  
127 135
        for iter, batch in enumerate(training_generator):
128
            image, label = batch
136
            image, label, image2 = batch
137
            image = Variable(image.cuda(), requires_grad=False)
129 138
            if torch.cuda.is_available():
130 139
                image = Variable(image.cuda(), requires_grad=True)
140
                origin = Variable(image2.cuda(), requires_grad=False)
131 141
            else:
132 142
                image = Variable(image, requires_grad=True)
133 143

  
......
135 145
            logits = model(image)
136 146
            loss, loss_coord, loss_conf, loss_cls = criterion(logits, label)
137 147
            loss.backward()
138

  
139 148
            optimizer.step()
140 149

  
141 150
            if iter % opt.test_interval == 0:
......
143 152
                    (epoch + 1, opt.num_epoches, iter + 1, num_iter_per_epoch, optimizer.param_groups[0]['lr'], loss,
144 153
                    loss_coord,loss_conf,loss_cls))
145 154

  
146
                predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold,
147
                                              opt.nms_threshold)
155
                if use_visdom:
156
                    predictions = post_processing(logits, opt.image_size, DOFTECH_CLASSES, model.anchors, opt.conf_threshold,
157
                                                  opt.nms_threshold)
148 158

  
149
                gt_image = at.tonumpy(image[0])
150
                gt_image = visdom_bbox(gt_image, label[0])
151
                #visdom.image(gt_image, opts=dict(title='gt_box_image'), win=3)
159
                    gt_image = at.tonumpy(image[0])
160
                    gt_image = visdom_bbox(gt_image, label[0])
161
                    visdom.image(gt_image, opts=dict(title='gt_box_image'), win=3)
162
                    #
163
                    origin_image = at.tonumpy(origin[0])
164
                    origin_image = visdom_bbox(origin_image, [])
165
                    visdom.image(origin_image, opts=dict(title='origin_box_image'), win=4)
152 166

  
153
                if len(predictions) != 0:
154 167
                    image = at.tonumpy(image[0])
155
                    box_image = visdom_bbox(image, predictions[0])
156
                    #visdom.image(box_image, opts=dict(title='box_image'), win=2)
157 168

  
158
                elif len(predictions) == 0:
159
                    box_image = tensor2im(image)
160
                    #visdom.image(box_image.transpose([2, 0, 1]), opts=dict(title='box_image'), win=2)
169
                    if len(predictions) != 0:
170
                        box_image = visdom_bbox(image, predictions[0])
171
                        visdom.image(box_image, opts=dict(title='box_image'), win=2)
172

  
173
                    elif len(predictions) == 0:
174
                        box_image = visdom_bbox(image, [])
175
                        visdom.image(box_image, opts=dict(title='box_image'), win=2)
161 176

  
162
                loss_dict = {
163
                    'total' : loss.item(),
164
                    'coord' : loss_coord.item(),
165
                    'conf' : loss_conf.item(),
166
                    'cls' : loss_cls.item()
167
                }
177
                    loss_dict = {
178
                        'total' : loss.item(),
179
                        'coord' : loss_coord.item(),
180
                        'conf' : loss_conf.item(),
181
                        'cls' : loss_cls.item()
182
                    }
168 183

  
169
                #visdom_loss(visdom, loss_step, loss_dict)
170
                loss_step = loss_step + 1
184
                    visdom_loss(visdom, loss_step, loss_dict)
185
                    loss_step = loss_step + 1
171 186

  
172 187
        if epoch % opt.test_interval == 0:
173 188
            model.eval()
......
232 247
        update='append'
233 248
    )
234 249

  
235
def tensor2im(image_tensor, imtype=np.uint8):
236

  
237
    image_numpy = image_tensor[0].detach().cpu().float().numpy()
238

  
239
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
240

  
241
    image_numpy = np.clip(image_numpy, 0, 255)
242

  
243
    return image_numpy.astype(imtype)
244

  
245
def denormalize(tensors):
246
    """ Denormalizes image tensors using mean and std """
247
    mean = np.array([0.5, 0.5, 0.5])
248
    std = np.array([0.5, 0.5, 0.5])
249

  
250
    # mean = np.array([0.47571, 0.50874, 0.56821])
251
    # std = np.array([0.10341, 0.1062, 0.11548])
252

  
253
    denorm = tensors.clone()
254

  
255
    for c in range(tensors.shape[1]):
256
        denorm[:, c] = denorm[:, c].mul_(std[c]).add_(mean[c])
257

  
258
    denorm = torch.clamp(denorm, 0, 255)
259

  
260
    return denorm
261

  
262 250
if __name__ == "__main__":
263 251
    datas = ['gate', 'globe', 'butterfly', 'check', 'ball', 'relief',
264 252
                  '3way_solenoid', 'gate_pressure', 'globe_pressure', 'butterfly_pressure', 'ball_shutoff', 'ball_pressure','ball_motor', 'plug_pressure',

내보내기 Unified diff

클립보드 이미지 추가 (최대 크기: 500 MB)