Source code for mmflow.datasets.sintel

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
from typing import Optional, Sequence, Union

from .base_dataset import BaseDataset
from .builder import DATASETS

[docs]@DATASETS.register_module() class Sintel(BaseDataset): """Sintel optical flow dataset. Args: pass_style (str): Pass style for Sintel dataset, and it has 2 options ['clean', 'final']. Default: 'clean'. scene (str, list, optional): Scene in Sintel dataset, if scene is None, it means collecting data in all of scene of Sintel dataset. Default: None. """ def __init__(self, *args, pass_style: str = 'clean', scene: Optional[Union[str, Sequence[str]]] = None, **kwargs) -> None: all_pass = ['clean', 'final'] assert pass_style in all_pass self.pass_style = pass_style self.scene = scene super().__init__(*args, **kwargs) self.dataset_name += f' {self.pass_style} subset'
[docs] def load_data_info(self) -> None: """Load data information, including file path of image1, image2 and optical flow.""" self._get_data_dir() img1_filenames = [] img2_filenames = [] flow_filenames = [] occ_filenames = [] invalid_filenames = [] def get_filenames(data_dir, data_suffix, img_idx=None): data_filenames = [] for data_dir_ in data_dir: data_filenames_ = self.get_data_filename( data_dir_, data_suffix) data_filenames_.sort() if img_idx == 1: data_filenames += data_filenames_[:-1] elif img_idx == 2: data_filenames += data_filenames_[1:] else: data_filenames += data_filenames_ return data_filenames img1_filenames = get_filenames(self.img1_dir, self.img1_suffix, 1) img2_filenames = get_filenames(self.img2_dir, self.img2_suffix, 2) flow_filenames = get_filenames(self.flow_dir, self.flow_suffix) occ_filenames = get_filenames(self.occ_dir, self.occ_suffix) invalid_filenames = get_filenames(self.invalid_dir, self.invalid_suffix, 1) self.load_img_info(self.data_infos, img1_filenames, img2_filenames) self.load_ann_info(self.data_infos, flow_filenames, 'filename_flow') self.load_ann_info(self.data_infos, occ_filenames, 'filename_occ') self.load_ann_info(self.data_infos, invalid_filenames, 'filename_invalid')
def _get_data_dir(self) -> None: """Get the paths for images and optical flow.""" self.img1_suffix = '.png' self.img2_suffix = '.png' self.flow_suffix = '.flo' self.occ_suffix = '.png' self.invalid_suffix = '.png' self.subset_dir = 'training' if self.test_mode else 'training' self.data_root = osp.join(self.data_root, self.subset_dir) img_root = osp.join(self.data_root, self.pass_style) flow_root = osp.join(self.data_root, 'flow') occ_root = osp.join(self.data_root, 'occlusions') invalid_root = osp.join(self.data_root, 'invalid') all_scene = os.listdir(img_root) self.scene = all_scene if self.scene is None else self.scene self.scene = self.scene if isinstance(self.scene, (list, tuple)) else [self.scene] assert set(self.scene).issubset(set(all_scene)) self.img1_dir = [osp.join(img_root, s) for s in self.scene] self.img2_dir = [osp.join(img_root, s) for s in self.scene] self.flow_dir = [osp.join(flow_root, s) for s in self.scene] self.occ_dir = [osp.join(occ_root, s) for s in self.scene] self.invalid_dir = [osp.join(invalid_root, s) for s in self.scene]
[docs] def pre_pipeline(self, results: Sequence[dict]) -> None: """Prepare results dict for pipeline. For Sintel, there is an additional annotation, invalid. """ super().pre_pipeline(results) results['filename_invalid'] = results['ann_info']['filename_invalid']
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.