This is post is a walkthrough of creating a Siamese network with FastAI. I had planned to simply use the tutorial from FastAI, but I had to change so much to be able to load the model and make it all work with the latest versions that I figured I would turn it into a blog post. This is really similar to my other post on Siamese Networks with FastAI, except that in this one I will follow on with a post about how to evaluate the model.

Table of Contents

The first major thing I had to do was create a file called siam_utils.py where I put all the code that I would need for both my training and inference. If I just copied the code to my inference notebook it wouldn’t work, so this step is essential.

siam_utils.py

Here is the code for siam_utils.py:


import PIL
import re
from fastai.vision.all import *


class SiameseImage(fastuple):
    def show(self, ctx=None, **kwargs):
        if len(self) > 2:
            img1, img2, similarity = self
        else:
            img1, img2 = self
            similarity = "Undetermined"
        if not isinstance(img1, Tensor):
            if img2.size != img1.size:
                img2 = img2.resize(img1.size)
            t1, t2 = tensor(img1), tensor(img2)
            t1, t2 = t1.permute(2, 0, 1), t2.permute(2, 0, 1)
        else:
            t1, t2 = img1, img2
        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
        return show_image(torch.cat([t1, line, t2], dim=2), title=similarity, ctx=ctx, **kwargs)


def open_image(fname, size=224):
    img = PIL.Image.open(fname).convert("RGB")
    img = img.resize((size, size))
    t = torch.Tensor(np.array(img))
    return t.permute(2, 0, 1).float() / 255.0


def label_func(fname):
    return re.match(r"^(.*)_\d+.jpg$", fname.name).groups()[0]


path = untar_data(URLs.PETS)
files = get_image_files(path / "images")
labels = list(set(files.map(label_func)))
lbl2files = {l: [f for f in files if label_func(f) == l] for l in labels}


class SiameseTransform(Transform):
    def __init__(self, files, splits):
        self.splbl2files = [{l: [f for f in files[splits[i]] if label_func(f) == l] for l in labels} for i in range(2)]
        self.valid = {f: self._draw(f, 1) for f in files[splits[1]]}

    def encodes(self, f):
        f2, same = self.valid.get(f, self._draw(f, 0))
        img1, img2 = PILImage.create(f), PILImage.create(f2)
        return SiameseImage(img1, img2, int(same))

    def _draw(self, f, split=0):
        same = random.random() < 0.5
        cls = label_func(f)
        if not same:
            cls = random.choice(L(l for l in labels if l != cls))
        return random.choice(self.splbl2files[split][cls]), same

Now on to this part. I haven’t included much verbiage in this post. If you want to see the motivation behind some steps, I recommend you check out the FastAI tutorial, which has the same steps in general.

from fastai.vision.all import *
from siam_utils import *
img = PIL.Image.open(files[0])
img

png

open_image(files[0]).shape
torch.Size([3, 224, 224])

Writing your own data block

class ImageTuple(fastuple):
    @classmethod
    def create(cls, fns):
        return cls(tuple(PILImage.create(f) for f in fns))

    def show(self, ctx=None, **kwargs):
        t1, t2 = self
        if not isinstance(t1, Tensor) or not isinstance(t2, Tensor) or t1.shape != t2.shape:
            return ctx
        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
        return show_image(torch.cat([t1, line, t2], dim=2), ctx=ctx, **kwargs)
img = ImageTuple.create((files[0], files[1]))
tst = ToTensor()(img)
type(tst[0]), type(tst[1])
(fastai.torch_core.TensorImage, fastai.torch_core.TensorImage)
img1 = Resize(224)(img)
tst = ToTensor()(img1)
tst.show();

png

def ImageTupleBlock():
    return TransformBlock(type_tfms=ImageTuple.create, batch_tfms=IntToFloatTensor)
splits = RandomSplitter()(files)
splits_files = [files[splits[i]] for i in range(2)]
splits_sets = mapped(set, splits_files)
def get_split(f):
    for i, s in enumerate(splits_sets):
        if f in s:
            return i
    raise ValueError(f"File {f} is not presented in any split.")
splbl2files = [{l: [f for f in s if label_func(f) == l] for l in labels} for s in splits_sets]
def splitter(items):
    def get_split_files(i):
        return [j for j, (f1, f2, same) in enumerate(items) if get_split(f1) == i]

    return get_split_files(0), get_split_files(1)
def draw_other(f):
    same = random.random() < 0.5
    cls = label_func(f)
    split = get_split(f)
    if not same:
        cls = random.choice(L(l for l in labels if l != cls))
    return random.choice(splbl2files[split][cls]), same
def get_tuples(files):
    return [[f, *draw_other(f)] for f in files]
def get_x(t):
    return t[:2]


def get_y(t):
    return t[2]
siamese = DataBlock(
    blocks=(ImageTupleBlock, CategoryBlock),
    get_items=get_tuples,
    get_x=get_x,
    get_y=get_y,
    splitter=splitter,
    item_tfms=Resize(224),
    batch_tfms=[Normalize.from_stats(*imagenet_stats)],
)
dls = siamese.dataloaders(files)
b = dls.one_batch()
explode_types(b)
{tuple: [{__main__.ImageTuple: [fastai.torch_core.TensorImage,
    fastai.torch_core.TensorImage]},
  fastai.torch_core.TensorCategory]}
@typedispatch
def show_batch(x: ImageTuple, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
    if figsize is None:
        figsize = (ncols * 6, max_n // ncols * 3)
    if ctxs is None:
        ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
    ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
    return ctxs
dls.show_batch()

png

class SiameseModel(Module):
    def __init__(self, encoder, head):
        self.encoder, self.head = encoder, head

    def forward(self, x1, x2):
        ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1)
        return self.head(ftrs)
# encoder = create_body(resnet34, cut=-2) # worked in old version of fastai/torchvision
encoder = create_body(resnet34(weights=ResNet34_Weights.IMAGENET1K_V1), cut=-2) # update for new torchvision
head = create_head(512 * 2, 2, ps=0.5)
model = SiameseModel(encoder, head)
def siamese_splitter(model):
    return [params(model.encoder), params(model.head)]
def loss_func(out, targ):
    return CrossEntropyLossFlat()(out, targ.long())
splits = RandomSplitter()(files)
tfm = SiameseTransform(files, splits)
tls = TfmdLists(files, tfm, splits=splits)
dls = tls.dataloaders(
    after_item=[Resize(224), ToTensor], after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]
)
valids = [v[0] for k, v in tfm.valid.items()]
assert not [v for v in valids if v in files[splits[0]]]

Create a Learner

learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), splitter=siamese_splitter, metrics=accuracy)
learn.freeze()

Train the Model

learn.lr_find()
SuggestedLRs(valley=0.0030199517495930195)

png

learn.fit_one_cycle(4, 3e-3)
epoch train_loss valid_loss accuracy time
0 0.543962 0.344742 0.841678 00:21
1 0.367816 0.221206 0.919486 00:21
2 0.298733 0.179088 0.935724 00:22
3 0.252596 0.172828 0.937754 00:22

See the Results

@typedispatch
def show_results(x: SiameseImage, y, samples, outs, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
    if figsize is None:
        figsize = (ncols * 6, max_n // ncols * 3)
    if ctxs is None:
        ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize)
    for i, ctx in enumerate(ctxs):
        title = f'Actual: {["Not similar","Similar"][x[2][i].item()]} \n Prediction: {["Not similar","Similar"][y[2][i].item()]}'
        SiameseImage(x[0][i], x[1][i], title).show(ctx=ctx)
learn.show_results()

png

@patch
def siampredict(self: Learner, item, rm_type_tfms=None, with_input=False):
    res = self.predict(item, rm_type_tfms=None, with_input=False)
    if res[0] == tensor(0):
        SiameseImage(item[0], item[1], "Prediction: Not similar").show()
    else:
        SiameseImage(item[0], item[1], "Prediction: Similar").show()
    return res
imgtest = PILImage.create(files[0])
imgval = PILImage.create(files[100])
siamtest = SiameseImage(imgval, imgtest)
siamtest.show();

png

res = learn.siampredict(siamtest)

png

Save the Model

The last step is to save the model so you can use it later.

import dill
learn.path = Path(os.getenv("MODELS"))
learn.export("siam_catsvdogs_tutorial.pkl", pickle_module=dill)
learn.save("siam_catsvdogs_tutorial_save")
Path('/home/julius/models/models/siam_catsvdogs_tutorial_save.pth')