{ "cells": [ { "cell_type": "markdown", "id": "ac658db4", "metadata": {}, "source": [ "# Building Random Forests" ] }, { "cell_type": "markdown", "id": "84d00330", "metadata": {}, "source": [ "This page shows how to build a Random Forest with tree elements (nodes and leaves). To illustrate it, we take an example from the [Trading Complexity for Sparsity in Random Forest Explanations](https://ojs.aaai.org/index.php/AAAI/article/view/20484) paper for recognizing Cattleya orchids.\n", "\n", "\"RFbase\"\n", "\n", "The Random Forest (composed of three trees) represented by this figure separates Cattleya orchids from other\n", "orchids using the following features: \n", "\n", "* $x_1$: has fragrant flowers.\n", "* $x_2$: has one or two leaves. \n", "* $x_3$: has large flowers. \n", "* $x_4$: is sympodial.\n", "\n", "When a leaf is equal to $1$, the instance is a Cattleya orchid, otherwise it is considered as being from another species." ] }, { "cell_type": "markdown", "id": "2c306cc3", "metadata": {}, "source": [ "## Building the Model" ] }, { "cell_type": "markdown", "id": "ad7f325b", "metadata": {}, "source": [ "First, we need to import some modules. Let us recall that the ```builder``` module contains methods to build the Decision Tree while the ```explainer``` module provides methods to explain it. " ] }, { "cell_type": "code", "execution_count": 1, "id": "34940b06", "metadata": {}, "outputs": [], "source": [ "from pyxai import Builder, Explainer" ] }, { "cell_type": "markdown", "id": "2c722f4a", "metadata": {}, "source": [ "Next, we build the tree in a bottom-up way, that is, from the leaves to the root. So we start with $x_1$ node of the first tree $T_1$." ] }, { "cell_type": "code", "execution_count": 2, "id": "1a9322de", "metadata": {}, "outputs": [], "source": [ "nodeT1_1 = Builder.DecisionNode(1, left=0, right=1)\n", "nodeT1_3 = Builder.DecisionNode(3, left=0, right=nodeT1_1)\n", "nodeT1_2 = Builder.DecisionNode(2, left=1, right=nodeT1_3)\n", "nodeT1_4 = Builder.DecisionNode(4, left=0, right=nodeT1_2)\n", "\n", "tree1 = Builder.DecisionTree(4, nodeT1_4, force_features_equal_to_binaries=True)" ] }, { "cell_type": "markdown", "id": "6ef1d851", "metadata": {}, "source": [ "{: .attention }\n", "> In this example, as each feature is binary (i.e. takes for value either 0 or 1), we do not include in the ```Builder.DecisionNode``` class the operator and the threshold parameters. Thus, the values of these parameters are those by default (respectively, OperatorCondition.GE and $0.5$). This gives us conditions for each node of the form \"$x_i \\ge 0.5$ ?\". In the [Boosted Tree](/documentation/learning/builder/BTbuilder/) page, more complex conditions are created. " ] }, { "cell_type": "markdown", "id": "8d379ace", "metadata": {}, "source": [ "Next, we build the tree $T_2$:" ] }, { "cell_type": "code", "execution_count": 3, "id": "06ca47db", "metadata": {}, "outputs": [], "source": [ "nodeT2_4 = Builder.DecisionNode(4, left=0, right=1)\n", "nodeT2_1 = Builder.DecisionNode(1, left=0, right=nodeT2_4)\n", "nodeT2_2 = Builder.DecisionNode(2, left=nodeT2_1, right=1)\n", "\n", "tree2 = Builder.DecisionTree(4, nodeT2_2, force_features_equal_to_binaries=True) #4 features but only 3 used" ] }, { "cell_type": "markdown", "id": "581df557", "metadata": {}, "source": [ "{: .attention }\n", "> The first parameter (n_features) of the ```Builder.DecitionTree``` method is set to 4 in this tree, even if there are only 3 features used. Indeed, it is necessary to consider the total number of features used by all trees of the model (not only those of this specific Decision Tree). " ] }, { "cell_type": "markdown", "id": "5b7c96f7", "metadata": {}, "source": [ "And the tree $T_3$:" ] }, { "cell_type": "code", "execution_count": 4, "id": "568d3608", "metadata": {}, "outputs": [], "source": [ "nodeT3_1_1 = Builder.DecisionNode(1, left=0, right=1)\n", "nodeT3_1_2 = Builder.DecisionNode(1, left=0, right=1)\n", "nodeT3_4_1 = Builder.DecisionNode(4, left=0, right=nodeT3_1_1)\n", "nodeT3_4_2 = Builder.DecisionNode(4, left=0, right=1)\n", "\n", "nodeT3_2_1 = Builder.DecisionNode(2, left=nodeT3_1_2, right=nodeT3_4_1)\n", "nodeT3_2_2 = Builder.DecisionNode(2, left=0, right=nodeT3_4_2)\n", "\n", "nodeT3_3_1 = Builder.DecisionNode(3, left=nodeT3_2_1, right=nodeT3_2_2)\n", "\n", "tree3 = Builder.DecisionTree(4, nodeT3_3_1, force_features_equal_to_binaries=True)" ] }, { "cell_type": "markdown", "id": "b7905502", "metadata": {}, "source": [ "{: .warning }\n", "> The parameter ```force_features_equal_to_binaries``` allows one to make binary variables equal to feature identifiers. These binary variables are used to represent explanations. Their values and signs indicate whether the conditions of nodes are satisfied or not. By default, these binary variables have random values depending on the order with the tree is traversed. Setting the parameter ```force_features_equal_to_binaries``` to ```True``` ensures that the binary variables no longer receive random values. This allows us to have explanations that match the features, without having to use the to_features method. However, this functionality cannot be used with all models because it is not compatible when nodes have different conditions on the same feature. It assumes that the features are the conditions. " ] }, { "cell_type": "markdown", "id": "a0dd85c9", "metadata": {}, "source": [ "We can now define the Random Forest: " ] }, { "cell_type": "code", "execution_count": 5, "id": "694ba536", "metadata": {}, "outputs": [], "source": [ "forest = Builder.RandomForest([tree1, tree2, tree3], n_classes=2)" ] }, { "cell_type": "markdown", "id": "d1b67838", "metadata": {}, "source": [ "More details about the ```DecisionNode``` and ```RandomForest``` classes are given in the [Building Models](/documentation/learning/builder/) page. " ] }, { "cell_type": "markdown", "id": "4578d34c", "metadata": {}, "source": [ "## Explaining the Model" ] }, { "cell_type": "markdown", "id": "baceda37", "metadata": {}, "source": [ "Let us compute explanations. Let us start with the instance ```(1,1,1,1)```: " ] }, { "cell_type": "code", "execution_count": 6, "id": "474dd97e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "For instance = (1,1,1,1):\n", "\n", "target_prediction: 1\n", "direct: (1, 2, 3, 4)\n", "sufficient: (1, 4)\n", "minimal: (1, 4)\n", "majoritary: (1, 2, 4)\n", "minimal_contrastive: ((4,),)\n", "minimals: ((2, 3, 4), (1, 3, 4), (1, 2, 4))\n" ] } ], "source": [ "print(\"For instance = (1,1,1,1):\")\n", "print(\"\")\n", "instance = (1,1,1,1)\n", "explainer = Explainer.initialize(forest, instance=instance)\n", "print(\"target_prediction:\", explainer.target_prediction)\n", "\n", "direct = explainer.direct_reason()\n", "print(\"direct:\", direct)\n", "assert direct == (1, 2, 3, 4), \"The direct reason is not good !\"\n", "\n", "sufficient = explainer.sufficient_reason()\n", "print(\"sufficient:\", sufficient)\n", "assert sufficient == (1, 4), \"The sufficient reason is not good !\"\n", "\n", "minimal = explainer.minimal_sufficient_reason()\n", "print(\"minimal:\", minimal)\n", "assert minimal == (1, 4), \"The minimal reason is not good !\"\n", "\n", "majoritary = explainer.majoritary_reason()\n", "print(\"majoritary:\", majoritary)\n", "\n", "minimal_contrastives = explainer.minimal_contrastive_reason(n=Explainer.ALL)\n", "print(\"minimal_contrastive: \", minimal_contrastives)\n", "\n", "minimals = explainer.preferred_majoritary_reason(method=Explainer.MINIMAL, n=10)\n", "print(\"minimals:\", minimals)\n", "\n", "for c in minimal_contrastives:\n", " assert explainer.is_contrastive_reason(c), \"...\"" ] }, { "cell_type": "markdown", "id": "61689131", "metadata": {}, "source": [ "And now with the instance ```(0,0,0,0)```: " ] }, { "cell_type": "code", "execution_count": 7, "id": "ef2ababe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "For instance = (0,1,0,0):\n", "\n", "target_prediction: 0\n", "direct: (2, -3, -4)\n", "sufficient: (-1, -3)\n", "minimal: (-4,)\n", "majoritary: ((2, -4), (-1, -4), (-1, 2, -3))\n", "minimals: ((2, -4), (-1, -4))\n", "minimal_contrastive: ((-3, -4), (-1, -4))\n" ] } ], "source": [ "print(\"\\nFor instance = (0,1,0,0):\")\n", "print(\"\")\n", "instance = (0,1,0,0)\n", "explainer.set_instance(instance=instance)\n", "print(\"target_prediction:\", explainer.target_prediction)\n", "\n", "direct = explainer.direct_reason()\n", "print(\"direct:\", direct)\n", "assert direct == (2, -3, -4), \"The direct reason is not good !\"\n", "\n", "sufficient = explainer.sufficient_reason()\n", "print(\"sufficient:\", sufficient)\n", "assert sufficient == (-1, -3), \"The sufficient reason is not good !\"\n", "\n", "minimal = explainer.minimal_sufficient_reason()\n", "print(\"minimal:\", minimal)\n", "assert minimal == (-4, ), \"The minimal reason is not good !\" \n", "\n", "majoritary = explainer.majoritary_reason(n=Explainer.ALL)\n", "print(\"majoritary:\", majoritary)\n", "\n", "minimals = explainer.preferred_majoritary_reason(method=Explainer.MINIMAL, n=10)\n", "print(\"minimals:\", minimals)\n", "\n", "minimal_contrastives = explainer.minimal_contrastive_reason(n=Explainer.ALL)\n", "print(\"minimal_contrastive: \", minimal_contrastives)\n", "\n", "for c in minimal_contrastives:\n", " assert explainer.is_contrastive_reason(c), \"...\"\n" ] }, { "cell_type": "markdown", "id": "21454e1d", "metadata": {}, "source": [ "Details on explanations are given in the [Explanations Computation](/documentation/explanations/RFexplanations/) page. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }