知識蒸餾(Knowledge Distillation)實作範例
這裡我們會用 PyTorch 來示範如何進行知識蒸餾,讓一個小型的學生模型學習大型教師模型的知識。這個範例將會:
- 定義教師模型(Teacher Model)
- 定義學生模型(Student Model)
- 使用標準的交叉熵損失 + KL 散度損失進行蒸餾
- 訓練學生模型並評估結果
我們以 MNIST 手寫數字辨識 為例,教師模型使用一個較大的 CNN,而學生模型使用一個較小的 CNN。
1. 安裝 & 引入必要的庫
請確保你已安裝 PyTorch,若未安裝可以使用:
bashpip install torch torchvision matplotlib
然後我們開始實作:
pythonimport torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
2. 定義教師模型(Teacher Model)
我們先定義一個較大的 CNN 作為教師模型。
pythonclass TeacherModel(nn.Module):
def __init__(self):
super(TeacherModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
3. 定義學生模型(Student Model)
學生模型是一個較小的 CNN,與教師模型相比,它的卷積層數較少、全連接層較小。
pythonclass StudentModel(nn.Module):
def __init__(self):
super(StudentModel, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 7 * 7, 64)
self.fc2 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
4. 設定數據集與 DataLoader
使用 MNIST 數據集進行訓練與測試:
python# 下載 MNIST 數據集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
5. 訓練教師模型
pythondevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model = TeacherModel().to(device)
optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
def train_teacher(model, optimizer, criterion, epochs=5):
model.train()
for epoch in range(epochs):
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")
train_teacher(teacher_model, optimizer, criterion)
6. 定義知識蒸餾損失
知識蒸餾的核心是 KL 散度損失,我們使用一個溫度參數 T 來平滑 Softmax 分佈:
pythondef knowledge_distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.5):
soft_targets = F.kl_div(F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1),
reduction='batchmean') * (T * T)
hard_targets = F.cross_entropy(student_logits, labels)
return alpha * hard_targets + (1 - alpha) * soft_targets
7. 訓練學生模型
我們讓學生模型學習教師模型的知識。
pythonstudent_model = StudentModel().to(device)
optimizer = optim.Adam(student_model.parameters(), lr=0.001)
def train_student(teacher_model, student_model, optimizer, epochs=5, T=4.0, alpha=0.5):
teacher_model.eval()
student_model.train()
for epoch in range(epochs):
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# 教師模型的預測
with torch.no_grad():
teacher_outputs = teacher_model(images)
student_outputs = student_model(images)
loss = knowledge_distillation_loss(student_outputs, teacher_outputs, labels, T, alpha)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")
train_student(teacher_model, student_model, optimizer)
8. 測試學生模型的表現
pythondef evaluate(model):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total:.2f}%")
evaluate(student_model)
總結
這個程式實現了基本的 知識蒸餾流程:
- 訓練教師模型
- 透過 Soft Label 讓學生學習
- 使用 KL 散度蒸餾知識
- 訓練與測試學生模型
學生模型能夠在 更小的網路架構下接近教師模型的準確率,
這樣的技術非常適合應用於 邊緣 AI、移動端模型壓縮等場景! 🚀
沒有留言:
張貼留言