This post walks through how to create a Siamese network using FastAI. It is based on a tutorial from FastAI, so some of the classes and function are directly copied from there, but I’ve adding things as well.

Table of Contents

First we import the FastAI library.

from fastai.vision.all import *

Then make sure GPU support is available.

torch.cuda.is_available()
True

First we download the images. We’ll use the cats and dogs dataset because it’s built into FastAI.

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

Now we’ll need a class to deal with the Siamese images. It’ll make everything easier.

class ImageTuple(fastuple):
    @classmethod
    def create(cls, file_paths):
        """
        creates a tuple of two file paths
        """
        return cls(tuple(PILImage.create(f) for f in file_paths))
    
    def show(self, ctx=None, **kwargs): 
        t1, t2 = self
        if not isinstance(t1, Tensor) or not isinstance(t2, Tensor) or t1.shape != t2.shape:
            return ctx
        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
        return show_image(torch.cat([t1,line,t2], dim=2), ctx=ctx, **kwargs)

Let’s look at an image.

PILImage.create(files[0])

png

Now let’s look at a tuple.

img = ImageTuple.create((files[0], files[1]))
tst = ToTensor()(img)
type(tst[0]),type(tst[1])
(fastai.torch_core.TensorImage, fastai.torch_core.TensorImage)

FastAI has a class ToTensor that converts items into tensors. We can use it to visualize the image.

img1 = Resize(224)(img)
tst = ToTensor()(img1)
tst.show();

png

Now we create an image tuple block. It’s really simple. It just returns a TransformBlock where we can the function to create a Siamese image.

def ImageTupleBlock():
    return TransformBlock(type_tfms=ImageTuple.create, batch_tfms=IntToFloatTensor)

We have to split the data. We can do a random split.

splits = RandomSplitter()(files)
splits_files = [files[splits[i]] for i in range(2)]
splits_sets = mapped(set, splits_files)
def get_split(f):
    for i,s in enumerate(splits_sets):
        if f in s:
            return i
    raise ValueError(f'File {f} is not presented in any split.')

Now we need a function to return the label given an item. Fortunately, the label is in the filename, so we can use regular expressions to extract it.

def label_func(fname: str):
    """
    Extract the label from the file name.
    """
    return re.match(r'^(.*)_\d+.jpg$', fname.name).groups()[0]

label_func(files[0])
'havanese'
labels = list(set(files.map(label_func)))
splbl2files = [{l: [f for f in s if label_func(f) == l] for l in labels} for s in splits_sets]

def splitter(items):
    """
    This is the function that we actually pass to the DataBlock. All the others are helpers of some form.
    """
    def get_split_files(i):
        return [j for j,(f1,f2,same) in enumerate(items) if get_split(f1)==i]
    return get_split_files(0),get_split_files(1)
def draw_other(f):
    """
    Find the pair for the other Siamese image.
    """
    same = random.random() < 0.5
    cls = label_func(f)
    split = get_split(f)
    if not same:
        cls = random.choice(L(l for l in labels if l != cls)) 
    return random.choice(splbl2files[split][cls]),same
def get_tuples(files):
    """
    This function turns a list of files into a list of inputs and label combos.
    So each element in this list contains paths of two images and the associated label
    """
    return [[f, *draw_other(f)] for f in files]

def get_x(t):
    return t[:2]
def get_y(t):
    return t[2]

New we can build the DataBlock.

siamese = DataBlock(
    blocks=(ImageTupleBlock, CategoryBlock),
    get_items=get_tuples,
    get_x=get_x, get_y=get_y,
    item_tfms=Resize(224),
    batch_tfms=[Normalize.from_stats(*imagenet_stats)]
)
# siamese.summary(files) # long output

Create Dataloaders

Everything looks good. Now we can create our Dataloaders.

dls = siamese.dataloaders(files)

Let’s make sure it looks good.

b = dls.one_batch()
explode_types(b)
{tuple: [{__main__.ImageTuple: [fastai.torch_core.TensorImage,
    fastai.torch_core.TensorImage]},
  fastai.torch_core.TensorCategory]}

Now we can make a function to show a batch.

@typedispatch
def show_batch(x:ImageTuple, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
    if figsize is None:
        figsize = (ncols*6, max_n//ncols * 3)
    if ctxs is None:
        ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
    ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
    return ctxs
dls.show_batch()

png

One of my favorite things to do with data in FastAI is turn it into a pandas DataFrame. I find it so much easy to clean or modify the data in the format. In this case, we’ll balance the dataset using a DataFrame.

Now let’s turn it into a DataFrame

def df_label_func(fname):
    return fname[2]

def create_dataframe(data, label_func, is_valid=False) -> pd.DataFrame:
    """
    Create pandas DataFrame from DataLoaders
    """
    items = [x[:2] for x in data.valid.items] if is_valid else [x[:2] for x in data.train.items]
    labels = [x[2] for x in data.valid.items] if is_valid else [x[2] for x in data.train.items]
    is_valid = [is_valid] * len(items)
    return pd.DataFrame({'items': items, 'label': labels, 'is_valid':is_valid})
train_df = create_dataframe(dls, df_label_func, is_valid=False)
valid_df = create_dataframe(dls, df_label_func, is_valid=True)

Let’s see what we have.

train_df.head()
items label is_valid
0 [/home/julius/.fastai/data/oxford-iiit-pet/images/Ragdoll_117.jpg, /home/julius/.fastai/data/oxford-iiit-pet/images/leonberger_79.jpg] False False
1 [/home/julius/.fastai/data/oxford-iiit-pet/images/Maine_Coon_214.jpg, /home/julius/.fastai/data/oxford-iiit-pet/images/samoyed_40.jpg] False False
2 [/home/julius/.fastai/data/oxford-iiit-pet/images/havanese_55.jpg, /home/julius/.fastai/data/oxford-iiit-pet/images/havanese_30.jpg] True False
3 [/home/julius/.fastai/data/oxford-iiit-pet/images/english_cocker_spaniel_37.jpg, /home/julius/.fastai/data/oxford-iiit-pet/images/english_cocker_spaniel_90.jpg] True False
4 [/home/julius/.fastai/data/oxford-iiit-pet/images/chihuahua_165.jpg, /home/julius/.fastai/data/oxford-iiit-pet/images/english_setter_91.jpg] False False

Now we have it as a dataframe, it’s easier to clean or manipulate the data.

Now let’s oversample the data

train_df['label'].value_counts()
True     2971
False    2941
Name: label, dtype: int64
def oversample(df: pd.DataFrame) -> pd.DataFrame:
    """
    We will use pd.DataFrame.sample to oversample from the rows with labels that need to be oversampled.
    """
    num_max_labels = df['label'].value_counts().max()
    dfs = [df]
    for class_index, group in df.groupby('label'):
        num_new_samples = num_max_labels - len(group)
        dfs.append(group.sample(num_new_samples, replace=True, random_state=42))
    return pd.concat(dfs)
oversampled_train_df = oversample(train_df)
oversampled_train_df['label'].value_counts()
False    2971
True     2971
Name: label, dtype: int64

Great. Now we’ll combine it with the validation data again. Note that we didn’t mess with the validation dataset’s label balance.

full_df = pd.concat([oversampled_train_df, valid_df])

Now we make another DataBlock that can read from a csv. Fortunately, FastAI’s support for this is great.

new_dblock = DataBlock(
    blocks    = (ImageTupleBlock, CategoryBlock),
    get_x=ColReader(0),
    get_y=ColReader('label'),
    splitter  = ColSplitter(),
    item_tfms = Resize(224),
    batch_tfms=[Normalize.from_stats(*imagenet_stats)]
    )
new_data = new_dblock.dataloaders(full_df)

Modeling

Now we build a simple Siamese Model.

class SiameseModel(Module):
    def __init__(self, encoder, head):
        self.encoder, self.head = encoder, head

    def forward(self, x):
        """
        This takes x1 and x2 as two separate things. But if we change it to x, maybe we're OK?
        """
        x1, x2 = x
        ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1)
        return self.head(ftrs)
model_meta[resnet34]
{'cut': -2,
 'split': <function fastai.vision.learner._resnet_split(m)>,
 'stats': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])}
encoder = create_body(resnet34, cut=-2)
head = create_head(512*2, 2, ps=0.5)
model = SiameseModel(encoder, head)
def siamese_splitter(model):
    return [params(model.encoder), params(model.head)]

def loss_func(out, targ):
    return CrossEntropyLossFlat()(out, targ.long())
learn = Learner(new_data, model, loss_func=loss_func, splitter=siamese_splitter, metrics=accuracy)
learn.freeze()
learn.lr_find()
SuggestedLRs(valley=0.0006918309954926372)

png

learn.fit_one_cycle(4, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.616689 0.371314 0.854533 00:24
1 0.372786 0.235297 0.917456 00:24
2 0.246533 0.197698 0.930988 00:24
3 0.185981 0.194794 0.928281 00:24

There you go! A fully trained Siamese network.