Source code for refann.train

# -*- coding: utf-8 -*-

from . import optimize
from . import data_process as dp
import torch
from torch.autograd import Variable
import numpy as np


[docs]def loss_funcs(name='L1'): if name=='L1': lf = torch.nn.L1Loss() elif name=='MSE': lf = torch.nn.MSELoss() elif name=='SmoothL1': lf = torch.nn.SmoothL1Loss() return lf
[docs]class Train(object): def __init__(self,net,loss_func='L1',iteration=10000,optimizer='Adam'): self.net = net self.loss_func = loss_funcs(name=loss_func) self.iteration = iteration self.lr = 1e-1 self.lr_min = 1e-6 self.batch_size = 128 self.optimizer = self._optimizer(name=optimizer) def _prints(self, items, prints=True): if prints: print(items)
[docs] def call_GPU(self, prints=True): if torch.cuda.is_available(): self.use_GPU = True gpu_num = torch.cuda.device_count() if gpu_num > 1: self.use_multiGPU = True self._prints('\nTraining the network using {} GPUs'.format(gpu_num), prints=prints) else: self.use_multiGPU = False self._prints('\nTraining the network using 1 GPU', prints=prints) else: self.use_GPU = False self._prints('\nTraining the network using CPU', prints=prints)
[docs] def transfer_net(self, use_DDP=False, device_ids=None, prints=True): if device_ids is None: device = None else: device = device_ids[0] self.call_GPU(prints=prints) if self.use_GPU: self.net = self.net.cuda(device=device) if self.use_multiGPU: if use_DDP: self.net = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=device_ids) else: self.net = torch.nn.DataParallel(self.net, device_ids=device_ids)
[docs] def transfer_data(self, device=None): if self.use_GPU: self.inputs = dp.numpy2cuda(self.inputs, device=device) self.target = dp.numpy2cuda(self.target, device=device) else: self.inputs = dp.numpy2torch(self.inputs) self.target = dp.numpy2torch(self.target)
def _optimizer(self, name='Adam'): if name=='Adam': _optim = torch.optim.Adam(self.net.parameters(), lr=self.lr) return _optim
[docs] def train_0(self, xx, yy, iter_mid, repeat_n=3, lr_decay=True): xx = Variable(xx) yy = Variable(yy, requires_grad=False) for t in range(repeat_n): _predicted = self.net(xx) _loss = self.loss_func(_predicted, yy) self.optimizer.zero_grad() _loss.backward() self.optimizer.step() if lr_decay: #reduce the learning rate lrdc = optimize.LrDecay(iter_mid,iteration=self.iteration,lr=self.lr,lr_min=self.lr_min) self.optimizer.param_groups[0]['lr'] = lrdc.exp() return _loss.item(), _predicted.data
[docs] def train_1(self, inputs, target, repeat_n=1, set_seed=False, lr_decay=True, print_info=True, showIter_n=200): if self.batch_size > len(inputs): raise ValueError('The batch size should be smaller than the number of the training set') if set_seed: np.random.seed(1000)# loss_all = [] for iter_mid in range(1, self.iteration+1): batch_index = np.random.choice(len(inputs), self.batch_size, replace=False)#Note: replace=False xx = inputs[batch_index] yy = target[batch_index] _loss, _ = self.train_0(xx, yy, iter_mid, repeat_n=repeat_n, lr_decay=lr_decay) loss_all.append(_loss) if print_info: if iter_mid%showIter_n==0: print('(iteration:%s/%s; loss:%.5f; lr:%.8f)'%(iter_mid, self.iteration, _loss, self.optimizer.param_groups[0]['lr'])) return self.net, loss_all