From 69b026593ed5d2f322ccc68b97eb03c74696a50f Mon Sep 17 00:00:00 2001 From: vandomej Date: Mon, 11 Mar 2024 01:56:48 -0700 Subject: [PATCH] Add cross-breeding functionality to FighterNN --- gemla/src/bin/fighter_nn/mod.rs | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/gemla/src/bin/fighter_nn/mod.rs b/gemla/src/bin/fighter_nn/mod.rs index 3504e82..cabbb29 100644 --- a/gemla/src/bin/fighter_nn/mod.rs +++ b/gemla/src/bin/fighter_nn/mod.rs @@ -4,6 +4,7 @@ use std::{fs, path::PathBuf}; use fann::{ActivationFunc, Fann}; use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error}; use rand::prelude::*; +use rand::distributions::{Distribution, Uniform}; use serde::{Deserialize, Serialize}; use anyhow::Context; use uuid::Uuid; @@ -135,6 +136,39 @@ impl GeneticNode for FighterNN { let mut fann = Fann::from_file(&nn) .with_context(|| format!("Failed to load nn"))?; + // Load another nn from the current generation and cross breed it with the current nn + let cross_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, to_keep[thread_rng().gen_range(0..survivor_count)])); + let cross_fann = Fann::from_file(&cross_nn) + .with_context(|| format!("Failed to load cross nn"))?; + + let mut connections = fann.get_connections(); // Vector of connections + let cross_connections = cross_fann.get_connections(); // Vector of connections + let segment_count: usize = 3; // For example, choose 3 segments to swap + let segment_distribution = Uniform::from(1..connections.len() / segment_count); // Ensure segments are not too small + + let mut start_points = vec![]; + + for _ in 0..segment_count { + let start_point = segment_distribution.sample(&mut rand::thread_rng()); + start_points.push(start_point); + } + start_points.sort_unstable(); // Ensure segments are in order + + for (j, &start) in start_points.iter().enumerate() { + let end = if j < segment_count - 1 { + start_points[j + 1] + } else { + connections.len() + }; + + // Swap segments + for k in start..end { + connections[k] = cross_connections[k].clone(); + } + } + + fann.set_connections(&connections); + // For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight) // And a 5% chance of a major mutation (a random number between -0.3 and 0.3 is added to the weight) let mut connections = fann.get_connections(); // Vector of connections