Feed: R-bloggers.
Author: msuzen.
Applying cross-validation prevents overfitting and a good out-of-sample performance, low generalisation error in unseen data, indicates not an overfit.
Aim In this post, we will give an intuition on why model validation as approximating generalization error of a model fit and detection of overfitting can not be resolved simultaneously on a single model. We will work on a concrete example workflow in understanding overfitting, overtraining and a typical final model building stage after some conceptual introduction. We will avoid giving a reference to the Bayesian interpretations and regularisation and restrict the post to regression and cross-validation. While regularisation has different ramification due to its mathematical properties and prior distributions have different implications in Bayesian statistics. We assume an introductory background in machine learning, so this is not a beginners tutorial.
A recent question from Andrew Gelman, a Bayesian guru, regarding What is overfitting? was one of the reasons why this post is developed along with my frustration to see practitioners being muddy on the meaning of overfitting and continuing recently published data science related technical articles circulating around and even in some academic papers claiming the above statement.
What do we need to satisfy in supervised learning? One of the most basic tasks in mathematics is to find a solution to a function: If we restrict ourselves to real numbers in $n$-dimensions and our domain of interest would be $mathbb{R}^{n}$. Now imagine set of $p$ points living in this domain $x_{i}$ form a dataset, this is actually a partial solution to a function. The main purpose of modelling is to find an explanation of the dataset, meaning that we need to determine $m$-parameters, $a in mathbb{R}^{m}$ which are unknown. (Note that a non-parametric model does not mean no parameters.) Mathematically speaking this manifests as a function as we said before, $f(x, a)$. This modelling is usually called regression, interpolation or supervised learning depending on the literature you are reading. This is a form of an inverse problem, while we don’t know the parameters but we have a partial information regarding variables. The main issue here is ill-posedness, meaning that solutions are not well-posed. Omitting axiomatic technical details, practical problem is that we can find many functions $f(x, a)$ or models, explaining the dataset. So, we seek the following two concepts to be satisfied by our model solution, $f(x, a)=0$.
1. Generalized: A model should not depend on the dataset. This step is called model validation.
![]() |
Figure 1: A workflow for model validation and selection in supervised learning. |
Generalization of a model can be measured by goodness-of-fit. It essentially tells us how good our model (chosen function) explains the dataset. To find a minimally complex model requires comparison against another model.
Up to now, we have not named a technique how to check if a model is generalized and selected as the best model. Unfortunately, there is no unique way of doing both and that’s the task of data scientist or quantitative practitioner that requires human judgement.
Overfitting of models is widely recognized as a concern. It is less recognized however that overfitting is not an absolute but involves a comparison. A model overfits if it is more complex than another model that fits equally well.
The important point here what do we mean by complex model, or how can we quantify model complexity? Unfortunately, again there is no unique way of doing this. One of the most used approaches is that a model having more parameters is getting more complex. But this is again a bit of a meme and not generally true. One could actually resort to different measures of complexity. For example, by this definition $f_{1}(a,x)=ax$ and $f_{2}(a,x)=ax^2$ have the same complexity by having the same number of free parameters, but intuitively $f_{2}$ is more complex, while it is nonlinear. There are a lot of information theory based measures of complexity but discussion of those are beyond the scope of our post. For demonstration purposes, we will consider more parameters and degree of nonlinearity as more complex a model.
![]() |
Figure 2: Simulated data and the non-stochastic part of the data. |
Hand on example We have intuitively covered the reasons behind how we can’t resolve model validation and judge overfitting simultaneously. Now try to demonstrate this with a simple dataset and models, yet essentially capturing the above premise.
A usual procedure is to generate a synthetic dataset, or simulated dataset, from a model, as a gold standard and use this dataset to build other models. Let’s use the following functional form, from classic text of Bishop, but with an added Gaussian noise $$ f(x) = sin(2pi x) + mathcal{N}(0,0.1).$$ We generate large enough set, 100 points to avoid sample size issue discussed in Bishop’s book, see Figure 2. Let’s decide on two models we would like to apply to this dataset in supervised learning task. Note that, we won’t be discussing Bayesian interpretation here, so equivalency of these model under a strong prior assumption is not an issue as we are using this example for ease of demonstrating the concept. A polynomial model of degree $3$ and degree $5$, we call them $g(x)$ and $h(x)$ respectively are used to learn from the simulated data. $$g(x) = a_{0} + a_{1} x + a_{2} x^{2} + a_{3} x^{3}$$ and $$h(x) = b_{0} + b_{1} x + b_{2} x^{2} + b_{3} x^{3} + b_{4} x^{4} + b_{5} x^{5} + b_{6} x^{6}.$$
![]() |
Figure 3: Overtraining occurs at around after 40 percent of the data usage for g(x). |
Overtraining is not overfitting Overtraining means a model performance degrades in learning model parameters against an objective variable that effects how model is build, for example, an objective variable can be a training data size or iteration cycle in neural network. This is more prevalent in neural networks (see Dayhoff 2011). In our practical example, this will manifest in hold out method to measure RMSD in modelling with g(x). In other words finding an optimal number of data points to use to train the model to give a better performance on unseen data, See Figure 3 and 4.
Overfitting with low validation error We can also estimate 10-fold cross-validation error, CV-RMSD. For this sampling, g and h have 0.13 and 0.12 CV-RMSD respectively. So as we can see, we have a situation that more complex model reaches similar predictive power with cross validation and we can not distinguish this overfitting by just looking at CV-RMSD value or detecting ‘overtraining’ curve from Figures 4. We need two models to compare, hence both Figure 3 and 4, with both CV-RMSD values. We might argue that in small data sets we might be able tell the difference by looking at test and training error differences, this is exactly how Bishop explains overfitting; where he points out overtraining in small datasets.
Which trained model to deploy? Now the question is, we found out best performing model with minimal complexity empirically. All well, but which trained model should we use in production?
Actualy we have already build the model in model selection. In above case, since we got similar
predictive power from g and h, we obviously will use g, trained on the splitting sweet spot from Figure 3.
![]() |
Figure 4: Overtraining occurs at around after 30 percent of the data usage for h(x) |
Conclusion The essential message here is poor validation performance would not guarantee the detection of an overfitted model. As we have seen from examples using synthetic data in one dimension. Overtraining is actually what most practitioners mean when they use the term overfitting.
Outlook As more and more people are using techniques from machine learning or inverse problems, both in academia and industry, some key technical concepts are deviated a bit and take different definitions and meaning for different people, due to the fact that people learn some concepts not from reading the literature carefully but from their line managers or senior colleagues verbally. This creates memes which are actually wrong or at least creating lots of confusion in jargon. It is very important for all of us as practitioners that we must question all technical concepts and try to seek origins from the published scientific literature and not rely entirely on verbal explanations from our experienced colleagues. Also, we should strongly avoid ridiculing question from colleagues even they sound too simple, at the end of the day we don’t stop learning and naive looking questions might have very important consequences in fundamentals of the field.
![]() |
Figure 5: A deployed models h and g on the testing set with the original data. |
Appendix: Reproducing the example using R
The code used in producing the synthetic data, modelling step and visualising the results can be found in github [repo]. In this appendix, we present this R code with detailed comments, but visualisation codes are omitted, they are available in the github repository.
R (GNU S) provides very powerful formula interface. It is probably the most advanced and expressive formula interface in statistical computing, of course along with S.
Above two polynomials can be expressed as formula and as well as a function where we can evaluate.
1 |
#' |
A learning from data will be achieved with lm function from R,
1 |
#' |
and the resulting approximated function can be applied to new data set with the following helper functions with measuring RMSD as a performance metric.
1 |
#' |
We can generate a simulated data as we discussed above by using runif.
1 |
#' |
To detect overtraining we can split the data in different places in increasing training size, and measure the the performance on the training data itself and unseen test data.
1 |
#' |
And the last portion of the code does 10-fold cross-validation.
1 |
#' CV for g(x) and h(x) |
Related
R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more…