The adult census dataset

The adult census dataset#

This dataset is a collection of demographic information for the adult population as of 1994 in the USA. The prediction task is to predict whether a person is earning a high or low revenue in USD/year.

The column named class is the target variable (i.e., the variable which we want to predict). The two possible classes are " <=50K" (low-revenue) and " >50K" (high-revenue).

Before drawing any conclusions based on its statistics or the predictions of models trained on it, remember that this dataset is not only outdated, but is also not representative of the US population. In fact, the original data contains a feature named fnlwgt that encodes the number of units in the target population that the responding unit represents.

First we load the dataset. We keep only some columns of interest to ease the plotting.

import pandas as pd

adult_census = pd.read_csv("../datasets/adult-census.csv")
columns_to_plot = [
    "age",
    "education-num",
    "capital-loss",
    "capital-gain",
    "hours-per-week",
    "relationship",
    "class",
]
target_name = "class"
target = adult_census[target_name]

We explore this dataset in the first module’s notebook β€œFirst look at our dataset”, where we provide a first intuition on how the data is structured. There, we use a seaborn pairplot to visualize pairwise relationships between the numerical variables in the dataset. This tool aligns scatter plots for every pair of variables and histograms for the plots in the diagonal of the array.

This approach is limited:

  • Pair plots can only deal with numerical features and;

  • by observing pairwise interactions we end up with a two-dimensional projection of a multi-dimensional feature space, which can lead to a wrong interpretation of the individual impact of a feature.

Here we explore with some more detail the relation between features using plotly Parcoords.

import plotly.graph_objects as go
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()


def generate_dict(col):
    """Check if column is categorical and generate the appropriate dict"""
    if adult_census[col].dtype == "object":  # Categorical column
        encoded = le.fit_transform(adult_census[col])
        return {
            "tickvals": list(range(len(le.classes_))),
            "ticktext": list(le.classes_),
            "label": col,
            "values": encoded,
        }
    else:  # Numerical column
        return {"label": col, "values": adult_census[col]}


plot_list = [generate_dict(col) for col in columns_to_plot]

fig = go.Figure(
    data=go.Parcoords(
        line=dict(
            color=le.fit_transform(target),
            colorscale="Viridis",
        ),
        dimensions=plot_list,
    )
)
fig.show()