NEAT(增强拓扑的进化神经网络)是一种基于遗传算法和神经网络的机器学习算法,不同于全连接神经网络,NEAT的神经网络是可以跨层相连的;它由最初始的输入层和输出层神经元连接,迭代繁衍和进化达到最终形态。
具体的介绍可以参考莫凡老师的教程。
最近因为搜索 raylib
的项目,发现了一个 简单的 NEAT
库simpleNEAT,觉得挺有趣,就拿来做个实验。
将 simpleNEAT
中的lib目录放到项目文件夹,新建 main.cpp
,写入以下代码,根据需要修改参数:
#include <iostream>
#include <valarray>
#include <vector>
#include <map>
#include "raylib.h"
#include "lib/SimpleNEAT.hpp"
int screenWidth = 1600; // 设置窗口宽度
int screenHeight = 900; // 设置窗口高度
float initRotation = 0.f; // 个体的初始化旋转角度
float objectSize = 15.f; // 个体的半径
float sensorMax = 300.f; // 距离传感器长度
float stepSize = 10.f; // 移动和旋转的步长
Vector2 initPosistion = {100.f, 100.f}; // 个体出生位置,初始在窗口从左到右从上往下 100,100
Vector2 targetPosition = {float(screenWidth) - 100.f, float(screenHeight) - 100.f}; // 目标位置,初始在窗口从右到左从下往上 100,100
bool isStart = false; // 是否开始训练
int frameLimit = 1500; // 每一代帧数限制
const int sensorCount = 11; // 具体传感器数量
int fps = 60; // 训练时的fps限制
int winnerCount = 0; // 到达目标的个体计数器
znn::SimpleNeat initNeat() { // 初始化神经网络和种群
znn::Opts.InputSize = sensorCount + 2; // 设置神经网络输入节点数量=传感器数量+个体与目标的相对距离+个体朝向与目标的相对角度
znn::Opts.OutputSize = 2; // 设置神经网络输出节点数量
znn::Opts.ActiveFunction = znn::Sigmoid; // 使用的激活函数
znn::Opts.IterationTimes = 0; // 迭代次数,0为不限制
znn::Opts.FitnessThreshold = 0.f; // 个体适应值阈值,0为不限制
znn::Opts.IterationCheckPoint = 1; // 保存最优神经网络的迭代次数,1为每代保存
znn::Opts.ThreadCount = 16; // 多线程数量,不设置则为设备默认数量
znn::Opts.MutateAddNeuronRate = 0.03f; // 添加新神经元的概率
znn::Opts.MutateAddConnectionRate = 0.99f; // 添加新连接的概率
znn::Opts.PopulationSize = 150; // 训练个体的数量
znn::Opts.NewSize = 0; // 每一代新生个体的数量
znn::Opts.ChampionToNewSize = 90; // 冠军被复制和交配的目标数量
znn::Opts.ChampionKeepSize = 15; // 冠军的数量
znn::Opts.KeepWorstSize = 0; // 保留最差个体的数量
znn::Opts.KeepComplexSize = 1; // 保留最复杂神经网络个体的数量,用于交配产生更复杂的神经网络
znn::Opts.WeightRange = 12; // 神经连接的权重范围,-12至12
znn::Opts.BiasRange = 6; // 神经元的偏置范围,-6至6
znn::Opts.MutateBiasRate = 1.f; // 神经元的偏置变异概率
znn::Opts.MutateWeightRate = 1.f; // 神经元连接的权重变异概率
znn::Opts.MutateBiasDirectOrNear = 0.5f; // 神经元偏置随机变异和就近变异的比例
znn::Opts.MutateWeightDirectOrNear = 0.5f; // 神经元连接权重随机变异和就近变异的比例
znn::Opts.Enable3dNN = false; // 是否显示3d实时可视化神经网络,不能启用,因为用的raylib库,训练环境也用的raylib库,不能开启多窗口
znn::Opts.CheckPointPath = "/tmp/raylib_path_findder"; // 自动保存神经网络和NEAT创新ID的路径
srandom((unsigned) clock()); // 初始化随机种子
znn::SimpleNeat sneat; // 创建NEAT对象
sneat.Start(); // 初始化NEAT神经网络和种群,如果自动保存路径存在,则导入,不存在则新建
return sneat;
}
Vector2 getXY(float angle, float distance) { // 通过角度和距离计算坐标
Vector2 result;
float radians = angle * PI / 180.f;
result.x = distance * std::cos(radians);
result.y = distance * std::sin(radians);
return result;
}
struct myWall { // 障碍物
std::vector<Vector2> path; // 存储坐标的容器
void add() { // 添加坐标
Vector2 mousePos = GetMousePosition();
if (path.empty() || (!path.empty() && path[path.size() - 1].x != mousePos.x && path[path.size() - 1].y != mousePos.y)) { // 判断是否和上一个坐标重复
path.push_back(mousePos);
}
}
void draw() { // 绘制障碍
if (!path.empty()) {
for (int i = 1; i < path.size(); ++i) {
DrawLineEx(path[i - 1], path[i], 1.f, WHITE);
}
}
}
};
std::vector<myWall> walls; // 存储多个障碍
myWall createScreenWall() { // 创建窗口四周的障碍
myWall screenWall;
screenWall.path.push_back({0, 0});
screenWall.path.push_back({0, float(screenHeight)});
screenWall.path.push_back({float(screenWidth), float(screenHeight)});
screenWall.path.push_back({float(screenWidth), 0});
screenWall.path.push_back({0, 0});
return screenWall;
}
bool getCollion(Vector2 center, Vector2 sensorTail, Vector2 &collisionPoint, float &sensorDistance) { // 根据两条线的起止坐标判断是否相交
std::map<float, Vector2> dis2Pos;
bool isCollision = false;
for (auto &w: walls) {
for (int i = 1; i < w.path.size(); ++i) {
Vector2 collisionPos;
if (CheckCollisionLines(center, sensorTail, w.path[i - 1], w.path[i], &collisionPos)) {
float distance = std::sqrt(std::pow(collisionPos.x - center.x, 2.f) + std::pow(collisionPos.y - center.y, 2.f));
dis2Pos[distance] = collisionPos;
isCollision = true;
}
}
}
collisionPoint = dis2Pos.begin()->second;
sensorDistance = dis2Pos.begin()->first;
return isCollision;
}
struct object { // 训练个体
float rotation = initRotation; // 初始旋转角度
Vector2 position = initPosistion; // 出生位置
bool isDead = false; // 是否死亡
float speed = 0.f; // 速度
std::vector<Vector2> path; // 走过的路径
std::vector<float> pathWidth; // 走过路径对应的路宽,根据速度判断
Vector2 sensorsPos[sensorCount]{}; // 距离传感器的相对末端坐标
Vector2 sensorCol[sensorCount]{}; // 距离传感器的相对探测到障碍的交叉坐标
float sensorDis[sensorCount]; // 距离传感器到障碍的长度
float score = 0.f; // 记录得分
float targetAngle = std::atan2((targetPosition.y - position.y), (targetPosition.x - position.x)) / PI * 180.f - float(int(rotation + 180.f) * 10 % 3600 - 1800) / 10.f; // 个体朝向和目标的相对角度
float targetDistance = std::sqrt(std::pow(position.x - targetPosition.x, 2.f) + std::pow(position.y - targetPosition.y, 2.f)); // 个体和目标的距离
float beginDistance = 0.f; // 个体和出生位置的距离
void setSensors() { // 放置具体传感器
for (int i = 0; i < sensorCount; ++i) {
Vector2 sensorTail = getXY(float(i) * 30.f + rotation - 150.f, sensorMax);
sensorsPos[i].x = sensorTail.x + position.x;
sensorsPos[i].y = sensorTail.y + position.y;
}
}
void getSensorsInfo() { // 更新传感器数据
for (int i = 0; i < sensorCount; ++i) {
if (!getCollion(position, sensorsPos[i], sensorCol[i], sensorDis[i])) {
sensorCol[i] = sensorsPos[i]; // 传感器与障碍物交叉的位置,没有交叉则为传感器目标位置
sensorDis[i] = sensorMax; // 传感器与障碍之间的距离,没有交叉则为预设最大值
}
if (sensorDis[i] < objectSize) {
isDead = true; // 如果传感器道障碍的距离小于个体半径,则判断为死亡
}
}
if (std::abs(position.x - targetPosition.x) < 10.f && std::abs(position.y - targetPosition.y) < 10.f) { // 判断个体是否到达目标坐标
isDead = true; // 到达坐标则死亡
++winnerCount; // 达到目标的个体计数器更新
score += 10000.f; // 到达目标加分
}
if (path.size() > 2) { // 判断个体是否走了老路,通过路径记录和碰撞判断
for (int i = 1; i < path.size() - 2; ++i) {
if (CheckCollisionLines(path[path.size() - 1], path[path.size() - 2], path[i], path[i - 1], nullptr)) {
isDead = true; // 如果和自己的运动轨迹碰撞则死亡
break;
}
}
}
targetDistance = std::sqrt(std::pow(position.x - targetPosition.x, 2.f) + std::pow(position.y - targetPosition.y, 2.f)); // 更新个体和目标的距离
if (targetDistance < 1.f) { // 为便于分数判定,需要将目标距离作为被除数
targetDistance = 1.f; // 如果距离小于1则为1,避免分数特别大
}
targetAngle = std::atan2((targetPosition.y - position.y), (targetPosition.x - position.x)) / PI * 180.f - float(int(rotation + 180.f) * 10 % 3600 - 1800) / 10.f; // 更新个体朝向与目标的相对角度
if (std::abs(targetAngle) < 10.f) { // 如果相对角度小于10,则加分
score += 1.f;
}
}
object() { // 创建个体时的初始化操作
setSensors(); // 更新传感器位置
getSensorsInfo(); // 更新传感器数据
}
void rotate(float angle) { // 个体旋转操作
rotation = float(int((rotation + angle * stepSize) * 10.f) % 3600) / 10.f;
setSensors();
getSensorsInfo();
}
void move(float distance) { // 个体移动操作
speed = distance * 30.f; // 更新速度,用于可视化尾喷长度
score += speed; // 更新分数,叠加速度
Vector2 movePos = getXY(rotation, distance * stepSize); // 获取需要移动的相对坐标
position.x += movePos.x; // 更新个体坐标x
position.y += movePos.y; // 更新个体坐标y
if (distance > 0.01f) { // 如果个体移动距离太小,则判断为死亡
path.push_back(position);
if (distance > 0.3f) { // 为防止可视化路径的时候宽度太小,则设置最小宽度0.3
pathWidth.push_back(distance);
} else {
pathWidth.push_back(0.3f);
}
} else {
isDead = true;
}
setSensors(); // 更新传感器位置
getSensorsInfo(); // 更新传感器数据
}
void draw() { // 绘制个体
if (path.size() > 1) { // 绘制个体移动路径,线条需要两个坐标
for (int i = 1; i < path.size(); ++i) {
DrawLineEx(path[i - 1], path[i], 1., ColorAlpha(GREEN, pathWidth[i] * 0.3f)); // 路径宽度改为路径透明度由宽度判定
}
}
if (!isDead) { // 如果个体存活则绘制传感器
for (auto sc: sensorCol) {
DrawLineEx(position, sc, 1, ColorAlpha(BLUE, 0.5f));
DrawCircleV(sc, objectSize / 10.f * 3.f, ColorAlpha(RED, 0.5f));
}
}
auto objColor = WHITE; // 如果个体存活,则本体为白色,死亡为红色
if (isDead) {
objColor = RED;
}
DrawPolyLinesEx(position, 3, objectSize, rotation + 30.f, objectSize / 5.f, objColor); // 绘制个体本体,三角形
Vector2 headPos = getXY(rotation, objectSize); // 获取头部坐标用于给个体头部画一根线
DrawLineEx(position, {headPos.x + position.x, headPos.y + position.y}, 1, objColor); // 给个体头部画一根线分辨方向
if (!isDead) { // 如果存活则绘制尾喷
Vector2 tailPos = getXY(rotation + 180.f, speed); // 获取尾喷相对位置用于绘制,长度由速度决定
DrawLineEx(position, {tailPos.x + position.x, tailPos.y + position.y}, objectSize / 10.f * 3.f, YELLOW);
}
}
};
void keyControl() { // 用户输入控制
if (IsMouseButtonDown(0)) { // 鼠标左键绘制障碍
walls[walls.size() - 1].add();
}
if (IsMouseButtonReleased(0)) { // 鼠标左键抬起用于添加新的障碍物列表,避免只绘制一条线
walls.push_back(myWall{});
}
if (IsMouseButtonDown(1)) { // 鼠标右键清除障碍
walls.clear();
walls.push_back(createScreenWall()); // 清除障碍以后先添加窗口四周的障碍
walls.push_back(myWall{});
}
if (IsKeyPressed('B')) { // B键用于设置个体出生位置
initPosistion = GetMousePosition();
initRotation = std::atan2(targetPosition.y - initPosistion.y, targetPosition.x - initPosistion.x) / PI * 180.f; // 设置完出生位置后更新个体初始朝向
}
if (IsKeyPressed('T')) { // T键用于设置目标位置
targetPosition = GetMousePosition();
initRotation = std::atan2(targetPosition.y - initPosistion.y, targetPosition.x - initPosistion.x) / PI * 180.f; // 设置完目标位置后更新个体初始朝向
}
if (IsKeyPressed(KEY_SPACE)) { // 空格键用于控制是否开始训练
if (isStart) {
isStart = false;
SetTargetFPS(30); // 没开始训练时帧率限制为30
} else {
isStart = true;
SetTargetFPS(fps);
}
}
}
bool isBreakFunc() { // 用于NEAT训练循环中判断是否中断
return !isStart;
};
int main() {
SetConfigFlags(FLAG_MSAA_4X_HINT); // 设置抗锯齿
InitWindow(screenWidth, screenHeight, "寻路实验"); // 初始化raylib窗口
SetTargetFPS(30); // 设置帧率
walls.push_back(createScreenWall()); // 创建窗口四周障碍
walls.push_back(myWall{});
initRotation = std::atan2(targetPosition.y - initPosistion.y, targetPosition.x - initPosistion.x) / PI * 180.f; // 初始化个体朝向
auto sneat = initNeat(); // 初始化神经网络和种群
int stepCount = 0; // 训练迭代计数器
std::function<std::map<znn::NetworkGenome *, float>()> interactiveFunc = [&]() {
++stepCount;
std::vector<object> objs; // 新建个体集容器
for (int i = 0; i < znn::Opts.PopulationSize; ++i) { // 塞满个体
objs.emplace_back();
}
std::map<znn::NetworkGenome *, float> populationFitness; // 创建神经网络地址对应的适应度map
for (int step = 0; step < frameLimit; ++step) { // 每一代训练,基于帧数限制的循环
keyControl(); // 用户输入控制
if (stepCount % znn::Opts.IterationCheckPoint == 0) { // 如果达到自动保存次数,则可视化显示
BeginDrawing(); // 开始绘制
ClearBackground(BLACK); // 清空背景
DrawCircleV(initPosistion, 10.f, GRAY); // 绘制出生点
DrawCircleV(targetPosition, 10.f, RED); // 绘制目标点
for (auto &w: walls) { // 绘制障碍
w.draw();
}
}
int deadCount = 0; // 死亡个体计数器
for (int i = 0; i < znn::Opts.PopulationSize; ++i) { // 每个训练个体的神经网络判断输入和输出
if (!objs[i].isDead) { // 如果个体存活则继续
if ((step > 100 && objs[i].path.size() < 50) || (objs[i].path.size() > 100 && std::abs(objs[i].path[objs[i].path.size() - 1].y - objs[i].path[objs[i].path.size() - 100].y) < 3 &&
std::abs(objs[i].path[objs[i].path.size() - 1].x - objs[i].path[objs[i].path.size() - 100].x) < 3) || objs[i].position.x < 0 ||
objs[i].position.x > float(screenWidth) || objs[i].position.y < 0 || objs[i].position.y > float(screenHeight)) { // 简单判断死亡
objs[i].isDead = true;
} else {
std::vector<float> perInputs; // 准备神经网络输入数据
for (auto &sd: objs[i].sensorDis) { // 输入数据放入传感器到障碍物的距离
perInputs.push_back(1.f - ((sd-objectSize) / (sensorMax-objectSize))); // 距离除以传感器最大值,离得越近数值越大,同时排除个体自身尺寸,使得输入值在0-1范围
}
perInputs.push_back(objs[i].targetAngle); // 输入数据放入个体朝向到目标的相对角度
perInputs.push_back(objs[i].targetDistance); // 输入数据放入个体和目标的距离
std::vector<float> nextMove = sneat.population.generation.neuralNetwork.FeedForwardPredict(&sneat.population.NeuralNetworks[i], perInputs); // 根据输入数据计算每个神经网络的输出
objs[i].rotate((nextMove[0] - 0.5f) * 2.f); // 执行输出结果的旋转操作
objs[i].move(nextMove[1]); // 执行输出结果的移动操作
if (perInputs[(sensorCount - 1) / 2] > 0.f && nextMove[1] < 1.f) { // 简单判断个体前方有障碍则减速的加分
objs[i].score += 1;
}
}
} else { // 如果个体死亡则更新死亡计数器
++deadCount;
}
if (stepCount % znn::Opts.IterationCheckPoint == 0) { // 如果达到自动保存次数,则可视化显示
objs[i].draw(); // 绘制个体
}
}
if (stepCount % znn::Opts.IterationCheckPoint == 0) { // 如果达到自动保存次数,则可视化显示
DrawFPS(10, 10); // 绘制fps
EndDrawing(); // 单帧绘制完毕
}
if (deadCount == znn::Opts.PopulationSize) { // 如果全部个体死亡则终止本代
break;
}
}
for (int i = 0; i < znn::Opts.PopulationSize; ++i) { // 更新每个个体的适应度(得分)
objs[i].beginDistance = std::sqrt(std::pow(objs[i].position.x - initPosistion.x, 2.f) + std::pow(objs[i].position.y - initPosistion.y, 2.f));
populationFitness[&sneat.population.NeuralNetworks[i]] = (objs[i].beginDistance + objs[i].score * 5.f) / objs[i].targetDistance;
}
std::cout << "Winner: " << winnerCount << "\n";
winnerCount = 0; // 重置达到目标的个体计数
return populationFitness; // NEAT训练函数的格式
};
while (!WindowShouldClose()) { // 判断窗口是否关闭
while (!isStart) { // 如果没开始训练,则不更新和绘制个体
keyControl();
BeginDrawing();
ClearBackground(BLACK);
DrawCircleV(initPosistion, 10.f, GRAY);
DrawCircleV(targetPosition, 10.f, RED);
for (auto &w: walls) {
w.draw();
}
DrawFPS(10, 10);
EndDrawing();
}
stepCount = 0; // 重置训练迭代计数器
auto best = sneat.TrainByInteractive(interactiveFunc, isBreakFunc); // NEAT训练函数,开始训练
printf("Pause\n"); // 如果训练循环终止,则重新开始训练
}
CloseWindow(); // 关闭窗口
return 0;
}
g++ -lraylib -std=c++17 -O2 main.cpp
执行编译后生成的 a.out
:
./a.out
训练开始以后可以实时用添加障碍