In this example, you learn how to implement inference code with a pytorch model to extract and compare face features.
Extract face feature:
The source code can be found at FeatureExtraction.java.
The model github can be found at facenet-pytorch.
Compare face features: The source code can be found at FeatureComparison.java.
To configure your development environment, follow setup.
You can find the image used in this example in the project test resource folder:
src/test/resources/kana1.jpg
src/test/resources/kana2.jpg
Use the following command to run the project:
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.face.FeatureExtraction
Your output should look like the following:
[INFO ] - [-0.04026184, -0.019486362, -0.09802659, 0.01700999, 0.037829027, ...]
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.face.FeatureComparison
Your output should look like the following:
[INFO ] - 0.9022607
# With pip:
pip install facenet-pytorch
# or clone this repo, removing the '-' to allow python imports:
git clone https://github.com/timesler/facenet-pytorch.git facenet_pytorch
from facenet_pytorch import InceptionResnetV1
import torch
from PIL import Image
import ssl
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
# Legacy Python that doesn't verify HTTPS certificates by default
pass
else:
# Handle target environment that doesn't support HTTPS verification
ssl._create_default_https_context = _create_unverified_https_context
# Create an inception resnet (in eval mode):
resnet = InceptionResnetV1(pretrained='vggface2').eval()
img = Image.open('/path/to/any/face/image.jpg')
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 320, 320)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(resnet, example)
# For control flow, use script
#script_module = torch.jit.script(model)
# Save the TorchScript model
traced_script_module.save("face_feature.pt")
output = traced_script_module(torch.rand(1,3,320, 320))
#print(traced_script_module.code)
print(output)