Java程序员学深度学习 DJL上手5 训练自己的模型
- 一、准备环境
- 二、创建示例项目
- 三、准备数据集
- 四、创建模型
- 五、创建训练器
- 1. 训练器配置
- 2. 初始化训练器
- 3. 训练模型
- 4. 保存模型
- 六、源代码
- 1. pom
- 2. java
一、准备环境
- windows
- idea
- maven
二、创建示例项目
三、准备数据集
int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());
这里对数据集进行了分批处理,每批大小32,合适的分批大小将在训练时显著提升性能。
四、创建模型
本节会根据之前文章创建模型。由于 MNIST 数据集中的图像为 28x28 灰度图像,这里我们创建一个具有 28 x 28 输入的 MLP 块。
输出的图输出为 10,因为每个图像可能有 10 个可能的类(0 到 9)。
对于隐藏的层,其大小是猜测的值new int[] {128, 64}
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
五、创建训练器
1. 训练器配置
- 损失函数,用来测量模型与测试数据集的匹配程度,值越低越好;这里定义为
softmaxCrossEntropyLoss()
- 评估函数,也用于测量模型与数据集的匹配程度。与损失不同,它们只供人们查看,不用于优化模型。
- 监听器,用来监控训练过程。
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
Trainer trainer = model.newTrainer(config);
2. 初始化训练器
这里使用输入的形状来初始化训练器。初始化函数里形状的第一个参数是批次大小,这个不影响参数初始化。
第二个参数是输入图像的像素数,即28*28。
trainer.initialize(new Shape(1, 28 * 28));
3. 训练模型
这里使用了DJL的EasyTrain,
int epoch = 2;
EasyTrain.fit(trainer, epoch, mnist, null);
4. 保存模型
保存模型还可以添加一些元数据,如训练迭代次数、训练精度等。
Path modelDir = Paths.get("build/mlp");
Files.createDirectories(modelDir);
model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "mlp");
System.out.println(model);
六、源代码
1. pom
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xundh</groupId>
<artifactId>djl-learning</artifactId>
<version>0.1-SNAPSHOT</version>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.version>8</java.version>
<djl.version>0.13.0-SNAPSHOT</djl.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.7.0</version>
</dependency>
</dependencies>
</project>
2. java
package com.xundh;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
public class NDArrayLearning {
public static void main(String[] args) throws IOException, TranslateException {
int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[]{128, 64}));
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(1, 28 * 28));
int epoch = 2;
EasyTrain.fit(trainer, epoch, mnist, null);
Path modelDir = Paths.get("build/mlp");
Files.createDirectories(modelDir);
model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "mlp");
System.out.println(model);
}
}
运行结果示例: