dootcamp #1
1 changed files with 34 additions and 0 deletions
|
@ -4,6 +4,7 @@ use std::{fs, path::PathBuf};
|
||||||
use fann::{ActivationFunc, Fann};
|
use fann::{ActivationFunc, Fann};
|
||||||
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
|
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
use rand::distributions::{Distribution, Uniform};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
@ -135,6 +136,39 @@ impl GeneticNode for FighterNN {
|
||||||
let mut fann = Fann::from_file(&nn)
|
let mut fann = Fann::from_file(&nn)
|
||||||
.with_context(|| format!("Failed to load nn"))?;
|
.with_context(|| format!("Failed to load nn"))?;
|
||||||
|
|
||||||
|
// Load another nn from the current generation and cross breed it with the current nn
|
||||||
|
let cross_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, to_keep[thread_rng().gen_range(0..survivor_count)]));
|
||||||
|
let cross_fann = Fann::from_file(&cross_nn)
|
||||||
|
.with_context(|| format!("Failed to load cross nn"))?;
|
||||||
|
|
||||||
|
let mut connections = fann.get_connections(); // Vector of connections
|
||||||
|
let cross_connections = cross_fann.get_connections(); // Vector of connections
|
||||||
|
let segment_count: usize = 3; // For example, choose 3 segments to swap
|
||||||
|
let segment_distribution = Uniform::from(1..connections.len() / segment_count); // Ensure segments are not too small
|
||||||
|
|
||||||
|
let mut start_points = vec![];
|
||||||
|
|
||||||
|
for _ in 0..segment_count {
|
||||||
|
let start_point = segment_distribution.sample(&mut rand::thread_rng());
|
||||||
|
start_points.push(start_point);
|
||||||
|
}
|
||||||
|
start_points.sort_unstable(); // Ensure segments are in order
|
||||||
|
|
||||||
|
for (j, &start) in start_points.iter().enumerate() {
|
||||||
|
let end = if j < segment_count - 1 {
|
||||||
|
start_points[j + 1]
|
||||||
|
} else {
|
||||||
|
connections.len()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Swap segments
|
||||||
|
for k in start..end {
|
||||||
|
connections[k] = cross_connections[k].clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fann.set_connections(&connections);
|
||||||
|
|
||||||
// For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight)
|
// For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight)
|
||||||
// And a 5% chance of a major mutation (a random number between -0.3 and 0.3 is added to the weight)
|
// And a 5% chance of a major mutation (a random number between -0.3 and 0.3 is added to the weight)
|
||||||
let mut connections = fann.get_connections(); // Vector of connections
|
let mut connections = fann.get_connections(); // Vector of connections
|
||||||
|
|
Loading…
Add table
Reference in a new issue