{ "cells": [ { "cell_type": "markdown", "id": "4800aa8d", "metadata": {}, "source": [ "# Rectification for Decision Trees" ] }, { "cell_type": "markdown", "id": "87af483d", "metadata": {}, "source": [ "Further information on the rectification method for Decision Trees is available in the paper [Rectifying Binary Classifiers](https://hal.science/hal-04236309/document)." ] }, { "cell_type": "markdown", "id": "1dc05e33", "metadata": {}, "source": [ "## Example from a Hand-Crafted Tree" ] }, { "cell_type": "markdown", "id": "37a0365c", "metadata": {}, "source": [ "To illustrate this, we take an example of a credit scoring scenario. \n", "\n", "Each customer is characterized by:\n", "- an annual income $I$ (in k\\\\$),\n", "- the fact of having already reimbursed a previous loan ($R$),\n", "- and, whether or not, the customer has a permanent position ($PP$). \n", "\n", "A decision tree T representing the model is described in the following figure. The Boolean conditions used in $T$ are $I > 30$, $I > 20$, $R$, and $PP$. \n", "\n", "\"DTRectify1\"\n", "\n", "We start by building the decision tree: " ] }, { "cell_type": "code", "execution_count": 1, "id": "db883698", "metadata": {}, "outputs": [], "source": [ "from pyxai import Builder, Explainer\n", "\n", "node_L_1 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)\n", "node_L_2 = Builder.DecisionNode(1, operator=Builder.GT, threshold=20, left=0, right=node_L_1)\n", "\n", "node_R_1 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)\n", "node_R_2 = Builder.DecisionNode(2, operator=Builder.EQ, threshold=1, left=node_R_1, right=1)\n", "\n", "root = Builder.DecisionNode(1, operator=Builder.GT, threshold=30, left=node_L_2, right=node_R_2)\n", "tree = Builder.DecisionTree(3, root, feature_names=[\"I\", \"PP\", \"R\"])" ] }, { "cell_type": "markdown", "id": "aca44977", "metadata": {}, "source": [ "Consider the instance $x = (I = 25, R = 1, PP = 1)$ corresponding to a customer applying for a loan.\n", "We initialize the explainer with this instance and the associated theory (see the [Theories](/documentation/explainer/theories/) page for more information). " ] }, { "cell_type": "code", "execution_count": 2, "id": "30b36230", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------- Theory Feature Types -----------\n", "Before the one-hot encoding of categorical features:\n", "Numerical features: 1\n", "Categorical features: 0\n", "Binary features: 2\n", "Number of features: 3\n", "Characteristics of categorical features: {}\n", "\n", "Number of used features in the model (before the encoding of categorical features): 3\n", "Number of used features in the model (after the encoding of categorical features): 3\n", "----------------------------------------------\n", "binary representation: (-1, 2, 3, 4)\n", "target_prediction: 1\n", "to_features: ('I <= 30', 'I > 20', 'PP == 1', 'R == 1')\n" ] } ], "source": [ "loan_types = {\n", " \"numerical\": [\"I\"],\n", " \"binary\": [\"PP\", \"R\"],\n", "}\n", "\n", "explainer = Explainer.initialize(tree, instance=(25, 1, 1), features_type=loan_types)\n", "print(\"binary representation: \", explainer.binary_representation)\n", "print(\"target_prediction:\", explainer.target_prediction)\n", "print(\"to_features:\", explainer.to_features(explainer.binary_representation, eliminate_redundant_features=False))" ] }, { "cell_type": "markdown", "id": "452f1209", "metadata": {}, "source": [ "The user (a bank employee) disagrees with this prediction (the loan acceptance). For him/her, the following classification rule must be obeyed: whenever the annual income of the client is lower than 30, the demand should be rejected. To do this rectification, we use the ```rectify()``` method of the ```explainer``` object. More information about this method are available on the [Rectification](/documentation/rectification/) page." ] }, { "cell_type": "code", "execution_count": 3, "id": "f2f3bc71", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "-------------- Rectification information:\n", "Classification Rule - Number of nodes: 3\n", "Model - Number of nodes: 11\n", "Model - Number of nodes (after rectification): 17\n", "Model - Number of nodes (after simplification using the theory): 11\n", "Model - Number of nodes (after elimination of redundant nodes): 7\n", "--------------\n", "target_prediction: 0\n" ] } ], "source": [ "rectified_model = explainer.rectify(conditions=(-1, ), label=0) \n", "# Keep in mind that the condition (-1, ) means that 'I <= 30'.\n", "\n", "print(\"target_prediction:\", explainer.target_prediction)" ] }, { "cell_type": "markdown", "id": "779a7def", "metadata": {}, "source": [ "Here is the model without any simplification: \n", "\n", "\"DTRectify2\"\n", "\n", "Here is the model once simplified:\n", "\n", "\"DTRectify3\"\n", "\n" ] }, { "cell_type": "markdown", "id": "44c1f638", "metadata": {}, "source": [ "We can check that the instance is now correctly classified:" ] }, { "cell_type": "code", "execution_count": 4, "id": "99697598", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "target_prediction: 0\n" ] } ], "source": [ "print(\"target_prediction:\", rectified_model.predict_instance((25, 1, 1)))" ] }, { "cell_type": "markdown", "id": "aa82967b", "metadata": {}, "source": [ "## Example from a Real Dataset" ] }, { "cell_type": "markdown", "id": "29048b47", "metadata": {}, "source": [ "For this example, we take the compas.csv dataset. We create a model using the hold-out approach (by default, the test size is set to 30%) and select a miss-classified instance. " ] }, { "cell_type": "code", "execution_count": 5, "id": "83b88fc4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data:\n", " Number_of_Priors score_factor Age_Above_FourtyFive \\\n", "0 0 0 1 \n", "1 0 0 0 \n", "2 4 0 0 \n", "3 0 0 0 \n", "4 14 1 0 \n", "... ... ... ... \n", "6167 0 1 0 \n", "6168 0 0 0 \n", "6169 0 0 1 \n", "6170 3 0 0 \n", "6171 2 0 0 \n", "\n", " Age_Below_TwentyFive African_American Asian Hispanic \\\n", "0 0 0 0 0 \n", "1 0 1 0 0 \n", "2 1 1 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "... ... ... ... ... \n", "6167 1 1 0 0 \n", "6168 1 1 0 0 \n", "6169 0 0 0 0 \n", "6170 0 1 0 0 \n", "6171 1 0 0 1 \n", "\n", " Native_American Other Female Misdemeanor Two_yr_Recidivism \n", "0 0 1 0 0 0 \n", "1 0 0 0 0 1 \n", "2 0 0 0 0 1 \n", "3 0 1 0 1 0 \n", "4 0 0 0 0 1 \n", "... ... ... ... ... ... \n", "6167 0 0 0 0 0 \n", "6168 0 0 0 0 0 \n", "6169 0 1 0 0 0 \n", "6170 0 0 1 1 0 \n", "6171 0 0 1 0 1 \n", "\n", "[6172 rows x 12 columns]\n", "-------------- Information ---------------\n", "Dataset name: ../dataset/compas.csv\n", "nFeatures (nAttributes, with the labels): 12\n", "nInstances (nObservations): 6172\n", "nLabels: 2\n", "--------------- Evaluation ---------------\n", "method: HoldOut\n", "output: DT\n", "learner_type: Classification\n", "learner_options: {'max_depth': None, 'random_state': 0}\n", "--------- Evaluation Information ---------\n", "For the evaluation number 0:\n", "metrics:\n", " accuracy: 65.33477321814254\n", " precision: 65.20423600605145\n", " recall: 51.126927639383155\n", " f1_score: 57.31382978723405\n", " specificity: 77.20515361744302\n", " true_positive: 431\n", " true_negative: 779\n", " false_positive: 230\n", " false_negative: 412\n", " sklearn_confusion_matrix: [[779, 230], [412, 431]]\n", "nTraining instances: 4320\n", "nTest instances: 1852\n", "\n", "--------------- Explainer ----------------\n", "For the evaluation number 0:\n", "**Decision Tree Model**\n", "nFeatures: 11\n", "nNodes: 539\n", "nVariables: 46\n", "\n", "--------------- Instances ----------------\n", "number of instances selected: 1\n", "----------------------------------------------\n", "prediction: 0\n" ] } ], "source": [ "from pyxai import Learning, Explainer\n", "\n", "learner = Learning.Scikitlearn(\"../dataset/compas.csv\", learner_type=Learning.CLASSIFICATION)\n", "model = learner.evaluate(method=Learning.HOLD_OUT, output=Learning.DT)\n", "\n", "dict_information = learner.get_instances(model, n=1, indexes=Learning.TEST, correct=False, details=True)\n", "\n", "instance = dict_information[\"instance\"]\n", "label = dict_information[\"label\"]\n", "prediction = dict_information[\"prediction\"]\n", "\n", "print(\"prediction:\", prediction)" ] }, { "cell_type": "markdown", "id": "01198b4d", "metadata": {}, "source": [ "We activate the explainer with the associated theory and the selected instance: " ] }, { "cell_type": "code", "execution_count": 6, "id": "4a9ca637", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------- Theory Feature Types -----------\n", "Before the one-hot encoding of categorical features:\n", "Numerical features: 1\n", "Categorical features: 2\n", "Binary features: 3\n", "Number of features: 6\n", "Characteristics of categorical features: {'African_American': ['{African_American,Asian,Hispanic,Native_American,Other}', 'African_American', ['African_American', 'Asian', 'Hispanic', 'Native_American', 'Other']], 'Asian': ['{African_American,Asian,Hispanic,Native_American,Other}', 'Asian', ['African_American', 'Asian', 'Hispanic', 'Native_American', 'Other']], 'Hispanic': ['{African_American,Asian,Hispanic,Native_American,Other}', 'Hispanic', ['African_American', 'Asian', 'Hispanic', 'Native_American', 'Other']], 'Native_American': ['{African_American,Asian,Hispanic,Native_American,Other}', 'Native_American', ['African_American', 'Asian', 'Hispanic', 'Native_American', 'Other']], 'Other': ['{African_American,Asian,Hispanic,Native_American,Other}', 'Other', ['African_American', 'Asian', 'Hispanic', 'Native_American', 'Other']], 'Age_Above_FourtyFive': ['Age', 'Above_FourtyFive', ['Above_FourtyFive', 'Below_TwentyFive']], 'Age_Below_TwentyFive': ['Age', 'Below_TwentyFive', ['Above_FourtyFive', 'Below_TwentyFive']]}\n", "\n", "Number of used features in the model (before the encoding of categorical features): 6\n", "Number of used features in the model (after the encoding of categorical features): 11\n", "----------------------------------------------\n" ] } ], "source": [ "compas_types = {\n", " \"numerical\": [\"Number_of_Priors\"],\n", " \"binary\": [\"Misdemeanor\", \"score_factor\", \"Female\"],\n", " \"categorical\": {\"{African_American,Asian,Hispanic,Native_American,Other}\": [\"African_American\", \"Asian\", \"Hispanic\", \"Native_American\", \"Other\"],\n", " \"Age*\": [\"Above_FourtyFive\", \"Below_TwentyFive\"]}\n", "}\n", "\n", "\n", "explainer = Explainer.initialize(model, instance=instance, features_type=compas_types)" ] }, { "cell_type": "markdown", "id": "d28ea08d", "metadata": {}, "source": [ "We compute why the model predicts 0 for this instance:" ] }, { "cell_type": "code", "execution_count": 7, "id": "b72f0e44", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "explanation: (-1, -2, -3, -4)\n", "to_features: ('Number_of_Priors <= 0.5', 'score_factor = 0', 'Age != Below_TwentyFive')\n" ] } ], "source": [ "reason = explainer.sufficient_reason(n=1)\n", "print(\"explanation:\", reason)\n", "print(\"to_features:\", explainer.to_features(reason))" ] }, { "cell_type": "markdown", "id": "39f6fe97", "metadata": {}, "source": [ "Suppose that the user knows that every instance covered by the explanation (-1, -2, -3, -4) should be classified as a positive instance. The model must be rectified by the corresponding classification rule.\n", "Once the model has been corrected, the instance is classified as expected by the user:" ] }, { "cell_type": "code", "execution_count": 8, "id": "7e4c567a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "-------------- Rectification information:\n", "Classification Rule - Number of nodes: 9\n", "Model - Number of nodes: 1079\n", "Model - Number of nodes (after rectification): 3559\n", "Model - Number of nodes (after simplification using the theory): 1079\n", "Model - Number of nodes (after elimination of redundant nodes): 619\n", "--------------\n", "new prediction: 1\n" ] } ], "source": [ "model = explainer.rectify(conditions=reason, label=1) \n", "print(\"new prediction:\", model.predict_instance(instance))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }