Shortcuts

Source code for mmflow.core.evaluation.eval_hooks

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Union

import mmcv
from mmcv.runner import Hook, IterBasedRunner, get_dist_info
from torch.utils.data import DataLoader

from .evaluation import (multi_gpu_online_evaluation,
                         single_gpu_online_evaluation)


[docs]class EvalHook(Hook): """Evaluation hook. Args: dataloader (DataLoader): A PyTorch dataloader. interval (int): Evaluation interval (by epochs). Default: 1. by_epoch (bool): Determine perform evaluation by epoch or by iteration. If set to True, it will perform by epoch. Otherwise, by iteration. Default: False. dataset_name (str, list, optional): The name of the dataset this evaluation hook will doing in. eval_kwargs (any): Evaluation arguments fed into the evaluate function of the dataset. """ def __init__(self, dataloader: DataLoader, interval: int = 1, by_epoch: bool = False, dataset_name: Optional[Union[str, Sequence[str]]] = None, **eval_kwargs: Any) -> None: if not (isinstance(dataloader, DataLoader) or mmcv.is_list_of(dataloader, DataLoader)): raise TypeError('dataloader must be a pytorch DataLoader, but got' f' {type(dataloader)}') self.dataloader = dataloader if isinstance( dataloader, (tuple, list), ) else [dataloader] self.interval = interval self.by_epoch = by_epoch self.eval_kwargs = eval_kwargs self.dataset_name = dataset_name if isinstance( dataset_name, (tuple, list)) else [dataset_name] assert len(self.dataloader) == len(self.dataset_name)
[docs] def after_train_iter(self, runner: IterBasedRunner) -> None: """After train iteration.""" if self.by_epoch or not self.every_n_iters(runner, self.interval): return runner.log_buffer.clear() self.evaluate(runner)
[docs] def after_train_epoch(self, runner: IterBasedRunner) -> None: """After train epoch.""" if not self.every_n_epochs(runner, self.interval): return self.evaluate(runner)
[docs] def evaluate(self, runner: IterBasedRunner) -> None: """Evaluation function to call online evaluate function.""" for i_dataset, i_dataloader in zip(self.dataset_name, self.dataloader): results_metrics = single_gpu_online_evaluation( runner.model, i_dataloader, **self.eval_kwargs) for name, val in results_metrics.items(): if i_dataset is not None: key = f'{name} in {i_dataset}' else: key = name runner.log_buffer.output[key] = val runner.log_buffer.ready = True
[docs]class DistEvalHook(EvalHook): """Distributed evaluation hook. Args: dataloader (DataLoader): A PyTorch dataloader. interval (int): Evaluation interval (by epochs). Default: 1. tmpdir (str | None): Temporary directory to save the results of all processes. Default: None. gpu_collect (bool): Whether to use gpu or cpu to collect results. Default: False. by_epoch (bool): Determine perform evaluation by epoch or by iteration. If set to True, it will perform by epoch. Otherwise, by iteration. Default: False. dataset_name (str, list, optional): The name of the dataset this evaluation hook will doing in. eval_kwargs (any): Evaluation arguments fed into the evaluate function of the dataset. """ def __init__(self, dataloader: DataLoader, interval: int = 1, tmpdir: Optional[str] = None, gpu_collect: bool = False, by_epoch: bool = False, dataset_name: Optional[Union[str, Sequence[str]]] = None, **eval_kwargs: Any) -> None: if not (isinstance(dataloader, DataLoader) or mmcv.is_list_of(dataloader, DataLoader)): raise TypeError('dataloader must be a pytorch DataLoader, but got' f' {type(dataloader)}') self.by_epoch = by_epoch self.dataloader = dataloader if isinstance( dataloader, (tuple, list), ) else [dataloader] self.interval = interval self.tmpdir = tmpdir self.gpu_collect = gpu_collect self.eval_kwargs = eval_kwargs self.dataset_name = dataset_name if isinstance( dataset_name, (tuple, list)) else [dataset_name] assert len(self.dataloader) == len(self.dataset_name)
[docs] def evaluate(self, runner: IterBasedRunner): """Evaluation function to call online evaluate function.""" for i_dataset, i_dataloader in zip(self.dataset_name, self.dataloader): results_metrics = multi_gpu_online_evaluation( runner.model, i_dataloader, **self.eval_kwargs) rank, _ = get_dist_info() if rank == 0: for name, val in results_metrics.items(): if i_dataset is not None: key = f'{name} in {i_dataset}' else: key = name runner.log_buffer.output[key] = val runner.log_buffer.ready = True
Read the Docs v: latest
Versions
latest
stable
1.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.