Brajraj
5 min readMay 3, 2021

--

Ktrain: A journey to finetune pre-trained Deep learning model using Ktrain

Problem Statement:

Although we have many pre-trained deep learning models like huggingface models. But it is a bit difficult to use those models for fine-tuning if the person is not having more understanding of deep learning and mathematics behind it. There are many transformers based pre-trained model available (BERT, ELECTRA, GPT-2, GPT-3, XLNet etc.) to use and these can be fine-tuned for your downstream task.

Recently I encountered a business requirement to automate helpdesk ticketing system. The dataset was in Turkish language and unfortunately, I do not know Turkish language. I tried machine learning approach but not getting accuracy more than 78%. I tried deep leaning approach as we had huge dataset. I tried bi-directional LSTM using fasttext model for word embedding and observed many words were not present in the word vector and had to deal with it explicitly. Overall accuracy with this approach was almost near to what we were getting with traditional machine learning approach.

As we know architecture of a model (ANN, CNN, RNN etc.) plays important role in overall model performance. It is challenging sometime to get a good architecture (Number of hidden layers, number of neurons in each layer etc.) when you have fixed timebound to deliver the solution or not much resources like GPUs.

Approach:

There are many pre-trained transformers-based models available which just need fine-tuning with your dataset. I tried simpletransformer to fine-tuned this model which was really great for me to fine-tuned and test the model. But I faced a bit challenge in saving the model and deployment the same in production.

Here, I came through my savior Ktrain. It is a wrapper for TensorFlow Keras that makes deep learning and AI more accessible and easier to apply. It is very easy to save, reload and deploy the model using Ktrain which we will see in our solution section in detail.

Solution:

The Dataset consist of total 4900 news records with different categories.

text: Text of the news

category: category of the news (siyaset, spor, dunya etc.)

1. We need to install ktrain usin below command

pip install ktrain

2. Install Turkish Stemmer using below command

pip install TurkishStemmer

3. Import all required library

4. Define dataset and text cleaning function

5. Clean dataset and text after calling this function

6. Split the dataset into train and test

7. Load pre-trained model from huggingface. Here I am going to use ELECTRA model which support Turkish language. You can change the MODEL_NAME parameter as per your need to support other languages.

8. Fine tunned the model with learning rate and number of epochs.

I have fine-tuned the model with only 25 percent of dataset. I am able to achieve 86 percent of validation accuracy. That is really a very good start.

9. Once you are satisfied with your model performance, you must save the model for deployment purpose.

10. We can reload the tunned model and use for category prediction with confidence score.

11. Classification report of this model for test data shows that we are getting minimum f1 score from 0.72 up to 0.97 for one of the categories.

You can get complete code with dataset on my github account.

Conclusion:

It is really a cake walk to use deep learning model and achieve good accuracy with limited dataset. I have fine-tuned this SOTA model on my cpu machine which is helpful as getting GPU is a bit difficult when you have to work in customer environment.

I will recommend everyone to use this tool to solve your downstream task not only specific to text classification but also for QnA, different NLU related problem, Image classification etc.

At the end, I would like to thank Arun S. Maiya who has developed this tool and made deep leaning SOTA models easy to use.

Improvement:

We can improve model accuracy after tunning the model with good amount of dataset and increase number of iterations. We can try different model like XLNet, distilBERT etc. These models may take less training time with high performance.

References and Acknowledgements

Ktrain Github repo: https://github.com/amaiya/ktrain

Python package doc: https://pypi.org/project/ktrain/

Medium blog: https://medium.com/analytics-vidhya/finetuning-bert-using-ktrain-for-disaster-tweets-classification-18f64a50910b

--

--