Transfer learning in Spark for image recognition

Transfer learning in Spark demystified in less than 3 minutes reading

Businesses that want to classify a huge set of images in batch per day can do this by leveraging the parallel processing power of PySpark and the accuracy of models trained on a huge set of images using transfer learning. Let’s first explain the two buzz words used in the title.

Transfer learning

Transfer learning is a machine learning method where a model that was pretrained for a different task, is reused for a new task. This is mainly common in tasks such as computer vision and natural language processing. An advantage of this approach is that it can also be used for data sets where you don’t have much data, since the starting point of your final model is a model that has already been trained on lots of data. Moreover, these models have been optimized by experienced developers and may give you a jumpstart in the accuracy of your model.

In this example, I will use transfer learning for image recognition. Three of the more popular base models used for transfer learning are VGG (e.g. VGG19), GoogLeNet (e.g. InceptionV3) or Residual Networks (e.g. ResNet50) . These are all neural networks.  I will not go into detail of these neural networks, but instead explain the basic principle.

The image below shows a schematic overview of transfer learning. The different layers of the neural network have been depicted. The input of the model is an image, which then goes through the several layers of the neural network, which has as output a final label. In the original model, this can be a cat, dog or fish.

Now assume, we have a data set with images of shoes, pants and t-shirts, and want to predict which one is which. To do so, we adopt the model, but only a specific part.  The basis (orange bars) of the model remains unaltered and is commonly named ‘the head’ of the model. This head of the model transforms the input images and extracts basic image features, such as colors and shapes. The retrained part of the model is the ‘tail’ and corresponds to the last layers of the model, which map these basic features onto the right category.


Spark is a unified analytics engine, built for big data processing. It allows to use multiple nodes to perform computations. Since transfer learning and feature extracting using neural networks can be quite a computationally intensive process, Spark is here used to run the first layers of the model in parallel on a lot of images.


The link to the code on github can be found at the end of this blog. In this code setup, I have used Azure Databricks and mounted an ADLSGen2 blob storage to the cluster. To do this, please follow this link: Using Azure Storage Explorer, you can transfer the images of your image training set onto this blob.

The first step is to read the images in Spark, this can be done using below command. Basically, this command will scan the entire folder for any images and convert them to a binary format. The columns of this dataframe are path, length and content (binary format). This will happen in parallel distributed over the Spark nodes.

In the next step, the base model of transfer learning can be used to extract the base features from the images in this dataframe. The image below indicates how a Spark User-Defined Function (UDF) extracts the features in parallel from the images and creates a new column ‘features’, using the base model. A Spark UDF function is a function that allows to execute computational logic on the rows of a dataframe in parallel over multiple Spark nodes. For the full code of this UDF, please see the Github at the end of this article.

 Now, where does that leave us? At this moment, we have a dataframe that is composed of classical numerical features, where we can execute classical machine learning techniques on such as logistic regression, forests, etc..

In this last step, we need to attach the labels to this dataframe based on the path or name of image and create an MLlib pipeline. This pipeline will contain an additional layer on top of the base model which allows to predict the target in our case. The code used for the MLlib pipeline is shown below. I have created a VectorAssembler in order to have the right format for the features and a LabelIndexer to convert the text of the label to a numerical label. As a last model, I used a logistic regression. These are all combined into a Spark ML pipeline, which combines these steps into a single object (similar to a SKLearn pipeline).

Alternative methods

There are multiple design choices to be made in this method:

  • The base model for transfer learning (Resnet, VGG, …) can be altered
  • The final spark pipeline can be changed and include multiple stages, or one can use a tree-based model instead of a Logistic Regression.
  • To tune the hyperparameters of the model one can use a grid search.

Feel free to experiment with this in order to achieve the best performance for your use case.


Above method indicates how you can train a model in parallel on a massive data set on images using transfer learning. This way of training is optimal if you have a case where you need to classify a lot of images in batch everyday. In case of questions, feel free to contact us and we can help you out!

Github link:

Tom Thevelein

This blog is written by Tom Thevelein. Tom is an experienced Big Data architect and Data Scientist who still likes to make his hands dirty by optimizing Spark (in any language), implementing data lake architectures and training algorithms.