This posts contains some of my notes from switching to PyTorch after having worked with TensorFlow and Keras for a long time.

Table of Contents

import imageio
import torch

Channels First

PyTorch requires channels first, so you may have to use the permute method to get your images in the right shape.

image_path = 'roo.jpg'
img_arr = imageio.imread(image_path)
img = torch.from_numpy(img_arr)
torch.Size([256, 192, 3])
fixed_img = img.permute(2, 0, 1)
torch.Size([3, 256, 192])

Note that fixed_img isn’t a copy of the original, it’s just a reshaping. So if you go on to change img, fixed_img will change as well.

Missing from PyTorch


There is no model.summary() like there is in Keras, so instead I recommend torchinfo. This is a great tool for sanity checking a network.