This posts contains some of my notes from switching to PyTorch after having worked with TensorFlow and Keras for a long time.
Table of Contents
import imageio import torch
PyTorch requires channels first, so you may have to use the
permute method to get your images in the right shape.
image_path = 'roo.jpg'
img_arr = imageio.imread(image_path)
img = torch.from_numpy(img_arr)
torch.Size([256, 192, 3])
fixed_img = img.permute(2, 0, 1)
torch.Size([3, 256, 192])
fixed_img isn’t a copy of the original, it’s just a reshaping. So if you go on to change
fixed_img will change as well.
Missing from PyTorch
There is no
model.summary() like there is in Keras, so instead I recommend torchinfo. This is a great tool for sanity checking a network.