今天就跟大家聊聊有关dl4j如何使用遗传神经网络完成手写数字识别,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。
实现步骤
1.随机初始化若干个智能体(神经网络),并让智能体识别训练数据,并对识别结果进行排序
2.随机在排序结果中选择一个作为母本,并在比母本识别率更高的智能体中随机选择一个作为父本
3.随机选择母本或父本同位的神经网络超参组成新的智能体
4.按照母本的排序对智能体进行超参调整,排序越靠后调整幅度越大(1%~10%)之间
5.让新的智能体识别训练集并放入排行榜,并移除排行榜最后一位
6.重复2~5过程,让识别率越来越高
这个过程就类似于自然界的优胜劣汰,将神经网络超参看作dna,超参的调整看作dna的突变;当然还可以把拥有不同隐藏层的神经网络看作不同的物种,让竞争过程更加多样化.当然我们这里只讨论一种神经网络的情况
优势: 可以解决很多没有头绪的问题 劣势: 训练效率极低
gitee地址:
https://gitee.com/ichiva/gnn.git
实现步骤 1.进化接口
public interface Evolution {
/**
* 遗传
* @param mDna
* @param fDna
* @return
*/
INDArray inheritance(INDArray mDna,INDArray fDna);
/**
* 突变
* @param dna
* @param v
* @param r 突变范围
* @return
*/
INDArray mutation(INDArray dna,double v, double r);
/**
* 置换
* @param dna
* @param v
* @return
*/
INDArray substitution(INDArray dna,double v);
/**
* 外源
* @param dna
* @param v
* @return
*/
INDArray other(INDArray dna,double v);
/**
* DNA 是否同源
* @param mDna
* @param fDna
* @return
*/
boolean iSogeny(INDArray mDna, INDArray fDna);
}
一个比较通用的实现
public class MnistEvolution implements Evolution {
private static final MnistEvolution instance = new MnistEvolution();
public static MnistEvolution getInstance() {
return instance;
}
@Override
public INDArray inheritance(INDArray mDna, INDArray fDna) {
if(mDna == fDna) return mDna;
long[] mShape = mDna.shape();
if(!iSogeny(mDna,fDna)){
throw new RuntimeException("非同源dna");
}
INDArray nDna = Nd4j.create(mShape);
NdIndexIterator it = new NdIndexIterator(mShape);
while (it.hasNext()){
long[] next = it.next();
double val;
if(Math.random() > 0.5){
val = fDna.getDouble(next);
}else {
val = mDna.getDouble(next);
}
nDna.putScalar(next,val);
}
return nDna;
}
@Override
public INDArray mutation(INDArray dna, double v, double r) {
long[] shape = dna.shape();
INDArray nDna = Nd4j.create(shape);
NdIndexIterator it = new NdIndexIterator(shape);
while (it.hasNext()) {
long[] next = it.next();
if(Math.random() < v){
dna.putScalar(next,dna.getDouble(next) + ((Math.random() - 0.5) * r * 2));
}else {
nDna.putScalar(next,dna.getDouble(next));
}
}
return nDna;
}
@Override
public INDArray substitution(INDArray dna, double v) {
long[] shape = dna.shape();
INDArray nDna = Nd4j.create(shape);
NdIndexIterator it = new NdIndexIterator(shape);
while (it.hasNext()) {
long[] next = it.next();
if(Math.random() > v){
long[] tag = new long[shape.length];
for (int i = 0; i < shape.length; i++) {
tag[i] = (long) (Math.random() * shape[i]);
}
nDna.putScalar(next,dna.getDouble(tag));
}else {
nDna.putScalar(next,dna.getDouble(next));
}
}
return nDna;
}
@Override
public INDArray other(INDArray dna, double v) {
long[] shape = dna.shape();
INDArray nDna = Nd4j.create(shape);
NdIndexIterator it = new NdIndexIterator(shape);
while (it.hasNext()) {
long[] next = it.next();
if(Math.random() > v){
nDna.putScalar(next,Math.random());
}else {
nDna.putScalar(next,dna.getDouble(next));
}
}
return nDna;
}
@Override
public boolean iSogeny(INDArray mDna, INDArray fDna) {
long[] mShape = mDna.shape();
long[] fShape = fDna.shape();
if (mShape.length == fShape.length) {
for (int i = 0; i < mShape.length; i++) {
if (mShape[i] != fShape[i]) {
return false;
}
}
return true;
}
return false;
}
}
定义智能体配置接口
public interface AgentConfig {
/**
* 输入量
* @return
*/
int getInput();
/**
* 输出量
* @return
*/
int getOutput();
/**
* 神经网络配置
* @return
*/
MultiLayerConfiguration getMultiLayerConfiguration();
}
按手写数字识别进行配置实现
public class MnistConfig implements AgentConfig {
@Override
public int getInput() {
return 28 * 28;
}
@Override
public int getOutput() {
return 10;
}
@Override
public MultiLayerConfiguration getMultiLayerConfiguration() {
return new NeuralNetConfiguration.Builder()
.seed((long) (Math.random() * Long.MAX_VALUE))
.updater(new Nesterovs(0.006, 0.9))
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(getInput())
.nOut(1000)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
.nIn(1000)
.nOut(getOutput())
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build())
.pretrain(false).backprop(true)
.build();
}
}
智能体基类
@Getter
public class Agent {
private final AgentConfig config;
private final INDArray dna;
private final MultiLayerNetwork multiLayerNetwork;
/**
* 采用默认方法初始化参数
* @param config
*/
public Agent(AgentConfig config){
this(config,null);
}
/**
*
* @param config
* @param dna
*/
public Agent(AgentConfig config, INDArray dna){
if(dna == null){
this.config = config;
MultiLayerConfiguration conf = config.getMultiLayerConfiguration();
this.multiLayerNetwork = new MultiLayerNetwork(conf);
multiLayerNetwork.init();
this.dna = multiLayerNetwork.params();
}else {
this.config = config;
MultiLayerConfiguration conf = config.getMultiLayerConfiguration();
this.multiLayerNetwork = new MultiLayerNetwork(conf);
multiLayerNetwork.init(dna,true);
this.dna = dna;
}
}
}
手写数字智能体实现类
@Getter
@Setter
public class MnistAgent extends Agent {
private static final AtomicInteger index = new AtomicInteger(0);
private String name;
/**
* 环境适应分数
*/
private double score;
/**
* 验证分数
*/
private double validScore;
public MnistAgent(AgentConfig config) {
this(config,null);
}
public MnistAgent(AgentConfig config, INDArray dna) {
super(config, dna);
name = "agent-" + index.incrementAndGet();
}
public static MnistConfig mnistConfig = new MnistConfig();
public static MnistAgent newInstance(){
return new MnistAgent(mnistConfig);
}
public static MnistAgent create(INDArray dna){
return new MnistAgent(mnistConfig,dna);
}
}
手写数字识别环境构建
@Slf4j
public class MnistEnv {
/**
* 环境数据
*/
private static final ThreadLocal<MnistDataSetIterator> tLocal = ThreadLocal.withInitial(() -> {
try {
return new MnistDataSetIterator(128, true, 0);
} catch (IOException e) {
throw new RuntimeException("mnist 文件读取失败");
}
});
private static final ThreadLocal<MnistDataSetIterator> testLocal = ThreadLocal.withInitial(() -> {
try {
return new MnistDataSetIterator(128, false, 0);
} catch (IOException e) {
throw new RuntimeException("mnist 文件读取失败");
}
});
private static final MnistEvolution evolution = MnistEvolution.getInstance();
/**
* 环境承载上限
*
* 超过上限AI会进行激烈竞争
*/
private final int max;
private Double maxScore,minScore;
/**
* 环境中的生命体
*
* 新生代与历史代共同排序,选出最适应环境的个体
*/
//2个变量,一个队列保存KEY的顺序,一个MAP保存KEY对应的具体对象的数据 线程安全map
private final TreeMap<Double,MnistAgent> lives = new TreeMap<>();
/**
* 初始化环境
*
* 1.向环境中初始化ai
* 2.将初始化ai进行环境适应性测试,并排序
* @param max
*/
public MnistEnv(int max){
this.max = max;
for (int i = 0; i < max; i++) {
MnistAgent agent = MnistAgent.newInstance();
test(agent);
synchronized (lives) {
lives.put(agent.getScore(),agent);
}
log.info("初始化智能体 name = {} , score = {}",i,agent.getScore());
}
synchronized (lives) {
minScore = lives.firstKey();
maxScore = lives.lastKey();
}
}
/**
* 环境适应性评估
* @param ai
*/
public void test(MnistAgent ai){
MultiLayerNetwork network = ai.getMultiLayerNetwork();
MnistDataSetIterator dataIterator = tLocal.get();
Evaluation eval = new Evaluation(ai.getConfig().getOutput());
try {
while (dataIterator.hasNext()) {
DataSet data = dataIterator.next();
INDArray output = network.output(data.getFeatures(), false);
eval.eval(data.getLabels(),output);
}
}finally {
dataIterator.reset();
}
ai.setScore(eval.accuracy());
}
/**
* 迁移评估
*
* @param ai
*/
public void validation(MnistAgent ai){
MultiLayerNetwork network = ai.getMultiLayerNetwork();
MnistDataSetIterator dataIterator = testLocal.get();
Evaluation eval = new Evaluation(ai.getConfig().getOutput());
try {
while (dataIterator.hasNext()) {
DataSet data = dataIterator.next();
INDArray output = network.output(data.getFeatures(), false);
eval.eval(data.getLabels(),output);
}
}finally {
dataIterator.reset();
}
ai.setValidScore(eval.accuracy());
}
/**
* 进化
*
* 每轮随机创建ai并放入环境中进行优胜劣汰
* @param n 进化次数
*/
public void evolution(int n){
BlockThreadPool blockThreadPool=new BlockThreadPool(2);
for (int i = 0; i < n; i++) {
blockThreadPool.execute(() -> contend(newLive()));
}
// for (int i = 0; i < n; i++) {
// contend(newLive());
// }
}
/**
* 竞争
* @param ai
*/
public void contend(MnistAgent ai){
test(ai);
quality(ai);
double score = ai.getScore();
if(score <= minScore){
UI.put("无法生存",String.format("name = %s, score = %s", ai.getName(),ai.getScore()));
return;
}
Map.Entry<Double, MnistAgent> lastEntry;
synchronized (lives) {
lives.put(score,ai);
if (lives.size() > max) {
MnistAgent lastAI = lives.remove(lives.firstKey());
UI.put("淘 汰 ",String.format("name = %s, score = %s", lastAI.getName(),lastAI.getScore()));
}
lastEntry = lives.lastEntry();
minScore = lives.firstKey();
}
Double lastScore = lastEntry.getKey();
if(lastScore > maxScore){
maxScore = lastScore;
MnistAgent agent = lastEntry.getValue();
validation(agent);
UI.put("max验证",String.format("score = %s,validScore = %s",lastScore,agent.getValidScore()));
try {
Warehouse.write(agent);
} catch (IOException ex) {
log.error("保存对象失败",ex);
}
}
}
ArrayList<Double> scoreList = new ArrayList<>(100);
ArrayList<Integer> avgList = new ArrayList<>();
private void quality(MnistAgent ai) {
synchronized (scoreList) {
scoreList.add(ai.getScore());
if (scoreList.size() >= 100) {
double avg = scoreList.stream().mapToDouble(e -> e)
.average().getAsDouble();
avgList.add((int) (avg * 1000));
StringBuffer buffer = new StringBuffer();
avgList.forEach(e -> buffer.append(e).append('\t'));
UI.put("平均得分",String.format("aix100 avg = %s",buffer.toString()));
scoreList.clear();
}
}
}
/**
* 随机生成新智能体
*
* 完全随机产生母本
* 随机从比目标相同或更高评分中选择父本
*
* 基因进化在1%~10%之间进行,评分越高基于越稳定
*/
public MnistAgent newLive(){
double r = Math.random();
//基因突变率
double v = r / 11 + 0.01;
//母本
MnistAgent mAgent = getMother(r);
//父本
MnistAgent fAgent = getFather(r);
int i = (int) (Math.random() * 3);
INDArray newDNA = evolution.inheritance(mAgent.getDna(), fAgent.getDna());
switch (i){
case 0:
newDNA = evolution.other(newDNA,v);
break;
case 1:
newDNA = evolution.mutation(newDNA,v,0.1);
break;
case 2:
newDNA = evolution.substitution(newDNA,v);
break;
}
return MnistAgent.create(newDNA);
}
/**
* 父本只选择比母本评分高的样本
* @param r
* @return
*/
private MnistAgent getFather(double r) {
r += (Math.random() * (1-r));
return getMother(r);
}
private MnistAgent getMother(double r) {
int index = (int) (r * max);
return getMnistAgent(index);
}
private MnistAgent getMnistAgent(int index) {
synchronized (lives) {
Iterator<Map.Entry<Double, MnistAgent>> it = lives.entrySet().iterator();
for (int i = 0; i < index; i++) {
it.next();
}
return it.next().getValue();
}
}
}
主函数
@Slf4j
public class Program {
public static void main(String[] args) {
UI.put("开始时间",new Date().toLocaleString());
MnistEnv env = new MnistEnv(128);
env.evolution(Integer.MAX_VALUE);
}
}
运行截图

看完上述内容,你们对dl4j如何使用遗传神经网络完成手写数字识别有进一步的了解吗?如果还想了解更多知识或者相关内容,请关注天达云行业资讯频道,感谢大家的支持。