{ "cells": [ { "cell_type": "markdown", "id": "ac658db4", "metadata": {}, "source": [ "# Building Decision Trees" ] }, { "cell_type": "markdown", "id": "84d00330", "metadata": {}, "source": [ "This page explains how to build a Decision Tree with tree elements (nodes and leaves). To illustrate it, we take an example from the [On the Explanatory Power of Decision Trees](https://arxiv.org/abs/2108.05266) paper for recognizing Cattleya orchids.\n", "\n", "\"DTbuilder\"\n", "\n", "The decision tree represented by this figure separates Cattleya orchids from other\n", "orchids using the following features: \n", "\n", "* $x_1$: has fragrant flowers.\n", "* $x_2$: has one or two leaves. \n", "* $x_3$: has large flowers. \n", "* $x_4$: is sympodial.\n", "\n", "When a leaf is equal to $1$, the instance is classified as a Cattleya orchid, otherwise it is considered as being from another species." ] }, { "cell_type": "markdown", "id": "ced4da4e", "metadata": {}, "source": [ "## Building the Model" ] }, { "cell_type": "markdown", "id": "696ef438", "metadata": {}, "source": [ "First, we need to import the necessary modules. Let us recall that the ```builder``` module contains methods to build this Decision Tree while the ```explainer``` module provides methods to explain it. " ] }, { "cell_type": "code", "execution_count": 1, "id": "8cfecb66", "metadata": {}, "outputs": [], "source": [ "from pyxai import Builder, Explainer" ] }, { "cell_type": "markdown", "id": "542c29ac", "metadata": {}, "source": [ "Next, we build the tree in a bottom up way, that is, from the leaves to the root. So we start with $x_4$ nodes." ] }, { "cell_type": "code", "execution_count": 2, "id": "56ed123e", "metadata": {}, "outputs": [], "source": [ "node_x4_1 = Builder.DecisionNode(4, left=0, right=1)\n", "node_x4_2 = Builder.DecisionNode(4, left=0, right=1)\n", "node_x4_3 = Builder.DecisionNode(4, left=0, right=1)\n", "node_x4_4 = Builder.DecisionNode(4, left=0, right=1)\n", "node_x4_5 = Builder.DecisionNode(4, left=0, right=1)" ] }, { "cell_type": "markdown", "id": "6779c0dc", "metadata": {}, "source": [ "The ```DecisionNode``` class takes as parameters, respectively, the identifier of the feature, then the left and right values or children nodes." ] }, { "cell_type": "markdown", "id": "7abbb8fa", "metadata": {}, "source": [ "{: .attention }\n", "> In this example, as each feature is binary (i.e. it takes for value either 0 or 1), we do not include in the ```Builder.DecisionNode``` class the operator and the threshold parameters. Thus, the values of these parameters are those by default (respectively, OperatorCondition.GE and $0.5$). This gives us conditions for each node of the form \"$x_4 \\ge 0.5$ ?\". In the [Boosted Trees](/documentation/learning/builder/BTbuilder/) page, more complex conditions are created. " ] }, { "cell_type": "markdown", "id": "dfd4b653", "metadata": {}, "source": [ "Next, we construct the remaining nodes:" ] }, { "cell_type": "code", "execution_count": 3, "id": "58bb0f84", "metadata": {}, "outputs": [], "source": [ "node_x3_1 = Builder.DecisionNode(3, left=0, right=node_x4_1)\n", "node_x3_2 = Builder.DecisionNode(3, left=node_x4_2, right=node_x4_3)\n", "node_x3_3 = Builder.DecisionNode(3, left=node_x4_4, right=node_x4_5)\n", "\n", "node_x2_1 = Builder.DecisionNode(2, left=0, right=node_x3_1)\n", "node_x2_2 = Builder.DecisionNode(2, left=node_x3_2, right=node_x3_3)\n", "\n", "node_x1_1 = Builder.DecisionNode(1, left=node_x2_1, right=node_x2_2)" ] }, { "cell_type": "markdown", "id": "4b30ed2e", "metadata": {}, "source": [ "We can now define the Decision Tree thanks to the ```DecisionTree``` class: " ] }, { "cell_type": "code", "execution_count": 4, "id": "53d65e59", "metadata": {}, "outputs": [], "source": [ "tree = Builder.DecisionTree(4, node_x1_1, force_features_equal_to_binaries=True)" ] }, { "cell_type": "markdown", "id": "47547037", "metadata": {}, "source": [ "This class takes as parameters the numbers of features, the root node and the keyword argument ```force_features_equal_to_binaries=True```. " ] }, { "cell_type": "markdown", "id": "460ebcf2", "metadata": {}, "source": [ "{: .warning }\n", "> The parameter ```force_features_equal_to_binaries``` allows one to make binary variables equal to feature identifiers. These binary variables are used to represent explanations and their values and signs indicate whether the condition at each node is satisfied or not. By default, these binary variables have random values depending on the order with the tree is traversed. Setting the parameter ```force_features_equal_to_binaries``` to ```True``` ensures that the binary variables no longer receive random values. This allows us to have explanations that match the features, without having to use the to_features method. However, this functionality cannot be used with all models because it is not compatible when nodes have different conditions on the same feature. It assumes that the features are the conditions. " ] }, { "cell_type": "markdown", "id": "531f17c5", "metadata": {}, "source": [ "More details about the ```DecisionNode``` and ```DecisionTree``` classes is given in the [Building Models](/documentation/learning/builder/) page. " ] }, { "cell_type": "markdown", "id": "d6537ee8", "metadata": {}, "source": [ "## Explaining the Model" ] }, { "cell_type": "markdown", "id": "b15c16d5", "metadata": {}, "source": [ "We can now compute explanations. Let us start with the instance ```(1,1,1,1)```: " ] }, { "cell_type": "code", "execution_count": 5, "id": "0f509d16", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "instance = (1,1,1,1):\n", "target_prediction: 1\n", "direct: (1, 2, 3, 4)\n", "sufficient_reasons: ((1, 4), (2, 3, 4))\n", "Minimal sufficient reasons: (1, 4)\n", "Contrastives: ((4,), (1, 2), (1, 3))\n" ] } ], "source": [ "print(\"instance = (1,1,1,1):\")\n", "explainer = Explainer.initialize(tree, instance=(1,1,1,1))\n", "\n", "print(\"target_prediction:\", explainer.target_prediction)\n", "direct = explainer.direct_reason()\n", "print(\"direct:\", direct)\n", "assert direct == (1, 2, 3, 4), \"The direct reason is not good !\"\n", "\n", "sufficient_reasons = explainer.sufficient_reason(n=Explainer.ALL)\n", "print(\"sufficient_reasons:\", sufficient_reasons)\n", "assert sufficient_reasons == ((1, 4), (2, 3, 4)), \"The sufficient reasons are not good!\"\n", "\n", "for sufficient in sufficient_reasons:\n", " assert explainer.is_sufficient_reason(sufficient), \"This is not a sufficient reason!\"\n", "\n", "minimals = explainer.minimal_sufficient_reason()\n", "print(\"Minimal sufficient reasons:\", minimals)\n", "assert minimals == (1, 4), \"The minimal sufficient reasons are not good!\"\n", "\n", "contrastives = explainer.contrastive_reason(n=Explainer.ALL)\n", "print(\"Contrastives:\", contrastives)\n", "for contrastive in contrastives:\n", " assert explainer.is_contrastive_reason(contrastive), \"This is not a contrastive reason!\"" ] }, { "cell_type": "markdown", "id": "e7a99346", "metadata": {}, "source": [ "And now with the instance ```(0,0,0,0)```: " ] }, { "cell_type": "code", "execution_count": 6, "id": "97c3868a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "instance = (0,0,0,0):\n", "target_prediction: 0\n", "direct: (-1, -2)\n", "sufficient_reasons: ((-4,), (-1, -2), (-1, -3))\n", "Minimal sufficient reasons: (-4,)\n" ] } ], "source": [ "print(\"\\ninstance = (0,0,0,0):\")\n", "\n", "explainer.set_instance((0,0,0,0))\n", "\n", "print(\"target_prediction:\", explainer.target_prediction)\n", "direct = explainer.direct_reason()\n", "print(\"direct:\", direct)\n", "assert direct == (-1, -2), \"The direct reason is not good !\"\n", "\n", "sufficient_reasons = explainer.sufficient_reason(n=Explainer.ALL)\n", "print(\"sufficient_reasons:\", sufficient_reasons)\n", "assert sufficient_reasons == ((-4,), (-1, -2), (-1, -3)), \"The sufficient reasons are not good !\"\n", "page\n", "for sufficient in sufficient_reasons:\n", " assert explainer.is_sufficient_reason(sufficient), \"This is have to be a sufficient reason !\"\n", "\n", "minimals = explainer.minimal_sufficient_reason(n=1)\n", "print(\"Minimal sufficient reasons:\", minimals)\n", "assert minimals == (-4,), \"The minimal sufficient reasons are not good !\"" ] }, { "cell_type": "markdown", "id": "2717ed48", "metadata": {}, "source": [ "Details on explanations are given in the [Explanations Computation](/documentation/explanations/DTexplanations/) page. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }