Modifying executor runtime for efficiency

This commit is contained in:
vandomej 2024-03-21 10:21:15 -07:00
parent c9b746e59d
commit 97086fdbe0
7 changed files with 356 additions and 157 deletions

101
analysis.py Normal file
View file

@ -0,0 +1,101 @@
# Re-importing necessary libraries
import json
import matplotlib.pyplot as plt
import networkx as nx
import random
def hierarchy_pos(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
if not nx.is_tree(G):
raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')
if root is None:
if isinstance(G, nx.DiGraph):
root = next(iter(nx.topological_sort(G)))
else:
root = random.choice(list(G.nodes))
def _hierarchy_pos(G, root, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
if pos is None:
pos = {root: (xcenter, vert_loc)}
else:
pos[root] = (xcenter, vert_loc)
children = list(G.successors(root)) # Use successors to get children for DiGraph
if not isinstance(G, nx.DiGraph):
if parent is not None:
children.remove(parent)
if len(children) != 0:
dx = width / len(children)
nextx = xcenter - width / 2 - dx / 2
for child in children:
nextx += dx
pos = _hierarchy_pos(G, child, width=dx, vert_gap=vert_gap,
vert_loc=vert_loc - vert_gap, xcenter=nextx,
pos=pos, parent=root)
return pos
return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)
# Simplified JSON data for demonstration
with open('gemla/test.json', 'r') as file:
simplified_json_data = json.load(file)
# Function to traverse the tree and create a graph
def traverse(node, graph, parent=None):
if node is None:
return
node_id = node["val"]["id"]
if "node" in node["val"] and node["val"]["node"]:
scores = node["val"]["node"]["scores"]
generations = node["val"]["node"]["generation"]
population_size = node["val"]["node"]["population_size"]
# Prepare to track the highest score across all generations and the corresponding individual
overall_max_score = float('-inf')
overall_max_score_individual = None
overall_max_score_gen = None
for gen, gen_scores in enumerate(scores):
if gen_scores: # Ensure the dictionary is not empty
# Find the max score and the individual for this generation
max_score_for_gen = max(gen_scores.values())
individual_with_max_score_for_gen = max(gen_scores, key=gen_scores.get)
# if max_score_for_gen > overall_max_score:
overall_max_score = max_score_for_gen
overall_max_score_individual = individual_with_max_score_for_gen
overall_max_score_gen = gen
label = f"{node_id}\nGenerations: {generations}, Population: {population_size}\nMax score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen})"
else:
label = node_id
graph.add_node(node_id, label=label)
if parent:
graph.add_edge(parent, node_id)
traverse(node.get("left"), graph, parent=node_id)
traverse(node.get("right"), graph, parent=node_id)
# Create a directed graph
G = nx.DiGraph()
# Populate the graph
traverse(simplified_json_data[0], G)
# Find the root node (a node with no incoming edges)
root_candidates = [node for node, indeg in G.in_degree() if indeg == 0]
if root_candidates:
root_node = root_candidates[0] # Assuming there's only one root candidate
else:
root_node = None # This should ideally never happen in a properly structured tree
# Use the determined root node for hierarchy_pos
if root_node is not None:
pos = hierarchy_pos(G, root=root_node)
labels = nx.get_node_attributes(G, 'label')
nx.draw(G, pos, labels=labels, with_labels=True, arrows=True)
plt.show()
else:
print("No root node found. Cannot draw the tree.")

View file

@ -26,8 +26,8 @@ rand = "0.8.5"
log = "0.4.21"
env_logger = "0.11.3"
futures = "0.3.30"
smol = "2.0.0"
smol-potat = "1.1.2"
tokio = { version = "1.36.0", features = ["full"] }
num_cpus = "1.16.0"
easy-parallel = "3.3.1"
fann = "0.1.8"
async-trait = "0.1.78"

View file

@ -6,16 +6,17 @@ extern crate log;
mod test_state;
mod fighter_nn;
use easy_parallel::Parallel;
use file_linked::constants::data_format::DataFormat;
use gemla::{
core::{Gemla, GemlaConfig},
error::{log_error, Error},
error::log_error,
};
use smol::{channel, channel::RecvError, future, Executor};
use std::{path::PathBuf, time::Instant};
use fighter_nn::FighterNN;
use clap::Parser;
use anyhow::Result;
// const NUM_THREADS: usize = 12;
#[derive(Parser)]
#[command(version, about, long_about = None)]
@ -29,48 +30,40 @@ struct Args {
///
/// Use the -h, --h, or --help flag to see usage syntax.
/// TODO
fn main() -> anyhow::Result<()> {
fn main() -> Result<()> {
env_logger::init();
info!("Starting");
let now = Instant::now();
// Obtainning number of threads to use
let num_threads = num_cpus::get().max(1);
let ex = Executor::new();
let (signal, shutdown) = channel::unbounded::<()>();
// Manually configure the Tokio runtime
let runtime: Result<()> = tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_cpus::get())
// .worker_threads(NUM_THREADS)
.build()?
.block_on(async {
let args = Args::parse(); // Assuming Args::parse() doesn't need to be async
let mut gemla = log_error(Gemla::<FighterNN>::new(
&PathBuf::from(args.file),
GemlaConfig {
generations_per_height: 10,
overwrite: false,
},
DataFormat::Json,
))?;
// Create an executor thread pool.
let (_, result): (Vec<Result<(), RecvError>>, Result<(), Error>) = Parallel::new()
.each(0..num_threads, |_| {
future::block_on(ex.run(shutdown.recv()))
})
.finish(|| {
smol::block_on(async {
drop(signal);
// let gemla_arc = Arc::new(gemla);
// Command line arguments are parsed with the clap crate.
let args = Args::parse();
// Setup your application logic here
// If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks
// Checking that the first argument <FILE> is a valid file
let mut gemla = log_error(Gemla::<FighterNN>::new(
&PathBuf::from(args.file),
GemlaConfig {
generations_per_height: 3,
overwrite: false,
},
DataFormat::Json,
))?;
loop {
log_error(gemla.simulate(5).await)?;
}
})
// Example placeholder loop to continuously run simulate
loop { // Arbitrary loop count for demonstration
gemla.simulate(5).await?;
}
});
result?;
runtime?; // Handle errors from the block_on call
info!("Finished in {:?}", now.elapsed());
Ok(())
}
}

View file

@ -1,7 +1,8 @@
extern crate fann;
use std::{fs, path::PathBuf};
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}};
use fann::{ActivationFunc, Fann};
use futures::future::join_all;
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
use rand::prelude::*;
use rand::distributions::{Distribution, Uniform};
@ -9,12 +10,15 @@ use serde::{Deserialize, Serialize};
use anyhow::Context;
use uuid::Uuid;
use std::collections::HashMap;
use tokio::process::Command;
use async_trait::async_trait;
const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations";
const POPULATION: usize = 100;
const POPULATION: usize = 50;
const NEURAL_NETWORK_SHAPE: &[u32; 5] = &[14, 20, 20, 12, 8];
const SIMULATION_ROUNDS: usize = 10;
const SIMULATION_ROUNDS: usize = 5;
const SURVIVAL_RATE: f32 = 0.5;
const GAME_EXECUTABLE_PATH: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe";
// Here is the folder structure for the FighterNN:
// base_dir/fighter_nn_{fighter_id}/{generation}/{fighter_id}_fighter_nn_{nn_id}.net
@ -36,9 +40,10 @@ pub struct FighterNN {
pub scores: Vec<HashMap<u64, f32>>,
}
#[async_trait]
impl GeneticNode for FighterNN {
// Check for the highest number of the folder name and increment it by 1
fn initialize(context: &GeneticNodeContext) -> Result<Box<Self>, Error> {
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error> {
let base_path = PathBuf::from(BASE_DIR);
let folder = base_path.join(format!("fighter_nn_{:06}", context.id));
@ -57,6 +62,7 @@ impl GeneticNode for FighterNN {
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i));
let mut fann = Fann::new(NEURAL_NETWORK_SHAPE)
.with_context(|| "Failed to create nn")?;
fann.randomize_weights(-0.8, 0.8);
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
// This will overwrite any existing file with the same name
@ -73,47 +79,89 @@ impl GeneticNode for FighterNN {
}))
}
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
// For each nn in the current generation:
for i in 0..self.population_size {
// load the nn
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, i));
let fann = Fann::from_file(&nn)
.with_context(|| format!("Failed to load nn"))?;
// Simulate the nn against the random nn
let mut score = 0.0;
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation
let mut simulations = Vec::new();
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
for _ in 0..SIMULATION_ROUNDS {
let random_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, thread_rng().gen_range(0..self.population_size)));
let random_fann = Fann::from_file(&random_nn)
.with_context(|| format!("Failed to load random nn"))?;
let random_nn_index = thread_rng().gen_range(0..self.population_size);
let id = self.id.clone();
let folder = self.folder.clone();
let generation = self.generation;
let inputs: Vec<f32> = (0..NEURAL_NETWORK_SHAPE[0]).map(|_| thread_rng().gen_range(-1.0..1.0)).collect();
let outputs = fann.run(&inputs)
.with_context(|| format!("Failed to run nn"))?;
let random_outputs = random_fann.run(&inputs)
.with_context(|| format!("Failed to run random nn"))?;
// Average the difference between the outputs of the nn and random_nn and add the result to score
let mut round_score = 0.0;
for (o, r) in outputs.iter().zip(random_outputs.iter()) {
round_score += o - r;
}
score += round_score / fann.get_num_output() as f32;
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index));
let nn_clone = nn.clone(); // Clone the path to use in the async block
let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap());
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
let disable_unreal_rendering_arg = "-nullrhi".to_string();
let future = async move {
// Construct the score file path
let nn_id = format!("{:06}_fighter_nn_{}", id, i);
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id);
let score_file = folder.join(format!("{}", generation)).join(&score_file_name);
// Check if score file already exists before running the simulation
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_id)
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
return Ok::<f32, Error>(round_score);
}
// Check if the opposite round score has been determined
let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id));
if opposite_score_file.exists() {
let round_score = read_score_from_file(&opposite_score_file, &nn_id)
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
return Ok::<f32, Error>(1.0 - round_score);
}
if thread_rng().gen_range(0..100) < 4 {
let _output = Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.output()
.await
.expect("Failed to execute game");
} else {
let _output = Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.arg(&disable_unreal_rendering_arg)
.output()
.await
.expect("Failed to execute game");
}
// Read the score from the file
let round_score = read_score_from_file(&score_file, &nn_id)
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
Ok::<f32, Error>(round_score)
};
simulations.push(future);
}
score /= 5.0;
// Wait for all simulation rounds to complete
let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect();
let score = results?.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32;
trace!("NN {:06}_fighter_nn_{} scored {}", self.id, i, score);
self.scores[self.generation as usize].insert(i as u64, score);
}
Ok(())
}
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
// Create the new generation folder
@ -234,3 +282,23 @@ impl GeneticNode for FighterNN {
}))
}
}
fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::Error> {
let file = File::open(file_path)?;
let reader = BufReader::new(file);
for line in reader.lines() {
let line = line?;
if line.starts_with(nn_id) {
let parts: Vec<&str> = line.split(':').collect();
if parts.len() == 2 {
return parts[1].trim().parse::<f32>().map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e));
}
}
}
Err(io::Error::new(
io::ErrorKind::NotFound,
"NN ID not found in scores file",
))
}

View file

@ -2,6 +2,7 @@ use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error}
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use async_trait::async_trait;
const POPULATION_SIZE: u64 = 5;
const POPULATION_REDUCTION_SIZE: u64 = 3;
@ -11,8 +12,9 @@ pub struct TestState {
pub population: Vec<i64>,
}
#[async_trait]
impl GeneticNode for TestState {
fn initialize(_context: &GeneticNodeContext) -> Result<Box<Self>, Error> {
fn initialize(_context: GeneticNodeContext) -> Result<Box<Self>, Error> {
let mut population: Vec<i64> = vec![];
for _ in 0..POPULATION_SIZE {
@ -22,7 +24,7 @@ impl GeneticNode for TestState {
Ok(Box::new(TestState { population }))
}
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
let mut rng = thread_rng();
self.population = self
@ -34,7 +36,7 @@ impl GeneticNode for TestState {
Ok(())
}
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
let mut rng = thread_rng();
let mut v = self.population.clone();
@ -83,7 +85,7 @@ impl GeneticNode for TestState {
let mut result = TestState { population: v };
result.mutate(&GeneticNodeContext {
result.mutate(GeneticNodeContext {
id: id.clone(),
generation: 0,
max_generations: 0,
@ -101,7 +103,7 @@ mod tests {
#[test]
fn test_initialize() {
let state = TestState::initialize(
&GeneticNodeContext {
GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
@ -111,8 +113,8 @@ mod tests {
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
}
#[test]
fn test_simulate() {
#[tokio::test]
async fn test_simulate() {
let mut state = TestState {
population: vec![1, 1, 2, 3],
};
@ -120,31 +122,31 @@ mod tests {
let original_population = state.population.clone();
state.simulate(
&GeneticNodeContext {
GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
}
).unwrap();
).await.unwrap();
assert!(original_population
.iter()
.zip(state.population.iter())
.all(|(&a, &b)| b >= a - 1 && b <= a + 2));
state.simulate(
&GeneticNodeContext {
GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
}
).unwrap();
).await.unwrap();
state.simulate(
&GeneticNodeContext {
GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
}
).unwrap();
).await.unwrap();
assert!(original_population
.iter()
.zip(state.population.iter())
@ -158,7 +160,7 @@ mod tests {
};
state.mutate(
&GeneticNodeContext {
GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,

View file

@ -8,6 +8,7 @@ use anyhow::Context;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use uuid::Uuid;
use async_trait::async_trait;
/// An enum used to control the state of a [`GeneticNode`]
///
@ -24,6 +25,7 @@ pub enum GeneticState {
Finish,
}
#[derive(Clone)]
pub struct GeneticNodeContext {
pub generation: u64,
pub max_generations: u64,
@ -33,20 +35,21 @@ pub struct GeneticNodeContext {
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
///
/// [`Bracket`]: crate::bracket::Bracket
pub trait GeneticNode {
#[async_trait]
pub trait GeneticNode: Send {
/// Initializes a new instance of a [`GeneticState`].
///
/// # Examples
/// TODO
fn initialize(context: &GeneticNodeContext) -> Result<Box<Self>, Error>;
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error>;
fn simulate(&mut self, context: &GeneticNodeContext) -> Result<(), Error>;
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error>;
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
///
/// # Examples
/// TODO
fn mutate(&mut self, context: &GeneticNodeContext) -> Result<(), Error>;
fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>;
fn merge(left: &Self, right: &Self, id: &Uuid) -> Result<Box<Self>, Error>;
}
@ -76,7 +79,7 @@ impl<T> Default for GeneticNodeWrapper<T> {
impl<T> GeneticNodeWrapper<T>
where
T: GeneticNode + Debug,
T: GeneticNode + Debug + Send,
{
pub fn new(max_generations: u64) -> Self {
GeneticNodeWrapper::<T> {
@ -115,7 +118,7 @@ where
self.state
}
pub fn process_node(&mut self) -> Result<GeneticState, Error> {
pub async fn process_node(&mut self) -> Result<GeneticState, Error> {
let context = GeneticNodeContext {
generation: self.generation,
max_generations: self.max_generations,
@ -124,11 +127,11 @@ where
match (self.state, &mut self.node) {
(GeneticState::Initialize, _) => {
self.node = Some(*T::initialize(&context)?);
self.node = Some(*T::initialize(context.clone())?);
self.state = GeneticState::Simulate;
}
(GeneticState::Simulate, Some(n)) => {
n.simulate(&context)
n.simulate(context.clone()).await
.with_context(|| format!("Error simulating node: {:?}", self))?;
self.state = if self.generation >= self.max_generations {
@ -138,7 +141,7 @@ where
};
}
(GeneticState::Mutate, Some(n)) => {
n.mutate(&context)
n.mutate(context.clone())
.with_context(|| format!("Error mutating node: {:?}", self))?;
self.generation += 1;
@ -157,23 +160,25 @@ mod tests {
use super::*;
use crate::error::Error;
use anyhow::anyhow;
use async_trait::async_trait;
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
struct TestState {
pub score: f64,
}
#[async_trait]
impl GeneticNode for TestState {
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
self.score += 1.0;
Ok(())
}
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
Ok(())
}
fn initialize(_context: &GeneticNodeContext) -> Result<Box<TestState>, Error> {
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
@ -270,16 +275,16 @@ mod tests {
Ok(())
}
#[test]
fn test_process_node() -> Result<(), Error> {
#[tokio::test]
async fn test_process_node() -> Result<(), Error> {
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
assert_eq!(genetic_node.state(), GeneticState::Initialize);
assert_eq!(genetic_node.process_node()?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node()?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node()?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node()?, GeneticState::Finish);
assert_eq!(genetic_node.process_node()?, GeneticState::Finish);
assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node().await?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node().await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node().await?, GeneticState::Finish);
Ok(())
}

View file

@ -5,10 +5,11 @@ pub mod genetic_node;
use crate::{error::Error, tree::Tree};
use file_linked::{constants::data_format::DataFormat, FileLinked};
use futures::{future, future::BoxFuture};
use futures::future;
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
use log::{info, trace, warn};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::task::JoinHandle;
use std::{
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
time::Instant,
@ -65,15 +66,15 @@ pub struct GemlaConfig {
/// individuals.
///
/// [`GeneticNode`]: genetic_node::GeneticNode
pub struct Gemla<'a, T>
pub struct Gemla<T>
where
T: Serialize + Clone,
{
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
threads: HashMap<Uuid, BoxFuture<'a, Result<GeneticNodeWrapper<T>, Error>>>,
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
}
impl<'a, T: 'a> Gemla<'a, T>
impl<T: 'static> Gemla<T>
where
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
{
@ -147,7 +148,9 @@ where
trace!("Adding node to process list {}", node.id());
self.threads
.insert(node.id(), Box::pin(Gemla::process_node(node)));
.insert(node.id(), tokio::spawn(async move {
Gemla::process_node(node).await
}));
} else {
trace!("No node found to process, joining threads");
@ -163,9 +166,10 @@ where
trace!("Joining threads for nodes {:?}", self.threads.keys());
let results = future::join_all(self.threads.values_mut()).await;
// Converting a list of results into a result wrapping the list
let reduced_results: Result<Vec<GeneticNodeWrapper<T>>, Error> =
results.into_iter().collect();
results.into_iter().flatten().collect();
self.threads.clear();
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
@ -323,7 +327,12 @@ where
let node_state_time = Instant::now();
let node_state = node.state();
node.process_node()?;
node.process_node().await?;
if node.state() == GeneticState::Simulate
{
node.process_node().await?;
}
trace!(
"{:?} completed in {:?} for {}",
@ -346,6 +355,8 @@ mod tests {
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::fs;
use async_trait::async_trait;
use tokio::runtime::Runtime;
use self::genetic_node::GeneticNodeContext;
@ -378,17 +389,18 @@ mod tests {
pub score: f64,
}
#[async_trait]
impl genetic_node::GeneticNode for TestState {
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
self.score += 1.0;
Ok(())
}
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
Ok(())
}
fn initialize(_context: &GeneticNodeContext) -> Result<Box<TestState>, Error> {
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
@ -401,66 +413,84 @@ mod tests {
}
}
#[test]
fn test_new() -> Result<(), Error> {
#[tokio::test]
async fn test_new() -> Result<(), Error> {
let path = PathBuf::from("test_new_non_existing");
CleanUp::new(&path).run(|p| {
assert!(!path.exists());
// Use `spawn_blocking` to run synchronous code that needs to call async code internally.
tokio::task::spawn_blocking(move || {
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
CleanUp::new(&path).run(move |p| {
rt.block_on(async {
assert!(!path.exists());
// Testing initial creation
let mut config = GemlaConfig {
generations_per_height: 1,
overwrite: true
};
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
// Testing initial creation
let mut config = GemlaConfig {
generations_per_height: 1,
overwrite: true,
};
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
smol::block_on(gemla.simulate(2))?;
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
drop(gemla);
assert!(path.exists());
// Now we can use `.await` within the spawned blocking task.
gemla.simulate(2).await?;
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
// Testing overwriting data
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
drop(gemla);
assert!(path.exists());
smol::block_on(gemla.simulate(2))?;
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
// Testing overwriting data
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
drop(gemla);
assert!(path.exists());
gemla.simulate(2).await?;
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
// Testing not-overwriting data
config.overwrite = false;
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
drop(gemla);
assert!(path.exists());
smol::block_on(gemla.simulate(2))?;
assert_eq!(gemla.tree_ref().unwrap().height(), 4);
// Testing not-overwriting data
config.overwrite = false;
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
drop(gemla);
assert!(path.exists());
gemla.simulate(2).await?;
assert_eq!(gemla.tree_ref().unwrap().height(), 4);
Ok(())
})
drop(gemla);
assert!(path.exists());
Ok(())
})
})
}).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
}
#[test]
fn test_simulate() -> Result<(), Error> {
#[tokio::test]
async fn test_simulate() -> Result<(), Error> {
let path = PathBuf::from("test_simulate");
CleanUp::new(&path).run(|p| {
// Testing initial creation
let config = GemlaConfig {
generations_per_height: 10,
overwrite: true
};
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
// Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
tokio::task::spawn_blocking(move || {
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
CleanUp::new(&path).run(move |p| {
rt.block_on(async {
// Testing initial creation
let config = GemlaConfig {
generations_per_height: 10,
overwrite: true,
};
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
smol::block_on(gemla.simulate(5))?;
let tree = gemla.tree_ref().unwrap();
assert_eq!(tree.height(), 5);
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
// Now we can use `.await` within the spawned blocking task.
gemla.simulate(5).await?;
let tree = gemla.tree_ref().unwrap();
assert_eq!(tree.height(), 5);
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
Ok(())
})
Ok(())
})
})
}).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
}
}