{
"cells": [
{
"cell_type": "markdown",
"id": "4800aa8d",
"metadata": {},
"source": [
"# Rectification for Decision Trees"
]
},
{
"cell_type": "markdown",
"id": "87af483d",
"metadata": {},
"source": [
"Further information on the rectification method for Decision Trees is available in the paper [Rectifying Binary Classifiers](https://hal.science/hal-04236309/document)."
]
},
{
"cell_type": "markdown",
"id": "1dc05e33",
"metadata": {},
"source": [
"## Example from a Hand-Crafted Tree"
]
},
{
"cell_type": "markdown",
"id": "37a0365c",
"metadata": {},
"source": [
"To illustrate this, we take an example of a credit scoring scenario. \n",
"\n",
"Each customer is characterized by:\n",
"- an annual income $I$ (in k\\\\$),\n",
"- the fact of having already reimbursed a previous loan ($R$),\n",
"- and, whether or not, the customer has a permanent position ($PP$). \n",
"\n",
"A decision tree T representing the model is described in the following figure. The Boolean conditions used in $T$ are $I > 30$, $I > 20$, $R$, and $PP$. \n",
"\n",
"\n",
"\n",
"We start by building the decision tree: "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "db883698",
"metadata": {},
"outputs": [],
"source": [
"from pyxai import Builder, Explainer\n",
"\n",
"node_L_1 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)\n",
"node_L_2 = Builder.DecisionNode(1, operator=Builder.GT, threshold=20, left=0, right=node_L_1)\n",
"\n",
"node_R_1 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)\n",
"node_R_2 = Builder.DecisionNode(2, operator=Builder.EQ, threshold=1, left=node_R_1, right=1)\n",
"\n",
"root = Builder.DecisionNode(1, operator=Builder.GT, threshold=30, left=node_L_2, right=node_R_2)\n",
"tree = Builder.DecisionTree(3, root, feature_names=[\"I\", \"PP\", \"R\"])"
]
},
{
"cell_type": "markdown",
"id": "aca44977",
"metadata": {},
"source": [
"Consider the instance $x = (I = 25, R = 1, PP = 1)$ corresponding to a customer applying for a loan.\n",
"We initialize the explainer with this instance and the associated theory (see the [Theories](/documentation/explainer/theories/) page for more information). "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "30b36230",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------- Theory Feature Types -----------\n",
"Before the one-hot encoding of categorical features:\n",
"Numerical features: 1\n",
"Categorical features: 0\n",
"Binary features: 2\n",
"Number of features: 3\n",
"Characteristics of categorical features: {}\n",
"\n",
"Number of used features in the model (before the encoding of categorical features): 3\n",
"Number of used features in the model (after the encoding of categorical features): 3\n",
"----------------------------------------------\n",
"binary representation: (-1, 2, 3, 4)\n",
"target_prediction: 1\n",
"to_features: ('I <= 30', 'I > 20', 'PP == 1', 'R == 1')\n"
]
}
],
"source": [
"loan_types = {\n",
" \"numerical\": [\"I\"],\n",
" \"binary\": [\"PP\", \"R\"],\n",
"}\n",
"\n",
"explainer = Explainer.initialize(tree, instance=(25, 1, 1), features_type=loan_types)\n",
"print(\"binary representation: \", explainer.binary_representation)\n",
"print(\"target_prediction:\", explainer.target_prediction)\n",
"print(\"to_features:\", explainer.to_features(explainer.binary_representation, eliminate_redundant_features=False))"
]
},
{
"cell_type": "markdown",
"id": "452f1209",
"metadata": {},
"source": [
"The user (a bank employee) disagrees with this prediction (the loan acceptance). For him/her, the following classification rule must be obeyed: whenever the annual income of the client is lower than 30, the demand should be rejected. To do this rectification, we use the ```rectify()``` method of the ```explainer``` object. More information about this method are available on the [Rectification](/documentation/rectification/) page."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f2f3bc71",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"-------------- Rectification information:\n",
"Classification Rule - Number of nodes: 3\n",
"Model - Number of nodes: 11\n",
"Model - Number of nodes (after rectification): 17\n",
"Model - Number of nodes (after simplification using the theory): 11\n",
"Model - Number of nodes (after elimination of redundant nodes): 7\n",
"--------------\n",
"target_prediction: 0\n"
]
}
],
"source": [
"rectified_model = explainer.rectify(conditions=(-1, ), label=0) \n",
"# Keep in mind that the condition (-1, ) means that 'I <= 30'.\n",
"\n",
"print(\"target_prediction:\", explainer.target_prediction)"
]
},
{
"cell_type": "markdown",
"id": "779a7def",
"metadata": {},
"source": [
"Here is the model without any simplification: \n",
"\n",
"\n",
"\n",
"Here is the model once simplified:\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "44c1f638",
"metadata": {},
"source": [
"We can check that the instance is now correctly classified:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "99697598",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"target_prediction: 0\n"
]
}
],
"source": [
"print(\"target_prediction:\", rectified_model.predict_instance((25, 1, 1)))"
]
},
{
"cell_type": "markdown",
"id": "aa82967b",
"metadata": {},
"source": [
"## Example from a Real Dataset"
]
},
{
"cell_type": "markdown",
"id": "29048b47",
"metadata": {},
"source": [
"For this example, we take the compas.csv dataset. We create a model using the hold-out approach (by default, the test size is set to 30%) and select a miss-classified instance. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "83b88fc4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data:\n",
" Number_of_Priors score_factor Age_Above_FourtyFive \\\n",
"0 0 0 1 \n",
"1 0 0 0 \n",
"2 4 0 0 \n",
"3 0 0 0 \n",
"4 14 1 0 \n",
"... ... ... ... \n",
"6167 0 1 0 \n",
"6168 0 0 0 \n",
"6169 0 0 1 \n",
"6170 3 0 0 \n",
"6171 2 0 0 \n",
"\n",
" Age_Below_TwentyFive African_American Asian Hispanic \\\n",
"0 0 0 0 0 \n",
"1 0 1 0 0 \n",
"2 1 1 0 0 \n",
"3 0 0 0 0 \n",
"4 0 0 0 0 \n",
"... ... ... ... ... \n",
"6167 1 1 0 0 \n",
"6168 1 1 0 0 \n",
"6169 0 0 0 0 \n",
"6170 0 1 0 0 \n",
"6171 1 0 0 1 \n",
"\n",
" Native_American Other Female Misdemeanor Two_yr_Recidivism \n",
"0 0 1 0 0 0 \n",
"1 0 0 0 0 1 \n",
"2 0 0 0 0 1 \n",
"3 0 1 0 1 0 \n",
"4 0 0 0 0 1 \n",
"... ... ... ... ... ... \n",
"6167 0 0 0 0 0 \n",
"6168 0 0 0 0 0 \n",
"6169 0 1 0 0 0 \n",
"6170 0 0 1 1 0 \n",
"6171 0 0 1 0 1 \n",
"\n",
"[6172 rows x 12 columns]\n",
"-------------- Information ---------------\n",
"Dataset name: ../dataset/compas.csv\n",
"nFeatures (nAttributes, with the labels): 12\n",
"nInstances (nObservations): 6172\n",
"nLabels: 2\n",
"--------------- Evaluation ---------------\n",
"method: HoldOut\n",
"output: DT\n",
"learner_type: Classification\n",
"learner_options: {'max_depth': None, 'random_state': 0}\n",
"--------- Evaluation Information ---------\n",
"For the evaluation number 0:\n",
"metrics:\n",
" accuracy: 65.33477321814254\n",
" precision: 65.20423600605145\n",
" recall: 51.126927639383155\n",
" f1_score: 57.31382978723405\n",
" specificity: 77.20515361744302\n",
" true_positive: 431\n",
" true_negative: 779\n",
" false_positive: 230\n",
" false_negative: 412\n",
" sklearn_confusion_matrix: [[779, 230], [412, 431]]\n",
"nTraining instances: 4320\n",
"nTest instances: 1852\n",
"\n",
"--------------- Explainer ----------------\n",
"For the evaluation number 0:\n",
"**Decision Tree Model**\n",
"nFeatures: 11\n",
"nNodes: 539\n",
"nVariables: 46\n",
"\n",
"--------------- Instances ----------------\n",
"number of instances selected: 1\n",
"----------------------------------------------\n",
"prediction: 0\n"
]
}
],
"source": [
"from pyxai import Learning, Explainer\n",
"\n",
"learner = Learning.Scikitlearn(\"../dataset/compas.csv\", learner_type=Learning.CLASSIFICATION)\n",
"model = learner.evaluate(method=Learning.HOLD_OUT, output=Learning.DT)\n",
"\n",
"dict_information = learner.get_instances(model, n=1, indexes=Learning.TEST, correct=False, details=True)\n",
"\n",
"instance = dict_information[\"instance\"]\n",
"label = dict_information[\"label\"]\n",
"prediction = dict_information[\"prediction\"]\n",
"\n",
"print(\"prediction:\", prediction)"
]
},
{
"cell_type": "markdown",
"id": "01198b4d",
"metadata": {},
"source": [
"We activate the explainer with the associated theory and the selected instance: "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4a9ca637",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------- Theory Feature Types -----------\n",
"Before the one-hot encoding of categorical features:\n",
"Numerical features: 1\n",
"Categorical features: 2\n",
"Binary features: 3\n",
"Number of features: 6\n",
"Characteristics of categorical features: {'African_American': ['{African_American,Asian,Hispanic,Native_American,Other}', 'African_American', ['African_American', 'Asian', 'Hispanic', 'Native_American', 'Other']], 'Asian': ['{African_American,Asian,Hispanic,Native_American,Other}', 'Asian', ['African_American', 'Asian', 'Hispanic', 'Native_American', 'Other']], 'Hispanic': ['{African_American,Asian,Hispanic,Native_American,Other}', 'Hispanic', ['African_American', 'Asian', 'Hispanic', 'Native_American', 'Other']], 'Native_American': ['{African_American,Asian,Hispanic,Native_American,Other}', 'Native_American', ['African_American', 'Asian', 'Hispanic', 'Native_American', 'Other']], 'Other': ['{African_American,Asian,Hispanic,Native_American,Other}', 'Other', ['African_American', 'Asian', 'Hispanic', 'Native_American', 'Other']], 'Age_Above_FourtyFive': ['Age', 'Above_FourtyFive', ['Above_FourtyFive', 'Below_TwentyFive']], 'Age_Below_TwentyFive': ['Age', 'Below_TwentyFive', ['Above_FourtyFive', 'Below_TwentyFive']]}\n",
"\n",
"Number of used features in the model (before the encoding of categorical features): 6\n",
"Number of used features in the model (after the encoding of categorical features): 11\n",
"----------------------------------------------\n"
]
}
],
"source": [
"compas_types = {\n",
" \"numerical\": [\"Number_of_Priors\"],\n",
" \"binary\": [\"Misdemeanor\", \"score_factor\", \"Female\"],\n",
" \"categorical\": {\"{African_American,Asian,Hispanic,Native_American,Other}\": [\"African_American\", \"Asian\", \"Hispanic\", \"Native_American\", \"Other\"],\n",
" \"Age*\": [\"Above_FourtyFive\", \"Below_TwentyFive\"]}\n",
"}\n",
"\n",
"\n",
"explainer = Explainer.initialize(model, instance=instance, features_type=compas_types)"
]
},
{
"cell_type": "markdown",
"id": "d28ea08d",
"metadata": {},
"source": [
"We compute why the model predicts 0 for this instance:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b72f0e44",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"explanation: (-1, -2, -3, -4)\n",
"to_features: ('Number_of_Priors <= 0.5', 'score_factor = 0', 'Age != Below_TwentyFive')\n"
]
}
],
"source": [
"reason = explainer.sufficient_reason(n=1)\n",
"print(\"explanation:\", reason)\n",
"print(\"to_features:\", explainer.to_features(reason))"
]
},
{
"cell_type": "markdown",
"id": "39f6fe97",
"metadata": {},
"source": [
"Suppose that the user knows that every instance covered by the explanation (-1, -2, -3, -4) should be classified as a positive instance. The model must be rectified by the corresponding classification rule.\n",
"Once the model has been corrected, the instance is classified as expected by the user:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7e4c567a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"-------------- Rectification information:\n",
"Classification Rule - Number of nodes: 9\n",
"Model - Number of nodes: 1079\n",
"Model - Number of nodes (after rectification): 3559\n",
"Model - Number of nodes (after simplification using the theory): 1079\n",
"Model - Number of nodes (after elimination of redundant nodes): 619\n",
"--------------\n",
"new prediction: 1\n"
]
}
],
"source": [
"model = explainer.rectify(conditions=reason, label=1) \n",
"print(\"new prediction:\", model.predict_instance(instance))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}