Regularization of linear regression model#

In this notebook, we will see the limitations of linear regression models and the advantage of using regularized models instead.

Besides, we will also present the preprocessing required when dealing with regularized models, furthermore when the regularization parameter needs to be tuned.

We will start by highlighting the over-fitting issue that can arise with a simple linear regression model.

Effect of regularization#

We will first load the California housing dataset.

Note

If you want a deeper overview regarding this dataset, you can refer to the Appendix - Datasets description section at the end of this MOOC.

from sklearn.datasets import fetch_california_housing

data, target = fetch_california_housing(as_frame=True, return_X_y=True)
target *= 100  # rescale the target in k$
data.head()
MedInc HouseAge AveRooms AveBedrms Population AveOccup Latitude Longitude
0 8.3252 41.0 6.984127 1.023810 322.0 2.555556 37.88 -122.23
1 8.3014 21.0 6.238137 0.971880 2401.0 2.109842 37.86 -122.22
2 7.2574 52.0 8.288136 1.073446 496.0 2.802260 37.85 -122.24
3 5.6431 52.0 5.817352 1.073059 558.0 2.547945 37.85 -122.25
4 3.8462 52.0 6.281853 1.081081 565.0 2.181467 37.85 -122.25

In one of the previous notebook, we showed that linear models could be used even in settings where data and target are not linearly linked.

We showed that one can use the PolynomialFeatures transformer to create additional features encoding non-linear interactions between features.

Here, we will use this transformer to augment the feature space. Subsequently, we will train a linear regression model. We will use the out-of-sample test set to evaluate the generalization capabilities of our model.

from sklearn.model_selection import cross_validate
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

linear_regression = make_pipeline(
    PolynomialFeatures(degree=2), LinearRegression()
)
cv_results = cross_validate(
    linear_regression,
    data,
    target,
    cv=10,
    scoring="neg_mean_squared_error",
    return_train_score=True,
    return_estimator=True,
)

We can compare the mean squared error on the training and testing set to assess the generalization performance of our model.

train_error = -cv_results["train_score"]
print(
    "Mean squared error of linear regression model on the train set:\n"
    f"{train_error.mean():.3f} Β± {train_error.std():.3f}"
)
Mean squared error of linear regression model on the train set:
4190.212 Β± 151.123
test_error = -cv_results["test_score"]
print(
    "Mean squared error of linear regression model on the test set:\n"
    f"{test_error.mean():.3f} Β± {test_error.std():.3f}"
)
Mean squared error of linear regression model on the test set:
13334.943 Β± 20292.681

The score on the training set is much better. This generalization performance gap between the training and testing score is an indication that our model overfitted our training set.

Indeed, this is one of the danger when augmenting the number of features with a PolynomialFeatures transformer. Our model will focus on some specific features. We can check the weights of the model to have a confirmation. Let’s create a dataframe: the columns will contain the name of the feature while the line the coefficients values stored by each model during the cross-validation.

Since we used a PolynomialFeatures to augment the data, we will create feature names representative of the feature combination. Scikit-learn provides a get_feature_names_out method for this purpose. First, let’s get the first fitted model from the cross-validation.

model_first_fold = cv_results["estimator"][0]

Now, we can access to the fitted PolynomialFeatures to generate the feature names:

feature_names = model_first_fold[0].get_feature_names_out(
    input_features=data.columns
)
feature_names
array(['1', 'MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population',
       'AveOccup', 'Latitude', 'Longitude', 'MedInc^2', 'MedInc HouseAge',
       'MedInc AveRooms', 'MedInc AveBedrms', 'MedInc Population',
       'MedInc AveOccup', 'MedInc Latitude', 'MedInc Longitude',
       'HouseAge^2', 'HouseAge AveRooms', 'HouseAge AveBedrms',
       'HouseAge Population', 'HouseAge AveOccup', 'HouseAge Latitude',
       'HouseAge Longitude', 'AveRooms^2', 'AveRooms AveBedrms',
       'AveRooms Population', 'AveRooms AveOccup', 'AveRooms Latitude',
       'AveRooms Longitude', 'AveBedrms^2', 'AveBedrms Population',
       'AveBedrms AveOccup', 'AveBedrms Latitude', 'AveBedrms Longitude',
       'Population^2', 'Population AveOccup', 'Population Latitude',
       'Population Longitude', 'AveOccup^2', 'AveOccup Latitude',
       'AveOccup Longitude', 'Latitude^2', 'Latitude Longitude',
       'Longitude^2'], dtype=object)

Finally, we can create the dataframe containing all the information.

import pandas as pd

coefs = [est[-1].coef_ for est in cv_results["estimator"]]
weights_linear_regression = pd.DataFrame(coefs, columns=feature_names)

Now, let’s use a box plot to see the coefficients variations.

import matplotlib.pyplot as plt

color = {"whiskers": "black", "medians": "black", "caps": "black"}
weights_linear_regression.plot.box(color=color, vert=False, figsize=(6, 16))
_ = plt.title("Linear regression coefficients")
../_images/7c88589773691ace1855bb0ca1f518de8007cc3d2f2bbb8b9d5f7781edbb85c3.png

We can force the linear regression model to consider all features in a more homogeneous manner. In fact, we could force large positive or negative weight to shrink toward zero. This is known as regularization. We will use a ridge model which enforces such behavior.

from sklearn.linear_model import Ridge

ridge = make_pipeline(PolynomialFeatures(degree=2), Ridge(alpha=100))
cv_results = cross_validate(
    ridge,
    data,
    target,
    cv=10,
    scoring="neg_mean_squared_error",
    return_train_score=True,
    return_estimator=True,
)
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:216: LinAlgWarning: Ill-conditioned matrix (rcond=2.672e-17): result may not be accurate.
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:216: LinAlgWarning: Ill-conditioned matrix (rcond=2.67257e-17): result may not be accurate.
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:216: LinAlgWarning: Ill-conditioned matrix (rcond=2.75536e-17): result may not be accurate.
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:216: LinAlgWarning: Ill-conditioned matrix (rcond=2.67367e-17): result may not be accurate.
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:216: LinAlgWarning: Ill-conditioned matrix (rcond=3.5546e-17): result may not be accurate.
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:216: LinAlgWarning: Ill-conditioned matrix (rcond=2.75974e-17): result may not be accurate.
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:216: LinAlgWarning: Ill-conditioned matrix (rcond=2.82401e-17): result may not be accurate.
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:216: LinAlgWarning: Ill-conditioned matrix (rcond=4.96672e-17): result may not be accurate.
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:216: LinAlgWarning: Ill-conditioned matrix (rcond=2.68318e-17): result may not be accurate.
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:216: LinAlgWarning: Ill-conditioned matrix (rcond=2.68514e-17): result may not be accurate.
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T

The code cell above will generate a couple of warnings because the features included both extremely large and extremely small values, which are causing numerical problems when training the predictive model.

We can explore the train and test scores of this model.

train_error = -cv_results["train_score"]
print(
    "Mean squared error of linear regression model on the train set:\n"
    f"{train_error.mean():.3f} Β± {train_error.std():.3f}"
)
Mean squared error of linear regression model on the train set:
4373.180 Β± 153.942
test_error = -cv_results["test_score"]
print(
    "Mean squared error of linear regression model on the test set:\n"
    f"{test_error.mean():.3f} Β± {test_error.std():.3f}"
)
Mean squared error of linear regression model on the test set:
7303.589 Β± 4950.732

We see that the training and testing scores are much closer, indicating that our model is less overfitting. We can compare the values of the weights of ridge with the un-regularized linear regression.

coefs = [est[-1].coef_ for est in cv_results["estimator"]]
weights_ridge = pd.DataFrame(coefs, columns=feature_names)
weights_ridge.plot.box(color=color, vert=False, figsize=(6, 16))
_ = plt.title("Ridge weights")
../_images/d50e64f6cd7de9c4416ad9b7f766ca2c1e621c36251d29218267c49babb0a4e6.png

By comparing the magnitude of the weights on this plot compared to the previous plot, we see that a ridge model will enforce all weights to have a similar magnitude, while the overall magnitude of the weights is shrunk towards zero with respect to the linear regression model.

However, in this example, we omitted two important aspects: (i) the need to scale the data and (ii) the need to search for the best regularization parameter.

Feature scaling and regularization#

On the one hand, weights define the link between feature values and the predicted target. On the other hand, regularization adds constraints on the weights of the model through the alpha parameter. Therefore, the effect that feature rescaling has on the final weights also interacts with regularization.

Let’s consider the case where features live on the same scale/units: if two features are found to be equally important by the model, they will be affected similarly by regularization strength.

Now, let’s consider the scenario where features have completely different data scale (for instance age in years and annual revenue in dollars). If two features are as important, our model will boost the weights of features with small scale and reduce the weights of features with high scale.

We recall that regularization forces weights to be closer. Therefore, we get an intuition that if we want to use regularization, dealing with rescaled data would make it easier to find an optimal regularization parameter and thus an adequate model.

As a side note, some solvers based on gradient computation are expecting such rescaled data. Unscaled data will be detrimental when computing the optimal weights. Therefore, when working with a linear model and numerical data, it is generally good practice to scale the data.

Thus, we will add a StandardScaler in the machine learning pipeline. This scaler will be placed just before the regressor.

from sklearn.preprocessing import StandardScaler

ridge = make_pipeline(
    PolynomialFeatures(degree=2), StandardScaler(), Ridge(alpha=0.5)
)
cv_results = cross_validate(
    ridge,
    data,
    target,
    cv=10,
    scoring="neg_mean_squared_error",
    return_train_score=True,
    return_estimator=True,
)
train_error = -cv_results["train_score"]
print(
    "Mean squared error of linear regression model on the train set:\n"
    f"{train_error.mean():.3f} Β± {train_error.std():.3f}"
)
Mean squared error of linear regression model on the train set:
4347.036 Β± 156.666
test_error = -cv_results["test_score"]
print(
    "Mean squared error of linear regression model on the test set:\n"
    f"{test_error.mean():.3f} Β± {test_error.std():.3f}"
)
Mean squared error of linear regression model on the test set:
5508.472 Β± 1816.642

We observe that scaling data has a positive impact on the test score and that the test score is closer to the train score. It means that our model is less overfitted and that we are getting closer to the best generalization sweet spot.

Let’s have an additional look to the different weights.

coefs = [est[-1].coef_ for est in cv_results["estimator"]]
weights_ridge = pd.DataFrame(coefs, columns=feature_names)
weights_ridge.plot.box(color=color, vert=False, figsize=(6, 16))
_ = plt.title("Ridge weights with data scaling")
../_images/ab6358996159e2ac612907fda920d4532485d102762e2659b1f62e32b00f5b6d.png

Compare to the previous plots, we see that now all weight magnitudes are closer and that all features are more equally contributing.

In the previous example, we fixed alpha=0.5. We will now check the impact of the value of alpha by increasing its value.

ridge = make_pipeline(
    PolynomialFeatures(degree=2), StandardScaler(), Ridge(alpha=1_000_000)
)
cv_results = cross_validate(
    ridge,
    data,
    target,
    cv=10,
    scoring="neg_mean_squared_error",
    return_train_score=True,
    return_estimator=True,
)
coefs = [est[-1].coef_ for est in cv_results["estimator"]]
weights_ridge = pd.DataFrame(coefs, columns=feature_names)
weights_ridge.plot.box(color=color, vert=False, figsize=(6, 16))
_ = plt.title("Ridge weights with data scaling and large alpha")
../_images/6a75d82613148f24edc5eeb2e2914d6ae665a2eae91b96168f8b8b7ff51501a1.png

Looking specifically to weights values, we observe that increasing the value of alpha will decrease the weight values. A negative value of alpha would actually enhance large weights and promote overfitting.

Note

Here, we only focus on numerical features. For categorical features, it is generally common to omit scaling when features are encoded with a OneHotEncoder since the feature values are already on a similar scale.

However, this choice can be questioned since scaling interacts with regularization as well. For instance, scaling categorical features that are imbalanced (e.g. more occurrences of a specific category) would even out the impact of regularization to each category. However, scaling such features in the presence of rare categories could be problematic (i.e. division by a very small standard deviation) and it can therefore introduce numerical issues.

In the previous analysis, we did not study if the parameter alpha will have an effect on the performance. We chose the parameter beforehand and fixed it for the analysis.

In the next section, we will check the impact of the regularization parameter alpha and how it should be tuned.

Fine tuning the regularization parameter#

As mentioned, the regularization parameter needs to be tuned on each dataset. The default parameter will not lead to the optimal model. Therefore, we need to tune the alpha parameter.

Model hyperparameter tuning should be done with care. Indeed, we want to find an optimal parameter that maximizes some metrics. Thus, it requires both a training set and testing set.

However, this testing set should be different from the out-of-sample testing set that we used to evaluate our model: if we use the same one, we are using an alpha which was optimized for this testing set and it breaks the out-of-sample rule.

Therefore, we should include search of the hyperparameter alpha within the cross-validation. As we saw in previous notebooks, we could use a grid-search. However, some predictor in scikit-learn are available with an integrated hyperparameter search, more efficient than using a grid-search. The name of these predictors finishes by CV. In the case of Ridge, scikit-learn provides a RidgeCV regressor.

Therefore, we can use this predictor as the last step of the pipeline. Including the pipeline a cross-validation allows to make a nested cross-validation: the inner cross-validation will search for the best alpha, while the outer cross-validation will give an estimate of the testing score.

import numpy as np
from sklearn.linear_model import RidgeCV

alphas = np.logspace(-2, 0, num=21)
ridge = make_pipeline(
    PolynomialFeatures(degree=2),
    StandardScaler(),
    RidgeCV(alphas=alphas, store_cv_values=True),
)
from sklearn.model_selection import ShuffleSplit

cv = ShuffleSplit(n_splits=5, random_state=1)
cv_results = cross_validate(
    ridge,
    data,
    target,
    cv=cv,
    scoring="neg_mean_squared_error",
    return_train_score=True,
    return_estimator=True,
    n_jobs=2,
)
train_error = -cv_results["train_score"]
print(
    "Mean squared error of linear regression model on the train set:\n"
    f"{train_error.mean():.3f} Β± {train_error.std():.3f}"
)
Mean squared error of linear regression model on the train set:
4301.922 Β± 24.826
test_error = -cv_results["test_score"]
print(
    "Mean squared error of linear regression model on the test set:\n"
    f"{test_error.mean():.3f} Β± {test_error.std():.3f}"
)
Mean squared error of linear regression model on the test set:
4458.556 Β± 462.192

By optimizing alpha, we see that the training and testing scores are close. It indicates that our model is not overfitting.

When fitting the ridge regressor, we also requested to store the error found during cross-validation (by setting the parameter store_cv_values=True). We will plot the mean squared error for the different alphas regularization strength that we tried. The error bars represent one standard deviation of the average mean square error across folds for a given value of alpha.

mse_alphas = [
    est[-1].cv_values_.mean(axis=0) for est in cv_results["estimator"]
]
cv_alphas = pd.DataFrame(mse_alphas, columns=alphas)
cv_alphas = cv_alphas.aggregate(["mean", "std"]).T
cv_alphas
mean std
0.010000 10805.760257 7859.294941
0.012589 9835.307516 6588.155387
0.015849 8925.577610 5409.395931
0.019953 8100.047161 4353.016704
0.025119 7375.000527 3438.794963
0.031623 6758.474506 2674.639258
0.039811 6250.760542 2056.950531
0.050119 5846.266069 1572.763214
0.063096 5536.215174 1202.943947
0.079433 5311.576140 925.547806
0.100000 5165.715922 718.596591
0.125893 5096.474028 562.042934
0.158489 5107.489238 439.763119
0.199526 5208.638340 345.021050
0.251189 5415.385561 297.489290
0.316228 5746.801131 353.672626
0.398107 6222.125875 535.737882
0.501187 6856.124687 820.747311
0.630957 7654.025118 1192.443073
0.794328 8607.322588 1638.892918
1.000000 9691.803104 2144.446582
plt.errorbar(cv_alphas.index, cv_alphas["mean"], yerr=cv_alphas["std"])
plt.xlim((0.0, 1.0))
plt.ylim((4_500, 11_000))
plt.ylabel("Mean squared error\n (lower is better)")
plt.xlabel("alpha")
_ = plt.title("Testing error obtained by cross-validation")
../_images/a134873af09330133354e26401d3ab749b1b371168010822fa5ac6f2f06569bd.png

As we can see, regularization is just like salt in cooking: one must balance its amount to get the best generalization performance. We can check if the best alpha found is stable across the cross-validation fold.

best_alphas = [est[-1].alpha_ for est in cv_results["estimator"]]
best_alphas
[0.07943282347242812,
 0.12589254117941676,
 0.31622776601683794,
 0.12589254117941676,
 0.09999999999999999]

The optimal regularization strength is not necessarily the same on all cross-validation iterations. But since we expect each cross-validation resampling to stem from the same data distribution, it is common practice to choose the best alpha to put into production as lying in the range defined by:

print(
    f"Min optimal alpha: {np.min(best_alphas):.2f} and "
    f"Max optimal alpha: {np.max(best_alphas):.2f}"
)
Min optimal alpha: 0.08 and Max optimal alpha: 0.32

This range can be reduced by decreasing the spacing between the grid of alphas.

In this notebook, you learned about the concept of regularization and the importance of preprocessing and parameter tuning.