Java程序员学深度学习 DJL上手1
- 一、简介
- 二、准备环境
- 二、从头开始
- 1. 使用anaconda新建mxnet环境
- 2. 新建一个空的 idea maven项目。
- 3. 安装Python Commu Edition 插件
- 4. 在Module新增conda的环境
- 三、一个简单的模型
一、简介
官网地址:http://djl.ai/
DPL是一款开源的Java深度学习框架,易启动、Java程序员容易上手操作。
下面是一个推理的伪代码示例:
// Assume user uses a pre-trained model from model zoo, they just need to load it
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION) // find object dection model
.setTypes(Image.class, Classifications.class) // define input and output
.optFilter("backbone", "resnet50") // choose network architecture
.build();
try (ZooModel<Image, Classifications> model = criteria.loadModel()) {
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
Image img = ImageFactory.getInstance().fromUrl("http://..."); // read image
Classifications result = predictor.predict(img);
// get the classification and probability
...
}
}
训练的伪代码示例:
// Construct your neural network with built-in blocks
Block block = new Mlp(28, 28);
try (Model model = Model.newInstance("mlp")) { // Create an empty model
model.setBlock(block); // set neural network to model
// Get training and validation dataset (MNIST dataset)
Dataset trainingSet = new Mnist.Builder().setUsage(Usage.TRAIN) ... .build();
Dataset validateSet = new Mnist.Builder().setUsage(Usage.TEST) ... .build();
// Setup training configurations, such as Initializer, Optimizer, Loss ...
TrainingConfig config = setupTrainingConfig();
try (Trainer trainer = model.newTrainer(config)) {
/*
* Configure input shape based on dataset to initialize the trainer.
* 1st axis is batch axis, we can use 1 for initialization.
* MNIST is 28x28 grayscale image and pre processed into 28 * 28 NDArray.
*/
Shape inputShape = new Shape(1, 28 * 28);
trainer.initialize(new Shape[] {inputShape});
EasyTrain.fit(trainer, epoch, trainingSet, validateSet);
}
// Save the model
model.save(modelDir, "mlp");
}
DPL 仍需要其它的深度学习框架。下面 demo 需要安装pytorch。
二、准备环境
- mac
- 已安装anaconda
- idea
二、从头开始
1. 使用anaconda新建mxnet环境
conda create -n mxnet
conda activate mxnet
conda install mxnet
2. 新建一个空的 idea maven项目。
3. 安装Python Commu Edition 插件
4. 在Module新增conda的环境
三、一个简单的模型
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Translator;
import java.io.File;
import java.io.FileInputStream;
public class SimpleSSDExample {
public static void main(String[] args) throws Exception{
// 下载aws 预先训练好的 resnet 模型
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
// 下载标签
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
// 图片预处理
Pipeline pipeline = new Pipeline();
pipeline.add(new Resize(256))
.add(new CenterCrop(224, 224))
.add(new ToTensor())
.add(new Normalize(
new float[]{0.485f, 0.456f, 0.406f},
new float[]{0.229f, 0.224f, 0.225f}));
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.setPipeline(pipeline)
.optApplySoftmax(true)
.build();
// 设置模型
System.setProperty("ai.djl.repository.zoo.location", "build/pytorch_models/resnet18");
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
// only search the model in local directory
// "ai.djl.localmodelzoo:{name of the model}"
.optArtifactId("ai.djl.localmodelzoo:resnet18")
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel<Image,Classifications> model = ModelZoo.loadModel(criteria);
// 加载本地图片
File fs=new File("/Users/apple/Downloads/1.jpeg");
Image img = ImageFactory.getInstance().fromInputStream(new FileInputStream(fs));
// 执行推理
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);
System.out.println(classifications);
}
}
随便找了一张狗的图片:
运行结果: