当前位置:首页 » 《随便一记》 » 正文

Java程序员学深度学习 DJL上手1_编程圈子-谢厂节的博客

20 人参与  2021年11月30日 09:43  分类 : 《随便一记》  评论

点击全文阅读


Java程序员学深度学习 DJL上手1

  • 一、简介
  • 二、准备环境
  • 二、从头开始
      • 1. 使用anaconda新建mxnet环境
      • 2. 新建一个空的 idea maven项目。
      • 3. 安装Python Commu Edition 插件
      • 4. 在Module新增conda的环境
  • 三、一个简单的模型

一、简介

官网地址:http://djl.ai/
DPL是一款开源的Java深度学习框架,易启动、Java程序员容易上手操作。

下面是一个推理的伪代码示例:

  // Assume user uses a pre-trained model from model zoo, they just need to load it
    Criteria<Image, Classifications> criteria =
            Criteria.builder()
                    .optApplication(Application.CV.OBJECT_DETECTION) // find object dection model
                    .setTypes(Image.class, Classifications.class) // define input and output
                    .optFilter("backbone", "resnet50") // choose network architecture
                    .build();

    try (ZooModel<Image, Classifications> model = criteria.loadModel()) {
        try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
            Image img = ImageFactory.getInstance().fromUrl("http://..."); // read image
            Classifications result = predictor.predict(img);

            // get the classification and probability
            ...
        }
    }

训练的伪代码示例:

    // Construct your neural network with built-in blocks
    Block block = new Mlp(28, 28);

    try (Model model = Model.newInstance("mlp")) { // Create an empty model
        model.setBlock(block); // set neural network to model

        // Get training and validation dataset (MNIST dataset)
        Dataset trainingSet = new Mnist.Builder().setUsage(Usage.TRAIN) ... .build();
        Dataset validateSet = new Mnist.Builder().setUsage(Usage.TEST) ... .build();

        // Setup training configurations, such as Initializer, Optimizer, Loss ...
        TrainingConfig config = setupTrainingConfig();
        try (Trainer trainer = model.newTrainer(config)) {
            /*
             * Configure input shape based on dataset to initialize the trainer.
             * 1st axis is batch axis, we can use 1 for initialization.
             * MNIST is 28x28 grayscale image and pre processed into 28 * 28 NDArray.
             */
            Shape inputShape = new Shape(1, 28 * 28);
            trainer.initialize(new Shape[] {inputShape});

            EasyTrain.fit(trainer, epoch, trainingSet, validateSet);
        }

        // Save the model
        model.save(modelDir, "mlp");
    }

DPL 仍需要其它的深度学习框架。下面 demo 需要安装pytorch。

二、准备环境

  • mac
  • 已安装anaconda
  • idea

二、从头开始

1. 使用anaconda新建mxnet环境

conda create -n mxnet
conda activate mxnet
conda install mxnet

2. 新建一个空的 idea maven项目。

3. 安装Python Commu Edition 插件

在这里插入图片描述

4. 在Module新增conda的环境

在这里插入图片描述
在这里插入图片描述

三、一个简单的模型

import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Translator;

import java.io.File;
import java.io.FileInputStream;

public class SimpleSSDExample {
    public static void main(String[] args) throws Exception{
        // 下载aws 预先训练好的 resnet 模型 
        DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
        // 下载标签
        DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
        // 图片预处理
        Pipeline pipeline = new Pipeline();
        pipeline.add(new Resize(256))
                .add(new CenterCrop(224, 224))
                .add(new ToTensor())
                .add(new Normalize(
                        new float[]{0.485f, 0.456f, 0.406f},
                        new float[]{0.229f, 0.224f, 0.225f}));

        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                .setPipeline(pipeline)
                .optApplySoftmax(true)
                .build();

        // 设置模型
        System.setProperty("ai.djl.repository.zoo.location", "build/pytorch_models/resnet18");
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                // only search the model in local directory
                // "ai.djl.localmodelzoo:{name of the model}"
                .optArtifactId("ai.djl.localmodelzoo:resnet18")
                .optTranslator(translator)
                .optProgress(new ProgressBar()).build();
        ZooModel<Image,Classifications> model = ModelZoo.loadModel(criteria);
        // 加载本地图片
        File fs=new File("/Users/apple/Downloads/1.jpeg");
        Image img = ImageFactory.getInstance().fromInputStream(new FileInputStream(fs));
        // 执行推理
        Predictor<Image, Classifications> predictor = model.newPredictor();
        Classifications classifications = predictor.predict(img);
        System.out.println(classifications);
    }
}

随便找了一张狗的图片:
在这里插入图片描述

运行结果:
在这里插入图片描述


点击全文阅读


本文链接:http://zhangshiyu.com/post/31426.html

环境  模型  安装  
<< 上一篇 下一篇 >>

  • 评论(0)
  • 赞助本站

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

关于我们 | 我要投稿 | 免责申明

Copyright © 2020-2022 ZhangShiYu.com Rights Reserved.豫ICP备2022013469号-1