개정판 072a26bd
issue #1336: package test
DTI_PID/WebServer/run.py | ||
---|---|---|
1 | 1 |
from app import app |
2 | 2 |
|
3 | 3 |
if __name__ == '__main__': |
4 |
app.run(port=8080, debug=True, host='0.0.0.0') |
|
5 |
#app.run(debug=False) |
|
4 |
#app.run(port=8080, debug=False)#, host='0.0.0.0') |
|
5 |
app.run(debug=False) |
DTI_PID/WebServer/symbol_training/src/doftech_dataset.py | ||
---|---|---|
51 | 51 |
img_path = self.img_path_list[index] |
52 | 52 |
self.img_list.append(cv2.imread(img_path)) |
53 | 53 |
|
54 |
''' |
|
54 | 55 |
if self.is_training==True: |
55 | 56 |
if self.use_rescale == True: |
56 | 57 |
new_image_list, new_label_list = self.rescale_func([self.img_list[-1]], [self.anno_list[-1]]) |
57 | 58 |
|
58 | 59 |
self.img_list.extend(new_image_list) |
59 | 60 |
self.anno_list.extend(new_label_list) |
60 |
''' |
|
61 |
''' |
|
62 |
|
|
61 | 63 |
if self.is_training==True: |
62 | 64 |
if self.use_rescale == True: |
63 | 65 |
new_image_list, new_label_list = self.rescale_func(self.img_list, self.anno_list) |
64 | 66 |
|
65 | 67 |
self.img_list.extend(new_image_list) |
66 | 68 |
self.anno_list.extend(new_label_list) |
67 |
''' |
|
69 |
|
|
68 | 70 |
|
69 | 71 |
def rescale_func(self, img_list, anno_list): |
70 | 72 |
#scale_array = [0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3] |
DTI_PID/WebServer/symbol_training/train.py | ||
---|---|---|
88 | 88 |
learning_rate_schedule = {"0": 1e-5, "5": 1e-4, |
89 | 89 |
"80": 1e-5, "110": 1e-6} |
90 | 90 |
|
91 |
training_params = {"batch_size": opt.batch_size, |
|
91 |
training_params = {"batch_size": 1,#opt.batch_size,
|
|
92 | 92 |
"shuffle": True, |
93 | 93 |
"drop_last": True, |
94 | 94 |
"collate_fn": custom_collate_fn} |
... | ... | |
217 | 217 |
best_epoch = epoch |
218 | 218 |
print("SAVE MODEL") |
219 | 219 |
# for debug for each loss |
220 |
torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth")) |
|
220 |
#torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth"))
|
|
221 | 221 |
#torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(best_loss) + ".pth")) |
222 | 222 |
# save |
223 | 223 |
torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params.pth")) |
... | ... | |
225 | 225 |
else: |
226 | 226 |
save_count += 1 |
227 | 227 |
# for debug for each loss |
228 |
torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth")) |
|
228 |
#torch.save(model.state_dict(), os.path.join(opt.saved_path, name + "_only_params_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth"))
|
|
229 | 229 |
#torch.save(model, os.path.join(opt.saved_path, name + "_whole_model_" + str(save_count) + "_" + "{:.5f}".format(te_loss) + ".pth")) |
230 | 230 |
|
231 | 231 |
# Early stopping |
내보내기 Unified diff