🔥 Training and testing a PyTorch model
Have you ever spent hours training a PyTorch model, only to realize you made a mistake? Or have you ever gotten decent results from your model, but you are not sure if it is because you built the model correctly or just because deep learning is so powerful that even a flawed architecture can produce acceptable outcomes? AI is not perfect. This trade-off is something to be aware of, especially when considering the limitations of AI and its suitability for various problems. However, AI has become increasingly practical and is helping us to solve a wide range of complex problems that were previously considered almost impossible.
Training a PyTorch model can be a time-consuming process, so it is frustrating when you realize you made a mistake in the code. It is also hard to know if your model is truly effective, or if it is just luck that it's producing good results. To ensure your model will work well on new data, it is important to test it and compare the accuracy and loss to the values you saw at the end of training. If there's a significant difference, it could mean the model is overfitted to the training data and may not perform well on unseen data.
Giskard is a tool designed to address some of the challenges of working with AI. It allows you to quickly test your model to make sure there are no biases and errors in your model. In this tutorial, we will guide you through how to use Giskard, with code examples showing you how to upload a PyTorch model built from scratch and a fine-tuned pretrained model to Giskard for analysis to find edge cases and bugs.
⏪ Before loading your PyTorch model
You’ll need to have Git and Docker installed.
- To get started, run the following commands to install Giskard on your server. This sets up your Python backend.
Once docker-compose starts all the modules, you'll be able to open Giskard at http://localhost:19000/.
- Log in to Giskard, use the following the default credentials:
- Login: admin
- Password: admin
- Next, install ML Worker which is the component in Giskard that connects your Python environment to the Giskard server that you just installed. It executes the model in your working Python environment (notebook, Python IDE, etc). To start ML Worker as a daemon, execute the following command line:
After you have installed the basics that you need, the next step is to upload your models.
🏃Load a PyTorch trained model and assess its performance with Giskard
If you have a model built from scratch, chances are there might be bugs and edge cases, so using Giskard to test comes in very handy. Below are the steps that we will take to upload our PyTorch model that reads the content of the news and classifies the category of the news.
Here is our newspaper classification model built from scratch
Note: The code is fully commented with explanations.
1. How to access the raw data iterators and prepare data processing pipeline
2. Generate data batch and iterator
3. Define a PyTorch model
4. Initiate an instance
5. Define functions to train a PyTorch model and evaluate its results
6. Split the dataset and run the model
7. Evaluate the PyTorch model with test dataset
8. Test on a random news
The next step is to create a pandas data frame which contains the final model you want to inspect and deploy into production.
After this we create a Giskard project and upload our model.
You can find the full code here.
Explanation to break the Giskard code down
- When we run our Giskard server you’ll have the local host running this is where the uploaded data will go to.
- Next, you need to generate your API token in the Admin tab of the Giskard application.
- You can choose the arguments you want for the following your_project = client.create_project("project_key", "PROJECT_NAME", "DESCRIPTION"). In our case it is newspaper= client.create_project("newspaper", "Classification of newspaper article", "Project to classify newspaper article").
Note: "project_key" should be unique and in lower case.
- After our project is created, next we upload the model by specifying the
- Model_type: This is either classification or regression model. In our example it is a classification model.
- Df: This is the pandas data frame that contains the data and final model we want to inspect
- model _names
After you have uploaded your model to Giskard, the image below is what the final output would look like:
This URL http://localhost:19000, will take you to the Giskard interface.
You are all set to try Giskard in action! You can now create some tests to find edge cases and bugs.
📊 How to load PyTorch pretrained models to Giskard?
Fine-tuning trains a pretrained model on a new dataset without training it from scratch. After training, you may want to check if the fine-tuned model gives accurate results. This can be done by uploading your fine-tuned model to Giskard for inspection to find edge cases and bugs.
For this example, we are going to look at how to upload a SST-2 binary text classification with XLM-ROBERTA model to Giskard.
Note: The code is fully commented with explanations.
1. Transform the raw dataset using non-batched API (i.e apply transformation line by line)
2. Prediction function for PyTorch pretrained model
3. Verification of prediction_function
4. Returning PyTorch model
5. Upload your PyTorch pretrained model to Giskard
With the code below, we will upload our model to Giskard. Taking note of the required arguments.
You can find the full code here
Your dataset and model will be uploaded to Giskard will be available at http://localhost:19000. Having uploaded the model, it is important to create tests before putting it into production.
With Giskard, you can look at the model to find edge cases that might need to be tested. This makes it much easier for data scientists to look at the model and figure out what's going on so they can find any problems.
In this article, we demonstrated how simple it is to test PyTorch models, whether they are built from scratch or fine-tuned using a pretrained model. The goal of testing PyTorch production models is to make sure the model is successfully deployed and works well in production with other services. To show how these ideas work, we have given code snippets and specific examples of how they're used. We hope you find this helpful. Happy testing!