Adding context information to genetic node
This commit is contained in:
parent
774a0df5d7
commit
7ffd48f186
4 changed files with 105 additions and 61 deletions
|
@ -2,10 +2,11 @@ extern crate fann;
|
|||
|
||||
use std::{fs, path::PathBuf};
|
||||
use fann::{ActivationFunc, Fann};
|
||||
use gemla::{core::genetic_node::GeneticNode, error::Error};
|
||||
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
|
||||
use rand::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use anyhow::Context;
|
||||
use uuid::Uuid;
|
||||
use std::collections::HashMap;
|
||||
|
||||
const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations";
|
||||
|
@ -22,13 +23,13 @@ const SURVIVAL_RATE: f32 = 0.5;
|
|||
// there is no training happening to the neural networks
|
||||
// the neural networks are only used to simulate the nn's and to save and read the nn's from files
|
||||
// Filenames are stored in the format of "{fighter_id}_fighter_nn_{generation}.net".
|
||||
// The folder name is stored in the format of "fighter_nn_xxxxxx" where xxxxxx is an incrementing number, checking for the highest number and incrementing it by 1
|
||||
// The main folder contains a subfolder for each generation, containing a population of 10 nn's
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct FighterNN {
|
||||
pub id: u64,
|
||||
pub id: Uuid,
|
||||
pub folder: PathBuf,
|
||||
pub population_size: usize,
|
||||
pub generation: u64,
|
||||
// A map of each nn identifier in a generation and their physics score
|
||||
pub scores: Vec<HashMap<u64, f32>>,
|
||||
|
@ -36,16 +37,10 @@ pub struct FighterNN {
|
|||
|
||||
impl GeneticNode for FighterNN {
|
||||
// Check for the highest number of the folder name and increment it by 1
|
||||
fn initialize() -> Result<Box<Self>, Error> {
|
||||
fn initialize(context: &GeneticNodeContext) -> Result<Box<Self>, Error> {
|
||||
let base_path = PathBuf::from(BASE_DIR);
|
||||
|
||||
let mut highest = 0;
|
||||
let mut folder = base_path.join(format!("fighter_nn_{:06}", highest));
|
||||
while folder.exists() {
|
||||
highest += 1;
|
||||
folder = base_path.join(format!("fighter_nn_{:06}", highest));
|
||||
}
|
||||
|
||||
let mut folder = base_path.join(format!("fighter_nn_{:06}", context.id));
|
||||
fs::create_dir(&folder)?;
|
||||
|
||||
//Create a new directory for the first generation
|
||||
|
@ -55,7 +50,7 @@ impl GeneticNode for FighterNN {
|
|||
// Create the first generation in this folder
|
||||
for i in 0..POPULATION {
|
||||
// Filenames are stored in the format of "xxxxxx_fighter_nn_0.net", "xxxxxx_fighter_nn_1.net", etc. Where xxxxxx is the folder name
|
||||
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", highest, i));
|
||||
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i));
|
||||
let mut fann = Fann::new(NEURAL_NETWORK_SHAPE)
|
||||
.with_context(|| format!("Failed to create nn"))?;
|
||||
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
|
||||
|
@ -65,16 +60,17 @@ impl GeneticNode for FighterNN {
|
|||
}
|
||||
|
||||
Ok(Box::new(FighterNN {
|
||||
id: highest,
|
||||
id: context.id,
|
||||
folder,
|
||||
population_size: POPULATION,
|
||||
generation: 0,
|
||||
scores: vec![HashMap::new()],
|
||||
}))
|
||||
}
|
||||
|
||||
fn simulate(&mut self) -> Result<(), Error> {
|
||||
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
// For each nn in the current generation:
|
||||
for i in 0..POPULATION {
|
||||
for i in 0..self.population_size {
|
||||
// load the nn
|
||||
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, i));
|
||||
let fann = Fann::from_file(&nn)
|
||||
|
@ -85,7 +81,7 @@ impl GeneticNode for FighterNN {
|
|||
|
||||
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation
|
||||
for _ in 0..SIMULATION_ROUNDS {
|
||||
let random_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, thread_rng().gen_range(0..POPULATION)));
|
||||
let random_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, thread_rng().gen_range(0..self.population_size)));
|
||||
let random_fann = Fann::from_file(&random_nn)
|
||||
.with_context(|| format!("Failed to load random nn"))?;
|
||||
|
||||
|
@ -112,8 +108,8 @@ impl GeneticNode for FighterNN {
|
|||
}
|
||||
|
||||
|
||||
fn mutate(&mut self) -> Result<(), Error> {
|
||||
let survivor_count = (POPULATION as f32 * SURVIVAL_RATE) as usize;
|
||||
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
|
||||
|
||||
// Create the new generation folder
|
||||
let new_gen_folder = self.folder.join(format!("{}", self.generation + 1));
|
||||
|
@ -163,17 +159,11 @@ impl GeneticNode for FighterNN {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn merge(left: &FighterNN, right: &FighterNN) -> Result<Box<FighterNN>, Error> {
|
||||
fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid) -> Result<Box<FighterNN>, Error> {
|
||||
let base_path = PathBuf::from(BASE_DIR);
|
||||
|
||||
// Find next highest
|
||||
let mut highest = 0;
|
||||
let mut folder = base_path.join(format!("fighter_nn_{:06}", highest));
|
||||
while folder.exists() {
|
||||
highest += 1;
|
||||
folder = base_path.join(format!("fighter_nn_{:06}", highest));
|
||||
}
|
||||
|
||||
let folder = base_path.join(format!("fighter_nn_{:06}", id));
|
||||
fs::create_dir(&folder)?;
|
||||
|
||||
//Create a new directory for the first generation
|
||||
|
@ -183,10 +173,10 @@ impl GeneticNode for FighterNN {
|
|||
// Take the 5 nn's with the highest scores from the left nn's and save them to the new fighter folder
|
||||
let mut sorted_scores: Vec<_> = left.scores[left.generation as usize].iter().collect();
|
||||
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
|
||||
let mut remaining = sorted_scores[(POPULATION / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
|
||||
for i in 0..(POPULATION / 2) {
|
||||
let mut remaining = sorted_scores[(left.population_size / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
|
||||
for i in 0..(left.population_size / 2) {
|
||||
let nn = left.folder.join(format!("{}", left.generation)).join(format!("{:06}_fighter_nn_{}.net", left.id, remaining.pop().unwrap()));
|
||||
let new_nn = folder.join(format!("0")).join(format!("{:06}_fighter_nn_{}.net", highest, i));
|
||||
let new_nn = folder.join(format!("0")).join(format!("{:06}_fighter_nn_{}.net", id, i));
|
||||
trace!("From: {:?}, To: {:?}", &nn, &new_nn);
|
||||
fs::copy(&nn, &new_nn)
|
||||
.with_context(|| format!("Failed to copy left nn"))?;
|
||||
|
@ -195,19 +185,20 @@ impl GeneticNode for FighterNN {
|
|||
// Take the 5 nn's with the highest scores from the right nn's and save them to the new fighter folder
|
||||
sorted_scores = right.scores[right.generation as usize].iter().collect();
|
||||
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
|
||||
remaining = sorted_scores[(POPULATION / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
|
||||
for i in (POPULATION / 2)..POPULATION {
|
||||
remaining = sorted_scores[(right.population_size / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
|
||||
for i in (right.population_size / 2)..right.population_size {
|
||||
let nn = right.folder.join(format!("{}", right.generation)).join(format!("{:06}_fighter_nn_{}.net", right.id, remaining.pop().unwrap()));
|
||||
let new_nn = folder.join(format!("0")).join(format!("{:06}_fighter_nn_{}.net", highest, i));
|
||||
let new_nn = folder.join(format!("0")).join(format!("{:06}_fighter_nn_{}.net", id, i));
|
||||
trace!("From: {:?}, To: {:?}", &nn, &new_nn);
|
||||
fs::copy(&nn, &new_nn)
|
||||
.with_context(|| format!("Failed to copy right nn"))?;
|
||||
}
|
||||
|
||||
Ok(Box::new(FighterNN {
|
||||
id: highest,
|
||||
id: *id,
|
||||
folder,
|
||||
generation: 0,
|
||||
population_size: POPULATION,
|
||||
scores: vec![HashMap::new()],
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use gemla::{core::genetic_node::GeneticNode, error::Error};
|
||||
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
|
||||
use rand::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
const POPULATION_SIZE: u64 = 5;
|
||||
const POPULATION_REDUCTION_SIZE: u64 = 3;
|
||||
|
@ -11,7 +12,7 @@ pub struct TestState {
|
|||
}
|
||||
|
||||
impl GeneticNode for TestState {
|
||||
fn initialize() -> Result<Box<Self>, Error> {
|
||||
fn initialize(_context: &GeneticNodeContext) -> Result<Box<Self>, Error> {
|
||||
let mut population: Vec<i64> = vec![];
|
||||
|
||||
for _ in 0..POPULATION_SIZE {
|
||||
|
@ -21,7 +22,7 @@ impl GeneticNode for TestState {
|
|||
Ok(Box::new(TestState { population }))
|
||||
}
|
||||
|
||||
fn simulate(&mut self) -> Result<(), Error> {
|
||||
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
self.population = self
|
||||
|
@ -33,7 +34,7 @@ impl GeneticNode for TestState {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self) -> Result<(), Error> {
|
||||
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let mut v = self.population.clone();
|
||||
|
@ -71,7 +72,7 @@ impl GeneticNode for TestState {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
|
||||
fn merge(left: &TestState, right: &TestState, id: &Uuid) -> Result<Box<TestState>, Error> {
|
||||
let mut v = left.population.clone();
|
||||
v.append(&mut right.population.clone());
|
||||
|
||||
|
@ -82,7 +83,11 @@ impl GeneticNode for TestState {
|
|||
|
||||
let mut result = TestState { population: v };
|
||||
|
||||
result.mutate()?;
|
||||
result.mutate(&GeneticNodeContext {
|
||||
id: id.clone(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
})?;
|
||||
|
||||
Ok(Box::new(result))
|
||||
}
|
||||
|
@ -95,7 +100,13 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_initialize() {
|
||||
let state = TestState::initialize().unwrap();
|
||||
let state = TestState::initialize(
|
||||
&GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
}
|
||||
).unwrap();
|
||||
|
||||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||
}
|
||||
|
@ -108,14 +119,32 @@ mod tests {
|
|||
|
||||
let original_population = state.population.clone();
|
||||
|
||||
state.simulate().unwrap();
|
||||
state.simulate(
|
||||
&GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
}
|
||||
).unwrap();
|
||||
assert!(original_population
|
||||
.iter()
|
||||
.zip(state.population.iter())
|
||||
.all(|(&a, &b)| b >= a - 1 && b <= a + 2));
|
||||
|
||||
state.simulate().unwrap();
|
||||
state.simulate().unwrap();
|
||||
state.simulate(
|
||||
&GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
}
|
||||
).unwrap();
|
||||
state.simulate(
|
||||
&GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
}
|
||||
).unwrap();
|
||||
assert!(original_population
|
||||
.iter()
|
||||
.zip(state.population.iter())
|
||||
|
@ -128,7 +157,13 @@ mod tests {
|
|||
population: vec![4, 3, 3],
|
||||
};
|
||||
|
||||
state.mutate().unwrap();
|
||||
state.mutate(
|
||||
&GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
}
|
||||
).unwrap();
|
||||
|
||||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||
}
|
||||
|
@ -143,7 +178,7 @@ mod tests {
|
|||
population: vec![0, 1, 3, 7],
|
||||
};
|
||||
|
||||
let merged_state = TestState::merge(&state1, &state2).unwrap();
|
||||
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4()).unwrap();
|
||||
|
||||
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
|
||||
assert!(merged_state.population.iter().any(|&x| x == 7));
|
||||
|
|
|
@ -24,6 +24,12 @@ pub enum GeneticState {
|
|||
Finish,
|
||||
}
|
||||
|
||||
pub struct GeneticNodeContext {
|
||||
pub generation: u64,
|
||||
pub max_generations: u64,
|
||||
pub id: Uuid,
|
||||
}
|
||||
|
||||
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
|
||||
///
|
||||
/// [`Bracket`]: crate::bracket::Bracket
|
||||
|
@ -32,17 +38,17 @@ pub trait GeneticNode {
|
|||
///
|
||||
/// # Examples
|
||||
/// TODO
|
||||
fn initialize() -> Result<Box<Self>, Error>;
|
||||
fn initialize(context: &GeneticNodeContext) -> Result<Box<Self>, Error>;
|
||||
|
||||
fn simulate(&mut self) -> Result<(), Error>;
|
||||
fn simulate(&mut self, context: &GeneticNodeContext) -> Result<(), Error>;
|
||||
|
||||
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
|
||||
///
|
||||
/// # Examples
|
||||
/// TODO
|
||||
fn mutate(&mut self) -> Result<(), Error>;
|
||||
fn mutate(&mut self, context: &GeneticNodeContext) -> Result<(), Error>;
|
||||
|
||||
fn merge(left: &Self, right: &Self) -> Result<Box<Self>, Error>;
|
||||
fn merge(left: &Self, right: &Self, id: &Uuid) -> Result<Box<Self>, Error>;
|
||||
}
|
||||
|
||||
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
|
||||
|
@ -101,18 +107,28 @@ where
|
|||
self.max_generations
|
||||
}
|
||||
|
||||
pub fn generation(&self) -> u64 {
|
||||
self.generation
|
||||
}
|
||||
|
||||
pub fn state(&self) -> GeneticState {
|
||||
self.state
|
||||
}
|
||||
|
||||
pub fn process_node(&mut self) -> Result<GeneticState, Error> {
|
||||
let context = GeneticNodeContext {
|
||||
generation: self.generation,
|
||||
max_generations: self.max_generations,
|
||||
id: self.id,
|
||||
};
|
||||
|
||||
match (self.state, &mut self.node) {
|
||||
(GeneticState::Initialize, _) => {
|
||||
self.node = Some(*T::initialize()?);
|
||||
self.node = Some(*T::initialize(&context)?);
|
||||
self.state = GeneticState::Simulate;
|
||||
}
|
||||
(GeneticState::Simulate, Some(n)) => {
|
||||
n.simulate()
|
||||
n.simulate(&context)
|
||||
.with_context(|| format!("Error simulating node: {:?}", self))?;
|
||||
|
||||
self.state = if self.generation >= self.max_generations {
|
||||
|
@ -122,7 +138,7 @@ where
|
|||
};
|
||||
}
|
||||
(GeneticState::Mutate, Some(n)) => {
|
||||
n.mutate()
|
||||
n.mutate(&context)
|
||||
.with_context(|| format!("Error mutating node: {:?}", self))?;
|
||||
|
||||
self.generation += 1;
|
||||
|
@ -148,20 +164,20 @@ mod tests {
|
|||
}
|
||||
|
||||
impl GeneticNode for TestState {
|
||||
fn simulate(&mut self) -> Result<(), Error> {
|
||||
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
self.score += 1.0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self) -> Result<(), Error> {
|
||||
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn initialize() -> Result<Box<TestState>, Error> {
|
||||
fn initialize(_context: &GeneticNodeContext) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState { score: 0.0 }))
|
||||
}
|
||||
|
||||
fn merge(_l: &TestState, _r: &TestState) -> Result<Box<TestState>, Error> {
|
||||
fn merge(_l: &TestState, _r: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> {
|
||||
Err(Error::Other(anyhow!("Unable to merge")))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -195,7 +195,7 @@ where
|
|||
{
|
||||
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
|
||||
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
|
||||
let merged_node = GeneticNode::merge(left_node, right_node)?;
|
||||
let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id())?;
|
||||
tree.val = GeneticNodeWrapper::from(
|
||||
*merged_node,
|
||||
tree.val.max_generations(),
|
||||
|
@ -337,6 +337,8 @@ mod tests {
|
|||
use std::path::PathBuf;
|
||||
use std::fs;
|
||||
|
||||
use self::genetic_node::GeneticNodeContext;
|
||||
|
||||
struct CleanUp {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
@ -367,20 +369,20 @@ mod tests {
|
|||
}
|
||||
|
||||
impl genetic_node::GeneticNode for TestState {
|
||||
fn simulate(&mut self) -> Result<(), Error> {
|
||||
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
self.score += 1.0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self) -> Result<(), Error> {
|
||||
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn initialize() -> Result<Box<TestState>, Error> {
|
||||
fn initialize(_context: &GeneticNodeContext) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState { score: 0.0 }))
|
||||
}
|
||||
|
||||
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
|
||||
fn merge(left: &TestState, right: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(if left.score > right.score {
|
||||
left.clone()
|
||||
} else {
|
||||
|
|
Loading…
Add table
Reference in a new issue