A model is a collection of artifacts that is created by the training process. In deep learning, running inference on a Model usually involves pre-processing and post-processing. DJL provides a ZooModel class, which makes it easy to combine data processing with the model.
This document will show you how to load a pre-trained model in various scenarios.
We recommend you use the ModelZoo API to load models.
The ModelZoo API provides a unified way to load models. The declarative nature of this API allows you to store model information inside a configuration file. This gives you great flexibility to test and deploy your model. See our reference project: DJL Spring Boot Starter.
You can use the Criteria class
to narrow down your search condition and locate the model you want to load.
Criteria class follows
DJL Builder convention. The methods start with set
are required fields, and opt
for optional fields.
You must call setType()
method when creating a Criteria
object:
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.build();
The criteria accept the following optional information:
Note: If multiple models match the criteria you specified, the first one will be returned. The result is not deterministic.
The advantage of using the ModelZoo repository is it provides a way to manage models versions. DJL allows you to update your model in the repository without conflict with existing models. The model consumer can pick up new models without any code changes. DJL searches the classpath and locates the available ModelZoos in the system.
DJL provide several built-in ModelZoos:
You can create your own model zoo if needed, but we are still working on improving the tools to help create custom model zoo repositories.
The following shows how to load a pre-trained model from a file path:
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class) // defines input and output data type
.optTranslator(ImageClassificationTranslator.builder().setSynsetArtifactName("synset.txt").build())
.optModelPath(Paths.get("/var/models/my_resnet50")) // search models in specified path
.optModelName("model/resnet50") // specify model file prefix
.build();
ZooModel<Image, Classifications> model = criteria.loadModel();
DJL supports loading a pre-trained model from a local directory or an archive file.
By default, DJL will use the directory/file name of the URL as the model’s modelName
.
DJL uses modelName
as model file’s prefix to load model file in the directory. We recommend
naming the model file name to be the same as the directory or archive file.
If your model file located in a sub-folder of the model directory or has a different name,
you can specify modelName by .optModelName()
in criteria:
Criteria<Image, Classifications> criteria = Criteria.builder()
.optModelName("traced_model/resnet18.pt") // specify model file prefix
You can also use the URL query string to tell DJL how to load model:
file:///var/models/resnet.zip?model_name=saved_model/resnet-18
DJL supports loading a model from a URL. Since a model consists multiple files, some of URL must be an archive file.
Current supported URL scheme:
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class) // defines input and output data type
.optTranslator(ImageClassificationTranslator.builder().setSynsetArtifactName("synset.txt").build())
.optModelUrls("https://resources.djl.ai/benchmark/squeezenet_v1.1.tar.gz") // search models in specified path
.build();
ZooModel<Image, Classifications> model = criteria.loadModel();
You can customize the artifactId and modelName the same way as loading model from the local file system.
DJL supports loading a model from an S3 bucket using s3://
URL and the AWS plugin. See here for details.
DJL supports loading a model from a Hadoop HDFS file system using hdfs://
URL and the Hadoop plugin. See here for details.
You may want to create additional model zoos using other protocols such as:
DJL is highly extensible and our API allows you to create your own URL protocol handling by extending Repository
class:
RepositoryFactory
interface
make sure getSupportedScheme()
returns URI schemes that you want to handleRepository
interface.RepositoryFactory
. You need add a file META-INF/services/ai.djl.repository.RepositoryFactory
See java ServiceLoader for more detail.You can refer to AWS S3 Repostory for an example.
DJL provides a way for developers to configure a system wide model search path by setting a ai.djl.repository.zoo.location
system properties:
-Dai.djl.repository.zoo.location=https://djl-ai.s3.amazonaws.com/resnet.zip,s3://djl-misc/test/models,file:///myModels
The value can be comma delimited url string.
You may run into ModelNotFoundException
issue. In most cases, it’s caused by the Criteria
you specified
doesn’t match the desired model.
Here is a few tips you can use to help you debug model loading issue:
See here for how to enable debug log
You can use ModelZoo.listModels() API to query available models.
Use the following command to list models in the DJL model zoo:
./gradlew :examples:listmodels
[INFO ] - CV.ACTION_RECOGNITION ai.djl.mxnet:action_recognition:0.0.1 {"backbone":"vgg16","dataset":"ucf101"}
[INFO ] - CV.ACTION_RECOGNITION ai.djl.mxnet:action_recognition:0.0.1 {"backbone":"inceptionv3","dataset":"ucf101"}
[INFO ] - CV.IMAGE_CLASSIFICATION ai.djl.zoo:resnet:0.0.1 {"layers":"50","flavor":"v1","dataset":"cifar10"}
[INFO ] - CV.IMAGE_CLASSIFICATION ai.djl.zoo:mlp:0.0.2 {"dataset":"mnist"}
[INFO ] - NLP.QUESTION_ANSWER ai.djl.mxnet:bertqa:0.0.1 {"backbone":"bert","dataset":"book_corpus_wiki_en_uncased"}
...
You can list models from your model folder with debug log:
./gradlew :examples:listmodels -Dai.djl.logging.level=debug -Dai.djl.repository.zoo.location=file:///mymodels