当前位置:首页 » 《随便一记》 » 正文

CelebA数据集下载|HTTPSConnectionPool(host=‘drive.google.com‘, port=443)|RuntimeError:Dataset not found

16 人参与  2023年04月02日 09:05  分类 : 《随便一记》  评论

点击全文阅读


CeleA是香港中文大学的开放数据,包含10177个名人身份的202599张图片,并且都做好了特征标记,这个数据集对人脸相关的训练来说是非常好用的数据集。

但是它不像其他数据集一样可以自动下载,比如mnist

import torchvision.datasets as dsetimport torchvision.transforms as transformsdataroot = './'imagesize = 64ataset = dset.MNIST(root=dataroot, download=True,                     transform=transforms.Compose([                     transforms.Resize(imagesize),                     transforms.ToTensor(),                     transforms.Normalize((0.5,), (0.5,)),                     ]))

 在torchvision.datasets.celeba.py文件中,celeba的下载方式有两种: 

def download(self) -> None:    # 第一种下载方式,手动下载    if self._check_integrity():        print("Files already downloaded and verified")        return    # 第二种下载方式,从谷歌云盘下载    for (file_id, md5, filename) in self.file_list:        download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)    extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))

显然,如果不能手动下载,就要从谷歌云盘下了。但是谷歌云盘需要科学上网,所以还是手动下吧。

谷歌云盘下载的错误信息:requests.exceptions.ConnectionError: HTTPSConnectionPool(host='drive.google.com', port=443): Max retries exceeded with url: /uc?id=0B7EVK8r0v71pblRyaVFSWGxPY0U&export=download (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x000002746E4E7E20>: Failed to establish a new connection: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败。'))

百度网盘地址:CelebA_免费高速下载|百度网盘-分享无限制 (baidu.com)

那么问题来了,这么多文件,该下哪个呢? 下完之后又放到哪里呢?

还是在torchvision.datasets.celeba.py文件中,有一个检查完整性的函数_check_integrity(),

    def _check_integrity(self) -> bool:        for (_, md5, filename) in self.file_list:            fpath = os.path.join(self.root, self.base_folder, filename)            _, ext = os.path.splitext(filename)            # Allow original archive to be deleted (zip and 7z)            # Only need the extracted images            if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):                return False

这个函数会扫描self.file_list中的内容,

base_folder = "celeba"# There currently does not appear to be a easy way to extract 7z in python (without introducing additional# dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available# right now.file_list = [    # File ID                                      MD5 Hash                            Filename    ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),    # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),    # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),    ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),    ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),    ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),    ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),    # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),    ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),    ]

被注释掉了三个,显然,我们只要把没被注释的六个文件下载就好了。

我们需要建一个存放数据的文件夹data,再在data下建一个文件夹celeba,最后把需要下载的文件放到celeba下。

因为

fpath = os.path.join(self.root, self.base_folder, filename)

base_folder = "celeba",所以使用的时候只需要写根路径就好,比如:

import torchvision.datasets as dsetimport torchvision.transforms as transformsdataroot = './data'dataset = dset.CelebA(root=dataroot, download=True,                      transform=transforms.Compose([                      transforms.Resize(64),                      transforms.CenterCrop(64),                      transforms.ToTensor(),                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))print(dataset)

 最终结果:

Files already downloaded and verifiedDataset CelebA    Number of datapoints: 162770    Root location: ./data    Target type: ['attr']    Split: train    StandardTransformTransform: Compose(               Resize(size=64, interpolation=bilinear, max_size=None, antialias=None)               CenterCrop(size=(64, 64))               ToTensor()               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))           )

感觉其他人好像轻轻松松就使用成功了,不知道为啥我就频频踩坑,先是通过程序无法下载,然后去kaggle上下了,结果报错。然后看了pytorch的官方文档,

以为只用下一个文件,又花时间下了,结果可想而知。

而csdn上大家都是在介绍这个数据集,这篇文章介绍得还蛮简洁,如果有不知道这个数据集的可以看看这个。

最后附上celeba.py

import csvimport osfrom collections import namedtuplefrom typing import Any, Callable, List, Optional, Union, Tupleimport PILimport torchfrom .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archivefrom .vision import VisionDatasetCSV = namedtuple("CSV", ["header", "index", "data"])class CelebA(VisionDataset):    """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.    Args:        root (string): Root directory where images are downloaded to.        split (string): One of {'train', 'valid', 'test', 'all'}.            Accordingly dataset is selected.        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,            or ``landmarks``. Can also be a list to output a tuple with all specified target types.            The targets represent:                - ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes                - ``identity`` (int): label for each person (data points with the same identity are the same person)                - ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)                - ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,                  righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)            Defaults to ``attr``. If empty, ``None`` will be returned as target.        transform (callable, optional): A function/transform that  takes in an PIL image            and returns a transformed version. E.g, ``transforms.PILToTensor``        target_transform (callable, optional): A function/transform that takes in the            target and transforms it.        download (bool, optional): If true, downloads the dataset from the internet and            puts it in root directory. If dataset is already downloaded, it is not            downloaded again.    """    base_folder = "celeba"    # There currently does not appear to be a easy way to extract 7z in python (without introducing additional    # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available    # right now.    file_list = [        # File ID                                      MD5 Hash                            Filename        ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),        # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),        # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),        ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),        ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),        ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),        ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),        # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),        ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),    ]    def __init__(        self,        root: str,        split: str = "train",        target_type: Union[List[str], str] = "attr",        transform: Optional[Callable] = None,        target_transform: Optional[Callable] = None,        download: bool = False,    ) -> None:        super().__init__(root, transform=transform, target_transform=target_transform)        self.split = split        if isinstance(target_type, list):            self.target_type = target_type        else:            self.target_type = [target_type]        if not self.target_type and self.target_transform is not None:            raise RuntimeError("target_transform is specified but target_type is empty")        if download:            self.download()        if not self._check_integrity():            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")        split_map = {            "train": 0,            "valid": 1,            "test": 2,            "all": None,        }        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]        splits = self._load_csv("list_eval_partition.txt")        identity = self._load_csv("identity_CelebA.txt")        bbox = self._load_csv("list_bbox_celeba.txt", header=1)        landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)        attr = self._load_csv("list_attr_celeba.txt", header=1)        mask = slice(None) if split_ is None else (splits.data == split_).squeeze()        if mask == slice(None):  # if split == "all"            self.filename = splits.index        else:            self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]        self.identity = identity.data[mask]        self.bbox = bbox.data[mask]        self.landmarks_align = landmarks_align.data[mask]        self.attr = attr.data[mask]        # map from {-1, 1} to {0, 1}        self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")        self.attr_names = attr.header    def _load_csv(        self,        filename: str,        header: Optional[int] = None,    ) -> CSV:        with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:            data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))        if header is not None:            headers = data[header]            data = data[header + 1 :]        else:            headers = []        indices = [row[0] for row in data]        data = [row[1:] for row in data]        data_int = [list(map(int, i)) for i in data]        return CSV(headers, indices, torch.tensor(data_int))    def _check_integrity(self) -> bool:        for (_, md5, filename) in self.file_list:            fpath = os.path.join(self.root, self.base_folder, filename)            _, ext = os.path.splitext(filename)            # Allow original archive to be deleted (zip and 7z)            # Only need the extracted images            if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):                return False        # Should check a hash of the images        return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))    def download(self) -> None:        if self._check_integrity():            print("Files already downloaded and verified")            return        for (file_id, md5, filename) in self.file_list:            download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)        extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))    def __getitem__(self, index: int) -> Tuple[Any, Any]:        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))        target: Any = []        for t in self.target_type:            if t == "attr":                target.append(self.attr[index, :])            elif t == "identity":                target.append(self.identity[index, 0])            elif t == "bbox":                target.append(self.bbox[index, :])            elif t == "landmarks":                target.append(self.landmarks_align[index, :])            else:                # TODO: refactor with utils.verify_str_arg                raise ValueError(f'Target type "{t}" is not recognized.')        if self.transform is not None:            X = self.transform(X)        if target:            target = tuple(target) if len(target) > 1 else target[0]            if self.target_transform is not None:                target = self.target_transform(target)        else:            target = None        return X, target    def __len__(self) -> int:        return len(self.attr)    def extra_repr(self) -> str:        lines = ["Target type: {target_type}", "Split: {split}"]        return "\n".join(lines).format(**self.__dict__)
check_integrity函数在torchvision.datasets.utils.py中,
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:    if not os.path.isfile(fpath):        return False    if md5 is None:        return True    return check_md5(fpath, md5)


点击全文阅读


本文链接:http://zhangshiyu.com/post/57903.html

<< 上一篇 下一篇 >>

  • 评论(0)
  • 赞助本站

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

关于我们 | 我要投稿 | 免责申明

Copyright © 2020-2022 ZhangShiYu.com Rights Reserved.豫ICP备2022013469号-1