G
Tutorials
October 11, 2023
10 min read

How to address Machine Learning Bias in a pre-trained HuggingFace text classification model?

Machine learning models, despite their potential, often face issues like biases and performance inconsistencies. As these models find real-world applications, ensuring their robustness becomes paramount. This tutorial explores these challenges, using the Ecommerce Text Classification dataset as a case study. Through this, we highlight key measures and tools, such as Giskard, to boost model performance.

Mostafa Ibrahim

Overview of Machine Learning bias

Navigating the complexities of machine learning unveils inherent challenges, from biases affecting performance to concerns over robustness and prediction confidence. In the following sections, we'll dissect these intricacies and explore their root causes.

  1. Performance Bias

In machine learning, it's not just overall accuracy that counts, but also consistency across data subsets. Performance bias arises when a model excels in general but falters for specific subsets.

  1. Unrobustness

In machine learning, unrobustness means that small changes in input data can drastically alter a model's predictions. Such models are unreliable as they can be easily thrown off by minor data variations, making them unsuitable for real-world applications where data isn't always perfect. Essentially, unrobustness undermines a model's trustworthiness and consistency.

  1. Unethical Behavior in Machine Learning

Unethical biases in machine learning arise when models unintentionally favor or discriminate against certain groups based on attributes like race or gender. This is often due to biased training data. Such prejudices can have serious consequences, especially in areas like job hiring, financial services, and healthcare.

  1. Confidence Issues in Machine Learning 

In machine learning, confidence in model predictions is pivotal. Issues like underconfidence arise when models predict correctly but with low certainty, often due to insufficient training or ambiguous data. Conversely, overconfidence occurs when models assign high certainty to incorrect predictions, a result of biases or overfitting. Both can hinder effective decision-making, making it essential to address for optimal model reliability.

Fine-tuning Hugging Face models for enhanced Text Classification

Hugging Face supplies models pre-trained on vast amounts of text data. To employ these models for text classification, we need to further train them with our specific dataset. This ensures they become adept at classifying texts precisely into the categories we've designated.

1. Load a Pre-trained Model: Use the Hugging Face's library to select a model that's known for text classification, like BERT or DistilBERT. They're already designed to understand the context of texts, which is crucial for this task.

2. Prepare Your Data: 

  • Tokenize Your Data: This is about converting your text into a model-readable format. Hugging Face offers tokenizers that do this job efficiently, ensuring your text is prepped properly for the model.
  • Divide Your Data: Split your dataset into training, validation, and test sections. The training set is for teaching, the validation set is for tweaking the model, and the test set checks how well the model is doing.

3. Fine-tune the Model: 

  • Configure the Model: Make sure the model knows the categories it needs to classify texts into. It's like giving the model a list of labels to use.
  • Training: Here, we will be using the model's foundational knowledge (from its previous training) and then sharpening it with your specific dataset. This way, it gets even better at classifying text into your categories.

In the forthcoming sections of this article, we'll guide you on how to execute the processes correctly, and with the assistance of Giskard, a tool adept at detecting model biases, we'll ensure a comprehensive and bias-free approach.

Using Giskard to detect Machine Learning bias in a pre-trained Text Classification model

Biases in real-world data sets are intrinsic and cannot be ignored when training models to accurately represent the underlying trends. Fine-tuning a model involves tailoring its parameters and structure to better suit a specific task, including making adjustments to address biases in the training data for a more equitable model. 

A comprehensive understanding of the data and its biases, as well as the use of specialized techniques and careful evaluation, are essential to this process. Without fine-tuning, a model may retain biases from its pre-training or fail to adapt to the distinct nature of the task.

Next, we'll explore whether a standard pre-trained BERT model can conduct text classification on the Ecommerce Text Classification dataset. This dataset includes a Label field which includes 4 variables Electronics", "Household", "Books" and "Clothing & Accessories" along with a Text field which contains a brief description of the product. The goal is to classify the given product into one of the four classes using only the product description.

We'll walk you through a straightforward model and finally use the Giskard tool to identify biases in the model. Then, we'll further fine-tune the model on such data points to ensure it can manage these unique scenarios. Without further ado, let’s start by identifying our semi-fine-tuned pre-trained model.

Semi-Fine-Tuned BERT Model for Text Classification

The following model has been adjusted using 1,000 data points and trained for one complete cycle (epoch).

Step 1: Installing and Importing Libraries

Step 2: Setting Up the GPU

Step 3: Loading the Dataset

Step 4: Loading Tokenizer and Pre-Trained Model

Step 5: Defining PredictDataset Class and Custom_Predict function

We need to do some preprocessing steps such as tokenizing, data padding, and batching the data before passing it to Giskard, ensuring it's appropriately processed for precise evaluations. To address this requirement, the custom_predict function was created. This function not only tokenizes but also manages how the data interacts with the model, ensuring consistent evaluations and capturing the model's subtle nuances and potential biases that a direct dataset input might overlook.

Step 6: Assigning numerical values to categories in the 'Label' column 

Step 7: Splitting and Tokenizing the Dataset

Step 8: Creating a Dataloader

Step 9: Fine-tuning our model

Step  10: Wrapping the Dataset Using Giskard

Before diving into model debugging and optimization, it's essential to first prepare our dataset for the impending Giskard scan. This preparation, often referred to as "wrapping," requires us to format our unprocessed dataset in a specific manner before passing it to the Giskard tool. Begin by inputting the raw, untouched dataset directly into Giskard. Ensure the dataset is structured within the pandas framework, encompassing the original raw data. For tasks like text classification, it's imperative to include the pristine dataset and the target output — in our case, the "label." Additionally, outline the columns present within the dataset and denote their respective data types. This wrapping process primes the dataset, making it ready and optimized for a thorough Giskard scan.

Step 11: Wrapping the Model Using Giskard with Custom Prediction Function

Step 12: Evaluating the Model’s performance

Output:

Validation Loss: 1.1822343534

Validation Accuracy: 0.4975

The statement indicating a "50 percent accuracy" for the BERT model immediately raises an eyebrow. While accuracy is a straightforward metric, it can often be an oversimplified representation of a model's true performance for a couple of reasons. 

To begin with, the model might be reflecting the characteristics of an imbalanced dataset, suggesting a bias towards a specific class within our data. Furthermore, relying solely on these metrics restricts our ability to identify and address potential specific problems within the model's predictions and the dataset as a whole.

In essence, while the 50% accuracy offers a glimpse, a comprehensive evaluation approach—potentially incorporating tools like Giskard—is essential to truly understand BERT's performance nuances and areas for enhancement.

Step 13: Scanning the Model Using Giskard, Displaying the Results, and Checking for Model Inconsistencies and Shortcomings

Let's employ Giskard to assess if the semi-trained BERT model needs more fine-tuning and to pinpoint its specific challenges.

The images below highlight the model's tendencies regarding performance, robustness, overconfidence, and underconfidence.

Scan Results - Performance
Scan Results - Performance

We've spotted 12 performance issues. One notable issue is when the average text length is lower than 517 letters. In such cases, the model's accuracy takes a small hit, performing slightly worse than its normal accuracy.

Another issue is caused by training data points which include words such as “men,”, “comfortable”, “fabric”, and a couple more. These issues may stem from a lack of sufficient examples, incorrect labels in the training dataset, or a discrepancy between the training and testing datasets as the warning section mentions.

In this instance, training our model with more examples that include such words would undoubtedly enhance the overall performance of our model. In real-world scenarios, we wouldn’t want our model to underperform in these scenarios so catching them at an earlier phase is quite beneficial as now we can be aware of them and train the model to catch them.

Next, we'll discuss model robustness, an area that points to overfitting and a lack of diversity in the training data.

Scan Results - Robustness

These two issues imply a perturbation issue. To better explain this issue let’s give a simple example.

Imagine you have a robot that's designed to answer questions. Most of the time, it gives you the correct answer. But sometimes, if you change the question just a little bit (this is called a "perturbation"), the robot gets confused and gives a different answer.

In this instance, the model displays a vulnerability in its performance. When trained on data points with specific punctuation, it exhibits an 8.42% error rate if the same question is presented without that punctuation or with altered punctuation. This implies that slight punctuation variations can throw off its predictions. Similarly, the model struggles with a 12.33% error rate when confronted with typos(typographical errors). For example, if it's trained on a data point spelled correctly and then queried with that same data point containing a typo, it often falters.

It's essential to understand these nuances and refine the model accordingly. In real-world scenarios, variations in punctuation and unintentional typos are frequent, so addressing these challenges is paramount for optimal performance.

In essence, this implies that our model may struggle to grasp specific nuances in questions and could benefit from further training or refinements to address these subtle variations more effectively.

Scan Results - Overconfidence

Moving on, we have caught 4 overconfidence issues in our model. Overconfident predictions occur when a machine learning model makes a prediction with high confidence (eg. it assigns a high probability to its prediction) but ends up being incorrect.

For instance, imagine a model predicts that an image of a cat is a dog with 95% confidence. That would be an overconfident incorrect prediction.

In the scenario presented, the model displays overconfidence in situations where the proportion of empty spaces (like the spaces between words) in a text lies between 14.8% and 16.2%. Essentially, when the spaces make up between 14.8% to 16.2% of all the characters in a paragraph, the model tends to strongly believe in its predictions, even when they're incorrect. To put it simply, for texts with this specific range of spacing, the model often thinks it's making the right call, but in reality, it's missing the mark. 

Continued training on the dataset is likely to enhance the model's performance in addressing these issues.

Scan Results - Underconfidence

Moving on to our last issue we have identified 3 issues of underconfidence, signaling and underfitting, in our analysis. One notable issue is the “LabelNum < 0.500” issue. This implies, that given our four classifications: {'Electronics': 0, 'Household': 1, 'Books': 2, 'Clothing & Accessories': 3} ​, when the model is trying to guess which items are 'Electronics', it's not very sure of its answers. Even when it gets them right, it's like it's saying, 'I think this might be Electronics, but I'm not confident.

Improving Text Classification with Partially Fine-Tuned BERT

In the following section, we'll address the previously mentioned problems by further fine-tuning our model. Similar to the previous model we will continue our training on the Ecommerce Text Classification dataset.

To refine the model, we will perform two additional steps, which are increasing the number of data points in our dataset, and increasing the number of epochs. The number of epochs represents the number of full iterations in which our data has passed through our model. For example, if we have 1000 data points and 2 epochs, then we will train our model on our data twice, meaning that we trained it on 2,000 points. It is worth noting that increasing the number of epochs will generally lead to significant improvements in the model’s overall performance up until the point at which it starts overfitting (which can be detected by a decrease in the validation accuracy).

Evaluating Text Classification model performance

Here, we will be utilizing the same code used in step 12 of the semi-fine-tuned model.

output:

Validation Loss: 0.9094364500045776

Validation Accuracy: 0.79625

This indicates that the model achieves an accuracy of roughly 80 percent, signifying a substantial enhancement in the model's performance.

Now let's examine how our model's performance has evolved, beginning with the evaluation of Performance issues.

Scan Results - Performance

When looking at performance, there's a notable improvement as the issues have reduced from 9 to 5. By training the model with more data and for additional epochs, we've addressed problems with words like “men,” “comfortable”, and “fabric”. However, two new words have surfaced that the model isn't handling well. It's also worth highlighting that the global percentage metrics have decreased, indicating the model is making fewer errors compared to the overall dataset.

Additionally, there's a reduction in the LabelNum issues, which refer to errors in predicting specific classes of items.

Scan Results - Robustness

In terms of robustness, we've seen a clear improvement in previous issues. Punctuation errors decreased to 5.95% from an initial 8.42%, and typos dropped to 5.24% from an earlier 12.33%.

Scan Results - Overconfidence

There's a noticeable improvement in overconfidence issues, as error rates have dropped significantly. This is due to the model being better trained with more examples.

Scan Results - Underconfidence

Finally, for the issue of underconfidence, we've resolved two problems. The only remaining concern is with 'LabelNum', but even that has shown significant improvement. We'll have to wait and observe whether further refinement can address this.

Achieving Optimal Text Classification with Fully-Fine-Tuned BERT Model

By further refining the model using 3,000 data points over 10 epochs, we've successfully addressed prior challenges, leaving us with only 3 performance issues and the subsequent evaluation:

Validation Loss: 0.3192953437194228

Validation Accuracy: 0.9125

Scan Results - Performance

Conclusion

Given that our model's dataset was already well-processed, we needed to uncover effective strategies for further enhancing performance. We quickly discerned that to mitigate biases such as overconfidence and punctuation issues, highlighted by Giskard, increasing the number of training points in conjunction with their corresponding epochs was pivotal. Starting with 1,000 data points over one epoch, we incrementally scaled up, moving to 2,000 points over three epochs, and culminating at 3,000 points with ten epochs.

Yet, as we delved deeper into the training process, an observation emerged: while the increase in epochs provided certain benefits, the real game-changer was the expansion in training data. This increment not only enriched our data variety but also steered clear of the redundancy associated with merely increasing epoch counts.

However, it wasn't all plain sailing. A curious challenge arose during our early attempts: the model's evaluation accuracy was nearly flawless, a promising sign on the surface. But a closer inspection via Giskard painted a different picture, unearthing inherent biases. We soon realized the culprit: a lack of shuffling in our dataset meant our initial 1,000 data points were clustered from the same class.

Now, consider a less-than-perfect dataset, distinct from our well-curated one. In such cases, a myriad of techniques beckons. For instance, Giskard’s capability to pinpoint data points with high global variability can be a beacon, signaling whether certain points should be retained or jettisoned. When diverse data inputs fail to elevate a model's response, it might be prudent to reconsider their inclusion. Moreover, adjusting specific hyperparameters including the learning rate, batch size, number of epochs, dropout rate, weight decay, choice of optimizer (like Adam or SGD), learning rate schedulers, activation functions, and decisions regarding layer size and number, can significantly improve your model’s overall performance.

Drawing our insights together, it's imperative to underscore a core realization: the mastery of machine learning models extends beyond merely opting for state-of-the-art architectures or acclaimed pre-trained models. It's deeply rooted in the finesse of fine-tuning, comprehensive calibration, and intimate data understanding. Navigating hyperparameters is not solely a technical endeavor but an intricate dance of art and science. Embracing the subtleties of this balancing act can mark the distinction between an average model and one that truly excels.

Integrate | Scan | Test | Automate

Giskard: Testing & evaluation framework for LLMs and AI models

Automatic LLM testing
Protect agaisnt AI risks
Evaluate RAG applications
Ensure compliance