dootcamp #1

Merged
tepichord merged 26 commits from dootcamp into master 2025-09-05 09:37:40 -07:00
5 changed files with 140 additions and 87 deletions
Showing only changes of commit ac71b28c7c - Show all commits

View file

@ -16,7 +16,7 @@ use fighter_nn::FighterNN;
use clap::Parser; use clap::Parser;
use anyhow::Result; use anyhow::Result;
// const NUM_THREADS: usize = 12; // const NUM_THREADS: usize = 2;
#[derive(Parser)] #[derive(Parser)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
@ -47,6 +47,7 @@ fn main() -> Result<()> {
GemlaConfig { GemlaConfig {
generations_per_height: 10, generations_per_height: 10,
overwrite: false, overwrite: false,
shared_semaphore_concurrency_limit: 30,
}, },
DataFormat::Json, DataFormat::Json,
))?; ))?;

View file

@ -1,6 +1,6 @@
extern crate fann; extern crate fann;
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}}; use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}, sync::Arc};
use fann::{ActivationFunc, Fann}; use fann::{ActivationFunc, Fann};
use futures::future::join_all; use futures::future::join_all;
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error}; use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
@ -8,9 +8,9 @@ use rand::prelude::*;
use rand::distributions::{Distribution, Uniform}; use rand::distributions::{Distribution, Uniform};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use anyhow::Context; use anyhow::Context;
use tokio::process::Command;
use uuid::Uuid; use uuid::Uuid;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::process::Command;
use async_trait::async_trait; use async_trait::async_trait;
const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations"; const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations";
@ -79,19 +79,26 @@ impl GeneticNode for FighterNN {
})) }))
} }
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> {
trace!("Context: {:?}", context);
let mut tasks = Vec::new();
// For each nn in the current generation: // For each nn in the current generation:
for i in 0..self.population_size { for i in 0..self.population_size {
// load the nn let self_clone = self.clone();
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, i)); let semaphore_clone = Arc::clone(context.semaphore.as_ref().unwrap());
let task = async move {
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(format!("{:06}_fighter_nn_{}.net", self_clone.id, i));
let mut simulations = Vec::new(); let mut simulations = Vec::new();
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently // Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
for _ in 0..SIMULATION_ROUNDS { for _ in 0..SIMULATION_ROUNDS {
let random_nn_index = thread_rng().gen_range(0..self.population_size); let random_nn_index = thread_rng().gen_range(0..self_clone.population_size);
let id = self.id.clone(); let id = self_clone.id.clone();
let folder = self.folder.clone(); let folder = self_clone.folder.clone();
let generation = self.generation; let generation = self_clone.generation;
let semaphore_clone = Arc::clone(&semaphore_clone);
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index)); let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index));
let nn_clone = nn.clone(); // Clone the path to use in the async block let nn_clone = nn.clone(); // Clone the path to use in the async block
@ -100,7 +107,11 @@ impl GeneticNode for FighterNN {
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap()); let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
let disable_unreal_rendering_arg = "-nullrhi".to_string(); let disable_unreal_rendering_arg = "-nullrhi".to_string();
let future = async move { let future = async move {
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
// Construct the score file path // Construct the score file path
let nn_id = format!("{:06}_fighter_nn_{}", id, i); let nn_id = format!("{:06}_fighter_nn_{}", id, i);
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index); let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
@ -122,22 +133,24 @@ impl GeneticNode for FighterNN {
return Ok::<f32, Error>(1.0 - round_score); return Ok::<f32, Error>(1.0 - round_score);
} }
if thread_rng().gen_range(0..100) < 4 { let _output = if thread_rng().gen_range(0..100) < 0 {
let _output = Command::new(GAME_EXECUTABLE_PATH) Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg) .arg(&config1_arg)
.arg(&config2_arg) .arg(&config2_arg)
.output() .output()
.await .await
.expect("Failed to execute game"); .expect("Failed to execute game")
} else { } else {
let _output = Command::new(GAME_EXECUTABLE_PATH) Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg) .arg(&config1_arg)
.arg(&config2_arg) .arg(&config2_arg)
.arg(&disable_unreal_rendering_arg) .arg(&disable_unreal_rendering_arg)
.output() .output()
.await .await
.expect("Failed to execute game"); .expect("Failed to execute game")
} };
drop(permit);
// Read the score from the file // Read the score from the file
let round_score = read_score_from_file(&score_file, &nn_id) let round_score = read_score_from_file(&score_file, &nn_id)
@ -152,9 +165,30 @@ impl GeneticNode for FighterNN {
// Wait for all simulation rounds to complete // Wait for all simulation rounds to complete
let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect(); let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect();
let score = results?.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32; let score = match results {
trace!("NN {:06}_fighter_nn_{} scored {}", self.id, i, score); Ok(scores) => scores.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32,
self.scores[self.generation as usize].insert(i as u64, score); Err(e) => return Err(e), // Return the error if results collection failed
};
trace!("NN {:06}_fighter_nn_{} scored {}", self_clone.id, i, score);
Ok((i, score))
};
tasks.push(task);
}
let results = join_all(tasks).await;
for result in results {
match result {
Ok((index, score)) => {
// Update the original `self` object with the score.
self.scores[self.generation as usize].insert(index as u64, score);
},
Err(e) => {
// Handle task panic or execution error
return Err(Error::Other(anyhow::anyhow!(format!("Task failed: {:?}", e))));
},
}
} }
Ok(()) Ok(())
@ -228,9 +262,9 @@ impl GeneticNode for FighterNN {
if thread_rng().gen_range(0..100) < 20 { if thread_rng().gen_range(0..100) < 20 {
c.weight += thread_rng().gen_range(-0.1..0.1); c.weight += thread_rng().gen_range(-0.1..0.1);
} }
// else if thread_rng().gen_range(0..100) < 5 { else if thread_rng().gen_range(0..100) < 5 {
// c.weight += thread_rng().gen_range(-0.3..0.3); c.weight += thread_rng().gen_range(-0.3..0.3);
// } }
} }
fann.set_connections(&connections); fann.set_connections(&connections);

View file

@ -89,6 +89,7 @@ impl GeneticNode for TestState {
id: id.clone(), id: id.clone(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None,
})?; })?;
Ok(Box::new(result)) Ok(Box::new(result))
@ -107,6 +108,7 @@ mod tests {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None,
} }
).unwrap(); ).unwrap();
@ -126,6 +128,7 @@ mod tests {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None,
} }
).await.unwrap(); ).await.unwrap();
assert!(original_population assert!(original_population
@ -138,6 +141,7 @@ mod tests {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None,
} }
).await.unwrap(); ).await.unwrap();
state.simulate( state.simulate(
@ -145,6 +149,7 @@ mod tests {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None,
} }
).await.unwrap(); ).await.unwrap();
assert!(original_population assert!(original_population
@ -164,6 +169,7 @@ mod tests {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None,
} }
).unwrap(); ).unwrap();

View file

@ -6,7 +6,8 @@ use crate::error::Error;
use anyhow::Context; use anyhow::Context;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::Debug; use tokio::sync::Semaphore;
use std::{fmt::Debug, sync::Arc};
use uuid::Uuid; use uuid::Uuid;
use async_trait::async_trait; use async_trait::async_trait;
@ -25,11 +26,12 @@ pub enum GeneticState {
Finish, Finish,
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct GeneticNodeContext { pub struct GeneticNodeContext {
pub generation: u64, pub generation: u64,
pub max_generations: u64, pub max_generations: u64,
pub id: Uuid, pub id: Uuid,
pub semaphore: Option<Arc<Semaphore>>,
} }
/// A trait used to interact with the internal state of nodes within the [`Bracket`] /// A trait used to interact with the internal state of nodes within the [`Bracket`]
@ -118,11 +120,12 @@ where
self.state self.state
} }
pub async fn process_node(&mut self) -> Result<GeneticState, Error> { pub async fn process_node(&mut self, semaphore: Arc<Semaphore>) -> Result<GeneticState, Error> {
let context = GeneticNodeContext { let context = GeneticNodeContext {
generation: self.generation, generation: self.generation,
max_generations: self.max_generations, max_generations: self.max_generations,
id: self.id, id: self.id,
semaphore: Some(semaphore),
}; };
match (self.state, &mut self.node) { match (self.state, &mut self.node) {
@ -278,13 +281,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_process_node() -> Result<(), Error> { async fn test_process_node() -> Result<(), Error> {
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2); let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
let semaphore = Arc::new(Semaphore::new(1));
assert_eq!(genetic_node.state(), GeneticState::Initialize); assert_eq!(genetic_node.state(), GeneticState::Initialize);
assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate); assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node().await?, GeneticState::Mutate); assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate); assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node().await?, GeneticState::Finish); assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node().await?, GeneticState::Finish); assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
Ok(()) Ok(())
} }

View file

@ -10,9 +10,9 @@ use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
use log::{info, trace, warn}; use log::{info, trace, warn};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio::sync::Semaphore;
use std::{ use std::{
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant
time::Instant,
}; };
use uuid::Uuid; use uuid::Uuid;
@ -58,6 +58,7 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
pub struct GemlaConfig { pub struct GemlaConfig {
pub generations_per_height: u64, pub generations_per_height: u64,
pub overwrite: bool, pub overwrite: bool,
pub shared_semaphore_concurrency_limit: usize,
} }
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`]. /// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
@ -72,6 +73,7 @@ where
{ {
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>, pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>, threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
semaphore: Arc<Semaphore>,
} }
impl<T: 'static> Gemla<T> impl<T: 'static> Gemla<T>
@ -89,11 +91,13 @@ where
FileLinked::from_file(path, data_format)? FileLinked::from_file(path, data_format)?
}, },
threads: HashMap::new(), threads: HashMap::new(),
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
}), }),
// If the file doesn't exist we must create it // If the file doesn't exist we must create it
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla { Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
data: FileLinked::new((None, config), path, data_format)?, data: FileLinked::new((None, config), path, data_format)?,
threads: HashMap::new(), threads: HashMap::new(),
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
}), }),
Err(error) => Err(Error::IO(error)), Err(error) => Err(Error::IO(error)),
} }
@ -147,9 +151,11 @@ where
{ {
trace!("Adding node to process list {}", node.id()); trace!("Adding node to process list {}", node.id());
let semaphore = self.semaphore.clone();
self.threads self.threads
.insert(node.id(), tokio::spawn(async move { .insert(node.id(), tokio::spawn(async move {
Gemla::process_node(node).await Gemla::process_node(node, semaphore).await
})); }));
} else { } else {
trace!("No node found to process, joining threads"); trace!("No node found to process, joining threads");
@ -323,15 +329,15 @@ where
tree.val.state() == GeneticState::Finish tree.val.state() == GeneticState::Finish
} }
async fn process_node(mut node: GeneticNodeWrapper<T>) -> Result<GeneticNodeWrapper<T>, Error> { async fn process_node(mut node: GeneticNodeWrapper<T>, semaphore: Arc<Semaphore>) -> Result<GeneticNodeWrapper<T>, Error> {
let node_state_time = Instant::now(); let node_state_time = Instant::now();
let node_state = node.state(); let node_state = node.state();
node.process_node().await?; node.process_node(semaphore.clone()).await?;
if node.state() == GeneticState::Simulate if node.state() == GeneticState::Simulate
{ {
node.process_node().await?; node.process_node(semaphore.clone()).await?;
} }
trace!( trace!(
@ -427,6 +433,7 @@ mod tests {
let mut config = GemlaConfig { let mut config = GemlaConfig {
generations_per_height: 1, generations_per_height: 1,
overwrite: true, overwrite: true,
shared_semaphore_concurrency_limit: 1,
}; };
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?; let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
@ -476,6 +483,7 @@ mod tests {
let config = GemlaConfig { let config = GemlaConfig {
generations_per_height: 10, generations_per_height: 10,
overwrite: true, overwrite: true,
shared_semaphore_concurrency_limit: 1,
}; };
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?; let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;