Finalizing async implementation

This commit is contained in:
vandomej 2024-04-06 00:07:10 -07:00
parent b56e37d411
commit 7a1f82ac63
8 changed files with 1439 additions and 626 deletions

View file

@ -6,12 +6,11 @@ pub mod constants;
use anyhow::{anyhow, Context};
use constants::data_format::DataFormat;
use error::Error;
use futures::executor::block_on;
use log::info;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::RwLock;
use std::{
borrow::Borrow, fs::{copy, remove_file, File}, io::{ErrorKind, Write}, path::{Path, PathBuf}, sync::Arc, thread::{self, JoinHandle}
fs::{copy, remove_file, File}, io::{ErrorKind, Write}, path::{Path, PathBuf}, sync::Arc, thread::{self, JoinHandle}
};
@ -56,6 +55,7 @@ where
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path::PathBuf;
/// # use tokio;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -64,19 +64,22 @@ where
/// # pub c: f64
/// # }
/// #
/// # fn main() {
/// # #[tokio::main]
/// # async fn main() {
/// let test = Test {
/// a: 1,
/// b: String::from("two"),
/// c: 3.0
/// };
///
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json)
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
/// assert_eq!(linked_test.readonly().b, String::from("two"));
/// assert_eq!(linked_test.readonly().c, 3.0);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
/// assert_eq!(readonly_ref.b, String::from("two"));
/// assert_eq!(readonly_ref.c, 3.0);
/// #
/// # drop(linked_test);
/// #
@ -97,6 +100,7 @@ where
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path::PathBuf;
/// # use tokio;
/// #
/// #[derive(Deserialize, Serialize)]
/// struct Test {
@ -105,19 +109,22 @@ where
/// pub c: f64
/// }
///
/// # fn main() {
/// #[tokio::main]
/// # async fn main() {
/// let test = Test {
/// a: 1,
/// b: String::from("two"),
/// c: 3.0
/// };
///
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json)
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
/// assert_eq!(linked_test.readonly().b, String::from("two"));
/// assert_eq!(linked_test.readonly().c, 3.0);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
/// assert_eq!(readonly_ref.b, String::from("two"));
/// assert_eq!(readonly_ref.c, 3.0);
/// #
/// # drop(linked_test);
/// #
@ -207,6 +214,7 @@ where
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path::PathBuf;
/// # use tokio;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -215,21 +223,28 @@ where
/// # pub c: f64
/// # }
/// #
/// # fn main() -> Result<(), Error> {
/// # #[tokio::main]
/// # async fn main() -> Result<(), Error> {
/// let test = Test {
/// a: 1,
/// b: String::from(""),
/// c: 0.0
/// };
///
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode)
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
/// {
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
/// }
///
/// linked_test.mutate(|t| t.a = 2)?;
/// linked_test.mutate(|t| t.a = 2).await?;
///
/// assert_eq!(linked_test.readonly().a, 2);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 2);
/// #
/// # drop(linked_test);
/// #
@ -262,6 +277,7 @@ where
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path::PathBuf;
/// # use tokio;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -270,25 +286,30 @@ where
/// # pub c: f64
/// # }
/// #
/// # fn main() -> Result<(), Error> {
/// # #[tokio::main]
/// # async fn main() -> Result<(), Error> {
/// let test = Test {
/// a: 1,
/// b: String::from(""),
/// c: 0.0
/// };
///
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode)
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
///
/// linked_test.replace(Test {
/// a: 2,
/// b: String::from(""),
/// c: 0.0
/// })?;
/// }).await?;
///
/// assert_eq!(linked_test.readonly().a, 2);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 2);
/// #
/// # drop(linked_test);
/// #
@ -343,6 +364,7 @@ where
/// # use std::fs::OpenOptions;
/// # use std::io::Write;
/// # use std::path::PathBuf;
/// # use tokio;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -351,7 +373,8 @@ where
/// # pub c: f64
/// # }
/// #
/// # fn main() -> Result<(), Error> {
/// # #[tokio::main]
/// # async fn main() -> Result<(), Error> {
/// let test = Test {
/// a: 1,
/// b: String::from("2"),
@ -371,9 +394,11 @@ where
/// let mut linked_test = FileLinked::<Test>::from_file(&path, DataFormat::Bincode)
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, test.a);
/// assert_eq!(linked_test.readonly().b, test.b);
/// assert_eq!(linked_test.readonly().c, test.c);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, test.a);
/// assert_eq!(readonly_ref.b, test.b);
/// assert_eq!(readonly_ref.c, test.c);
/// #
/// # drop(linked_test);
/// #

View file

@ -3,18 +3,18 @@ extern crate gemla;
#[macro_use]
extern crate log;
mod test_state;
mod fighter_nn;
mod test_state;
use anyhow::Result;
use clap::Parser;
use fighter_nn::FighterNN;
use file_linked::constants::data_format::DataFormat;
use gemla::{
core::{Gemla, GemlaConfig},
error::log_error,
};
use std::{path::PathBuf, time::Instant};
use fighter_nn::FighterNN;
use clap::Parser;
use anyhow::Result;
// const NUM_THREADS: usize = 2;
@ -44,14 +44,17 @@ fn main() -> Result<()> {
.build()?
.block_on(async {
let args = Args::parse(); // Assuming Args::parse() doesn't need to be async
let mut gemla = log_error(Gemla::<FighterNN>::new(
&PathBuf::from(args.file),
GemlaConfig {
generations_per_height: 5,
overwrite: false,
},
DataFormat::Json,
).await)?;
let mut gemla = log_error(
Gemla::<FighterNN>::new(
&PathBuf::from(args.file),
GemlaConfig {
generations_per_height: 5,
overwrite: false,
},
DataFormat::Json,
)
.await,
)?;
// let gemla_arc = Arc::new(gemla);
@ -59,7 +62,8 @@ fn main() -> Result<()> {
// If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks
// Example placeholder loop to continuously run simulate
loop { // Arbitrary loop count for demonstration
loop {
// Arbitrary loop count for demonstration
gemla.simulate(1).await?;
}
});

View file

@ -5,7 +5,6 @@ use tokio::sync::Semaphore;
const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 50;
#[derive(Debug, Clone)]
pub struct FighterContext {
pub shared_semaphore: Arc<Semaphore>,
@ -19,7 +18,6 @@ impl Default for FighterContext {
}
}
// Custom serialization to just output the concurrency limit.
impl Serialize for FighterContext {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>

View file

@ -1,20 +1,29 @@
extern crate fann;
pub mod neural_network_utility;
pub mod fighter_context;
pub mod neural_network_utility;
use std::{cmp::max, collections::{HashSet, VecDeque}, fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, panic::{catch_unwind, AssertUnwindSafe}, path::{Path, PathBuf}, sync::{Arc, Mutex}, time::Duration};
use anyhow::Context;
use async_trait::async_trait;
use fann::{ActivationFunc, Fann};
use futures::{executor::block_on, future::{join, join_all, select_all}, stream::FuturesUnordered, FutureExt, StreamExt};
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
use futures::future::join_all;
use gemla::{
core::genetic_node::{GeneticNode, GeneticNodeContext},
error::Error,
};
use lerp::Lerp;
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use anyhow::Context;
use tokio::{process::Command, sync::{mpsc, Semaphore}, task, time::{sleep, timeout, Sleep}};
use uuid::Uuid;
use std::collections::HashMap;
use async_trait::async_trait;
use std::{
cmp::max,
fs::{self, File},
io::{self, BufRead, BufReader},
ops::Range,
path::{Path, PathBuf},
};
use tokio::process::Command;
use uuid::Uuid;
use self::neural_network_utility::{crossbreed, major_mutation};
@ -34,7 +43,8 @@ const NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX: usize = 20;
const SIMULATION_ROUNDS: usize = 5;
const SURVIVAL_RATE: f32 = 0.5;
const GAME_EXECUTABLE_PATH: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe";
const GAME_EXECUTABLE_PATH: &str =
"F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe";
// Here is the folder structure for the FighterNN:
// base_dir/fighter_nn_{fighter_id}/{generation}/{fighter_id}_fighter_nn_{nn_id}.net
@ -78,11 +88,17 @@ impl GeneticNode for FighterNN {
//Create a new directory for the first generation, using create_dir_all to avoid errors if it already exists
let gen_folder = folder.join("0");
fs::create_dir_all(&gen_folder)
.with_context(|| format!("Failed to create or access the generation folder: {:?}", gen_folder))?;
fs::create_dir_all(&gen_folder).with_context(|| {
format!(
"Failed to create or access the generation folder: {:?}",
gen_folder
)
})?;
let mut nn_shapes = HashMap::new();
let weight_initialization_range = thread_rng().gen_range(NEURAL_NETWORK_INITIAL_WEIGHT_MIN..0.0)..thread_rng().gen_range(0.0..=NEURAL_NETWORK_INITIAL_WEIGHT_MAX);
let weight_initialization_range = thread_rng()
.gen_range(NEURAL_NETWORK_INITIAL_WEIGHT_MIN..0.0)
..thread_rng().gen_range(0.0..=NEURAL_NETWORK_INITIAL_WEIGHT_MAX);
// Create the first generation in this folder
for i in 0..POPULATION {
@ -90,17 +106,22 @@ impl GeneticNode for FighterNN {
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i));
// Randomly generate a neural network shape based on constants
let hidden_layers = thread_rng().gen_range(NEURAL_NETWORK_HIDDEN_LAYERS_MIN..NEURAL_NETWORK_HIDDEN_LAYERS_MAX);
let hidden_layers = thread_rng()
.gen_range(NEURAL_NETWORK_HIDDEN_LAYERS_MIN..NEURAL_NETWORK_HIDDEN_LAYERS_MAX);
let mut nn_shape = vec![NEURAL_NETWORK_INPUTS as u32];
for _ in 0..hidden_layers {
nn_shape.push(thread_rng().gen_range(NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN..NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX) as u32);
nn_shape.push(thread_rng().gen_range(
NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN..NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX,
) as u32);
}
nn_shape.push(NEURAL_NETWORK_OUTPUTS as u32);
nn_shapes.insert(i as u64, nn_shape.clone());
let mut fann = Fann::new(nn_shape.as_slice())
.with_context(|| "Failed to create nn")?;
fann.randomize_weights(weight_initialization_range.start, weight_initialization_range.end);
let mut fann = Fann::new(nn_shape.as_slice()).with_context(|| "Failed to create nn")?;
fann.randomize_weights(
weight_initialization_range.start,
weight_initialization_range.end,
);
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
// This will overwrite any existing file with the same name
@ -108,7 +129,9 @@ impl GeneticNode for FighterNN {
.with_context(|| format!("Failed to save nn at {:?}", nn))?;
}
let mut crossbreed_segments = thread_rng().gen_range(NEURAL_NETWORK_CROSSBREED_SEGMENTS_MIN..NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX);
let mut crossbreed_segments = thread_rng().gen_range(
NEURAL_NETWORK_CROSSBREED_SEGMENTS_MIN..NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX,
);
if crossbreed_segments % 2 == 0 {
crossbreed_segments += 1;
}
@ -141,7 +164,10 @@ impl GeneticNode for FighterNN {
let semaphore_clone = context.gemla_context.shared_semaphore.clone();
let task = async move {
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(self_clone.get_individual_id(i as u64));
let nn = self_clone
.folder
.join(format!("{}", self_clone.generation))
.join(self_clone.get_individual_id(i as u64));
let mut simulations = Vec::new();
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
@ -151,11 +177,16 @@ impl GeneticNode for FighterNN {
let generation = self_clone.generation;
let semaphore_clone = semaphore_clone.clone();
let random_nn = folder.join(format!("{}", generation)).join(self_clone.get_individual_id(random_nn_index as u64));
let random_nn = folder
.join(format!("{}", generation))
.join(self_clone.get_individual_id(random_nn_index as u64));
let nn_clone = nn.clone(); // Clone the path to use in the async block
let future = async move {
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
let permit = semaphore_clone
.acquire_owned()
.await
.with_context(|| "Failed to acquire semaphore permit")?;
let (score, _) = run_1v1_simulation(&nn_clone, &random_nn).await?;
@ -168,7 +199,8 @@ impl GeneticNode for FighterNN {
}
// Wait for all simulation rounds to complete
let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect();
let results: Result<Vec<f32>, Error> =
join_all(simulations).await.into_iter().collect();
let score = match results {
Ok(scores) => scores.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32,
@ -188,34 +220,46 @@ impl GeneticNode for FighterNN {
Ok((index, score)) => {
// Update the original `self` object with the score.
self.scores[self.generation as usize].insert(index as u64, score);
},
}
Err(e) => {
// Handle task panic or execution error
return Err(Error::Other(anyhow::anyhow!(format!("Task failed: {:?}", e))));
},
return Err(Error::Other(anyhow::anyhow!(format!(
"Task failed: {:?}",
e
))));
}
}
}
Ok(())
}
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
// Create the new generation folder
let new_gen_folder = self.folder.join(format!("{}", self.generation + 1));
fs::create_dir_all(&new_gen_folder).with_context(|| format!("Failed to create or access new generation folder: {:?}", new_gen_folder))?;
fs::create_dir_all(&new_gen_folder).with_context(|| {
format!(
"Failed to create or access new generation folder: {:?}",
new_gen_folder
)
})?;
// Remove the 5 nn's with the lowest scores
let mut sorted_scores: Vec<_> = self.scores[self.generation as usize].iter().collect();
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
let to_keep = sorted_scores[survivor_count..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
let to_keep = sorted_scores[survivor_count..]
.iter()
.map(|(k, _)| *k)
.collect::<Vec<_>>();
// Save the remaining 5 nn's to the new generation folder
for i in 0..survivor_count {
let nn_id = to_keep[i];
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id));
for (i, nn_id) in to_keep.iter().enumerate().take(survivor_count) {
let nn = self
.folder
.join(format!("{}", self.generation))
.join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id));
let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i));
fs::copy(&nn, &new_nn)?;
}
@ -223,16 +267,25 @@ impl GeneticNode for FighterNN {
// Take the remaining 5 nn's and create 5 new nn's by the following:
for i in 0..survivor_count {
let nn_id = to_keep[i];
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id));
let fann = Fann::from_file(&nn)
.with_context(|| format!("Failed to load nn"))?;
let nn = self
.folder
.join(format!("{}", self.generation))
.join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id));
let fann = Fann::from_file(&nn).with_context(|| "Failed to load nn")?;
// Load another nn from the current generation and cross breed it with the current nn
let cross_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, to_keep[thread_rng().gen_range(0..survivor_count)]));
let cross_fann = Fann::from_file(&cross_nn)
.with_context(|| format!("Failed to load cross nn"))?;
let cross_nn = self
.folder
.join(format!("{}", self.generation))
.join(format!(
"{:06}_fighter_nn_{}.net",
self.id,
to_keep[thread_rng().gen_range(0..survivor_count)]
));
let cross_fann =
Fann::from_file(&cross_nn).with_context(|| "Failed to load cross nn")?;
let mut new_fann = crossbreed(&self, &fann, &cross_fann, self.crossbreed_segments)?;
let mut new_fann = crossbreed(self, &fann, &cross_fann, self.crossbreed_segments)?;
// For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight)
// And a 5% chance of a major mutation a new neuron is randomly added to a hidden layer
@ -252,9 +305,14 @@ impl GeneticNode for FighterNN {
}
// Save the new nn's to the new generation folder
let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i + survivor_count));
new_fann.save(&new_nn)
.with_context(|| format!("Failed to save nn"))?;
let new_nn = new_gen_folder.join(format!(
"{:06}_fighter_nn_{}.net",
self.id,
i + survivor_count
));
new_fann
.save(&new_nn)
.with_context(|| "Failed to save nn")?;
}
self.generation += 1;
@ -263,18 +321,28 @@ impl GeneticNode for FighterNN {
Ok(())
}
async fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid, gemla_context: Self::Context) -> Result<Box<FighterNN>, Error> {
async fn merge(
left: &FighterNN,
right: &FighterNN,
id: &Uuid,
gemla_context: Self::Context,
) -> Result<Box<FighterNN>, Error> {
let base_path = PathBuf::from(BASE_DIR);
let folder = base_path.join(format!("fighter_nn_{:06}", id));
// Ensure the folder exists, including the generation subfolder.
fs::create_dir_all(&folder.join("0"))
fs::create_dir_all(folder.join("0"))
.with_context(|| format!("Failed to create directory {:?}", folder.join("0")))?;
let get_highest_scores = |fighter: &FighterNN| -> Vec<(u64, f32)> {
let mut sorted_scores: Vec<_> = fighter.scores[fighter.generation as usize].iter().collect();
let mut sorted_scores: Vec<_> =
fighter.scores[fighter.generation as usize].iter().collect();
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
sorted_scores.iter().take(fighter.population_size / 2).map(|(k, v)| (**k, **v)).collect()
sorted_scores
.iter()
.take(fighter.population_size / 2)
.map(|(k, v)| (**k, **v))
.collect()
};
let left_scores = get_highest_scores(left);
@ -285,18 +353,28 @@ impl GeneticNode for FighterNN {
let mut simulations = Vec::new();
for _ in 0..max(left.population_size, right.population_size)*SIMULATION_ROUNDS {
for _ in 0..max(left.population_size, right.population_size) * SIMULATION_ROUNDS {
let left_nn_id = left_scores[thread_rng().gen_range(0..left_scores.len())].0;
let right_nn_id = right_scores[thread_rng().gen_range(0..right_scores.len())].0;
let left_nn_path = left.folder.join(left.generation.to_string()).join(left.get_individual_id(left_nn_id));
let right_nn_path = right.folder.join(right.generation.to_string()).join(right.get_individual_id(right_nn_id));
let left_nn_path = left
.folder
.join(left.generation.to_string())
.join(left.get_individual_id(left_nn_id));
let right_nn_path = right
.folder
.join(right.generation.to_string())
.join(right.get_individual_id(right_nn_id));
let semaphore_clone = gemla_context.shared_semaphore.clone();
let future = async move {
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
let permit = semaphore_clone
.acquire_owned()
.await
.with_context(|| "Failed to acquire semaphore permit")?;
let (left_score, right_score) = run_1v1_simulation(&left_nn_path, &right_nn_path).await?;
let (left_score, right_score) =
run_1v1_simulation(&left_nn_path, &right_nn_path).await?;
drop(permit);
@ -306,7 +384,8 @@ impl GeneticNode for FighterNN {
simulations.push(future);
}
let results: Result<Vec<(f32, f32)>, Error> = join_all(simulations).await.into_iter().collect();
let results: Result<Vec<(f32, f32)>, Error> =
join_all(simulations).await.into_iter().collect();
let scores = results?;
let total_left_score = scores.iter().map(|(l, _)| l).sum::<f32>();
@ -322,18 +401,36 @@ impl GeneticNode for FighterNN {
let mut nn_shapes = HashMap::new();
// Function to copy NNs from a source FighterNN to the new folder.
let mut copy_nns = |source: &FighterNN, folder: &PathBuf, id: &Uuid, start_idx: usize| -> Result<(), Error> {
let mut sorted_scores: Vec<_> = source.scores[source.generation as usize].iter().collect();
let mut copy_nns = |source: &FighterNN,
folder: &PathBuf,
id: &Uuid,
start_idx: usize|
-> Result<(), Error> {
let mut sorted_scores: Vec<_> =
source.scores[source.generation as usize].iter().collect();
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
let remaining = sorted_scores[(source.population_size / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
let remaining = sorted_scores[(source.population_size / 2)..]
.iter()
.map(|(k, _)| *k)
.collect::<Vec<_>>();
for (i, nn_id) in remaining.into_iter().enumerate() {
let nn_path = source.folder.join(source.generation.to_string()).join(format!("{:06}_fighter_nn_{}.net", source.id, nn_id));
let new_nn_path = folder.join("0").join(format!("{:06}_fighter_nn_{}.net", id, start_idx + i));
fs::copy(&nn_path, &new_nn_path)
.with_context(|| format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path))?;
let nn_path = source
.folder
.join(source.generation.to_string())
.join(format!("{:06}_fighter_nn_{}.net", source.id, nn_id));
let new_nn_path =
folder
.join("0")
.join(format!("{:06}_fighter_nn_{}.net", id, start_idx + i));
fs::copy(&nn_path, &new_nn_path).with_context(|| {
format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path)
})?;
nn_shapes.insert((start_idx + i) as u64, source.nn_shapes.get(&nn_id).unwrap().clone());
nn_shapes.insert(
(start_idx + i) as u64,
source.nn_shapes.get(nn_id).unwrap().clone(),
);
}
Ok(())
@ -341,32 +438,54 @@ impl GeneticNode for FighterNN {
// Copy the top half of NNs from each parent to the new folder.
copy_nns(left, &folder, id, 0)?;
copy_nns(right, &folder, id, left.population_size as usize / 2)?;
copy_nns(right, &folder, id, left.population_size / 2)?;
debug!("nn_shapes: {:?}", nn_shapes);
// Lerp the mutation rates and weight ranges
let crossbreed_segments = (left.crossbreed_segments as f32).lerp(right.crossbreed_segments as f32, lerp_amount) as usize;
let crossbreed_segments = (left.crossbreed_segments as f32)
.lerp(right.crossbreed_segments as f32, lerp_amount)
as usize;
let weight_initialization_range_start = left.weight_initialization_range.start.lerp(right.weight_initialization_range.start, lerp_amount);
let weight_initialization_range_end = left.weight_initialization_range.end.lerp(right.weight_initialization_range.end, lerp_amount);
let weight_initialization_range_start = left
.weight_initialization_range
.start
.lerp(right.weight_initialization_range.start, lerp_amount);
let weight_initialization_range_end = left
.weight_initialization_range
.end
.lerp(right.weight_initialization_range.end, lerp_amount);
// Have to ensure the range is valid
let weight_initialization_range = if weight_initialization_range_start < weight_initialization_range_end {
weight_initialization_range_start..weight_initialization_range_end
} else {
weight_initialization_range_end..weight_initialization_range_start
};
let weight_initialization_range =
if weight_initialization_range_start < weight_initialization_range_end {
weight_initialization_range_start..weight_initialization_range_end
} else {
weight_initialization_range_end..weight_initialization_range_start
};
debug!("weight_initialization_range: {:?}", weight_initialization_range);
debug!(
"weight_initialization_range: {:?}",
weight_initialization_range
);
let minor_mutation_rate = left.minor_mutation_rate.lerp(right.minor_mutation_rate, lerp_amount);
let major_mutation_rate = left.major_mutation_rate.lerp(right.major_mutation_rate, lerp_amount);
let minor_mutation_rate = left
.minor_mutation_rate
.lerp(right.minor_mutation_rate, lerp_amount);
let major_mutation_rate = left
.major_mutation_rate
.lerp(right.major_mutation_rate, lerp_amount);
debug!("minor_mutation_rate: {}", minor_mutation_rate);
debug!("major_mutation_rate: {}", major_mutation_rate);
let mutation_weight_range_start = left.mutation_weight_range.start.lerp(right.mutation_weight_range.start, lerp_amount);
let mutation_weight_range_end = left.mutation_weight_range.end.lerp(right.mutation_weight_range.end, lerp_amount);
let mutation_weight_range_start = left
.mutation_weight_range
.start
.lerp(right.mutation_weight_range.start, lerp_amount);
let mutation_weight_range_end = left
.mutation_weight_range
.end
.lerp(right.mutation_weight_range.end, lerp_amount);
// Have to ensure the range is valid
let mutation_weight_range = if mutation_weight_range_start < mutation_weight_range_end {
mutation_weight_range_start..mutation_weight_range_end
@ -398,7 +517,7 @@ impl FighterNN {
}
}
async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<(f32, f32), Error> {
async fn run_1v1_simulation(nn_path_1: &Path, nn_path_2: &Path) -> Result<(f32, f32), Error> {
// Construct the score file path
let base_folder = nn_path_1.parent().unwrap();
let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap();
@ -407,14 +526,18 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<
// Check if score file already exists before running the simulation
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_1_id).await
let round_score = read_score_from_file(&score_file, nn_1_id)
.await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
let opposing_score = read_score_from_file(&score_file, &nn_2_id).await
let opposing_score = read_score_from_file(&score_file, nn_2_id)
.await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score);
debug!(
"{} scored {}, while {} scored {}",
nn_1_id, round_score, nn_2_id, opposing_score
);
return Ok((round_score, opposing_score));
}
@ -422,13 +545,22 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<
// Check if the opposite round score has been determined
let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id));
if opposite_score_file.exists() {
let round_score = read_score_from_file(&opposite_score_file, &nn_1_id).await
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
let round_score = read_score_from_file(&opposite_score_file, nn_1_id)
.await
.with_context(|| {
format!("Failed to read score from file: {:?}", opposite_score_file)
})?;
let opposing_score = read_score_from_file(&opposite_score_file, &nn_2_id).await
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
let opposing_score = read_score_from_file(&opposite_score_file, nn_2_id)
.await
.with_context(|| {
format!("Failed to read score from file: {:?}", opposite_score_file)
})?;
debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score);
debug!(
"{} scored {}, while {} scored {}",
nn_1_id, round_score, nn_2_id, opposing_score
);
return Ok((round_score, opposing_score));
}
@ -459,20 +591,29 @@ async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<
.expect("Failed to execute game")
};
trace!("Simulation completed for {} vs {}: {}", nn_1_id, nn_2_id, score_file.exists());
trace!(
"Simulation completed for {} vs {}: {}",
nn_1_id,
nn_2_id,
score_file.exists()
);
// Read the score from the file
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_1_id).await
let round_score = read_score_from_file(&score_file, nn_1_id)
.await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
let opposing_score = read_score_from_file(&score_file, &nn_2_id).await
let opposing_score = read_score_from_file(&score_file, nn_2_id)
.await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
debug!("{} scored {}, while {} scored {}", nn_1_id, round_score, nn_2_id, opposing_score);
debug!(
"{} scored {}, while {} scored {}",
nn_1_id, round_score, nn_2_id, opposing_score
);
return Ok((round_score, opposing_score))
Ok((round_score, opposing_score))
} else {
warn!("Score file not found: {:?}", score_file);
Ok((0.0, 0.0))
@ -492,7 +633,10 @@ async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::
if line.starts_with(nn_id) {
let parts: Vec<&str> = line.split(':').collect();
if parts.len() == 2 {
return parts[1].trim().parse::<f32>().map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e));
return parts[1]
.trim()
.parse::<f32>()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e));
}
}
}
@ -501,16 +645,21 @@ async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::
io::ErrorKind::NotFound,
"NN ID not found in scores file",
));
},
Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::PermissionDenied || e.kind() == io::ErrorKind::Other => {
if attempts >= 5 { // Attempt 5 times before giving up.
}
Err(e)
if e.kind() == io::ErrorKind::WouldBlock
|| e.kind() == io::ErrorKind::PermissionDenied
|| e.kind() == io::ErrorKind::Other =>
{
if attempts >= 5 {
// Attempt 5 times before giving up.
return Err(e);
}
attempts += 1;
// wait 1 second to ensure the file is written
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
},
}
Err(e) => return Err(e),
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,11 @@
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
use async_trait::async_trait;
use gemla::{
core::genetic_node::{GeneticNode, GeneticNodeContext},
error::Error,
};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use async_trait::async_trait;
const POPULATION_SIZE: u64 = 5;
const POPULATION_REDUCTION_SIZE: u64 = 3;
@ -76,7 +79,12 @@ impl GeneticNode for TestState {
Ok(())
}
async fn merge(left: &TestState, right: &TestState, id: &Uuid, gemla_context: Self::Context) -> Result<Box<TestState>, Error> {
async fn merge(
left: &TestState,
right: &TestState,
id: &Uuid,
gemla_context: Self::Context,
) -> Result<Box<TestState>, Error> {
let mut v = left.population.clone();
v.append(&mut right.population.clone());
@ -87,12 +95,14 @@ impl GeneticNode for TestState {
let mut result = TestState { population: v };
result.mutate(GeneticNodeContext {
id: id.clone(),
generation: 0,
max_generations: 0,
gemla_context: gemla_context
}).await?;
result
.mutate(GeneticNodeContext {
id: *id,
generation: 0,
max_generations: 0,
gemla_context,
})
.await?;
Ok(Box::new(result))
}
@ -105,14 +115,14 @@ mod tests {
#[tokio::test]
async fn test_initialize() {
let state = TestState::initialize(
GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
gemla_context: (),
}
).await.unwrap();
let state = TestState::initialize(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
gemla_context: (),
})
.await
.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
}
@ -125,35 +135,38 @@ mod tests {
let original_population = state.population.clone();
state.simulate(
GeneticNodeContext {
state
.simulate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
gemla_context: (),
}
).await.unwrap();
})
.await
.unwrap();
assert!(original_population
.iter()
.zip(state.population.iter())
.all(|(&a, &b)| b >= a - 1 && b <= a + 2));
state.simulate(
GeneticNodeContext {
state
.simulate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
gemla_context: (),
}
).await.unwrap();
state.simulate(
GeneticNodeContext {
})
.await
.unwrap();
state
.simulate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
gemla_context: (),
}
).await.unwrap();
})
.await
.unwrap();
assert!(original_population
.iter()
.zip(state.population.iter())
@ -166,14 +179,15 @@ mod tests {
population: vec![4, 3, 3],
};
state.mutate(
GeneticNodeContext {
state
.mutate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
gemla_context: (),
}
).await.unwrap();
})
.await
.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
}
@ -188,7 +202,9 @@ mod tests {
population: vec![0, 1, 3, 7],
};
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()).await.unwrap();
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ())
.await
.unwrap();
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
assert!(merged_state.population.iter().any(|&x| x == 7));

View file

@ -5,10 +5,10 @@
use crate::error::Error;
use anyhow::Context;
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt::Debug;
use uuid::Uuid;
use async_trait::async_trait;
/// An enum used to control the state of a [`GeneticNode`]
///
@ -30,14 +30,14 @@ pub struct GeneticNodeContext<S> {
pub generation: u64,
pub max_generations: u64,
pub id: Uuid,
pub gemla_context: S
pub gemla_context: S,
}
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
///
/// [`Bracket`]: crate::bracket::Bracket
#[async_trait]
pub trait GeneticNode : Send {
pub trait GeneticNode: Send {
type Context;
/// Initializes a new instance of a [`GeneticState`].
@ -54,7 +54,12 @@ pub trait GeneticNode : Send {
/// TODO
async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
async fn merge(left: &Self, right: &Self, id: &Uuid, context: Self::Context) -> Result<Box<Self>, Error>;
async fn merge(
left: &Self,
right: &Self,
id: &Uuid,
context: Self::Context,
) -> Result<Box<Self>, Error>;
}
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
@ -62,7 +67,7 @@ pub trait GeneticNode : Send {
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct GeneticNodeWrapper<T>
where
T: Clone
T: Clone,
{
node: Option<T>,
state: GeneticState,
@ -73,7 +78,7 @@ where
impl<T> Default for GeneticNodeWrapper<T>
where
T: Clone
T: Clone,
{
fn default() -> Self {
GeneticNodeWrapper {
@ -146,7 +151,8 @@ where
self.state = GeneticState::Simulate;
}
(GeneticState::Simulate, Some(n)) => {
n.simulate(context.clone()).await
n.simulate(context.clone())
.await
.with_context(|| format!("Error simulating node: {:?}", self))?;
self.state = if self.generation >= self.max_generations {
@ -156,7 +162,8 @@ where
};
}
(GeneticState::Mutate, Some(n)) => {
n.mutate(context.clone()).await
n.mutate(context.clone())
.await
.with_context(|| format!("Error mutating node: {:?}", self))?;
self.generation += 1;
@ -186,20 +193,33 @@ mod tests {
impl GeneticNode for TestState {
type Context = ();
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
async fn simulate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
self.score += 1.0;
Ok(())
}
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
async fn mutate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
Ok(())
}
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
async fn initialize(
_context: GeneticNodeContext<Self::Context>,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
async fn merge(_l: &TestState, _r: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
async fn merge(
_l: &TestState,
_r: &TestState,
_id: &Uuid,
_: Self::Context,
) -> Result<Box<TestState>, Error> {
Err(Error::Other(anyhow!("Unable to merge")))
}
}

View file

@ -10,10 +10,11 @@ use futures::future;
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
use log::{info, trace, warn};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::{sync::RwLock, task::JoinHandle};
use std::{
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
sync::Arc, time::Instant,
};
use tokio::{sync::RwLock, task::JoinHandle};
use uuid::Uuid;
type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
@ -80,13 +81,18 @@ where
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Sync + Clone,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{
pub async fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result<Self, Error> {
pub async fn new(
path: &Path,
config: GemlaConfig,
data_format: DataFormat,
) -> Result<Self, Error> {
match File::open(path) {
// If the file exists we either want to overwrite the file or read from the file
// based on the configuration provided
Ok(_) => Ok(Gemla {
data: if config.overwrite {
FileLinked::new((None, config, T::Context::default()), path, data_format).await?
FileLinked::new((None, config, T::Context::default()), path, data_format)
.await?
} else {
FileLinked::from_file(path, data_format)?
},
@ -94,7 +100,8 @@ where
}),
// If the file doesn't exist we must create it
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
data: FileLinked::new((None, config, T::Context::default()), path, data_format).await?,
data: FileLinked::new((None, config, T::Context::default()), path, data_format)
.await?,
threads: HashMap::new(),
}),
Err(error) => Err(Error::IO(error)),
@ -106,24 +113,32 @@ where
}
pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> {
{
let tree_completed = {
// Only increase height if the tree is uninitialized or completed
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
if tree_ref.is_none() ||
tree_ref
.map(|t| Gemla::is_completed(t))
.unwrap_or(true)
{
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
// in the tree and which nodes have not.
self.data.mutate(|(d, c, _)| {
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
tree_ref.is_none() || tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(true)
};
if tree_completed {
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
// in the tree and which nodes have not.
self.data
.mutate(|(d, c, _)| {
let mut tree: Option<SimulationTree<T>> =
Gemla::increase_height(d.take(), c, steps);
mem::swap(d, &mut tree);
}).await?;
}
})
.await?;
}
{
// Only increase height if the tree is uninitialized or completed
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
info!(
"Height of simulation tree increased to {}",
@ -141,36 +156,36 @@ where
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
is_tree_processed = tree_ref
.map(|t| Gemla::is_completed(t))
.unwrap_or(false)
is_tree_processed = tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(false)
}
// We need to keep simulating until the tree has been completely processed.
if is_tree_processed
{
if is_tree_processed {
self.join_threads().await?;
info!("Processed tree");
break;
}
if let Some(node) = tree_ref
.and_then(|t| self.get_unprocessed_node(t))
{
let (node, gemla_context) = {
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let (tree_ref, _, gemla_context) = &*data_ref; // (Option<Box<Tree<GeneticNodeWrapper<T>>>, GemlaConfig, T::Context)
let node = tree_ref.as_ref().and_then(|t| self.get_unprocessed_node(t));
(node, gemla_context.clone())
};
if let Some(node) = node {
trace!("Adding node to process list {}", node.id());
let data_arc = self.data.readonly();
let data_ref2 = data_arc.read().await;
let gemla_context = data_ref2.2.clone();
drop(data_ref2);
let gemla_context = gemla_context.clone();
self.threads
.insert(node.id(), tokio::spawn(async move {
Gemla::process_node(node, gemla_context).await
}));
self.threads.insert(
node.id(),
tokio::spawn(async move { Gemla::process_node(node, gemla_context).await }),
);
} else {
trace!("No node found to process, joining threads");
@ -195,32 +210,34 @@ where
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
match reduced_results {
Ok(r) => {
self.data.mutate_async(|d| async move {
// Scope to limit the duration of the read lock
let (_, context) = {
let data_read = d.read().await;
(data_read.1.clone(), data_read.2.clone())
}; // Read lock is dropped here
self.data
.mutate_async(|d| async move {
// Scope to limit the duration of the read lock
let (_, context) = {
let data_read = d.read().await;
(data_read.1, data_read.2.clone())
}; // Read lock is dropped here
let mut data_write = d.write().await;
let mut data_write = d.write().await;
if let Some(t) = data_write.0.as_mut() {
let failed_nodes = Gemla::replace_nodes(t, r);
// We receive a list of nodes that were unable to be found in the original tree
if !failed_nodes.is_empty() {
warn!(
"Unable to find {:?} to replace in tree",
failed_nodes.iter().map(|n| n.id())
)
if let Some(t) = data_write.0.as_mut() {
let failed_nodes = Gemla::replace_nodes(t, r);
// We receive a list of nodes that were unable to be found in the original tree
if !failed_nodes.is_empty() {
warn!(
"Unable to find {:?} to replace in tree",
failed_nodes.iter().map(|n| n.id())
)
}
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
Gemla::merge_completed_nodes(t, context.clone()).await
} else {
warn!("Unable to replce nodes {:?} in empty tree", r);
Ok(())
}
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
Gemla::merge_completed_nodes(t, context.clone()).await
} else {
warn!("Unable to replce nodes {:?} in empty tree", r);
Ok(())
}
}).await??;
})
.await??;
}
Err(e) => return Err(e),
}
@ -230,7 +247,10 @@ where
}
#[async_recursion]
async fn merge_completed_nodes<'a>(tree: &'a mut SimulationTree<T>, gemla_context: T::Context) -> Result<(), Error> {
async fn merge_completed_nodes<'a>(
tree: &'a mut SimulationTree<T>,
gemla_context: T::Context,
) -> Result<(), Error> {
if tree.val.state() == GeneticState::Initialize {
match (&mut tree.left, &mut tree.right) {
// If the current node has been initialized, and has children nodes that are completed, then we need
@ -241,7 +261,13 @@ where
{
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
if let (Some(left_node), Some(right_node)) = (l.val.take(), r.val.take()) {
let merged_node = GeneticNode::merge(&left_node, &right_node, &tree.val.id(), gemla_context.clone()).await?;
let merged_node = GeneticNode::merge(
&left_node,
&right_node,
&tree.val.id(),
gemla_context.clone(),
)
.await?;
tree.val = GeneticNodeWrapper::from(
*merged_node,
tree.val.max_generations(),
@ -294,7 +320,10 @@ where
// during join_threads.
(Some(l), Some(r))
if l.val.state() == GeneticState::Finish
&& r.val.state() == GeneticState::Finish => Some(tree.val.clone()),
&& r.val.state() == GeneticState::Finish =>
{
Some(tree.val.clone())
}
(Some(l), Some(r)) => self
.get_unprocessed_node(l)
.or_else(|| self.get_unprocessed_node(r)),
@ -355,7 +384,10 @@ where
tree.val.state() == GeneticState::Finish
}
async fn process_node(mut node: GeneticNodeWrapper<T>, gemla_context: T::Context) -> Result<GeneticNodeWrapper<T>, Error> {
async fn process_node(
mut node: GeneticNodeWrapper<T>,
gemla_context: T::Context,
) -> Result<GeneticNodeWrapper<T>, Error> {
let node_state_time = Instant::now();
let node_state = node.state();
@ -379,10 +411,10 @@ where
#[cfg(test)]
mod tests {
use crate::core::*;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::fs;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use tokio::runtime::Runtime;
use self::genetic_node::GeneticNodeContext;
@ -420,20 +452,33 @@ mod tests {
impl genetic_node::GeneticNode for TestState {
type Context = ();
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
async fn simulate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
self.score += 1.0;
Ok(())
}
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
async fn mutate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
Ok(())
}
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
async fn initialize(
_context: GeneticNodeContext<Self::Context>,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
async fn merge(left: &TestState, right: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
async fn merge(
left: &TestState,
right: &TestState,
_id: &Uuid,
_: Self::Context,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.score > right.score {
left.clone()
} else {
@ -498,40 +543,43 @@ mod tests {
Ok(())
})
})
}).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
})
.await
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
}
// #[tokio::test]
// async fn test_simulate() -> Result<(), Error> {
// let path = PathBuf::from("test_simulate");
// // Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
// tokio::task::spawn_blocking(move || {
// let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
// CleanUp::new(&path).run(move |p| {
// rt.block_on(async {
// // Testing initial creation
// let config = GemlaConfig {
// generations_per_height: 10,
// overwrite: true,
// };
// let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
#[tokio::test]
async fn test_simulate() -> Result<(), Error> {
let path = PathBuf::from("test_simulate");
// Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
tokio::task::spawn_blocking(move || {
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
CleanUp::new(&path).run(move |p| {
rt.block_on(async {
// Testing initial creation
let config = GemlaConfig {
generations_per_height: 10,
overwrite: true,
};
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
// // Now we can use `.await` within the spawned blocking task.
// gemla.simulate(5).await?;
// let data = gemla.data.readonly();
// let data_lock = data.read().unwrap();
// let tree = data_lock.0.as_ref().unwrap();
// assert_eq!(tree.height(), 5);
// assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
// Now we can use `.await` within the spawned blocking task.
gemla.simulate(5).await?;
let data = gemla.data.readonly();
let data_lock = data.read().await;
let tree = data_lock.0.as_ref().unwrap();
assert_eq!(tree.height(), 5);
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
// Ok(())
// })
// })
// }).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
// Ok(())
// }
Ok(())
})
})
})
.await
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
}
}