当前位置:首页 » 《我的小黑屋》 » 正文

解决pytorch dataloader报错:Trying to resize storage that is not resizable

8 人参与  2024年04月12日 13:50  分类 : 《我的小黑屋》  评论

点击全文阅读


省流

碰到这种问题,尤其是平常运行的好好的,换个数据集就报错,那大概率就是数据集本身有问题。顺着这个思路去debug即可。

问题描述

dataloader在设置num_workers为任何大于0的数时出现如下报错:

Traceback (most recent call last):  File "/home/username/distort/main.py", line 131, in <module>    model, perms, accs = train_model(dinfos, args.mid, args.pretrained, args.num_classes, args.treps, args.testep, args.test_dist, device, args.distort)  File "/home/username/distort/main.py", line 65, in train_model    for img, y in train_dataloader:  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__    data = self._next_data()  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1376, in _next_data    return self._process_data(data)  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1402, in _process_data    data.reraise()  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/_utils.py", line 461, in reraise    raise exceptionRuntimeError: Caught RuntimeError in DataLoader worker process 0.Original Traceback (most recent call last):  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop    data = fetcher.fetch(index)  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch    return self.collate_fn(data)  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in default_collate    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in <listcomp>    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 140, in default_collate    out = elem.new(storage).resize_(len(batch), *list(elem.size()))RuntimeError: Trying to resize storage that is not resizable

num_workers设置为0时则出现新的报错:

Traceback (most recent call last):  File "/home/username/distort/main.py", line 130, in <module>    model, perms, accs = train_model(dinfos, args.mid, args.pretrained, args.num_classes, args.treps, args.testep, args.test_dist, device, args.distort)  File "/home/username/distort/main.py", line 64, in train_model    for img, y in train_dataloader:  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__    data = self._next_data()  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 721, in _next_data    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch    return self.collate_fn(data)  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in default_collate    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in <listcomp>    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.  File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 141, in default_collate    return torch.stack(batch, 0, out=out)RuntimeError: stack expects each tensor to be equal size, but got [3, 64, 64] at entry 0 and [1, 64, 64] at entry 32

问题排查

第二个报错还是比较容易排查的。在自定义dataset类的__getitem__()函数中加入代码:当读取的tensor的shape[0]为1时打印该tensor对应原始数据文件的路径。

发现数据集中确实有通道数为1的图片(我用的tiny-imagenet-200),没想到真的是数据集的锅。

问题解决

在__getitem__()函数使用tensor类的expand,对于通道数不对的tensor,调用expand(3,-1,-1)即可。之后num_workers设置为0或者其他正数时都能正常加载数据集。

另外需要注意,有的博客说num_workers需要匹配GPU核心的数量,这逻辑属实离谱。从上面的第一个报错就能看出来,出错点和CUDA库毫无关系,因此不可能是GPU相关的问题。至少按照常用的加载数据集的方法,num_workers就是规定dataloader使用CPU线程的最大数量。


点击全文阅读


本文链接:http://zhangshiyu.com/post/94282.html

<< 上一篇 下一篇 >>

  • 评论(0)
  • 赞助本站

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

关于我们 | 我要投稿 | 免责申明

Copyright © 2020-2022 ZhangShiYu.com Rights Reserved.豫ICP备2022013469号-1