Link Search Menu Expand Document
PyXAI
Papers Video GitHub In-the-Loop EXPEKCTATION Release Notes About
download notebook

Saving/Loading Models

The PyXAI library provides functions to save and load models and related hyper-parameters, as well as preselected instances. PyXAI can save several models from an experimental protocol in a directory chosen by the user (named <save_directory> in this example). Each model is associated with an index <i> and two files:

  • <save_directory>/<dataset>.<i>.map: JSON file containing training and test indexes, accuracy, solver name, etc.
  • <save_directory>/<dataset>.<i>.pkl: Raw model saved as a Pickle file.

You can also save preselected instances, which requires an additional file:

  • <save_directory>/<dataset>.<i>.instances (optional): JSON file containing the indexes of preselected instances.

Saving Models

As an illustration, we use the compas dataset. We start by training two Random Forests using a leave-one-group-out cross-validation protocol and selecting one instance:

from pyxai import Learning, Explainer, Tools

learner = Learning.Scikitlearn("../dataset/compas.csv", problem_type=Learning.CLASSIFICATION)
models = learner.evaluate(splitting_method=Learning.LEAVE_ONE_GROUP_OUT, model_type=Learning.RF, splitting_parameters={'n_models': 2})
instance, prediction = learner.get_instances(n=1)
data:
      Number_of_Priors  score_factor  Age_Above_FourtyFive   
0                    0             0                     1  \
1                    0             0                     0   
2                    4             0                     0   
3                    0             0                     0   
4                   14             1                     0   
...                ...           ...                   ...   
6167                 0             1                     0   
6168                 0             0                     0   
6169                 0             0                     1   
6170                 3             0                     0   
6171                 2             0                     0   

      Age_Below_TwentyFive  African_American  Asian  Hispanic   
0                        0                 0      0         0  \
1                        0                 1      0         0   
2                        1                 1      0         0   
3                        0                 0      0         0   
4                        0                 0      0         0   
...                    ...               ...    ...       ...   
6167                     1                 1      0         0   
6168                     1                 1      0         0   
6169                     0                 0      0         0   
6170                     0                 1      0         0   
6171                     1                 0      0         1   

      Native_American  Other  Female  Misdemeanor  Two_yr_Recidivism  
0                   0      1       0            0                  0  
1                   0      0       0            0                  1  
2                   0      0       0            0                  1  
3                   0      1       0            1                  0  
4                   0      0       0            0                  1  
...               ...    ...     ...          ...                ...  
6167                0      0       0            0                  0  
6168                0      0       0            0                  0  
6169                0      1       0            0                  0  
6170                0      0       1            1                  0  
6171                0      0       1            0                  1  

[6172 rows x 12 columns]
--------------   Information   ---------------
Dataset name: ../dataset/compas.csv
nFeatures (nAttributes, with the labels): 12
nInstances (nObservations): 6172
nLabels: 2
---------------   Evaluation   ---------------
method: LeaveOneGroupOut
output: RF
learner_type: Classification
learner_options: {'max_depth': None, 'random_state': 0}
---------   Evaluation Information   ---------
For the evaluation number 0:
metrics:
   accuracy: 66.42903434867142
nTraining instances: 3086
nTest instances: 3086

For the evaluation number 1:
metrics:
   accuracy: 64.45236552171096
nTraining instances: 3086
nTest instances: 3086

---------------   Explainer   ----------------
For the evaluation number 0:
**Random Forest Model**
nClasses: 2
nTrees: 100
nVariables: 63

For the evaluation number 1:
**Random Forest Model**
nClasses: 2
nTrees: 100
nVariables: 69

---------------   Instances   ----------------
number of instances selected: 1
----------------------------------------------

The save method of ModelIO allows saving the models:

Learning.ModelIO.save(models, "try_save")
Model saved: (try_save/compas.0.model, try_save/compas.0.map)
Model saved: (try_save/compas.1.model, try_save/compas.1.map)

If models based on the same dataset already exist in this folder, the method overwrites them.

Loading Models

After saving the models, you can reload them in another program using load:

from pyxai import Learning, Explainer

learner, models = Learning.ModelIO.load("try_save")

for model in models:
    explainer = Explainer.initialize(model, instance)
    print("sufficient_reason:", explainer.sufficient_reason())
----------   Loading Information   -----------
mapping file: try_save/compas.0.map
nFeatures (nAttributes, with the labels): 12
nInstances (nObservations): 6172
nLabels: 2
----------   Loading Information   -----------
mapping file: try_save/compas.1.map
nFeatures (nAttributes, with the labels): 12
nInstances (nObservations): 6172
nLabels: 2
---------   Evaluation Information   ---------
For the evaluation number 0:
metrics: {'accuracy': 66.42903434867142}
nTraining instances: 3086
nTest instances: 3086

For the evaluation number 1:
metrics: {'accuracy': 64.45236552171096}
nTraining instances: 3086
nTest instances: 3086

---------------   Explainer   ----------------
For the evaluation number 0:
**Random Forest Model**
nClasses: 2
nTrees: 100
nVariables: 63

For the evaluation number 1:
**Random Forest Model**
nClasses: 2
nTrees: 100
nVariables: 69

sufficient_reason: (-1, -2, -3, -4, 5, -6, -9, -11, -13)
sufficient_reason: (-1, -2, -3, -4, -6, 8, -13)

Saving/Loading Instances

PyXAI also allows saving and loading instances. To this end, we use the get_instances method with the save_directory and instances_id parameters.

To save instances (more precisely, their indexes), use the save_directory and instances_id parameters. To reload them, use the indexes and instances_id parameters.

In this example, for each of the two models, the indexes of 10 instances of the test set are save into the try_save directory:

for id, model in enumerate(models):
    instances = learner.get_instances(
      dataset="../dataset/compas.csv",
      indexes=Learning.TEST, 
      n=10, 
      model=model, 
      save_directory="try_save",
      instances_id=id)
---------------   Instances   ----------------
Indexes of selected instances saved in: try_save/compas.0.instances
number of instances selected: 10
----------------------------------------------
---------------   Instances   ----------------
Indexes of selected instances saved in: try_save/compas.1.instances
number of instances selected: 10
----------------------------------------------

If the dataset has never been loaded, get_instances does not load it completely and reads only the necessary indexes in the dataset.

Later, in another program, you can load the same instances using these instructions:

for id, model in enumerate(models):
    instances = learner.get_instances(
      dataset="../dataset/compas.csv",
      indexes="try_save", 
      model=model, 
      instances_id=id)
---------------   Instances   ----------------
Loading instances file: try_save/compas.0.instances
number of instances selected: 10
----------------------------------------------
---------------   Instances   ----------------
Loading instances file: try_save/compas.1.instances
number of instances selected: 10
----------------------------------------------

More information about the get_instances method is available on the Generating Models page.