Saving and loading neural networks is always a little tricky. The best way to do it depends on what exactly you’re trying to do. Do you want to continue training the model? If so, you’ll need to save the optimizer state. If you just want to run it for inference, you might not need this. It also gets more complicated with custom functions. In this post, I’ll walk through how to save a FastAI model and then load it again for inference.
Table of Contents
Training
Let’s train a simple model straight from the FastAI tutorial.
from fastai.vision.all import *
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f):
return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224))
Due to IPython and Windows limitation, python multiprocessing isn't available now.
So `number_workers` is changed to 0 to avoid getting stuck
# if you're on Windows, you need to add this
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.818718 | 1.532427 | 0.527273 | 00:08 |
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.300550 | 0.545057 | 0.254545 | 00:05 |
learn.show_results()
That looks good. Now let’s save it.
Saving
learn.save
There are two options for saving models in FastAI, learn.save
and learn.export
. learn.save
saves the model and, by default, also saves the optimizer state. This is what you want to do if you want to resume training.
If you used learn.save
, you’ll need to use learner.load
to load it. You’ll need all of your functions from the previous notebook again.
learn.export
If you’re going to save it for inference, I recommend using .export()
. You should also make sure the model is going to save in the right place by setting learn.path
.
learn.path = Path(os.getenv('MODELS'))
If you have custom functions as we do above with label_func
, you’ll need to use the dill
module to pickle your learner. If you use the default one, it will save the function names but not the content.
import dill
learn.export('catsvdogs.pkl', pickle_module=dill)
Loading
Nothing in this section will use any of the code above. This could be done in a completely separate file or notebook.
import dill
from fastai.tabular.all import *
When you load a model, you can load it to the GPU or CPU. By default, it will load to the CPU. If you want it to load to the GPU, you’ll need to pass cpu=False
.
learn = load_learner(Path(os.getenv('MODELS')) / 'catsvdogs.pkl', cpu=False, pickle_module=dill)
test = np.zeros((100, 100))
learn.predict(test)
('True', tensor(1), tensor([0.3902, 0.6098]))