1
+ import inspect
1
2
import multiprocessing as mp
3
+ import warnings
2
4
from numpy .random import SeedSequence , default_rng
3
5
4
6
import numpy as np
@@ -18,15 +20,15 @@ def calculate_variable_importance(explainer,
18
20
if processes == 1 :
19
21
result = [None ] * B
20
22
for i in range (B ):
21
- result [i ] = loss_after_permutation (explainer .data , explainer .y , explainer .model , explainer .predict_function ,
23
+ result [i ] = loss_after_permutation (explainer .data , explainer .y , explainer .weights , explainer . model , explainer .predict_function ,
22
24
loss_function , variables , N , np .random )
23
25
else :
24
26
# Create number generator for each iteration
25
27
ss = SeedSequence (random_state )
26
28
generators = [default_rng (s ) for s in ss .spawn (B )]
27
29
pool = mp .get_context ('spawn' ).Pool (processes )
28
30
result = pool .starmap_async (loss_after_permutation , [
29
- (explainer .data , explainer .y , explainer .model , explainer .predict_function , loss_function , variables , N , generators [i ]) for
31
+ (explainer .data , explainer .y , explainer .weights , explainer . model , explainer .predict_function , loss_function , variables , N , generators [i ]) for
30
32
i in range (B )]).get ()
31
33
pool .close ()
32
34
@@ -49,21 +51,24 @@ def calculate_variable_importance(explainer,
49
51
return result , raw_permutations
50
52
51
53
52
- def loss_after_permutation (data , y , model , predict , loss_function , variables , N , rng ):
54
+ def loss_after_permutation (data , y , weights , model , predict , loss_function , variables , N , rng ):
53
55
if isinstance (N , int ):
54
56
N = min (N , data .shape [0 ])
55
57
sampled_rows = rng .choice (np .arange (data .shape [0 ]), N , replace = False )
56
58
sampled_data = data .iloc [sampled_rows , :]
57
59
observed = y [sampled_rows ]
60
+ sample_weights = weights [sampled_rows ] if weights is not None else None
58
61
else :
59
62
sampled_data = data
60
63
observed = y
64
+ sample_weights = weights
61
65
62
66
# loss on the full model or when outcomes are permuted
63
- loss_full = loss_function ( observed , predict (model , sampled_data ))
67
+ loss_full = calculate_loss ( loss_function , observed , predict (model , sampled_data ), sample_weights )
64
68
65
69
sampled_rows2 = rng .choice (range (observed .shape [0 ]), observed .shape [0 ], replace = False )
66
- loss_baseline = loss_function (observed [sampled_rows2 ], predict (model , sampled_data ))
70
+ sample_weights_rows2 = sample_weights [sampled_rows2 ] if sample_weights is not None else None
71
+ loss_baseline = calculate_loss (loss_function , observed [sampled_rows2 ], predict (model , sampled_data ), sample_weights_rows2 )
67
72
68
73
loss_features = {}
69
74
for variables_set_key in variables :
@@ -74,9 +79,24 @@ def loss_after_permutation(data, y, model, predict, loss_function, variables, N,
74
79
75
80
predicted = predict (model , ndf )
76
81
77
- loss_features [variables_set_key ] = loss_function ( observed , predicted )
82
+ loss_features [variables_set_key ] = calculate_loss ( loss_function , observed , predicted , sample_weights )
78
83
79
84
loss_features ['_full_model_' ] = loss_full
80
85
loss_features ['_baseline_' ] = loss_baseline
81
86
82
87
return pd .DataFrame (loss_features , index = [0 ])
88
+
89
+
90
+ def calculate_loss (loss_function , observed , predicted , sample_weights = None ):
91
+ # Determine if loss function accepts 'sample_weight'
92
+ loss_args = inspect .signature (loss_function ).parameters
93
+ supports_weight = "sample_weight" in loss_args
94
+
95
+ if supports_weight :
96
+ return loss_function (observed , predicted , sample_weight = sample_weights )
97
+ else :
98
+ if sample_weights is not None :
99
+ warnings .warn (
100
+ f"Loss function `{ loss_function .__name__ } ` does not have `sample_weight` argument. Calculating unweighted loss."
101
+ )
102
+ return loss_function (observed , predicted )
0 commit comments