Saving and loading neural networks is always a little tricky. The best way to do it depends on what exactly you’re trying to do. Do you want to continue training the model? If so, you’ll need to save the optimizer state. If you just want to run it for inference, you might not need this. It also gets more complicated with custom functions. In this post I’ll walk through how to save a FastAI model and then load it again for inference.
Table of Contents
Let’s train a simple model straight from the FastAI tutorial.
from fastai.vision.all import *
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f.isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224))
Due to IPython and Windows limitation, python multiprocessing isn't available now. So `number_workers` is changed to 0 to avoid getting stuck
# if you're on Windows, you need to add this from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True
learn = cnn_learner(dls, resnet34, metrics=error_rate) learn.fine_tune(1)
That looks good. Now let’s save it.
There are two options for saving models in FastAI,
learn.save saves the model and, by default, also saves the optimizer state. This is what you want to do if you want to resume training.
If you used
learn.save, you’ll need to use
learner.load to load it. You’ll need all of your functions from the previous notebook again.
If you’re going to save it for interence, I recommend using
.export(). You should also make sure the model is going to save in the right place by setting
learn.path = Path(os.getenv('MODELS'))
If you have custom functions as we do above with
label_func, you’ll need to use the
dill module to pickle your learner. If you use the default one, it will save the function names but not the content.
Nothing in this section will use any of the code above. This could be done in a completely separate file or notebook.
import dill from fastai.tabular.all import *
When you load a model, you can load it to the GPU or CPU. By default, it will load to the CPU. If you want it to load to the GPU, you’ll need to pass
learn = load_learner(Path(os.getenv('MODELS')) / 'catsvdogs.pkl', cpu=False, pickle_module=dill)
test = np.zeros((100, 100))
('True', tensor(1), tensor([0.3902, 0.6098]))