dootcamp #1

Merged
tepichord merged 26 commits from dootcamp into master 2025-09-05 09:37:40 -07:00
8 changed files with 1439 additions and 626 deletions
Showing only changes of commit 7a1f82ac63 - Show all commits

View file

@ -6,12 +6,11 @@ pub mod constants;
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use constants::data_format::DataFormat; use constants::data_format::DataFormat;
use error::Error; use error::Error;
use futures::executor::block_on;
use log::info; use log::info;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use std::{ use std::{
borrow::Borrow, fs::{copy, remove_file, File}, io::{ErrorKind, Write}, path::{Path, PathBuf}, sync::Arc, thread::{self, JoinHandle} fs::{copy, remove_file, File}, io::{ErrorKind, Write}, path::{Path, PathBuf}, sync::Arc, thread::{self, JoinHandle}
}; };
@ -56,6 +55,7 @@ where
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path::PathBuf; /// # use std::path::PathBuf;
/// # use tokio;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -64,19 +64,22 @@ where
/// # pub c: f64 /// # pub c: f64
/// # } /// # }
/// # /// #
/// # fn main() { /// # #[tokio::main]
/// # async fn main() {
/// let test = Test { /// let test = Test {
/// a: 1, /// a: 1,
/// b: String::from("two"), /// b: String::from("two"),
/// c: 3.0 /// c: 3.0
/// }; /// };
/// ///
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json) /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// assert_eq!(linked_test.readonly().a, 1); /// let readonly = linked_test.readonly();
/// assert_eq!(linked_test.readonly().b, String::from("two")); /// let readonly_ref = readonly.read().await;
/// assert_eq!(linked_test.readonly().c, 3.0); /// assert_eq!(readonly_ref.a, 1);
/// assert_eq!(readonly_ref.b, String::from("two"));
/// assert_eq!(readonly_ref.c, 3.0);
/// # /// #
/// # drop(linked_test); /// # drop(linked_test);
/// # /// #
@ -97,6 +100,7 @@ where
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path::PathBuf; /// # use std::path::PathBuf;
/// # use tokio;
/// # /// #
/// #[derive(Deserialize, Serialize)] /// #[derive(Deserialize, Serialize)]
/// struct Test { /// struct Test {
@ -105,19 +109,22 @@ where
/// pub c: f64 /// pub c: f64
/// } /// }
/// ///
/// # fn main() { /// #[tokio::main]
/// # async fn main() {
/// let test = Test { /// let test = Test {
/// a: 1, /// a: 1,
/// b: String::from("two"), /// b: String::from("two"),
/// c: 3.0 /// c: 3.0
/// }; /// };
/// ///
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json) /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// assert_eq!(linked_test.readonly().a, 1); /// let readonly = linked_test.readonly();
/// assert_eq!(linked_test.readonly().b, String::from("two")); /// let readonly_ref = readonly.read().await;
/// assert_eq!(linked_test.readonly().c, 3.0); /// assert_eq!(readonly_ref.a, 1);
/// assert_eq!(readonly_ref.b, String::from("two"));
/// assert_eq!(readonly_ref.c, 3.0);
/// # /// #
/// # drop(linked_test); /// # drop(linked_test);
/// # /// #
@ -207,6 +214,7 @@ where
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path::PathBuf; /// # use std::path::PathBuf;
/// # use tokio;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -215,21 +223,28 @@ where
/// # pub c: f64 /// # pub c: f64
/// # } /// # }
/// # /// #
/// # fn main() -> Result<(), Error> { /// # #[tokio::main]
/// # async fn main() -> Result<(), Error> {
/// let test = Test { /// let test = Test {
/// a: 1, /// a: 1,
/// b: String::from(""), /// b: String::from(""),
/// c: 0.0 /// c: 0.0
/// }; /// };
/// ///
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode) /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// assert_eq!(linked_test.readonly().a, 1); /// {
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
/// }
/// ///
/// linked_test.mutate(|t| t.a = 2)?; /// linked_test.mutate(|t| t.a = 2).await?;
/// ///
/// assert_eq!(linked_test.readonly().a, 2); /// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 2);
/// # /// #
/// # drop(linked_test); /// # drop(linked_test);
/// # /// #
@ -262,6 +277,7 @@ where
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path::PathBuf; /// # use std::path::PathBuf;
/// # use tokio;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -270,25 +286,30 @@ where
/// # pub c: f64 /// # pub c: f64
/// # } /// # }
/// # /// #
/// # fn main() -> Result<(), Error> { /// # #[tokio::main]
/// # async fn main() -> Result<(), Error> {
/// let test = Test { /// let test = Test {
/// a: 1, /// a: 1,
/// b: String::from(""), /// b: String::from(""),
/// c: 0.0 /// c: 0.0
/// }; /// };
/// ///
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode) /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// assert_eq!(linked_test.readonly().a, 1); /// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
/// ///
/// linked_test.replace(Test { /// linked_test.replace(Test {
/// a: 2, /// a: 2,
/// b: String::from(""), /// b: String::from(""),
/// c: 0.0 /// c: 0.0
/// })?; /// }).await?;
/// ///
/// assert_eq!(linked_test.readonly().a, 2); /// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 2);
/// # /// #
/// # drop(linked_test); /// # drop(linked_test);
/// # /// #
@ -343,6 +364,7 @@ where
/// # use std::fs::OpenOptions; /// # use std::fs::OpenOptions;
/// # use std::io::Write; /// # use std::io::Write;
/// # use std::path::PathBuf; /// # use std::path::PathBuf;
/// # use tokio;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -351,7 +373,8 @@ where
/// # pub c: f64 /// # pub c: f64
/// # } /// # }
/// # /// #
/// # fn main() -> Result<(), Error> { /// # #[tokio::main]
/// # async fn main() -> Result<(), Error> {
/// let test = Test { /// let test = Test {
/// a: 1, /// a: 1,
/// b: String::from("2"), /// b: String::from("2"),
@ -371,9 +394,11 @@ where
/// let mut linked_test = FileLinked::<Test>::from_file(&path, DataFormat::Bincode) /// let mut linked_test = FileLinked::<Test>::from_file(&path, DataFormat::Bincode)
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// assert_eq!(linked_test.readonly().a, test.a); /// let readonly = linked_test.readonly();
/// assert_eq!(linked_test.readonly().b, test.b); /// let readonly_ref = readonly.read().await;
/// assert_eq!(linked_test.readonly().c, test.c); /// assert_eq!(readonly_ref.a, test.a);
/// assert_eq!(readonly_ref.b, test.b);
/// assert_eq!(readonly_ref.c, test.c);
/// # /// #
/// # drop(linked_test); /// # drop(linked_test);
/// # /// #

View file

@ -3,18 +3,18 @@ extern crate gemla;
#[macro_use] #[macro_use]
extern crate log; extern crate log;
mod test_state;
mod fighter_nn; mod fighter_nn;
mod test_state;
use anyhow::Result;
use clap::Parser;
use fighter_nn::FighterNN;
use file_linked::constants::data_format::DataFormat; use file_linked::constants::data_format::DataFormat;
use gemla::{ use gemla::{
core::{Gemla, GemlaConfig}, core::{Gemla, GemlaConfig},
error::log_error, error::log_error,
}; };
use std::{path::PathBuf, time::Instant}; use std::{path::PathBuf, time::Instant};
use fighter_nn::FighterNN;
use clap::Parser;
use anyhow::Result;
// const NUM_THREADS: usize = 2; // const NUM_THREADS: usize = 2;
@ -44,14 +44,17 @@ fn main() -> Result<()> {
.build()? .build()?
.block_on(async { .block_on(async {
let args = Args::parse(); // Assuming Args::parse() doesn't need to be async let args = Args::parse(); // Assuming Args::parse() doesn't need to be async
let mut gemla = log_error(Gemla::<FighterNN>::new( let mut gemla = log_error(
&PathBuf::from(args.file), Gemla::<FighterNN>::new(
GemlaConfig { &PathBuf::from(args.file),
generations_per_height: 5, GemlaConfig {
overwrite: false, generations_per_height: 5,
}, overwrite: false,
DataFormat::Json, },
).await)?; DataFormat::Json,
)
.await,
)?;
// let gemla_arc = Arc::new(gemla); // let gemla_arc = Arc::new(gemla);
@ -59,7 +62,8 @@ fn main() -> Result<()> {
// If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks // If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks
// Example placeholder loop to continuously run simulate // Example placeholder loop to continuously run simulate
loop { // Arbitrary loop count for demonstration loop {
// Arbitrary loop count for demonstration
gemla.simulate(1).await?; gemla.simulate(1).await?;
} }
}); });

View file

@ -5,7 +5,6 @@ use tokio::sync::Semaphore;
const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 50; const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 50;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct FighterContext { pub struct FighterContext {
pub shared_semaphore: Arc<Semaphore>, pub shared_semaphore: Arc<Semaphore>,
@ -19,7 +18,6 @@ impl Default for FighterContext {
} }
} }
// Custom serialization to just output the concurrency limit. // Custom serialization to just output the concurrency limit.
impl Serialize for FighterContext { impl Serialize for FighterContext {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>

View file

@ -1,20 +1,29 @@
extern crate fann; extern crate fann;
pub mod neural_network_utility;
pub mod fighter_context; pub mod fighter_context;
pub mod neural_network_utility;
use std::{cmp::max, collections::{HashSet, VecDeque}, fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, panic::{catch_unwind, AssertUnwindSafe}, path::{Path, PathBuf}, sync::{Arc, Mutex}, time::Duration}; use anyhow::Context;
use async_trait::async_trait;
use fann::{ActivationFunc, Fann}; use fann::{ActivationFunc, Fann};
use futures::{executor::block_on, future::{join, join_all, select_all}, stream::FuturesUnordered, FutureExt, StreamExt}; use futures::future::join_all;
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error}; use gemla::{
core::genetic_node::{GeneticNode, GeneticNodeContext},
error::Error,
};
use lerp::Lerp; use lerp::Lerp;
use rand::prelude::*; use rand::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use anyhow::Context;
use tokio::{process::Command, sync::{mpsc, Semaphore}, task, time::{sleep, timeout, Sleep}};
use uuid::Uuid;
use std::collections::HashMap; use std::collections::HashMap;
use async_trait::async_trait; use std::{
cmp::max,
fs::{self, File},
io::{self, BufRead, BufReader},
ops::Range,
path::{Path, PathBuf},
};
use tokio::process::Command;
use uuid::Uuid;
use self::neural_network_utility::{crossbreed, major_mutation}; use self::neural_network_utility::{crossbreed, major_mutation};
@ -34,7 +43,8 @@ const NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX: usize = 20;
const SIMULATION_ROUNDS: usize = 5; const SIMULATION_ROUNDS: usize = 5;
const SURVIVAL_RATE: f32 = 0.5; const SURVIVAL_RATE: f32 = 0.5;
const GAME_EXECUTABLE_PATH: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe"; const GAME_EXECUTABLE_PATH: &str =
"F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe";
// Here is the folder structure for the FighterNN: // Here is the folder structure for the FighterNN:
// base_dir/fighter_nn_{fighter_id}/{generation}/{fighter_id}_fighter_nn_{nn_id}.net // base_dir/fighter_nn_{fighter_id}/{generation}/{fighter_id}_fighter_nn_{nn_id}.net
@ -78,11 +88,17 @@ impl GeneticNode for FighterNN {
//Create a new directory for the first generation, using create_dir_all to avoid errors if it already exists //Create a new directory for the first generation, using create_dir_all to avoid errors if it already exists
let gen_folder = folder.join("0"); let gen_folder = folder.join("0");
fs::create_dir_all(&gen_folder) fs::create_dir_all(&gen_folder).with_context(|| {
.with_context(|| format!("Failed to create or access the generation folder: {:?}", gen_folder))?; format!(
"Failed to create or access the generation folder: {:?}",
gen_folder
)
})?;
let mut nn_shapes = HashMap::new(); let mut nn_shapes = HashMap::new();
let weight_initialization_range = thread_rng().gen_range(NEURAL_NETWORK_INITIAL_WEIGHT_MIN..0.0)..thread_rng().gen_range(0.0..=NEURAL_NETWORK_INITIAL_WEIGHT_MAX); let weight_initialization_range = thread_rng()
.gen_range(NEURAL_NETWORK_INITIAL_WEIGHT_MIN..0.0)
..thread_rng().gen_range(0.0..=NEURAL_NETWORK_INITIAL_WEIGHT_MAX);
// Create the first generation in this folder // Create the first generation in this folder
for i in 0..POPULATION { for i in 0..POPULATION {
@ -90,17 +106,22 @@ impl GeneticNode for FighterNN {
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i)); let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i));
// Randomly generate a neural network shape based on constants // Randomly generate a neural network shape based on constants
let hidden_layers = thread_rng().gen_range(NEURAL_NETWORK_HIDDEN_LAYERS_MIN..NEURAL_NETWORK_HIDDEN_LAYERS_MAX); let hidden_layers = thread_rng()
.gen_range(NEURAL_NETWORK_HIDDEN_LAYERS_MIN..NEURAL_NETWORK_HIDDEN_LAYERS_MAX);
let mut nn_shape = vec![NEURAL_NETWORK_INPUTS as u32]; let mut nn_shape = vec![NEURAL_NETWORK_INPUTS as u32];
for _ in 0..hidden_layers { for _ in 0..hidden_layers {
nn_shape.push(thread_rng().gen_range(NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN..NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX) as u32); nn_shape.push(thread_rng().gen_range(
NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN..NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX,
) as u32);
} }
nn_shape.push(NEURAL_NETWORK_OUTPUTS as u32); nn_shape.push(NEURAL_NETWORK_OUTPUTS as u32);
nn_shapes.insert(i as u64, nn_shape.clone()); nn_shapes.insert(i as u64, nn_shape.clone());
let mut fann = Fann::new(nn_shape.as_slice()) let mut fann = Fann::new(nn_shape.as_slice()).with_context(|| "Failed to create nn")?;
.with_context(|| "Failed to create nn")?; fann.randomize_weights(
fann.randomize_weights(weight_initialization_range.start, weight_initialization_range.end); weight_initialization_range.start,
weight_initialization_range.end,
);
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
// This will overwrite any existing file with the same name // This will overwrite any existing file with the same name
@ -108,7 +129,9 @@ impl GeneticNode for FighterNN {
.with_context(|| format!("Failed to save nn at {:?}", nn))?; .with_context(|| format!("Failed to save nn at {:?}", nn))?;
} }
let mut crossbreed_segments = thread_rng().gen_range(NEURAL_NETWORK_CROSSBREED_SEGMENTS_MIN..NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX); let mut crossbreed_segments = thread_rng().gen_range(
NEURAL_NETWORK_CROSSBREED_SEGMENTS_MIN..NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX,
);
if crossbreed_segments % 2 == 0 { if crossbreed_segments % 2 == 0 {
crossbreed_segments += 1; crossbreed_segments += 1;
} }
@ -141,7 +164,10 @@ impl GeneticNode for FighterNN {
let semaphore_clone = context.gemla_context.shared_semaphore.clone(); let semaphore_clone = context.gemla_context.shared_semaphore.clone();
let task = async move { let task = async move {
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(self_clone.get_individual_id(i as u64)); let nn = self_clone
.folder
.join(format!("{}", self_clone.generation))
.join(self_clone.get_individual_id(i as u64));
let mut simulations = Vec::new(); let mut simulations = Vec::new();
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently // Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
@ -151,11 +177,16 @@ impl GeneticNode for FighterNN {
let generation = self_clone.generation; let generation = self_clone.generation;
let semaphore_clone = semaphore_clone.clone(); let semaphore_clone = semaphore_clone.clone();
let random_nn = folder.join(format!("{}", generation)).join(self_clone.get_individual_id(random_nn_index as u64)); let random_nn = folder
.join(format!("{}", generation))
.join(self_clone.get_individual_id(random_nn_index as u64));
let nn_clone = nn.clone(); // Clone the path to use in the async block let nn_clone = nn.clone(); // Clone the path to use in the async block
let future = async move { let future = async move {
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?; let permit = semaphore_clone
.acquire_owned()
.await
.with_context(|| "Failed to acquire semaphore permit")?;
let (score, _) = run_1v1_simulation(&nn_clone, &random_nn).await?; let (score, _) = run_1v1_simulation(&nn_clone, &random_nn).await?;
@ -168,7 +199,8 @@ impl GeneticNode for FighterNN {
} }
// Wait for all simulation rounds to complete // Wait for all simulation rounds to complete
let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect(); let results: Result<Vec<f32>, Error> =
join_all(simulations).await.into_iter().collect();
let score = match results { let score = match results {
Ok(scores) => scores.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32, Ok(scores) => scores.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32,
@ -188,34 +220,46 @@ impl GeneticNode for FighterNN {
Ok((index, score)) => { Ok((index, score)) => {
// Update the original `self` object with the score. // Update the original `self` object with the score.
self.scores[self.generation as usize].insert(index as u64, score); self.scores[self.generation as usize].insert(index as u64, score);
}, }
Err(e) => { Err(e) => {
// Handle task panic or execution error // Handle task panic or execution error
return Err(Error::Other(anyhow::anyhow!(format!("Task failed: {:?}", e)))); return Err(Error::Other(anyhow::anyhow!(format!(
}, "Task failed: {:?}",
e
))));
}
} }
} }
Ok(()) Ok(())
} }
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> { async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize; let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
// Create the new generation folder // Create the new generation folder
let new_gen_folder = self.folder.join(format!("{}", self.generation + 1)); let new_gen_folder = self.folder.join(format!("{}", self.generation + 1));
fs::create_dir_all(&new_gen_folder).with_context(|| format!("Failed to create or access new generation folder: {:?}", new_gen_folder))?; fs::create_dir_all(&new_gen_folder).with_context(|| {
format!(
"Failed to create or access new generation folder: {:?}",
new_gen_folder
)
})?;
// Remove the 5 nn's with the lowest scores // Remove the 5 nn's with the lowest scores
let mut sorted_scores: Vec<_> = self.scores[self.generation as usize].iter().collect(); let mut sorted_scores: Vec<_> = self.scores[self.generation as usize].iter().collect();
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap()); sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
let to_keep = sorted_scores[survivor_count..].iter().map(|(k, _)| *k).collect::<Vec<_>>(); let to_keep = sorted_scores[survivor_count..]
.iter()
.map(|(k, _)| *k)
.collect::<Vec<_>>();
// Save the remaining 5 nn's to the new generation folder // Save the remaining 5 nn's to the new generation folder
for i in 0..survivor_count { for (i, nn_id) in to_keep.iter().enumerate().take(survivor_count) {
let nn_id = to_keep[i]; let nn = self
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id)); .folder
.join(format!("{}", self.generation))
.join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id));
let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i)); let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i));
fs::copy(&nn, &new_nn)?; fs::copy(&nn, &new_nn)?;
} }
@ -223,16 +267,25 @@ impl GeneticNode for FighterNN {
// Take the remaining 5 nn's and create 5 new nn's by the following: // Take the remaining 5 nn's and create 5 new nn's by the following:
for i in 0..survivor_count { for i in 0..survivor_count {
let nn_id = to_keep[i]; let nn_id = to_keep[i];
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id)); let nn = self
let fann = Fann::from_file(&nn) .folder
.with_context(|| format!("Failed to load nn"))?; .join(format!("{}", self.generation))
.join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id));
let fann = Fann::from_file(&nn).with_context(|| "Failed to load nn")?;
// Load another nn from the current generation and cross breed it with the current nn // Load another nn from the current generation and cross breed it with the current nn
let cross_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, to_keep[thread_rng().gen_range(0..survivor_count)])); let cross_nn = self
let cross_fann = Fann::from_file(&cross_nn) .folder
.with_context(|| format!("Failed to load cross nn"))?; .join(format!("{}", self.generation))
.join(format!(
"{:06}_fighter_nn_{}.net",
self.id,
to_keep[thread_rng().gen_range(0..survivor_count)]
));
let cross_fann =
Fann::from_file(&cross_nn).with_context(|| "Failed to load cross nn")?;
let mut new_fann = crossbreed(&self, &fann, &cross_fann, self.crossbreed_segments)?; let mut new_fann = crossbreed(self, &fann, &cross_fann, self.crossbreed_segments)?;
// For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight) // For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight)
// And a 5% chance of a major mutation a new neuron is randomly added to a hidden layer // And a 5% chance of a major mutation a new neuron is randomly added to a hidden layer
@ -252,9 +305,14 @@ impl GeneticNode for FighterNN {
} }
// Save the new nn's to the new generation folder // Save the new nn's to the new generation folder
let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i + survivor_count)); let new_nn = new_gen_folder.join(format!(
new_fann.save(&new_nn) "{:06}_fighter_nn_{}.net",
.with_context(|| format!("Failed to save nn"))?; self.id,
i + survivor_count
));
new_fann
.save(&new_nn)
.with_context(|| "Failed to save nn")?;
} }
self.generation += 1; self.generation += 1;
@ -263,18 +321,28 @@ impl GeneticNode for FighterNN {
Ok(()) Ok(())
} }
async fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid, gemla_context: Self::Context) -> Result<Box<FighterNN>, Error> { async fn merge(
left: &FighterNN,
right: &FighterNN,
id: &Uuid,
gemla_context: Self::Context,
) -> Result<Box<FighterNN>, Error> {
let base_path = PathBuf::from(BASE_DIR); let base_path = PathBuf::from(BASE_DIR);
let folder = base_path.join(format!("fighter_nn_{:06}", id)); let folder = base_path.join(format!("fighter_nn_{:06}", id));
// Ensure the folder exists, including the generation subfolder. // Ensure the folder exists, including the generation subfolder.
fs::create_dir_all(&folder.join("0")) fs::create_dir_all(folder.join("0"))
.with_context(|| format!("Failed to create directory {:?}", folder.join("0")))?; .with_context(|| format!("Failed to create directory {:?}", folder.join("0")))?;
let get_highest_scores = |fighter: &FighterNN| -> Vec<(u64, f32)> { let get_highest_scores = |fighter: &FighterNN| -> Vec<(u64, f32)> {
let mut sorted_scores: Vec<_> = fighter.scores[fighter.generation as usize].iter().collect(); let mut sorted_scores: Vec<_> =
fighter.scores[fighter.generation as usize].iter().collect();
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap()); sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
sorted_scores.iter().take(fighter.population_size / 2).map(|(k, v)| (**k, **v)).collect() sorted_scores
.iter()
.take(fighter.population_size / 2)
.map(|(k, v)| (**k, **v))
.collect()
}; };
let left_scores = get_highest_scores(left); let left_scores = get_highest_scores(left);
@ -285,18 +353,28 @@ impl GeneticNode for FighterNN {
let mut simulations = Vec::new(); let mut simulations = Vec::new();
for _ in 0..max(left.population_size, right.population_size)*SIMULATION_ROUNDS { for _ in 0..max(left.population_size, right.population_size) * SIMULATION_ROUNDS {
let left_nn_id = left_scores[thread_rng().gen_range(0..left_scores.len())].0; let left_nn_id = left_scores[thread_rng().gen_range(0..left_scores.len())].0;
let right_nn_id = right_scores[thread_rng().gen_range(0..right_scores.len())].0; let right_nn_id = right_scores[thread_rng().gen_range(0..right_scores.len())].0;
let left_nn_path = left.folder.join(left.generation.to_string()).join(left.get_individual_id(left_nn_id)); let left_nn_path = left
let right_nn_path = right.folder.join(right.generation.to_string()).join(right.get_individual_id(right_nn_id)); .folder
.join(left.generation.to_string())
.join(left.get_individual_id(left_nn_id));
let right_nn_path = right
.folder
.join(right.generation.to_string())
.join(right.get_individual_id(right_nn_id));
let semaphore_clone = gemla_context.shared_semaphore.clone(); let semaphore_clone = gemla_context.shared_semaphore.clone();
let future = async move { let future = async move {
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?; let permit = semaphore_clone
.acquire_owned()
.await
.with_context(|| "Failed to acquire semaphore permit")?;
let (left_score, right_score) = run_1v1_simulation(&left_nn_path, &right_nn_path).await?; let (left_score, right_score) =
run_1v1_simulation(&left_nn_path, &right_nn_path).await?;
drop(permit); drop(permit);
@ -306,7 +384,8 @@ impl GeneticNode for FighterNN {
simulations.push(future); simulations.push(future);
} }
let results: Result<Vec<(f32, f32)>, Error> = join_all(simulations).await.into_iter().collect(); let results: Result<Vec<(f32, f32)>, Error> =
join_all(simulations).await.into_iter().collect();
let scores = results?; let scores = results?;
let total_left_score = scores.iter().map(|(l, _)| l).sum::<f32>(); let total_left_score = scores.iter().map(|(l, _)| l).sum::<f32>();
@ -322,18 +401,36 @@ impl GeneticNode for FighterNN {
let mut nn_shapes = HashMap::new(); let mut nn_shapes = HashMap::new();
// Function to copy NNs from a source FighterNN to the new folder. // Function to copy NNs from a source FighterNN to the new folder.
let mut copy_nns = |source: &FighterNN, folder: &PathBuf, id: &Uuid, start_idx: usize| -> Result<(), Error> { let mut copy_nns = |source: &FighterNN,
let mut sorted_scores: Vec<_> = source.scores[source.generation as usize].iter().collect(); folder: &PathBuf,
id: &Uuid,
start_idx: usize|
-> Result<(), Error> {
let mut sorted_scores: Vec<_> =
source.scores[source.generation as usize].iter().collect();
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap()); sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
let remaining = sorted_scores[(source.population_size / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>(); let remaining = sorted_scores[(source.population_size / 2)..]
.iter()
.map(|(k, _)| *k)
.collect::<Vec<_>>();
for (i, nn_id) in remaining.into_iter().enumerate() { for (i, nn_id) in remaining.into_iter().enumerate() {
let nn_path = source.folder.join(source.generation.to_string()).join(format!("{:06}_fighter_nn_{}.net", source.id, nn_id)); let nn_path = source
let new_nn_path = folder.join("0").join(format!("{:06}_fighter_nn_{}.net", id, start_idx + i)); .folder
fs::copy(&nn_path, &new_nn_path) .join(source.generation.to_string())
.with_context(|| format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path))?; .join(format!("{:06}_fighter_nn_{}.net", source.id, nn_id));
let new_nn_path =
folder
.join("0")
.join(format!("{:06}_fighter_nn_{}.net", id, start_idx + i));
fs::copy(&nn_path, &new_nn_path).with_context(|| {
format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path)
})?;
nn_shapes.insert((start_idx + i) as u64, source.nn_shapes.get(&nn_id).unwrap().clone()); nn_shapes.insert(
(start_idx + i) as u64,
source.nn_shapes.get(nn_id).unwrap().clone(),
);
} }
Ok(()) Ok(())
@ -341,32 +438,54 @@ impl GeneticNode for FighterNN {
// Copy the top half of NNs from each parent to the new folder. // Copy the top half of NNs from each parent to the new folder.
copy_nns(left, &folder, id, 0)?; copy_nns(left, &folder, id, 0)?;
copy_nns(right, &folder, id, left.population_size as usize / 2)?; copy_nns(right, &folder, id, left.population_size / 2)?;
debug!("nn_shapes: {:?}", nn_shapes); debug!("nn_shapes: {:?}", nn_shapes);
// Lerp the mutation rates and weight ranges // Lerp the mutation rates and weight ranges
let crossbreed_segments = (left.crossbreed_segments as f32).lerp(right.crossbreed_segments as f32, lerp_amount) as usize; let crossbreed_segments = (left.crossbreed_segments as f32)
.lerp(right.crossbreed_segments as f32, lerp_amount)
as usize;
let weight_initialization_range_start = left.weight_initialization_range.start.lerp(right.weight_initialization_range.start, lerp_amount); let weight_initialization_range_start = left
let weight_initialization_range_end = left.weight_initialization_range.end.lerp(right.weight_initialization_range.end, lerp_amount); .weight_initialization_range
.start
.lerp(right.weight_initialization_range.start, lerp_amount);
let weight_initialization_range_end = left
.weight_initialization_range
.end
.lerp(right.weight_initialization_range.end, lerp_amount);
// Have to ensure the range is valid // Have to ensure the range is valid
let weight_initialization_range = if weight_initialization_range_start < weight_initialization_range_end { let weight_initialization_range =
weight_initialization_range_start..weight_initialization_range_end if weight_initialization_range_start < weight_initialization_range_end {
} else { weight_initialization_range_start..weight_initialization_range_end
weight_initialization_range_end..weight_initialization_range_start } else {
}; weight_initialization_range_end..weight_initialization_range_start
};
debug!("weight_initialization_range: {:?}", weight_initialization_range); debug!(
"weight_initialization_range: {:?}",
weight_initialization_range
);
let minor_mutation_rate = left.minor_mutation_rate.lerp(right.minor_mutation_rate, lerp_amount); let minor_mutation_rate = left
let major_mutation_rate = left.major_mutation_rate.lerp(right.major_mutation_rate, lerp_amount); .minor_mutation_rate
.lerp(right.minor_mutation_rate, lerp_amount);
let major_mutation_rate = left
.major_mutation_rate
.lerp(right.major_mutation_rate, lerp_amount);
debug!("minor_mutation_rate: {}", minor_mutation_rate); debug!("minor_mutation_rate: {}", minor_mutation_rate);
debug!("major_mutation_rate: {}", major_mutation_rate); debug!("major_mutation_rate: {}", major_mutation_rate);
let mutation_weight_range_start = left.mutation_weight_range.start.lerp(right.mutation_weight_range.start, lerp_amount); let mutation_weight_range_start = left
let mutation_weight_range_end = left.mutation_weight_range.end.lerp(right.mutation_weight_range.end, lerp_amount); .mutation_weight_range
.start
.lerp(right.mutation_weight_range.start, lerp_amount);
let mutation_weight_range_end = left
.mutation_weight_range
.end
.lerp(right.mutation_weight_range.end, lerp_amount);
// Have to ensure the range is valid // Have to ensure the range is valid
let mutation_weight_range = if mutation_weight_range_start < mutation_weight_range_end { let mutation_weight_range = if mutation_weight_range_start < mutation_weight_range_end {
mutation_weight_range_start..mutation_weight_range_end mutation_weight_range_start..mutation_weight_range_end
@ -398,7 +517,7 @@ impl FighterNN {
} }
} }
async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<(f32, f32), Error> { async fn run_1v1_simulation(nn_path_1: &Path, nn_path_2: &Path) -> Result<(f32, f32), Error> {
// Construct the score file path // Construct the score file path
let base_folder = nn_path_1.parent().unwrap(); let base_folder = nn_path_1.parent().unwrap();
let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap(); let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap();
@ -407,14 +526,18 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<
// Check if score file already exists before running the simulation // Check if score file already exists before running the simulation
if score_file.exists() { if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_1_id).await let round_score = read_score_from_file(&score_file, nn_1_id)
.await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?; .with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
let opposing_score = read_score_from_file(&score_file, &nn_2_id).await let opposing_score = read_score_from_file(&score_file, nn_2_id)
.await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?; .with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score); debug!(
"{} scored {}, while {} scored {}",
nn_1_id, round_score, nn_2_id, opposing_score
);
return Ok((round_score, opposing_score)); return Ok((round_score, opposing_score));
} }
@ -422,13 +545,22 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<
// Check if the opposite round score has been determined // Check if the opposite round score has been determined
let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id)); let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id));
if opposite_score_file.exists() { if opposite_score_file.exists() {
let round_score = read_score_from_file(&opposite_score_file, &nn_1_id).await let round_score = read_score_from_file(&opposite_score_file, nn_1_id)
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?; .await
.with_context(|| {
format!("Failed to read score from file: {:?}", opposite_score_file)
})?;
let opposing_score = read_score_from_file(&opposite_score_file, &nn_2_id).await let opposing_score = read_score_from_file(&opposite_score_file, nn_2_id)
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?; .await
.with_context(|| {
format!("Failed to read score from file: {:?}", opposite_score_file)
})?;
debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score); debug!(
"{} scored {}, while {} scored {}",
nn_1_id, round_score, nn_2_id, opposing_score
);
return Ok((round_score, opposing_score)); return Ok((round_score, opposing_score));
} }
@ -459,20 +591,29 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<
.expect("Failed to execute game") .expect("Failed to execute game")
}; };
trace!("Simulation completed for {} vs {}: {}", nn_1_id, nn_2_id, score_file.exists()); trace!(
"Simulation completed for {} vs {}: {}",
nn_1_id,
nn_2_id,
score_file.exists()
);
// Read the score from the file // Read the score from the file
if score_file.exists() { if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_1_id).await let round_score = read_score_from_file(&score_file, nn_1_id)
.await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?; .with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
let opposing_score = read_score_from_file(&score_file, &nn_2_id).await let opposing_score = read_score_from_file(&score_file, nn_2_id)
.await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?; .with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score); debug!(
"{} scored {}, while {} scored {}",
nn_1_id, round_score, nn_2_id, opposing_score
);
Ok((round_score, opposing_score))
return Ok((round_score, opposing_score))
} else { } else {
warn!("Score file not found: {:?}", score_file); warn!("Score file not found: {:?}", score_file);
Ok((0.0, 0.0)) Ok((0.0, 0.0))
@ -492,7 +633,10 @@ async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::
if line.starts_with(nn_id) { if line.starts_with(nn_id) {
let parts: Vec<&str> = line.split(':').collect(); let parts: Vec<&str> = line.split(':').collect();
if parts.len() == 2 { if parts.len() == 2 {
return parts[1].trim().parse::<f32>().map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)); return parts[1]
.trim()
.parse::<f32>()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e));
} }
} }
} }
@ -501,16 +645,21 @@ async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::
io::ErrorKind::NotFound, io::ErrorKind::NotFound,
"NN ID not found in scores file", "NN ID not found in scores file",
)); ));
}, }
Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::PermissionDenied || e.kind() == io::ErrorKind::Other => { Err(e)
if attempts >= 5 { // Attempt 5 times before giving up. if e.kind() == io::ErrorKind::WouldBlock
|| e.kind() == io::ErrorKind::PermissionDenied
|| e.kind() == io::ErrorKind::Other =>
{
if attempts >= 5 {
// Attempt 5 times before giving up.
return Err(e); return Err(e);
} }
attempts += 1; attempts += 1;
// wait 1 second to ensure the file is written // wait 1 second to ensure the file is written
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
}, }
Err(e) => return Err(e), Err(e) => return Err(e),
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,11 @@
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error}; use async_trait::async_trait;
use gemla::{
core::genetic_node::{GeneticNode, GeneticNodeContext},
error::Error,
};
use rand::prelude::*; use rand::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
use async_trait::async_trait;
const POPULATION_SIZE: u64 = 5; const POPULATION_SIZE: u64 = 5;
const POPULATION_REDUCTION_SIZE: u64 = 3; const POPULATION_REDUCTION_SIZE: u64 = 3;
@ -76,7 +79,12 @@ impl GeneticNode for TestState {
Ok(()) Ok(())
} }
async fn merge(left: &TestState, right: &TestState, id: &Uuid, gemla_context: Self::Context) -> Result<Box<TestState>, Error> { async fn merge(
left: &TestState,
right: &TestState,
id: &Uuid,
gemla_context: Self::Context,
) -> Result<Box<TestState>, Error> {
let mut v = left.population.clone(); let mut v = left.population.clone();
v.append(&mut right.population.clone()); v.append(&mut right.population.clone());
@ -87,12 +95,14 @@ impl GeneticNode for TestState {
let mut result = TestState { population: v }; let mut result = TestState { population: v };
result.mutate(GeneticNodeContext { result
id: id.clone(), .mutate(GeneticNodeContext {
generation: 0, id: *id,
max_generations: 0, generation: 0,
gemla_context: gemla_context max_generations: 0,
}).await?; gemla_context,
})
.await?;
Ok(Box::new(result)) Ok(Box::new(result))
} }
@ -105,14 +115,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_initialize() { async fn test_initialize() {
let state = TestState::initialize( let state = TestState::initialize(GeneticNodeContext {
GeneticNodeContext { id: Uuid::new_v4(),
id: Uuid::new_v4(), generation: 0,
generation: 0, max_generations: 0,
max_generations: 0, gemla_context: (),
gemla_context: (), })
} .await
).await.unwrap(); .unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize); assert_eq!(state.population.len(), POPULATION_SIZE as usize);
} }
@ -125,35 +135,38 @@ mod tests {
let original_population = state.population.clone(); let original_population = state.population.clone();
state.simulate( state
GeneticNodeContext { .simulate(GeneticNodeContext {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
gemla_context: (), gemla_context: (),
} })
).await.unwrap(); .await
.unwrap();
assert!(original_population assert!(original_population
.iter() .iter()
.zip(state.population.iter()) .zip(state.population.iter())
.all(|(&a, &b)| b >= a - 1 && b <= a + 2)); .all(|(&a, &b)| b >= a - 1 && b <= a + 2));
state.simulate( state
GeneticNodeContext { .simulate(GeneticNodeContext {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
gemla_context: (), gemla_context: (),
} })
).await.unwrap(); .await
state.simulate( .unwrap();
GeneticNodeContext { state
.simulate(GeneticNodeContext {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
gemla_context: (), gemla_context: (),
} })
).await.unwrap(); .await
.unwrap();
assert!(original_population assert!(original_population
.iter() .iter()
.zip(state.population.iter()) .zip(state.population.iter())
@ -166,14 +179,15 @@ mod tests {
population: vec![4, 3, 3], population: vec![4, 3, 3],
}; };
state.mutate( state
GeneticNodeContext { .mutate(GeneticNodeContext {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
gemla_context: (), gemla_context: (),
} })
).await.unwrap(); .await
.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize); assert_eq!(state.population.len(), POPULATION_SIZE as usize);
} }
@ -188,7 +202,9 @@ mod tests {
population: vec![0, 1, 3, 7], population: vec![0, 1, 3, 7],
}; };
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()).await.unwrap(); let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ())
.await
.unwrap();
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize); assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
assert!(merged_state.population.iter().any(|&x| x == 7)); assert!(merged_state.population.iter().any(|&x| x == 7));

View file

@ -5,10 +5,10 @@
use crate::error::Error; use crate::error::Error;
use anyhow::Context; use anyhow::Context;
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt::Debug; use std::fmt::Debug;
use uuid::Uuid; use uuid::Uuid;
use async_trait::async_trait;
/// An enum used to control the state of a [`GeneticNode`] /// An enum used to control the state of a [`GeneticNode`]
/// ///
@ -30,14 +30,14 @@ pub struct GeneticNodeContext<S> {
pub generation: u64, pub generation: u64,
pub max_generations: u64, pub max_generations: u64,
pub id: Uuid, pub id: Uuid,
pub gemla_context: S pub gemla_context: S,
} }
/// A trait used to interact with the internal state of nodes within the [`Bracket`] /// A trait used to interact with the internal state of nodes within the [`Bracket`]
/// ///
/// [`Bracket`]: crate::bracket::Bracket /// [`Bracket`]: crate::bracket::Bracket
#[async_trait] #[async_trait]
pub trait GeneticNode : Send { pub trait GeneticNode: Send {
type Context; type Context;
/// Initializes a new instance of a [`GeneticState`]. /// Initializes a new instance of a [`GeneticState`].
@ -54,7 +54,12 @@ pub trait GeneticNode : Send {
/// TODO /// TODO
async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>; async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
async fn merge(left: &Self, right: &Self, id: &Uuid, context: Self::Context) -> Result<Box<Self>, Error>; async fn merge(
left: &Self,
right: &Self,
id: &Uuid,
context: Self::Context,
) -> Result<Box<Self>, Error>;
} }
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as /// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
@ -62,7 +67,7 @@ pub trait GeneticNode : Send {
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct GeneticNodeWrapper<T> pub struct GeneticNodeWrapper<T>
where where
T: Clone T: Clone,
{ {
node: Option<T>, node: Option<T>,
state: GeneticState, state: GeneticState,
@ -73,7 +78,7 @@ where
impl<T> Default for GeneticNodeWrapper<T> impl<T> Default for GeneticNodeWrapper<T>
where where
T: Clone T: Clone,
{ {
fn default() -> Self { fn default() -> Self {
GeneticNodeWrapper { GeneticNodeWrapper {
@ -146,7 +151,8 @@ where
self.state = GeneticState::Simulate; self.state = GeneticState::Simulate;
} }
(GeneticState::Simulate, Some(n)) => { (GeneticState::Simulate, Some(n)) => {
n.simulate(context.clone()).await n.simulate(context.clone())
.await
.with_context(|| format!("Error simulating node: {:?}", self))?; .with_context(|| format!("Error simulating node: {:?}", self))?;
self.state = if self.generation >= self.max_generations { self.state = if self.generation >= self.max_generations {
@ -156,7 +162,8 @@ where
}; };
} }
(GeneticState::Mutate, Some(n)) => { (GeneticState::Mutate, Some(n)) => {
n.mutate(context.clone()).await n.mutate(context.clone())
.await
.with_context(|| format!("Error mutating node: {:?}", self))?; .with_context(|| format!("Error mutating node: {:?}", self))?;
self.generation += 1; self.generation += 1;
@ -186,20 +193,33 @@ mod tests {
impl GeneticNode for TestState { impl GeneticNode for TestState {
type Context = (); type Context = ();
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> { async fn simulate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
self.score += 1.0; self.score += 1.0;
Ok(()) Ok(())
} }
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> { async fn mutate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
Ok(()) Ok(())
} }
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> { async fn initialize(
_context: GeneticNodeContext<Self::Context>,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 })) Ok(Box::new(TestState { score: 0.0 }))
} }
async fn merge(_l: &TestState, _r: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> { async fn merge(
_l: &TestState,
_r: &TestState,
_id: &Uuid,
_: Self::Context,
) -> Result<Box<TestState>, Error> {
Err(Error::Other(anyhow!("Unable to merge"))) Err(Error::Other(anyhow!("Unable to merge")))
} }
} }

View file

@ -10,10 +10,11 @@ use futures::future;
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState}; use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
use log::{info, trace, warn}; use log::{info, trace, warn};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::{sync::RwLock, task::JoinHandle};
use std::{ use std::{
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
sync::Arc, time::Instant,
}; };
use tokio::{sync::RwLock, task::JoinHandle};
use uuid::Uuid; use uuid::Uuid;
type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>; type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
@ -80,13 +81,18 @@ where
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Sync + Clone, T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Sync + Clone,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default, T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{ {
pub async fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result<Self, Error> { pub async fn new(
path: &Path,
config: GemlaConfig,
data_format: DataFormat,
) -> Result<Self, Error> {
match File::open(path) { match File::open(path) {
// If the file exists we either want to overwrite the file or read from the file // If the file exists we either want to overwrite the file or read from the file
// based on the configuration provided // based on the configuration provided
Ok(_) => Ok(Gemla { Ok(_) => Ok(Gemla {
data: if config.overwrite { data: if config.overwrite {
FileLinked::new((None, config, T::Context::default()), path, data_format).await? FileLinked::new((None, config, T::Context::default()), path, data_format)
.await?
} else { } else {
FileLinked::from_file(path, data_format)? FileLinked::from_file(path, data_format)?
}, },
@ -94,7 +100,8 @@ where
}), }),
// If the file doesn't exist we must create it // If the file doesn't exist we must create it
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla { Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
data: FileLinked::new((None, config, T::Context::default()), path, data_format).await?, data: FileLinked::new((None, config, T::Context::default()), path, data_format)
.await?,
threads: HashMap::new(), threads: HashMap::new(),
}), }),
Err(error) => Err(Error::IO(error)), Err(error) => Err(Error::IO(error)),
@ -106,24 +113,32 @@ where
} }
pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> { pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> {
{ let tree_completed = {
// Only increase height if the tree is uninitialized or completed // Only increase height if the tree is uninitialized or completed
let data_arc = self.data.readonly(); let data_arc = self.data.readonly();
let data_ref = data_arc.read().await; let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref(); let tree_ref = data_ref.0.as_ref();
if tree_ref.is_none() || tree_ref.is_none() || tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(true)
tree_ref };
.map(|t| Gemla::is_completed(t))
.unwrap_or(true) if tree_completed {
{ // Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed // in the tree and which nodes have not.
// in the tree and which nodes have not. self.data
self.data.mutate(|(d, c, _)| { .mutate(|(d, c, _)| {
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps); let mut tree: Option<SimulationTree<T>> =
Gemla::increase_height(d.take(), c, steps);
mem::swap(d, &mut tree); mem::swap(d, &mut tree);
}).await?; })
} .await?;
}
{
// Only increase height if the tree is uninitialized or completed
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
info!( info!(
"Height of simulation tree increased to {}", "Height of simulation tree increased to {}",
@ -141,36 +156,36 @@ where
let data_ref = data_arc.read().await; let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref(); let tree_ref = data_ref.0.as_ref();
is_tree_processed = tree_ref is_tree_processed = tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(false)
.map(|t| Gemla::is_completed(t))
.unwrap_or(false)
} }
// We need to keep simulating until the tree has been completely processed. // We need to keep simulating until the tree has been completely processed.
if is_tree_processed if is_tree_processed {
{
self.join_threads().await?; self.join_threads().await?;
info!("Processed tree"); info!("Processed tree");
break; break;
} }
if let Some(node) = tree_ref let (node, gemla_context) = {
.and_then(|t| self.get_unprocessed_node(t)) let data_arc = self.data.readonly();
{ let data_ref = data_arc.read().await;
let (tree_ref, _, gemla_context) = &*data_ref; // (Option<Box<Tree<GeneticNodeWrapper<T>>>, GemlaConfig, T::Context)
let node = tree_ref.as_ref().and_then(|t| self.get_unprocessed_node(t));
(node, gemla_context.clone())
};
if let Some(node) = node {
trace!("Adding node to process list {}", node.id()); trace!("Adding node to process list {}", node.id());
let data_arc = self.data.readonly(); let gemla_context = gemla_context.clone();
let data_ref2 = data_arc.read().await;
let gemla_context = data_ref2.2.clone();
drop(data_ref2);
self.threads self.threads.insert(
.insert(node.id(), tokio::spawn(async move { node.id(),
Gemla::process_node(node, gemla_context).await tokio::spawn(async move { Gemla::process_node(node, gemla_context).await }),
})); );
} else { } else {
trace!("No node found to process, joining threads"); trace!("No node found to process, joining threads");
@ -195,32 +210,34 @@ where
// We need to retrieve the processed nodes from the resulting list and replace them in the original list // We need to retrieve the processed nodes from the resulting list and replace them in the original list
match reduced_results { match reduced_results {
Ok(r) => { Ok(r) => {
self.data.mutate_async(|d| async move { self.data
// Scope to limit the duration of the read lock .mutate_async(|d| async move {
let (_, context) = { // Scope to limit the duration of the read lock
let data_read = d.read().await; let (_, context) = {
(data_read.1.clone(), data_read.2.clone()) let data_read = d.read().await;
}; // Read lock is dropped here (data_read.1, data_read.2.clone())
}; // Read lock is dropped here
let mut data_write = d.write().await; let mut data_write = d.write().await;
if let Some(t) = data_write.0.as_mut() { if let Some(t) = data_write.0.as_mut() {
let failed_nodes = Gemla::replace_nodes(t, r); let failed_nodes = Gemla::replace_nodes(t, r);
// We receive a list of nodes that were unable to be found in the original tree // We receive a list of nodes that were unable to be found in the original tree
if !failed_nodes.is_empty() { if !failed_nodes.is_empty() {
warn!( warn!(
"Unable to find {:?} to replace in tree", "Unable to find {:?} to replace in tree",
failed_nodes.iter().map(|n| n.id()) failed_nodes.iter().map(|n| n.id())
) )
}
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
Gemla::merge_completed_nodes(t, context.clone()).await
} else {
warn!("Unable to replce nodes {:?} in empty tree", r);
Ok(())
} }
})
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes .await??;
Gemla::merge_completed_nodes(t, context.clone()).await
} else {
warn!("Unable to replce nodes {:?} in empty tree", r);
Ok(())
}
}).await??;
} }
Err(e) => return Err(e), Err(e) => return Err(e),
} }
@ -230,7 +247,10 @@ where
} }
#[async_recursion] #[async_recursion]
async fn merge_completed_nodes<'a>(tree: &'a mut SimulationTree<T>, gemla_context: T::Context) -> Result<(), Error> { async fn merge_completed_nodes<'a>(
tree: &'a mut SimulationTree<T>,
gemla_context: T::Context,
) -> Result<(), Error> {
if tree.val.state() == GeneticState::Initialize { if tree.val.state() == GeneticState::Initialize {
match (&mut tree.left, &mut tree.right) { match (&mut tree.left, &mut tree.right) {
// If the current node has been initialized, and has children nodes that are completed, then we need // If the current node has been initialized, and has children nodes that are completed, then we need
@ -241,7 +261,13 @@ where
{ {
info!("Merging nodes {} and {}", l.val.id(), r.val.id()); info!("Merging nodes {} and {}", l.val.id(), r.val.id());
if let (Some(left_node), Some(right_node)) = (l.val.take(), r.val.take()) { if let (Some(left_node), Some(right_node)) = (l.val.take(), r.val.take()) {
let merged_node = GeneticNode::merge(&left_node, &right_node, &tree.val.id(), gemla_context.clone()).await?; let merged_node = GeneticNode::merge(
&left_node,
&right_node,
&tree.val.id(),
gemla_context.clone(),
)
.await?;
tree.val = GeneticNodeWrapper::from( tree.val = GeneticNodeWrapper::from(
*merged_node, *merged_node,
tree.val.max_generations(), tree.val.max_generations(),
@ -294,7 +320,10 @@ where
// during join_threads. // during join_threads.
(Some(l), Some(r)) (Some(l), Some(r))
if l.val.state() == GeneticState::Finish if l.val.state() == GeneticState::Finish
&& r.val.state() == GeneticState::Finish => Some(tree.val.clone()), && r.val.state() == GeneticState::Finish =>
{
Some(tree.val.clone())
}
(Some(l), Some(r)) => self (Some(l), Some(r)) => self
.get_unprocessed_node(l) .get_unprocessed_node(l)
.or_else(|| self.get_unprocessed_node(r)), .or_else(|| self.get_unprocessed_node(r)),
@ -355,7 +384,10 @@ where
tree.val.state() == GeneticState::Finish tree.val.state() == GeneticState::Finish
} }
async fn process_node(mut node: GeneticNodeWrapper<T>, gemla_context: T::Context) -> Result<GeneticNodeWrapper<T>, Error> { async fn process_node(
mut node: GeneticNodeWrapper<T>,
gemla_context: T::Context,
) -> Result<GeneticNodeWrapper<T>, Error> {
let node_state_time = Instant::now(); let node_state_time = Instant::now();
let node_state = node.state(); let node_state = node.state();
@ -379,10 +411,10 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::core::*; use crate::core::*;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::fs;
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use self::genetic_node::GeneticNodeContext; use self::genetic_node::GeneticNodeContext;
@ -420,20 +452,33 @@ mod tests {
impl genetic_node::GeneticNode for TestState { impl genetic_node::GeneticNode for TestState {
type Context = (); type Context = ();
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> { async fn simulate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
self.score += 1.0; self.score += 1.0;
Ok(()) Ok(())
} }
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> { async fn mutate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
Ok(()) Ok(())
} }
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> { async fn initialize(
_context: GeneticNodeContext<Self::Context>,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 })) Ok(Box::new(TestState { score: 0.0 }))
} }
async fn merge(left: &TestState, right: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> { async fn merge(
left: &TestState,
right: &TestState,
_id: &Uuid,
_: Self::Context,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.score > right.score { Ok(Box::new(if left.score > right.score {
left.clone() left.clone()
} else { } else {
@ -498,40 +543,43 @@ mod tests {
Ok(()) Ok(())
}) })
}) })
}).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result. })
.await
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(()) Ok(())
} }
// #[tokio::test] #[tokio::test]
// async fn test_simulate() -> Result<(), Error> { async fn test_simulate() -> Result<(), Error> {
// let path = PathBuf::from("test_simulate"); let path = PathBuf::from("test_simulate");
// // Use `spawn_blocking` to run the synchronous closure that internally awaits async code. // Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
// tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
// let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block. let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
// CleanUp::new(&path).run(move |p| { CleanUp::new(&path).run(move |p| {
// rt.block_on(async { rt.block_on(async {
// // Testing initial creation // Testing initial creation
// let config = GemlaConfig { let config = GemlaConfig {
// generations_per_height: 10, generations_per_height: 10,
// overwrite: true, overwrite: true,
// }; };
// let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?; let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
// // Now we can use `.await` within the spawned blocking task. // Now we can use `.await` within the spawned blocking task.
// gemla.simulate(5).await?; gemla.simulate(5).await?;
// let data = gemla.data.readonly(); let data = gemla.data.readonly();
// let data_lock = data.read().unwrap(); let data_lock = data.read().await;
// let tree = data_lock.0.as_ref().unwrap(); let tree = data_lock.0.as_ref().unwrap();
// assert_eq!(tree.height(), 5); assert_eq!(tree.height(), 5);
// assert_eq!(tree.val.as_ref().unwrap().score, 50.0); assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
// Ok(()) Ok(())
// }) })
// }) })
// }).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result. })
.await
// Ok(()) .unwrap()?; // Wait for the blocking task to complete, then handle the Result.
// }
Ok(())
}
} }