Adding global sempahore to better control resources.

This commit is contained in:
vandomej 2024-03-21 12:55:42 -07:00
parent 97086fdbe0
commit ac71b28c7c
5 changed files with 140 additions and 87 deletions

View file

@ -16,7 +16,7 @@ use fighter_nn::FighterNN;
use clap::Parser;
use anyhow::Result;
// const NUM_THREADS: usize = 12;
// const NUM_THREADS: usize = 2;
#[derive(Parser)]
#[command(version, about, long_about = None)]
@ -47,6 +47,7 @@ fn main() -> Result<()> {
GemlaConfig {
generations_per_height: 10,
overwrite: false,
shared_semaphore_concurrency_limit: 30,
},
DataFormat::Json,
))?;

View file

@ -1,6 +1,6 @@
extern crate fann;
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}};
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}, sync::Arc};
use fann::{ActivationFunc, Fann};
use futures::future::join_all;
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
@ -8,9 +8,9 @@ use rand::prelude::*;
use rand::distributions::{Distribution, Uniform};
use serde::{Deserialize, Serialize};
use anyhow::Context;
use tokio::process::Command;
use uuid::Uuid;
use std::collections::HashMap;
use tokio::process::Command;
use async_trait::async_trait;
const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations";
@ -79,82 +79,116 @@ impl GeneticNode for FighterNN {
}))
}
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> {
trace!("Context: {:?}", context);
let mut tasks = Vec::new();
// For each nn in the current generation:
for i in 0..self.population_size {
// load the nn
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, i));
let mut simulations = Vec::new();
let self_clone = self.clone();
let semaphore_clone = Arc::clone(context.semaphore.as_ref().unwrap());
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
for _ in 0..SIMULATION_ROUNDS {
let random_nn_index = thread_rng().gen_range(0..self.population_size);
let id = self.id.clone();
let folder = self.folder.clone();
let generation = self.generation;
let task = async move {
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(format!("{:06}_fighter_nn_{}.net", self_clone.id, i));
let mut simulations = Vec::new();
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index));
let nn_clone = nn.clone(); // Clone the path to use in the async block
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
for _ in 0..SIMULATION_ROUNDS {
let random_nn_index = thread_rng().gen_range(0..self_clone.population_size);
let id = self_clone.id.clone();
let folder = self_clone.folder.clone();
let generation = self_clone.generation;
let semaphore_clone = Arc::clone(&semaphore_clone);
let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap());
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
let disable_unreal_rendering_arg = "-nullrhi".to_string();
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index));
let nn_clone = nn.clone(); // Clone the path to use in the async block
let future = async move {
// Construct the score file path
let nn_id = format!("{:06}_fighter_nn_{}", id, i);
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id);
let score_file = folder.join(format!("{}", generation)).join(&score_file_name);
let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap());
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
let disable_unreal_rendering_arg = "-nullrhi".to_string();
// Check if score file already exists before running the simulation
if score_file.exists() {
let future = async move {
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
// Construct the score file path
let nn_id = format!("{:06}_fighter_nn_{}", id, i);
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id);
let score_file = folder.join(format!("{}", generation)).join(&score_file_name);
// Check if score file already exists before running the simulation
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_id)
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
return Ok::<f32, Error>(round_score);
}
// Check if the opposite round score has been determined
let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id));
if opposite_score_file.exists() {
let round_score = read_score_from_file(&opposite_score_file, &nn_id)
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
return Ok::<f32, Error>(1.0 - round_score);
}
let _output = if thread_rng().gen_range(0..100) < 0 {
Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.output()
.await
.expect("Failed to execute game")
} else {
Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.arg(&disable_unreal_rendering_arg)
.output()
.await
.expect("Failed to execute game")
};
drop(permit);
// Read the score from the file
let round_score = read_score_from_file(&score_file, &nn_id)
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
return Ok::<f32, Error>(round_score);
}
// Check if the opposite round score has been determined
let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id));
if opposite_score_file.exists() {
let round_score = read_score_from_file(&opposite_score_file, &nn_id)
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
return Ok::<f32, Error>(1.0 - round_score);
}
Ok::<f32, Error>(round_score)
};
if thread_rng().gen_range(0..100) < 4 {
let _output = Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.output()
.await
.expect("Failed to execute game");
} else {
let _output = Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.arg(&disable_unreal_rendering_arg)
.output()
.await
.expect("Failed to execute game");
}
simulations.push(future);
}
// Read the score from the file
let round_score = read_score_from_file(&score_file, &nn_id)
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
// Wait for all simulation rounds to complete
let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect();
Ok::<f32, Error>(round_score)
let score = match results {
Ok(scores) => scores.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32,
Err(e) => return Err(e), // Return the error if results collection failed
};
trace!("NN {:06}_fighter_nn_{} scored {}", self_clone.id, i, score);
Ok((i, score))
};
simulations.push(future);
tasks.push(task);
}
let results = join_all(tasks).await;
for result in results {
match result {
Ok((index, score)) => {
// Update the original `self` object with the score.
self.scores[self.generation as usize].insert(index as u64, score);
},
Err(e) => {
// Handle task panic or execution error
return Err(Error::Other(anyhow::anyhow!(format!("Task failed: {:?}", e))));
},
}
// Wait for all simulation rounds to complete
let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect();
let score = results?.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32;
trace!("NN {:06}_fighter_nn_{} scored {}", self.id, i, score);
self.scores[self.generation as usize].insert(i as u64, score);
}
Ok(())
@ -228,9 +262,9 @@ impl GeneticNode for FighterNN {
if thread_rng().gen_range(0..100) < 20 {
c.weight += thread_rng().gen_range(-0.1..0.1);
}
// else if thread_rng().gen_range(0..100) < 5 {
// c.weight += thread_rng().gen_range(-0.3..0.3);
// }
else if thread_rng().gen_range(0..100) < 5 {
c.weight += thread_rng().gen_range(-0.3..0.3);
}
}
fann.set_connections(&connections);

View file

@ -89,6 +89,7 @@ impl GeneticNode for TestState {
id: id.clone(),
generation: 0,
max_generations: 0,
semaphore: None,
})?;
Ok(Box::new(result))
@ -107,6 +108,7 @@ mod tests {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
semaphore: None,
}
).unwrap();
@ -126,6 +128,7 @@ mod tests {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
semaphore: None,
}
).await.unwrap();
assert!(original_population
@ -138,6 +141,7 @@ mod tests {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
semaphore: None,
}
).await.unwrap();
state.simulate(
@ -145,6 +149,7 @@ mod tests {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
semaphore: None,
}
).await.unwrap();
assert!(original_population
@ -164,6 +169,7 @@ mod tests {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
semaphore: None,
}
).unwrap();

View file

@ -6,7 +6,8 @@ use crate::error::Error;
use anyhow::Context;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use tokio::sync::Semaphore;
use std::{fmt::Debug, sync::Arc};
use uuid::Uuid;
use async_trait::async_trait;
@ -25,11 +26,12 @@ pub enum GeneticState {
Finish,
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct GeneticNodeContext {
pub generation: u64,
pub max_generations: u64,
pub id: Uuid,
pub semaphore: Option<Arc<Semaphore>>,
}
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
@ -118,11 +120,12 @@ where
self.state
}
pub async fn process_node(&mut self) -> Result<GeneticState, Error> {
pub async fn process_node(&mut self, semaphore: Arc<Semaphore>) -> Result<GeneticState, Error> {
let context = GeneticNodeContext {
generation: self.generation,
max_generations: self.max_generations,
id: self.id,
semaphore: Some(semaphore),
};
match (self.state, &mut self.node) {
@ -278,13 +281,14 @@ mod tests {
#[tokio::test]
async fn test_process_node() -> Result<(), Error> {
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
let semaphore = Arc::new(Semaphore::new(1));
assert_eq!(genetic_node.state(), GeneticState::Initialize);
assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node().await?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node().await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node().await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
Ok(())
}

View file

@ -10,9 +10,9 @@ use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
use log::{info, trace, warn};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::task::JoinHandle;
use tokio::sync::Semaphore;
use std::{
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
time::Instant,
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant
};
use uuid::Uuid;
@ -58,6 +58,7 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
pub struct GemlaConfig {
pub generations_per_height: u64,
pub overwrite: bool,
pub shared_semaphore_concurrency_limit: usize,
}
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
@ -72,6 +73,7 @@ where
{
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
semaphore: Arc<Semaphore>,
}
impl<T: 'static> Gemla<T>
@ -89,11 +91,13 @@ where
FileLinked::from_file(path, data_format)?
},
threads: HashMap::new(),
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
}),
// If the file doesn't exist we must create it
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
data: FileLinked::new((None, config), path, data_format)?,
threads: HashMap::new(),
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
}),
Err(error) => Err(Error::IO(error)),
}
@ -147,9 +151,11 @@ where
{
trace!("Adding node to process list {}", node.id());
let semaphore = self.semaphore.clone();
self.threads
.insert(node.id(), tokio::spawn(async move {
Gemla::process_node(node).await
Gemla::process_node(node, semaphore).await
}));
} else {
trace!("No node found to process, joining threads");
@ -323,15 +329,15 @@ where
tree.val.state() == GeneticState::Finish
}
async fn process_node(mut node: GeneticNodeWrapper<T>) -> Result<GeneticNodeWrapper<T>, Error> {
async fn process_node(mut node: GeneticNodeWrapper<T>, semaphore: Arc<Semaphore>) -> Result<GeneticNodeWrapper<T>, Error> {
let node_state_time = Instant::now();
let node_state = node.state();
node.process_node().await?;
node.process_node(semaphore.clone()).await?;
if node.state() == GeneticState::Simulate
{
node.process_node().await?;
node.process_node(semaphore.clone()).await?;
}
trace!(
@ -427,6 +433,7 @@ mod tests {
let mut config = GemlaConfig {
generations_per_height: 1,
overwrite: true,
shared_semaphore_concurrency_limit: 1,
};
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
@ -476,6 +483,7 @@ mod tests {
let config = GemlaConfig {
generations_per_height: 10,
overwrite: true,
shared_semaphore_concurrency_limit: 1,
};
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;