
Machine Learning for Medical Imaging

Model Distillation : Big to Small Models

Published 10 months ago • 2 min read

Knowledge Distillation : From Teacher to Student


Have you heard of the term "model distillation" or “knowledge distillation”? It’s a very cool concept in deep learning to compress what a large model has learned into a smaller model. Here’s how it works.
In the typical knowledge distillation process, the larger model is first trained on the dataset. Once trained, this model's predictions are used as "soft targets" for training the smaller model.
The smaller model is then trained on the same dataset, but instead of using the original hard target labels (0 or 1 in a binary classification problem), it uses the output probabilities of the larger model as soft targets.
The smaller model learns to mimic the larger model's behavior, including its handling of more nuanced or borderline cases represented in these soft targets.
This process allows the smaller model to generalize better and often achieve performance metrics close to the larger model's, despite its reduced size and complexity.
However, it's worth noting that the performance of the distilled model largely depends on the quality of the teacher model.
If the teacher model is poorly trained or not sophisticated enough, the student model's performance will also be subpar.
Below you can see a code sample on how to do this in Pytorch.
The main part to look at in the student training function is the calculation of the loss function.
It has two components:
1 - The traditional cross-entropy loss between the student's predictions and the true labels,
2 - The Kullback-Leibler (KL) divergence between the "softened" output distributions of the student and the teacher. The softmax function is "softened" by the temperature parameter, which is usually set to a value greater than 1.
Notice that you’re adding 2 hyperparameters here: alpha and temperature.


Why do you need to know this technique?

Knowledge distillation can help compress the knowledge of big models to small ones. So think about it this way, what if you could deploy a model that's 10 times smaller then some original model?

This opens up a lot of opportunities for edge deployment. For example, below I share how you can deploy Pytorch models inside mobile phones.


Deploying Pytorch Models on Mobile Phones

Let’s say you want to deploy a Pytorch model in a mobile app (on the edge and not in the cloud). How would you do that?
Well, you can use Pytorch Mobile. Here’s how you’d go about it.
Train your model using Pytorch. Once that’s done, you convert it to a format that can be used by PyTorch Mobile.
This is often done using a process called "torchscript". TorchScript allows you to serialize your models, meaning they can be loaded in a non-Python environment.
Then, to integrate your model inside your mobile app, you have several options depending on your target environment: iOS or Android.
If you’re on an iOS device you can use TorchModule.
If you’re on an Android device you can use org.pytroch.Module.
These are libraries provided by Pytorch.
You can also deploy your Pytorch model inside a Flutter app by writing custom platform-specific code. You can check this article on how to do that.

Why do you need to know this?

A lot of companies are developing deep learning models that need to be deployed directly on mobile devices. They don't want the option to deploy their models on the cloud. This is specifically done for security and privacy reasons. And in some cases, it is done for the sake of having the app always running, even without internet access.



What'd you think of today's edition?



That's it for this week's edition, I hope you enjoyed it!


Machine Learning for Medical Imaging

by Nour Islam Mokhtari from

👉 Learn how to build AI systems for medical imaging domain by leveraging tools and techniques that I share with you! | 💡 The newsletter is read by people from: Nvidia, Baker Hughes, Harvard, NYU, Columbia University, University of Toronto and more!

Read more from Machine Learning for Medical Imaging

Hi Reader,, Welcome to the PYCAD newsletter, where every week you receive doses of machine learning and computer vision techniques and tools to help you learn how to build AI solutions to empower the most vulnerable members of our society, patients. TotalSegmentator : Whole Body Segmentation at your Fingertips This free tool available online can do full body segmentation, it's called TotalSegmentator. I have already mentioned this tool in a previous edition of the newsletter, but in this...

8 days ago • 3 min read

Hello Reader, Welcome to another edition of PYCAD newsletter where we cover interesting topics in Machine Learning and Computer Vision applied to Medical Imaging. The goal of this newsletter is to help you stay up-to-date and learn important concepts in this amazing field! I've got some cool insights for you below ↓ A Medical Imaging Expert Told Me This Recently I saw a post on LinkedIn where a medical imaging expert showcased his work of segmenting the lungs and its bronchial trees. You can...

15 days ago • 2 min read

Hello Reader, Welcome to another edition of PYCAD newsletter where we cover interesting topics in Machine Learning and Computer Vision applied to Medical Imaging. The goal of this newsletter is to help you stay up-to-date and learn important concepts in this amazing field! I've got some cool insights for you below ↓ How we helped accelerate inference time for a client's AI product Below is a screenshot of a benchmark we did for a client of ours. The goal was to accelerate inference time. This...

21 days ago • 3 min read
Share this post