Add cross-breeding functionality to FighterNN
This commit is contained in:
parent
7ffd48f186
commit
69b026593e
1 changed files with 34 additions and 0 deletions
|
@ -4,6 +4,7 @@ use std::{fs, path::PathBuf};
|
|||
use fann::{ActivationFunc, Fann};
|
||||
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
|
||||
use rand::prelude::*;
|
||||
use rand::distributions::{Distribution, Uniform};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use anyhow::Context;
|
||||
use uuid::Uuid;
|
||||
|
@ -135,6 +136,39 @@ impl GeneticNode for FighterNN {
|
|||
let mut fann = Fann::from_file(&nn)
|
||||
.with_context(|| format!("Failed to load nn"))?;
|
||||
|
||||
// Load another nn from the current generation and cross breed it with the current nn
|
||||
let cross_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, to_keep[thread_rng().gen_range(0..survivor_count)]));
|
||||
let cross_fann = Fann::from_file(&cross_nn)
|
||||
.with_context(|| format!("Failed to load cross nn"))?;
|
||||
|
||||
let mut connections = fann.get_connections(); // Vector of connections
|
||||
let cross_connections = cross_fann.get_connections(); // Vector of connections
|
||||
let segment_count: usize = 3; // For example, choose 3 segments to swap
|
||||
let segment_distribution = Uniform::from(1..connections.len() / segment_count); // Ensure segments are not too small
|
||||
|
||||
let mut start_points = vec![];
|
||||
|
||||
for _ in 0..segment_count {
|
||||
let start_point = segment_distribution.sample(&mut rand::thread_rng());
|
||||
start_points.push(start_point);
|
||||
}
|
||||
start_points.sort_unstable(); // Ensure segments are in order
|
||||
|
||||
for (j, &start) in start_points.iter().enumerate() {
|
||||
let end = if j < segment_count - 1 {
|
||||
start_points[j + 1]
|
||||
} else {
|
||||
connections.len()
|
||||
};
|
||||
|
||||
// Swap segments
|
||||
for k in start..end {
|
||||
connections[k] = cross_connections[k].clone();
|
||||
}
|
||||
}
|
||||
|
||||
fann.set_connections(&connections);
|
||||
|
||||
// For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight)
|
||||
// And a 5% chance of a major mutation (a random number between -0.3 and 0.3 is added to the weight)
|
||||
let mut connections = fann.get_connections(); // Vector of connections
|
||||
|
|
Loading…
Add table
Reference in a new issue