Bài 13: Giới thiệu về Convents – Pytorch Cơ bản

Trang chủ » Training » Bài 13: Giới thiệu về Convents – Pytorch Cơ bản
24/02/2022 Training 46 viewed
Convents là tất cả về việc xây dựng mô hình CNN từ đầu. Kiến trúc mạng sẽ bao gồm sự kết hợp của các bước sau:
  • Conv2d
  • MaxPool2d
  • Rectified Linear Unit
  • View
  • Linear Layer

Huấn luyện mô hình :

Huấn luyện mô hình là một quá trình giống như các bài toán phân loại ảnh. Đoạn code sau hoàn thành quy trình của mô hình đào tạo trên tập dữ liệu được cung cấp –
def fit(epoch,model,data_loader,phase 
= 'training',volatile = False):
   if phase == 'training':
      model.train()
   if phase == 'training':
      model.train()
   if phase == 'validation':
      model.eval()
   volatile=True
   running_loss = 0.0
   running_correct = 0
   for batch_idx , (data,target) in enumerate(data_loader):
      if is_cuda:
         data,target = data.cuda(),target.cuda()
         data , target = Variable(data,volatile),Variable(target)
      if phase == 'training':
         optimizer.zero_grad()
         output = model(data)
         loss = F.nll_loss(output,target)
         running_loss + = 
         F.nll_loss(output,target,size_average = 
         False).data[0]
         preds = output.data.max(dim = 1,keepdim = True)[1]
         running_correct + = 
         preds.eq(target.data.view_as(preds)).cpu().sum()
         if phase == 'training':
            loss.backward()
            optimizer.step()
   loss = running_loss/len(data_loader.dataset)
   accuracy = 100. * running_correct/len(data_loader.dataset)
   print(f'{phase} loss is {loss:{5}.{2}} and {phase} accuracy is {running_correct}/{len(data_loader.dataset)}{accuracy:{return loss,accuracy}})
Phương pháp này bao gồm các logic khác nhau để training và validation. Có hai lý do chính để sử dụng các mode khác nhau
  • Ở chế độ huấn luyện, dropout sẽ loại bỏ một phần trăm giá trị, điều này sẽ không xảy ra trong quá trình validation hoặc testing.
  • Tính toán độ dốc và thay đổi giá trị tham số của mô hình, nhưng không cần truyền ngược trong giai đoạn thử nghiệm hoặc xác nhận.
Chia sẻ:
Tags:
TOP HOME