Add cross-breeding functionality to FighterNN

This commit is contained in:
vandomej 2024-03-11 01:56:48 -07:00
parent 7ffd48f186
commit 69b026593e

View file

@ -4,6 +4,7 @@ use std::{fs, path::PathBuf};
use fann::{ActivationFunc, Fann};
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
use rand::prelude::*;
use rand::distributions::{Distribution, Uniform};
use serde::{Deserialize, Serialize};
use anyhow::Context;
use uuid::Uuid;
@ -135,6 +136,39 @@ impl GeneticNode for FighterNN {
let mut fann = Fann::from_file(&nn)
.with_context(|| format!("Failed to load nn"))?;
// Load another nn from the current generation and cross breed it with the current nn
let cross_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, to_keep[thread_rng().gen_range(0..survivor_count)]));
let cross_fann = Fann::from_file(&cross_nn)
.with_context(|| format!("Failed to load cross nn"))?;
let mut connections = fann.get_connections(); // Vector of connections
let cross_connections = cross_fann.get_connections(); // Vector of connections
let segment_count: usize = 3; // For example, choose 3 segments to swap
let segment_distribution = Uniform::from(1..connections.len() / segment_count); // Ensure segments are not too small
let mut start_points = vec![];
for _ in 0..segment_count {
let start_point = segment_distribution.sample(&mut rand::thread_rng());
start_points.push(start_point);
}
start_points.sort_unstable(); // Ensure segments are in order
for (j, &start) in start_points.iter().enumerate() {
let end = if j < segment_count - 1 {
start_points[j + 1]
} else {
connections.len()
};
// Swap segments
for k in start..end {
connections[k] = cross_connections[k].clone();
}
}
fann.set_connections(&connections);
// For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight)
// And a 5% chance of a major mutation (a random number between -0.3 and 0.3 is added to the weight)
let mut connections = fann.get_connections(); // Vector of connections