ML\DL model Overfitting and Solutions

ML\DL model Overfitting and Solutions

Table of contents

No heading

No headings in the article.

Hello ML enthusiast!

My todays post is about "Overfitting" or high Variance of a model (both ML or DL) .So, Let's right jump into it.

Sometimes ML model is trained with features from the data that are not necessarily co-related with the target vector or in the case of ANN , running too many epoch than necessary or using too many hidden layers than needed can cause the problem known as overfitting.This problem happens when model captures more noise or underlying pattern from the data causing it to perform very well in the training dataset(i.g 90% or above) but very poor accuracy on the test data set (i.g 80% accuracy)!

Let's see an example, Suppose you are working with "Housing" dataset where your goal is to predict the house price based on the data like garage, square feet, swimming pool, porch etc.

You start with a simple linear regression model to predict house prices based on the square footage.However, the results aren't ideal; the model doesn't capture the complex relationship within the data.To improve the model, you try polynomial regression, fitting higher-degree polynomials.The model fits the training data much better and you keep increasing the degree of polynomial. Eventually you've got a high degree of polynomial that almost perfectly fits all the data points, accuracy score is very high. Surprisingly ,the high degree polynomial performed very poor on the test dataset.It fails to generalize well and predict house price incorrectly.

So what's the solution?

For Machine Learning model: Use K-fold cross validation which will use the entire data set but each time a K subset of the data and leave one ,which will be used for validation.

Another technique is using ensemble model like "Random forest" which use multiple decisions trees. By aggregating their predictions.Each tree in the forest is trained on a random subset of the data and features. This randomness ensures that each tree learns different aspects of the data, reducing the likelihood of overfitting to specific patterns

L1(Ridge),L2(Lasso), Elastic net(L1+L2), those are regularisation techniques that adds a penalty term to the cost function.L1 adds penalty term so that some coefficient becomes exactly 0 or L2 which minimizes some large coefficient, thus reduce model complexity.

Also theirs is other processes like standardisation , normalisation,principal component analysis (PCA). it's a dimensionality reduction technique which takes the screen shot of the data(metaphor) in a direction where it sparsed more. and use those features for prediction.( for details see my another article on PCA).

For ANN : Technique like Dropout, Early Stopping,Batch normalisation ,Data Augmentation etc.

Dropout randomly deactivates neurons (sets to zero) during training at each iteration.Each neuron is attached with a probability of let's say 0.5, which means it's output will be multiplied with 0.5 , which will make some neurons dead.

In the case of Early stopping ,the training process is stops (early stopping occurs) if the performance metric on the validation set fails to improve or starts decreasing for a certain number of consecutive epochs.

All those process mentioned above are huge that i can not describe enough. But tried to give an overview hope that will help. Thanks for reading it this far.