xgboost icon indicating copy to clipboard operation
xgboost copied to clipboard

FairXGBoost: Fairness-aware Classification in XGBoost

Open s-ravichandran opened this issue 4 years ago • 13 comments

Implementation of fairness-aware classification based on our paper

Discussion at https://github.com/dmlc/xgboost/issues/7282

The regularizer proposed in the paper aims at achieving demographic parity (by reducing disparate impact) for minority groups. For a detailed overview of what demographic parity means, please refer to this wonderful blog post

Contributions:

  • Added a mechanism to store the group identifier (sensitive feature) in memory by adding a MetaInfo field
  • Support for reading the sensitive feature in both DMatrix and DaskDMatrix
  • Added dataset and demo for ProPublica COMPAS recidivism analysis

s-ravichandran avatar Oct 27 '21 13:10 s-ravichandran

Thanks for your contribution!

Here is a version of your script that does not change DMatrix. Please check it and tell me if it's correct. I am passing row indices as the labels and then using these indices to fetch the actual labels and sensitivity values.

Further comments on this PR:

  • Generally speaking we will not add data to version control. Your options are to fetch data in the script (not recommended as links tend to go out of date), add instructions for the user to manually download data, or generate some synthetic data in the script (this is the most future proof).
  • The metrics code is very verbose, I think you can improve it.
  • Using try/except to deal with zero division is bad practice. I think there is a better way.
"""
Fairness Aware XGBoost - UCI Adult Income Data
Implementation of https://arxiv.org/abs/2009.01442
"""

import numpy as np
import pandas as pd
import xgboost as xgb
import dask.array as da
import dask.distributed
import dask.dataframe as dd

from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score

"""
Load the data available at 
https://github.com/propublica/compas-analysis/blob/master/compas-scores-two-years.csv
The modified objective function defined in https://arxiv.org/abs/2009.01442 
encourages more favourable outcomes to the minority group. 
In order to ensure this, one needs to follow the following encoding scheme 
If `t` represents the favourable outcome of a classifier and `1 - t`, 
the unfavourable outcome, encode minority members by setting `s = t` 
and majority members (males) by setting `s = 1 - t`
In our example, a label 1.0 represents an unfavourable outcome (likely to be \
re-arrested) and correspondingly, the minority race is represented with a value 1.0
"""


def load_data(filename):
    df = dd.read_csv(filename)
    X = df.drop("two_year_recid", axis=1).drop("race", axis=1).drop("id", axis=1)
    y = df["two_year_recid"]
    s = df["race"]
    return df, X, y, s


def fairbceobj(global_labels, global_sensitive, fair_reg=0.0):
    def fairbceobj_inner(preds, dtrain):
        labels = global_labels[dtrain.get_label().astype(int)]
        sensitive_feature = global_sensitive[dtrain.get_label().astype(int)]

        preds = 1.0 / (1.0 + np.exp(-preds))  # transform raw leaf weight

        grad = preds - labels + (fair_reg * (sensitive_feature - preds))
        hess = (1.0 - fair_reg) * preds * (1.0 - preds)

        return grad, hess

    return fairbceobj_inner


def get_metrics(eval_df, minority_indicator=0.0):
    max_acc = 0
    max_acc_thresh = 0
    for threshold in np.linspace(0.1, 1.0, 100):
        preds = (eval_df["score"] > threshold).astype("float32")
        acc = accuracy_score(eval_df["y"].to_numpy(), preds)
        if acc > max_acc:
            max_acc = acc
            max_acc_thresh = threshold
    eval_df["y_pred"] = (eval_df["score"] > max_acc_thresh).astype("float32")

    mTP = (
        eval_df.loc[eval_df["s"] == minority_indicator]
        .loc[eval_df["y_pred"] == eval_df["y"]]
        .loc[eval_df["y"] == 1.0]
        .shape[0]
    )
    mTN = (
        eval_df.loc[eval_df["s"] == minority_indicator]
        .loc[eval_df["y_pred"] == eval_df["y"]]
        .loc[eval_df["y"] == 0.0]
        .shape[0]
    )
    mFP = (
        eval_df.loc[eval_df["s"] == minority_indicator]
        .loc[eval_df["y_pred"] != eval_df["y"]]
        .loc[eval_df["y"] == 0.0]
        .shape[0]
    )
    mFN = (
        eval_df.loc[eval_df["s"] == minority_indicator]
        .loc[eval_df["y_pred"] != eval_df["y"]]
        .loc[eval_df["y"] == 1.0]
        .shape[0]
    )

    MTP = (
        eval_df.loc[eval_df["s"] == (1.0 - minority_indicator)]
        .loc[eval_df["y_pred"] == eval_df["y"]]
        .loc[eval_df["y"] == 1.0]
        .shape[0]
    )
    MTN = (
        eval_df.loc[eval_df["s"] == (1.0 - minority_indicator)]
        .loc[eval_df["y_pred"] == eval_df["y"]]
        .loc[eval_df["y"] == 0.0]
        .shape[0]
    )
    MFP = (
        eval_df.loc[eval_df["s"] == (1.0 - minority_indicator)]
        .loc[eval_df["y_pred"] != eval_df["y"]]
        .loc[eval_df["y"] == 0.0]
        .shape[0]
    )
    MFN = (
        eval_df.loc[eval_df["s"] == (1.0 - minority_indicator)]
        .loc[eval_df["y_pred"] != eval_df["y"]]
        .loc[eval_df["y"] == 1.0]
        .shape[0]
    )

    TP = mTP + MTP
    FP = mFP + MFP
    TN = mTN + MTN
    FN = mFN + MFN

    try:
        precision = TP / float(TP + FP)
    except ZeroDivisionError as e:
        precision = np.nan
        pass
    try:
        recall = TP / float(TP + FN)
    except ZeroDivisionError as e:
        recall = np.nan
        pass
    try:
        accuracy = (TP + TN) / float(TP + TN + FP + FN)
    except ZeroDivisionError as e:
        accuracy = np.nan
        pass
    try:
        f1_score = 1.0 / ((1.0 / precision) + (1.0 / recall))
    except ZeroDivisionError as e:
        f1_score = np.nan
        pass
    try:
        m_base_rate = (mTP + mFP) / float(mTP + mFP + mTN + mFN)
    except:
        m_base_rate = np.nan
        pass
    try:
        M_base_rate = (MTP + MFP) / float(MTP + MFP + MTN + MFN)
    except:
        M_base_rate = np.nan
        pass

    try:
        disparate_impact = m_base_rate / float(M_base_rate)
    except ZeroDivisionError as e:
        disparate_impact = np.nan
        pass
    try:
        mTPR = mTP / float(mTP + mFN)
    except ZeroDivisionError as e:
        mTPR = np.nan
        pass
    try:
        MTPR = MTP / float(MTP + MFN)
    except ZeroDivisionError as e:
        MTPR = np.nan
        pass

    eod = mTPR - MTPR

    results = {}
    results["precision"] = precision
    results["recall"] = recall
    results["accuracy"] = accuracy
    results["f1_score"] = f1_score
    results["disparate_impact"] = disparate_impact
    results["equal_opportunity_difference"] = eod
    results["mTP"] = mTP
    results["mFP"] = mFP
    results["mTN"] = mTN
    results["mFN"] = mFN
    results["MTP"] = MTP
    results["MFP"] = MFP
    results["MTN"] = MTN
    results["MFN"] = MFN
    return results


if __name__ == "__main__":

    train_df, X_train, y_train, s_train = load_data("../data/compas.txt.train")
    test_df, X_test, y_test, s_test = load_data("../data/compas.txt.test")

    client = dask.distributed.Client()

    dtrain = xgb.dask.DaskDMatrix(client, X_train, label=da.arange(len(y_train)))
    dtest = xgb.dask.DaskDMatrix(client, X_test)

    param = {"max_depth": 5, "eta": 0.1, "reg_lambda": 0}
    num_round = 5
    vanilla_params = param.copy()
    vanilla_result = xgb.dask.train(client, vanilla_params, dtrain, num_round)

    fair_results = {}
    fair_params = param.copy()

    print("Training fair models")
    for fair_reg in tqdm(np.linspace(0.0, 0.8, num=100)):
        fair_results[fair_reg] = {}
        fair_result = xgb.dask.train(
            client,
            fair_params,
            dtrain,
            num_round,
            obj=fairbceobj(
                y_train.values.compute(), s_train.values.compute(), fair_reg=fair_reg
            ),
        )
        fair_results[fair_reg] = fair_result

    results = {}
    print("Predictions with fair models")
    for fair_reg in tqdm(np.linspace(0.0, 0.8, num=100)):
        y = xgb.dask.predict(client, fair_results[fair_reg], dtest)
        eval_df = pd.DataFrame(
            [
                y_test.to_dask_array().compute(),
                y.compute(),
                s_test.to_dask_array().compute(),
            ]
        ).transpose()
        eval_df.columns = ["y", "score", "s"]
        results[fair_reg] = get_metrics(eval_df, minority_indicator=1.0)

    di_arr = []
    eod_arr = []
    fair_reg_arr = []
    for k, v in results.items():
        fair_reg_arr.append(k)
        di = v["disparate_impact"]
        eod = v["equal_opportunity_difference"]
        di_arr.append(di)
        eod_arr.append(eod)

    results_df = pd.DataFrame.from_dict(results)

    z = zip(
        *[
            (y, x)
            for (x, y) in zip(
                results_df.transpose()["disparate_impact"].values,
                results_df.transpose()["accuracy"].values,
            )
        ]
    )

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    fig.suptitle("FairXGBoost on COMPAS data")

    ax1.plot(fair_reg_arr, di_arr)
    ax1.set_xlabel("Fair Regularization Strength")
    ax1.set_ylabel("Disparate Impact")

    ax2.scatter(*z)
    ax2.set_ylabel("Disparate Impact")
    ax2.set_xlabel("Accuracy")
    plt.savefig("example.png")

RAMitchell avatar Oct 28 '21 01:10 RAMitchell

Hi,

Thank you all for the suggestions. Apologies for not following up on this earlier. Was unable to spend time on this.

I'm following up on this PR to understand what changes are finally being planned for the 1.6 release and if I can contribute towards it, I'd like to know how I can be of help.

As for the suggestion of storing keys, my colleague @bhatturam had already suggested something similar by storing the "key" column in one of the MetaInfo fields and then loading the relevant feature values from a look up dataframe. However, we couldn't find a suitable field of int (or long) type that can be used (only floats were available which would not be sufficient for larger datasets to store the index/key column). Would love to hear your thoughts.

Thanks again.

s-ravichandran avatar Dec 15 '21 16:12 s-ravichandran

Thanks for the patience. I will focus on this feature once I can clean up some WIP items. For this feature my preference is either we integrate it fully into the native xgboost library, or use only the Python package. Will get back to this as soon as possible.

trivialfis avatar Dec 28 '21 09:12 trivialfis

Quick update: I'm moving the implementation into C++.

trivialfis avatar Jan 08 '22 06:01 trivialfis

@trivialfis Is there anything I can help to move this feature forward?

hcho3 avatar Feb 07 '22 22:02 hcho3

Hi, sorry for the delay. I'm reading the paper, is the sign for the regularization term in the objective correct?

trivialfis avatar Feb 09 '22 13:02 trivialfis

@s-ravichandran Could you please take a look into https://github.com/dmlc/xgboost/pull/7640 and try it on your datasets?

trivialfis avatar Feb 10 '22 15:02 trivialfis

Hi, sorry for the delay. I'm reading the paper, is the sign for the regularization term in the objective correct?

Hi there, I had the same confusion when I was reading the paper as well. I did some testing using UCI Adult dataset - after modifying the signs in the objective, the fairness measure (disparate impact) does seem more steadily improving over increasing 'mu' (fair regularization strength).

Looking forward to learn more from the author!

f9779 avatar Feb 10 '22 19:02 f9779

@s-ravichandran Could you please take a look into #7640 and try it on your datasets?

Thanks a lot for working on this @trivialfis and apologies for the delay in response.

Will take this up and post an update in a couple of days.

s-ravichandran avatar Feb 14 '22 06:02 s-ravichandran

Hi, sorry for the delay. I'm reading the paper, is the sign for the regularization term in the objective correct?

Hi there, I had the same confusion when I was reading the paper as well. I did some testing using UCI Adult dataset - after modifying the signs in the objective, the fairness measure (disparate impact) does seem more steadily improving over increasing 'mu' (fair regularization strength).

Looking forward to learn more from the author!

Will double-check and get back. I remember encountering a similar scenario. Will provide an update with some clarity on the signs soon @f9779 and @trivialfis

Thanks for bringing this up!

s-ravichandran avatar Feb 14 '22 06:02 s-ravichandran

Hi, any update on this?

trivialfis avatar Mar 13 '22 17:03 trivialfis

Hi, any update on this?

trivialfis avatar Oct 10 '22 05:10 trivialfis

Hi,

Apologies for the delay. Had been unable to spend time on this.

I'll try and close this in the next week or 2

s-ravichandran avatar Oct 27 '22 08:10 s-ravichandran