Adding context information to genetic node

This commit is contained in:
vandomej 2024-03-11 01:23:43 -07:00
parent 774a0df5d7
commit 7ffd48f186
4 changed files with 105 additions and 61 deletions

View file

@ -2,10 +2,11 @@ extern crate fann;
use std::{fs, path::PathBuf};
use fann::{ActivationFunc, Fann};
use gemla::{core::genetic_node::GeneticNode, error::Error};
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use anyhow::Context;
use uuid::Uuid;
use std::collections::HashMap;
const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations";
@ -22,13 +23,13 @@ const SURVIVAL_RATE: f32 = 0.5;
// there is no training happening to the neural networks
// the neural networks are only used to simulate the nn's and to save and read the nn's from files
// Filenames are stored in the format of "{fighter_id}_fighter_nn_{generation}.net".
// The folder name is stored in the format of "fighter_nn_xxxxxx" where xxxxxx is an incrementing number, checking for the highest number and incrementing it by 1
// The main folder contains a subfolder for each generation, containing a population of 10 nn's
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct FighterNN {
pub id: u64,
pub id: Uuid,
pub folder: PathBuf,
pub population_size: usize,
pub generation: u64,
// A map of each nn identifier in a generation and their physics score
pub scores: Vec<HashMap<u64, f32>>,
@ -36,16 +37,10 @@ pub struct FighterNN {
impl GeneticNode for FighterNN {
// Check for the highest number of the folder name and increment it by 1
fn initialize() -> Result<Box<Self>, Error> {
fn initialize(context: &GeneticNodeContext) -> Result<Box<Self>, Error> {
let base_path = PathBuf::from(BASE_DIR);
let mut highest = 0;
let mut folder = base_path.join(format!("fighter_nn_{:06}", highest));
while folder.exists() {
highest += 1;
folder = base_path.join(format!("fighter_nn_{:06}", highest));
}
let mut folder = base_path.join(format!("fighter_nn_{:06}", context.id));
fs::create_dir(&folder)?;
//Create a new directory for the first generation
@ -55,7 +50,7 @@ impl GeneticNode for FighterNN {
// Create the first generation in this folder
for i in 0..POPULATION {
// Filenames are stored in the format of "xxxxxx_fighter_nn_0.net", "xxxxxx_fighter_nn_1.net", etc. Where xxxxxx is the folder name
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", highest, i));
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i));
let mut fann = Fann::new(NEURAL_NETWORK_SHAPE)
.with_context(|| format!("Failed to create nn"))?;
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
@ -65,16 +60,17 @@ impl GeneticNode for FighterNN {
}
Ok(Box::new(FighterNN {
id: highest,
id: context.id,
folder,
population_size: POPULATION,
generation: 0,
scores: vec![HashMap::new()],
}))
}
fn simulate(&mut self) -> Result<(), Error> {
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
// For each nn in the current generation:
for i in 0..POPULATION {
for i in 0..self.population_size {
// load the nn
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, i));
let fann = Fann::from_file(&nn)
@ -85,7 +81,7 @@ impl GeneticNode for FighterNN {
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation
for _ in 0..SIMULATION_ROUNDS {
let random_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, thread_rng().gen_range(0..POPULATION)));
let random_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, thread_rng().gen_range(0..self.population_size)));
let random_fann = Fann::from_file(&random_nn)
.with_context(|| format!("Failed to load random nn"))?;
@ -112,8 +108,8 @@ impl GeneticNode for FighterNN {
}
fn mutate(&mut self) -> Result<(), Error> {
let survivor_count = (POPULATION as f32 * SURVIVAL_RATE) as usize;
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
// Create the new generation folder
let new_gen_folder = self.folder.join(format!("{}", self.generation + 1));
@ -163,17 +159,11 @@ impl GeneticNode for FighterNN {
Ok(())
}
fn merge(left: &FighterNN, right: &FighterNN) -> Result<Box<FighterNN>, Error> {
fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid) -> Result<Box<FighterNN>, Error> {
let base_path = PathBuf::from(BASE_DIR);
// Find next highest
let mut highest = 0;
let mut folder = base_path.join(format!("fighter_nn_{:06}", highest));
while folder.exists() {
highest += 1;
folder = base_path.join(format!("fighter_nn_{:06}", highest));
}
let folder = base_path.join(format!("fighter_nn_{:06}", id));
fs::create_dir(&folder)?;
//Create a new directory for the first generation
@ -183,10 +173,10 @@ impl GeneticNode for FighterNN {
// Take the 5 nn's with the highest scores from the left nn's and save them to the new fighter folder
let mut sorted_scores: Vec<_> = left.scores[left.generation as usize].iter().collect();
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
let mut remaining = sorted_scores[(POPULATION / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
for i in 0..(POPULATION / 2) {
let mut remaining = sorted_scores[(left.population_size / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
for i in 0..(left.population_size / 2) {
let nn = left.folder.join(format!("{}", left.generation)).join(format!("{:06}_fighter_nn_{}.net", left.id, remaining.pop().unwrap()));
let new_nn = folder.join(format!("0")).join(format!("{:06}_fighter_nn_{}.net", highest, i));
let new_nn = folder.join(format!("0")).join(format!("{:06}_fighter_nn_{}.net", id, i));
trace!("From: {:?}, To: {:?}", &nn, &new_nn);
fs::copy(&nn, &new_nn)
.with_context(|| format!("Failed to copy left nn"))?;
@ -195,19 +185,20 @@ impl GeneticNode for FighterNN {
// Take the 5 nn's with the highest scores from the right nn's and save them to the new fighter folder
sorted_scores = right.scores[right.generation as usize].iter().collect();
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
remaining = sorted_scores[(POPULATION / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
for i in (POPULATION / 2)..POPULATION {
remaining = sorted_scores[(right.population_size / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
for i in (right.population_size / 2)..right.population_size {
let nn = right.folder.join(format!("{}", right.generation)).join(format!("{:06}_fighter_nn_{}.net", right.id, remaining.pop().unwrap()));
let new_nn = folder.join(format!("0")).join(format!("{:06}_fighter_nn_{}.net", highest, i));
let new_nn = folder.join(format!("0")).join(format!("{:06}_fighter_nn_{}.net", id, i));
trace!("From: {:?}, To: {:?}", &nn, &new_nn);
fs::copy(&nn, &new_nn)
.with_context(|| format!("Failed to copy right nn"))?;
}
Ok(Box::new(FighterNN {
id: highest,
id: *id,
folder,
generation: 0,
population_size: POPULATION,
scores: vec![HashMap::new()],
}))
}

View file

@ -1,6 +1,7 @@
use gemla::{core::genetic_node::GeneticNode, error::Error};
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
const POPULATION_SIZE: u64 = 5;
const POPULATION_REDUCTION_SIZE: u64 = 3;
@ -11,7 +12,7 @@ pub struct TestState {
}
impl GeneticNode for TestState {
fn initialize() -> Result<Box<Self>, Error> {
fn initialize(_context: &GeneticNodeContext) -> Result<Box<Self>, Error> {
let mut population: Vec<i64> = vec![];
for _ in 0..POPULATION_SIZE {
@ -21,7 +22,7 @@ impl GeneticNode for TestState {
Ok(Box::new(TestState { population }))
}
fn simulate(&mut self) -> Result<(), Error> {
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
let mut rng = thread_rng();
self.population = self
@ -33,7 +34,7 @@ impl GeneticNode for TestState {
Ok(())
}
fn mutate(&mut self) -> Result<(), Error> {
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
let mut rng = thread_rng();
let mut v = self.population.clone();
@ -71,7 +72,7 @@ impl GeneticNode for TestState {
Ok(())
}
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
fn merge(left: &TestState, right: &TestState, id: &Uuid) -> Result<Box<TestState>, Error> {
let mut v = left.population.clone();
v.append(&mut right.population.clone());
@ -82,7 +83,11 @@ impl GeneticNode for TestState {
let mut result = TestState { population: v };
result.mutate()?;
result.mutate(&GeneticNodeContext {
id: id.clone(),
generation: 0,
max_generations: 0,
})?;
Ok(Box::new(result))
}
@ -95,7 +100,13 @@ mod tests {
#[test]
fn test_initialize() {
let state = TestState::initialize().unwrap();
let state = TestState::initialize(
&GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
}
).unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
}
@ -108,14 +119,32 @@ mod tests {
let original_population = state.population.clone();
state.simulate().unwrap();
state.simulate(
&GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
}
).unwrap();
assert!(original_population
.iter()
.zip(state.population.iter())
.all(|(&a, &b)| b >= a - 1 && b <= a + 2));
state.simulate().unwrap();
state.simulate().unwrap();
state.simulate(
&GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
}
).unwrap();
state.simulate(
&GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
}
).unwrap();
assert!(original_population
.iter()
.zip(state.population.iter())
@ -128,7 +157,13 @@ mod tests {
population: vec![4, 3, 3],
};
state.mutate().unwrap();
state.mutate(
&GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
}
).unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
}
@ -143,7 +178,7 @@ mod tests {
population: vec![0, 1, 3, 7],
};
let merged_state = TestState::merge(&state1, &state2).unwrap();
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4()).unwrap();
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
assert!(merged_state.population.iter().any(|&x| x == 7));

View file

@ -24,6 +24,12 @@ pub enum GeneticState {
Finish,
}
pub struct GeneticNodeContext {
pub generation: u64,
pub max_generations: u64,
pub id: Uuid,
}
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
///
/// [`Bracket`]: crate::bracket::Bracket
@ -32,17 +38,17 @@ pub trait GeneticNode {
///
/// # Examples
/// TODO
fn initialize() -> Result<Box<Self>, Error>;
fn initialize(context: &GeneticNodeContext) -> Result<Box<Self>, Error>;
fn simulate(&mut self) -> Result<(), Error>;
fn simulate(&mut self, context: &GeneticNodeContext) -> Result<(), Error>;
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
///
/// # Examples
/// TODO
fn mutate(&mut self) -> Result<(), Error>;
fn mutate(&mut self, context: &GeneticNodeContext) -> Result<(), Error>;
fn merge(left: &Self, right: &Self) -> Result<Box<Self>, Error>;
fn merge(left: &Self, right: &Self, id: &Uuid) -> Result<Box<Self>, Error>;
}
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
@ -101,18 +107,28 @@ where
self.max_generations
}
pub fn generation(&self) -> u64 {
self.generation
}
pub fn state(&self) -> GeneticState {
self.state
}
pub fn process_node(&mut self) -> Result<GeneticState, Error> {
let context = GeneticNodeContext {
generation: self.generation,
max_generations: self.max_generations,
id: self.id,
};
match (self.state, &mut self.node) {
(GeneticState::Initialize, _) => {
self.node = Some(*T::initialize()?);
self.node = Some(*T::initialize(&context)?);
self.state = GeneticState::Simulate;
}
(GeneticState::Simulate, Some(n)) => {
n.simulate()
n.simulate(&context)
.with_context(|| format!("Error simulating node: {:?}", self))?;
self.state = if self.generation >= self.max_generations {
@ -122,7 +138,7 @@ where
};
}
(GeneticState::Mutate, Some(n)) => {
n.mutate()
n.mutate(&context)
.with_context(|| format!("Error mutating node: {:?}", self))?;
self.generation += 1;
@ -148,20 +164,20 @@ mod tests {
}
impl GeneticNode for TestState {
fn simulate(&mut self) -> Result<(), Error> {
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
self.score += 1.0;
Ok(())
}
fn mutate(&mut self) -> Result<(), Error> {
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
Ok(())
}
fn initialize() -> Result<Box<TestState>, Error> {
fn initialize(_context: &GeneticNodeContext) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
fn merge(_l: &TestState, _r: &TestState) -> Result<Box<TestState>, Error> {
fn merge(_l: &TestState, _r: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> {
Err(Error::Other(anyhow!("Unable to merge")))
}
}

View file

@ -195,7 +195,7 @@ where
{
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
let merged_node = GeneticNode::merge(left_node, right_node)?;
let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id())?;
tree.val = GeneticNodeWrapper::from(
*merged_node,
tree.val.max_generations(),
@ -337,6 +337,8 @@ mod tests {
use std::path::PathBuf;
use std::fs;
use self::genetic_node::GeneticNodeContext;
struct CleanUp {
path: PathBuf,
}
@ -367,20 +369,20 @@ mod tests {
}
impl genetic_node::GeneticNode for TestState {
fn simulate(&mut self) -> Result<(), Error> {
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
self.score += 1.0;
Ok(())
}
fn mutate(&mut self) -> Result<(), Error> {
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
Ok(())
}
fn initialize() -> Result<Box<TestState>, Error> {
fn initialize(_context: &GeneticNodeContext) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
fn merge(left: &TestState, right: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.score > right.score {
left.clone()
} else {