From ac71b28c7c38348f767390973f134b38477d4986 Mon Sep 17 00:00:00 2001 From: vandomej Date: Thu, 21 Mar 2024 12:55:42 -0700 Subject: [PATCH] Adding global sempahore to better control resources. --- gemla/src/bin/bin.rs | 3 +- gemla/src/bin/fighter_nn/mod.rs | 178 +++++++++++++++++++------------- gemla/src/bin/test_state/mod.rs | 6 ++ gemla/src/core/genetic_node.rs | 20 ++-- gemla/src/core/mod.rs | 20 ++-- 5 files changed, 140 insertions(+), 87 deletions(-) diff --git a/gemla/src/bin/bin.rs b/gemla/src/bin/bin.rs index afb2ba7..200bb2e 100644 --- a/gemla/src/bin/bin.rs +++ b/gemla/src/bin/bin.rs @@ -16,7 +16,7 @@ use fighter_nn::FighterNN; use clap::Parser; use anyhow::Result; -// const NUM_THREADS: usize = 12; +// const NUM_THREADS: usize = 2; #[derive(Parser)] #[command(version, about, long_about = None)] @@ -47,6 +47,7 @@ fn main() -> Result<()> { GemlaConfig { generations_per_height: 10, overwrite: false, + shared_semaphore_concurrency_limit: 30, }, DataFormat::Json, ))?; diff --git a/gemla/src/bin/fighter_nn/mod.rs b/gemla/src/bin/fighter_nn/mod.rs index bb641d8..2de6b96 100644 --- a/gemla/src/bin/fighter_nn/mod.rs +++ b/gemla/src/bin/fighter_nn/mod.rs @@ -1,6 +1,6 @@ extern crate fann; -use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}}; +use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}, sync::Arc}; use fann::{ActivationFunc, Fann}; use futures::future::join_all; use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error}; @@ -8,9 +8,9 @@ use rand::prelude::*; use rand::distributions::{Distribution, Uniform}; use serde::{Deserialize, Serialize}; use anyhow::Context; +use tokio::process::Command; use uuid::Uuid; use std::collections::HashMap; -use tokio::process::Command; use async_trait::async_trait; const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations"; @@ -79,82 +79,116 @@ impl GeneticNode for FighterNN { })) } - async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> { + trace!("Context: {:?}", context); + let mut tasks = Vec::new(); + // For each nn in the current generation: for i in 0..self.population_size { - // load the nn - let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, i)); - let mut simulations = Vec::new(); - - // Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently - for _ in 0..SIMULATION_ROUNDS { - let random_nn_index = thread_rng().gen_range(0..self.population_size); - let id = self.id.clone(); - let folder = self.folder.clone(); - let generation = self.generation; + let self_clone = self.clone(); + let semaphore_clone = Arc::clone(context.semaphore.as_ref().unwrap()); - let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index)); - let nn_clone = nn.clone(); // Clone the path to use in the async block - - let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap()); - let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap()); - let disable_unreal_rendering_arg = "-nullrhi".to_string(); - - let future = async move { - // Construct the score file path - let nn_id = format!("{:06}_fighter_nn_{}", id, i); - let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index); - let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id); - let score_file = folder.join(format!("{}", generation)).join(&score_file_name); + let task = async move { + let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(format!("{:06}_fighter_nn_{}.net", self_clone.id, i)); + let mut simulations = Vec::new(); + + // Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently + for _ in 0..SIMULATION_ROUNDS { + let random_nn_index = thread_rng().gen_range(0..self_clone.population_size); + let id = self_clone.id.clone(); + let folder = self_clone.folder.clone(); + let generation = self_clone.generation; + let semaphore_clone = Arc::clone(&semaphore_clone); - // Check if score file already exists before running the simulation - if score_file.exists() { + let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index)); + let nn_clone = nn.clone(); // Clone the path to use in the async block + + let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap()); + let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap()); + let disable_unreal_rendering_arg = "-nullrhi".to_string(); + + + + let future = async move { + let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?; + + // Construct the score file path + let nn_id = format!("{:06}_fighter_nn_{}", id, i); + let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index); + let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id); + let score_file = folder.join(format!("{}", generation)).join(&score_file_name); + + // Check if score file already exists before running the simulation + if score_file.exists() { + let round_score = read_score_from_file(&score_file, &nn_id) + .with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?; + return Ok::(round_score); + } + + // Check if the opposite round score has been determined + let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id)); + if opposite_score_file.exists() { + let round_score = read_score_from_file(&opposite_score_file, &nn_id) + .with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?; + return Ok::(1.0 - round_score); + } + + let _output = if thread_rng().gen_range(0..100) < 0 { + Command::new(GAME_EXECUTABLE_PATH) + .arg(&config1_arg) + .arg(&config2_arg) + .output() + .await + .expect("Failed to execute game") + } else { + Command::new(GAME_EXECUTABLE_PATH) + .arg(&config1_arg) + .arg(&config2_arg) + .arg(&disable_unreal_rendering_arg) + .output() + .await + .expect("Failed to execute game") + }; + + drop(permit); + + // Read the score from the file let round_score = read_score_from_file(&score_file, &nn_id) .with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?; - return Ok::(round_score); - } - // Check if the opposite round score has been determined - let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id)); - if opposite_score_file.exists() { - let round_score = read_score_from_file(&opposite_score_file, &nn_id) - .with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?; - return Ok::(1.0 - round_score); - } - - if thread_rng().gen_range(0..100) < 4 { - let _output = Command::new(GAME_EXECUTABLE_PATH) - .arg(&config1_arg) - .arg(&config2_arg) - .output() - .await - .expect("Failed to execute game"); - } else { - let _output = Command::new(GAME_EXECUTABLE_PATH) - .arg(&config1_arg) - .arg(&config2_arg) - .arg(&disable_unreal_rendering_arg) - .output() - .await - .expect("Failed to execute game"); - } - - // Read the score from the file - let round_score = read_score_from_file(&score_file, &nn_id) - .with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?; - - Ok::(round_score) + Ok::(round_score) + }; + + simulations.push(future); + } + + // Wait for all simulation rounds to complete + let results: Result, Error> = join_all(simulations).await.into_iter().collect(); + + let score = match results { + Ok(scores) => scores.into_iter().sum::() / SIMULATION_ROUNDS as f32, + Err(e) => return Err(e), // Return the error if results collection failed }; - - simulations.push(future); + trace!("NN {:06}_fighter_nn_{} scored {}", self_clone.id, i, score); + Ok((i, score)) + }; + + tasks.push(task); + } + + let results = join_all(tasks).await; + + for result in results { + match result { + Ok((index, score)) => { + // Update the original `self` object with the score. + self.scores[self.generation as usize].insert(index as u64, score); + }, + Err(e) => { + // Handle task panic or execution error + return Err(Error::Other(anyhow::anyhow!(format!("Task failed: {:?}", e)))); + }, } - - // Wait for all simulation rounds to complete - let results: Result, Error> = join_all(simulations).await.into_iter().collect(); - - let score = results?.into_iter().sum::() / SIMULATION_ROUNDS as f32; - trace!("NN {:06}_fighter_nn_{} scored {}", self.id, i, score); - self.scores[self.generation as usize].insert(i as u64, score); } Ok(()) @@ -228,9 +262,9 @@ impl GeneticNode for FighterNN { if thread_rng().gen_range(0..100) < 20 { c.weight += thread_rng().gen_range(-0.1..0.1); } - // else if thread_rng().gen_range(0..100) < 5 { - // c.weight += thread_rng().gen_range(-0.3..0.3); - // } + else if thread_rng().gen_range(0..100) < 5 { + c.weight += thread_rng().gen_range(-0.3..0.3); + } } fann.set_connections(&connections); diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index 0e8d11d..73df36e 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -89,6 +89,7 @@ impl GeneticNode for TestState { id: id.clone(), generation: 0, max_generations: 0, + semaphore: None, })?; Ok(Box::new(result)) @@ -107,6 +108,7 @@ mod tests { id: Uuid::new_v4(), generation: 0, max_generations: 0, + semaphore: None, } ).unwrap(); @@ -126,6 +128,7 @@ mod tests { id: Uuid::new_v4(), generation: 0, max_generations: 0, + semaphore: None, } ).await.unwrap(); assert!(original_population @@ -138,6 +141,7 @@ mod tests { id: Uuid::new_v4(), generation: 0, max_generations: 0, + semaphore: None, } ).await.unwrap(); state.simulate( @@ -145,6 +149,7 @@ mod tests { id: Uuid::new_v4(), generation: 0, max_generations: 0, + semaphore: None, } ).await.unwrap(); assert!(original_population @@ -164,6 +169,7 @@ mod tests { id: Uuid::new_v4(), generation: 0, max_generations: 0, + semaphore: None, } ).unwrap(); diff --git a/gemla/src/core/genetic_node.rs b/gemla/src/core/genetic_node.rs index 92a7f48..0838dc8 100644 --- a/gemla/src/core/genetic_node.rs +++ b/gemla/src/core/genetic_node.rs @@ -6,7 +6,8 @@ use crate::error::Error; use anyhow::Context; use serde::{Deserialize, Serialize}; -use std::fmt::Debug; +use tokio::sync::Semaphore; +use std::{fmt::Debug, sync::Arc}; use uuid::Uuid; use async_trait::async_trait; @@ -25,11 +26,12 @@ pub enum GeneticState { Finish, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct GeneticNodeContext { pub generation: u64, pub max_generations: u64, pub id: Uuid, + pub semaphore: Option>, } /// A trait used to interact with the internal state of nodes within the [`Bracket`] @@ -118,11 +120,12 @@ where self.state } - pub async fn process_node(&mut self) -> Result { + pub async fn process_node(&mut self, semaphore: Arc) -> Result { let context = GeneticNodeContext { generation: self.generation, max_generations: self.max_generations, id: self.id, + semaphore: Some(semaphore), }; match (self.state, &mut self.node) { @@ -278,13 +281,14 @@ mod tests { #[tokio::test] async fn test_process_node() -> Result<(), Error> { let mut genetic_node = GeneticNodeWrapper::::new(2); + let semaphore = Arc::new(Semaphore::new(1)); assert_eq!(genetic_node.state(), GeneticState::Initialize); - assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate); - assert_eq!(genetic_node.process_node().await?, GeneticState::Mutate); - assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate); - assert_eq!(genetic_node.process_node().await?, GeneticState::Finish); - assert_eq!(genetic_node.process_node().await?, GeneticState::Finish); + assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate); + assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Mutate); + assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate); + assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish); + assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish); Ok(()) } diff --git a/gemla/src/core/mod.rs b/gemla/src/core/mod.rs index 42a1fc7..b3e7faa 100644 --- a/gemla/src/core/mod.rs +++ b/gemla/src/core/mod.rs @@ -10,9 +10,9 @@ use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState}; use log::{info, trace, warn}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::task::JoinHandle; +use tokio::sync::Semaphore; use std::{ - collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, - time::Instant, + collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant }; use uuid::Uuid; @@ -58,6 +58,7 @@ type SimulationTree = Box>>; pub struct GemlaConfig { pub generations_per_height: u64, pub overwrite: bool, + pub shared_semaphore_concurrency_limit: usize, } /// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`]. @@ -72,6 +73,7 @@ where { pub data: FileLinked<(Option>, GemlaConfig)>, threads: HashMap, Error>>>, + semaphore: Arc, } impl Gemla @@ -89,11 +91,13 @@ where FileLinked::from_file(path, data_format)? }, threads: HashMap::new(), + semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)), }), // If the file doesn't exist we must create it Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla { data: FileLinked::new((None, config), path, data_format)?, threads: HashMap::new(), + semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)), }), Err(error) => Err(Error::IO(error)), } @@ -147,9 +151,11 @@ where { trace!("Adding node to process list {}", node.id()); + let semaphore = self.semaphore.clone(); + self.threads .insert(node.id(), tokio::spawn(async move { - Gemla::process_node(node).await + Gemla::process_node(node, semaphore).await })); } else { trace!("No node found to process, joining threads"); @@ -323,15 +329,15 @@ where tree.val.state() == GeneticState::Finish } - async fn process_node(mut node: GeneticNodeWrapper) -> Result, Error> { + async fn process_node(mut node: GeneticNodeWrapper, semaphore: Arc) -> Result, Error> { let node_state_time = Instant::now(); let node_state = node.state(); - node.process_node().await?; + node.process_node(semaphore.clone()).await?; if node.state() == GeneticState::Simulate { - node.process_node().await?; + node.process_node(semaphore.clone()).await?; } trace!( @@ -427,6 +433,7 @@ mod tests { let mut config = GemlaConfig { generations_per_height: 1, overwrite: true, + shared_semaphore_concurrency_limit: 1, }; let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; @@ -476,6 +483,7 @@ mod tests { let config = GemlaConfig { generations_per_height: 10, overwrite: true, + shared_semaphore_concurrency_limit: 1, }; let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?;