網頁

2025年3月13日 星期四

知識蒸餾(Knowledge Distillation)實作範例 PyTorch &手寫數字辨識 為例

 知識蒸餾(Knowledge Distillation)實作範例

這裡我們會用 PyTorch 來示範如何進行知識蒸餾,讓一個小型的學生模型學習大型教師模型的知識。這個範例將會:

  1. 定義教師模型(Teacher Model)
  2. 定義學生模型(Student Model)
  3. 使用標準的交叉熵損失 + KL 散度損失進行蒸餾
  4. 訓練學生模型並評估結果

我們以 MNIST 手寫數字辨識 為例,教師模型使用一個較大的 CNN,而學生模型使用一個較小的 CNN。


1. 安裝 & 引入必要的庫

請確保你已安裝 PyTorch,若未安裝可以使用:

bash
pip install torch torchvision matplotlib

然後我們開始實作:

python
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np

2. 定義教師模型(Teacher Model)

我們先定義一個較大的 CNN 作為教師模型。

python
class TeacherModel(nn.Module): def __init__(self): super(TeacherModel, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.fc1 = nn.Linear(64 * 7 * 7, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x

3. 定義學生模型(Student Model)

學生模型是一個較小的 CNN,與教師模型相比,它的卷積層數較少、全連接層較小。

python
class StudentModel(nn.Module): def __init__(self): super(StudentModel, self).__init__() self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) self.fc1 = nn.Linear(32 * 7 * 7, 64) self.fc2 = nn.Linear(64, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x

4. 設定數據集與 DataLoader

使用 MNIST 數據集進行訓練與測試:

python
# 下載 MNIST 數據集 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True) test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

5. 訓練教師模型

python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") teacher_model = TeacherModel().to(device) optimizer = optim.Adam(teacher_model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() def train_teacher(model, optimizer, criterion, epochs=5): model.train() for epoch in range(epochs): total_loss = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}") train_teacher(teacher_model, optimizer, criterion)

6. 定義知識蒸餾損失

知識蒸餾的核心是 KL 散度損失,我們使用一個溫度參數 TT 來平滑 Softmax 分佈:

python
def knowledge_distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.5): soft_targets = F.kl_div(F.log_softmax(student_logits / T, dim=1), F.softmax(teacher_logits / T, dim=1), reduction='batchmean') * (T * T) hard_targets = F.cross_entropy(student_logits, labels) return alpha * hard_targets + (1 - alpha) * soft_targets

7. 訓練學生模型

我們讓學生模型學習教師模型的知識。

python
student_model = StudentModel().to(device) optimizer = optim.Adam(student_model.parameters(), lr=0.001) def train_student(teacher_model, student_model, optimizer, epochs=5, T=4.0, alpha=0.5): teacher_model.eval() student_model.train() for epoch in range(epochs): total_loss = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) # 教師模型的預測 with torch.no_grad(): teacher_outputs = teacher_model(images) student_outputs = student_model(images) loss = knowledge_distillation_loss(student_outputs, teacher_outputs, labels, T, alpha) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}") train_student(teacher_model, student_model, optimizer)

8. 測試學生模型的表現

python
def evaluate(model): model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Accuracy: {100 * correct / total:.2f}%") evaluate(student_model)

總結

這個程式實現了基本的 知識蒸餾流程

  1. 訓練教師模型
  2. 透過 Soft Label 讓學生學習
  3. 使用 KL 散度蒸餾知識
  4. 訓練與測試學生模型

學生模型能夠在 更小的網路架構下接近教師模型的準確率

這樣的技術非常適合應用於 邊緣 AI、移動端模型壓縮等場景! 🚀

沒有留言:

張貼留言