Finalizing async implementation
This commit is contained in:
parent
b56e37d411
commit
7a1f82ac63
8 changed files with 1439 additions and 626 deletions
|
@ -6,12 +6,11 @@ pub mod constants;
|
|||
use anyhow::{anyhow, Context};
|
||||
use constants::data_format::DataFormat;
|
||||
use error::Error;
|
||||
use futures::executor::block_on;
|
||||
use log::info;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
use std::{
|
||||
borrow::Borrow, fs::{copy, remove_file, File}, io::{ErrorKind, Write}, path::{Path, PathBuf}, sync::Arc, thread::{self, JoinHandle}
|
||||
fs::{copy, remove_file, File}, io::{ErrorKind, Write}, path::{Path, PathBuf}, sync::Arc, thread::{self, JoinHandle}
|
||||
};
|
||||
|
||||
|
||||
|
@ -56,6 +55,7 @@ where
|
|||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tokio;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -64,19 +64,22 @@ where
|
|||
/// # pub c: f64
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() {
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() {
|
||||
/// let test = Test {
|
||||
/// a: 1,
|
||||
/// b: String::from("two"),
|
||||
/// c: 3.0
|
||||
/// };
|
||||
///
|
||||
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json)
|
||||
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
/// assert_eq!(linked_test.readonly().b, String::from("two"));
|
||||
/// assert_eq!(linked_test.readonly().c, 3.0);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 1);
|
||||
/// assert_eq!(readonly_ref.b, String::from("two"));
|
||||
/// assert_eq!(readonly_ref.c, 3.0);
|
||||
/// #
|
||||
/// # drop(linked_test);
|
||||
/// #
|
||||
|
@ -97,6 +100,7 @@ where
|
|||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tokio;
|
||||
/// #
|
||||
/// #[derive(Deserialize, Serialize)]
|
||||
/// struct Test {
|
||||
|
@ -105,19 +109,22 @@ where
|
|||
/// pub c: f64
|
||||
/// }
|
||||
///
|
||||
/// # fn main() {
|
||||
/// #[tokio::main]
|
||||
/// # async fn main() {
|
||||
/// let test = Test {
|
||||
/// a: 1,
|
||||
/// b: String::from("two"),
|
||||
/// c: 3.0
|
||||
/// };
|
||||
///
|
||||
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json)
|
||||
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
/// assert_eq!(linked_test.readonly().b, String::from("two"));
|
||||
/// assert_eq!(linked_test.readonly().c, 3.0);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 1);
|
||||
/// assert_eq!(readonly_ref.b, String::from("two"));
|
||||
/// assert_eq!(readonly_ref.c, 3.0);
|
||||
/// #
|
||||
/// # drop(linked_test);
|
||||
/// #
|
||||
|
@ -207,6 +214,7 @@ where
|
|||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tokio;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -215,21 +223,28 @@ where
|
|||
/// # pub c: f64
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() -> Result<(), Error> {
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() -> Result<(), Error> {
|
||||
/// let test = Test {
|
||||
/// a: 1,
|
||||
/// b: String::from(""),
|
||||
/// c: 0.0
|
||||
/// };
|
||||
///
|
||||
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode)
|
||||
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
/// {
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 1);
|
||||
/// }
|
||||
///
|
||||
/// linked_test.mutate(|t| t.a = 2)?;
|
||||
/// linked_test.mutate(|t| t.a = 2).await?;
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 2);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 2);
|
||||
/// #
|
||||
/// # drop(linked_test);
|
||||
/// #
|
||||
|
@ -262,6 +277,7 @@ where
|
|||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tokio;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -270,25 +286,30 @@ where
|
|||
/// # pub c: f64
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() -> Result<(), Error> {
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() -> Result<(), Error> {
|
||||
/// let test = Test {
|
||||
/// a: 1,
|
||||
/// b: String::from(""),
|
||||
/// c: 0.0
|
||||
/// };
|
||||
///
|
||||
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode)
|
||||
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 1);
|
||||
///
|
||||
/// linked_test.replace(Test {
|
||||
/// a: 2,
|
||||
/// b: String::from(""),
|
||||
/// c: 0.0
|
||||
/// })?;
|
||||
/// }).await?;
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 2);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 2);
|
||||
/// #
|
||||
/// # drop(linked_test);
|
||||
/// #
|
||||
|
@ -343,6 +364,7 @@ where
|
|||
/// # use std::fs::OpenOptions;
|
||||
/// # use std::io::Write;
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tokio;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -351,7 +373,8 @@ where
|
|||
/// # pub c: f64
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() -> Result<(), Error> {
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() -> Result<(), Error> {
|
||||
/// let test = Test {
|
||||
/// a: 1,
|
||||
/// b: String::from("2"),
|
||||
|
@ -371,9 +394,11 @@ where
|
|||
/// let mut linked_test = FileLinked::<Test>::from_file(&path, DataFormat::Bincode)
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, test.a);
|
||||
/// assert_eq!(linked_test.readonly().b, test.b);
|
||||
/// assert_eq!(linked_test.readonly().c, test.c);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, test.a);
|
||||
/// assert_eq!(readonly_ref.b, test.b);
|
||||
/// assert_eq!(readonly_ref.c, test.c);
|
||||
/// #
|
||||
/// # drop(linked_test);
|
||||
/// #
|
||||
|
|
|
@ -3,18 +3,18 @@ extern crate gemla;
|
|||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
mod test_state;
|
||||
mod fighter_nn;
|
||||
mod test_state;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use fighter_nn::FighterNN;
|
||||
use file_linked::constants::data_format::DataFormat;
|
||||
use gemla::{
|
||||
core::{Gemla, GemlaConfig},
|
||||
error::log_error,
|
||||
};
|
||||
use std::{path::PathBuf, time::Instant};
|
||||
use fighter_nn::FighterNN;
|
||||
use clap::Parser;
|
||||
use anyhow::Result;
|
||||
|
||||
// const NUM_THREADS: usize = 2;
|
||||
|
||||
|
@ -44,14 +44,17 @@ fn main() -> Result<()> {
|
|||
.build()?
|
||||
.block_on(async {
|
||||
let args = Args::parse(); // Assuming Args::parse() doesn't need to be async
|
||||
let mut gemla = log_error(Gemla::<FighterNN>::new(
|
||||
&PathBuf::from(args.file),
|
||||
GemlaConfig {
|
||||
generations_per_height: 5,
|
||||
overwrite: false,
|
||||
},
|
||||
DataFormat::Json,
|
||||
).await)?;
|
||||
let mut gemla = log_error(
|
||||
Gemla::<FighterNN>::new(
|
||||
&PathBuf::from(args.file),
|
||||
GemlaConfig {
|
||||
generations_per_height: 5,
|
||||
overwrite: false,
|
||||
},
|
||||
DataFormat::Json,
|
||||
)
|
||||
.await,
|
||||
)?;
|
||||
|
||||
// let gemla_arc = Arc::new(gemla);
|
||||
|
||||
|
@ -59,7 +62,8 @@ fn main() -> Result<()> {
|
|||
// If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks
|
||||
|
||||
// Example placeholder loop to continuously run simulate
|
||||
loop { // Arbitrary loop count for demonstration
|
||||
loop {
|
||||
// Arbitrary loop count for demonstration
|
||||
gemla.simulate(1).await?;
|
||||
}
|
||||
});
|
||||
|
|
|
@ -5,7 +5,6 @@ use tokio::sync::Semaphore;
|
|||
|
||||
const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 50;
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FighterContext {
|
||||
pub shared_semaphore: Arc<Semaphore>,
|
||||
|
@ -19,7 +18,6 @@ impl Default for FighterContext {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Custom serialization to just output the concurrency limit.
|
||||
impl Serialize for FighterContext {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
|
|
|
@ -1,20 +1,29 @@
|
|||
extern crate fann;
|
||||
|
||||
pub mod neural_network_utility;
|
||||
pub mod fighter_context;
|
||||
pub mod neural_network_utility;
|
||||
|
||||
use std::{cmp::max, collections::{HashSet, VecDeque}, fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, panic::{catch_unwind, AssertUnwindSafe}, path::{Path, PathBuf}, sync::{Arc, Mutex}, time::Duration};
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use fann::{ActivationFunc, Fann};
|
||||
use futures::{executor::block_on, future::{join, join_all, select_all}, stream::FuturesUnordered, FutureExt, StreamExt};
|
||||
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
|
||||
use futures::future::join_all;
|
||||
use gemla::{
|
||||
core::genetic_node::{GeneticNode, GeneticNodeContext},
|
||||
error::Error,
|
||||
};
|
||||
use lerp::Lerp;
|
||||
use rand::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use anyhow::Context;
|
||||
use tokio::{process::Command, sync::{mpsc, Semaphore}, task, time::{sleep, timeout, Sleep}};
|
||||
use uuid::Uuid;
|
||||
use std::collections::HashMap;
|
||||
use async_trait::async_trait;
|
||||
use std::{
|
||||
cmp::max,
|
||||
fs::{self, File},
|
||||
io::{self, BufRead, BufReader},
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use tokio::process::Command;
|
||||
use uuid::Uuid;
|
||||
|
||||
use self::neural_network_utility::{crossbreed, major_mutation};
|
||||
|
||||
|
@ -34,7 +43,8 @@ const NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX: usize = 20;
|
|||
|
||||
const SIMULATION_ROUNDS: usize = 5;
|
||||
const SURVIVAL_RATE: f32 = 0.5;
|
||||
const GAME_EXECUTABLE_PATH: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe";
|
||||
const GAME_EXECUTABLE_PATH: &str =
|
||||
"F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe";
|
||||
|
||||
// Here is the folder structure for the FighterNN:
|
||||
// base_dir/fighter_nn_{fighter_id}/{generation}/{fighter_id}_fighter_nn_{nn_id}.net
|
||||
|
@ -78,11 +88,17 @@ impl GeneticNode for FighterNN {
|
|||
|
||||
//Create a new directory for the first generation, using create_dir_all to avoid errors if it already exists
|
||||
let gen_folder = folder.join("0");
|
||||
fs::create_dir_all(&gen_folder)
|
||||
.with_context(|| format!("Failed to create or access the generation folder: {:?}", gen_folder))?;
|
||||
fs::create_dir_all(&gen_folder).with_context(|| {
|
||||
format!(
|
||||
"Failed to create or access the generation folder: {:?}",
|
||||
gen_folder
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut nn_shapes = HashMap::new();
|
||||
let weight_initialization_range = thread_rng().gen_range(NEURAL_NETWORK_INITIAL_WEIGHT_MIN..0.0)..thread_rng().gen_range(0.0..=NEURAL_NETWORK_INITIAL_WEIGHT_MAX);
|
||||
let weight_initialization_range = thread_rng()
|
||||
.gen_range(NEURAL_NETWORK_INITIAL_WEIGHT_MIN..0.0)
|
||||
..thread_rng().gen_range(0.0..=NEURAL_NETWORK_INITIAL_WEIGHT_MAX);
|
||||
|
||||
// Create the first generation in this folder
|
||||
for i in 0..POPULATION {
|
||||
|
@ -90,17 +106,22 @@ impl GeneticNode for FighterNN {
|
|||
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i));
|
||||
|
||||
// Randomly generate a neural network shape based on constants
|
||||
let hidden_layers = thread_rng().gen_range(NEURAL_NETWORK_HIDDEN_LAYERS_MIN..NEURAL_NETWORK_HIDDEN_LAYERS_MAX);
|
||||
let hidden_layers = thread_rng()
|
||||
.gen_range(NEURAL_NETWORK_HIDDEN_LAYERS_MIN..NEURAL_NETWORK_HIDDEN_LAYERS_MAX);
|
||||
let mut nn_shape = vec![NEURAL_NETWORK_INPUTS as u32];
|
||||
for _ in 0..hidden_layers {
|
||||
nn_shape.push(thread_rng().gen_range(NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN..NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX) as u32);
|
||||
nn_shape.push(thread_rng().gen_range(
|
||||
NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN..NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX,
|
||||
) as u32);
|
||||
}
|
||||
nn_shape.push(NEURAL_NETWORK_OUTPUTS as u32);
|
||||
nn_shapes.insert(i as u64, nn_shape.clone());
|
||||
|
||||
let mut fann = Fann::new(nn_shape.as_slice())
|
||||
.with_context(|| "Failed to create nn")?;
|
||||
fann.randomize_weights(weight_initialization_range.start, weight_initialization_range.end);
|
||||
let mut fann = Fann::new(nn_shape.as_slice()).with_context(|| "Failed to create nn")?;
|
||||
fann.randomize_weights(
|
||||
weight_initialization_range.start,
|
||||
weight_initialization_range.end,
|
||||
);
|
||||
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
|
||||
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
|
||||
// This will overwrite any existing file with the same name
|
||||
|
@ -108,7 +129,9 @@ impl GeneticNode for FighterNN {
|
|||
.with_context(|| format!("Failed to save nn at {:?}", nn))?;
|
||||
}
|
||||
|
||||
let mut crossbreed_segments = thread_rng().gen_range(NEURAL_NETWORK_CROSSBREED_SEGMENTS_MIN..NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX);
|
||||
let mut crossbreed_segments = thread_rng().gen_range(
|
||||
NEURAL_NETWORK_CROSSBREED_SEGMENTS_MIN..NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX,
|
||||
);
|
||||
if crossbreed_segments % 2 == 0 {
|
||||
crossbreed_segments += 1;
|
||||
}
|
||||
|
@ -141,7 +164,10 @@ impl GeneticNode for FighterNN {
|
|||
let semaphore_clone = context.gemla_context.shared_semaphore.clone();
|
||||
|
||||
let task = async move {
|
||||
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(self_clone.get_individual_id(i as u64));
|
||||
let nn = self_clone
|
||||
.folder
|
||||
.join(format!("{}", self_clone.generation))
|
||||
.join(self_clone.get_individual_id(i as u64));
|
||||
let mut simulations = Vec::new();
|
||||
|
||||
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
|
||||
|
@ -151,11 +177,16 @@ impl GeneticNode for FighterNN {
|
|||
let generation = self_clone.generation;
|
||||
let semaphore_clone = semaphore_clone.clone();
|
||||
|
||||
let random_nn = folder.join(format!("{}", generation)).join(self_clone.get_individual_id(random_nn_index as u64));
|
||||
let random_nn = folder
|
||||
.join(format!("{}", generation))
|
||||
.join(self_clone.get_individual_id(random_nn_index as u64));
|
||||
let nn_clone = nn.clone(); // Clone the path to use in the async block
|
||||
|
||||
let future = async move {
|
||||
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
|
||||
let permit = semaphore_clone
|
||||
.acquire_owned()
|
||||
.await
|
||||
.with_context(|| "Failed to acquire semaphore permit")?;
|
||||
|
||||
let (score, _) = run_1v1_simulation(&nn_clone, &random_nn).await?;
|
||||
|
||||
|
@ -168,7 +199,8 @@ impl GeneticNode for FighterNN {
|
|||
}
|
||||
|
||||
// Wait for all simulation rounds to complete
|
||||
let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect();
|
||||
let results: Result<Vec<f32>, Error> =
|
||||
join_all(simulations).await.into_iter().collect();
|
||||
|
||||
let score = match results {
|
||||
Ok(scores) => scores.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32,
|
||||
|
@ -188,34 +220,46 @@ impl GeneticNode for FighterNN {
|
|||
Ok((index, score)) => {
|
||||
// Update the original `self` object with the score.
|
||||
self.scores[self.generation as usize].insert(index as u64, score);
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
// Handle task panic or execution error
|
||||
return Err(Error::Other(anyhow::anyhow!(format!("Task failed: {:?}", e))));
|
||||
},
|
||||
return Err(Error::Other(anyhow::anyhow!(format!(
|
||||
"Task failed: {:?}",
|
||||
e
|
||||
))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
|
||||
|
||||
// Create the new generation folder
|
||||
let new_gen_folder = self.folder.join(format!("{}", self.generation + 1));
|
||||
fs::create_dir_all(&new_gen_folder).with_context(|| format!("Failed to create or access new generation folder: {:?}", new_gen_folder))?;
|
||||
fs::create_dir_all(&new_gen_folder).with_context(|| {
|
||||
format!(
|
||||
"Failed to create or access new generation folder: {:?}",
|
||||
new_gen_folder
|
||||
)
|
||||
})?;
|
||||
|
||||
// Remove the 5 nn's with the lowest scores
|
||||
let mut sorted_scores: Vec<_> = self.scores[self.generation as usize].iter().collect();
|
||||
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
|
||||
let to_keep = sorted_scores[survivor_count..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
|
||||
let to_keep = sorted_scores[survivor_count..]
|
||||
.iter()
|
||||
.map(|(k, _)| *k)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Save the remaining 5 nn's to the new generation folder
|
||||
for i in 0..survivor_count {
|
||||
let nn_id = to_keep[i];
|
||||
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id));
|
||||
for (i, nn_id) in to_keep.iter().enumerate().take(survivor_count) {
|
||||
let nn = self
|
||||
.folder
|
||||
.join(format!("{}", self.generation))
|
||||
.join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id));
|
||||
let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i));
|
||||
fs::copy(&nn, &new_nn)?;
|
||||
}
|
||||
|
@ -223,16 +267,25 @@ impl GeneticNode for FighterNN {
|
|||
// Take the remaining 5 nn's and create 5 new nn's by the following:
|
||||
for i in 0..survivor_count {
|
||||
let nn_id = to_keep[i];
|
||||
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id));
|
||||
let fann = Fann::from_file(&nn)
|
||||
.with_context(|| format!("Failed to load nn"))?;
|
||||
let nn = self
|
||||
.folder
|
||||
.join(format!("{}", self.generation))
|
||||
.join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id));
|
||||
let fann = Fann::from_file(&nn).with_context(|| "Failed to load nn")?;
|
||||
|
||||
// Load another nn from the current generation and cross breed it with the current nn
|
||||
let cross_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, to_keep[thread_rng().gen_range(0..survivor_count)]));
|
||||
let cross_fann = Fann::from_file(&cross_nn)
|
||||
.with_context(|| format!("Failed to load cross nn"))?;
|
||||
let cross_nn = self
|
||||
.folder
|
||||
.join(format!("{}", self.generation))
|
||||
.join(format!(
|
||||
"{:06}_fighter_nn_{}.net",
|
||||
self.id,
|
||||
to_keep[thread_rng().gen_range(0..survivor_count)]
|
||||
));
|
||||
let cross_fann =
|
||||
Fann::from_file(&cross_nn).with_context(|| "Failed to load cross nn")?;
|
||||
|
||||
let mut new_fann = crossbreed(&self, &fann, &cross_fann, self.crossbreed_segments)?;
|
||||
let mut new_fann = crossbreed(self, &fann, &cross_fann, self.crossbreed_segments)?;
|
||||
|
||||
// For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight)
|
||||
// And a 5% chance of a major mutation a new neuron is randomly added to a hidden layer
|
||||
|
@ -252,9 +305,14 @@ impl GeneticNode for FighterNN {
|
|||
}
|
||||
|
||||
// Save the new nn's to the new generation folder
|
||||
let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i + survivor_count));
|
||||
new_fann.save(&new_nn)
|
||||
.with_context(|| format!("Failed to save nn"))?;
|
||||
let new_nn = new_gen_folder.join(format!(
|
||||
"{:06}_fighter_nn_{}.net",
|
||||
self.id,
|
||||
i + survivor_count
|
||||
));
|
||||
new_fann
|
||||
.save(&new_nn)
|
||||
.with_context(|| "Failed to save nn")?;
|
||||
}
|
||||
|
||||
self.generation += 1;
|
||||
|
@ -263,18 +321,28 @@ impl GeneticNode for FighterNN {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid, gemla_context: Self::Context) -> Result<Box<FighterNN>, Error> {
|
||||
async fn merge(
|
||||
left: &FighterNN,
|
||||
right: &FighterNN,
|
||||
id: &Uuid,
|
||||
gemla_context: Self::Context,
|
||||
) -> Result<Box<FighterNN>, Error> {
|
||||
let base_path = PathBuf::from(BASE_DIR);
|
||||
let folder = base_path.join(format!("fighter_nn_{:06}", id));
|
||||
|
||||
// Ensure the folder exists, including the generation subfolder.
|
||||
fs::create_dir_all(&folder.join("0"))
|
||||
fs::create_dir_all(folder.join("0"))
|
||||
.with_context(|| format!("Failed to create directory {:?}", folder.join("0")))?;
|
||||
|
||||
let get_highest_scores = |fighter: &FighterNN| -> Vec<(u64, f32)> {
|
||||
let mut sorted_scores: Vec<_> = fighter.scores[fighter.generation as usize].iter().collect();
|
||||
let mut sorted_scores: Vec<_> =
|
||||
fighter.scores[fighter.generation as usize].iter().collect();
|
||||
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
|
||||
sorted_scores.iter().take(fighter.population_size / 2).map(|(k, v)| (**k, **v)).collect()
|
||||
sorted_scores
|
||||
.iter()
|
||||
.take(fighter.population_size / 2)
|
||||
.map(|(k, v)| (**k, **v))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let left_scores = get_highest_scores(left);
|
||||
|
@ -285,18 +353,28 @@ impl GeneticNode for FighterNN {
|
|||
|
||||
let mut simulations = Vec::new();
|
||||
|
||||
for _ in 0..max(left.population_size, right.population_size)*SIMULATION_ROUNDS {
|
||||
for _ in 0..max(left.population_size, right.population_size) * SIMULATION_ROUNDS {
|
||||
let left_nn_id = left_scores[thread_rng().gen_range(0..left_scores.len())].0;
|
||||
let right_nn_id = right_scores[thread_rng().gen_range(0..right_scores.len())].0;
|
||||
|
||||
let left_nn_path = left.folder.join(left.generation.to_string()).join(left.get_individual_id(left_nn_id));
|
||||
let right_nn_path = right.folder.join(right.generation.to_string()).join(right.get_individual_id(right_nn_id));
|
||||
let left_nn_path = left
|
||||
.folder
|
||||
.join(left.generation.to_string())
|
||||
.join(left.get_individual_id(left_nn_id));
|
||||
let right_nn_path = right
|
||||
.folder
|
||||
.join(right.generation.to_string())
|
||||
.join(right.get_individual_id(right_nn_id));
|
||||
let semaphore_clone = gemla_context.shared_semaphore.clone();
|
||||
|
||||
let future = async move {
|
||||
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
|
||||
let permit = semaphore_clone
|
||||
.acquire_owned()
|
||||
.await
|
||||
.with_context(|| "Failed to acquire semaphore permit")?;
|
||||
|
||||
let (left_score, right_score) = run_1v1_simulation(&left_nn_path, &right_nn_path).await?;
|
||||
let (left_score, right_score) =
|
||||
run_1v1_simulation(&left_nn_path, &right_nn_path).await?;
|
||||
|
||||
drop(permit);
|
||||
|
||||
|
@ -306,7 +384,8 @@ impl GeneticNode for FighterNN {
|
|||
simulations.push(future);
|
||||
}
|
||||
|
||||
let results: Result<Vec<(f32, f32)>, Error> = join_all(simulations).await.into_iter().collect();
|
||||
let results: Result<Vec<(f32, f32)>, Error> =
|
||||
join_all(simulations).await.into_iter().collect();
|
||||
let scores = results?;
|
||||
|
||||
let total_left_score = scores.iter().map(|(l, _)| l).sum::<f32>();
|
||||
|
@ -322,18 +401,36 @@ impl GeneticNode for FighterNN {
|
|||
let mut nn_shapes = HashMap::new();
|
||||
|
||||
// Function to copy NNs from a source FighterNN to the new folder.
|
||||
let mut copy_nns = |source: &FighterNN, folder: &PathBuf, id: &Uuid, start_idx: usize| -> Result<(), Error> {
|
||||
let mut sorted_scores: Vec<_> = source.scores[source.generation as usize].iter().collect();
|
||||
let mut copy_nns = |source: &FighterNN,
|
||||
folder: &PathBuf,
|
||||
id: &Uuid,
|
||||
start_idx: usize|
|
||||
-> Result<(), Error> {
|
||||
let mut sorted_scores: Vec<_> =
|
||||
source.scores[source.generation as usize].iter().collect();
|
||||
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
|
||||
let remaining = sorted_scores[(source.population_size / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
|
||||
let remaining = sorted_scores[(source.population_size / 2)..]
|
||||
.iter()
|
||||
.map(|(k, _)| *k)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (i, nn_id) in remaining.into_iter().enumerate() {
|
||||
let nn_path = source.folder.join(source.generation.to_string()).join(format!("{:06}_fighter_nn_{}.net", source.id, nn_id));
|
||||
let new_nn_path = folder.join("0").join(format!("{:06}_fighter_nn_{}.net", id, start_idx + i));
|
||||
fs::copy(&nn_path, &new_nn_path)
|
||||
.with_context(|| format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path))?;
|
||||
let nn_path = source
|
||||
.folder
|
||||
.join(source.generation.to_string())
|
||||
.join(format!("{:06}_fighter_nn_{}.net", source.id, nn_id));
|
||||
let new_nn_path =
|
||||
folder
|
||||
.join("0")
|
||||
.join(format!("{:06}_fighter_nn_{}.net", id, start_idx + i));
|
||||
fs::copy(&nn_path, &new_nn_path).with_context(|| {
|
||||
format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path)
|
||||
})?;
|
||||
|
||||
nn_shapes.insert((start_idx + i) as u64, source.nn_shapes.get(&nn_id).unwrap().clone());
|
||||
nn_shapes.insert(
|
||||
(start_idx + i) as u64,
|
||||
source.nn_shapes.get(nn_id).unwrap().clone(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -341,32 +438,54 @@ impl GeneticNode for FighterNN {
|
|||
|
||||
// Copy the top half of NNs from each parent to the new folder.
|
||||
copy_nns(left, &folder, id, 0)?;
|
||||
copy_nns(right, &folder, id, left.population_size as usize / 2)?;
|
||||
copy_nns(right, &folder, id, left.population_size / 2)?;
|
||||
|
||||
debug!("nn_shapes: {:?}", nn_shapes);
|
||||
|
||||
// Lerp the mutation rates and weight ranges
|
||||
let crossbreed_segments = (left.crossbreed_segments as f32).lerp(right.crossbreed_segments as f32, lerp_amount) as usize;
|
||||
let crossbreed_segments = (left.crossbreed_segments as f32)
|
||||
.lerp(right.crossbreed_segments as f32, lerp_amount)
|
||||
as usize;
|
||||
|
||||
let weight_initialization_range_start = left.weight_initialization_range.start.lerp(right.weight_initialization_range.start, lerp_amount);
|
||||
let weight_initialization_range_end = left.weight_initialization_range.end.lerp(right.weight_initialization_range.end, lerp_amount);
|
||||
let weight_initialization_range_start = left
|
||||
.weight_initialization_range
|
||||
.start
|
||||
.lerp(right.weight_initialization_range.start, lerp_amount);
|
||||
let weight_initialization_range_end = left
|
||||
.weight_initialization_range
|
||||
.end
|
||||
.lerp(right.weight_initialization_range.end, lerp_amount);
|
||||
// Have to ensure the range is valid
|
||||
let weight_initialization_range = if weight_initialization_range_start < weight_initialization_range_end {
|
||||
weight_initialization_range_start..weight_initialization_range_end
|
||||
} else {
|
||||
weight_initialization_range_end..weight_initialization_range_start
|
||||
};
|
||||
let weight_initialization_range =
|
||||
if weight_initialization_range_start < weight_initialization_range_end {
|
||||
weight_initialization_range_start..weight_initialization_range_end
|
||||
} else {
|
||||
weight_initialization_range_end..weight_initialization_range_start
|
||||
};
|
||||
|
||||
debug!("weight_initialization_range: {:?}", weight_initialization_range);
|
||||
debug!(
|
||||
"weight_initialization_range: {:?}",
|
||||
weight_initialization_range
|
||||
);
|
||||
|
||||
let minor_mutation_rate = left.minor_mutation_rate.lerp(right.minor_mutation_rate, lerp_amount);
|
||||
let major_mutation_rate = left.major_mutation_rate.lerp(right.major_mutation_rate, lerp_amount);
|
||||
let minor_mutation_rate = left
|
||||
.minor_mutation_rate
|
||||
.lerp(right.minor_mutation_rate, lerp_amount);
|
||||
let major_mutation_rate = left
|
||||
.major_mutation_rate
|
||||
.lerp(right.major_mutation_rate, lerp_amount);
|
||||
|
||||
debug!("minor_mutation_rate: {}", minor_mutation_rate);
|
||||
debug!("major_mutation_rate: {}", major_mutation_rate);
|
||||
|
||||
let mutation_weight_range_start = left.mutation_weight_range.start.lerp(right.mutation_weight_range.start, lerp_amount);
|
||||
let mutation_weight_range_end = left.mutation_weight_range.end.lerp(right.mutation_weight_range.end, lerp_amount);
|
||||
let mutation_weight_range_start = left
|
||||
.mutation_weight_range
|
||||
.start
|
||||
.lerp(right.mutation_weight_range.start, lerp_amount);
|
||||
let mutation_weight_range_end = left
|
||||
.mutation_weight_range
|
||||
.end
|
||||
.lerp(right.mutation_weight_range.end, lerp_amount);
|
||||
// Have to ensure the range is valid
|
||||
let mutation_weight_range = if mutation_weight_range_start < mutation_weight_range_end {
|
||||
mutation_weight_range_start..mutation_weight_range_end
|
||||
|
@ -398,7 +517,7 @@ impl FighterNN {
|
|||
}
|
||||
}
|
||||
|
||||
async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<(f32, f32), Error> {
|
||||
async fn run_1v1_simulation(nn_path_1: &Path, nn_path_2: &Path) -> Result<(f32, f32), Error> {
|
||||
// Construct the score file path
|
||||
let base_folder = nn_path_1.parent().unwrap();
|
||||
let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap();
|
||||
|
@ -407,14 +526,18 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<
|
|||
|
||||
// Check if score file already exists before running the simulation
|
||||
if score_file.exists() {
|
||||
let round_score = read_score_from_file(&score_file, &nn_1_id).await
|
||||
let round_score = read_score_from_file(&score_file, nn_1_id)
|
||||
.await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
|
||||
|
||||
let opposing_score = read_score_from_file(&score_file, &nn_2_id).await
|
||||
let opposing_score = read_score_from_file(&score_file, nn_2_id)
|
||||
.await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
|
||||
|
||||
debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score);
|
||||
|
||||
debug!(
|
||||
"{} scored {}, while {} scored {}",
|
||||
nn_1_id, round_score, nn_2_id, opposing_score
|
||||
);
|
||||
|
||||
return Ok((round_score, opposing_score));
|
||||
}
|
||||
|
@ -422,13 +545,22 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<
|
|||
// Check if the opposite round score has been determined
|
||||
let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id));
|
||||
if opposite_score_file.exists() {
|
||||
let round_score = read_score_from_file(&opposite_score_file, &nn_1_id).await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
|
||||
let round_score = read_score_from_file(&opposite_score_file, nn_1_id)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("Failed to read score from file: {:?}", opposite_score_file)
|
||||
})?;
|
||||
|
||||
let opposing_score = read_score_from_file(&opposite_score_file, &nn_2_id).await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
|
||||
let opposing_score = read_score_from_file(&opposite_score_file, nn_2_id)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("Failed to read score from file: {:?}", opposite_score_file)
|
||||
})?;
|
||||
|
||||
debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score);
|
||||
debug!(
|
||||
"{} scored {}, while {} scored {}",
|
||||
nn_1_id, round_score, nn_2_id, opposing_score
|
||||
);
|
||||
|
||||
return Ok((round_score, opposing_score));
|
||||
}
|
||||
|
@ -459,20 +591,29 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<
|
|||
.expect("Failed to execute game")
|
||||
};
|
||||
|
||||
trace!("Simulation completed for {} vs {}: {}", nn_1_id, nn_2_id, score_file.exists());
|
||||
trace!(
|
||||
"Simulation completed for {} vs {}: {}",
|
||||
nn_1_id,
|
||||
nn_2_id,
|
||||
score_file.exists()
|
||||
);
|
||||
|
||||
// Read the score from the file
|
||||
if score_file.exists() {
|
||||
let round_score = read_score_from_file(&score_file, &nn_1_id).await
|
||||
let round_score = read_score_from_file(&score_file, nn_1_id)
|
||||
.await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
|
||||
|
||||
let opposing_score = read_score_from_file(&score_file, &nn_2_id).await
|
||||
let opposing_score = read_score_from_file(&score_file, nn_2_id)
|
||||
.await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
|
||||
|
||||
debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score);
|
||||
debug!(
|
||||
"{} scored {}, while {} scored {}",
|
||||
nn_1_id, round_score, nn_2_id, opposing_score
|
||||
);
|
||||
|
||||
|
||||
return Ok((round_score, opposing_score))
|
||||
Ok((round_score, opposing_score))
|
||||
} else {
|
||||
warn!("Score file not found: {:?}", score_file);
|
||||
Ok((0.0, 0.0))
|
||||
|
@ -492,7 +633,10 @@ async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::
|
|||
if line.starts_with(nn_id) {
|
||||
let parts: Vec<&str> = line.split(':').collect();
|
||||
if parts.len() == 2 {
|
||||
return parts[1].trim().parse::<f32>().map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e));
|
||||
return parts[1]
|
||||
.trim()
|
||||
.parse::<f32>()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -501,16 +645,21 @@ async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::
|
|||
io::ErrorKind::NotFound,
|
||||
"NN ID not found in scores file",
|
||||
));
|
||||
},
|
||||
Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::PermissionDenied || e.kind() == io::ErrorKind::Other => {
|
||||
if attempts >= 5 { // Attempt 5 times before giving up.
|
||||
}
|
||||
Err(e)
|
||||
if e.kind() == io::ErrorKind::WouldBlock
|
||||
|| e.kind() == io::ErrorKind::PermissionDenied
|
||||
|| e.kind() == io::ErrorKind::Other =>
|
||||
{
|
||||
if attempts >= 5 {
|
||||
// Attempt 5 times before giving up.
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
attempts += 1;
|
||||
// wait 1 second to ensure the file is written
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
},
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,8 +1,11 @@
|
|||
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
|
||||
use async_trait::async_trait;
|
||||
use gemla::{
|
||||
core::genetic_node::{GeneticNode, GeneticNodeContext},
|
||||
error::Error,
|
||||
};
|
||||
use rand::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
use async_trait::async_trait;
|
||||
|
||||
const POPULATION_SIZE: u64 = 5;
|
||||
const POPULATION_REDUCTION_SIZE: u64 = 3;
|
||||
|
@ -76,7 +79,12 @@ impl GeneticNode for TestState {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn merge(left: &TestState, right: &TestState, id: &Uuid, gemla_context: Self::Context) -> Result<Box<TestState>, Error> {
|
||||
async fn merge(
|
||||
left: &TestState,
|
||||
right: &TestState,
|
||||
id: &Uuid,
|
||||
gemla_context: Self::Context,
|
||||
) -> Result<Box<TestState>, Error> {
|
||||
let mut v = left.population.clone();
|
||||
v.append(&mut right.population.clone());
|
||||
|
||||
|
@ -87,12 +95,14 @@ impl GeneticNode for TestState {
|
|||
|
||||
let mut result = TestState { population: v };
|
||||
|
||||
result.mutate(GeneticNodeContext {
|
||||
id: id.clone(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
gemla_context: gemla_context
|
||||
}).await?;
|
||||
result
|
||||
.mutate(GeneticNodeContext {
|
||||
id: *id,
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
gemla_context,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(Box::new(result))
|
||||
}
|
||||
|
@ -105,14 +115,14 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_initialize() {
|
||||
let state = TestState::initialize(
|
||||
GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
gemla_context: (),
|
||||
}
|
||||
).await.unwrap();
|
||||
let state = TestState::initialize(GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
gemla_context: (),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||
}
|
||||
|
@ -125,35 +135,38 @@ mod tests {
|
|||
|
||||
let original_population = state.population.clone();
|
||||
|
||||
state.simulate(
|
||||
GeneticNodeContext {
|
||||
state
|
||||
.simulate(GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
gemla_context: (),
|
||||
}
|
||||
).await.unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(original_population
|
||||
.iter()
|
||||
.zip(state.population.iter())
|
||||
.all(|(&a, &b)| b >= a - 1 && b <= a + 2));
|
||||
|
||||
state.simulate(
|
||||
GeneticNodeContext {
|
||||
state
|
||||
.simulate(GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
gemla_context: (),
|
||||
}
|
||||
).await.unwrap();
|
||||
state.simulate(
|
||||
GeneticNodeContext {
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
state
|
||||
.simulate(GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
gemla_context: (),
|
||||
}
|
||||
).await.unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(original_population
|
||||
.iter()
|
||||
.zip(state.population.iter())
|
||||
|
@ -166,14 +179,15 @@ mod tests {
|
|||
population: vec![4, 3, 3],
|
||||
};
|
||||
|
||||
state.mutate(
|
||||
GeneticNodeContext {
|
||||
state
|
||||
.mutate(GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
gemla_context: (),
|
||||
}
|
||||
).await.unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||
}
|
||||
|
@ -188,7 +202,9 @@ mod tests {
|
|||
population: vec![0, 1, 3, 7],
|
||||
};
|
||||
|
||||
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()).await.unwrap();
|
||||
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
|
||||
assert!(merged_state.population.iter().any(|&x| x == 7));
|
||||
|
|
|
@ -5,10 +5,10 @@
|
|||
use crate::error::Error;
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use uuid::Uuid;
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// An enum used to control the state of a [`GeneticNode`]
|
||||
///
|
||||
|
@ -30,14 +30,14 @@ pub struct GeneticNodeContext<S> {
|
|||
pub generation: u64,
|
||||
pub max_generations: u64,
|
||||
pub id: Uuid,
|
||||
pub gemla_context: S
|
||||
pub gemla_context: S,
|
||||
}
|
||||
|
||||
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
|
||||
///
|
||||
/// [`Bracket`]: crate::bracket::Bracket
|
||||
#[async_trait]
|
||||
pub trait GeneticNode : Send {
|
||||
pub trait GeneticNode: Send {
|
||||
type Context;
|
||||
|
||||
/// Initializes a new instance of a [`GeneticState`].
|
||||
|
@ -54,7 +54,12 @@ pub trait GeneticNode : Send {
|
|||
/// TODO
|
||||
async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
|
||||
|
||||
async fn merge(left: &Self, right: &Self, id: &Uuid, context: Self::Context) -> Result<Box<Self>, Error>;
|
||||
async fn merge(
|
||||
left: &Self,
|
||||
right: &Self,
|
||||
id: &Uuid,
|
||||
context: Self::Context,
|
||||
) -> Result<Box<Self>, Error>;
|
||||
}
|
||||
|
||||
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
|
||||
|
@ -62,7 +67,7 @@ pub trait GeneticNode : Send {
|
|||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
pub struct GeneticNodeWrapper<T>
|
||||
where
|
||||
T: Clone
|
||||
T: Clone,
|
||||
{
|
||||
node: Option<T>,
|
||||
state: GeneticState,
|
||||
|
@ -73,7 +78,7 @@ where
|
|||
|
||||
impl<T> Default for GeneticNodeWrapper<T>
|
||||
where
|
||||
T: Clone
|
||||
T: Clone,
|
||||
{
|
||||
fn default() -> Self {
|
||||
GeneticNodeWrapper {
|
||||
|
@ -146,7 +151,8 @@ where
|
|||
self.state = GeneticState::Simulate;
|
||||
}
|
||||
(GeneticState::Simulate, Some(n)) => {
|
||||
n.simulate(context.clone()).await
|
||||
n.simulate(context.clone())
|
||||
.await
|
||||
.with_context(|| format!("Error simulating node: {:?}", self))?;
|
||||
|
||||
self.state = if self.generation >= self.max_generations {
|
||||
|
@ -156,7 +162,8 @@ where
|
|||
};
|
||||
}
|
||||
(GeneticState::Mutate, Some(n)) => {
|
||||
n.mutate(context.clone()).await
|
||||
n.mutate(context.clone())
|
||||
.await
|
||||
.with_context(|| format!("Error mutating node: {:?}", self))?;
|
||||
|
||||
self.generation += 1;
|
||||
|
@ -186,20 +193,33 @@ mod tests {
|
|||
impl GeneticNode for TestState {
|
||||
type Context = ();
|
||||
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
async fn simulate(
|
||||
&mut self,
|
||||
_context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<(), Error> {
|
||||
self.score += 1.0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
async fn mutate(
|
||||
&mut self,
|
||||
_context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
|
||||
async fn initialize(
|
||||
_context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState { score: 0.0 }))
|
||||
}
|
||||
|
||||
async fn merge(_l: &TestState, _r: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
|
||||
async fn merge(
|
||||
_l: &TestState,
|
||||
_r: &TestState,
|
||||
_id: &Uuid,
|
||||
_: Self::Context,
|
||||
) -> Result<Box<TestState>, Error> {
|
||||
Err(Error::Other(anyhow!("Unable to merge")))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,10 +10,11 @@ use futures::future;
|
|||
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
|
||||
use log::{info, trace, warn};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::{sync::RwLock, task::JoinHandle};
|
||||
use std::{
|
||||
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant
|
||||
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
|
||||
sync::Arc, time::Instant,
|
||||
};
|
||||
use tokio::{sync::RwLock, task::JoinHandle};
|
||||
use uuid::Uuid;
|
||||
|
||||
type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
|
||||
|
@ -80,13 +81,18 @@ where
|
|||
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Sync + Clone,
|
||||
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
|
||||
{
|
||||
pub async fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result<Self, Error> {
|
||||
pub async fn new(
|
||||
path: &Path,
|
||||
config: GemlaConfig,
|
||||
data_format: DataFormat,
|
||||
) -> Result<Self, Error> {
|
||||
match File::open(path) {
|
||||
// If the file exists we either want to overwrite the file or read from the file
|
||||
// based on the configuration provided
|
||||
Ok(_) => Ok(Gemla {
|
||||
data: if config.overwrite {
|
||||
FileLinked::new((None, config, T::Context::default()), path, data_format).await?
|
||||
FileLinked::new((None, config, T::Context::default()), path, data_format)
|
||||
.await?
|
||||
} else {
|
||||
FileLinked::from_file(path, data_format)?
|
||||
},
|
||||
|
@ -94,7 +100,8 @@ where
|
|||
}),
|
||||
// If the file doesn't exist we must create it
|
||||
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
|
||||
data: FileLinked::new((None, config, T::Context::default()), path, data_format).await?,
|
||||
data: FileLinked::new((None, config, T::Context::default()), path, data_format)
|
||||
.await?,
|
||||
threads: HashMap::new(),
|
||||
}),
|
||||
Err(error) => Err(Error::IO(error)),
|
||||
|
@ -106,24 +113,32 @@ where
|
|||
}
|
||||
|
||||
pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> {
|
||||
{
|
||||
let tree_completed = {
|
||||
// Only increase height if the tree is uninitialized or completed
|
||||
let data_arc = self.data.readonly();
|
||||
let data_ref = data_arc.read().await;
|
||||
let tree_ref = data_ref.0.as_ref();
|
||||
|
||||
if tree_ref.is_none() ||
|
||||
tree_ref
|
||||
.map(|t| Gemla::is_completed(t))
|
||||
.unwrap_or(true)
|
||||
{
|
||||
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
|
||||
// in the tree and which nodes have not.
|
||||
self.data.mutate(|(d, c, _)| {
|
||||
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
|
||||
tree_ref.is_none() || tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(true)
|
||||
};
|
||||
|
||||
if tree_completed {
|
||||
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
|
||||
// in the tree and which nodes have not.
|
||||
self.data
|
||||
.mutate(|(d, c, _)| {
|
||||
let mut tree: Option<SimulationTree<T>> =
|
||||
Gemla::increase_height(d.take(), c, steps);
|
||||
mem::swap(d, &mut tree);
|
||||
}).await?;
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
{
|
||||
// Only increase height if the tree is uninitialized or completed
|
||||
let data_arc = self.data.readonly();
|
||||
let data_ref = data_arc.read().await;
|
||||
let tree_ref = data_ref.0.as_ref();
|
||||
|
||||
info!(
|
||||
"Height of simulation tree increased to {}",
|
||||
|
@ -141,36 +156,36 @@ where
|
|||
let data_ref = data_arc.read().await;
|
||||
let tree_ref = data_ref.0.as_ref();
|
||||
|
||||
is_tree_processed = tree_ref
|
||||
.map(|t| Gemla::is_completed(t))
|
||||
.unwrap_or(false)
|
||||
is_tree_processed = tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(false)
|
||||
}
|
||||
|
||||
|
||||
// We need to keep simulating until the tree has been completely processed.
|
||||
if is_tree_processed
|
||||
{
|
||||
|
||||
if is_tree_processed {
|
||||
self.join_threads().await?;
|
||||
|
||||
info!("Processed tree");
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(node) = tree_ref
|
||||
.and_then(|t| self.get_unprocessed_node(t))
|
||||
{
|
||||
let (node, gemla_context) = {
|
||||
let data_arc = self.data.readonly();
|
||||
let data_ref = data_arc.read().await;
|
||||
let (tree_ref, _, gemla_context) = &*data_ref; // (Option<Box<Tree<GeneticNodeWrapper<T>>>, GemlaConfig, T::Context)
|
||||
|
||||
let node = tree_ref.as_ref().and_then(|t| self.get_unprocessed_node(t));
|
||||
|
||||
(node, gemla_context.clone())
|
||||
};
|
||||
|
||||
if let Some(node) = node {
|
||||
trace!("Adding node to process list {}", node.id());
|
||||
|
||||
let data_arc = self.data.readonly();
|
||||
let data_ref2 = data_arc.read().await;
|
||||
let gemla_context = data_ref2.2.clone();
|
||||
drop(data_ref2);
|
||||
let gemla_context = gemla_context.clone();
|
||||
|
||||
self.threads
|
||||
.insert(node.id(), tokio::spawn(async move {
|
||||
Gemla::process_node(node, gemla_context).await
|
||||
}));
|
||||
self.threads.insert(
|
||||
node.id(),
|
||||
tokio::spawn(async move { Gemla::process_node(node, gemla_context).await }),
|
||||
);
|
||||
} else {
|
||||
trace!("No node found to process, joining threads");
|
||||
|
||||
|
@ -195,32 +210,34 @@ where
|
|||
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
|
||||
match reduced_results {
|
||||
Ok(r) => {
|
||||
self.data.mutate_async(|d| async move {
|
||||
// Scope to limit the duration of the read lock
|
||||
let (_, context) = {
|
||||
let data_read = d.read().await;
|
||||
(data_read.1.clone(), data_read.2.clone())
|
||||
}; // Read lock is dropped here
|
||||
self.data
|
||||
.mutate_async(|d| async move {
|
||||
// Scope to limit the duration of the read lock
|
||||
let (_, context) = {
|
||||
let data_read = d.read().await;
|
||||
(data_read.1, data_read.2.clone())
|
||||
}; // Read lock is dropped here
|
||||
|
||||
let mut data_write = d.write().await;
|
||||
let mut data_write = d.write().await;
|
||||
|
||||
if let Some(t) = data_write.0.as_mut() {
|
||||
let failed_nodes = Gemla::replace_nodes(t, r);
|
||||
// We receive a list of nodes that were unable to be found in the original tree
|
||||
if !failed_nodes.is_empty() {
|
||||
warn!(
|
||||
"Unable to find {:?} to replace in tree",
|
||||
failed_nodes.iter().map(|n| n.id())
|
||||
)
|
||||
if let Some(t) = data_write.0.as_mut() {
|
||||
let failed_nodes = Gemla::replace_nodes(t, r);
|
||||
// We receive a list of nodes that were unable to be found in the original tree
|
||||
if !failed_nodes.is_empty() {
|
||||
warn!(
|
||||
"Unable to find {:?} to replace in tree",
|
||||
failed_nodes.iter().map(|n| n.id())
|
||||
)
|
||||
}
|
||||
|
||||
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
|
||||
Gemla::merge_completed_nodes(t, context.clone()).await
|
||||
} else {
|
||||
warn!("Unable to replce nodes {:?} in empty tree", r);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
|
||||
Gemla::merge_completed_nodes(t, context.clone()).await
|
||||
} else {
|
||||
warn!("Unable to replce nodes {:?} in empty tree", r);
|
||||
Ok(())
|
||||
}
|
||||
}).await??;
|
||||
})
|
||||
.await??;
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
|
@ -230,7 +247,10 @@ where
|
|||
}
|
||||
|
||||
#[async_recursion]
|
||||
async fn merge_completed_nodes<'a>(tree: &'a mut SimulationTree<T>, gemla_context: T::Context) -> Result<(), Error> {
|
||||
async fn merge_completed_nodes<'a>(
|
||||
tree: &'a mut SimulationTree<T>,
|
||||
gemla_context: T::Context,
|
||||
) -> Result<(), Error> {
|
||||
if tree.val.state() == GeneticState::Initialize {
|
||||
match (&mut tree.left, &mut tree.right) {
|
||||
// If the current node has been initialized, and has children nodes that are completed, then we need
|
||||
|
@ -241,7 +261,13 @@ where
|
|||
{
|
||||
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
|
||||
if let (Some(left_node), Some(right_node)) = (l.val.take(), r.val.take()) {
|
||||
let merged_node = GeneticNode::merge(&left_node, &right_node, &tree.val.id(), gemla_context.clone()).await?;
|
||||
let merged_node = GeneticNode::merge(
|
||||
&left_node,
|
||||
&right_node,
|
||||
&tree.val.id(),
|
||||
gemla_context.clone(),
|
||||
)
|
||||
.await?;
|
||||
tree.val = GeneticNodeWrapper::from(
|
||||
*merged_node,
|
||||
tree.val.max_generations(),
|
||||
|
@ -294,7 +320,10 @@ where
|
|||
// during join_threads.
|
||||
(Some(l), Some(r))
|
||||
if l.val.state() == GeneticState::Finish
|
||||
&& r.val.state() == GeneticState::Finish => Some(tree.val.clone()),
|
||||
&& r.val.state() == GeneticState::Finish =>
|
||||
{
|
||||
Some(tree.val.clone())
|
||||
}
|
||||
(Some(l), Some(r)) => self
|
||||
.get_unprocessed_node(l)
|
||||
.or_else(|| self.get_unprocessed_node(r)),
|
||||
|
@ -355,7 +384,10 @@ where
|
|||
tree.val.state() == GeneticState::Finish
|
||||
}
|
||||
|
||||
async fn process_node(mut node: GeneticNodeWrapper<T>, gemla_context: T::Context) -> Result<GeneticNodeWrapper<T>, Error> {
|
||||
async fn process_node(
|
||||
mut node: GeneticNodeWrapper<T>,
|
||||
gemla_context: T::Context,
|
||||
) -> Result<GeneticNodeWrapper<T>, Error> {
|
||||
let node_state_time = Instant::now();
|
||||
let node_state = node.state();
|
||||
|
||||
|
@ -379,10 +411,10 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::core::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::fs;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use self::genetic_node::GeneticNodeContext;
|
||||
|
@ -420,20 +452,33 @@ mod tests {
|
|||
impl genetic_node::GeneticNode for TestState {
|
||||
type Context = ();
|
||||
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
async fn simulate(
|
||||
&mut self,
|
||||
_context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<(), Error> {
|
||||
self.score += 1.0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
async fn mutate(
|
||||
&mut self,
|
||||
_context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
|
||||
async fn initialize(
|
||||
_context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState { score: 0.0 }))
|
||||
}
|
||||
|
||||
async fn merge(left: &TestState, right: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
|
||||
async fn merge(
|
||||
left: &TestState,
|
||||
right: &TestState,
|
||||
_id: &Uuid,
|
||||
_: Self::Context,
|
||||
) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(if left.score > right.score {
|
||||
left.clone()
|
||||
} else {
|
||||
|
@ -498,40 +543,43 @@ mod tests {
|
|||
Ok(())
|
||||
})
|
||||
})
|
||||
}).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
|
||||
})
|
||||
.await
|
||||
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn test_simulate() -> Result<(), Error> {
|
||||
// let path = PathBuf::from("test_simulate");
|
||||
// // Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
|
||||
// tokio::task::spawn_blocking(move || {
|
||||
// let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
|
||||
// CleanUp::new(&path).run(move |p| {
|
||||
// rt.block_on(async {
|
||||
// // Testing initial creation
|
||||
// let config = GemlaConfig {
|
||||
// generations_per_height: 10,
|
||||
// overwrite: true,
|
||||
// };
|
||||
// let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
#[tokio::test]
|
||||
async fn test_simulate() -> Result<(), Error> {
|
||||
let path = PathBuf::from("test_simulate");
|
||||
// Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
|
||||
CleanUp::new(&path).run(move |p| {
|
||||
rt.block_on(async {
|
||||
// Testing initial creation
|
||||
let config = GemlaConfig {
|
||||
generations_per_height: 10,
|
||||
overwrite: true,
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
|
||||
|
||||
// // Now we can use `.await` within the spawned blocking task.
|
||||
// gemla.simulate(5).await?;
|
||||
// let data = gemla.data.readonly();
|
||||
// let data_lock = data.read().unwrap();
|
||||
// let tree = data_lock.0.as_ref().unwrap();
|
||||
// assert_eq!(tree.height(), 5);
|
||||
// assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
|
||||
// Now we can use `.await` within the spawned blocking task.
|
||||
gemla.simulate(5).await?;
|
||||
let data = gemla.data.readonly();
|
||||
let data_lock = data.read().await;
|
||||
let tree = data_lock.0.as_ref().unwrap();
|
||||
assert_eq!(tree.height(), 5);
|
||||
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
|
||||
|
||||
// Ok(())
|
||||
// })
|
||||
// })
|
||||
// }).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue