djl

Spark Support for DJL

Overview

This module contains the Spark support extension, which allows DJL to be used seamlessly with Apache Spark.

Some key features of the DJL Spark Extension include:

Documentation

The latest javadocs can be found on here.

You can also build the latest javadocs locally using the following command:

./gradlew javadoc

Installation

You can pull the module from the central Maven repository by including the following dependency in your pom.xml file:

<dependency>
    <groupId>ai.djl.spark</groupId>
    <artifactId>spark_2.12</artifactId>
    <version>0.31.0</version>
</dependency>

Usage

Using the DJL Spark Extension is simple and straightforward. Here is an example of how to use it to run image classification on a large dataset using Apache Spark and DJL:

Scala

import ai.djl.spark.task.vision.ImageClassifier

val classifier = new ImageClassifier()
  .setInputCols(Array("origin", "height", "width", "nChannels", "mode", "data"))
  .setOutputCol("prediction")
  .setEngine("PyTorch")
  .setModelUrl("djl://ai.djl.pytorch/resnet")
  .setTopK(2)
var outputDf = classifier.classify(df)

Python

from djl_spark.task.vision import ImageClassifier

classifier = ImageClassifier(input_cols=["origin", "height", "width", "nChannels", "mode", "data"],
                             output_col="prediction",
                             engine="PyTorch",
                             model_url="djl://ai.djl.pytorch/resnet",
                             top_k=2)
outputDf = classifier.classify(df)

See examples for more details.