In this tutorial, I will show you how to get started with this new open source library from Amazon to build and deploy Deep Learning in Java.
Open IntelliJ and create a new Maven project and add the following to the pom.xml file.
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
</properties>
<repositories>
<repository>
<id>djl.ai</id>
<url>https://oss.sonatype.org/content/repositories/snapshots/</url>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.3.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-model-zoo</artifactId>
<version>0.3.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-native-mkl</artifactId>
<version>1.6.0-b-SNAPSHOT</version>
<classifier>osx-x86_64</classifier>
<!-- <classifier>linux-x86_64</classifier>-->
<!-- <classifier>win-x86_64</classifier>-->
<scope>runtime</scope>
</dependency>
</dependencies>
I'm using the osx-x86_64 classifier because I'm on Mac, if you are using Linux or Windows uncomment one of the other classifiers.
<!-- <classifier>linux-x86_64</classifier>-->
<!-- <classifier>win-x86_64</classifier>-->
Now create the DetectObject.java file and make it look like following:
package com.kodnito.djl;
import ai.djl.MalformedModelException;
import ai.djl.modality.cv.DetectedObjects;
import ai.djl.modality.cv.ImageVisualization;
import ai.djl.modality.cv.util.BufferedImageUtils;
import ai.djl.mxnet.zoo.MxModelZoo;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Paths;
public class DetectObject {
private static final Logger logger = LoggerFactory.getLogger(DetectObject.class);
public static void main(String[] args) throws MalformedModelException, ModelNotFoundException, TranslateException, IOException {
var detectedObjects = new DetectObject().predict();
logger.info("{}", detectedObjects);
}
public DetectedObjects predict() throws MalformedModelException, ModelNotFoundException, IOException, TranslateException {
var imageFile = Paths.get("src/main/resources/new-york.jpg");
var img = BufferedImageUtils.fromFile(imageFile);
ZooModel<BufferedImage, DetectedObjects> model =
MxModelZoo.SSD.loadModel(new ProgressBar());
var predictor = model.newPredictor().predict(img);
ImageVisualization.drawBoundingBoxes(img, predictor);
ImageIO.write(img, "png", new File("new-york.png"));
model.close();
return predictor;
}
}
First we load image from disk.
var imageFile = Paths.get("src/main/resources/new-york.jpg");
var img = BufferedImageUtils.fromFile(imageFile);
Now, we load a SSD (Single Shot MultiBox Detector) model from the MXNet model zoo.
The Model Zoo is a collection of pre-trained models, which are ready to use out of the box.
ZooModel<BufferedImage, DetectedObjects> model =
MxModelZoo.SSD.loadModel(new ProgressBar());
We create predictor and detect an object in the image.
var predictResult = model.newPredictor().predict(img);
We check detected result.
ImageVisualization.drawBoundingBoxes(img, predictResult);
Saving the result.
ImageIO.write(img, "png", new File("new-york.png"));
Image we loaded:
Useful links:
Download source code: GitHub
DJL Website
DJL GitHub
MXNet
Share this: