关于图像识别问题,图片输入格式的错误提示
收藏
报错类型TypeError: img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got
出错的代码是:
File "", line 45, in __getitem__
img = self.train_transforms(img)
定义的getitem为继承了Datasetd 子类
def __getitem__(self, idx):
img = cv2.imread(self.all_image_path[idx])
if self.mode == 'train':
img = self.train_transforms(img)
label = int(self.label_index[str(self.all_image_path[idx]).split('/')[1]])
return img, label
self.train_transforms定义如下:
self.train_transforms = Compose([
Resize(size=image_size),
RandomHorizontalFlip(),
RandomRotation(15),
Transpose(),
Normalize(mean=127.5, std=127.5)
])
请问大佬为什么会出现这种错误提示,求解。。。
0
收藏
请登录后评论
imread返回的是mat吧?需要转一下格式。
新版Paddle与Numpy兼容很好,建议转成ndarray 。
img = cv2.imread(self.all_image_path[idx]) 后,打印下img的shape看看,有时读不出来甚至是空的
如果图片没问题,可以把 Compose 里的数据增强加加减减,看看哪个处理出错了
再多点报错信息呗。
PIL Image or Tensor Image or ndarray 。格式只支持 PIL.OPEN,Tensor 和 array 三个。而且维度要求为2或3。 一个可能是你读的格式不是这三种,可以用type()进行查看。最可能的是第二种:opencv 读取的数据格式会把 通道放在第三个维度 即 HWC 但是transforms要求的格式为 CHW。你可以通过把 数据的shape打印出来看下 第一个维度是否为 通道维度。如果不是的话,使用np.transpose进行变换。