{ "cells": [ { "cell_type": "markdown", "id": "b1db8c9a", "metadata": {}, "source": [ "# Contrastive Reasons" ] }, { "cell_type": "markdown", "id": "dc23fadf", "metadata": {}, "source": [ "{: .attention }\n", "> The algorithms to compute contrastive reasons for multi-class classification problems are still under development and should be available in the next versions of PyXAI (however, the contrastive reasons for binary classification can be calculated)." ] }, { "cell_type": "markdown", "id": "514a5144", "metadata": {}, "source": [ "Unlike abductive explanations that explain why an instance $x$ is classified as belonging to a given class, the **contrastive explanations** explains why $x$ has not been classified by the ML model as expected.\n", "\n", "Let 𝑓 be a Boolean function represented by a random forest 𝑅𝐹, 𝑥 be an instance and 1 (resp. 0) the prediction of 𝑅𝐹 on 𝑥 (𝑓(𝑥)=1 (resp $f(x)=0$)), a **contrastive reason** for $x$ is a term $t$ such that:\n", "* $t \\subseteq t_{x}$, $t_{x} \\setminus t$ is not an implicant of $f;$ \n", "* for every $\\ell \\in t$, $t \\setminus \\{\\ell\\}$ does not satisfy this previous condition (i.e., $t$ is minimal w.r.t. set inclusion).\n", "\n", "Formally, a **contrastive reason** for $x$ is a subset $t$ of the characteristics of $x$ that is minimal w.r.t. set inclusion among those such that at least one instance $x'$ that coincides with $x$ except on the characteristics from $t$ is not classified by the decision tree as $x$ is. Stated otherwhise, a **contrastive reason** represents adjustments of the features that we have to do to change the prediction for an instance. \n", "\n", "A contrastive reason is minimal w.r.t. set inclusion, i.e. there is no subset of this reason which is also a contrastive reason. A **minimal contrastive reason** for $x$ is a contrastive reason for $x$ that contains a minimal number of literals. In other words, a **minimal contrastive reason** has a minimal size. \n", "\n", "More information about contrastive reasons can be found in the paper [On the Explanatory Power of Decision Trees](https://arxiv.org/abs/2108.05266)." ] }, { "cell_type": "markdown", "id": "56693e24", "metadata": {}, "source": [ "{: .note}\n", "For random forests, PyXAI can only compute minimal contrastive reasons." ] }, { "cell_type": "markdown", "id": "24c3fed5", "metadata": {}, "source": [ "| <ExplainerRF Object>.minimal_contrastive_reason(*, n=1, time_limit=None): | \n", "| :----------- | \n", "|This method considers a CNF formula corresponding to the negation of the random forest as hard clauses and adds binary variables representing the instances as unary soft clauses with weights equal to 1. Several calls to a MAXSAT solver ([OPENWBO](https://github.com/sat-group/open-wbo)) are performed and the result of each call is a minimal contrastive reason. The minimal reasons are those with the lowest scores (i.e., the sum of weights). Thus, the algorithm stops its search when a non-minimal reason is found (i.e., when a higher score is found). Moreover, the method prevents from finding the same reason twice or more by adding clauses(called blocking clauses) between each invocation.

Return ```n``` minimal contrastive reasons of the current instance in a ```Tuple``` (when ```n``` is set to 1, does not return a ```Tuple``` but just the reason). Supports the excluded features. The reasons are in the form of binary variables, you must use the ```to_features``` method if you want a representation based on the features considered at start.|\n", "| n ```Integer``` ```Explainer.ALL```: The desired number of contrastive reasons. Set this to ```Explainer.ALL``` to request all reasons. Default value is 1.|\n", "| time_limit ```Integer``` ```None```: The time limit of the method in seconds. Set this to ```None``` to give this process an infinite amount of time. Default value is ```None```.|" ] }, { "cell_type": "markdown", "id": "b151176d", "metadata": {}, "source": [ "The PyXAI library provides a way to check that a reason is contrastive:" ] }, { "cell_type": "markdown", "id": "21c0adf9", "metadata": {}, "source": [ "| <Explainer Object>.is_contrastive_reason(reason): | \n", "| :----------- | \n", "| Checks if the reason is a contrastive one. Replaces in the binary representation of the instance each literal of the reason with its opposite and checks that the result does not predict the same class as the initial instance. Returns ```True``` if the reason is contrastive, ```False``` otherwise. |\n", "| reason ```List``` of ```Integer```: The reason to be checked.|" ] }, { "cell_type": "markdown", "id": "2b240b8f", "metadata": {}, "source": [ "The basic methods (```initialize```, ```set_instance```, ```to_features```, ```is_reason```, ...) of the ```Explainer``` module used in the next examples are described in the [Explainer Principles](/documentation/explainer/) page." ] }, { "cell_type": "markdown", "id": "c8f0eead", "metadata": {}, "source": [ "## Example from Hand-Crafted Trees" ] }, { "cell_type": "markdown", "id": "ad910b80", "metadata": {}, "source": [ "For this example, we take the random forest of the [Building Models](/documentation/learning/builder/RFbuilder/) page consisting of $4$ binary features ($x_1$, $x_2$, $x_3$ and $x_4$). \n", "\n", "The following figure shows the new instance $x' = (1,1,1,0)$ created from the contrastive reason $(x_4)$ in red for the instance $x = (1,1,1,1)$. Thus, the instance $(1,1,1,0)$ that differs with $x$ only on $x_4$ is not classified by $T$ as $x$ is. More precisely, $x'$ is classified as a negative instance while $x$ is classified as a positive instance. Indeed, in this figure, $T_1(x') = 0$, $T_2(x') = 1$ and $T_3(x') = 0$, so $f(x') = 0$. \n", "\n", "\"RFcontrastive\"\n", "\n", "Now, we show how to get them with PyXAI. We start by building the random forest: " ] }, { "cell_type": "code", "execution_count": 1, "id": "a7a88ebf", "metadata": {}, "outputs": [], "source": [ "from pyxai import Builder, Explainer\n", "\n", "nodeT1_1 = Builder.DecisionNode(1, left=0, right=1)\n", "nodeT1_3 = Builder.DecisionNode(3, left=0, right=nodeT1_1)\n", "nodeT1_2 = Builder.DecisionNode(2, left=1, right=nodeT1_3)\n", "nodeT1_4 = Builder.DecisionNode(4, left=0, right=nodeT1_2)\n", "\n", "tree1 = Builder.DecisionTree(4, nodeT1_4, force_features_equal_to_binaries=True)\n", "\n", "nodeT2_4 = Builder.DecisionNode(4, left=0, right=1)\n", "nodeT2_1 = Builder.DecisionNode(1, left=0, right=nodeT2_4)\n", "nodeT2_2 = Builder.DecisionNode(2, left=nodeT2_1, right=1)\n", "\n", "tree2 = Builder.DecisionTree(4, nodeT2_2, force_features_equal_to_binaries=True) #4 features but only 3 used\n", "\n", "nodeT3_1_1 = Builder.DecisionNode(1, left=0, right=1)\n", "nodeT3_1_2 = Builder.DecisionNode(1, left=0, right=1)\n", "nodeT3_4_1 = Builder.DecisionNode(4, left=0, right=nodeT3_1_1)\n", "nodeT3_4_2 = Builder.DecisionNode(4, left=0, right=1)\n", "\n", "nodeT3_2_1 = Builder.DecisionNode(2, left=nodeT3_1_2, right=nodeT3_4_1)\n", "nodeT3_2_2 = Builder.DecisionNode(2, left=0, right=nodeT3_4_2)\n", "\n", "nodeT3_3_1 = Builder.DecisionNode(3, left=nodeT3_2_1, right=nodeT3_2_2)\n", "\n", "tree3 = Builder.DecisionTree(4, nodeT3_3_1, force_features_equal_to_binaries=True)\n", "forest = Builder.RandomForest([tree1, tree2, tree3], n_classes=2)" ] }, { "cell_type": "markdown", "id": "c177fc1f", "metadata": {}, "source": [ "We compute the contrastive reasons for these two instances: " ] }, { "cell_type": "code", "execution_count": 2, "id": "f0733b41", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Contrastives: ((4,),)\n", "-------------------------------\n", "Contrastives: ((-1, -4),)\n" ] } ], "source": [ "explainer = Explainer.initialize(forest)\n", "explainer.set_instance((1,1,1,1))\n", "\n", "contrastives = explainer.minimal_contrastive_reason(n=Explainer.ALL)\n", "print(\"Contrastives:\", contrastives)\n", "for contrastive in contrastives:\n", " assert explainer.is_contrastive_reason(contrastive), \"It is not a contrastive reason !\"\n", "\n", "print(\"-------------------------------\")\n", "\n", "explainer.set_instance((0,0,0,0))\n", "\n", "contrastives = explainer.minimal_contrastive_reason(n=Explainer.ALL)\n", "print(\"Contrastives:\", contrastives)\n", "for contrastive in contrastives:\n", " assert explainer.is_contrastive_reason(contrastive), \"It is not a contrastive reason !\"" ] }, { "cell_type": "markdown", "id": "c75f8563", "metadata": {}, "source": [ "## Example from a Real Dataset" ] }, { "cell_type": "markdown", "id": "ed0ed888", "metadata": {}, "source": [ "For this example, we take the [mnist49](/assets/notebooks/dataset/mnist49.csv) dataset. We create a model using the hold-out approach (by default, the test size is set to 30%) and select a well-classified instance. " ] }, { "cell_type": "code", "execution_count": 3, "id": "bbeb5462", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data:\n", " 0 1 2 3 4 5 6 7 8 9 ... 775 776 777 778 779 780 781 \n", "0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 \\\n", "1 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 \n", "2 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 \n", "3 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 \n", "4 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 \n", "... .. .. .. .. .. .. .. .. .. .. ... ... ... ... ... ... ... ... \n", "13777 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 \n", "13778 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 \n", "13779 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 \n", "13780 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 \n", "13781 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 \n", "\n", " 782 783 784 \n", "0 0 0 4 \n", "1 0 0 9 \n", "2 0 0 4 \n", "3 0 0 9 \n", "4 0 0 4 \n", "... ... ... ... \n", "13777 0 0 4 \n", "13778 0 0 4 \n", "13779 0 0 4 \n", "13780 0 0 9 \n", "13781 0 0 4 \n", "\n", "[13782 rows x 785 columns]\n", "-------------- Information ---------------\n", "Dataset name: ../../../dataset/mnist49.csv\n", "nFeatures (nAttributes, with the labels): 785\n", "nInstances (nObservations): 13782\n", "nLabels: 2\n", "--------------- Evaluation ---------------\n", "method: HoldOut\n", "output: RF\n", "learner_type: Classification\n", "learner_options: {'max_depth': None, 'random_state': 0}\n", "--------- Evaluation Information ---------\n", "For the evaluation number 0:\n", "metrics:\n", " accuracy: 98.21039903264813\n", "nTraining instances: 9647\n", "nTest instances: 4135\n", "\n", "--------------- Explainer ----------------\n", "For the evaluation number 0:\n", "**Random Forest Model**\n", "nClasses: 2\n", "nTrees: 100\n", "nVariables: 27880\n", "\n", "--------------- Instances ----------------\n", "number of instances selected: 1\n", "----------------------------------------------\n" ] } ], "source": [ "from pyxai import Learning, Explainer\n", "\n", "learner = Learning.Scikitlearn(\"../../../dataset/mnist49.csv\", learner_type=Learning.CLASSIFICATION)\n", "model = learner.evaluate(method=Learning.HOLD_OUT, output=Learning.RF)\n", "instance, prediction = learner.get_instances(model, n=1, correct=True)" ] }, { "cell_type": "markdown", "id": "0bc4b271", "metadata": {}, "source": [ "We compute one contrastive reason. Since it is a hard task, we put a time_limit. If ```time_limit``` is reached, we obtain either an approximation of a contrastive reason (some literals can be redundant) or the empty list if no contrastive reason was found: " ] }, { "cell_type": "code", "execution_count": 4, "id": "c2a2d7f1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "instance prediction: 0\n", "\n", "this is an approximation\n", "constrative: ('153 <= 127.5', '154 <= 104.5', '155 <= 127.0', '157 <= 2.5', '158 <= 0.5', '161 > 0.5', '162 > 0.5', '182 <= 12.5', '183 <= 253.5', '184 <= 10.5', '186 <= 101.0', '188 > 17.5', '189 > 3.5', '192 <= 0.5', '208 <= 163.5', '209 <= 1.5', '210 <= 162.0', '211 <= 49.0', '212 <= 254.5', '213 <= 5.0', '216 > 3.5', '232 <= 1.0', '235 <= 5.0', '236 <= 252.5', '237 <= 26.5', '239 <= 36.5', '240 <= 6.0', '241 <= 40.5', '243 > 7.5', '244 > 110.5', '266 <= 212.5', '267 <= 249.5', '269 <= 0.5', '289 <= 254.5', '291 <= 0.5', '292 <= 253.5', '295 <= 6.5', '319 <= 246.0', '320 <= 125.0', '321 <= 253.5', '322 <= 253.5', '326 <= 253.5', '328 > 2.5', '382 > 193.5', '399 <= 1.0', '400 <= 2.0', '428 > 19.5', '429 > 21.5', '435 > 24.5', '454 > 0.5', '456 > 4.5', '465 <= 248.5', '469 <= 0.5', '545 <= 254.5', '550 > 6.5', '592 <= 14.0', '612 <= 2.5', '621 <= 88.0', '633 <= 252.5', '688 <= 253.5', '706 <= 77.5', '711 <= 6.5', '712 <= 1.5', '717 <= 18.5', '735 <= 2.0', '745 <= 25.5', '746 <= 21.5', '750 <= 8.5', '770 <= 19.5')\n" ] } ], "source": [ "explainer = Explainer.initialize(model, instance)\n", "print(\"instance prediction:\", prediction)\n", "print()\n", "\n", "contrastive_reason = explainer.minimal_contrastive_reason(n=1, time_limit=10)\n", "if explainer.elapsed_time == Explainer.TIMEOUT: \n", " print('this is an approximation')\n", "if len(contrastive_reason) > 0: \n", " print(\"constrative: \", explainer.to_features(contrastive_reason, contrastive=True))\n", "else: \n", " print('No contrative reason found')" ] }, { "cell_type": "markdown", "id": "318185c1", "metadata": {}, "source": [ "Other types of explanations are presented in the [Explanations Computation](/documentation/explanations/RFexplanations/) page." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }