
Leveraging Knowledge Distillation for Embedded AI: A Comprehensive Guide
In today’s rapidly evolving AI landscape, deploying deep learning models on resource-constrained devices is both a challenge and a necessity. Embedded AI—where inference happens directly on devices like smartphones, IoT sensors, or embedded systems—demands models that are both efficient and effective. In this guide, we explore how to bridge the gap between high-capacity models and lightweight models using knowledge distillation. We will define key concepts, outline project goals, delve into detailed code examples (with a focus on teacher and student training scripts), and analyze performance outputs.
1. Introduction and Definitions
Embedded AI vs. Normal AI
- Normal AI typically runs on powerful servers or in the cloud using complex models that have minimal resource constraints.
- Embedded AI runs on devices with limited computation, memory, and energy. Therefore, models for embedded AI must be optimized for size, speed, and energy efficiency.
Knowledge Distillation
Knowledge distillation is a model compression technique where a large, high-capacity teacher model transfers its “knowledge” to a smaller, more efficient student model. Instead of training the student from scratch, it learns to mimic the teacher by using:
- Hard Targets: The ground truth labels.
- Soft Targets: The probability distributions produced by the teacher model (often softened using temperature scaling).
This approach allows the student model to achieve competitive performance with significantly fewer parameters and faster inference speed, making it ideal for embedded AI applications.
2. Project Goals
This project aims to:
- Train a High-Capacity Teacher Model: Develop and train a convolutional neural network (CNN) on the CIFAR-10 dataset.
- Train a Lightweight Student Model via Knowledge Distillation: Use the teacher’s outputs to guide the training of a compact student model.
- Compare and Analyze Both Models: Evaluate and visualize their performance, complexity, inference speed, and class-wise metrics to understand the trade-offs between accuracy and efficiency.
3. Code Overview
Our solution is structured into several modular files:
- common.py: Contains shared definitions including the model architectures, data loaders, and evaluation routines.
- teacher_train.py: Trains the teacher model using standard techniques and saves the best-performing model.
- student_train.py: Loads the pre-trained teacher and trains the student model using knowledge distillation.
- compare_models.py: Loads both models to compute metrics, generate visualizations, and output detailed performance analysis.
Each file is designed to be maintainable and reusable, ensuring that the training, distillation, and evaluation processes are clearly separated.
4. Detailed Code and Explanation
4.1. Shared Components – common.py
This module defines both the TeacherNet and StudentNet architectures along with functions to load CIFAR-10 data and evaluate models.
# common.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Teacher Model: A deeper, high-capacity network
class TeacherNet(nn.Module):
def __init__(self, num_classes=10):
super(TeacherNet, self).__init__()
self.features = nn.Sequential(
# Block 1
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 16x16 output
# Block 2
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 8x8 output
# Block 3
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2) # 4x4 output
)
self.classifier = nn.Sequential(
nn.Linear(256 * 4 * 4, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1) # Flatten
x = self.classifier(x)
return x
# Student Model: A lightweight network optimized for efficiency
class StudentNet(nn.Module):
def __init__(self, num_classes=10):
super(StudentNet, self).__init__()
self.features = nn.Sequential(
# Block 1
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 16x16 output
# Block 2
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 8x8 output
# Block 3
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2) # 4x4 output
)
self.classifier = nn.Sequential(
nn.Linear(128 * 4 * 4, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1) # Flatten
x = self.classifier(x)
return x
# Data loaders for CIFAR-10
def get_data_loaders(batch_size=128):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
return trainloader, testloader
# Evaluation routine for a model
def evaluate(model, dataloader, criterion):
model.eval()
running_loss, correct, total = 0.0, 0, 0
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
avg_loss = running_loss / total
accuracy = correct / total
return avg_loss, accuracy
4.2. Training the Teacher Model – teacher_train.py
The teacher model is trained using standard supervised learning techniques. Key steps include:
- Data Preparation:
The CIFAR-10 data is loaded with data augmentation (random cropping and horizontal flipping) for robust training. - Model Initialization:
ATeacherNet
instance is created. This deeper architecture is designed to learn rich feature representations. - Optimizer and Scheduler:
- Adam Optimizer: Used with an initial learning rate.
- Learning Rate Scheduler (StepLR): Reduces the learning rate periodically (every
step_size
epochs) by a factor ofgamma
. This aids in fine-tuning the learning process.
- Training Loop:
- For each epoch, the model processes batches of images.
- The cross-entropy loss is computed and backpropagated.
- A progress bar (
tqdm
) monitors training progress. - After each epoch, the model is evaluated on the test set. If validation accuracy improves, the model’s state is saved.
# teacher_train.py
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import argparse
import os
from tqdm import tqdm
from common import TeacherNet, get_data_loaders, evaluate, device
def train_teacher(model, trainloader, testloader, num_epochs, lr, step_size, gamma):
print("Starting Teacher Training...")
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
criterion = torch.nn.CrossEntropyLoss()
best_acc = 0.0
for epoch in range(1, num_epochs + 1):
model.train()
running_loss = 0.0
progress_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=f"Epoch {epoch}/{num_epochs}")
for batch_idx, (images, labels) in progress_bar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
progress_bar.set_postfix(loss=f"{running_loss/(batch_idx+1):.4f}")
scheduler.step()
val_loss, val_acc = evaluate(model, testloader, criterion)
print(f"Epoch {epoch}/{num_epochs} - Validation Loss: {val_loss:.4f} | Accuracy: {val_acc:.4f}")
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "teacher_best.pth")
print(f" --> New best model saved with accuracy: {best_acc:.4f}")
print("Teacher training complete.")
return model
def main():
parser = argparse.ArgumentParser(description="Teacher Model Training")
parser.add_argument("--epochs", type=int, default=20, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
parser.add_argument("--lr", type=float, default=0.001, help="Initial learning rate")
parser.add_argument("--step_size", type=int, default=10, help="Step size for LR scheduler")
parser.add_argument("--gamma", type=float, default=0.5, help="LR decay factor")
args = parser.parse_args()
trainloader, testloader = get_data_loaders(batch_size=args.batch_size)
teacher = TeacherNet(num_classes=10)
train_teacher(teacher, trainloader, testloader,
num_epochs=args.epochs, lr=args.lr,
step_size=args.step_size, gamma=args.gamma)
print("Teacher model training complete and saved as teacher_best.pth.")
if __name__ == "__main__":
main()
4.3. Training the Student Model with Knowledge Distillation – student_train.py
The student model is trained to mimic the teacher while learning from the true labels. Key techniques include:
- Teacher Model Loading and Freezing:
The pre-trained teacher model is loaded and set to evaluation mode. Its parameters are frozen so that only the student’s weights are updated. - Student Model Initialization:
AStudentNet
instance is created. This model is much smaller in capacity. - Knowledge Distillation Loss:
- Hard Loss: Standard cross-entropy loss between the student’s predictions and ground truth.
- Soft Loss (Distillation Loss):
- Both teacher and student logits are divided by a temperature TT to soften the distributions.
- The Kullback–Leibler (KL) divergence is computed between these softened outputs.
- The soft loss is scaled by T2T2 to balance the gradients.
- Combined Loss: A weighted sum (using factor αα) of the hard loss and the distillation loss.
- Training Loop:
The student model is trained similarly to the teacher, with the additional step of calculating and backpropagating the combined loss.
# student_train.py
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
import argparse
import os
from tqdm import tqdm
from common import TeacherNet, StudentNet, get_data_loaders, evaluate, device
def train_student(teacher, student, trainloader, testloader, num_epochs, lr, step_size, gamma, T, alpha):
print("Starting Student Training with Knowledge Distillation...")
teacher.to(device)
student.to(device)
# Freeze teacher parameters
teacher.eval()
for param in teacher.parameters():
param.requires_grad = False
optimizer = optim.Adam(student.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
criterion_ce = torch.nn.CrossEntropyLoss()
best_acc = 0.0
for epoch in range(1, num_epochs + 1):
student.train()
running_loss = 0.0
progress_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=f"Epoch {epoch}/{num_epochs}")
for batch_idx, (images, labels) in progress_bar:
images, labels = images.to(device), labels.to(device)
with torch.no_grad():
teacher_logits = teacher(images)
student_logits = student(images)
loss_ce = criterion_ce(student_logits, labels)
log_student_prob = F.log_softmax(student_logits / T, dim=1)
teacher_prob = F.softmax(teacher_logits / T, dim=1)
loss_kd = F.kl_div(log_student_prob, teacher_prob, reduction='batchmean')
loss = alpha * loss_ce + (1 - alpha) * (T ** 2) * loss_kd
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
progress_bar.set_postfix(loss=f"{running_loss/(batch_idx+1):.4f}")
scheduler.step()
val_loss, val_acc = evaluate(student, testloader, criterion_ce)
print(f"Epoch {epoch}/{num_epochs} - Validation Loss: {val_loss:.4f} | Accuracy: {val_acc:.4f}")
if val_acc > best_acc:
best_acc = val_acc
torch.save(student.state_dict(), "student_best.pth")
print(f" --> New best student model saved with accuracy: {best_acc:.4f}")
print("Student training complete.")
return student
def main():
parser = argparse.ArgumentParser(description="Student Model Training with Knowledge Distillation")
parser.add_argument("--epochs", type=int, default=30, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
parser.add_argument("--lr", type=float, default=0.001, help="Initial learning rate")
parser.add_argument("--step_size", type=int, default=15, help="Step size for LR scheduler")
parser.add_argument("--gamma", type=float, default=0.5, help="LR decay factor")
parser.add_argument("--T", type=float, default=4.0, help="Temperature for distillation")
parser.add_argument("--alpha", type=float, default=0.7, help="Weight factor between CE loss and KD loss")
args = parser.parse_args()
trainloader, testloader = get_data_loaders(batch_size=args.batch_size)
teacher_path = "teacher_best.pth"
if not os.path.exists(teacher_path):
print("Error: Teacher model not found. Please run teacher_train.py first.")
return
teacher = TeacherNet(num_classes=10)
teacher.load_state_dict(torch.load(teacher_path, map_location=device))
print("Loaded pre-trained teacher model from teacher_best.pth.")
student = StudentNet(num_classes=10)
train_student(teacher, student, trainloader, testloader,
num_epochs=args.epochs, lr=args.lr,
step_size=args.step_size, gamma=args.gamma,
T=args.T, alpha=args.alpha)
criterion = torch.nn.CrossEntropyLoss()
test_loss, test_acc = evaluate(student, testloader, criterion)
print(f"Final Student Model -- Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
if __name__ == "__main__":
main()
4.4. Model Comparison and Analysis – compare_models.py
After training, we compare the teacher and student models on several dimensions:
- Performance Metrics: Test loss and accuracy.
- Model Complexity: Parameter counts and model file sizes.
- Inference Latency: Average inference time per run.
- Class-wise Analysis: Confusion matrices and classification reports.
Visualizations (using Matplotlib and Seaborn) help illustrate the trade-offs between the two models.
# compare_models.py
import torch
import time
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from common import TeacherNet, StudentNet, get_data_loaders, evaluate, device
# Utility Functions
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_model_size(filepath):
return os.path.getsize(filepath) / 1e6 # in MB
def measure_inference_time(model, input_size=(1, 3, 32, 32), device=device, num_runs=100):
model.eval()
dummy_input = torch.randn(input_size).to(device)
with torch.no_grad():
for _ in range(10):
_ = model(dummy_input)
start_time = time.time()
for _ in range(num_runs):
_ = model(dummy_input)
total_time = time.time() - start_time
return total_time / num_runs
def get_predictions(model, dataloader):
model.eval()
all_labels, all_preds = [], []
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
outputs = model(images)
preds = outputs.argmax(dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.numpy())
return np.array(all_labels), np.array(all_preds)
def main():
teacher_path = "teacher_best.pth"
student_path = "student_best.pth"
if not os.path.exists(teacher_path) or not os.path.exists(student_path):
raise FileNotFoundError("One or both model files not found.")
teacher = TeacherNet(num_classes=10)
teacher.load_state_dict(torch.load(teacher_path, map_location=device))
teacher.to(device)
student = StudentNet(num_classes=10)
student.load_state_dict(torch.load(student_path, map_location=device))
student.to(device)
_, testloader = get_data_loaders(batch_size=128)
criterion = torch.nn.CrossEntropyLoss()
# Performance on Test Set
teacher_loss, teacher_acc = evaluate(teacher, testloader, criterion)
student_loss, student_acc = evaluate(student, testloader, criterion)
print("=== Performance on Test Set ===")
print(f"Teacher Test Loss: {teacher_loss:.4f}, Accuracy: {teacher_acc:.4f}")
print(f"Student Test Loss: {student_loss:.4f}, Accuracy: {student_acc:.4f}")
# Model Complexity
teacher_params = count_parameters(teacher)
student_params = count_parameters(student)
print("\n=== Model Complexity ===")
print(f"Teacher parameters: {teacher_params}")
print(f"Student parameters: {student_params}")
torch.save(teacher.state_dict(), "teacher_temp.pth")
torch.save(student.state_dict(), "student_temp.pth")
teacher_size = get_model_size("teacher_temp.pth")
student_size = get_model_size("student_temp.pth")
os.remove("teacher_temp.pth")
os.remove("student_temp.pth")
print(f"Teacher model file size: {teacher_size:.2f} MB")
print(f"Student model file size: {student_size:.2f} MB")
# Inference Latency
teacher_latency = measure_inference_time(teacher, device=device)
student_latency = measure_inference_time(student, device=device)
print("\n=== Inference Latency ===")
print(f"Teacher inference time: {teacher_latency * 1000:.2f} ms per run")
print(f"Student inference time: {student_latency * 1000:.2f} ms per run")
# Confusion Matrices and Visualization
cifar10_labels = ["airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"]
true_labels, teacher_preds = get_predictions(teacher, testloader)
_, student_preds = get_predictions(student, testloader)
teacher_cm = confusion_matrix(true_labels, teacher_preds)
student_cm = confusion_matrix(true_labels, student_preds)
plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
sns.heatmap(teacher_cm, annot=True, fmt="d", cmap="Blues",
xticklabels=cifar10_labels, yticklabels=cifar10_labels)
plt.title("Teacher Model Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True Label")
plt.subplot(1, 2, 2)
sns.heatmap(student_cm, annot=True, fmt="d", cmap="Blues",
xticklabels=cifar10_labels, yticklabels=cifar10_labels)
plt.title("Student Model Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True Label")
plt.tight_layout()
plt.show()
# Comparison Bar Chart
metrics = {
"Accuracy": [teacher_acc * 100, student_acc * 100],
"Test Loss": [teacher_loss, student_loss],
"Parameters": [teacher_params, student_params],
"Model Size (MB)": [teacher_size, student_size],
"Inference Time (ms)": [teacher_latency * 1000, student_latency * 1000]
}
labels_list = list(metrics.keys())
teacher_vals = [metrics[m][0] for m in labels_list]
student_vals = [metrics[m][1] for m in labels_list]
x = np.arange(len(labels_list))
width = 0.35
fig, ax = plt.subplots(figsize=(12, 6))
rects1 = ax.bar(x - width/2, teacher_vals, width, label="Teacher", color="tab:blue")
rects2 = ax.bar(x + width/2, student_vals, width, label="Student", color="tab:green")
ax.set_ylabel("Value")
ax.set_title("Comparison of Teacher vs. Student Models")
ax.set_xticks(x)
ax.set_xticklabels(labels_list, rotation=45, ha="right")
ax.legend()
def autolabel(rects):
for rect in rects:
height = rect.get_height()
ax.annotate(f"{height:.2f}",
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha="center", va="bottom")
autolabel(rects1)
autolabel(rects2)
plt.tight_layout()
plt.show()
# Classification Reports
from sklearn.metrics import classification_report
print("\n=== Classification Report for Teacher Model ===")
print(classification_report(true_labels, teacher_preds, target_names=cifar10_labels))
print("\n=== Classification Report for Student Model ===")
print(classification_report(true_labels, student_preds, target_names=cifar10_labels))
if __name__ == '__main__':
main()
5. Output Analysis
When running the comparison script, you might see output similar to:
=== Performance on Test Set ===
Teacher Test Loss: 0.3833, Accuracy: 0.8697
Student Test Loss: 0.5196, Accuracy: 0.8297
=== Model Complexity ===
Teacher parameters: 2659402
Student parameters: 620810
Teacher model file size: 10.65 MB
Student model file size: 2.49 MB
=== Inference Latency ===
Teacher inference time: 2.43 ms per run
Student inference time: 0.88 ms per run
Analysis Summary
- Accuracy:
The teacher achieves ~87% accuracy, while the student reaches ~83%. The slight drop in the student’s performance is acceptable given its efficiency. - Model Complexity:
With approximately 2.66 million parameters vs. 0.62 million, the student model reduces the parameter count by over 75% and saves significant storage space (10.65 MB vs. 2.49 MB). - Inference Latency:
The student model is about three times faster (0.88 ms vs. 2.43 ms per inference), which is crucial for real-time applications on embedded devices. - Class-wise Performance:
Both models excel on classes like “automobile” and “ship,” but the student has a slightly lower performance on more challenging classes such as “cat” and “dog.” Fine-tuning or additional distillation techniques might further narrow this gap.
6. Conclusion
This guide demonstrated a complete workflow for implementing knowledge distillation in embedded AI:
- Definitions and Goals: We outlined the need for lightweight models and introduced knowledge distillation as a solution.
- Detailed Code: We walked through modular scripts (common components, teacher training, student training, and model comparison), providing clear explanations for each step.
- Performance Analysis: We compared metrics such as accuracy, parameter count, model size, and inference speed, highlighting the trade-offs between a high-capacity teacher and an efficient student model.
By employing knowledge distillation, you can deploy models that balance accuracy with the practical constraints of embedded systems. Whether you’re optimizing for speed, storage, or power consumption, this approach provides a robust framework for bridging the gap between cutting-edge deep learning and real-world applications.
You can download and inspect the implementation in this post in this GitHub repository .
Happy coding and efficient model deployment!