Skip to content

Commit 332898e

Browse files
authored
add sample_size visualization and update notebooks (#667)
* add sample_size vis and update notebooks * fix mypy error * change Group Size to sample_size
1 parent fa9f19c commit 332898e

File tree

16 files changed

+540
-253
lines changed

16 files changed

+540
-253
lines changed

benchmarks/mimiciv/discharge_prediction.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -1182,9 +1182,9 @@
11821182
"# Reformatting the fairness metrics\n",
11831183
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
11841184
"fairness_metrics = {}\n",
1185-
"# remove the group size from the fairness results and add it to the slice name\n",
1185+
"# remove the sample_size from the fairness results and add it to the slice name\n",
11861186
"for slice_name, slice_results in fairness_results.items():\n",
1187-
" group_size = slice_results.pop(\"Group Size\")\n",
1187+
" group_size = slice_results.pop(\"sample_size\")\n",
11881188
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
11891189
]
11901190
},

benchmarks/mimiciv/icu_mortality_prediction.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -1159,9 +1159,9 @@
11591159
"# Reformatting the fairness metrics\n",
11601160
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
11611161
"fairness_metrics = {}\n",
1162-
"# remove the group size from the fairness results and add it to the slice name\n",
1162+
"# remove the sample_size from the fairness results and add it to the slice name\n",
11631163
"for slice_name, slice_results in fairness_results.items():\n",
1164-
" group_size = slice_results.pop(\"Group Size\")\n",
1164+
" group_size = slice_results.pop(\"sample_size\")\n",
11651165
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
11661166
]
11671167
},

cyclops/evaluate/evaluator.py

+1
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def _compute_metrics(
311311
model_name: str = "model_for_%s" % prediction_column
312312
results.setdefault(model_name, {})
313313
results[model_name][slice_name] = metric_output
314+
results[model_name][slice_name]["sample_size"] = len(sliced_dataset)
314315

315316
set_decode(dataset, True) # restore decoding features
316317

cyclops/evaluate/fairness/evaluator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def evaluate_fairness( # noqa: PLR0912
260260
for prediction_column in fmt_prediction_columns:
261261
results.setdefault(prediction_column, {})
262262
results[prediction_column].setdefault(slice_name, {}).update(
263-
{"Group Size": len(sliced_dataset)},
263+
{"sample_size": len(sliced_dataset)},
264264
)
265265

266266
pred_result = _get_metric_results_for_prediction_and_slice(
@@ -966,7 +966,7 @@ def _compute_parity_metrics(
966966
parity_results[key] = {}
967967
for slice_name, slice_result in prediction_result.items():
968968
for metric_name, metric_value in slice_result.items():
969-
if metric_name == "Group Size":
969+
if metric_name == "sample_size":
970970
continue
971971

972972
# add 'Parity' to the metric name before @threshold, if specified

cyclops/report/model_card/fields.py

+10
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ class PerformanceMetric(
380380
default_factory=list,
381381
)
382382

383+
sample_size: Optional[StrictInt] = Field(
384+
None,
385+
description="The sample size used to compute this metric.",
386+
)
387+
383388

384389
class User(
385390
BaseModelCardField,
@@ -599,6 +604,11 @@ class MetricCard(
599604
description="Timestamps for each point in the history.",
600605
)
601606

607+
sample_sizes: Optional[List[int]] = Field(
608+
None,
609+
description="Sample sizes for each point in the history.",
610+
)
611+
602612

603613
class MetricCardCollection(BaseModelCardField, composable_with="Overview"):
604614
"""A collection of metric cards to be displayed in the model card."""

cyclops/report/report.py

+62-30
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
get_histories,
4949
get_names,
5050
get_passed,
51+
get_sample_sizes,
5152
get_slices,
5253
get_thresholds,
5354
get_timestamps,
@@ -855,6 +856,7 @@ def log_quantitative_analysis(
855856
pass_fail_threshold_fns: Optional[
856857
Union[Callable[[Any, float], bool], List[Callable[[Any, float], bool]]]
857858
] = None,
859+
sample_size: Optional[int] = None,
858860
**extra: Any,
859861
) -> None:
860862
"""Add a quantitative analysis to the report.
@@ -921,6 +923,7 @@ def log_quantitative_analysis(
921923
"slice": metric_slice,
922924
"decision_threshold": decision_threshold,
923925
"description": description,
926+
"sample_size": sample_size,
924927
**extra,
925928
}
926929

@@ -958,42 +961,70 @@ def log_quantitative_analysis(
958961
field_type=field_type,
959962
)
960963

961-
def log_performance_metrics(self, metrics: Dict[str, Any]) -> None:
962-
"""Add a performance metric to the `Quantitative Analysis` section.
964+
def log_performance_metrics(
965+
self,
966+
results: Dict[str, Any],
967+
metric_descriptions: Dict[str, str],
968+
pass_fail_thresholds: Union[float, Dict[str, float]] = 0.7,
969+
pass_fail_threshold_fn: Callable[[float, float], bool] = lambda x,
970+
threshold: bool(x >= threshold),
971+
) -> None:
972+
"""
973+
Log all performance metrics to the model card report.
963974
964975
Parameters
965976
----------
966-
metrics : Dict[str, Any]
967-
A dictionary of performance metrics. The keys should be the name of the
968-
metric, and the values should be the value of the metric. If the metric
969-
is a slice metric, the key should be the slice name followed by a slash
970-
and then the metric name (e.g. "slice_name/metric_name"). If no slice
971-
name is provided, the slice name will be "overall".
972-
973-
Raises
974-
------
975-
TypeError
976-
If the given metrics are not a dictionary with string keys.
977+
results : Dict[str, Any]
978+
Dictionary containing the results,
979+
with keys in the format "split/metric_name".
980+
metric_descriptions : Dict[str, str]
981+
Dictionary mapping metric names to their descriptions.
982+
pass_fail_thresholds : Union[float, Dict[str, float]], optional
983+
The threshold(s) for pass/fail tests.
984+
Can be a single float applied to all metrics,
985+
or a dictionary mapping "split/metric_name" to individual thresholds.
986+
Default is 0.7.
987+
pass_fail_threshold_fn : Callable[[float, float], bool], optional
988+
Function to determine if a metric passes or fails.
989+
Default is lambda x, threshold: bool(x >= threshold).
977990
991+
Returns
992+
-------
993+
None
978994
"""
979-
_raise_if_not_dict_with_str_keys(metrics)
980-
for metric_name, metric_value in metrics.items():
981-
name_split = metric_name.split("/")
982-
if len(name_split) == 1:
983-
slice_name = "overall"
984-
metric_name = name_split[0] # noqa: PLW2901
985-
else: # everything before the last slash is the slice name
986-
slice_name = "/".join(name_split[:-1])
987-
metric_name = name_split[-1] # noqa: PLW2901
988-
989-
# TODO: create plot
995+
# Extract sample sizes
996+
sample_sizes = {
997+
key.split("/")[0]: value
998+
for key, value in results.items()
999+
if "sample_size" in key.split("/")[1]
1000+
}
9901001

991-
self._log_field(
992-
data={"type": metric_name, "value": metric_value, "slice": slice_name},
993-
section_name="quantitative_analysis",
994-
field_name="performance_metrics",
995-
field_type=PerformanceMetric,
996-
)
1002+
# Log metrics
1003+
for name, metric in results.items():
1004+
split, metric_name = name.split("/")
1005+
if metric_name != "sample_size":
1006+
metric_value = metric.tolist() if hasattr(metric, "tolist") else metric
1007+
1008+
# Determine the threshold for this specific metric
1009+
if isinstance(pass_fail_thresholds, dict):
1010+
threshold = pass_fail_thresholds.get(
1011+
name, 0.7
1012+
) # Default to 0.7 if not specified
1013+
else:
1014+
threshold = pass_fail_thresholds
1015+
1016+
self.log_quantitative_analysis(
1017+
"performance",
1018+
name=metric_name,
1019+
value=metric_value,
1020+
description=metric_descriptions.get(
1021+
metric_name, "No description provided."
1022+
),
1023+
metric_slice=split,
1024+
pass_fail_thresholds=threshold,
1025+
pass_fail_threshold_fns=pass_fail_threshold_fn,
1026+
sample_size=sample_sizes.get(split),
1027+
)
9971028

9981029
# TODO: MERGE/COMPARE MODEL CARDS
9991030

@@ -1162,6 +1193,7 @@ def export(
11621193
"get_names": get_names,
11631194
"get_histories": get_histories,
11641195
"get_timestamps": get_timestamps,
1196+
"get_sample_sizes": get_sample_sizes,
11651197
}
11661198
template.globals.update(func_dict)
11671199

0 commit comments

Comments
 (0)