|
48 | 48 | get_histories,
|
49 | 49 | get_names,
|
50 | 50 | get_passed,
|
| 51 | + get_sample_sizes, |
51 | 52 | get_slices,
|
52 | 53 | get_thresholds,
|
53 | 54 | get_timestamps,
|
@@ -855,6 +856,7 @@ def log_quantitative_analysis(
|
855 | 856 | pass_fail_threshold_fns: Optional[
|
856 | 857 | Union[Callable[[Any, float], bool], List[Callable[[Any, float], bool]]]
|
857 | 858 | ] = None,
|
| 859 | + sample_size: Optional[int] = None, |
858 | 860 | **extra: Any,
|
859 | 861 | ) -> None:
|
860 | 862 | """Add a quantitative analysis to the report.
|
@@ -921,6 +923,7 @@ def log_quantitative_analysis(
|
921 | 923 | "slice": metric_slice,
|
922 | 924 | "decision_threshold": decision_threshold,
|
923 | 925 | "description": description,
|
| 926 | + "sample_size": sample_size, |
924 | 927 | **extra,
|
925 | 928 | }
|
926 | 929 |
|
@@ -958,42 +961,70 @@ def log_quantitative_analysis(
|
958 | 961 | field_type=field_type,
|
959 | 962 | )
|
960 | 963 |
|
961 |
| - def log_performance_metrics(self, metrics: Dict[str, Any]) -> None: |
962 |
| - """Add a performance metric to the `Quantitative Analysis` section. |
| 964 | + def log_performance_metrics( |
| 965 | + self, |
| 966 | + results: Dict[str, Any], |
| 967 | + metric_descriptions: Dict[str, str], |
| 968 | + pass_fail_thresholds: Union[float, Dict[str, float]] = 0.7, |
| 969 | + pass_fail_threshold_fn: Callable[[float, float], bool] = lambda x, |
| 970 | + threshold: bool(x >= threshold), |
| 971 | + ) -> None: |
| 972 | + """ |
| 973 | + Log all performance metrics to the model card report. |
963 | 974 |
|
964 | 975 | Parameters
|
965 | 976 | ----------
|
966 |
| - metrics : Dict[str, Any] |
967 |
| - A dictionary of performance metrics. The keys should be the name of the |
968 |
| - metric, and the values should be the value of the metric. If the metric |
969 |
| - is a slice metric, the key should be the slice name followed by a slash |
970 |
| - and then the metric name (e.g. "slice_name/metric_name"). If no slice |
971 |
| - name is provided, the slice name will be "overall". |
972 |
| -
|
973 |
| - Raises |
974 |
| - ------ |
975 |
| - TypeError |
976 |
| - If the given metrics are not a dictionary with string keys. |
| 977 | + results : Dict[str, Any] |
| 978 | + Dictionary containing the results, |
| 979 | + with keys in the format "split/metric_name". |
| 980 | + metric_descriptions : Dict[str, str] |
| 981 | + Dictionary mapping metric names to their descriptions. |
| 982 | + pass_fail_thresholds : Union[float, Dict[str, float]], optional |
| 983 | + The threshold(s) for pass/fail tests. |
| 984 | + Can be a single float applied to all metrics, |
| 985 | + or a dictionary mapping "split/metric_name" to individual thresholds. |
| 986 | + Default is 0.7. |
| 987 | + pass_fail_threshold_fn : Callable[[float, float], bool], optional |
| 988 | + Function to determine if a metric passes or fails. |
| 989 | + Default is lambda x, threshold: bool(x >= threshold). |
977 | 990 |
|
| 991 | + Returns |
| 992 | + ------- |
| 993 | + None |
978 | 994 | """
|
979 |
| - _raise_if_not_dict_with_str_keys(metrics) |
980 |
| - for metric_name, metric_value in metrics.items(): |
981 |
| - name_split = metric_name.split("/") |
982 |
| - if len(name_split) == 1: |
983 |
| - slice_name = "overall" |
984 |
| - metric_name = name_split[0] # noqa: PLW2901 |
985 |
| - else: # everything before the last slash is the slice name |
986 |
| - slice_name = "/".join(name_split[:-1]) |
987 |
| - metric_name = name_split[-1] # noqa: PLW2901 |
988 |
| - |
989 |
| - # TODO: create plot |
| 995 | + # Extract sample sizes |
| 996 | + sample_sizes = { |
| 997 | + key.split("/")[0]: value |
| 998 | + for key, value in results.items() |
| 999 | + if "sample_size" in key.split("/")[1] |
| 1000 | + } |
990 | 1001 |
|
991 |
| - self._log_field( |
992 |
| - data={"type": metric_name, "value": metric_value, "slice": slice_name}, |
993 |
| - section_name="quantitative_analysis", |
994 |
| - field_name="performance_metrics", |
995 |
| - field_type=PerformanceMetric, |
996 |
| - ) |
| 1002 | + # Log metrics |
| 1003 | + for name, metric in results.items(): |
| 1004 | + split, metric_name = name.split("/") |
| 1005 | + if metric_name != "sample_size": |
| 1006 | + metric_value = metric.tolist() if hasattr(metric, "tolist") else metric |
| 1007 | + |
| 1008 | + # Determine the threshold for this specific metric |
| 1009 | + if isinstance(pass_fail_thresholds, dict): |
| 1010 | + threshold = pass_fail_thresholds.get( |
| 1011 | + name, 0.7 |
| 1012 | + ) # Default to 0.7 if not specified |
| 1013 | + else: |
| 1014 | + threshold = pass_fail_thresholds |
| 1015 | + |
| 1016 | + self.log_quantitative_analysis( |
| 1017 | + "performance", |
| 1018 | + name=metric_name, |
| 1019 | + value=metric_value, |
| 1020 | + description=metric_descriptions.get( |
| 1021 | + metric_name, "No description provided." |
| 1022 | + ), |
| 1023 | + metric_slice=split, |
| 1024 | + pass_fail_thresholds=threshold, |
| 1025 | + pass_fail_threshold_fns=pass_fail_threshold_fn, |
| 1026 | + sample_size=sample_sizes.get(split), |
| 1027 | + ) |
997 | 1028 |
|
998 | 1029 | # TODO: MERGE/COMPARE MODEL CARDS
|
999 | 1030 |
|
@@ -1162,6 +1193,7 @@ def export(
|
1162 | 1193 | "get_names": get_names,
|
1163 | 1194 | "get_histories": get_histories,
|
1164 | 1195 | "get_timestamps": get_timestamps,
|
| 1196 | + "get_sample_sizes": get_sample_sizes, |
1165 | 1197 | }
|
1166 | 1198 | template.globals.update(func_dict)
|
1167 | 1199 |
|
|
0 commit comments