Adding user defined, shared context

This commit is contained in:
vandomej 2024-04-04 18:40:17 -07:00
parent d473970325
commit 95699bd47e
7 changed files with 212 additions and 148 deletions

View file

@ -31,3 +31,4 @@ num_cpus = "1.16.0"
easy-parallel = "3.3.1"
fann = "0.1.8"
async-trait = "0.1.78"
async-recursion = "1.1.0"

View file

@ -47,7 +47,6 @@ fn main() -> Result<()> {
GemlaConfig {
generations_per_height: 5,
overwrite: false,
shared_semaphore_concurrency_limit: 50,
},
DataFormat::Json,
))?;

View file

@ -0,0 +1,48 @@
use std::sync::Arc;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tokio::sync::Semaphore;
const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 20;
#[derive(Debug, Clone)]
pub struct FighterContext {
pub shared_semaphore: Arc<Semaphore>,
}
impl Default for FighterContext {
fn default() -> Self {
FighterContext {
shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)),
}
}
}
// Custom serialization to just output the concurrency limit.
impl Serialize for FighterContext {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
// Assuming the semaphore's available permits represent the concurrency limit.
// This part is tricky since Semaphore does not expose its initial permits.
// You might need to store the concurrency limit as a separate field if this assumption doesn't hold.
let concurrency_limit = SHARED_SEMAPHORE_CONCURRENCY_LIMIT;
serializer.serialize_u64(concurrency_limit as u64)
}
}
// Custom deserialization to reconstruct the FighterContext from a concurrency limit.
impl<'de> Deserialize<'de> for FighterContext {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let concurrency_limit = u64::deserialize(deserializer)?;
Ok(FighterContext {
shared_semaphore: Arc::new(Semaphore::new(concurrency_limit as usize)),
})
}
}

View file

@ -1,6 +1,7 @@
extern crate fann;
pub mod neural_network_utility;
pub mod fighter_context;
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, path::{Path, PathBuf}, sync::Arc};
use fann::{ActivationFunc, Fann};
@ -63,8 +64,10 @@ pub struct FighterNN {
#[async_trait]
impl GeneticNode for FighterNN {
type Context = fighter_context::FighterContext;
// Check for the highest number of the folder name and increment it by 1
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error> {
async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
let base_path = PathBuf::from(BASE_DIR);
let folder = base_path.join(format!("fighter_nn_{:06}", context.id));
@ -127,100 +130,37 @@ impl GeneticNode for FighterNN {
}))
}
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> {
async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
debug!("Context: {:?}", context);
let mut tasks = Vec::new();
// For each nn in the current generation:
for i in 0..self.population_size {
let self_clone = self.clone();
let semaphore_clone = Arc::clone(context.semaphore.as_ref().unwrap());
let semaphore_clone = context.gemla_context.shared_semaphore.clone();
let task = async move {
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(format!("{:06}_fighter_nn_{}.net", self_clone.id, i));
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(self_clone.get_individual_id(i as u64));
let mut simulations = Vec::new();
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
for _ in 0..SIMULATION_ROUNDS {
let random_nn_index = thread_rng().gen_range(0..self_clone.population_size);
let id = self_clone.id.clone();
let folder = self_clone.folder.clone();
let generation = self_clone.generation;
let semaphore_clone = Arc::clone(&semaphore_clone);
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index));
let random_nn = folder.join(format!("{}", generation)).join(self_clone.get_individual_id(random_nn_index as u64));
let nn_clone = nn.clone(); // Clone the path to use in the async block
let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap());
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
let disable_unreal_rendering_arg = "-nullrhi".to_string();
let future = async move {
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
// Construct the score file path
let nn_id = format!("{:06}_fighter_nn_{}", id, i);
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id);
let score_file = folder.join(format!("{}", generation)).join(&score_file_name);
// Check if score file already exists before running the simulation
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_id).await
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
debug!("{} scored {}", nn_id, round_score);
return Ok::<f32, Error>(round_score);
}
// Check if the opposite round score has been determined
let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id));
if opposite_score_file.exists() {
let round_score = read_score_from_file(&opposite_score_file, &nn_id).await
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
debug!("{} scored {}", nn_id, round_score);
return Ok::<f32, Error>(1.0 - round_score);
}
// Run simulation until score file is generated
while !score_file.exists() {
let _output = if thread_rng().gen_range(0..100) < 1 {
Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.output()
.await
.expect("Failed to execute game")
} else {
Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.arg(&disable_unreal_rendering_arg)
.output()
.await
.expect("Failed to execute game")
};
}
let score = run_1v1_simulation(&nn_clone, &random_nn).await?;
drop(permit);
// Read the score from the file
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_id).await
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
debug!("{} scored {}", nn_id, round_score);
Ok(round_score)
} else {
warn!("Score file not found: {:?}", score_file_name);
Ok(0.0)
}
Ok(score)
};
simulations.push(future);
@ -259,7 +199,7 @@ impl GeneticNode for FighterNN {
}
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
// Create the new generation folder
@ -322,7 +262,7 @@ impl GeneticNode for FighterNN {
Ok(())
}
fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid) -> Result<Box<FighterNN>, Error> {
async fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid, _: Self::Context) -> Result<Box<FighterNN>, Error> {
let base_path = PathBuf::from(BASE_DIR);
let folder = base_path.join(format!("fighter_nn_{:06}", id));
@ -365,6 +305,78 @@ impl GeneticNode for FighterNN {
}
}
impl FighterNN {
pub fn get_individual_id(&self, nn_id: u64) -> String {
format!("{:06}_fighter_nn_{}", self.id, nn_id)
}
}
async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<f32, Error> {
// Construct the score file path
let base_folder = nn_path_1.parent().unwrap();
let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap();
let nn_2_id = nn_path_2.file_stem().unwrap().to_str().unwrap();
let score_file = base_folder.join(format!("{}_vs_{}.txt", nn_1_id, nn_2_id));
// Check if score file already exists before running the simulation
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_1_id).await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
trace!("{} scored {}", nn_1_id, round_score);
return Ok::<f32, Error>(round_score);
}
// Check if the opposite round score has been determined
let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id));
if opposite_score_file.exists() {
let round_score = read_score_from_file(&opposite_score_file, &nn_1_id).await
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
trace!("{} scored {}", nn_1_id, round_score);
return Ok::<f32, Error>(1.0 - round_score);
}
// Run simulation until score file is generated
let config1_arg = format!("-NN1Config=\"{}\"", nn_path_1.to_str().unwrap());
let config2_arg = format!("-NN2Config=\"{}\"", nn_path_2.to_str().unwrap());
let disable_unreal_rendering_arg = "-nullrhi".to_string();
while !score_file.exists() {
let _output = if thread_rng().gen_range(0..100) < 1 {
Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.output()
.await
.expect("Failed to execute game")
} else {
Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.arg(&disable_unreal_rendering_arg)
.output()
.await
.expect("Failed to execute game")
};
}
// Read the score from the file
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_1_id).await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
trace!("{} scored {}", nn_1_id, round_score);
Ok(round_score)
} else {
warn!("Score file not found: {:?}", score_file);
Ok(0.0)
}
}
async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::Error> {
let mut attempts = 0;

View file

@ -14,7 +14,9 @@ pub struct TestState {
#[async_trait]
impl GeneticNode for TestState {
fn initialize(_context: GeneticNodeContext) -> Result<Box<Self>, Error> {
type Context = ();
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
let mut population: Vec<i64> = vec![];
for _ in 0..POPULATION_SIZE {
@ -24,7 +26,7 @@ impl GeneticNode for TestState {
Ok(Box::new(TestState { population }))
}
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
let mut rng = thread_rng();
self.population = self
@ -36,7 +38,7 @@ impl GeneticNode for TestState {
Ok(())
}
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
let mut rng = thread_rng();
let mut v = self.population.clone();
@ -74,7 +76,7 @@ impl GeneticNode for TestState {
Ok(())
}
fn merge(left: &TestState, right: &TestState, id: &Uuid) -> Result<Box<TestState>, Error> {
async fn merge(left: &TestState, right: &TestState, id: &Uuid, gemla_context: Self::Context) -> Result<Box<TestState>, Error> {
let mut v = left.population.clone();
v.append(&mut right.population.clone());
@ -89,8 +91,8 @@ impl GeneticNode for TestState {
id: id.clone(),
generation: 0,
max_generations: 0,
semaphore: None,
})?;
gemla_context: gemla_context
}).await?;
Ok(Box::new(result))
}
@ -101,16 +103,16 @@ mod tests {
use super::*;
use gemla::core::genetic_node::GeneticNode;
#[test]
fn test_initialize() {
#[tokio::test]
async fn test_initialize() {
let state = TestState::initialize(
GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
semaphore: None,
gemla_context: (),
}
).unwrap();
).await.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
}
@ -128,7 +130,7 @@ mod tests {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
semaphore: None,
gemla_context: (),
}
).await.unwrap();
assert!(original_population
@ -141,7 +143,7 @@ mod tests {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
semaphore: None,
gemla_context: (),
}
).await.unwrap();
state.simulate(
@ -149,7 +151,7 @@ mod tests {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
semaphore: None,
gemla_context: (),
}
).await.unwrap();
assert!(original_population
@ -158,8 +160,8 @@ mod tests {
.all(|(&a, &b)| b >= a - 3 && b <= a + 6))
}
#[test]
fn test_mutate() {
#[tokio::test]
async fn test_mutate() {
let mut state = TestState {
population: vec![4, 3, 3],
};
@ -169,15 +171,15 @@ mod tests {
id: Uuid::new_v4(),
generation: 0,
max_generations: 0,
semaphore: None,
gemla_context: (),
}
).unwrap();
).await.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
}
#[test]
fn test_merge() {
#[tokio::test]
async fn test_merge() {
let state1 = TestState {
population: vec![1, 2, 4, 5],
};
@ -186,7 +188,7 @@ mod tests {
population: vec![0, 1, 3, 7],
};
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4()).unwrap();
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()).await.unwrap();
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
assert!(merged_state.population.iter().any(|&x| x == 7));

View file

@ -5,9 +5,8 @@
use crate::error::Error;
use anyhow::Context;
use serde::{Deserialize, Serialize};
use tokio::sync::Semaphore;
use std::{fmt::Debug, sync::Arc};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt::Debug;
use uuid::Uuid;
use async_trait::async_trait;
@ -27,33 +26,35 @@ pub enum GeneticState {
}
#[derive(Clone, Debug)]
pub struct GeneticNodeContext {
pub struct GeneticNodeContext<S> {
pub generation: u64,
pub max_generations: u64,
pub id: Uuid,
pub semaphore: Option<Arc<Semaphore>>,
pub gemla_context: S
}
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
///
/// [`Bracket`]: crate::bracket::Bracket
#[async_trait]
pub trait GeneticNode: Send {
pub trait GeneticNode : Send {
type Context;
/// Initializes a new instance of a [`GeneticState`].
///
/// # Examples
/// TODO
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error>;
async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error>;
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error>;
async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
///
/// # Examples
/// TODO
fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>;
async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
fn merge(left: &Self, right: &Self, id: &Uuid) -> Result<Box<Self>, Error>;
async fn merge(left: &Self, right: &Self, id: &Uuid, context: Self::Context) -> Result<Box<Self>, Error>;
}
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
@ -82,6 +83,7 @@ impl<T> Default for GeneticNodeWrapper<T> {
impl<T> GeneticNodeWrapper<T>
where
T: GeneticNode + Debug + Send,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{
pub fn new(max_generations: u64) -> Self {
GeneticNodeWrapper::<T> {
@ -120,17 +122,17 @@ where
self.state
}
pub async fn process_node(&mut self, semaphore: Arc<Semaphore>) -> Result<GeneticState, Error> {
pub async fn process_node(&mut self, gemla_context: T::Context) -> Result<GeneticState, Error> {
let context = GeneticNodeContext {
generation: self.generation,
max_generations: self.max_generations,
id: self.id,
semaphore: Some(semaphore),
gemla_context,
};
match (self.state, &mut self.node) {
(GeneticState::Initialize, _) => {
self.node = Some(*T::initialize(context.clone())?);
self.node = Some(*T::initialize(context.clone()).await?);
self.state = GeneticState::Simulate;
}
(GeneticState::Simulate, Some(n)) => {
@ -144,7 +146,7 @@ where
};
}
(GeneticState::Mutate, Some(n)) => {
n.mutate(context.clone())
n.mutate(context.clone()).await
.with_context(|| format!("Error mutating node: {:?}", self))?;
self.generation += 1;
@ -172,20 +174,22 @@ mod tests {
#[async_trait]
impl GeneticNode for TestState {
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
type Context = ();
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
self.score += 1.0;
Ok(())
}
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
Ok(())
}
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> {
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
fn merge(_l: &TestState, _r: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> {
async fn merge(_l: &TestState, _r: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
Err(Error::Other(anyhow!("Unable to merge")))
}
}
@ -281,14 +285,13 @@ mod tests {
#[tokio::test]
async fn test_process_node() -> Result<(), Error> {
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
let semaphore = Arc::new(Semaphore::new(1));
assert_eq!(genetic_node.state(), GeneticState::Initialize);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
Ok(())
}

View file

@ -4,15 +4,15 @@
pub mod genetic_node;
use crate::{error::Error, tree::Tree};
use async_recursion::async_recursion;
use file_linked::{constants::data_format::DataFormat, FileLinked};
use futures::future;
use futures::{executor::block_on, future};
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
use log::{info, trace, warn};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::task::JoinHandle;
use tokio::sync::Semaphore;
use std::{
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, time::Instant
};
use uuid::Uuid;
@ -58,7 +58,6 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
pub struct GemlaConfig {
pub generations_per_height: u64,
pub overwrite: bool,
pub shared_semaphore_concurrency_limit: usize,
}
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
@ -69,16 +68,17 @@ pub struct GemlaConfig {
/// [`GeneticNode`]: genetic_node::GeneticNode
pub struct Gemla<T>
where
T: Serialize + Clone,
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>,
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
semaphore: Arc<Semaphore>,
}
impl<T: 'static> Gemla<T>
where
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{
pub fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result<Self, Error> {
match File::open(path) {
@ -86,18 +86,16 @@ where
// based on the configuration provided
Ok(_) => Ok(Gemla {
data: if config.overwrite {
FileLinked::new((None, config), path, data_format)?
FileLinked::new((None, config, T::Context::default()), path, data_format)?
} else {
FileLinked::from_file(path, data_format)?
},
threads: HashMap::new(),
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
}),
// If the file doesn't exist we must create it
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
data: FileLinked::new((None, config), path, data_format)?,
data: FileLinked::new((None, config, T::Context::default()), path, data_format)?,
threads: HashMap::new(),
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
}),
Err(error) => Err(Error::IO(error)),
}
@ -117,7 +115,7 @@ where
{
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
// in the tree and which nodes have not.
self.data.mutate(|(d, c)| {
self.data.mutate(|(d, c, _)| {
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
mem::swap(d, &mut tree);
})?;
@ -151,11 +149,11 @@ where
{
trace!("Adding node to process list {}", node.id());
let semaphore = self.semaphore.clone();
let gemla_context = self.data.readonly().2.clone();
self.threads
.insert(node.id(), tokio::spawn(async move {
Gemla::process_node(node, semaphore).await
Gemla::process_node(node, gemla_context).await
}));
} else {
trace!("No node found to process, joining threads");
@ -180,7 +178,7 @@ where
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
reduced_results.and_then(|r| {
self.data.mutate(|(d, _)| {
self.data.mutate(|(d, _, context)| {
if let Some(t) = d {
let failed_nodes = Gemla::replace_nodes(t, r);
// We receive a list of nodes that were unable to be found in the original tree
@ -192,7 +190,7 @@ where
}
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
Gemla::merge_completed_nodes(t)
block_on(Gemla::merge_completed_nodes(t, context.clone()))
} else {
warn!("Unable to replce nodes {:?} in empty tree", r);
Ok(())
@ -204,7 +202,8 @@ where
Ok(())
}
fn merge_completed_nodes(tree: &mut SimulationTree<T>) -> Result<(), Error> {
#[async_recursion]
async fn merge_completed_nodes(tree: &mut SimulationTree<T>, gemla_context: T::Context) -> Result<(), Error> {
if tree.val.state() == GeneticState::Initialize {
match (&mut tree.left, &mut tree.right) {
// If the current node has been initialized, and has children nodes that are completed, then we need
@ -215,7 +214,7 @@ where
{
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id())?;
let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id(), gemla_context.clone()).await?;
tree.val = GeneticNodeWrapper::from(
*merged_node,
tree.val.max_generations(),
@ -224,8 +223,8 @@ where
}
}
(Some(l), Some(r)) => {
Gemla::merge_completed_nodes(l)?;
Gemla::merge_completed_nodes(r)?;
Gemla::merge_completed_nodes(l, gemla_context.clone()).await?;
Gemla::merge_completed_nodes(r, gemla_context.clone()).await?;
}
// If there is only one child node that's completed then we want to copy it to the parent node
(Some(l), None) if l.val.state() == GeneticState::Finish => {
@ -239,7 +238,7 @@ where
);
}
}
(Some(l), None) => Gemla::merge_completed_nodes(l)?,
(Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?,
(None, Some(r)) if r.val.state() == GeneticState::Finish => {
trace!("Copying node {}", r.val.id());
@ -251,7 +250,7 @@ where
);
}
}
(None, Some(r)) => Gemla::merge_completed_nodes(r)?,
(None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?,
(_, _) => (),
}
}
@ -329,15 +328,15 @@ where
tree.val.state() == GeneticState::Finish
}
async fn process_node(mut node: GeneticNodeWrapper<T>, semaphore: Arc<Semaphore>) -> Result<GeneticNodeWrapper<T>, Error> {
async fn process_node(mut node: GeneticNodeWrapper<T>, gemla_context: T::Context) -> Result<GeneticNodeWrapper<T>, Error> {
let node_state_time = Instant::now();
let node_state = node.state();
node.process_node(semaphore.clone()).await?;
node.process_node(gemla_context.clone()).await?;
if node.state() == GeneticState::Simulate
{
node.process_node(semaphore.clone()).await?;
node.process_node(gemla_context.clone()).await?;
}
trace!(
@ -397,20 +396,22 @@ mod tests {
#[async_trait]
impl genetic_node::GeneticNode for TestState {
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
type Context = ();
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
self.score += 1.0;
Ok(())
}
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
Ok(())
}
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> {
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
fn merge(left: &TestState, right: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> {
async fn merge(left: &TestState, right: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.score > right.score {
left.clone()
} else {
@ -433,7 +434,6 @@ mod tests {
let mut config = GemlaConfig {
generations_per_height: 1,
overwrite: true,
shared_semaphore_concurrency_limit: 1,
};
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
@ -483,7 +483,6 @@ mod tests {
let config = GemlaConfig {
generations_per_height: 10,
overwrite: true,
shared_semaphore_concurrency_limit: 1,
};
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;