diff --git a/file_linked/src/lib.rs b/file_linked/src/lib.rs index 4c027fd..54dcf30 100644 --- a/file_linked/src/lib.rs +++ b/file_linked/src/lib.rs @@ -6,12 +6,11 @@ pub mod constants; use anyhow::{anyhow, Context}; use constants::data_format::DataFormat; use error::Error; -use futures::executor::block_on; use log::info; use serde::{de::DeserializeOwned, Serialize}; use tokio::sync::RwLock; use std::{ - borrow::Borrow, fs::{copy, remove_file, File}, io::{ErrorKind, Write}, path::{Path, PathBuf}, sync::Arc, thread::{self, JoinHandle} + fs::{copy, remove_file, File}, io::{ErrorKind, Write}, path::{Path, PathBuf}, sync::Arc, thread::{self, JoinHandle} }; @@ -56,6 +55,7 @@ where /// # use std::fmt; /// # use std::string::ToString; /// # use std::path::PathBuf; + /// # use tokio; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -64,19 +64,22 @@ where /// # pub c: f64 /// # } /// # - /// # fn main() { + /// # #[tokio::main] + /// # async fn main() { /// let test = Test { /// a: 1, /// b: String::from("two"), /// c: 3.0 /// }; /// - /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json) + /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await /// .expect("Unable to create file linked object"); /// - /// assert_eq!(linked_test.readonly().a, 1); - /// assert_eq!(linked_test.readonly().b, String::from("two")); - /// assert_eq!(linked_test.readonly().c, 3.0); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 1); + /// assert_eq!(readonly_ref.b, String::from("two")); + /// assert_eq!(readonly_ref.c, 3.0); /// # /// # drop(linked_test); /// # @@ -97,6 +100,7 @@ where /// # use std::fmt; /// # use std::string::ToString; /// # use std::path::PathBuf; + /// # use tokio; /// # /// #[derive(Deserialize, Serialize)] /// struct Test { @@ -105,19 +109,22 @@ where /// pub c: f64 /// } /// - /// # fn main() { + /// #[tokio::main] + /// # async fn main() { /// let test = Test { /// a: 1, /// b: String::from("two"), /// c: 3.0 /// }; /// - /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json) + /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await /// .expect("Unable to create file linked object"); /// - /// assert_eq!(linked_test.readonly().a, 1); - /// assert_eq!(linked_test.readonly().b, String::from("two")); - /// assert_eq!(linked_test.readonly().c, 3.0); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 1); + /// assert_eq!(readonly_ref.b, String::from("two")); + /// assert_eq!(readonly_ref.c, 3.0); /// # /// # drop(linked_test); /// # @@ -207,6 +214,7 @@ where /// # use std::fmt; /// # use std::string::ToString; /// # use std::path::PathBuf; + /// # use tokio; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -215,21 +223,28 @@ where /// # pub c: f64 /// # } /// # - /// # fn main() -> Result<(), Error> { + /// # #[tokio::main] + /// # async fn main() -> Result<(), Error> { /// let test = Test { /// a: 1, /// b: String::from(""), /// c: 0.0 /// }; /// - /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode) + /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await /// .expect("Unable to create file linked object"); /// - /// assert_eq!(linked_test.readonly().a, 1); + /// { + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 1); + /// } /// - /// linked_test.mutate(|t| t.a = 2)?; + /// linked_test.mutate(|t| t.a = 2).await?; /// - /// assert_eq!(linked_test.readonly().a, 2); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 2); /// # /// # drop(linked_test); /// # @@ -262,6 +277,7 @@ where /// # use std::fmt; /// # use std::string::ToString; /// # use std::path::PathBuf; + /// # use tokio; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -270,25 +286,30 @@ where /// # pub c: f64 /// # } /// # - /// # fn main() -> Result<(), Error> { + /// # #[tokio::main] + /// # async fn main() -> Result<(), Error> { /// let test = Test { /// a: 1, /// b: String::from(""), /// c: 0.0 /// }; /// - /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode) + /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await /// .expect("Unable to create file linked object"); /// - /// assert_eq!(linked_test.readonly().a, 1); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 1); /// /// linked_test.replace(Test { /// a: 2, /// b: String::from(""), /// c: 0.0 - /// })?; + /// }).await?; /// - /// assert_eq!(linked_test.readonly().a, 2); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 2); /// # /// # drop(linked_test); /// # @@ -343,6 +364,7 @@ where /// # use std::fs::OpenOptions; /// # use std::io::Write; /// # use std::path::PathBuf; + /// # use tokio; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -351,7 +373,8 @@ where /// # pub c: f64 /// # } /// # - /// # fn main() -> Result<(), Error> { + /// # #[tokio::main] + /// # async fn main() -> Result<(), Error> { /// let test = Test { /// a: 1, /// b: String::from("2"), @@ -371,9 +394,11 @@ where /// let mut linked_test = FileLinked::::from_file(&path, DataFormat::Bincode) /// .expect("Unable to create file linked object"); /// - /// assert_eq!(linked_test.readonly().a, test.a); - /// assert_eq!(linked_test.readonly().b, test.b); - /// assert_eq!(linked_test.readonly().c, test.c); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, test.a); + /// assert_eq!(readonly_ref.b, test.b); + /// assert_eq!(readonly_ref.c, test.c); /// # /// # drop(linked_test); /// # diff --git a/gemla/src/bin/bin.rs b/gemla/src/bin/bin.rs index a6f8b38..e41e417 100644 --- a/gemla/src/bin/bin.rs +++ b/gemla/src/bin/bin.rs @@ -3,18 +3,18 @@ extern crate gemla; #[macro_use] extern crate log; -mod test_state; mod fighter_nn; +mod test_state; +use anyhow::Result; +use clap::Parser; +use fighter_nn::FighterNN; use file_linked::constants::data_format::DataFormat; use gemla::{ core::{Gemla, GemlaConfig}, error::log_error, }; use std::{path::PathBuf, time::Instant}; -use fighter_nn::FighterNN; -use clap::Parser; -use anyhow::Result; // const NUM_THREADS: usize = 2; @@ -39,19 +39,22 @@ fn main() -> Result<()> { // Manually configure the Tokio runtime let runtime: Result<()> = tokio::runtime::Builder::new_multi_thread() - .worker_threads(num_cpus::get()) - // .worker_threads(NUM_THREADS) + .worker_threads(num_cpus::get()) + // .worker_threads(NUM_THREADS) .build()? .block_on(async { let args = Args::parse(); // Assuming Args::parse() doesn't need to be async - let mut gemla = log_error(Gemla::::new( - &PathBuf::from(args.file), - GemlaConfig { - generations_per_height: 5, - overwrite: false, - }, - DataFormat::Json, - ).await)?; + let mut gemla = log_error( + Gemla::::new( + &PathBuf::from(args.file), + GemlaConfig { + generations_per_height: 5, + overwrite: false, + }, + DataFormat::Json, + ) + .await, + )?; // let gemla_arc = Arc::new(gemla); @@ -59,7 +62,8 @@ fn main() -> Result<()> { // If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks // Example placeholder loop to continuously run simulate - loop { // Arbitrary loop count for demonstration + loop { + // Arbitrary loop count for demonstration gemla.simulate(1).await?; } }); @@ -68,4 +72,4 @@ fn main() -> Result<()> { info!("Finished in {:?}", now.elapsed()); Ok(()) -} \ No newline at end of file +} diff --git a/gemla/src/bin/fighter_nn/fighter_context.rs b/gemla/src/bin/fighter_nn/fighter_context.rs index a87706b..c631627 100644 --- a/gemla/src/bin/fighter_nn/fighter_context.rs +++ b/gemla/src/bin/fighter_nn/fighter_context.rs @@ -5,7 +5,6 @@ use tokio::sync::Semaphore; const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 50; - #[derive(Debug, Clone)] pub struct FighterContext { pub shared_semaphore: Arc, @@ -19,7 +18,6 @@ impl Default for FighterContext { } } - // Custom serialization to just output the concurrency limit. impl Serialize for FighterContext { fn serialize(&self, serializer: S) -> Result @@ -45,4 +43,4 @@ impl<'de> Deserialize<'de> for FighterContext { shared_semaphore: Arc::new(Semaphore::new(concurrency_limit as usize)), }) } -} \ No newline at end of file +} diff --git a/gemla/src/bin/fighter_nn/mod.rs b/gemla/src/bin/fighter_nn/mod.rs index d4683b9..ee99b8c 100644 --- a/gemla/src/bin/fighter_nn/mod.rs +++ b/gemla/src/bin/fighter_nn/mod.rs @@ -1,20 +1,29 @@ extern crate fann; -pub mod neural_network_utility; pub mod fighter_context; +pub mod neural_network_utility; -use std::{cmp::max, collections::{HashSet, VecDeque}, fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, panic::{catch_unwind, AssertUnwindSafe}, path::{Path, PathBuf}, sync::{Arc, Mutex}, time::Duration}; +use anyhow::Context; +use async_trait::async_trait; use fann::{ActivationFunc, Fann}; -use futures::{executor::block_on, future::{join, join_all, select_all}, stream::FuturesUnordered, FutureExt, StreamExt}; -use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error}; +use futures::future::join_all; +use gemla::{ + core::genetic_node::{GeneticNode, GeneticNodeContext}, + error::Error, +}; use lerp::Lerp; use rand::prelude::*; use serde::{Deserialize, Serialize}; -use anyhow::Context; -use tokio::{process::Command, sync::{mpsc, Semaphore}, task, time::{sleep, timeout, Sleep}}; -use uuid::Uuid; use std::collections::HashMap; -use async_trait::async_trait; +use std::{ + cmp::max, + fs::{self, File}, + io::{self, BufRead, BufReader}, + ops::Range, + path::{Path, PathBuf}, +}; +use tokio::process::Command; +use uuid::Uuid; use self::neural_network_utility::{crossbreed, major_mutation}; @@ -34,13 +43,14 @@ const NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX: usize = 20; const SIMULATION_ROUNDS: usize = 5; const SURVIVAL_RATE: f32 = 0.5; -const GAME_EXECUTABLE_PATH: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe"; +const GAME_EXECUTABLE_PATH: &str = + "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe"; // Here is the folder structure for the FighterNN: // base_dir/fighter_nn_{fighter_id}/{generation}/{fighter_id}_fighter_nn_{nn_id}.net // A neural network that utilizes the fann library to save and read nn's from files -// FighterNN contains a list of file locations for the nn's stored, all of which are stored under the same folder which is also contained. +// FighterNN contains a list of file locations for the nn's stored, all of which are stored under the same folder which is also contained. // there is no training happening to the neural networks // the neural networks are only used to simulate the nn's and to save and read the nn's from files // Filenames are stored in the format of "{fighter_id}_fighter_nn_{generation}.net". @@ -70,37 +80,48 @@ impl GeneticNode for FighterNN { // Check for the highest number of the folder name and increment it by 1 async fn initialize(context: GeneticNodeContext) -> Result, Error> { let base_path = PathBuf::from(BASE_DIR); - + let folder = base_path.join(format!("fighter_nn_{:06}", context.id)); // Ensures directory is created if it doesn't exist and does nothing if it exists fs::create_dir_all(&folder) .with_context(|| format!("Failed to create or access the folder: {:?}", folder))?; - + //Create a new directory for the first generation, using create_dir_all to avoid errors if it already exists let gen_folder = folder.join("0"); - fs::create_dir_all(&gen_folder) - .with_context(|| format!("Failed to create or access the generation folder: {:?}", gen_folder))?; + fs::create_dir_all(&gen_folder).with_context(|| { + format!( + "Failed to create or access the generation folder: {:?}", + gen_folder + ) + })?; let mut nn_shapes = HashMap::new(); - let weight_initialization_range = thread_rng().gen_range(NEURAL_NETWORK_INITIAL_WEIGHT_MIN..0.0)..thread_rng().gen_range(0.0..=NEURAL_NETWORK_INITIAL_WEIGHT_MAX); - + let weight_initialization_range = thread_rng() + .gen_range(NEURAL_NETWORK_INITIAL_WEIGHT_MIN..0.0) + ..thread_rng().gen_range(0.0..=NEURAL_NETWORK_INITIAL_WEIGHT_MAX); + // Create the first generation in this folder for i in 0..POPULATION { // Filenames are stored in the format of "xxxxxx_fighter_nn_0.net", "xxxxxx_fighter_nn_1.net", etc. Where xxxxxx is the folder name let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i)); // Randomly generate a neural network shape based on constants - let hidden_layers = thread_rng().gen_range(NEURAL_NETWORK_HIDDEN_LAYERS_MIN..NEURAL_NETWORK_HIDDEN_LAYERS_MAX); + let hidden_layers = thread_rng() + .gen_range(NEURAL_NETWORK_HIDDEN_LAYERS_MIN..NEURAL_NETWORK_HIDDEN_LAYERS_MAX); let mut nn_shape = vec![NEURAL_NETWORK_INPUTS as u32]; for _ in 0..hidden_layers { - nn_shape.push(thread_rng().gen_range(NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN..NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX) as u32); + nn_shape.push(thread_rng().gen_range( + NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN..NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX, + ) as u32); } nn_shape.push(NEURAL_NETWORK_OUTPUTS as u32); nn_shapes.insert(i as u64, nn_shape.clone()); - let mut fann = Fann::new(nn_shape.as_slice()) - .with_context(|| "Failed to create nn")?; - fann.randomize_weights(weight_initialization_range.start, weight_initialization_range.end); + let mut fann = Fann::new(nn_shape.as_slice()).with_context(|| "Failed to create nn")?; + fann.randomize_weights( + weight_initialization_range.start, + weight_initialization_range.end, + ); fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); // This will overwrite any existing file with the same name @@ -108,13 +129,15 @@ impl GeneticNode for FighterNN { .with_context(|| format!("Failed to save nn at {:?}", nn))?; } - let mut crossbreed_segments = thread_rng().gen_range(NEURAL_NETWORK_CROSSBREED_SEGMENTS_MIN..NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX); + let mut crossbreed_segments = thread_rng().gen_range( + NEURAL_NETWORK_CROSSBREED_SEGMENTS_MIN..NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX, + ); if crossbreed_segments % 2 == 0 { crossbreed_segments += 1; } let mutation_weight_amplitude = thread_rng().gen_range(0.0..1.0); - + Ok(Box::new(FighterNN { id: context.id, folder, @@ -141,9 +164,12 @@ impl GeneticNode for FighterNN { let semaphore_clone = context.gemla_context.shared_semaphore.clone(); let task = async move { - let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(self_clone.get_individual_id(i as u64)); + let nn = self_clone + .folder + .join(format!("{}", self_clone.generation)) + .join(self_clone.get_individual_id(i as u64)); let mut simulations = Vec::new(); - + // Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently for _ in 0..SIMULATION_ROUNDS { let random_nn_index = thread_rng().gen_range(0..self_clone.population_size); @@ -151,11 +177,16 @@ impl GeneticNode for FighterNN { let generation = self_clone.generation; let semaphore_clone = semaphore_clone.clone(); - let random_nn = folder.join(format!("{}", generation)).join(self_clone.get_individual_id(random_nn_index as u64)); + let random_nn = folder + .join(format!("{}", generation)) + .join(self_clone.get_individual_id(random_nn_index as u64)); let nn_clone = nn.clone(); // Clone the path to use in the async block - + let future = async move { - let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?; + let permit = semaphore_clone + .acquire_owned() + .await + .with_context(|| "Failed to acquire semaphore permit")?; let (score, _) = run_1v1_simulation(&nn_clone, &random_nn).await?; @@ -163,13 +194,14 @@ impl GeneticNode for FighterNN { Ok(score) }; - + simulations.push(future); } - + // Wait for all simulation rounds to complete - let results: Result, Error> = join_all(simulations).await.into_iter().collect(); - + let results: Result, Error> = + join_all(simulations).await.into_iter().collect(); + let score = match results { Ok(scores) => scores.into_iter().sum::() / SIMULATION_ROUNDS as f32, Err(e) => return Err(e), // Return the error if results collection failed @@ -188,34 +220,46 @@ impl GeneticNode for FighterNN { Ok((index, score)) => { // Update the original `self` object with the score. self.scores[self.generation as usize].insert(index as u64, score); - }, + } Err(e) => { // Handle task panic or execution error - return Err(Error::Other(anyhow::anyhow!(format!("Task failed: {:?}", e)))); - }, + return Err(Error::Other(anyhow::anyhow!(format!( + "Task failed: {:?}", + e + )))); + } } } - + Ok(()) } - async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize; // Create the new generation folder let new_gen_folder = self.folder.join(format!("{}", self.generation + 1)); - fs::create_dir_all(&new_gen_folder).with_context(|| format!("Failed to create or access new generation folder: {:?}", new_gen_folder))?; + fs::create_dir_all(&new_gen_folder).with_context(|| { + format!( + "Failed to create or access new generation folder: {:?}", + new_gen_folder + ) + })?; // Remove the 5 nn's with the lowest scores let mut sorted_scores: Vec<_> = self.scores[self.generation as usize].iter().collect(); sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap()); - let to_keep = sorted_scores[survivor_count..].iter().map(|(k, _)| *k).collect::>(); + let to_keep = sorted_scores[survivor_count..] + .iter() + .map(|(k, _)| *k) + .collect::>(); // Save the remaining 5 nn's to the new generation folder - for i in 0..survivor_count { - let nn_id = to_keep[i]; - let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id)); + for (i, nn_id) in to_keep.iter().enumerate().take(survivor_count) { + let nn = self + .folder + .join(format!("{}", self.generation)) + .join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id)); let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i)); fs::copy(&nn, &new_nn)?; } @@ -223,16 +267,25 @@ impl GeneticNode for FighterNN { // Take the remaining 5 nn's and create 5 new nn's by the following: for i in 0..survivor_count { let nn_id = to_keep[i]; - let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id)); - let fann = Fann::from_file(&nn) - .with_context(|| format!("Failed to load nn"))?; + let nn = self + .folder + .join(format!("{}", self.generation)) + .join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id)); + let fann = Fann::from_file(&nn).with_context(|| "Failed to load nn")?; // Load another nn from the current generation and cross breed it with the current nn - let cross_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, to_keep[thread_rng().gen_range(0..survivor_count)])); - let cross_fann = Fann::from_file(&cross_nn) - .with_context(|| format!("Failed to load cross nn"))?; + let cross_nn = self + .folder + .join(format!("{}", self.generation)) + .join(format!( + "{:06}_fighter_nn_{}.net", + self.id, + to_keep[thread_rng().gen_range(0..survivor_count)] + )); + let cross_fann = + Fann::from_file(&cross_nn).with_context(|| "Failed to load cross nn")?; - let mut new_fann = crossbreed(&self, &fann, &cross_fann, self.crossbreed_segments)?; + let mut new_fann = crossbreed(self, &fann, &cross_fann, self.crossbreed_segments)?; // For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight) // And a 5% chance of a major mutation a new neuron is randomly added to a hidden layer @@ -246,15 +299,20 @@ impl GeneticNode for FighterNN { } new_fann.set_connections(&connections); - + if thread_rng().gen_range(0.0..1.0) < self.major_mutation_rate { new_fann = major_mutation(&new_fann, self.weight_initialization_range.clone())?; } // Save the new nn's to the new generation folder - let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i + survivor_count)); - new_fann.save(&new_nn) - .with_context(|| format!("Failed to save nn"))?; + let new_nn = new_gen_folder.join(format!( + "{:06}_fighter_nn_{}.net", + self.id, + i + survivor_count + )); + new_fann + .save(&new_nn) + .with_context(|| "Failed to save nn")?; } self.generation += 1; @@ -263,18 +321,28 @@ impl GeneticNode for FighterNN { Ok(()) } - async fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid, gemla_context: Self::Context) -> Result, Error> { + async fn merge( + left: &FighterNN, + right: &FighterNN, + id: &Uuid, + gemla_context: Self::Context, + ) -> Result, Error> { let base_path = PathBuf::from(BASE_DIR); let folder = base_path.join(format!("fighter_nn_{:06}", id)); - + // Ensure the folder exists, including the generation subfolder. - fs::create_dir_all(&folder.join("0")) + fs::create_dir_all(folder.join("0")) .with_context(|| format!("Failed to create directory {:?}", folder.join("0")))?; let get_highest_scores = |fighter: &FighterNN| -> Vec<(u64, f32)> { - let mut sorted_scores: Vec<_> = fighter.scores[fighter.generation as usize].iter().collect(); + let mut sorted_scores: Vec<_> = + fighter.scores[fighter.generation as usize].iter().collect(); sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap()); - sorted_scores.iter().take(fighter.population_size / 2).map(|(k, v)| (**k, **v)).collect() + sorted_scores + .iter() + .take(fighter.population_size / 2) + .map(|(k, v)| (**k, **v)) + .collect() }; let left_scores = get_highest_scores(left); @@ -285,18 +353,28 @@ impl GeneticNode for FighterNN { let mut simulations = Vec::new(); - for _ in 0..max(left.population_size, right.population_size)*SIMULATION_ROUNDS { + for _ in 0..max(left.population_size, right.population_size) * SIMULATION_ROUNDS { let left_nn_id = left_scores[thread_rng().gen_range(0..left_scores.len())].0; let right_nn_id = right_scores[thread_rng().gen_range(0..right_scores.len())].0; - let left_nn_path = left.folder.join(left.generation.to_string()).join(left.get_individual_id(left_nn_id)); - let right_nn_path = right.folder.join(right.generation.to_string()).join(right.get_individual_id(right_nn_id)); + let left_nn_path = left + .folder + .join(left.generation.to_string()) + .join(left.get_individual_id(left_nn_id)); + let right_nn_path = right + .folder + .join(right.generation.to_string()) + .join(right.get_individual_id(right_nn_id)); let semaphore_clone = gemla_context.shared_semaphore.clone(); let future = async move { - let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?; + let permit = semaphore_clone + .acquire_owned() + .await + .with_context(|| "Failed to acquire semaphore permit")?; - let (left_score, right_score) = run_1v1_simulation(&left_nn_path, &right_nn_path).await?; + let (left_score, right_score) = + run_1v1_simulation(&left_nn_path, &right_nn_path).await?; drop(permit); @@ -306,7 +384,8 @@ impl GeneticNode for FighterNN { simulations.push(future); } - let results: Result, Error> = join_all(simulations).await.into_iter().collect(); + let results: Result, Error> = + join_all(simulations).await.into_iter().collect(); let scores = results?; let total_left_score = scores.iter().map(|(l, _)| l).sum::(); @@ -320,53 +399,93 @@ impl GeneticNode for FighterNN { let lerp_amount = 1.0 / (1.0 + (-score_difference).exp()); let mut nn_shapes = HashMap::new(); - - // Function to copy NNs from a source FighterNN to the new folder. - let mut copy_nns = |source: &FighterNN, folder: &PathBuf, id: &Uuid, start_idx: usize| -> Result<(), Error> { - let mut sorted_scores: Vec<_> = source.scores[source.generation as usize].iter().collect(); - sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap()); - let remaining = sorted_scores[(source.population_size / 2)..].iter().map(|(k, _)| *k).collect::>(); - - for (i, nn_id) in remaining.into_iter().enumerate() { - let nn_path = source.folder.join(source.generation.to_string()).join(format!("{:06}_fighter_nn_{}.net", source.id, nn_id)); - let new_nn_path = folder.join("0").join(format!("{:06}_fighter_nn_{}.net", id, start_idx + i)); - fs::copy(&nn_path, &new_nn_path) - .with_context(|| format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path))?; - nn_shapes.insert((start_idx + i) as u64, source.nn_shapes.get(&nn_id).unwrap().clone()); + // Function to copy NNs from a source FighterNN to the new folder. + let mut copy_nns = |source: &FighterNN, + folder: &PathBuf, + id: &Uuid, + start_idx: usize| + -> Result<(), Error> { + let mut sorted_scores: Vec<_> = + source.scores[source.generation as usize].iter().collect(); + sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap()); + let remaining = sorted_scores[(source.population_size / 2)..] + .iter() + .map(|(k, _)| *k) + .collect::>(); + + for (i, nn_id) in remaining.into_iter().enumerate() { + let nn_path = source + .folder + .join(source.generation.to_string()) + .join(format!("{:06}_fighter_nn_{}.net", source.id, nn_id)); + let new_nn_path = + folder + .join("0") + .join(format!("{:06}_fighter_nn_{}.net", id, start_idx + i)); + fs::copy(&nn_path, &new_nn_path).with_context(|| { + format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path) + })?; + + nn_shapes.insert( + (start_idx + i) as u64, + source.nn_shapes.get(nn_id).unwrap().clone(), + ); } Ok(()) }; - + // Copy the top half of NNs from each parent to the new folder. copy_nns(left, &folder, id, 0)?; - copy_nns(right, &folder, id, left.population_size as usize / 2)?; + copy_nns(right, &folder, id, left.population_size / 2)?; debug!("nn_shapes: {:?}", nn_shapes); // Lerp the mutation rates and weight ranges - let crossbreed_segments = (left.crossbreed_segments as f32).lerp(right.crossbreed_segments as f32, lerp_amount) as usize; + let crossbreed_segments = (left.crossbreed_segments as f32) + .lerp(right.crossbreed_segments as f32, lerp_amount) + as usize; - let weight_initialization_range_start = left.weight_initialization_range.start.lerp(right.weight_initialization_range.start, lerp_amount); - let weight_initialization_range_end = left.weight_initialization_range.end.lerp(right.weight_initialization_range.end, lerp_amount); + let weight_initialization_range_start = left + .weight_initialization_range + .start + .lerp(right.weight_initialization_range.start, lerp_amount); + let weight_initialization_range_end = left + .weight_initialization_range + .end + .lerp(right.weight_initialization_range.end, lerp_amount); // Have to ensure the range is valid - let weight_initialization_range = if weight_initialization_range_start < weight_initialization_range_end { - weight_initialization_range_start..weight_initialization_range_end - } else { - weight_initialization_range_end..weight_initialization_range_start - }; + let weight_initialization_range = + if weight_initialization_range_start < weight_initialization_range_end { + weight_initialization_range_start..weight_initialization_range_end + } else { + weight_initialization_range_end..weight_initialization_range_start + }; - debug!("weight_initialization_range: {:?}", weight_initialization_range); + debug!( + "weight_initialization_range: {:?}", + weight_initialization_range + ); - let minor_mutation_rate = left.minor_mutation_rate.lerp(right.minor_mutation_rate, lerp_amount); - let major_mutation_rate = left.major_mutation_rate.lerp(right.major_mutation_rate, lerp_amount); + let minor_mutation_rate = left + .minor_mutation_rate + .lerp(right.minor_mutation_rate, lerp_amount); + let major_mutation_rate = left + .major_mutation_rate + .lerp(right.major_mutation_rate, lerp_amount); debug!("minor_mutation_rate: {}", minor_mutation_rate); debug!("major_mutation_rate: {}", major_mutation_rate); - - let mutation_weight_range_start = left.mutation_weight_range.start.lerp(right.mutation_weight_range.start, lerp_amount); - let mutation_weight_range_end = left.mutation_weight_range.end.lerp(right.mutation_weight_range.end, lerp_amount); + + let mutation_weight_range_start = left + .mutation_weight_range + .start + .lerp(right.mutation_weight_range.start, lerp_amount); + let mutation_weight_range_end = left + .mutation_weight_range + .end + .lerp(right.mutation_weight_range.end, lerp_amount); // Have to ensure the range is valid let mutation_weight_range = if mutation_weight_range_start < mutation_weight_range_end { mutation_weight_range_start..mutation_weight_range_end @@ -398,7 +517,7 @@ impl FighterNN { } } -async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<(f32, f32), Error> { +async fn run_1v1_simulation(nn_path_1: &Path, nn_path_2: &Path) -> Result<(f32, f32), Error> { // Construct the score file path let base_folder = nn_path_1.parent().unwrap(); let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap(); @@ -407,14 +526,18 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result< // Check if score file already exists before running the simulation if score_file.exists() { - let round_score = read_score_from_file(&score_file, &nn_1_id).await + let round_score = read_score_from_file(&score_file, nn_1_id) + .await .with_context(|| format!("Failed to read score from file: {:?}", score_file))?; - let opposing_score = read_score_from_file(&score_file, &nn_2_id).await + let opposing_score = read_score_from_file(&score_file, nn_2_id) + .await .with_context(|| format!("Failed to read score from file: {:?}", score_file))?; - debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score); - + debug!( + "{} scored {}, while {} scored {}", + nn_1_id, round_score, nn_2_id, opposing_score + ); return Ok((round_score, opposing_score)); } @@ -422,13 +545,22 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result< // Check if the opposite round score has been determined let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id)); if opposite_score_file.exists() { - let round_score = read_score_from_file(&opposite_score_file, &nn_1_id).await - .with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?; + let round_score = read_score_from_file(&opposite_score_file, nn_1_id) + .await + .with_context(|| { + format!("Failed to read score from file: {:?}", opposite_score_file) + })?; - let opposing_score = read_score_from_file(&opposite_score_file, &nn_2_id).await - .with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?; + let opposing_score = read_score_from_file(&opposite_score_file, nn_2_id) + .await + .with_context(|| { + format!("Failed to read score from file: {:?}", opposite_score_file) + })?; - debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score); + debug!( + "{} scored {}, while {} scored {}", + nn_1_id, round_score, nn_2_id, opposing_score + ); return Ok((round_score, opposing_score)); } @@ -459,20 +591,29 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result< .expect("Failed to execute game") }; - trace!("Simulation completed for {} vs {}: {}", nn_1_id, nn_2_id, score_file.exists()); + trace!( + "Simulation completed for {} vs {}: {}", + nn_1_id, + nn_2_id, + score_file.exists() + ); // Read the score from the file if score_file.exists() { - let round_score = read_score_from_file(&score_file, &nn_1_id).await + let round_score = read_score_from_file(&score_file, nn_1_id) + .await .with_context(|| format!("Failed to read score from file: {:?}", score_file))?; - let opposing_score = read_score_from_file(&score_file, &nn_2_id).await + let opposing_score = read_score_from_file(&score_file, nn_2_id) + .await .with_context(|| format!("Failed to read score from file: {:?}", score_file))?; - debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score); + debug!( + "{} scored {}, while {} scored {}", + nn_1_id, round_score, nn_2_id, opposing_score + ); - - return Ok((round_score, opposing_score)) + Ok((round_score, opposing_score)) } else { warn!("Score file not found: {:?}", score_file); Ok((0.0, 0.0)) @@ -492,7 +633,10 @@ async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result = line.split(':').collect(); if parts.len() == 2 { - return parts[1].trim().parse::().map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)); + return parts[1] + .trim() + .parse::() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)); } } } @@ -501,17 +645,22 @@ async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result { - if attempts >= 5 { // Attempt 5 times before giving up. + } + Err(e) + if e.kind() == io::ErrorKind::WouldBlock + || e.kind() == io::ErrorKind::PermissionDenied + || e.kind() == io::ErrorKind::Other => + { + if attempts >= 5 { + // Attempt 5 times before giving up. return Err(e); } attempts += 1; // wait 1 second to ensure the file is written tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - }, + } Err(e) => return Err(e), } } -} \ No newline at end of file +} diff --git a/gemla/src/bin/fighter_nn/neural_network_utility.rs b/gemla/src/bin/fighter_nn/neural_network_utility.rs index a3e5bc1..04a0992 100644 --- a/gemla/src/bin/fighter_nn/neural_network_utility.rs +++ b/gemla/src/bin/fighter_nn/neural_network_utility.rs @@ -3,16 +3,24 @@ use std::{cmp::min, collections::HashMap, ops::Range}; use anyhow::Context; use fann::{ActivationFunc, Fann}; use gemla::error::Error; -use rand::{distributions::{Distribution, Uniform}, seq::IteratorRandom, thread_rng, Rng}; +use rand::{ + distributions::{Distribution, Uniform}, + seq::IteratorRandom, + thread_rng, Rng, +}; use super::{FighterNN, NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN}; - /// Crossbreeds two neural networks of different shapes by finding cut points, and swapping neurons between the two networks. /// Algorithm tries to ensure similar functionality is maintained between the two networks. /// It does this by preserving connections between the same neurons from the original to the new network, and if a connection cannot be found /// it will create a new connection with a random weight. -pub fn crossbreed(fighter_nn: &FighterNN, primary: &Fann, secondary: &Fann, crossbreed_segments: usize) -> Result { +pub fn crossbreed( + fighter_nn: &FighterNN, + primary: &Fann, + secondary: &Fann, + crossbreed_segments: usize, +) -> Result { // First we need to get the shape of the networks and transform this into a format that is easier to work with // We want a list of every neuron id, and the layer it is in let primary_shape = primary.get_layer_sizes(); @@ -21,25 +29,27 @@ pub fn crossbreed(fighter_nn: &FighterNN, primary: &Fann, secondary: &Fann, cros let secondary_neurons = generate_neuron_datastructure(&secondary_shape); let segments = generate_segments(primary_shape, secondary_shape, crossbreed_segments); - + let new_neurons = crossbreed_neuron_arrays(segments, primary_neurons, secondary_neurons); - + // Now we need to create the new network with the shape we've determined let mut new_shape = vec![]; for (_, _, layer, _) in new_neurons.iter() { // Check if new_shape has an entry for layer in it - if new_shape.len() <= *layer as usize { + if new_shape.len() <= *layer { new_shape.push(1); - } - else { - new_shape[*layer as usize] += 1; + } else { + new_shape[*layer] += 1; } } - let mut new_fann = Fann::new(new_shape.as_slice()) - .with_context(|| "Failed to create new fann")?; + let mut new_fann = + Fann::new(new_shape.as_slice()).with_context(|| "Failed to create new fann")?; // We need to randomize the weights to a small value - new_fann.randomize_weights(fighter_nn.weight_initialization_range.start, fighter_nn.weight_initialization_range.end); + new_fann.randomize_weights( + fighter_nn.weight_initialization_range.start, + fighter_nn.weight_initialization_range.end, + ); new_fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); new_fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); @@ -48,11 +58,18 @@ pub fn crossbreed(fighter_nn: &FighterNN, primary: &Fann, secondary: &Fann, cros Ok(new_fann) } -pub fn generate_segments(primary_shape: Vec, secondary_shape: Vec, crossbreed_segments: usize) -> Vec<(u32, u32)> { +pub fn generate_segments( + primary_shape: Vec, + secondary_shape: Vec, + crossbreed_segments: usize, +) -> Vec<(u32, u32)> { // Now we need to find the cut points for the crossbreed let start = primary_shape[0] + 1; // Start at the first hidden layer - let end = min(primary_shape.iter().sum::() - primary_shape.last().unwrap(), secondary_shape.iter().sum::() - secondary_shape.last().unwrap()); + let end = min( + primary_shape.iter().sum::() - primary_shape.last().unwrap(), + secondary_shape.iter().sum::() - secondary_shape.last().unwrap(), + ); // End at the last hidden layer let segment_distribution = Uniform::from(start..end); // Ensure segments are not too small @@ -77,23 +94,35 @@ pub fn generate_segments(primary_shape: Vec, secondary_shape: Vec, cro segments } -pub fn consolidate_old_connections(primary: &Fann, secondary: &Fann, new_shape: Vec, new_neurons: Vec<(u32, bool, usize, u32)>, new_fann: &mut Fann) { +pub fn consolidate_old_connections( + primary: &Fann, + secondary: &Fann, + new_shape: Vec, + new_neurons: Vec<(u32, bool, usize, u32)>, + new_fann: &mut Fann, +) { // Now we need to copy the connections from the original networks to the new network - // We can do this by referencing our connections array, it will contain the original id's of the neurons + // We can do this by referencing our connections array, it will contain the original id's of the neurons // and their new id as well as their layer. We can iterate one layer at a time and copy the connections let primary_shape = primary.get_layer_sizes(); let secondary_shape = secondary.get_layer_sizes(); debug!("Primary shape: {:?}", primary_shape); debug!("Secondary shape: {:?}", secondary_shape); - + // Start by iterating layer by later let primary_connections = primary.get_connections(); let secondary_connections = secondary.get_connections(); for layer in 1..new_shape.len() { // filter out the connections that are in the current layer and previous layer - let current_layer_connections = new_neurons.iter().filter(|(_, _, l, _)| l == &layer).collect::>(); - let previous_layer_connections = new_neurons.iter().filter(|(_, _, l, _)| l == &(layer - 1)).collect::>(); + let current_layer_connections = new_neurons + .iter() + .filter(|(_, _, l, _)| l == &layer) + .collect::>(); + let previous_layer_connections = new_neurons + .iter() + .filter(|(_, _, l, _)| l == &(layer - 1)) + .collect::>(); // Now we need to iterate over the connections in the current layer for (neuron_id, is_primary, _, new_id) in current_layer_connections.iter() { @@ -104,10 +133,26 @@ pub fn consolidate_old_connections(primary: &Fann, secondary: &Fann, new_shape: let mut connection; let mut found_in_primary = false; if *is_primary { - connection = primary_connections.iter() - .find(|connection| { - let from_neuron = to_non_bias_network_id(connection.from_neuron, &primary_shape); - let to_neuron = to_non_bias_network_id(connection.to_neuron, &primary_shape); + connection = primary_connections.iter().find(|connection| { + let from_neuron = + to_non_bias_network_id(connection.from_neuron, &primary_shape); + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); + + // If both neurons have a Some value + if let (Some(from_neuron), Some(to_neuron)) = (from_neuron, to_neuron) { + from_neuron == *previous_neuron_id && to_neuron == *neuron_id + } else { + false + } + }); + + if connection.is_none() { + connection = secondary_connections.iter().find(|connection| { + let from_neuron = + to_non_bias_network_id(connection.from_neuron, &secondary_shape); + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); // If both neurons have a Some value if let (Some(from_neuron), Some(to_neuron)) = (from_neuron, to_neuron) { @@ -116,29 +161,30 @@ pub fn consolidate_old_connections(primary: &Fann, secondary: &Fann, new_shape: false } }); - - if let None = connection { - connection = secondary_connections.iter() - .find(|connection| { - let from_neuron = to_non_bias_network_id(connection.from_neuron, &secondary_shape); - let to_neuron = to_non_bias_network_id(connection.to_neuron, &secondary_shape); - - // If both neurons have a Some value - if let (Some(from_neuron), Some(to_neuron)) = (from_neuron, to_neuron) { - from_neuron == *previous_neuron_id && to_neuron == *neuron_id - } else { - false - } - }); } else { found_in_primary = true; } - } - else { - connection = secondary_connections.iter() - .find(|connection| { - let from_neuron = to_non_bias_network_id(connection.from_neuron, &secondary_shape); - let to_neuron = to_non_bias_network_id(connection.to_neuron, &secondary_shape); + } else { + connection = secondary_connections.iter().find(|connection| { + let from_neuron = + to_non_bias_network_id(connection.from_neuron, &secondary_shape); + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); + + // If both neurons have a Some value + if let (Some(from_neuron), Some(to_neuron)) = (from_neuron, to_neuron) { + from_neuron == *previous_neuron_id && to_neuron == *neuron_id + } else { + false + } + }); + + if connection.is_none() { + connection = primary_connections.iter().find(|connection| { + let from_neuron = + to_non_bias_network_id(connection.from_neuron, &primary_shape); + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); // If both neurons have a Some value if let (Some(from_neuron), Some(to_neuron)) = (from_neuron, to_neuron) { @@ -147,20 +193,6 @@ pub fn consolidate_old_connections(primary: &Fann, secondary: &Fann, new_shape: false } }); - - if let None = connection { - connection = primary_connections.iter() - .find(|connection| { - let from_neuron = to_non_bias_network_id(connection.from_neuron, &primary_shape); - let to_neuron = to_non_bias_network_id(connection.to_neuron, &primary_shape); - - // If both neurons have a Some value - if let (Some(from_neuron), Some(to_neuron)) = (from_neuron, to_neuron) { - from_neuron == *previous_neuron_id && to_neuron == *neuron_id - } else { - false - } - }); } else { found_in_primary = true; } @@ -169,19 +201,29 @@ pub fn consolidate_old_connections(primary: &Fann, secondary: &Fann, new_shape: // If the connection exists, we need to add it to the new network if let Some(connection) = connection { if *is_primary { - let original_from_neuron = to_non_bias_network_id(connection.from_neuron, &primary_shape); - let original_to_neuron = to_non_bias_network_id(connection.to_neuron, &primary_shape); + let original_from_neuron = + to_non_bias_network_id(connection.from_neuron, &primary_shape); + let original_to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); trace!("Primary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", previous_new_id, new_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, previous_neuron_id, neuron_id); } else { - let original_from_neuron = to_non_bias_network_id(connection.from_neuron, &secondary_shape); - let original_to_neuron = to_non_bias_network_id(connection.to_neuron, &secondary_shape); + let original_from_neuron = + to_non_bias_network_id(connection.from_neuron, &secondary_shape); + let original_to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); trace!("Secondary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", previous_new_id, new_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, previous_neuron_id, neuron_id); } let translated_from = to_bias_network_id(previous_new_id, &new_shape); let translated_to = to_bias_network_id(new_id, &new_shape); new_fann.set_weight(translated_from, translated_to, connection.weight); } else { - trace!("Connection not found for ({}, {}) -> ({}, {})", previous_new_id, new_id, previous_neuron_id, neuron_id); + trace!( + "Connection not found for ({}, {}) -> ({}, {})", + previous_new_id, + new_id, + previous_neuron_id, + neuron_id + ); } } } @@ -197,34 +239,35 @@ pub fn consolidate_old_connections(primary: &Fann, secondary: &Fann, new_shape: let mut found_in_primary = false; if *is_primary { let primary_bias_neuron = get_bias_neuron_for_layer(layer, &primary_shape); - if let Some(primary_bias_neuron) = primary_bias_neuron - { - connection = primary_connections.iter() - .find(|connection| { - let to_neuron = to_non_bias_network_id(connection.to_neuron, &primary_shape); + if let Some(primary_bias_neuron) = primary_bias_neuron { + connection = primary_connections.iter().find(|connection| { + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); if let Some(to_neuron) = to_neuron { - connection.from_neuron == primary_bias_neuron && to_neuron == *neuron_id + connection.from_neuron == primary_bias_neuron + && to_neuron == *neuron_id } else { false } }); } - - if let None = connection { - let secondary_bias_neuron = get_bias_neuron_for_layer(layer, &secondary_shape); + if connection.is_none() { + let secondary_bias_neuron = + get_bias_neuron_for_layer(layer, &secondary_shape); if let Some(secondary_bias_neuron) = secondary_bias_neuron { - connection = secondary_connections.iter() - .find(|connection| { - let to_neuron = to_non_bias_network_id(connection.to_neuron, &secondary_shape); + connection = secondary_connections.iter().find(|connection| { + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); - if let Some(to_neuron) = to_neuron { - connection.from_neuron == secondary_bias_neuron && to_neuron == *neuron_id - } else { - false - } - }); + if let Some(to_neuron) = to_neuron { + connection.from_neuron == secondary_bias_neuron + && to_neuron == *neuron_id + } else { + false + } + }); } } else { found_in_primary = true; @@ -232,31 +275,33 @@ pub fn consolidate_old_connections(primary: &Fann, secondary: &Fann, new_shape: } else { let secondary_bias_neuron = get_bias_neuron_for_layer(layer, &secondary_shape); if let Some(secondary_bias_neuron) = secondary_bias_neuron { - connection = secondary_connections.iter() - .find(|connection| { - let to_neuron = to_non_bias_network_id(connection.to_neuron, &secondary_shape); + connection = secondary_connections.iter().find(|connection| { + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); + + if let Some(to_neuron) = to_neuron { + connection.from_neuron == secondary_bias_neuron + && to_neuron == *neuron_id + } else { + false + } + }); + } + + if connection.is_none() { + let primary_bias_neuron = get_bias_neuron_for_layer(layer, &primary_shape); + if let Some(primary_bias_neuron) = primary_bias_neuron { + connection = primary_connections.iter().find(|connection| { + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); if let Some(to_neuron) = to_neuron { - connection.from_neuron == secondary_bias_neuron && to_neuron == *neuron_id + connection.from_neuron == primary_bias_neuron + && to_neuron == *neuron_id } else { false } }); - } - - if let None = connection { - let primary_bias_neuron = get_bias_neuron_for_layer(layer, &primary_shape); - if let Some(primary_bias_neuron) = primary_bias_neuron { - connection = primary_connections.iter() - .find(|connection| { - let to_neuron = to_non_bias_network_id(connection.to_neuron, &primary_shape); - - if let Some(to_neuron) = to_neuron { - connection.from_neuron == primary_bias_neuron && to_neuron == *neuron_id - } else { - false - } - }); } } else { found_in_primary = true; @@ -265,24 +310,39 @@ pub fn consolidate_old_connections(primary: &Fann, secondary: &Fann, new_shape: if let Some(connection) = connection { if *is_primary { - let original_from_neuron = to_non_bias_network_id(connection.from_neuron, &primary_shape); - let original_to_neuron = to_non_bias_network_id(connection.to_neuron, &primary_shape); + let original_from_neuron = + to_non_bias_network_id(connection.from_neuron, &primary_shape); + let original_to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); trace!("Primary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", bias_neuron, translated_neuron_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, bias_neuron, neuron_id); } else { - let original_from_neuron = to_non_bias_network_id(connection.from_neuron, &secondary_shape); - let original_to_neuron = to_non_bias_network_id(connection.to_neuron, &secondary_shape); + let original_from_neuron = + to_non_bias_network_id(connection.from_neuron, &secondary_shape); + let original_to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); trace!("Secondary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", bias_neuron, translated_neuron_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, bias_neuron, neuron_id); } new_fann.set_weight(bias_neuron, translated_neuron_id, connection.weight); } else { - trace!("Connection not found for bias ({}, {}) -> ({}, {}) primary: {}", bias_neuron, neuron_id, bias_neuron, translated_neuron_id, is_primary); + trace!( + "Connection not found for bias ({}, {}) -> ({}, {}) primary: {}", + bias_neuron, + neuron_id, + bias_neuron, + translated_neuron_id, + is_primary + ); } } } } } -pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec<(u32, usize)>, secondary_neurons: Vec<(u32, usize)>) -> Vec<(u32, bool, usize, u32)> { +pub fn crossbreed_neuron_arrays( + segments: Vec<(u32, u32)>, + primary_neurons: Vec<(u32, usize)>, + secondary_neurons: Vec<(u32, usize)>, +) -> Vec<(u32, bool, usize, u32)> { // We now need to determine the resulting location of the neurons in the new network. // To do this we need a new structure that keeps track of the following information: // - The neuron id from the original network @@ -306,32 +366,32 @@ pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec< new_neurons.push((*neuron_id, is_primary, current_layer, 0)); if is_primary { primary_last_layer = current_layer; - } - else { + } else { secondary_last_layer = current_layer; } - } - else { + } else { break; } } - } - else { - let target_neurons = if is_primary { &primary_neurons } else { &secondary_neurons }; + } else { + let target_neurons = if is_primary { + &primary_neurons + } else { + &secondary_neurons + }; for (neuron_id, layer) in target_neurons.iter() { // Iterate until neuron_id equals the cut_point if neuron_id >= &segment.0 && neuron_id <= &segment.1 { // We need to do something different depending on whether the neuron layer is, lower, higher or equal to the target layer - + // Equal if layer == ¤t_layer { new_neurons.push((*neuron_id, is_primary, current_layer, 0)); if is_primary { primary_last_layer = current_layer; - } - else { + } else { secondary_last_layer = current_layer; } } @@ -340,19 +400,28 @@ pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec< // If it's in an earlier layer, add it to the earlier layer // Check if there's a lower id from the same individual in that earlier layer // As long as there isn't a neuron from the other individual in between the lower id and current id, add the id values from the same individual - let earlier_layer_neurons = new_neurons.iter().filter(|(_, _, l, _)| l == layer).collect::>(); + let earlier_layer_neurons = new_neurons + .iter() + .filter(|(_, _, l, _)| l == layer) + .collect::>(); // get max id from that layer - let highest_id = earlier_layer_neurons.iter().max_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); + let highest_id = earlier_layer_neurons + .iter() + .max_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); if let Some(highest_id) = highest_id { if highest_id.1 == is_primary { - let neurons_to_add = target_neurons.iter().filter(|(id, l)| id > &highest_id.0 && id < neuron_id && l == layer).collect::>(); + let neurons_to_add = target_neurons + .iter() + .filter(|(id, l)| { + id > &highest_id.0 && id < neuron_id && l == layer + }) + .collect::>(); for (neuron_id, layer) in neurons_to_add { new_neurons.push((*neuron_id, is_primary, *layer, 0)); if is_primary { primary_last_layer = *layer; - } - else { + } else { secondary_last_layer = *layer; } } @@ -363,8 +432,7 @@ pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec< if is_primary { primary_last_layer = *layer; - } - else { + } else { secondary_last_layer = *layer; } } @@ -372,18 +440,24 @@ pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec< else if layer > ¤t_layer { // If the highest id in the current layer is from the same individual, add anything with a higher id to the current layer before moving to the next layer // First filter new_neurons to look at neurons from the current layer - let current_layer_neurons = new_neurons.iter().filter(|(_, _, l, _)| l == ¤t_layer).collect::>(); - let highest_id = current_layer_neurons.iter().max_by_key(|(id, _, _, _)| id); + let current_layer_neurons = new_neurons + .iter() + .filter(|(_, _, l, _)| l == ¤t_layer) + .collect::>(); + let highest_id = + current_layer_neurons.iter().max_by_key(|(id, _, _, _)| id); if let Some(highest_id) = highest_id { if highest_id.1 == is_primary { - let neurons_to_add = target_neurons.iter().filter(|(id, l)| id > &highest_id.0 && *l == layer - 1).collect::>(); + let neurons_to_add = target_neurons + .iter() + .filter(|(id, l)| id > &highest_id.0 && *l == layer - 1) + .collect::>(); for (neuron_id, _) in neurons_to_add { new_neurons.push((*neuron_id, is_primary, current_layer, 0)); if is_primary { primary_last_layer = current_layer; - } - else { + } else { secondary_last_layer = current_layer; } } @@ -395,21 +469,21 @@ pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec< // Add the neuron to the new network // Along with any neurons that have a lower id in the future layer - let neurons_to_add = target_neurons.iter().filter(|(id, l)| id <= &neuron_id && l == layer).collect::>(); + let neurons_to_add = target_neurons + .iter() + .filter(|(id, l)| id <= neuron_id && l == layer) + .collect::>(); for (neuron_id, _) in neurons_to_add { new_neurons.push((*neuron_id, is_primary, current_layer, 0)); if is_primary { primary_last_layer = current_layer; - } - else { + } else { secondary_last_layer = current_layer; } } } - - } - else if neuron_id >= &segment.1 { + } else if neuron_id >= &segment.1 { break; } } @@ -420,7 +494,11 @@ pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec< } // For the last segment, copy the remaining neurons - let target_neurons = if is_primary { &primary_neurons } else { &secondary_neurons }; + let target_neurons = if is_primary { + &primary_neurons + } else { + &secondary_neurons + }; // Get output layer number let output_layer = target_neurons.iter().max_by_key(|(_, l)| l).unwrap().1; @@ -436,17 +514,28 @@ pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec< new_neurons.push((*neuron_id, is_primary, current_layer, 0)); } break; - } - else if *neuron_id == &segments.last().unwrap().1 + 1 { - let target_layer = if is_primary { primary_last_layer } else { secondary_last_layer }; - let earlier_layer_neurons = new_neurons.iter().filter(|(_, _, l, _)| *l >= target_layer && l <= layer).collect::>(); + } else if *neuron_id == &segments.last().unwrap().1 + 1 { + let target_layer = if is_primary { + primary_last_layer + } else { + secondary_last_layer + }; + let earlier_layer_neurons = new_neurons + .iter() + .filter(|(_, _, l, _)| *l >= target_layer && l <= layer) + .collect::>(); // get max neuron from with both // The highest layer // get max id from that layer - let highest_id = earlier_layer_neurons.iter().max_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); + let highest_id = earlier_layer_neurons + .iter() + .max_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); if let Some(highest_id) = highest_id { if highest_id.1 == is_primary { - let neurons_to_add = target_neurons.iter().filter(|(id, _)| id > &highest_id.0 && id < neuron_id).collect::>(); + let neurons_to_add = target_neurons + .iter() + .filter(|(id, _)| id > &highest_id.0 && id < neuron_id) + .collect::>(); for (neuron_id, l) in neurons_to_add { new_neurons.push((*neuron_id, is_primary, *l, 0)); } @@ -454,33 +543,39 @@ pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec< } new_neurons.push((*neuron_id, is_primary, *layer, 0)); - } - else { + } else { new_neurons.push((*neuron_id, is_primary, *layer, 0)); } } } // Filtering layers with too few neurons, if necessary - let layer_counts = new_neurons.iter().fold(vec![0; current_layer + 1], |mut counts, &(_, _, layer, _)| { - counts[layer] += 1; - counts - }); + let layer_counts = new_neurons.iter().fold( + vec![0; current_layer + 1], + |mut counts, &(_, _, layer, _)| { + counts[layer] += 1; + counts + }, + ); // Filter out layers based on the minimum number of neurons per layer - new_neurons = new_neurons.into_iter() - .filter(|&(_, _, layer, _)| layer_counts[layer] >= NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN) - .collect::>(); + new_neurons = new_neurons + .into_iter() + .filter(|&(_, _, layer, _)| layer_counts[layer] >= NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN) + .collect::>(); // Collect and sort unique layer numbers - let mut unique_layers = new_neurons.iter() + let mut unique_layers = new_neurons + .iter() .map(|(_, _, layer, _)| *layer) .collect::>(); unique_layers.sort(); unique_layers.dedup(); // Removes duplicates, keeping only unique layer numbers // Create a mapping from old layer numbers to new (gap-less) layer numbers - let layer_mapping = unique_layers.iter().enumerate() + let layer_mapping = unique_layers + .iter() + .enumerate() .map(|(new_layer, &old_layer)| (old_layer, new_layer)) .collect::>(); @@ -492,9 +587,12 @@ pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec< // Assign new IDs // new_neurons must be sorted by layer, then by neuron ID within the layer new_neurons.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); - new_neurons.iter_mut().enumerate().for_each(|(new_id, neuron)| { - neuron.3 = new_id as u32; - }); + new_neurons + .iter_mut() + .enumerate() + .for_each(|(new_id, neuron)| { + neuron.3 = new_id as u32; + }); new_neurons } @@ -502,46 +600,64 @@ pub fn crossbreed_neuron_arrays(segments: Vec<(u32, u32)>, primary_neurons: Vec< pub fn major_mutation(fann: &Fann, weight_initialization_range: Range) -> Result { // add or remove a random neuron from a hidden layer let mut mutated_shape = fann.get_layer_sizes().to_vec(); - let mut mutated_neurons = generate_neuron_datastructure(&mutated_shape).iter().map(|(id, layer)| (*id, true, *layer, *id)).collect::>(); + let mut mutated_neurons = generate_neuron_datastructure(&mutated_shape) + .iter() + .map(|(id, layer)| (*id, true, *layer, *id)) + .collect::>(); // Determine first whether to add or remove a neuron if thread_rng().gen_range(0..2) == 0 { // To add a neuron we need to create a new fann object with the new layer sizes, then copy the information and connections over - let max_id = mutated_neurons.iter().max_by_key(|(id, _, _, _)| id).unwrap().0; + let max_id = mutated_neurons + .iter() + .max_by_key(|(id, _, _, _)| id) + .unwrap() + .0; // Now we inject the new neuron into mutated_neurons let layer = thread_rng().gen_range(1..fann.get_num_layers() - 1) as usize; let new_id = max_id + 1; mutated_neurons.push((new_id, true, layer, new_id)); mutated_shape[layer] += 1; - } - else { + } else { // Remove a neuron let layer = thread_rng().gen_range(1..fann.get_num_layers() - 1) as usize; // Do not remove from layer if it would result in less than NEURALNETWORK_HIDDEN_LAYER_SIZE_MIN neurons if mutated_shape[layer] > NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN as u32 { - let remove_id = mutated_neurons.iter().filter(|(_, _, l, _)| l == &layer).choose(&mut thread_rng()).unwrap().0; + let remove_id = mutated_neurons + .iter() + .filter(|(_, _, l, _)| l == &layer) + .choose(&mut thread_rng()) + .unwrap() + .0; mutated_neurons.retain(|(id, _, _, _)| id != &remove_id); mutated_shape[layer] -= 1; } } - let mut mutated_fann = Fann::new(mutated_shape.as_slice()) - .with_context(|| "Failed to create new fann")?; - mutated_fann.randomize_weights(weight_initialization_range.start, weight_initialization_range.end); + let mut mutated_fann = + Fann::new(mutated_shape.as_slice()).with_context(|| "Failed to create new fann")?; + mutated_fann.randomize_weights( + weight_initialization_range.start, + weight_initialization_range.end, + ); mutated_fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); mutated_fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); // We need to regenerate the new_id's in mutated_neurons (the 4th item in the tuple) we can do this by iterating over the mutated_neurons all over again starting from ZERO mutated_neurons.sort_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); - let mut i = 0; - for (_, _, _, new_id) in mutated_neurons.iter_mut() { - *new_id = i; - i += 1; + for (i, (_, _, _, new_id)) in mutated_neurons.iter_mut().enumerate() { + *new_id = i as u32; } // We need to copy the connections from the old fann to the new fann - consolidate_old_connections(&fann, &fann, mutated_shape, mutated_neurons, &mut mutated_fann); + consolidate_old_connections( + fann, + fann, + mutated_shape, + mutated_neurons, + &mut mutated_fann, + ); Ok(mutated_fann) } @@ -561,8 +677,7 @@ pub fn generate_neuron_datastructure(shape: &[u32]) -> Vec<(u32, usize)> { result } - -fn to_bias_network_id(id: &u32, shape: &Vec) -> u32 { +fn to_bias_network_id(id: &u32, shape: &[u32]) -> u32 { // The given id comes from a network without a bias neuron at the end of every layer // We need to translate this id to the id in the network with bias neurons let mut translated_id = 0; @@ -580,14 +695,13 @@ fn to_bias_network_id(id: &u32, shape: &Vec) -> u32 { } fn to_non_bias_network_id(id: u32, shape: &[u32]) -> Option { - let mut bias_count = 0; // Count of bias neurons encountered up to the current ID let mut total_neurons = 0; // Total count of neurons (excluding bias neurons) processed - for &neurons in shape { + for (bias_count, &neurons) in shape.iter().enumerate() { let layer_end = total_neurons + neurons; // End of the current layer, excluding the bias neuron if id < layer_end { // ID is within the current layer (excluding the bias neuron) - return Some(id - bias_count); + return Some(id - bias_count as u32); } if id == layer_end { // ID matches the position where a bias neuron would be @@ -596,7 +710,6 @@ fn to_non_bias_network_id(id: u32, shape: &[u32]) -> Option { // Update counts after considering the current layer total_neurons += neurons + 1; // Move to the next layer, accounting for the bias neuron - bias_count += 1; // Increment bias count as we've moved past where a bias neuron would be } // If the ID is beyond the range of all neurons (including bias), it's treated as invalid @@ -611,8 +724,8 @@ fn get_bias_neuron_for_layer(layer: usize, shape: &[u32]) -> Option { } else { // Compute the bias neuron for intermediate layers let mut bias = 0; - for i in 0..layer { - bias += shape[i]; + for layer_count in shape.iter().take(layer) { + bias += layer_count; } Some(bias + layer as u32 - 1) } @@ -654,7 +767,10 @@ mod tests { // Assert that input and output layers have the same size assert_eq!(primary_shape[0], new_shape[0]); - assert_eq!(primary_shape[primary_shape.len() - 1], new_shape[new_shape.len() - 1]); + assert_eq!( + primary_shape[primary_shape.len() - 1], + new_shape[new_shape.len() - 1] + ); // Determine if a neuron was removed or added if new_shape.iter().sum::() == primary_shape.iter().sum::() + 1 { @@ -671,14 +787,25 @@ mod tests { } for connection in connections.iter() { - if connection.from_neuron == added_neuron_id || connection.to_neuron == added_neuron_id { - assert!(connection.weight < 0.0, "Connection: {:?}, Added Neuron: {}", connection, added_neuron_id); + if connection.from_neuron == added_neuron_id + || connection.to_neuron == added_neuron_id + { + assert!( + connection.weight < 0.0, + "Connection: {:?}, Added Neuron: {}", + connection, + added_neuron_id + ); } else { - assert!(connection.weight > 0.0, "Connection: {:?}, Added Neuron: {}", connection, added_neuron_id); + assert!( + connection.weight > 0.0, + "Connection: {:?}, Added Neuron: {}", + connection, + added_neuron_id + ); } } - } - else if new_shape.iter().sum::() == primary_shape.iter().sum::() - 1 { + } else if new_shape.iter().sum::() == primary_shape.iter().sum::() - 1 { //Neuron was removed for connection in connections.iter() { assert!(connection.weight > 0.0, "Connection: {:?}", connection); @@ -687,11 +814,14 @@ mod tests { for (i, layer) in new_shape.iter().enumerate() { // if layer isn't input or output if i != 0 && i as u32 != new_shape.len() as u32 - 1 { - assert!(*layer >= NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN as u32, "Layer: {}", layer); + assert!( + *layer >= NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN as u32, + "Layer: {}", + layer + ); } } - } - else { + } else { //Neuron was neither added nor removed for connection in connections.iter() { assert!(connection.weight > 0.0, "Connection: {:?}", connection); @@ -710,12 +840,20 @@ mod tests { let crossbreed_segments = 5; // Act - let result = generate_segments(primary_shape.clone(), secondary_shape.clone(), crossbreed_segments); + let result = generate_segments( + primary_shape.clone(), + secondary_shape.clone(), + crossbreed_segments, + ); println!("{:?}", result); // Assert - assert!(result.len() <= crossbreed_segments, "Segments: {:?}", result); + assert!( + result.len() <= crossbreed_segments, + "Segments: {:?}", + result + ); //Assert that segments are within the bounds of the layers for (start, end) in result.iter() { // Bounds are the end of the first layer to the end of the second to last layer @@ -726,8 +864,7 @@ mod tests { } //Assert that segments start and end are in ascending order - for (start, end) in result.iter() - { + for (start, end) in result.iter() { assert!(*start <= *end, "Start: {}, End: {}", start, end); } @@ -740,7 +877,11 @@ mod tests { let crossbreed_segments = 15; // Act - let result = generate_segments(primary_shape.clone(), secondary_shape.clone(), crossbreed_segments); + let result = generate_segments( + primary_shape.clone(), + secondary_shape.clone(), + crossbreed_segments, + ); println!("{:?}", result); @@ -754,8 +895,7 @@ mod tests { } //Assert that segments start and end are in ascending order - for (start, end) in result.iter() - { + for (start, end) in result.iter() { assert!(*start <= *end, "Start: {}, End: {}", start, end); } @@ -805,25 +945,52 @@ mod tests { fn crossbreed_neuron_arrays_test() { // Assign let segments = vec![(0, 3), (4, 6), (7, 8), (9, 10)]; - + let primary_network = generate_neuron_datastructure(&vec![4, 8, 6, 4]); let secondary_network = generate_neuron_datastructure(&vec![4, 3, 3, 3, 3, 3, 4]); // Act - let result = crossbreed_neuron_arrays(segments.clone(), primary_network.clone(), secondary_network.clone()); + let result = crossbreed_neuron_arrays( + segments.clone(), + primary_network.clone(), + secondary_network.clone(), + ); // Expected Result Set let expected: HashSet<(u32, bool, usize, u32)> = vec![ // Input layer: Expect 4 - (0, true, 0, 0), (1, true, 0, 1), (2, true, 0, 2), (3, true, 0, 3), + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), // Hidden Layer 1: Expect 8 - (4, false, 1, 4), (5, false, 1, 5), (6, false, 1, 6), (7, true, 1, 7), (8, true, 1, 8), (9, true, 1, 9), (10, true, 1, 10), (11, true, 1, 11), + (4, false, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), + (7, true, 1, 7), + (8, true, 1, 8), + (9, true, 1, 9), + (10, true, 1, 10), + (11, true, 1, 11), // Hidden Layer 2: Expect 9 - (7, false, 2, 12), (8, false, 2, 13), (9, false, 2, 14), (12, true, 2, 15), (13, true, 2, 16), (14, true, 2, 17), (15, true, 2, 18), (16, true, 2, 19), (17, true, 2, 20), + (7, false, 2, 12), + (8, false, 2, 13), + (9, false, 2, 14), + (12, true, 2, 15), + (13, true, 2, 16), + (14, true, 2, 17), + (15, true, 2, 18), + (16, true, 2, 19), + (17, true, 2, 20), // Output Layer: Expect 4 - (18, true, 3, 21), (19, true, 3, 22), (20, true, 3, 23), (21, true, 3, 24), - ].into_iter().collect(); + (18, true, 3, 21), + (19, true, 3, 22), + (20, true, 3, 23), + (21, true, 3, 24), + ] + .into_iter() + .collect(); // Convert Result to HashSet for Comparison let result_set: HashSet<(u32, bool, usize, u32)> = result.into_iter().collect(); @@ -833,25 +1000,51 @@ mod tests { // Now we test the ooposite case // Act - let result = crossbreed_neuron_arrays(segments.clone(), secondary_network.clone(), primary_network.clone()); + let result = crossbreed_neuron_arrays( + segments.clone(), + secondary_network.clone(), + primary_network.clone(), + ); // Expected Result Set let expected: HashSet<(u32, bool, usize, u32)> = vec![ // Input layer: Expect 4 - (0, true, 0, 0), (1, true, 0, 1), (2, true, 0, 2), (3, true, 0, 3), + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), // Hidden Layer 1: Expect 7 - (4, false, 1, 4), (5, false, 1, 5), (6, false, 1, 6), (7, false, 1, 7), (8, false, 1, 8), (9, false, 1, 9), (10, false, 1, 10), + (4, false, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), + (7, false, 1, 7), + (8, false, 1, 8), + (9, false, 1, 9), + (10, false, 1, 10), // Hidden Layer 2: Expect 3 - (7, true, 2, 11), (8, true, 2, 12), (9, true, 2, 13), + (7, true, 2, 11), + (8, true, 2, 12), + (9, true, 2, 13), // Hidden Layer 3: Expect 3 - (10, true, 3, 14), (11, true, 3, 15), (12, true, 3, 16), + (10, true, 3, 14), + (11, true, 3, 15), + (12, true, 3, 16), // Hidden Layer 4: Expect 3 - (13, true, 4, 17), (14, true, 4, 18), (15, true, 4, 19), + (13, true, 4, 17), + (14, true, 4, 18), + (15, true, 4, 19), // Hidden Layer 5: Expect 3 - (16, true, 5, 20), (17, true, 5, 21), (18, true, 5, 22), + (16, true, 5, 20), + (17, true, 5, 21), + (18, true, 5, 22), // Output Layer: Expect 4 - (19, true, 6, 23), (20, true, 6, 24), (21, true, 6, 25), (22, true, 6, 26), - ].into_iter().collect(); + (19, true, 6, 23), + (20, true, 6, 24), + (21, true, 6, 25), + (22, true, 6, 26), + ] + .into_iter() + .collect(); // Convert Result to HashSet for Comparison let result_set: HashSet<(u32, bool, usize, u32)> = result.into_iter().collect(); @@ -864,23 +1057,46 @@ mod tests { let segments = vec![(0, 4), (5, 14), (15, 15), (16, 16)]; // Act - let result = crossbreed_neuron_arrays(segments.clone(), primary_network.clone(), secondary_network.clone()); + let result = crossbreed_neuron_arrays( + segments.clone(), + primary_network.clone(), + secondary_network.clone(), + ); // Expected Result Set let expected: HashSet<(u32, bool, usize, u32)> = vec![ // Input layer: Expect 4 - (0, true, 0, 0), (1, true, 0, 1), (2, true, 0, 2), (3, true, 0, 3), + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), // Hidden Layer 1: Expect 3 - (4, true, 1, 4), (5, false, 1, 5), (6, false, 1, 6), + (4, true, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), // Hidden Layer 2: Expect 6 - (7, false, 2, 7), (8, false, 2, 8), (9, false, 2, 9), (15, true, 2, 10), (16, true, 2, 11), (17, true, 2, 12), + (7, false, 2, 7), + (8, false, 2, 8), + (9, false, 2, 9), + (15, true, 2, 10), + (16, true, 2, 11), + (17, true, 2, 12), // Hidden Layer 3: Expect 3 - (10, false, 3, 13), (11, false, 3, 14), (12, false, 3, 15), + (10, false, 3, 13), + (11, false, 3, 14), + (12, false, 3, 15), // Hidden Layer 4: Expect 3 - (13, false, 4, 16), (14, false, 4, 17), (15, false, 4, 18), + (13, false, 4, 16), + (14, false, 4, 17), + (15, false, 4, 18), // Output Layer: Expect 4 - (18, true, 5, 19), (19, true, 5, 20), (20, true, 5, 21), (21, true, 5, 22), - ].into_iter().collect(); + (18, true, 5, 19), + (19, true, 5, 20), + (20, true, 5, 21), + (21, true, 5, 22), + ] + .into_iter() + .collect(); // print result before comparison for r in result.iter() { @@ -894,23 +1110,50 @@ mod tests { assert_eq!(result_set, expected); // Swapping order - let result = crossbreed_neuron_arrays(segments.clone(), secondary_network.clone(), primary_network.clone()); + let result = crossbreed_neuron_arrays( + segments.clone(), + secondary_network.clone(), + primary_network.clone(), + ); // Expected Result Set let expected: HashSet<(u32, bool, usize, u32)> = vec![ // Input layer: Expect 4 - (0, true, 0, 0), (1, true, 0, 1), (2, true, 0, 2), (3, true, 0, 3), + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), // Hidden Layer 1: Expect 8 - (4, true, 1, 4), (5, false, 1, 5), (6, false, 1, 6), (7, false, 1, 7), (8, false, 1, 8), (9, false, 1, 9), (10, false, 1, 10), (11, false, 1, 11), + (4, true, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), + (7, false, 1, 7), + (8, false, 1, 8), + (9, false, 1, 9), + (10, false, 1, 10), + (11, false, 1, 11), // Hidden Layer 2: Expect 5 - (12, false, 2, 12), (13, false, 2, 13), (14, false, 2, 14), (15, false, 2, 15), (16, false, 2, 16), + (12, false, 2, 12), + (13, false, 2, 13), + (14, false, 2, 14), + (15, false, 2, 15), + (16, false, 2, 16), // Hidden Layer 3: Expect 3 - (13, true, 3, 17), (14, true, 3, 18), (15, true, 3, 19), + (13, true, 3, 17), + (14, true, 3, 18), + (15, true, 3, 19), // Hidden Layer 4: Expect 3 - (16, true, 4, 20), (17, true, 4, 21), (18, true, 4, 22), + (16, true, 4, 20), + (17, true, 4, 21), + (18, true, 4, 22), // Output Layer: Expect 4 - (19, true, 5, 23), (20, true, 5, 24), (21, true, 5, 25), (22, true, 5, 26), - ].into_iter().collect(); + (19, true, 5, 23), + (20, true, 5, 24), + (21, true, 5, 25), + (22, true, 5, 26), + ] + .into_iter() + .collect(); // print result before comparison for r in result.iter() { @@ -928,21 +1171,48 @@ mod tests { let segments = vec![(0, 7), (8, 9), (10, 10), (11, 12)]; // Act - let result = crossbreed_neuron_arrays(segments.clone(), primary_network.clone(), secondary_network.clone()); + let result = crossbreed_neuron_arrays( + segments.clone(), + primary_network.clone(), + secondary_network.clone(), + ); // Expected Result Set let expected: HashSet<(u32, bool, usize, u32)> = vec![ // Input layer: Expect 4 - (0, true, 0, 0), (1, true, 0, 1), (2, true, 0, 2), (3, true, 0, 3), + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), // Hidden Layer 1: Expect 7 - (4, true, 1, 4), (5, true, 1, 5), (6, true, 1, 6), (7, true, 1, 7), (8, true, 1, 8), (9, true, 1, 9), (10, true, 1, 10), + (4, true, 1, 4), + (5, true, 1, 5), + (6, true, 1, 6), + (7, true, 1, 7), + (8, true, 1, 8), + (9, true, 1, 9), + (10, true, 1, 10), // Hidden Layer 2: Expect 8 - (7, false, 2, 11), (8, false, 2, 12), (9, false, 2, 13), (13, true, 2, 14), (14, true, 2, 15), (15, true, 2, 16), (16, true, 2, 17), (17, true, 2, 18), + (7, false, 2, 11), + (8, false, 2, 12), + (9, false, 2, 13), + (13, true, 2, 14), + (14, true, 2, 15), + (15, true, 2, 16), + (16, true, 2, 17), + (17, true, 2, 18), // Hidden Layer 3: Expect 3 - (10, false, 3, 19), (11, false, 3, 20), (12, false, 3, 21), + (10, false, 3, 19), + (11, false, 3, 20), + (12, false, 3, 21), // Output Layer: Expect 4 - (18, true, 4, 22), (19, true, 4, 23), (20, true, 4, 24), (21, true, 4, 25), - ].into_iter().collect(); + (18, true, 4, 22), + (19, true, 4, 23), + (20, true, 4, 24), + (21, true, 4, 25), + ] + .into_iter() + .collect(); // print result before comparison for r in result.iter() { @@ -956,25 +1226,52 @@ mod tests { assert_eq!(result_set, expected); // Swapping order - let result = crossbreed_neuron_arrays(segments.clone(), secondary_network.clone(), primary_network.clone()); + let result = crossbreed_neuron_arrays( + segments.clone(), + secondary_network.clone(), + primary_network.clone(), + ); // Expected Result Set let expected: HashSet<(u32, bool, usize, u32)> = vec![ // Input layer: Expect 4 - (0, true, 0, 0), (1, true, 0, 1), (2, true, 0, 2), (3, true, 0, 3), + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), // Hidden Layer 1: Expect 7 - (4, true, 1, 4), (5, true, 1, 5), (6, true, 1, 6), (8, false, 1, 7), (9, false, 1, 8), (10, false, 1, 9), (11, false, 1, 10), + (4, true, 1, 4), + (5, true, 1, 5), + (6, true, 1, 6), + (8, false, 1, 7), + (9, false, 1, 8), + (10, false, 1, 9), + (11, false, 1, 10), // Hidden Layer 2: Expect 4 - (7, true, 2, 11), (8, true, 2, 12), (9, true, 2, 13), (12, false, 2, 14), + (7, true, 2, 11), + (8, true, 2, 12), + (9, true, 2, 13), + (12, false, 2, 14), // Hidden Layer 3: Expect 3 - (10, true, 3, 15), (11, true, 3, 16), (12, true, 3, 17), + (10, true, 3, 15), + (11, true, 3, 16), + (12, true, 3, 17), // Hidden Layer 4: Expect 3 - (13, true, 4, 18), (14, true, 4, 19), (15, true, 4, 20), + (13, true, 4, 18), + (14, true, 4, 19), + (15, true, 4, 20), // Hidden Layer 5: Expect 3 - (16, true, 5, 21), (17, true, 5, 22), (18, true, 5, 23), + (16, true, 5, 21), + (17, true, 5, 22), + (18, true, 5, 23), // Output Layer: Expect 4 - (19, true, 6, 24), (20, true, 6, 25), (21, true, 6, 26), (22, true, 6, 27), - ].into_iter().collect(); + (19, true, 6, 24), + (20, true, 6, 25), + (21, true, 6, 26), + (22, true, 6, 27), + ] + .into_iter() + .collect(); // print result before comparison for r in result.iter() { @@ -990,40 +1287,76 @@ mod tests { // Testing networks with the same size // Assign let segments = vec![(0, 3), (4, 6), (7, 8), (9, 11)]; - + let primary_network = generate_neuron_datastructure(&vec![4, 3, 4, 5, 4]); - + vec![ // Input layer - (0, 0), (1, 0), (2, 0), (3, 0), + (0, 0), + (1, 0), + (2, 0), + (3, 0), // Hidden layer 1: 3 neurons - (4, 1), (5, 1), (6, 1), + (4, 1), + (5, 1), + (6, 1), // Hidden Layer 2: 4 neurons - (7, 2), (8, 2), (9, 2), (10, 2), + (7, 2), + (8, 2), + (9, 2), + (10, 2), // Hidden Layer 3: 5 neurons - (11, 3), (12, 3), (13, 3), (14, 3), (15, 3), + (11, 3), + (12, 3), + (13, 3), + (14, 3), + (15, 3), // Output layer - (16, 4), (17, 4), (18, 4), (19, 4), + (16, 4), + (17, 4), + (18, 4), + (19, 4), ]; let secondary_network = primary_network.clone(); // Act - let result = crossbreed_neuron_arrays(segments.clone(), primary_network.clone(), secondary_network.clone()); + let result = crossbreed_neuron_arrays( + segments.clone(), + primary_network.clone(), + secondary_network.clone(), + ); // Expected Result Set let expected: HashSet<(u32, bool, usize, u32)> = vec![ // Input layer: Expect 4 - (0, true, 0, 0), (1, true, 0, 1), (2, true, 0, 2), (3, true, 0, 3), + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), // Hidden Layer 1: Expect 3 - (4, false, 1, 4), (5, false, 1, 5), (6, false, 1, 6), + (4, false, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), // Hidden Layer 2: Expect 4 - (7, true, 2, 7), (8, true, 2, 8), (9, false, 2, 9), (10, false, 2, 10), + (7, true, 2, 7), + (8, true, 2, 8), + (9, false, 2, 9), + (10, false, 2, 10), // Hidden Layer 3: Expect 5 - (11, false, 3, 11), (12, true, 3, 12), (13, true, 3, 13), (14, true, 3, 14), (15, true, 3, 15), + (11, false, 3, 11), + (12, true, 3, 12), + (13, true, 3, 13), + (14, true, 3, 14), + (15, true, 3, 15), // Output Layer: Expect 4 - (16, true, 4, 16), (17, true, 4, 17), (18, true, 4, 18), (19, true, 4, 19), - ].into_iter().collect(); + (16, true, 4, 16), + (17, true, 4, 17), + (18, true, 4, 18), + (19, true, 4, 19), + ] + .into_iter() + .collect(); // print result before comparison for r in result.iter() { @@ -1040,21 +1373,42 @@ mod tests { let segments = vec![(0, 5), (6, 6), (7, 11), (12, 13)]; // Act - let result = crossbreed_neuron_arrays(segments.clone(), primary_network.clone(), secondary_network.clone()); + let result = crossbreed_neuron_arrays( + segments.clone(), + primary_network.clone(), + secondary_network.clone(), + ); // Expected Result Set let expected: HashSet<(u32, bool, usize, u32)> = vec![ // Input layer: Expect 4 - (0, true, 0, 0), (1, true, 0, 1), (2, true, 0, 2), (3, true, 0, 3), + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), // Hidden Layer 1: Expect 3 - (4, true, 1, 4), (5, true, 1, 5), (6, false, 1, 6), + (4, true, 1, 4), + (5, true, 1, 5), + (6, false, 1, 6), // Hidden Layer 2: Expect 4 - (7, true, 2, 7), (8, true, 2, 8), (9, true, 2, 9), (10, true, 2, 10), + (7, true, 2, 7), + (8, true, 2, 8), + (9, true, 2, 9), + (10, true, 2, 10), // Hidden Layer 3: Expect 5 - (11, true, 3, 11), (12, false, 3, 12), (13, false, 3, 13), (14, true, 3, 14), (15, true, 3, 15), + (11, true, 3, 11), + (12, false, 3, 12), + (13, false, 3, 13), + (14, true, 3, 14), + (15, true, 3, 15), // Output Layer: Expect 4 - (16, true, 4, 16), (17, true, 4, 17), (18, true, 4, 18), (19, true, 4, 19), - ].into_iter().collect(); + (16, true, 4, 16), + (17, true, 4, 17), + (18, true, 4, 18), + (19, true, 4, 19), + ] + .into_iter() + .collect(); // print result before comparison for r in result.iter() { @@ -1078,10 +1432,22 @@ mod tests { // Expected Result let expected: Vec<(u32, usize)> = vec![ - (0, 0), (1, 0), (2, 0), (3, 0), - (4, 1), (5, 1), (6, 1), - (7, 2), (8, 2), (9, 2), (10, 2), (11, 2), - (12, 3), (13, 3), (14, 3), (15, 3), + (0, 0), + (1, 0), + (2, 0), + (3, 0), + (4, 1), + (5, 1), + (6, 1), + (7, 2), + (8, 2), + (9, 2), + (10, 2), + (11, 2), + (12, 3), + (13, 3), + (14, 3), + (15, 3), ]; // Assert @@ -1095,10 +1461,22 @@ mod tests { let expected = vec![ // (input, expected output) - (0, 0), (1, 1), (2, 2), (3, 3), - (4, 5), (5, 6), (6, 7), - (7, 9), (8, 10), (9, 11), (10, 12), (11, 13), - (12, 15), (13, 16), (14, 17), (15, 18), + (0, 0), + (1, 1), + (2, 2), + (3, 3), + (4, 5), + (5, 6), + (6, 7), + (7, 9), + (8, 10), + (9, 11), + (10, 12), + (11, 13), + (12, 15), + (13, 16), + (14, 17), + (15, 18), ]; // Act @@ -1113,8 +1491,7 @@ mod tests { // Assert if let Some(result) = result { assert_eq!(result, input); - } - else { + } else { assert!(false, "Expected Some, got None"); } } @@ -1131,7 +1508,7 @@ mod tests { } #[test] - fn consolidate_old_connections_test() -> Result<(), Box>{ + fn consolidate_old_connections_test() -> Result<(), Box> { // Assign let primary_shape = vec![4, 8, 6, 4]; let secondary_shape = vec![4, 3, 3, 3, 3, 3, 4]; @@ -1154,13 +1531,34 @@ mod tests { let new_neurons = vec![ // Input layer: Expect 4 - (0, true, 0, 0), (1, true, 0, 1), (2, true, 0, 2), (3, true, 0, 3), + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), // Hidden Layer 1: Expect 8 - (4, false, 1, 4), (5, false, 1, 5), (6, false, 1, 6), (7, true, 1, 7), (8, true, 1, 8), (9, true, 1, 9), (10, true, 1, 10), (11, true, 1, 11), + (4, false, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), + (7, true, 1, 7), + (8, true, 1, 8), + (9, true, 1, 9), + (10, true, 1, 10), + (11, true, 1, 11), // Hidden Layer 2: Expect 9 - (7, false, 2, 12), (8, false, 2, 13), (9, false, 2, 14), (12, true, 2, 15), (13, true, 2, 16), (14, true, 2, 17), (15, true, 2, 18), (16, true, 2, 19), (17, true, 2, 20), + (7, false, 2, 12), + (8, false, 2, 13), + (9, false, 2, 14), + (12, true, 2, 15), + (13, true, 2, 16), + (14, true, 2, 17), + (15, true, 2, 18), + (16, true, 2, 19), + (17, true, 2, 20), // Output Layer: Expect 4 - (18, true, 3, 21), (19, true, 3, 22), (20, true, 3, 23), (21, true, 3, 24), + (18, true, 3, 21), + (19, true, 3, 22), + (20, true, 3, 23), + (21, true, 3, 24), ]; let new_shape = vec![4, 8, 9, 4]; let mut new_fann = Fann::new(&[4, 8, 9, 4])?; @@ -1172,7 +1570,13 @@ mod tests { new_fann.set_connections(&new_connections); // Act - consolidate_old_connections(&primary_fann, &secondary_fann, new_shape, new_neurons, &mut new_fann); + consolidate_old_connections( + &primary_fann, + &secondary_fann, + new_shape, + new_neurons, + &mut new_fann, + ); // Bias neurons // Layer 1: 4 @@ -1181,29 +1585,148 @@ mod tests { let expected_connections = vec![ // (from_neuron, to_neuron, weight) // Hidden Layer 1 (5-12) - (0, 5, -5.0), (1, 5, -105.0), (2, 5, -205.0), (3, 5, -305.0), - (0, 6, -6.0), (1, 6, -106.0), (2, 6, -206.0), (3, 6, -306.0), - (0, 7, -7.0), (1, 7, -107.0), (2, 7, -207.0), (3, 7, -307.0), - (0, 8, 8.0), (1, 8, 108.0), (2, 8, 208.0), (3, 8, 308.0), - (0, 9, 9.0), (1, 9, 109.0), (2, 9, 209.0), (3, 9, 309.0), - (0, 10, 10.0), (1, 10, 110.0), (2, 10, 210.0), (3, 10, 310.0), - (0, 11, 11.0), (1, 11, 111.0), (2, 11, 211.0), (3, 11, 311.0), - (0, 12, 12.0), (1, 12, 112.0), (2, 12, 212.0), (3, 12, 312.0), + (0, 5, -5.0), + (1, 5, -105.0), + (2, 5, -205.0), + (3, 5, -305.0), + (0, 6, -6.0), + (1, 6, -106.0), + (2, 6, -206.0), + (3, 6, -306.0), + (0, 7, -7.0), + (1, 7, -107.0), + (2, 7, -207.0), + (3, 7, -307.0), + (0, 8, 8.0), + (1, 8, 108.0), + (2, 8, 208.0), + (3, 8, 308.0), + (0, 9, 9.0), + (1, 9, 109.0), + (2, 9, 209.0), + (3, 9, 309.0), + (0, 10, 10.0), + (1, 10, 110.0), + (2, 10, 210.0), + (3, 10, 310.0), + (0, 11, 11.0), + (1, 11, 111.0), + (2, 11, 211.0), + (3, 11, 311.0), + (0, 12, 12.0), + (1, 12, 112.0), + (2, 12, 212.0), + (3, 12, 312.0), // Hidden Layer 2 (14-22) - (5, 14, -509.0), (6, 14, -609.0), (7, 14, -709.0), (8, 14, 0.0), (9, 14, 0.0), (10, 14, 0.0), (11, 14, 0.0), (12, 14, 0.0), - (5, 15, -510.0), (6, 15, -610.0), (7, 15, -710.0), (8, 15, 0.0), (9, 15, 0.0), (10, 15, 0.0), (11, 15, 0.0), (12, 15, 0.0), - (5, 16, -511.0), (6, 16, -611.0), (7, 16, -711.0), (8, 16, 0.0), (9, 16, 0.0), (10, 16, 0.0), (11, 16, 0.0), (12, 16, 0.0), - (5, 17, 514.0), (6, 17, 614.0), (7, 17, 714.0), (8, 17, 814.0), (9, 17, 914.0), (10, 17, 1014.0), (11, 17, 1114.0), (12, 17, 1214.0), - (5, 18, 515.0), (6, 18, 615.0), (7, 18, 715.0), (8, 18, 815.0), (9, 18, 915.0), (10, 18, 1015.0), (11, 18, 1115.0), (12, 18, 1215.0), - (5, 19, 516.0), (6, 19, 616.0), (7, 19, 716.0), (8, 19, 816.0), (9, 19, 916.0), (10, 19, 1016.0), (11, 19, 1116.0), (12, 19, 1216.0), - (5, 20, 517.0), (6, 20, 617.0), (7, 20, 717.0), (8, 20, 817.0), (9, 20, 917.0), (10, 20, 1017.0), (11, 20, 1117.0), (12, 20, 1217.0), - (5, 21, 518.0), (6, 21, 618.0), (7, 21, 718.0), (8, 21, 818.0), (9, 21, 918.0), (10, 21, 1018.0), (11, 21, 1118.0), (12, 21, 1218.0), - (5, 22, 519.0), (6, 22, 619.0), (7, 22, 719.0), (8, 22, 819.0), (9, 22, 919.0), (10, 22, 1019.0), (11, 22, 1119.0), (12, 22, 1219.0), + (5, 14, -509.0), + (6, 14, -609.0), + (7, 14, -709.0), + (8, 14, 0.0), + (9, 14, 0.0), + (10, 14, 0.0), + (11, 14, 0.0), + (12, 14, 0.0), + (5, 15, -510.0), + (6, 15, -610.0), + (7, 15, -710.0), + (8, 15, 0.0), + (9, 15, 0.0), + (10, 15, 0.0), + (11, 15, 0.0), + (12, 15, 0.0), + (5, 16, -511.0), + (6, 16, -611.0), + (7, 16, -711.0), + (8, 16, 0.0), + (9, 16, 0.0), + (10, 16, 0.0), + (11, 16, 0.0), + (12, 16, 0.0), + (5, 17, 514.0), + (6, 17, 614.0), + (7, 17, 714.0), + (8, 17, 814.0), + (9, 17, 914.0), + (10, 17, 1014.0), + (11, 17, 1114.0), + (12, 17, 1214.0), + (5, 18, 515.0), + (6, 18, 615.0), + (7, 18, 715.0), + (8, 18, 815.0), + (9, 18, 915.0), + (10, 18, 1015.0), + (11, 18, 1115.0), + (12, 18, 1215.0), + (5, 19, 516.0), + (6, 19, 616.0), + (7, 19, 716.0), + (8, 19, 816.0), + (9, 19, 916.0), + (10, 19, 1016.0), + (11, 19, 1116.0), + (12, 19, 1216.0), + (5, 20, 517.0), + (6, 20, 617.0), + (7, 20, 717.0), + (8, 20, 817.0), + (9, 20, 917.0), + (10, 20, 1017.0), + (11, 20, 1117.0), + (12, 20, 1217.0), + (5, 21, 518.0), + (6, 21, 618.0), + (7, 21, 718.0), + (8, 21, 818.0), + (9, 21, 918.0), + (10, 21, 1018.0), + (11, 21, 1118.0), + (12, 21, 1218.0), + (5, 22, 519.0), + (6, 22, 619.0), + (7, 22, 719.0), + (8, 22, 819.0), + (9, 22, 919.0), + (10, 22, 1019.0), + (11, 22, 1119.0), + (12, 22, 1219.0), // Output layer (24-27) - (14, 24, 0.0), (15, 24, 0.0), (16, 24, 0.0), (17, 24, 1421.0), (18, 24, 1521.0), (19, 24, 1621.0), (20, 24, 1721.0), (21, 24, 1821.0), (22, 24, 1921.0), - (14, 25, 0.0), (15, 25, 0.0), (16, 25, 0.0), (17, 25, 1422.0), (18, 25, 1522.0), (19, 25, 1622.0), (20, 25, 1722.0), (21, 25, 1822.0), (22, 25, 1922.0), - (14, 26, 0.0), (15, 26, 0.0), (16, 26, 0.0), (17, 26, 1423.0), (18, 26, 1523.0), (19, 26, 1623.0), (20, 26, 1723.0), (21, 26, 1823.0), (22, 26, 1923.0), - (14, 27, 0.0), (15, 27, 0.0), (16, 27, 0.0), (17, 27, 1424.0), (18, 27, 1524.0), (19, 27, 1624.0), (20, 27, 1724.0), (21, 27, 1824.0), (22, 27, 1924.0), + (14, 24, 0.0), + (15, 24, 0.0), + (16, 24, 0.0), + (17, 24, 1421.0), + (18, 24, 1521.0), + (19, 24, 1621.0), + (20, 24, 1721.0), + (21, 24, 1821.0), + (22, 24, 1921.0), + (14, 25, 0.0), + (15, 25, 0.0), + (16, 25, 0.0), + (17, 25, 1422.0), + (18, 25, 1522.0), + (19, 25, 1622.0), + (20, 25, 1722.0), + (21, 25, 1822.0), + (22, 25, 1922.0), + (14, 26, 0.0), + (15, 26, 0.0), + (16, 26, 0.0), + (17, 26, 1423.0), + (18, 26, 1523.0), + (19, 26, 1623.0), + (20, 26, 1723.0), + (21, 26, 1823.0), + (22, 26, 1923.0), + (14, 27, 0.0), + (15, 27, 0.0), + (16, 27, 0.0), + (17, 27, 1424.0), + (18, 27, 1524.0), + (19, 27, 1624.0), + (20, 27, 1724.0), + (21, 27, 1824.0), + (22, 27, 1924.0), ]; for connection in new_fann.get_connections().iter() { @@ -1214,9 +1737,15 @@ mod tests { // Compare each connection to the expected connection let new_connections = new_fann.get_connections(); for connection in expected_connections.iter() { - let matching_connection = new_connections.iter().find(|&c| c.from_neuron == connection.0 && c.to_neuron == connection.1); + let matching_connection = new_connections + .iter() + .find(|&c| c.from_neuron == connection.0 && c.to_neuron == connection.1); if let Some(matching_connection) = matching_connection { - assert_eq!(matching_connection.weight, connection.2, "Connection: {:?}", matching_connection); + assert_eq!( + matching_connection.weight, connection.2, + "Connection: {:?}", + matching_connection + ); } else { assert!(false, "Connection not found: {:?}", connection); } @@ -1226,17 +1755,41 @@ mod tests { // (from_neuron, to_neuron, weight) // Bias Neurons // Layer 2: bias neuron_id 4 - (4, 5, -405.0), (4, 6, -406.0), (4, 7, -407.0), (4, 8, 408.0), (4, 9, 409.0), (4, 10, 410.0), (4, 11, 411.0), (4, 12, 412.0), + (4, 5, -405.0), + (4, 6, -406.0), + (4, 7, -407.0), + (4, 8, 408.0), + (4, 9, 409.0), + (4, 10, 410.0), + (4, 11, 411.0), + (4, 12, 412.0), // Layer 3: bias neuron_id 13 - (13, 14, -809.0), (13, 15, -810.0), (13, 16, -811.0), (13, 17, 1314.0), (13, 18, 1315.0), (13, 19, 1316.0), (13, 20, 1317.0), (13, 21, 1318.0), (13, 22, 1319.0), + (13, 14, -809.0), + (13, 15, -810.0), + (13, 16, -811.0), + (13, 17, 1314.0), + (13, 18, 1315.0), + (13, 19, 1316.0), + (13, 20, 1317.0), + (13, 21, 1318.0), + (13, 22, 1319.0), // Layer 4: bias neuron_id 23 - (23, 24, 2021.0), (23, 25, 2022.0), (23, 26, 2023.0), (23, 27, 2024.0), + (23, 24, 2021.0), + (23, 25, 2022.0), + (23, 26, 2023.0), + (23, 27, 2024.0), ]; for connection in expected_bias_neuron_connections.iter() { - let matching_connection = new_connections.iter().find(|&c| c.from_neuron == connection.0 && c.to_neuron == connection.1); + let matching_connection = new_connections + .iter() + .find(|&c| c.from_neuron == connection.0 && c.to_neuron == connection.1); if let Some(matching_connection) = matching_connection { - assert_eq!(matching_connection.weight, connection.2, "Connection: {:?}", matching_connection); + assert_eq!( + matching_connection.weight, connection.2, + "Connection: {:?}", + matching_connection + ); } else { assert!(false, "Connection not found: {:?}", connection); } @@ -1244,4 +1797,4 @@ mod tests { Ok(()) } -} \ No newline at end of file +} diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index c11c668..af5b316 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -1,8 +1,11 @@ -use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error}; +use async_trait::async_trait; +use gemla::{ + core::genetic_node::{GeneticNode, GeneticNodeContext}, + error::Error, +}; use rand::prelude::*; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use async_trait::async_trait; const POPULATION_SIZE: u64 = 5; const POPULATION_REDUCTION_SIZE: u64 = 3; @@ -76,7 +79,12 @@ impl GeneticNode for TestState { Ok(()) } - async fn merge(left: &TestState, right: &TestState, id: &Uuid, gemla_context: Self::Context) -> Result, Error> { + async fn merge( + left: &TestState, + right: &TestState, + id: &Uuid, + gemla_context: Self::Context, + ) -> Result, Error> { let mut v = left.population.clone(); v.append(&mut right.population.clone()); @@ -87,12 +95,14 @@ impl GeneticNode for TestState { let mut result = TestState { population: v }; - result.mutate(GeneticNodeContext { - id: id.clone(), - generation: 0, - max_generations: 0, - gemla_context: gemla_context - }).await?; + result + .mutate(GeneticNodeContext { + id: *id, + generation: 0, + max_generations: 0, + gemla_context, + }) + .await?; Ok(Box::new(result)) } @@ -105,14 +115,14 @@ mod tests { #[tokio::test] async fn test_initialize() { - let state = TestState::initialize( - GeneticNodeContext { - id: Uuid::new_v4(), - generation: 0, - max_generations: 0, - gemla_context: (), - } - ).await.unwrap(); + let state = TestState::initialize(GeneticNodeContext { + id: Uuid::new_v4(), + generation: 0, + max_generations: 0, + gemla_context: (), + }) + .await + .unwrap(); assert_eq!(state.population.len(), POPULATION_SIZE as usize); } @@ -125,35 +135,38 @@ mod tests { let original_population = state.population.clone(); - state.simulate( - GeneticNodeContext { + state + .simulate(GeneticNodeContext { id: Uuid::new_v4(), generation: 0, max_generations: 0, gemla_context: (), - } - ).await.unwrap(); + }) + .await + .unwrap(); assert!(original_population .iter() .zip(state.population.iter()) .all(|(&a, &b)| b >= a - 1 && b <= a + 2)); - state.simulate( - GeneticNodeContext { + state + .simulate(GeneticNodeContext { id: Uuid::new_v4(), generation: 0, max_generations: 0, gemla_context: (), - } - ).await.unwrap(); - state.simulate( - GeneticNodeContext { + }) + .await + .unwrap(); + state + .simulate(GeneticNodeContext { id: Uuid::new_v4(), generation: 0, max_generations: 0, gemla_context: (), - } - ).await.unwrap(); + }) + .await + .unwrap(); assert!(original_population .iter() .zip(state.population.iter()) @@ -166,14 +179,15 @@ mod tests { population: vec![4, 3, 3], }; - state.mutate( - GeneticNodeContext { + state + .mutate(GeneticNodeContext { id: Uuid::new_v4(), generation: 0, max_generations: 0, gemla_context: (), - } - ).await.unwrap(); + }) + .await + .unwrap(); assert_eq!(state.population.len(), POPULATION_SIZE as usize); } @@ -188,7 +202,9 @@ mod tests { population: vec![0, 1, 3, 7], }; - let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()).await.unwrap(); + let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()) + .await + .unwrap(); assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize); assert!(merged_state.population.iter().any(|&x| x == 7)); diff --git a/gemla/src/core/genetic_node.rs b/gemla/src/core/genetic_node.rs index d9d7bc0..020d2c6 100644 --- a/gemla/src/core/genetic_node.rs +++ b/gemla/src/core/genetic_node.rs @@ -5,10 +5,10 @@ use crate::error::Error; use anyhow::Context; +use async_trait::async_trait; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::fmt::Debug; use uuid::Uuid; -use async_trait::async_trait; /// An enum used to control the state of a [`GeneticNode`] /// @@ -30,14 +30,14 @@ pub struct GeneticNodeContext { pub generation: u64, pub max_generations: u64, pub id: Uuid, - pub gemla_context: S + pub gemla_context: S, } /// A trait used to interact with the internal state of nodes within the [`Bracket`] /// /// [`Bracket`]: crate::bracket::Bracket #[async_trait] -pub trait GeneticNode : Send { +pub trait GeneticNode: Send { type Context; /// Initializes a new instance of a [`GeneticState`]. @@ -54,15 +54,20 @@ pub trait GeneticNode : Send { /// TODO async fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; - async fn merge(left: &Self, right: &Self, id: &Uuid, context: Self::Context) -> Result, Error>; + async fn merge( + left: &Self, + right: &Self, + id: &Uuid, + context: Self::Context, + ) -> Result, Error>; } /// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as /// well as signal recovery. Transition states are given by [`GeneticState`] #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct GeneticNodeWrapper -where - T: Clone +pub struct GeneticNodeWrapper +where + T: Clone, { node: Option, state: GeneticState, @@ -71,9 +76,9 @@ where id: Uuid, } -impl Default for GeneticNodeWrapper +impl Default for GeneticNodeWrapper where - T: Clone + T: Clone, { fn default() -> Self { GeneticNodeWrapper { @@ -146,7 +151,8 @@ where self.state = GeneticState::Simulate; } (GeneticState::Simulate, Some(n)) => { - n.simulate(context.clone()).await + n.simulate(context.clone()) + .await .with_context(|| format!("Error simulating node: {:?}", self))?; self.state = if self.generation >= self.max_generations { @@ -156,7 +162,8 @@ where }; } (GeneticState::Mutate, Some(n)) => { - n.mutate(context.clone()).await + n.mutate(context.clone()) + .await .with_context(|| format!("Error mutating node: {:?}", self))?; self.generation += 1; @@ -186,20 +193,33 @@ mod tests { impl GeneticNode for TestState { type Context = (); - async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn simulate( + &mut self, + _context: GeneticNodeContext, + ) -> Result<(), Error> { self.score += 1.0; Ok(()) } - async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn mutate( + &mut self, + _context: GeneticNodeContext, + ) -> Result<(), Error> { Ok(()) } - async fn initialize(_context: GeneticNodeContext) -> Result, Error> { + async fn initialize( + _context: GeneticNodeContext, + ) -> Result, Error> { Ok(Box::new(TestState { score: 0.0 })) } - async fn merge(_l: &TestState, _r: &TestState, _id: &Uuid, _: Self::Context) -> Result, Error> { + async fn merge( + _l: &TestState, + _r: &TestState, + _id: &Uuid, + _: Self::Context, + ) -> Result, Error> { Err(Error::Other(anyhow!("Unable to merge"))) } } diff --git a/gemla/src/core/mod.rs b/gemla/src/core/mod.rs index e779318..4cf9b7a 100644 --- a/gemla/src/core/mod.rs +++ b/gemla/src/core/mod.rs @@ -10,37 +10,38 @@ use futures::future; use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState}; use log::{info, trace, warn}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use tokio::{sync::RwLock, task::JoinHandle}; use std::{ - collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant + collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, + sync::Arc, time::Instant, }; +use tokio::{sync::RwLock, task::JoinHandle}; use uuid::Uuid; type SimulationTree = Box>>; /// Provides configuration options for managing a [`Gemla`] object as it executes. -/// +/// /// # Examples /// ```rust,ignore /// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] /// struct TestState { /// pub score: f64, /// } -/// +/// /// impl genetic_node::GeneticNode for TestState { /// fn simulate(&mut self) -> Result<(), Error> { /// self.score += 1.0; /// Ok(()) /// } -/// +/// /// fn mutate(&mut self) -> Result<(), Error> { /// Ok(()) /// } -/// +/// /// fn initialize() -> Result, Error> { /// Ok(Box::new(TestState { score: 0.0 })) /// } -/// +/// /// fn merge(left: &TestState, right: &TestState) -> Result, Error> { /// Ok(Box::new(if left.score > right.score { /// left.clone() @@ -49,7 +50,7 @@ type SimulationTree = Box>>; /// })) /// } /// } -/// +/// /// fn main() { /// /// } @@ -80,13 +81,18 @@ where T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Sync + Clone, T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default, { - pub async fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result { + pub async fn new( + path: &Path, + config: GemlaConfig, + data_format: DataFormat, + ) -> Result { match File::open(path) { - // If the file exists we either want to overwrite the file or read from the file + // If the file exists we either want to overwrite the file or read from the file // based on the configuration provided Ok(_) => Ok(Gemla { data: if config.overwrite { - FileLinked::new((None, config, T::Context::default()), path, data_format).await? + FileLinked::new((None, config, T::Context::default()), path, data_format) + .await? } else { FileLinked::from_file(path, data_format)? }, @@ -94,7 +100,8 @@ where }), // If the file doesn't exist we must create it Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla { - data: FileLinked::new((None, config, T::Context::default()), path, data_format).await?, + data: FileLinked::new((None, config, T::Context::default()), path, data_format) + .await?, threads: HashMap::new(), }), Err(error) => Err(Error::IO(error)), @@ -106,24 +113,32 @@ where } pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> { - { + let tree_completed = { // Only increase height if the tree is uninitialized or completed let data_arc = self.data.readonly(); let data_ref = data_arc.read().await; let tree_ref = data_ref.0.as_ref(); - if tree_ref.is_none() || - tree_ref - .map(|t| Gemla::is_completed(t)) - .unwrap_or(true) - { - // Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed - // in the tree and which nodes have not. - self.data.mutate(|(d, c, _)| { - let mut tree: Option> = Gemla::increase_height(d.take(), c, steps); + tree_ref.is_none() || tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(true) + }; + + if tree_completed { + // Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed + // in the tree and which nodes have not. + self.data + .mutate(|(d, c, _)| { + let mut tree: Option> = + Gemla::increase_height(d.take(), c, steps); mem::swap(d, &mut tree); - }).await?; - } + }) + .await?; + } + + { + // Only increase height if the tree is uninitialized or completed + let data_arc = self.data.readonly(); + let data_ref = data_arc.read().await; + let tree_ref = data_ref.0.as_ref(); info!( "Height of simulation tree increased to {}", @@ -141,36 +156,36 @@ where let data_ref = data_arc.read().await; let tree_ref = data_ref.0.as_ref(); - is_tree_processed = tree_ref - .map(|t| Gemla::is_completed(t)) - .unwrap_or(false) + is_tree_processed = tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(false) } - // We need to keep simulating until the tree has been completely processed. - if is_tree_processed - { - + if is_tree_processed { self.join_threads().await?; info!("Processed tree"); break; } - if let Some(node) = tree_ref - .and_then(|t| self.get_unprocessed_node(t)) - { + let (node, gemla_context) = { + let data_arc = self.data.readonly(); + let data_ref = data_arc.read().await; + let (tree_ref, _, gemla_context) = &*data_ref; // (Option>>, GemlaConfig, T::Context) + + let node = tree_ref.as_ref().and_then(|t| self.get_unprocessed_node(t)); + + (node, gemla_context.clone()) + }; + + if let Some(node) = node { trace!("Adding node to process list {}", node.id()); - let data_arc = self.data.readonly(); - let data_ref2 = data_arc.read().await; - let gemla_context = data_ref2.2.clone(); - drop(data_ref2); + let gemla_context = gemla_context.clone(); - self.threads - .insert(node.id(), tokio::spawn(async move { - Gemla::process_node(node, gemla_context).await - })); + self.threads.insert( + node.id(), + tokio::spawn(async move { Gemla::process_node(node, gemla_context).await }), + ); } else { trace!("No node found to process, joining threads"); @@ -186,7 +201,7 @@ where trace!("Joining threads for nodes {:?}", self.threads.keys()); let results = future::join_all(self.threads.values_mut()).await; - + // Converting a list of results into a result wrapping the list let reduced_results: Result>, Error> = results.into_iter().flatten().collect(); @@ -195,32 +210,34 @@ where // We need to retrieve the processed nodes from the resulting list and replace them in the original list match reduced_results { Ok(r) => { - self.data.mutate_async(|d| async move { - // Scope to limit the duration of the read lock - let (_, context) = { - let data_read = d.read().await; - (data_read.1.clone(), data_read.2.clone()) - }; // Read lock is dropped here + self.data + .mutate_async(|d| async move { + // Scope to limit the duration of the read lock + let (_, context) = { + let data_read = d.read().await; + (data_read.1, data_read.2.clone()) + }; // Read lock is dropped here - let mut data_write = d.write().await; - - if let Some(t) = data_write.0.as_mut() { - let failed_nodes = Gemla::replace_nodes(t, r); - // We receive a list of nodes that were unable to be found in the original tree - if !failed_nodes.is_empty() { - warn!( - "Unable to find {:?} to replace in tree", - failed_nodes.iter().map(|n| n.id()) - ) + let mut data_write = d.write().await; + + if let Some(t) = data_write.0.as_mut() { + let failed_nodes = Gemla::replace_nodes(t, r); + // We receive a list of nodes that were unable to be found in the original tree + if !failed_nodes.is_empty() { + warn!( + "Unable to find {:?} to replace in tree", + failed_nodes.iter().map(|n| n.id()) + ) + } + + // Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes + Gemla::merge_completed_nodes(t, context.clone()).await + } else { + warn!("Unable to replce nodes {:?} in empty tree", r); + Ok(()) } - - // Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes - Gemla::merge_completed_nodes(t, context.clone()).await - } else { - warn!("Unable to replce nodes {:?} in empty tree", r); - Ok(()) - } - }).await??; + }) + .await??; } Err(e) => return Err(e), } @@ -230,7 +247,10 @@ where } #[async_recursion] - async fn merge_completed_nodes<'a>(tree: &'a mut SimulationTree, gemla_context: T::Context) -> Result<(), Error> { + async fn merge_completed_nodes<'a>( + tree: &'a mut SimulationTree, + gemla_context: T::Context, + ) -> Result<(), Error> { if tree.val.state() == GeneticState::Initialize { match (&mut tree.left, &mut tree.right) { // If the current node has been initialized, and has children nodes that are completed, then we need @@ -241,7 +261,13 @@ where { info!("Merging nodes {} and {}", l.val.id(), r.val.id()); if let (Some(left_node), Some(right_node)) = (l.val.take(), r.val.take()) { - let merged_node = GeneticNode::merge(&left_node, &right_node, &tree.val.id(), gemla_context.clone()).await?; + let merged_node = GeneticNode::merge( + &left_node, + &right_node, + &tree.val.id(), + gemla_context.clone(), + ) + .await?; tree.val = GeneticNodeWrapper::from( *merged_node, tree.val.max_generations(), @@ -286,15 +312,18 @@ where } fn get_unprocessed_node(&self, tree: &SimulationTree) -> Option> { - // If the current node has been processed or exists in the thread list then we want to stop recursing. Checking if it exists in the thread list + // If the current node has been processed or exists in the thread list then we want to stop recursing. Checking if it exists in the thread list // should be fine because we process the tree from bottom to top. if tree.val.state() != GeneticState::Finish && !self.threads.contains_key(&tree.val.id()) { match (&tree.left, &tree.right) { - // If the children are finished we can start processing the currrent node. The current node should be merged from the children already + // If the children are finished we can start processing the currrent node. The current node should be merged from the children already // during join_threads. (Some(l), Some(r)) if l.val.state() == GeneticState::Finish - && r.val.state() == GeneticState::Finish => Some(tree.val.clone()), + && r.val.state() == GeneticState::Finish => + { + Some(tree.val.clone()) + } (Some(l), Some(r)) => self .get_unprocessed_node(l) .or_else(|| self.get_unprocessed_node(r)), @@ -334,7 +363,7 @@ where } else { let left_branch_height = tree.as_ref().map(|t| t.height() as u64).unwrap_or(0) + amount - 1; - + Some(Box::new(Tree::new( GeneticNodeWrapper::new(config.generations_per_height), Gemla::increase_height(tree, config, amount - 1), @@ -352,10 +381,13 @@ where fn is_completed(tree: &SimulationTree) -> bool { // If the current node is finished, then by convention the children should all be finished as well - tree.val.state() == GeneticState::Finish + tree.val.state() == GeneticState::Finish } - async fn process_node(mut node: GeneticNodeWrapper, gemla_context: T::Context) -> Result, Error> { + async fn process_node( + mut node: GeneticNodeWrapper, + gemla_context: T::Context, + ) -> Result, Error> { let node_state_time = Instant::now(); let node_state = node.state(); @@ -379,10 +411,10 @@ where #[cfg(test)] mod tests { use crate::core::*; - use serde::{Deserialize, Serialize}; - use std::path::PathBuf; - use std::fs; use async_trait::async_trait; + use serde::{Deserialize, Serialize}; + use std::fs; + use std::path::PathBuf; use tokio::runtime::Runtime; use self::genetic_node::GeneticNodeContext; @@ -420,20 +452,33 @@ mod tests { impl genetic_node::GeneticNode for TestState { type Context = (); - async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn simulate( + &mut self, + _context: GeneticNodeContext, + ) -> Result<(), Error> { self.score += 1.0; Ok(()) } - async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn mutate( + &mut self, + _context: GeneticNodeContext, + ) -> Result<(), Error> { Ok(()) } - async fn initialize(_context: GeneticNodeContext) -> Result, Error> { + async fn initialize( + _context: GeneticNodeContext, + ) -> Result, Error> { Ok(Box::new(TestState { score: 0.0 })) } - async fn merge(left: &TestState, right: &TestState, _id: &Uuid, _: Self::Context) -> Result, Error> { + async fn merge( + left: &TestState, + right: &TestState, + _id: &Uuid, + _: Self::Context, + ) -> Result, Error> { Ok(Box::new(if left.score > right.score { left.clone() } else { @@ -464,7 +509,7 @@ mod tests { let data = gemla.data.readonly(); let data_lock = data.read().await; assert_eq!(data_lock.0.as_ref().unwrap().height(), 2); - + drop(data_lock); drop(gemla); assert!(path.exists()); @@ -498,40 +543,43 @@ mod tests { Ok(()) }) }) - }).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result. + }) + .await + .unwrap()?; // Wait for the blocking task to complete, then handle the Result. Ok(()) } - // #[tokio::test] - // async fn test_simulate() -> Result<(), Error> { - // let path = PathBuf::from("test_simulate"); - // // Use `spawn_blocking` to run the synchronous closure that internally awaits async code. - // tokio::task::spawn_blocking(move || { - // let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block. - // CleanUp::new(&path).run(move |p| { - // rt.block_on(async { - // // Testing initial creation - // let config = GemlaConfig { - // generations_per_height: 10, - // overwrite: true, - // }; - // let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; + #[tokio::test] + async fn test_simulate() -> Result<(), Error> { + let path = PathBuf::from("test_simulate"); + // Use `spawn_blocking` to run the synchronous closure that internally awaits async code. + tokio::task::spawn_blocking(move || { + let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block. + CleanUp::new(&path).run(move |p| { + rt.block_on(async { + // Testing initial creation + let config = GemlaConfig { + generations_per_height: 10, + overwrite: true, + }; + let mut gemla = Gemla::::new(&p, config, DataFormat::Json).await?; - // // Now we can use `.await` within the spawned blocking task. - // gemla.simulate(5).await?; - // let data = gemla.data.readonly(); - // let data_lock = data.read().unwrap(); - // let tree = data_lock.0.as_ref().unwrap(); - // assert_eq!(tree.height(), 5); - // assert_eq!(tree.val.as_ref().unwrap().score, 50.0); + // Now we can use `.await` within the spawned blocking task. + gemla.simulate(5).await?; + let data = gemla.data.readonly(); + let data_lock = data.read().await; + let tree = data_lock.0.as_ref().unwrap(); + assert_eq!(tree.height(), 5); + assert_eq!(tree.val.as_ref().unwrap().score, 50.0); - // Ok(()) - // }) - // }) - // }).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result. - - // Ok(()) - // } + Ok(()) + }) + }) + }) + .await + .unwrap()?; // Wait for the blocking task to complete, then handle the Result. + Ok(()) + } }