diff --git a/analysis.py b/analysis.py new file mode 100644 index 0000000..f5db293 --- /dev/null +++ b/analysis.py @@ -0,0 +1,101 @@ +# Re-importing necessary libraries +import json +import matplotlib.pyplot as plt +import networkx as nx +import random + +def hierarchy_pos(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5): + if not nx.is_tree(G): + raise TypeError('cannot use hierarchy_pos on a graph that is not a tree') + + if root is None: + if isinstance(G, nx.DiGraph): + root = next(iter(nx.topological_sort(G))) + else: + root = random.choice(list(G.nodes)) + + def _hierarchy_pos(G, root, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None): + if pos is None: + pos = {root: (xcenter, vert_loc)} + else: + pos[root] = (xcenter, vert_loc) + children = list(G.successors(root)) # Use successors to get children for DiGraph + if not isinstance(G, nx.DiGraph): + if parent is not None: + children.remove(parent) + if len(children) != 0: + dx = width / len(children) + nextx = xcenter - width / 2 - dx / 2 + for child in children: + nextx += dx + pos = _hierarchy_pos(G, child, width=dx, vert_gap=vert_gap, + vert_loc=vert_loc - vert_gap, xcenter=nextx, + pos=pos, parent=root) + return pos + + return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) + +# Simplified JSON data for demonstration +with open('gemla/test.json', 'r') as file: + simplified_json_data = json.load(file) + +# Function to traverse the tree and create a graph +def traverse(node, graph, parent=None): + if node is None: + return + + node_id = node["val"]["id"] + if "node" in node["val"] and node["val"]["node"]: + scores = node["val"]["node"]["scores"] + generations = node["val"]["node"]["generation"] + population_size = node["val"]["node"]["population_size"] + # Prepare to track the highest score across all generations and the corresponding individual + overall_max_score = float('-inf') + overall_max_score_individual = None + overall_max_score_gen = None + + for gen, gen_scores in enumerate(scores): + if gen_scores: # Ensure the dictionary is not empty + # Find the max score and the individual for this generation + max_score_for_gen = max(gen_scores.values()) + individual_with_max_score_for_gen = max(gen_scores, key=gen_scores.get) + + # if max_score_for_gen > overall_max_score: + overall_max_score = max_score_for_gen + overall_max_score_individual = individual_with_max_score_for_gen + overall_max_score_gen = gen + + label = f"{node_id}\nGenerations: {generations}, Population: {population_size}\nMax score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen})" + else: + label = node_id + + graph.add_node(node_id, label=label) + if parent: + graph.add_edge(parent, node_id) + + traverse(node.get("left"), graph, parent=node_id) + traverse(node.get("right"), graph, parent=node_id) + + +# Create a directed graph +G = nx.DiGraph() + +# Populate the graph +traverse(simplified_json_data[0], G) + +# Find the root node (a node with no incoming edges) +root_candidates = [node for node, indeg in G.in_degree() if indeg == 0] + +if root_candidates: + root_node = root_candidates[0] # Assuming there's only one root candidate +else: + root_node = None # This should ideally never happen in a properly structured tree + +# Use the determined root node for hierarchy_pos +if root_node is not None: + pos = hierarchy_pos(G, root=root_node) + labels = nx.get_node_attributes(G, 'label') + nx.draw(G, pos, labels=labels, with_labels=True, arrows=True) + plt.show() +else: + print("No root node found. Cannot draw the tree.") \ No newline at end of file diff --git a/gemla/Cargo.toml b/gemla/Cargo.toml index 6119210..ee4951d 100644 --- a/gemla/Cargo.toml +++ b/gemla/Cargo.toml @@ -26,8 +26,8 @@ rand = "0.8.5" log = "0.4.21" env_logger = "0.11.3" futures = "0.3.30" -smol = "2.0.0" -smol-potat = "1.1.2" +tokio = { version = "1.36.0", features = ["full"] } num_cpus = "1.16.0" easy-parallel = "3.3.1" fann = "0.1.8" +async-trait = "0.1.78" diff --git a/gemla/src/bin/bin.rs b/gemla/src/bin/bin.rs index 4cb05c8..afb2ba7 100644 --- a/gemla/src/bin/bin.rs +++ b/gemla/src/bin/bin.rs @@ -6,16 +6,17 @@ extern crate log; mod test_state; mod fighter_nn; -use easy_parallel::Parallel; use file_linked::constants::data_format::DataFormat; use gemla::{ core::{Gemla, GemlaConfig}, - error::{log_error, Error}, + error::log_error, }; -use smol::{channel, channel::RecvError, future, Executor}; use std::{path::PathBuf, time::Instant}; use fighter_nn::FighterNN; use clap::Parser; +use anyhow::Result; + +// const NUM_THREADS: usize = 12; #[derive(Parser)] #[command(version, about, long_about = None)] @@ -29,48 +30,40 @@ struct Args { /// /// Use the -h, --h, or --help flag to see usage syntax. /// TODO -fn main() -> anyhow::Result<()> { +fn main() -> Result<()> { env_logger::init(); info!("Starting"); - let now = Instant::now(); - // Obtainning number of threads to use - let num_threads = num_cpus::get().max(1); - let ex = Executor::new(); - let (signal, shutdown) = channel::unbounded::<()>(); + // Manually configure the Tokio runtime + let runtime: Result<()> = tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_cpus::get()) + // .worker_threads(NUM_THREADS) + .build()? + .block_on(async { + let args = Args::parse(); // Assuming Args::parse() doesn't need to be async + let mut gemla = log_error(Gemla::::new( + &PathBuf::from(args.file), + GemlaConfig { + generations_per_height: 10, + overwrite: false, + }, + DataFormat::Json, + ))?; - // Create an executor thread pool. - let (_, result): (Vec>, Result<(), Error>) = Parallel::new() - .each(0..num_threads, |_| { - future::block_on(ex.run(shutdown.recv())) - }) - .finish(|| { - smol::block_on(async { - drop(signal); + // let gemla_arc = Arc::new(gemla); - // Command line arguments are parsed with the clap crate. - let args = Args::parse(); + // Setup your application logic here + // If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks - // Checking that the first argument is a valid file - let mut gemla = log_error(Gemla::::new( - &PathBuf::from(args.file), - GemlaConfig { - generations_per_height: 3, - overwrite: false, - }, - DataFormat::Json, - ))?; - - loop { - log_error(gemla.simulate(5).await)?; - } - }) + // Example placeholder loop to continuously run simulate + loop { // Arbitrary loop count for demonstration + gemla.simulate(5).await?; + } }); - result?; + runtime?; // Handle errors from the block_on call info!("Finished in {:?}", now.elapsed()); - Ok(()) -} +} \ No newline at end of file diff --git a/gemla/src/bin/fighter_nn/mod.rs b/gemla/src/bin/fighter_nn/mod.rs index 248dbb0..bb641d8 100644 --- a/gemla/src/bin/fighter_nn/mod.rs +++ b/gemla/src/bin/fighter_nn/mod.rs @@ -1,7 +1,8 @@ extern crate fann; -use std::{fs, path::PathBuf}; +use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}}; use fann::{ActivationFunc, Fann}; +use futures::future::join_all; use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error}; use rand::prelude::*; use rand::distributions::{Distribution, Uniform}; @@ -9,12 +10,15 @@ use serde::{Deserialize, Serialize}; use anyhow::Context; use uuid::Uuid; use std::collections::HashMap; +use tokio::process::Command; +use async_trait::async_trait; const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations"; -const POPULATION: usize = 100; +const POPULATION: usize = 50; const NEURAL_NETWORK_SHAPE: &[u32; 5] = &[14, 20, 20, 12, 8]; -const SIMULATION_ROUNDS: usize = 10; +const SIMULATION_ROUNDS: usize = 5; const SURVIVAL_RATE: f32 = 0.5; +const GAME_EXECUTABLE_PATH: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe"; // Here is the folder structure for the FighterNN: // base_dir/fighter_nn_{fighter_id}/{generation}/{fighter_id}_fighter_nn_{nn_id}.net @@ -36,9 +40,10 @@ pub struct FighterNN { pub scores: Vec>, } +#[async_trait] impl GeneticNode for FighterNN { // Check for the highest number of the folder name and increment it by 1 - fn initialize(context: &GeneticNodeContext) -> Result, Error> { + fn initialize(context: GeneticNodeContext) -> Result, Error> { let base_path = PathBuf::from(BASE_DIR); let folder = base_path.join(format!("fighter_nn_{:06}", context.id)); @@ -57,6 +62,7 @@ impl GeneticNode for FighterNN { let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i)); let mut fann = Fann::new(NEURAL_NETWORK_SHAPE) .with_context(|| "Failed to create nn")?; + fann.randomize_weights(-0.8, 0.8); fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); // This will overwrite any existing file with the same name @@ -73,47 +79,89 @@ impl GeneticNode for FighterNN { })) } - fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> { + async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { // For each nn in the current generation: for i in 0..self.population_size { // load the nn let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, i)); - let fann = Fann::from_file(&nn) - .with_context(|| format!("Failed to load nn"))?; - - // Simulate the nn against the random nn - let mut score = 0.0; - - // Using the same original nn, repeat the simulation with 5 random nn's from the current generation + let mut simulations = Vec::new(); + + // Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently for _ in 0..SIMULATION_ROUNDS { - let random_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, thread_rng().gen_range(0..self.population_size))); - let random_fann = Fann::from_file(&random_nn) - .with_context(|| format!("Failed to load random nn"))?; + let random_nn_index = thread_rng().gen_range(0..self.population_size); + let id = self.id.clone(); + let folder = self.folder.clone(); + let generation = self.generation; - let inputs: Vec = (0..NEURAL_NETWORK_SHAPE[0]).map(|_| thread_rng().gen_range(-1.0..1.0)).collect(); - let outputs = fann.run(&inputs) - .with_context(|| format!("Failed to run nn"))?; - let random_outputs = random_fann.run(&inputs) - .with_context(|| format!("Failed to run random nn"))?; - - // Average the difference between the outputs of the nn and random_nn and add the result to score - let mut round_score = 0.0; - for (o, r) in outputs.iter().zip(random_outputs.iter()) { - round_score += o - r; - } - score += round_score / fann.get_num_output() as f32; + let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index)); + let nn_clone = nn.clone(); // Clone the path to use in the async block + + let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap()); + let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap()); + let disable_unreal_rendering_arg = "-nullrhi".to_string(); + + let future = async move { + // Construct the score file path + let nn_id = format!("{:06}_fighter_nn_{}", id, i); + let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index); + let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id); + let score_file = folder.join(format!("{}", generation)).join(&score_file_name); + // Check if score file already exists before running the simulation + if score_file.exists() { + let round_score = read_score_from_file(&score_file, &nn_id) + .with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?; + return Ok::(round_score); + } + + // Check if the opposite round score has been determined + let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id)); + if opposite_score_file.exists() { + let round_score = read_score_from_file(&opposite_score_file, &nn_id) + .with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?; + return Ok::(1.0 - round_score); + } + + if thread_rng().gen_range(0..100) < 4 { + let _output = Command::new(GAME_EXECUTABLE_PATH) + .arg(&config1_arg) + .arg(&config2_arg) + .output() + .await + .expect("Failed to execute game"); + } else { + let _output = Command::new(GAME_EXECUTABLE_PATH) + .arg(&config1_arg) + .arg(&config2_arg) + .arg(&disable_unreal_rendering_arg) + .output() + .await + .expect("Failed to execute game"); + } + + // Read the score from the file + let round_score = read_score_from_file(&score_file, &nn_id) + .with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?; + + Ok::(round_score) + }; + + simulations.push(future); } - - score /= 5.0; + + // Wait for all simulation rounds to complete + let results: Result, Error> = join_all(simulations).await.into_iter().collect(); + + let score = results?.into_iter().sum::() / SIMULATION_ROUNDS as f32; + trace!("NN {:06}_fighter_nn_{} scored {}", self.id, i, score); self.scores[self.generation as usize].insert(i as u64, score); } - + Ok(()) } - fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> { + fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize; // Create the new generation folder @@ -234,3 +282,23 @@ impl GeneticNode for FighterNN { })) } } + +fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result { + let file = File::open(file_path)?; + let reader = BufReader::new(file); + + for line in reader.lines() { + let line = line?; + if line.starts_with(nn_id) { + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() == 2 { + return parts[1].trim().parse::().map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)); + } + } + } + + Err(io::Error::new( + io::ErrorKind::NotFound, + "NN ID not found in scores file", + )) +} \ No newline at end of file diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index 35ef7cc..0e8d11d 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -2,6 +2,7 @@ use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error} use rand::prelude::*; use serde::{Deserialize, Serialize}; use uuid::Uuid; +use async_trait::async_trait; const POPULATION_SIZE: u64 = 5; const POPULATION_REDUCTION_SIZE: u64 = 3; @@ -11,8 +12,9 @@ pub struct TestState { pub population: Vec, } +#[async_trait] impl GeneticNode for TestState { - fn initialize(_context: &GeneticNodeContext) -> Result, Error> { + fn initialize(_context: GeneticNodeContext) -> Result, Error> { let mut population: Vec = vec![]; for _ in 0..POPULATION_SIZE { @@ -22,7 +24,7 @@ impl GeneticNode for TestState { Ok(Box::new(TestState { population })) } - fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> { + async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { let mut rng = thread_rng(); self.population = self @@ -34,7 +36,7 @@ impl GeneticNode for TestState { Ok(()) } - fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> { + fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { let mut rng = thread_rng(); let mut v = self.population.clone(); @@ -83,7 +85,7 @@ impl GeneticNode for TestState { let mut result = TestState { population: v }; - result.mutate(&GeneticNodeContext { + result.mutate(GeneticNodeContext { id: id.clone(), generation: 0, max_generations: 0, @@ -101,7 +103,7 @@ mod tests { #[test] fn test_initialize() { let state = TestState::initialize( - &GeneticNodeContext { + GeneticNodeContext { id: Uuid::new_v4(), generation: 0, max_generations: 0, @@ -111,8 +113,8 @@ mod tests { assert_eq!(state.population.len(), POPULATION_SIZE as usize); } - #[test] - fn test_simulate() { + #[tokio::test] + async fn test_simulate() { let mut state = TestState { population: vec![1, 1, 2, 3], }; @@ -120,31 +122,31 @@ mod tests { let original_population = state.population.clone(); state.simulate( - &GeneticNodeContext { + GeneticNodeContext { id: Uuid::new_v4(), generation: 0, max_generations: 0, } - ).unwrap(); + ).await.unwrap(); assert!(original_population .iter() .zip(state.population.iter()) .all(|(&a, &b)| b >= a - 1 && b <= a + 2)); state.simulate( - &GeneticNodeContext { + GeneticNodeContext { id: Uuid::new_v4(), generation: 0, max_generations: 0, } - ).unwrap(); + ).await.unwrap(); state.simulate( - &GeneticNodeContext { + GeneticNodeContext { id: Uuid::new_v4(), generation: 0, max_generations: 0, } - ).unwrap(); + ).await.unwrap(); assert!(original_population .iter() .zip(state.population.iter()) @@ -158,7 +160,7 @@ mod tests { }; state.mutate( - &GeneticNodeContext { + GeneticNodeContext { id: Uuid::new_v4(), generation: 0, max_generations: 0, diff --git a/gemla/src/core/genetic_node.rs b/gemla/src/core/genetic_node.rs index c1e669f..92a7f48 100644 --- a/gemla/src/core/genetic_node.rs +++ b/gemla/src/core/genetic_node.rs @@ -8,6 +8,7 @@ use anyhow::Context; use serde::{Deserialize, Serialize}; use std::fmt::Debug; use uuid::Uuid; +use async_trait::async_trait; /// An enum used to control the state of a [`GeneticNode`] /// @@ -24,6 +25,7 @@ pub enum GeneticState { Finish, } +#[derive(Clone)] pub struct GeneticNodeContext { pub generation: u64, pub max_generations: u64, @@ -33,20 +35,21 @@ pub struct GeneticNodeContext { /// A trait used to interact with the internal state of nodes within the [`Bracket`] /// /// [`Bracket`]: crate::bracket::Bracket -pub trait GeneticNode { +#[async_trait] +pub trait GeneticNode: Send { /// Initializes a new instance of a [`GeneticState`]. /// /// # Examples /// TODO - fn initialize(context: &GeneticNodeContext) -> Result, Error>; + fn initialize(context: GeneticNodeContext) -> Result, Error>; - fn simulate(&mut self, context: &GeneticNodeContext) -> Result<(), Error>; + async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; /// Mutates members in a population and/or crossbreeds them to produce new offspring. /// /// # Examples /// TODO - fn mutate(&mut self, context: &GeneticNodeContext) -> Result<(), Error>; + fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; fn merge(left: &Self, right: &Self, id: &Uuid) -> Result, Error>; } @@ -76,7 +79,7 @@ impl Default for GeneticNodeWrapper { impl GeneticNodeWrapper where - T: GeneticNode + Debug, + T: GeneticNode + Debug + Send, { pub fn new(max_generations: u64) -> Self { GeneticNodeWrapper:: { @@ -115,7 +118,7 @@ where self.state } - pub fn process_node(&mut self) -> Result { + pub async fn process_node(&mut self) -> Result { let context = GeneticNodeContext { generation: self.generation, max_generations: self.max_generations, @@ -124,11 +127,11 @@ where match (self.state, &mut self.node) { (GeneticState::Initialize, _) => { - self.node = Some(*T::initialize(&context)?); + self.node = Some(*T::initialize(context.clone())?); self.state = GeneticState::Simulate; } (GeneticState::Simulate, Some(n)) => { - n.simulate(&context) + n.simulate(context.clone()).await .with_context(|| format!("Error simulating node: {:?}", self))?; self.state = if self.generation >= self.max_generations { @@ -138,7 +141,7 @@ where }; } (GeneticState::Mutate, Some(n)) => { - n.mutate(&context) + n.mutate(context.clone()) .with_context(|| format!("Error mutating node: {:?}", self))?; self.generation += 1; @@ -157,23 +160,25 @@ mod tests { use super::*; use crate::error::Error; use anyhow::anyhow; + use async_trait::async_trait; #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] struct TestState { pub score: f64, } + #[async_trait] impl GeneticNode for TestState { - fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> { + async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { self.score += 1.0; Ok(()) } - fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> { + fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { Ok(()) } - fn initialize(_context: &GeneticNodeContext) -> Result, Error> { + fn initialize(_context: GeneticNodeContext) -> Result, Error> { Ok(Box::new(TestState { score: 0.0 })) } @@ -270,16 +275,16 @@ mod tests { Ok(()) } - #[test] - fn test_process_node() -> Result<(), Error> { + #[tokio::test] + async fn test_process_node() -> Result<(), Error> { let mut genetic_node = GeneticNodeWrapper::::new(2); assert_eq!(genetic_node.state(), GeneticState::Initialize); - assert_eq!(genetic_node.process_node()?, GeneticState::Simulate); - assert_eq!(genetic_node.process_node()?, GeneticState::Mutate); - assert_eq!(genetic_node.process_node()?, GeneticState::Simulate); - assert_eq!(genetic_node.process_node()?, GeneticState::Finish); - assert_eq!(genetic_node.process_node()?, GeneticState::Finish); + assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate); + assert_eq!(genetic_node.process_node().await?, GeneticState::Mutate); + assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate); + assert_eq!(genetic_node.process_node().await?, GeneticState::Finish); + assert_eq!(genetic_node.process_node().await?, GeneticState::Finish); Ok(()) } diff --git a/gemla/src/core/mod.rs b/gemla/src/core/mod.rs index 152ed6d..42a1fc7 100644 --- a/gemla/src/core/mod.rs +++ b/gemla/src/core/mod.rs @@ -5,10 +5,11 @@ pub mod genetic_node; use crate::{error::Error, tree::Tree}; use file_linked::{constants::data_format::DataFormat, FileLinked}; -use futures::{future, future::BoxFuture}; +use futures::future; use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState}; use log::{info, trace, warn}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use tokio::task::JoinHandle; use std::{ collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, time::Instant, @@ -65,15 +66,15 @@ pub struct GemlaConfig { /// individuals. /// /// [`GeneticNode`]: genetic_node::GeneticNode -pub struct Gemla<'a, T> +pub struct Gemla where T: Serialize + Clone, { pub data: FileLinked<(Option>, GemlaConfig)>, - threads: HashMap, Error>>>, + threads: HashMap, Error>>>, } -impl<'a, T: 'a> Gemla<'a, T> +impl Gemla where T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send, { @@ -147,7 +148,9 @@ where trace!("Adding node to process list {}", node.id()); self.threads - .insert(node.id(), Box::pin(Gemla::process_node(node))); + .insert(node.id(), tokio::spawn(async move { + Gemla::process_node(node).await + })); } else { trace!("No node found to process, joining threads"); @@ -163,9 +166,10 @@ where trace!("Joining threads for nodes {:?}", self.threads.keys()); let results = future::join_all(self.threads.values_mut()).await; + // Converting a list of results into a result wrapping the list let reduced_results: Result>, Error> = - results.into_iter().collect(); + results.into_iter().flatten().collect(); self.threads.clear(); // We need to retrieve the processed nodes from the resulting list and replace them in the original list @@ -323,7 +327,12 @@ where let node_state_time = Instant::now(); let node_state = node.state(); - node.process_node()?; + node.process_node().await?; + + if node.state() == GeneticState::Simulate + { + node.process_node().await?; + } trace!( "{:?} completed in {:?} for {}", @@ -346,6 +355,8 @@ mod tests { use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::fs; + use async_trait::async_trait; + use tokio::runtime::Runtime; use self::genetic_node::GeneticNodeContext; @@ -378,17 +389,18 @@ mod tests { pub score: f64, } + #[async_trait] impl genetic_node::GeneticNode for TestState { - fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> { + async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { self.score += 1.0; Ok(()) } - fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> { + fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { Ok(()) } - fn initialize(_context: &GeneticNodeContext) -> Result, Error> { + fn initialize(_context: GeneticNodeContext) -> Result, Error> { Ok(Box::new(TestState { score: 0.0 })) } @@ -401,66 +413,84 @@ mod tests { } } - #[test] - fn test_new() -> Result<(), Error> { + #[tokio::test] + async fn test_new() -> Result<(), Error> { let path = PathBuf::from("test_new_non_existing"); - CleanUp::new(&path).run(|p| { - assert!(!path.exists()); + // Use `spawn_blocking` to run synchronous code that needs to call async code internally. + tokio::task::spawn_blocking(move || { + let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block. + CleanUp::new(&path).run(move |p| { + rt.block_on(async { + assert!(!path.exists()); - // Testing initial creation - let mut config = GemlaConfig { - generations_per_height: 1, - overwrite: true - }; - let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; + // Testing initial creation + let mut config = GemlaConfig { + generations_per_height: 1, + overwrite: true, + }; + let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; - smol::block_on(gemla.simulate(2))?; - assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2); - - drop(gemla); - assert!(path.exists()); + // Now we can use `.await` within the spawned blocking task. + gemla.simulate(2).await?; + assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2); - // Testing overwriting data - let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; + drop(gemla); + assert!(path.exists()); - smol::block_on(gemla.simulate(2))?; - assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2); + // Testing overwriting data + let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; - drop(gemla); - assert!(path.exists()); + gemla.simulate(2).await?; + assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2); - // Testing not-overwriting data - config.overwrite = false; - let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; + drop(gemla); + assert!(path.exists()); - smol::block_on(gemla.simulate(2))?; - assert_eq!(gemla.tree_ref().unwrap().height(), 4); + // Testing not-overwriting data + config.overwrite = false; + let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; - drop(gemla); - assert!(path.exists()); + gemla.simulate(2).await?; + assert_eq!(gemla.tree_ref().unwrap().height(), 4); - Ok(()) - }) + drop(gemla); + assert!(path.exists()); + + Ok(()) + }) + }) + }).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result. + + Ok(()) } - #[test] - fn test_simulate() -> Result<(), Error> { + #[tokio::test] + async fn test_simulate() -> Result<(), Error> { let path = PathBuf::from("test_simulate"); - CleanUp::new(&path).run(|p| { - // Testing initial creation - let config = GemlaConfig { - generations_per_height: 10, - overwrite: true - }; - let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; + // Use `spawn_blocking` to run the synchronous closure that internally awaits async code. + tokio::task::spawn_blocking(move || { + let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block. + CleanUp::new(&path).run(move |p| { + rt.block_on(async { + // Testing initial creation + let config = GemlaConfig { + generations_per_height: 10, + overwrite: true, + }; + let mut gemla = Gemla::::new(&p, config, DataFormat::Json)?; - smol::block_on(gemla.simulate(5))?; - let tree = gemla.tree_ref().unwrap(); - assert_eq!(tree.height(), 5); - assert_eq!(tree.val.as_ref().unwrap().score, 50.0); + // Now we can use `.await` within the spawned blocking task. + gemla.simulate(5).await?; + let tree = gemla.tree_ref().unwrap(); + assert_eq!(tree.height(), 5); + assert_eq!(tree.val.as_ref().unwrap().score, 50.0); - Ok(()) - }) + Ok(()) + }) + }) + }).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result. + + Ok(()) } }