Adding global sempahore to better control resources.
This commit is contained in:
parent
97086fdbe0
commit
ac71b28c7c
5 changed files with 140 additions and 87 deletions
|
@ -16,7 +16,7 @@ use fighter_nn::FighterNN;
|
|||
use clap::Parser;
|
||||
use anyhow::Result;
|
||||
|
||||
// const NUM_THREADS: usize = 12;
|
||||
// const NUM_THREADS: usize = 2;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(version, about, long_about = None)]
|
||||
|
@ -47,6 +47,7 @@ fn main() -> Result<()> {
|
|||
GemlaConfig {
|
||||
generations_per_height: 10,
|
||||
overwrite: false,
|
||||
shared_semaphore_concurrency_limit: 30,
|
||||
},
|
||||
DataFormat::Json,
|
||||
))?;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
extern crate fann;
|
||||
|
||||
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}};
|
||||
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}, sync::Arc};
|
||||
use fann::{ActivationFunc, Fann};
|
||||
use futures::future::join_all;
|
||||
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
|
||||
|
@ -8,9 +8,9 @@ use rand::prelude::*;
|
|||
use rand::distributions::{Distribution, Uniform};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use anyhow::Context;
|
||||
use tokio::process::Command;
|
||||
use uuid::Uuid;
|
||||
use std::collections::HashMap;
|
||||
use tokio::process::Command;
|
||||
use async_trait::async_trait;
|
||||
|
||||
const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations";
|
||||
|
@ -79,82 +79,116 @@ impl GeneticNode for FighterNN {
|
|||
}))
|
||||
}
|
||||
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> {
|
||||
trace!("Context: {:?}", context);
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
// For each nn in the current generation:
|
||||
for i in 0..self.population_size {
|
||||
// load the nn
|
||||
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, i));
|
||||
let mut simulations = Vec::new();
|
||||
let self_clone = self.clone();
|
||||
let semaphore_clone = Arc::clone(context.semaphore.as_ref().unwrap());
|
||||
|
||||
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
|
||||
for _ in 0..SIMULATION_ROUNDS {
|
||||
let random_nn_index = thread_rng().gen_range(0..self.population_size);
|
||||
let id = self.id.clone();
|
||||
let folder = self.folder.clone();
|
||||
let generation = self.generation;
|
||||
let task = async move {
|
||||
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(format!("{:06}_fighter_nn_{}.net", self_clone.id, i));
|
||||
let mut simulations = Vec::new();
|
||||
|
||||
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index));
|
||||
let nn_clone = nn.clone(); // Clone the path to use in the async block
|
||||
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
|
||||
for _ in 0..SIMULATION_ROUNDS {
|
||||
let random_nn_index = thread_rng().gen_range(0..self_clone.population_size);
|
||||
let id = self_clone.id.clone();
|
||||
let folder = self_clone.folder.clone();
|
||||
let generation = self_clone.generation;
|
||||
let semaphore_clone = Arc::clone(&semaphore_clone);
|
||||
|
||||
let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap());
|
||||
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
|
||||
let disable_unreal_rendering_arg = "-nullrhi".to_string();
|
||||
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index));
|
||||
let nn_clone = nn.clone(); // Clone the path to use in the async block
|
||||
|
||||
let future = async move {
|
||||
// Construct the score file path
|
||||
let nn_id = format!("{:06}_fighter_nn_{}", id, i);
|
||||
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
|
||||
let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id);
|
||||
let score_file = folder.join(format!("{}", generation)).join(&score_file_name);
|
||||
let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap());
|
||||
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
|
||||
let disable_unreal_rendering_arg = "-nullrhi".to_string();
|
||||
|
||||
// Check if score file already exists before running the simulation
|
||||
if score_file.exists() {
|
||||
|
||||
|
||||
let future = async move {
|
||||
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
|
||||
|
||||
// Construct the score file path
|
||||
let nn_id = format!("{:06}_fighter_nn_{}", id, i);
|
||||
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
|
||||
let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id);
|
||||
let score_file = folder.join(format!("{}", generation)).join(&score_file_name);
|
||||
|
||||
// Check if score file already exists before running the simulation
|
||||
if score_file.exists() {
|
||||
let round_score = read_score_from_file(&score_file, &nn_id)
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
|
||||
return Ok::<f32, Error>(round_score);
|
||||
}
|
||||
|
||||
// Check if the opposite round score has been determined
|
||||
let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id));
|
||||
if opposite_score_file.exists() {
|
||||
let round_score = read_score_from_file(&opposite_score_file, &nn_id)
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
|
||||
return Ok::<f32, Error>(1.0 - round_score);
|
||||
}
|
||||
|
||||
let _output = if thread_rng().gen_range(0..100) < 0 {
|
||||
Command::new(GAME_EXECUTABLE_PATH)
|
||||
.arg(&config1_arg)
|
||||
.arg(&config2_arg)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute game")
|
||||
} else {
|
||||
Command::new(GAME_EXECUTABLE_PATH)
|
||||
.arg(&config1_arg)
|
||||
.arg(&config2_arg)
|
||||
.arg(&disable_unreal_rendering_arg)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute game")
|
||||
};
|
||||
|
||||
drop(permit);
|
||||
|
||||
// Read the score from the file
|
||||
let round_score = read_score_from_file(&score_file, &nn_id)
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
|
||||
return Ok::<f32, Error>(round_score);
|
||||
}
|
||||
|
||||
// Check if the opposite round score has been determined
|
||||
let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id));
|
||||
if opposite_score_file.exists() {
|
||||
let round_score = read_score_from_file(&opposite_score_file, &nn_id)
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
|
||||
return Ok::<f32, Error>(1.0 - round_score);
|
||||
}
|
||||
Ok::<f32, Error>(round_score)
|
||||
};
|
||||
|
||||
if thread_rng().gen_range(0..100) < 4 {
|
||||
let _output = Command::new(GAME_EXECUTABLE_PATH)
|
||||
.arg(&config1_arg)
|
||||
.arg(&config2_arg)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute game");
|
||||
} else {
|
||||
let _output = Command::new(GAME_EXECUTABLE_PATH)
|
||||
.arg(&config1_arg)
|
||||
.arg(&config2_arg)
|
||||
.arg(&disable_unreal_rendering_arg)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute game");
|
||||
}
|
||||
simulations.push(future);
|
||||
}
|
||||
|
||||
// Read the score from the file
|
||||
let round_score = read_score_from_file(&score_file, &nn_id)
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
|
||||
// Wait for all simulation rounds to complete
|
||||
let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect();
|
||||
|
||||
Ok::<f32, Error>(round_score)
|
||||
let score = match results {
|
||||
Ok(scores) => scores.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32,
|
||||
Err(e) => return Err(e), // Return the error if results collection failed
|
||||
};
|
||||
trace!("NN {:06}_fighter_nn_{} scored {}", self_clone.id, i, score);
|
||||
Ok((i, score))
|
||||
};
|
||||
|
||||
simulations.push(future);
|
||||
tasks.push(task);
|
||||
}
|
||||
|
||||
let results = join_all(tasks).await;
|
||||
|
||||
for result in results {
|
||||
match result {
|
||||
Ok((index, score)) => {
|
||||
// Update the original `self` object with the score.
|
||||
self.scores[self.generation as usize].insert(index as u64, score);
|
||||
},
|
||||
Err(e) => {
|
||||
// Handle task panic or execution error
|
||||
return Err(Error::Other(anyhow::anyhow!(format!("Task failed: {:?}", e))));
|
||||
},
|
||||
}
|
||||
|
||||
// Wait for all simulation rounds to complete
|
||||
let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect();
|
||||
|
||||
let score = results?.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32;
|
||||
trace!("NN {:06}_fighter_nn_{} scored {}", self.id, i, score);
|
||||
self.scores[self.generation as usize].insert(i as u64, score);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -228,9 +262,9 @@ impl GeneticNode for FighterNN {
|
|||
if thread_rng().gen_range(0..100) < 20 {
|
||||
c.weight += thread_rng().gen_range(-0.1..0.1);
|
||||
}
|
||||
// else if thread_rng().gen_range(0..100) < 5 {
|
||||
// c.weight += thread_rng().gen_range(-0.3..0.3);
|
||||
// }
|
||||
else if thread_rng().gen_range(0..100) < 5 {
|
||||
c.weight += thread_rng().gen_range(-0.3..0.3);
|
||||
}
|
||||
}
|
||||
fann.set_connections(&connections);
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ impl GeneticNode for TestState {
|
|||
id: id.clone(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
})?;
|
||||
|
||||
Ok(Box::new(result))
|
||||
|
@ -107,6 +108,7 @@ mod tests {
|
|||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
}
|
||||
).unwrap();
|
||||
|
||||
|
@ -126,6 +128,7 @@ mod tests {
|
|||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
}
|
||||
).await.unwrap();
|
||||
assert!(original_population
|
||||
|
@ -138,6 +141,7 @@ mod tests {
|
|||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
}
|
||||
).await.unwrap();
|
||||
state.simulate(
|
||||
|
@ -145,6 +149,7 @@ mod tests {
|
|||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
}
|
||||
).await.unwrap();
|
||||
assert!(original_population
|
||||
|
@ -164,6 +169,7 @@ mod tests {
|
|||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
}
|
||||
).unwrap();
|
||||
|
||||
|
|
|
@ -6,7 +6,8 @@ use crate::error::Error;
|
|||
|
||||
use anyhow::Context;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use tokio::sync::Semaphore;
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
use uuid::Uuid;
|
||||
use async_trait::async_trait;
|
||||
|
||||
|
@ -25,11 +26,12 @@ pub enum GeneticState {
|
|||
Finish,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GeneticNodeContext {
|
||||
pub generation: u64,
|
||||
pub max_generations: u64,
|
||||
pub id: Uuid,
|
||||
pub semaphore: Option<Arc<Semaphore>>,
|
||||
}
|
||||
|
||||
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
|
||||
|
@ -118,11 +120,12 @@ where
|
|||
self.state
|
||||
}
|
||||
|
||||
pub async fn process_node(&mut self) -> Result<GeneticState, Error> {
|
||||
pub async fn process_node(&mut self, semaphore: Arc<Semaphore>) -> Result<GeneticState, Error> {
|
||||
let context = GeneticNodeContext {
|
||||
generation: self.generation,
|
||||
max_generations: self.max_generations,
|
||||
id: self.id,
|
||||
semaphore: Some(semaphore),
|
||||
};
|
||||
|
||||
match (self.state, &mut self.node) {
|
||||
|
@ -278,13 +281,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_process_node() -> Result<(), Error> {
|
||||
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
|
||||
let semaphore = Arc::new(Semaphore::new(1));
|
||||
|
||||
assert_eq!(genetic_node.state(), GeneticState::Initialize);
|
||||
assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node().await?, GeneticState::Mutate);
|
||||
assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node().await?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node().await?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Mutate);
|
||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -10,9 +10,9 @@ use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
|
|||
use log::{info, trace, warn};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::sync::Semaphore;
|
||||
use std::{
|
||||
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
|
||||
time::Instant,
|
||||
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -58,6 +58,7 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
|
|||
pub struct GemlaConfig {
|
||||
pub generations_per_height: u64,
|
||||
pub overwrite: bool,
|
||||
pub shared_semaphore_concurrency_limit: usize,
|
||||
}
|
||||
|
||||
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
|
||||
|
@ -72,6 +73,7 @@ where
|
|||
{
|
||||
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
|
||||
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
|
||||
semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl<T: 'static> Gemla<T>
|
||||
|
@ -89,11 +91,13 @@ where
|
|||
FileLinked::from_file(path, data_format)?
|
||||
},
|
||||
threads: HashMap::new(),
|
||||
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
|
||||
}),
|
||||
// If the file doesn't exist we must create it
|
||||
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
|
||||
data: FileLinked::new((None, config), path, data_format)?,
|
||||
threads: HashMap::new(),
|
||||
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
|
||||
}),
|
||||
Err(error) => Err(Error::IO(error)),
|
||||
}
|
||||
|
@ -147,9 +151,11 @@ where
|
|||
{
|
||||
trace!("Adding node to process list {}", node.id());
|
||||
|
||||
let semaphore = self.semaphore.clone();
|
||||
|
||||
self.threads
|
||||
.insert(node.id(), tokio::spawn(async move {
|
||||
Gemla::process_node(node).await
|
||||
Gemla::process_node(node, semaphore).await
|
||||
}));
|
||||
} else {
|
||||
trace!("No node found to process, joining threads");
|
||||
|
@ -323,15 +329,15 @@ where
|
|||
tree.val.state() == GeneticState::Finish
|
||||
}
|
||||
|
||||
async fn process_node(mut node: GeneticNodeWrapper<T>) -> Result<GeneticNodeWrapper<T>, Error> {
|
||||
async fn process_node(mut node: GeneticNodeWrapper<T>, semaphore: Arc<Semaphore>) -> Result<GeneticNodeWrapper<T>, Error> {
|
||||
let node_state_time = Instant::now();
|
||||
let node_state = node.state();
|
||||
|
||||
node.process_node().await?;
|
||||
node.process_node(semaphore.clone()).await?;
|
||||
|
||||
if node.state() == GeneticState::Simulate
|
||||
{
|
||||
node.process_node().await?;
|
||||
node.process_node(semaphore.clone()).await?;
|
||||
}
|
||||
|
||||
trace!(
|
||||
|
@ -427,6 +433,7 @@ mod tests {
|
|||
let mut config = GemlaConfig {
|
||||
generations_per_height: 1,
|
||||
overwrite: true,
|
||||
shared_semaphore_concurrency_limit: 1,
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
|
||||
|
@ -476,6 +483,7 @@ mod tests {
|
|||
let config = GemlaConfig {
|
||||
generations_per_height: 10,
|
||||
overwrite: true,
|
||||
shared_semaphore_concurrency_limit: 1,
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue