[{"data":1,"prerenderedAt":268},["ShallowReactive",2],{"article-13":3},{"id":4,"title":5,"body":6,"create":256,"description":12,"extension":257,"labels":258,"locked":260,"meta":261,"navigation":262,"path":263,"seo":264,"stem":265,"update":266,"__hash__":267},"articles/article/13.md","MMPretrain 实现 AUC 评价指标",{"type":7,"value":8,"toc":250},"minimark",[9,13,17,110,113,122,128,137,143,160,177,183,186,192,198,204,214,220,223,241,247],[10,11,12],"p",{},"在使用 MMPretrain 训练模型的时候，发现其提供的指标中并没有 AUC 指标。但自己需要用到该指标，于是按照官方文档的指导，实现并在训练过程中使用了该指标。",[14,15,16],"h2",{"id":16},"官方文档对添加指标的说明",[18,19,20,23,45,56,66,77,83,96,102],"blockquote",{},[10,21,22],{},"MMPretrain 支持为追求更高定制化的用户实现定制化的评估指标。",[10,24,25,26,30,31,34,35,38,39,44],{},"您需要在 ",[27,28,29],"code",{},"mmpretrain/evaluation/metrics"," 下创建一个新文件，并在该文件中实现新的指标，例如，在\n",[27,32,33],{},"mmpretrain/evaluation/>metrics/my_metric.py"," 中。并创建一个自定义的评估指标类 ",[27,36,37],{},"MyMetric","\n继承 ",[40,41,43],"a",{"href":42},"mmengine.%3Eevaluator.BaseMetric","MMEngine 中的 BaseMetric","。",[10,46,47,48,51,52,55],{},"需要分别覆盖数据格式处理方法",[27,49,50],{},"process","和度量计算方法",[27,53,54],{},"compute_metrics","。 将其添加到“METRICS”注册表以实施任何自定义评估指标。",[57,58,64],"pre",{"className":59,"code":61,"language":62,"meta":63},[60],"language-python","from mmengine.evaluator import BaseMetric\nfrom mmpretrain.registry import METRICS\n\n@METRICS.register_module()\nclass MyMetric(BaseMetric):\n\n   def process(self, data_batch: Sequence[Dict], data_samples: Sequence[Dict]):\n   \"\"\" The processed results should be stored in ``self.results``, which will\n       be used to computed the metrics when all batches have been processed.\n       `data_batch` stores the batch data from dataloader,\n       and `data_samples` stores the batch outputs from model.\n   \"\"\"\n       ...\n\n   def compute_metrics(self, results: List):\n   \"\"\" Compute the metrics from processed results and returns the evaluation results.\n   \"\"\"\n       ...\n","python","",[27,65,61],{"__ignoreMap":63},[10,67,68,69,72,73,76],{},"然后，将其导入 ",[27,70,71],{},"mmpretrain/evaluation/metrics/__init__.py"," 以将其添加到 ",[27,74,75],{},"mmpretrain.evaluation"," 包中。",[57,78,81],{"className":79,"code":80,"language":62,"meta":63},[60],"# In mmpretrain/evaluation/metrics/__init__.py\n...\nfrom .my_metric import MyMetric\n\n__all__ = [..., 'MyMetric']\n",[27,82,80],{"__ignoreMap":63},[10,84,85,86,89,90,93,94,44],{},"最后，在配置文件的 ",[27,87,88],{},"val_evaluator"," 和 ",[27,91,92],{},"test_evaluator"," 字段中使用 ",[27,95,37],{},[57,97,100],{"className":98,"code":99,"language":62,"meta":63},[60],"val_evaluator = dict(type='MyMetric', ...)\ntest_evaluator = val_evaluator\n",[27,101,99],{"__ignoreMap":63},[57,103,108],{"className":104,"code":106,"language":107,"meta":63},[105],"language-text","更多的细节可以参考 {external+mmengine:doc}`MMEngine 文档: Evaluation \u003Cdesign/evaluation>`.\n","text",[27,109,106],{"__ignoreMap":63},[14,111,112],{"id":112},"错误尝试",[10,114,115,116,118,119],{},"按照文档指引，在 ",[27,117,29],{}," 目录下创建新的文件 ",[27,120,121],{},"auc.py",[57,123,126],{"className":124,"code":125,"language":62,"meta":63},[60],"@METRICS.register_module()\nclass SingleLabelAUC(BaseMetric):\n  default_prefix: Optional[str] = 'auc' # 在终端打印时的前缀，如 `compute_metrics` 中的字典是 {'foo': 1, 'bar': 0.9}，前缀是 `bzz` 那么在打印结果的时候会显示 `bzz/foo: 1 bzz/bar: 0.9`\n  def process(self, data_batch, data_samples: Sequence[dict]) -> None:\n    # 这一函数将训练的结果存入 self.results 变量中\n\n  def compute_metrics(self, results: List) -> dict:\n    # 这一函数计算并返回评价指标，字典格式，返回的字典会在验证和测试的时候将所有的键值对输出到终端中\n",[27,127,125],{"__ignoreMap":63},[10,129,130,132,133,136],{},[27,131,50],{}," 函数的写法，参照 ",[27,134,135],{},"mmpretrain/evaluation/metrics/single_label.py"," 文件中的写法：",[57,138,141],{"className":139,"code":140,"language":62,"meta":63},[60],"for data_sample in data_samples:\n    result = dict()\n    if 'pred_score' in data_sample:\n        result['pred_score'] = data_sample['pred_score'].cpu()\n    else:\n        result['pred_label'] = data_sample['pred_label'].cpu()\n    result['gt_label'] = data_sample['gt_label'].cpu()\n    # Save the result to `self.results`.\n    self.results.append(result)\n",[27,142,140],{"__ignoreMap":63},[10,144,145,146,148,149,152,153,156,157,159],{},"从这段方法不难看出，",[27,147,50],{}," 函数仅仅是将模型预测的结果与真实的标签存入 ",[27,150,151],{},"self.results"," 中，而计算 ",[27,154,155],{},"auc"," 也只需要这两个参数，所以在\n",[27,158,54],{}," 函数中，我们只需要读取这两个参数并计算。",[10,161,162,163,165,166,169,170,176],{},"对于 ",[27,164,155],{}," 的计算，",[27,167,168],{},"scikit-learn","\n库中已有",[40,171,175],{"href":172,"rel":173},"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score",[174],"nofollow","现成的方法","\n，所以在实现 auc 的时候，直接调用该库的函数：",[57,178,181],{"className":179,"code":180,"language":62,"meta":63},[60],"metrics = {}\n\ntarget = torch.cat([res['gt_label'] for res in results]) # 拼接所有的正确标签\nif 'pred_score' in results[0]:\n    pred = torch.stack([res['pred_score'] for res in results]) # 拼接所有的预测结果\n    auc = roc_auc_score(target, pred, average='macro', sample_weight=None,\n                        max_fpr=None, multi_class='ovr', labels=None)\n\n    metrics['auc'] = auc\nelse:\n    # If only label in the `pred_label`.\n    pred = torch.cat([res['pred_label'] for res in results]) # 拼接所有的预测结果\n    auc = roc_auc_score(target, pred, average='macro', sample_weight=None,\n                        max_fpr=None, multi_class='ovr', labels=None)\n    metrics['auc'] = auc\n\nreturn metrics\n",[27,182,180],{"__ignoreMap":63},[10,184,185],{},"在这个函数中，返回的字典键名没有特殊要求，可以随便写。",[10,187,188,189,191],{},"写完这个类之后，我们需要在 ",[27,190,71],{}," 中注册自定义的评价指标：",[57,193,196],{"className":194,"code":195,"language":62,"meta":63},[60],"from .auc import SingleLabelAUC\n\n_all_ = [..., 'SingleLabelAUC']\n",[27,197,195],{"__ignoreMap":63},[10,199,200,201],{},"最后在我们的训练配置文件添加如下配置：",[27,202,203],{},"val_evaluator = dict(type='SingleLabelAUC')",[10,205,206,207,210,211,213],{},"测试的时候使用的是 ",[27,208,209],{},"resnet18_8xb16_cifar10.py"," 配置文件，能够输出运算的结果，完整的 ",[27,212,121],{}," 文件如下：",[57,215,218],{"className":216,"code":217,"language":62,"meta":63},[60],"import torch\n\nfrom mmengine.evaluator import BaseMetric\nfrom mmpretrain.registry import METRICS\nfrom sklearn.metrics import roc_auc_score\nfrom typing import List, Optional, Sequence, Any\n\n\n@METRICS.register_module()\nclass SingleLabelAUC(BaseMetric):\n    default_prefix: Optional[str] = 'auc'\n\n    def process(self, data_batch, data_samples: Sequence[dict]) -> None:\n        for data_sample in data_samples:\n            result = dict()\n            if 'pred_score' in data_sample:\n                result['pred_score'] = data_sample['pred_score'].cpu()\n            else:\n                result['pred_label'] = data_sample['pred_label'].cpu()\n            result['gt_label'] = data_sample['gt_label'].cpu()\n            self.results.append(result)\n\n    def compute_metrics(self, results: List) -> dict:\n        metrics = {}\n\n        # concat\n        target = torch.cat([res['gt_label'] for res in results])\n        if 'pred_score' in results[0]:\n            pred = torch.stack([res['pred_score'] for res in results])\n\n            auc = roc_auc_score(target, pred, average='macro', sample_weight=None,\n                                max_fpr=None, multi_class='ovr', labels=None)\n\n            metrics['auc'] = auc\n        else:\n            pred = torch.cat([res['pred_label'] for res in results])\n            auc = roc_auc_score(target, pred, average='macro', sample_weight=None,\n                                max_fpr=None, multi_class='ovr', labels=None)\n            metrics['auc'] = auc\n\n        return metrics\n",[27,219,217],{"__ignoreMap":63},[14,221,222],{"id":222},"正确写法",[10,224,225,226,229,230,232,233,236,237,240],{},"在所有文件设置完成之后，进行了一次测试，发现该方法在 ",[27,227,228],{},"acc=0.5"," 的情况下，",[27,231,155],{}," 达到了 ",[27,234,235],{},"0.9","\n（训练时的第一次验证），不是很正常，于是放弃自己写的代码，直接复制 ",[27,238,239],{},"SingleLabelMetric"," 的代码，并替换一些关键算法：",[57,242,245],{"className":243,"code":244,"language":62,"meta":63},[60],"import torch\nimport numpy as np\nimport torch.nn.functional as F\n\nfrom mmengine.evaluator import BaseMetric\nfrom mmpretrain.registry import METRICS\nfrom sklearn.metrics import roc_auc_score\nfrom typing import List, Optional, Sequence, Union\nfrom .single_label import to_tensor\n\n\n@METRICS.register_module()\nclass SingleLabelAUC(BaseMetric):\n\n    default_prefix: Optional[str] = 'single-label'\n\n    def __init__(self,\n                 thrs: Union[float, Sequence[Union[float, None]], None] = 0.,\n                 average: Optional[str] = 'macro',\n                 num_classes: Optional[int] = None,\n                 collect_device: str = 'cpu',\n                 prefix: Optional[str] = None) -> None:\n        super().__init__(collect_device=collect_device, prefix=prefix)\n\n        if isinstance(thrs, float) or thrs is None:\n            self.thrs = (thrs, )\n        else:\n            self.thrs = tuple(thrs)\n\n        self.average = average\n        self.num_classes = num_classes\n\n    def process(self, data_batch, data_samples: Sequence[dict]):\n\n\n        for data_sample in data_samples:\n            result = dict()\n            if 'pred_score' in data_sample:\n                result['pred_score'] = data_sample['pred_score'].cpu()\n            else:\n                num_classes = self.num_classes or data_sample.get(\n                    'num_classes')\n                assert num_classes is not None, \\\n                    'The `num_classes` must be specified if no `pred_score`.'\n                result['pred_label'] = data_sample['pred_label'].cpu()\n                result['num_classes'] = num_classes\n            result['gt_label'] = data_sample['gt_label'].cpu()\n            self.results.append(result)\n\n    def compute_metrics(self, results: List):\n\n        metrics = {}\n\n        # concat\n        target = torch.cat([res['gt_label'] for res in results])\n        if 'pred_score' in results[0]:\n            pred = torch.stack([res['pred_score'] for res in results])\n            auc = self.calculate(\n                pred, target, thrs=self.thrs, average=self.average)\n\n            multi_thrs = len(self.thrs) > 1\n            for i, thr in enumerate(self.thrs):\n                if multi_thrs:\n                    suffix = 'auc_no-thr' if thr is None else f'_thr-{thr:.2f}'\n                else:\n                    suffix = 'auc'\n                print(type(auc), auc)\n                for k, v in enumerate(auc):\n                    print(k, v)\n                    metrics[str(k)+suffix] = v\n        else:\n            pred = torch.cat([res['pred_label'] for res in results])\n            auc = self.calculate(\n                pred,\n                target,\n                average=self.average,\n                num_classes=results[0]['num_classes'])\n            metrics['auc'] = auc\n\n        return metrics\n\n    @staticmethod\n    def calculate(\n        pred: Union[torch.Tensor, np.ndarray, Sequence],\n        target: Union[torch.Tensor, np.ndarray, Sequence],\n        thrs: Sequence[Union[float, None]] = (0., ),\n        average: Optional[str] = 'macro',\n        num_classes: Optional[int] = None,\n    ) -> Union[torch.Tensor, List[torch.Tensor]]:\n        average_options = ['micro', 'macro', None]\n        assert average in average_options, 'Invalid `average` argument, ' \\\n            f'please specify from {average_options}.'\n\n        pred = to_tensor(pred)\n        target = to_tensor(target).to(torch.int64)\n        assert pred.size(0) == target.size(0), \\\n            f\"The size of pred ({pred.size(0)}) doesn't match \"\\\n            f'the target ({target.size(0)}).'\n\n        if pred.ndim == 1:\n            assert num_classes is not None, \\\n                'Please specify the `num_classes` if the `pred` is labels ' \\\n                'intead of scores.'\n            gt_positive = F.one_hot(target.flatten(), num_classes)\n            pred_positive = F.one_hot(pred.to(torch.int64), num_classes)\n            return roc_auc_score(gt_positive, pred_positive, ,\n                                                average=average)\n        else:\n            # For pred score, calculate on all thresholds.\n            num_classes = pred.size(1)\n            pred_score, pred_label = torch.topk(pred, k=1)\n            pred_score = pred_score.flatten()\n            pred_label = pred_label.flatten()\n\n            gt_positive = F.one_hot(target.flatten(), num_classes)\n\n            results = []\n            for thr in thrs:\n                pred_positive = F.one_hot(pred_label, num_classes)\n                if thr is not None:\n                    pred_positive[pred_score \u003C= thr] = 0\n                results.append(\n                    roc_auc_score(gt_positive, pred_positive, \n                                                 average=average))\n\n            return results\n",[27,246,244],{"__ignoreMap":63},[10,248,249],{},"经过再一次的实验，在其他条件同样的情况下，第一次训练验证时，auc 在 0.75 左右，符合预期。",{"title":63,"searchDepth":251,"depth":251,"links":252},2,[253,254,255],{"id":16,"depth":251,"text":16},{"id":112,"depth":251,"text":112},{"id":222,"depth":251,"text":222},"2024-09-16T02:46:46.000Z","md",[259],"其他",false,{},true,"/article/13",{"title":5,"description":12},"article/13","2024-09-24T03:07:09.000Z","_hFEiSPxXljI8fPofXHU6cL6jJab6f0FJGZh7o_YE5g",1755235549197]