From 95699bd47eabd0f7d1b7d88fa1fa8b4cb498b03a Mon Sep 17 00:00:00 2001 From: vandomej Date: Thu, 4 Apr 2024 18:40:17 -0700 Subject: [PATCH] Adding user defined, shared context --- gemla/Cargo.toml | 1 + gemla/src/bin/bin.rs | 1 - gemla/src/bin/fighter_nn/fighter_context.rs | 48 ++++++ gemla/src/bin/fighter_nn/mod.rs | 156 +++++++++++--------- gemla/src/bin/test_state/mod.rs | 42 +++--- gemla/src/core/genetic_node.rs | 51 ++++--- gemla/src/core/mod.rs | 61 ++++---- 7 files changed, 212 insertions(+), 148 deletions(-) create mode 100644 gemla/src/bin/fighter_nn/fighter_context.rs diff --git a/gemla/Cargo.toml b/gemla/Cargo.toml index ee4951d..161d544 100644 --- a/gemla/Cargo.toml +++ b/gemla/Cargo.toml @@ -31,3 +31,4 @@ num_cpus = "1.16.0" easy-parallel = "3.3.1" fann = "0.1.8" async-trait = "0.1.78" +async-recursion = "1.1.0" diff --git a/gemla/src/bin/bin.rs b/gemla/src/bin/bin.rs index 16539fb..464f1aa 100644 --- a/gemla/src/bin/bin.rs +++ b/gemla/src/bin/bin.rs @@ -47,7 +47,6 @@ fn main() -> Result<()> { GemlaConfig { generations_per_height: 5, overwrite: false, - shared_semaphore_concurrency_limit: 50, }, DataFormat::Json, ))?; diff --git a/gemla/src/bin/fighter_nn/fighter_context.rs b/gemla/src/bin/fighter_nn/fighter_context.rs new file mode 100644 index 0000000..503bbf8 --- /dev/null +++ b/gemla/src/bin/fighter_nn/fighter_context.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use tokio::sync::Semaphore; + +const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 20; + + +#[derive(Debug, Clone)] +pub struct FighterContext { + pub shared_semaphore: Arc, +} + +impl Default for FighterContext { + fn default() -> Self { + FighterContext { + shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)), + } + } +} + + +// Custom serialization to just output the concurrency limit. +impl Serialize for FighterContext { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + // Assuming the semaphore's available permits represent the concurrency limit. + // This part is tricky since Semaphore does not expose its initial permits. + // You might need to store the concurrency limit as a separate field if this assumption doesn't hold. + let concurrency_limit = SHARED_SEMAPHORE_CONCURRENCY_LIMIT; + serializer.serialize_u64(concurrency_limit as u64) + } +} + +// Custom deserialization to reconstruct the FighterContext from a concurrency limit. +impl<'de> Deserialize<'de> for FighterContext { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let concurrency_limit = u64::deserialize(deserializer)?; + Ok(FighterContext { + shared_semaphore: Arc::new(Semaphore::new(concurrency_limit as usize)), + }) + } +} \ No newline at end of file diff --git a/gemla/src/bin/fighter_nn/mod.rs b/gemla/src/bin/fighter_nn/mod.rs index 4161079..b7b2c43 100644 --- a/gemla/src/bin/fighter_nn/mod.rs +++ b/gemla/src/bin/fighter_nn/mod.rs @@ -1,6 +1,7 @@ extern crate fann; pub mod neural_network_utility; +pub mod fighter_context; use std::{fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, path::{Path, PathBuf}, sync::Arc}; use fann::{ActivationFunc, Fann}; @@ -63,8 +64,10 @@ pub struct FighterNN { #[async_trait] impl GeneticNode for FighterNN { + type Context = fighter_context::FighterContext; + // Check for the highest number of the folder name and increment it by 1 - fn initialize(context: GeneticNodeContext) -> Result, Error> { + async fn initialize(context: GeneticNodeContext) -> Result, Error> { let base_path = PathBuf::from(BASE_DIR); let folder = base_path.join(format!("fighter_nn_{:06}", context.id)); @@ -127,100 +130,37 @@ impl GeneticNode for FighterNN { })) } - async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> { + async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> { debug!("Context: {:?}", context); let mut tasks = Vec::new(); // For each nn in the current generation: for i in 0..self.population_size { let self_clone = self.clone(); - let semaphore_clone = Arc::clone(context.semaphore.as_ref().unwrap()); + let semaphore_clone = context.gemla_context.shared_semaphore.clone(); let task = async move { - let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(format!("{:06}_fighter_nn_{}.net", self_clone.id, i)); + let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(self_clone.get_individual_id(i as u64)); let mut simulations = Vec::new(); // Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently for _ in 0..SIMULATION_ROUNDS { let random_nn_index = thread_rng().gen_range(0..self_clone.population_size); - let id = self_clone.id.clone(); let folder = self_clone.folder.clone(); let generation = self_clone.generation; let semaphore_clone = Arc::clone(&semaphore_clone); - let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index)); + let random_nn = folder.join(format!("{}", generation)).join(self_clone.get_individual_id(random_nn_index as u64)); let nn_clone = nn.clone(); // Clone the path to use in the async block - let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap()); - let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap()); - let disable_unreal_rendering_arg = "-nullrhi".to_string(); - - - let future = async move { let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?; - // Construct the score file path - let nn_id = format!("{:06}_fighter_nn_{}", id, i); - let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index); - let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id); - let score_file = folder.join(format!("{}", generation)).join(&score_file_name); - - // Check if score file already exists before running the simulation - if score_file.exists() { - let round_score = read_score_from_file(&score_file, &nn_id).await - .with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?; - - debug!("{} scored {}", nn_id, round_score); - - return Ok::(round_score); - } - - // Check if the opposite round score has been determined - let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id)); - if opposite_score_file.exists() { - let round_score = read_score_from_file(&opposite_score_file, &nn_id).await - .with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?; - - debug!("{} scored {}", nn_id, round_score); - - return Ok::(1.0 - round_score); - } - - // Run simulation until score file is generated - while !score_file.exists() { - let _output = if thread_rng().gen_range(0..100) < 1 { - Command::new(GAME_EXECUTABLE_PATH) - .arg(&config1_arg) - .arg(&config2_arg) - .output() - .await - .expect("Failed to execute game") - } else { - Command::new(GAME_EXECUTABLE_PATH) - .arg(&config1_arg) - .arg(&config2_arg) - .arg(&disable_unreal_rendering_arg) - .output() - .await - .expect("Failed to execute game") - }; - } + let score = run_1v1_simulation(&nn_clone, &random_nn).await?; drop(permit); - // Read the score from the file - if score_file.exists() { - let round_score = read_score_from_file(&score_file, &nn_id).await - .with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?; - - debug!("{} scored {}", nn_id, round_score); - - Ok(round_score) - } else { - warn!("Score file not found: {:?}", score_file_name); - Ok(0.0) - } + Ok(score) }; simulations.push(future); @@ -259,7 +199,7 @@ impl GeneticNode for FighterNN { } - fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize; // Create the new generation folder @@ -322,7 +262,7 @@ impl GeneticNode for FighterNN { Ok(()) } - fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid) -> Result, Error> { + async fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid, _: Self::Context) -> Result, Error> { let base_path = PathBuf::from(BASE_DIR); let folder = base_path.join(format!("fighter_nn_{:06}", id)); @@ -365,6 +305,78 @@ impl GeneticNode for FighterNN { } } +impl FighterNN { + pub fn get_individual_id(&self, nn_id: u64) -> String { + format!("{:06}_fighter_nn_{}", self.id, nn_id) + } +} + +async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result { + // Construct the score file path + let base_folder = nn_path_1.parent().unwrap(); + let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap(); + let nn_2_id = nn_path_2.file_stem().unwrap().to_str().unwrap(); + let score_file = base_folder.join(format!("{}_vs_{}.txt", nn_1_id, nn_2_id)); + + // Check if score file already exists before running the simulation + if score_file.exists() { + let round_score = read_score_from_file(&score_file, &nn_1_id).await + .with_context(|| format!("Failed to read score from file: {:?}", score_file))?; + + trace!("{} scored {}", nn_1_id, round_score); + + return Ok::(round_score); + } + + // Check if the opposite round score has been determined + let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id)); + if opposite_score_file.exists() { + let round_score = read_score_from_file(&opposite_score_file, &nn_1_id).await + .with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?; + + trace!("{} scored {}", nn_1_id, round_score); + + return Ok::(1.0 - round_score); + } + + // Run simulation until score file is generated + let config1_arg = format!("-NN1Config=\"{}\"", nn_path_1.to_str().unwrap()); + let config2_arg = format!("-NN2Config=\"{}\"", nn_path_2.to_str().unwrap()); + let disable_unreal_rendering_arg = "-nullrhi".to_string(); + + while !score_file.exists() { + let _output = if thread_rng().gen_range(0..100) < 1 { + Command::new(GAME_EXECUTABLE_PATH) + .arg(&config1_arg) + .arg(&config2_arg) + .output() + .await + .expect("Failed to execute game") + } else { + Command::new(GAME_EXECUTABLE_PATH) + .arg(&config1_arg) + .arg(&config2_arg) + .arg(&disable_unreal_rendering_arg) + .output() + .await + .expect("Failed to execute game") + }; + } + + // Read the score from the file + if score_file.exists() { + let round_score = read_score_from_file(&score_file, &nn_1_id).await + .with_context(|| format!("Failed to read score from file: {:?}", score_file))?; + + trace!("{} scored {}", nn_1_id, round_score); + + Ok(round_score) + } else { + warn!("Score file not found: {:?}", score_file); + Ok(0.0) + } +} + async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result { let mut attempts = 0; diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index 73df36e..c11c668 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -14,7 +14,9 @@ pub struct TestState { #[async_trait] impl GeneticNode for TestState { - fn initialize(_context: GeneticNodeContext) -> Result, Error> { + type Context = (); + + async fn initialize(_context: GeneticNodeContext) -> Result, Error> { let mut population: Vec = vec![]; for _ in 0..POPULATION_SIZE { @@ -24,7 +26,7 @@ impl GeneticNode for TestState { Ok(Box::new(TestState { population })) } - async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { let mut rng = thread_rng(); self.population = self @@ -36,7 +38,7 @@ impl GeneticNode for TestState { Ok(()) } - fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { let mut rng = thread_rng(); let mut v = self.population.clone(); @@ -74,7 +76,7 @@ impl GeneticNode for TestState { Ok(()) } - fn merge(left: &TestState, right: &TestState, id: &Uuid) -> Result, Error> { + async fn merge(left: &TestState, right: &TestState, id: &Uuid, gemla_context: Self::Context) -> Result, Error> { let mut v = left.population.clone(); v.append(&mut right.population.clone()); @@ -89,8 +91,8 @@ impl GeneticNode for TestState { id: id.clone(), generation: 0, max_generations: 0, - semaphore: None, - })?; + gemla_context: gemla_context + }).await?; Ok(Box::new(result)) } @@ -101,16 +103,16 @@ mod tests { use super::*; use gemla::core::genetic_node::GeneticNode; - #[test] - fn test_initialize() { + #[tokio::test] + async fn test_initialize() { let state = TestState::initialize( GeneticNodeContext { id: Uuid::new_v4(), generation: 0, max_generations: 0, - semaphore: None, + gemla_context: (), } - ).unwrap(); + ).await.unwrap(); assert_eq!(state.population.len(), POPULATION_SIZE as usize); } @@ -128,7 +130,7 @@ mod tests { id: Uuid::new_v4(), generation: 0, max_generations: 0, - semaphore: None, + gemla_context: (), } ).await.unwrap(); assert!(original_population @@ -141,7 +143,7 @@ mod tests { id: Uuid::new_v4(), generation: 0, max_generations: 0, - semaphore: None, + gemla_context: (), } ).await.unwrap(); state.simulate( @@ -149,7 +151,7 @@ mod tests { id: Uuid::new_v4(), generation: 0, max_generations: 0, - semaphore: None, + gemla_context: (), } ).await.unwrap(); assert!(original_population @@ -158,8 +160,8 @@ mod tests { .all(|(&a, &b)| b >= a - 3 && b <= a + 6)) } - #[test] - fn test_mutate() { + #[tokio::test] + async fn test_mutate() { let mut state = TestState { population: vec![4, 3, 3], }; @@ -169,15 +171,15 @@ mod tests { id: Uuid::new_v4(), generation: 0, max_generations: 0, - semaphore: None, + gemla_context: (), } - ).unwrap(); + ).await.unwrap(); assert_eq!(state.population.len(), POPULATION_SIZE as usize); } - #[test] - fn test_merge() { + #[tokio::test] + async fn test_merge() { let state1 = TestState { population: vec![1, 2, 4, 5], }; @@ -186,7 +188,7 @@ mod tests { population: vec![0, 1, 3, 7], }; - let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4()).unwrap(); + let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()).await.unwrap(); assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize); assert!(merged_state.population.iter().any(|&x| x == 7)); diff --git a/gemla/src/core/genetic_node.rs b/gemla/src/core/genetic_node.rs index 0838dc8..eec30bc 100644 --- a/gemla/src/core/genetic_node.rs +++ b/gemla/src/core/genetic_node.rs @@ -5,9 +5,8 @@ use crate::error::Error; use anyhow::Context; -use serde::{Deserialize, Serialize}; -use tokio::sync::Semaphore; -use std::{fmt::Debug, sync::Arc}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::fmt::Debug; use uuid::Uuid; use async_trait::async_trait; @@ -27,33 +26,35 @@ pub enum GeneticState { } #[derive(Clone, Debug)] -pub struct GeneticNodeContext { +pub struct GeneticNodeContext { pub generation: u64, pub max_generations: u64, pub id: Uuid, - pub semaphore: Option>, + pub gemla_context: S } /// A trait used to interact with the internal state of nodes within the [`Bracket`] /// /// [`Bracket`]: crate::bracket::Bracket #[async_trait] -pub trait GeneticNode: Send { +pub trait GeneticNode : Send { + type Context; + /// Initializes a new instance of a [`GeneticState`]. /// /// # Examples /// TODO - fn initialize(context: GeneticNodeContext) -> Result, Error>; + async fn initialize(context: GeneticNodeContext) -> Result, Error>; - async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; + async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; /// Mutates members in a population and/or crossbreeds them to produce new offspring. /// /// # Examples /// TODO - fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; + async fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; - fn merge(left: &Self, right: &Self, id: &Uuid) -> Result, Error>; + async fn merge(left: &Self, right: &Self, id: &Uuid, context: Self::Context) -> Result, Error>; } /// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as @@ -82,6 +83,7 @@ impl Default for GeneticNodeWrapper { impl GeneticNodeWrapper where T: GeneticNode + Debug + Send, + T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default, { pub fn new(max_generations: u64) -> Self { GeneticNodeWrapper:: { @@ -120,17 +122,17 @@ where self.state } - pub async fn process_node(&mut self, semaphore: Arc) -> Result { + pub async fn process_node(&mut self, gemla_context: T::Context) -> Result { let context = GeneticNodeContext { generation: self.generation, max_generations: self.max_generations, id: self.id, - semaphore: Some(semaphore), + gemla_context, }; match (self.state, &mut self.node) { (GeneticState::Initialize, _) => { - self.node = Some(*T::initialize(context.clone())?); + self.node = Some(*T::initialize(context.clone()).await?); self.state = GeneticState::Simulate; } (GeneticState::Simulate, Some(n)) => { @@ -144,7 +146,7 @@ where }; } (GeneticState::Mutate, Some(n)) => { - n.mutate(context.clone()) + n.mutate(context.clone()).await .with_context(|| format!("Error mutating node: {:?}", self))?; self.generation += 1; @@ -172,20 +174,22 @@ mod tests { #[async_trait] impl GeneticNode for TestState { - async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + type Context = (); + + async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { self.score += 1.0; Ok(()) } - fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { Ok(()) } - fn initialize(_context: GeneticNodeContext) -> Result, Error> { + async fn initialize(_context: GeneticNodeContext) -> Result, Error> { Ok(Box::new(TestState { score: 0.0 })) } - fn merge(_l: &TestState, _r: &TestState, _id: &Uuid) -> Result, Error> { + async fn merge(_l: &TestState, _r: &TestState, _id: &Uuid, _: Self::Context) -> Result, Error> { Err(Error::Other(anyhow!("Unable to merge"))) } } @@ -281,14 +285,13 @@ mod tests { #[tokio::test] async fn test_process_node() -> Result<(), Error> { let mut genetic_node = GeneticNodeWrapper::::new(2); - let semaphore = Arc::new(Semaphore::new(1)); assert_eq!(genetic_node.state(), GeneticState::Initialize); - assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate); - assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Mutate); - assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate); - assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish); - assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish); + assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate); + assert_eq!(genetic_node.process_node(()).await?, GeneticState::Mutate); + assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate); + assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish); + assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish); Ok(()) } diff --git a/gemla/src/core/mod.rs b/gemla/src/core/mod.rs index b3e7faa..9d831a3 100644 --- a/gemla/src/core/mod.rs +++ b/gemla/src/core/mod.rs @@ -4,15 +4,15 @@ pub mod genetic_node; use crate::{error::Error, tree::Tree}; +use async_recursion::async_recursion; use file_linked::{constants::data_format::DataFormat, FileLinked}; -use futures::future; +use futures::{executor::block_on, future}; use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState}; use log::{info, trace, warn}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::task::JoinHandle; -use tokio::sync::Semaphore; use std::{ - collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant + collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, time::Instant }; use uuid::Uuid; @@ -58,7 +58,6 @@ type SimulationTree = Box>>; pub struct GemlaConfig { pub generations_per_height: u64, pub overwrite: bool, - pub shared_semaphore_concurrency_limit: usize, } /// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`]. @@ -69,16 +68,17 @@ pub struct GemlaConfig { /// [`GeneticNode`]: genetic_node::GeneticNode pub struct Gemla where - T: Serialize + Clone, + T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send, + T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default, { - pub data: FileLinked<(Option>, GemlaConfig)>, + pub data: FileLinked<(Option>, GemlaConfig, T::Context)>, threads: HashMap, Error>>>, - semaphore: Arc, } impl Gemla where T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send, + T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default, { pub fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result { match File::open(path) { @@ -86,18 +86,16 @@ where // based on the configuration provided Ok(_) => Ok(Gemla { data: if config.overwrite { - FileLinked::new((None, config), path, data_format)? + FileLinked::new((None, config, T::Context::default()), path, data_format)? } else { FileLinked::from_file(path, data_format)? }, threads: HashMap::new(), - semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)), }), // If the file doesn't exist we must create it Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla { - data: FileLinked::new((None, config), path, data_format)?, + data: FileLinked::new((None, config, T::Context::default()), path, data_format)?, threads: HashMap::new(), - semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)), }), Err(error) => Err(Error::IO(error)), } @@ -117,7 +115,7 @@ where { // Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed // in the tree and which nodes have not. - self.data.mutate(|(d, c)| { + self.data.mutate(|(d, c, _)| { let mut tree: Option> = Gemla::increase_height(d.take(), c, steps); mem::swap(d, &mut tree); })?; @@ -151,11 +149,11 @@ where { trace!("Adding node to process list {}", node.id()); - let semaphore = self.semaphore.clone(); + let gemla_context = self.data.readonly().2.clone(); self.threads .insert(node.id(), tokio::spawn(async move { - Gemla::process_node(node, semaphore).await + Gemla::process_node(node, gemla_context).await })); } else { trace!("No node found to process, joining threads"); @@ -180,7 +178,7 @@ where // We need to retrieve the processed nodes from the resulting list and replace them in the original list reduced_results.and_then(|r| { - self.data.mutate(|(d, _)| { + self.data.mutate(|(d, _, context)| { if let Some(t) = d { let failed_nodes = Gemla::replace_nodes(t, r); // We receive a list of nodes that were unable to be found in the original tree @@ -192,7 +190,7 @@ where } // Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes - Gemla::merge_completed_nodes(t) + block_on(Gemla::merge_completed_nodes(t, context.clone())) } else { warn!("Unable to replce nodes {:?} in empty tree", r); Ok(()) @@ -204,7 +202,8 @@ where Ok(()) } - fn merge_completed_nodes(tree: &mut SimulationTree) -> Result<(), Error> { + #[async_recursion] + async fn merge_completed_nodes(tree: &mut SimulationTree, gemla_context: T::Context) -> Result<(), Error> { if tree.val.state() == GeneticState::Initialize { match (&mut tree.left, &mut tree.right) { // If the current node has been initialized, and has children nodes that are completed, then we need @@ -215,7 +214,7 @@ where { info!("Merging nodes {} and {}", l.val.id(), r.val.id()); if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) { - let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id())?; + let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id(), gemla_context.clone()).await?; tree.val = GeneticNodeWrapper::from( *merged_node, tree.val.max_generations(), @@ -224,8 +223,8 @@ where } } (Some(l), Some(r)) => { - Gemla::merge_completed_nodes(l)?; - Gemla::merge_completed_nodes(r)?; + Gemla::merge_completed_nodes(l, gemla_context.clone()).await?; + Gemla::merge_completed_nodes(r, gemla_context.clone()).await?; } // If there is only one child node that's completed then we want to copy it to the parent node (Some(l), None) if l.val.state() == GeneticState::Finish => { @@ -239,7 +238,7 @@ where ); } } - (Some(l), None) => Gemla::merge_completed_nodes(l)?, + (Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?, (None, Some(r)) if r.val.state() == GeneticState::Finish => { trace!("Copying node {}", r.val.id()); @@ -251,7 +250,7 @@ where ); } } - (None, Some(r)) => Gemla::merge_completed_nodes(r)?, + (None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?, (_, _) => (), } } @@ -329,15 +328,15 @@ where tree.val.state() == GeneticState::Finish } - async fn process_node(mut node: GeneticNodeWrapper, semaphore: Arc) -> Result, Error> { + async fn process_node(mut node: GeneticNodeWrapper, gemla_context: T::Context) -> Result, Error> { let node_state_time = Instant::now(); let node_state = node.state(); - node.process_node(semaphore.clone()).await?; + node.process_node(gemla_context.clone()).await?; if node.state() == GeneticState::Simulate { - node.process_node(semaphore.clone()).await?; + node.process_node(gemla_context.clone()).await?; } trace!( @@ -397,20 +396,22 @@ mod tests { #[async_trait] impl genetic_node::GeneticNode for TestState { - async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + type Context = (); + + async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { self.score += 1.0; Ok(()) } - fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { Ok(()) } - fn initialize(_context: GeneticNodeContext) -> Result, Error> { + async fn initialize(_context: GeneticNodeContext) -> Result, Error> { Ok(Box::new(TestState { score: 0.0 })) } - fn merge(left: &TestState, right: &TestState, _id: &Uuid) -> Result, Error> { + async fn merge(left: &TestState, right: &TestState, _id: &Uuid, _: Self::Context) -> Result, Error> { Ok(Box::new(if left.score > right.score { left.clone() } else { @@ -433,7 +434,6 @@ mod tests { let mut config = GemlaConfig { generations_per_height: 1, overwrite: true, - shared_semaphore_concurrency_limit: 1, }; let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; @@ -483,7 +483,6 @@ mod tests { let config = GemlaConfig { generations_per_height: 10, overwrite: true, - shared_semaphore_concurrency_limit: 1, }; let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?;