チラ裏備忘録

情報整理

int64型のnumpy配列では,ToTensor()による正規化が行われない?

型をuint8にすると正しく[0, 1]になりました.以後,気をつけます.

import numpy as np
import torchvision.transforms as transforms

t = transforms.ToTensor()

a = np.random.randint(0, 255, size=(4, 3))

print(a.dtype)
print(t(a)) # なぜか[0, 1]にならない

a = a.astype('uint8')

print(a.dtype)
print(t(a)) # dtypeを変更すると[0, 1]になる

# output
int64
tensor([[[129,  27,  22],
         [ 23, 107,  42],
         [ 62,  44, 228],
         [206, 226,  22]]])
uint8
tensor([[[0.5059, 0.1059, 0.0863],
         [0.0902, 0.4196, 0.1647],
         [0.2431, 0.1725, 0.8941],
         [0.8078, 0.8863, 0.0863]]])