Adding user defined, shared context

This commit is contained in:
vandomej 2024-04-04 18:40:17 -07:00
parent d473970325
commit 95699bd47e
7 changed files with 212 additions and 148 deletions

View file

@ -31,3 +31,4 @@ num_cpus = "1.16.0"
easy-parallel = "3.3.1" easy-parallel = "3.3.1"
fann = "0.1.8" fann = "0.1.8"
async-trait = "0.1.78" async-trait = "0.1.78"
async-recursion = "1.1.0"

View file

@ -47,7 +47,6 @@ fn main() -> Result<()> {
GemlaConfig { GemlaConfig {
generations_per_height: 5, generations_per_height: 5,
overwrite: false, overwrite: false,
shared_semaphore_concurrency_limit: 50,
}, },
DataFormat::Json, DataFormat::Json,
))?; ))?;

View file

@ -0,0 +1,48 @@
use std::sync::Arc;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tokio::sync::Semaphore;
const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 20;
#[derive(Debug, Clone)]
pub struct FighterContext {
pub shared_semaphore: Arc<Semaphore>,
}
impl Default for FighterContext {
fn default() -> Self {
FighterContext {
shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)),
}
}
}
// Custom serialization to just output the concurrency limit.
impl Serialize for FighterContext {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
// Assuming the semaphore's available permits represent the concurrency limit.
// This part is tricky since Semaphore does not expose its initial permits.
// You might need to store the concurrency limit as a separate field if this assumption doesn't hold.
let concurrency_limit = SHARED_SEMAPHORE_CONCURRENCY_LIMIT;
serializer.serialize_u64(concurrency_limit as u64)
}
}
// Custom deserialization to reconstruct the FighterContext from a concurrency limit.
impl<'de> Deserialize<'de> for FighterContext {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let concurrency_limit = u64::deserialize(deserializer)?;
Ok(FighterContext {
shared_semaphore: Arc::new(Semaphore::new(concurrency_limit as usize)),
})
}
}

View file

@ -1,6 +1,7 @@
extern crate fann; extern crate fann;
pub mod neural_network_utility; pub mod neural_network_utility;
pub mod fighter_context;
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, path::{Path, PathBuf}, sync::Arc}; use std::{fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, path::{Path, PathBuf}, sync::Arc};
use fann::{ActivationFunc, Fann}; use fann::{ActivationFunc, Fann};
@ -63,8 +64,10 @@ pub struct FighterNN {
#[async_trait] #[async_trait]
impl GeneticNode for FighterNN { impl GeneticNode for FighterNN {
type Context = fighter_context::FighterContext;
// Check for the highest number of the folder name and increment it by 1 // Check for the highest number of the folder name and increment it by 1
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error> { async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
let base_path = PathBuf::from(BASE_DIR); let base_path = PathBuf::from(BASE_DIR);
let folder = base_path.join(format!("fighter_nn_{:06}", context.id)); let folder = base_path.join(format!("fighter_nn_{:06}", context.id));
@ -127,100 +130,37 @@ impl GeneticNode for FighterNN {
})) }))
} }
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> { async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
debug!("Context: {:?}", context); debug!("Context: {:?}", context);
let mut tasks = Vec::new(); let mut tasks = Vec::new();
// For each nn in the current generation: // For each nn in the current generation:
for i in 0..self.population_size { for i in 0..self.population_size {
let self_clone = self.clone(); let self_clone = self.clone();
let semaphore_clone = Arc::clone(context.semaphore.as_ref().unwrap()); let semaphore_clone = context.gemla_context.shared_semaphore.clone();
let task = async move { let task = async move {
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(format!("{:06}_fighter_nn_{}.net", self_clone.id, i)); let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(self_clone.get_individual_id(i as u64));
let mut simulations = Vec::new(); let mut simulations = Vec::new();
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently // Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
for _ in 0..SIMULATION_ROUNDS { for _ in 0..SIMULATION_ROUNDS {
let random_nn_index = thread_rng().gen_range(0..self_clone.population_size); let random_nn_index = thread_rng().gen_range(0..self_clone.population_size);
let id = self_clone.id.clone();
let folder = self_clone.folder.clone(); let folder = self_clone.folder.clone();
let generation = self_clone.generation; let generation = self_clone.generation;
let semaphore_clone = Arc::clone(&semaphore_clone); let semaphore_clone = Arc::clone(&semaphore_clone);
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index)); let random_nn = folder.join(format!("{}", generation)).join(self_clone.get_individual_id(random_nn_index as u64));
let nn_clone = nn.clone(); // Clone the path to use in the async block let nn_clone = nn.clone(); // Clone the path to use in the async block
let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap());
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
let disable_unreal_rendering_arg = "-nullrhi".to_string();
let future = async move { let future = async move {
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?; let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
// Construct the score file path let score = run_1v1_simulation(&nn_clone, &random_nn).await?;
let nn_id = format!("{:06}_fighter_nn_{}", id, i);
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id);
let score_file = folder.join(format!("{}", generation)).join(&score_file_name);
// Check if score file already exists before running the simulation
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_id).await
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
debug!("{} scored {}", nn_id, round_score);
return Ok::<f32, Error>(round_score);
}
// Check if the opposite round score has been determined
let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id));
if opposite_score_file.exists() {
let round_score = read_score_from_file(&opposite_score_file, &nn_id).await
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
debug!("{} scored {}", nn_id, round_score);
return Ok::<f32, Error>(1.0 - round_score);
}
// Run simulation until score file is generated
while !score_file.exists() {
let _output = if thread_rng().gen_range(0..100) < 1 {
Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.output()
.await
.expect("Failed to execute game")
} else {
Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.arg(&disable_unreal_rendering_arg)
.output()
.await
.expect("Failed to execute game")
};
}
drop(permit); drop(permit);
// Read the score from the file Ok(score)
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_id).await
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
debug!("{} scored {}", nn_id, round_score);
Ok(round_score)
} else {
warn!("Score file not found: {:?}", score_file_name);
Ok(0.0)
}
}; };
simulations.push(future); simulations.push(future);
@ -259,7 +199,7 @@ impl GeneticNode for FighterNN {
} }
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize; let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
// Create the new generation folder // Create the new generation folder
@ -322,7 +262,7 @@ impl GeneticNode for FighterNN {
Ok(()) Ok(())
} }
fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid) -> Result<Box<FighterNN>, Error> { async fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid, _: Self::Context) -> Result<Box<FighterNN>, Error> {
let base_path = PathBuf::from(BASE_DIR); let base_path = PathBuf::from(BASE_DIR);
let folder = base_path.join(format!("fighter_nn_{:06}", id)); let folder = base_path.join(format!("fighter_nn_{:06}", id));
@ -365,6 +305,78 @@ impl GeneticNode for FighterNN {
} }
} }
impl FighterNN {
pub fn get_individual_id(&self, nn_id: u64) -> String {
format!("{:06}_fighter_nn_{}", self.id, nn_id)
}
}
async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<f32, Error> {
// Construct the score file path
let base_folder = nn_path_1.parent().unwrap();
let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap();
let nn_2_id = nn_path_2.file_stem().unwrap().to_str().unwrap();
let score_file = base_folder.join(format!("{}_vs_{}.txt", nn_1_id, nn_2_id));
// Check if score file already exists before running the simulation
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_1_id).await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
trace!("{} scored {}", nn_1_id, round_score);
return Ok::<f32, Error>(round_score);
}
// Check if the opposite round score has been determined
let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id));
if opposite_score_file.exists() {
let round_score = read_score_from_file(&opposite_score_file, &nn_1_id).await
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
trace!("{} scored {}", nn_1_id, round_score);
return Ok::<f32, Error>(1.0 - round_score);
}
// Run simulation until score file is generated
let config1_arg = format!("-NN1Config=\"{}\"", nn_path_1.to_str().unwrap());
let config2_arg = format!("-NN2Config=\"{}\"", nn_path_2.to_str().unwrap());
let disable_unreal_rendering_arg = "-nullrhi".to_string();
while !score_file.exists() {
let _output = if thread_rng().gen_range(0..100) < 1 {
Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.output()
.await
.expect("Failed to execute game")
} else {
Command::new(GAME_EXECUTABLE_PATH)
.arg(&config1_arg)
.arg(&config2_arg)
.arg(&disable_unreal_rendering_arg)
.output()
.await
.expect("Failed to execute game")
};
}
// Read the score from the file
if score_file.exists() {
let round_score = read_score_from_file(&score_file, &nn_1_id).await
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
trace!("{} scored {}", nn_1_id, round_score);
Ok(round_score)
} else {
warn!("Score file not found: {:?}", score_file);
Ok(0.0)
}
}
async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::Error> { async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::Error> {
let mut attempts = 0; let mut attempts = 0;

View file

@ -14,7 +14,9 @@ pub struct TestState {
#[async_trait] #[async_trait]
impl GeneticNode for TestState { impl GeneticNode for TestState {
fn initialize(_context: GeneticNodeContext) -> Result<Box<Self>, Error> { type Context = ();
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
let mut population: Vec<i64> = vec![]; let mut population: Vec<i64> = vec![];
for _ in 0..POPULATION_SIZE { for _ in 0..POPULATION_SIZE {
@ -24,7 +26,7 @@ impl GeneticNode for TestState {
Ok(Box::new(TestState { population })) Ok(Box::new(TestState { population }))
} }
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
let mut rng = thread_rng(); let mut rng = thread_rng();
self.population = self self.population = self
@ -36,7 +38,7 @@ impl GeneticNode for TestState {
Ok(()) Ok(())
} }
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
let mut rng = thread_rng(); let mut rng = thread_rng();
let mut v = self.population.clone(); let mut v = self.population.clone();
@ -74,7 +76,7 @@ impl GeneticNode for TestState {
Ok(()) Ok(())
} }
fn merge(left: &TestState, right: &TestState, id: &Uuid) -> Result<Box<TestState>, Error> { async fn merge(left: &TestState, right: &TestState, id: &Uuid, gemla_context: Self::Context) -> Result<Box<TestState>, Error> {
let mut v = left.population.clone(); let mut v = left.population.clone();
v.append(&mut right.population.clone()); v.append(&mut right.population.clone());
@ -89,8 +91,8 @@ impl GeneticNode for TestState {
id: id.clone(), id: id.clone(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None, gemla_context: gemla_context
})?; }).await?;
Ok(Box::new(result)) Ok(Box::new(result))
} }
@ -101,16 +103,16 @@ mod tests {
use super::*; use super::*;
use gemla::core::genetic_node::GeneticNode; use gemla::core::genetic_node::GeneticNode;
#[test] #[tokio::test]
fn test_initialize() { async fn test_initialize() {
let state = TestState::initialize( let state = TestState::initialize(
GeneticNodeContext { GeneticNodeContext {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None, gemla_context: (),
} }
).unwrap(); ).await.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize); assert_eq!(state.population.len(), POPULATION_SIZE as usize);
} }
@ -128,7 +130,7 @@ mod tests {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None, gemla_context: (),
} }
).await.unwrap(); ).await.unwrap();
assert!(original_population assert!(original_population
@ -141,7 +143,7 @@ mod tests {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None, gemla_context: (),
} }
).await.unwrap(); ).await.unwrap();
state.simulate( state.simulate(
@ -149,7 +151,7 @@ mod tests {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None, gemla_context: (),
} }
).await.unwrap(); ).await.unwrap();
assert!(original_population assert!(original_population
@ -158,8 +160,8 @@ mod tests {
.all(|(&a, &b)| b >= a - 3 && b <= a + 6)) .all(|(&a, &b)| b >= a - 3 && b <= a + 6))
} }
#[test] #[tokio::test]
fn test_mutate() { async fn test_mutate() {
let mut state = TestState { let mut state = TestState {
population: vec![4, 3, 3], population: vec![4, 3, 3],
}; };
@ -169,15 +171,15 @@ mod tests {
id: Uuid::new_v4(), id: Uuid::new_v4(),
generation: 0, generation: 0,
max_generations: 0, max_generations: 0,
semaphore: None, gemla_context: (),
} }
).unwrap(); ).await.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize); assert_eq!(state.population.len(), POPULATION_SIZE as usize);
} }
#[test] #[tokio::test]
fn test_merge() { async fn test_merge() {
let state1 = TestState { let state1 = TestState {
population: vec![1, 2, 4, 5], population: vec![1, 2, 4, 5],
}; };
@ -186,7 +188,7 @@ mod tests {
population: vec![0, 1, 3, 7], population: vec![0, 1, 3, 7],
}; };
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4()).unwrap(); let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()).await.unwrap();
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize); assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
assert!(merged_state.population.iter().any(|&x| x == 7)); assert!(merged_state.population.iter().any(|&x| x == 7));

View file

@ -5,9 +5,8 @@
use crate::error::Error; use crate::error::Error;
use anyhow::Context; use anyhow::Context;
use serde::{Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::sync::Semaphore; use std::fmt::Debug;
use std::{fmt::Debug, sync::Arc};
use uuid::Uuid; use uuid::Uuid;
use async_trait::async_trait; use async_trait::async_trait;
@ -27,11 +26,11 @@ pub enum GeneticState {
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct GeneticNodeContext { pub struct GeneticNodeContext<S> {
pub generation: u64, pub generation: u64,
pub max_generations: u64, pub max_generations: u64,
pub id: Uuid, pub id: Uuid,
pub semaphore: Option<Arc<Semaphore>>, pub gemla_context: S
} }
/// A trait used to interact with the internal state of nodes within the [`Bracket`] /// A trait used to interact with the internal state of nodes within the [`Bracket`]
@ -39,21 +38,23 @@ pub struct GeneticNodeContext {
/// [`Bracket`]: crate::bracket::Bracket /// [`Bracket`]: crate::bracket::Bracket
#[async_trait] #[async_trait]
pub trait GeneticNode : Send { pub trait GeneticNode : Send {
type Context;
/// Initializes a new instance of a [`GeneticState`]. /// Initializes a new instance of a [`GeneticState`].
/// ///
/// # Examples /// # Examples
/// TODO /// TODO
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error>; async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error>;
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
/// Mutates members in a population and/or crossbreeds them to produce new offspring. /// Mutates members in a population and/or crossbreeds them to produce new offspring.
/// ///
/// # Examples /// # Examples
/// TODO /// TODO
fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
fn merge(left: &Self, right: &Self, id: &Uuid) -> Result<Box<Self>, Error>; async fn merge(left: &Self, right: &Self, id: &Uuid, context: Self::Context) -> Result<Box<Self>, Error>;
} }
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as /// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
@ -82,6 +83,7 @@ impl<T> Default for GeneticNodeWrapper<T> {
impl<T> GeneticNodeWrapper<T> impl<T> GeneticNodeWrapper<T>
where where
T: GeneticNode + Debug + Send, T: GeneticNode + Debug + Send,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{ {
pub fn new(max_generations: u64) -> Self { pub fn new(max_generations: u64) -> Self {
GeneticNodeWrapper::<T> { GeneticNodeWrapper::<T> {
@ -120,17 +122,17 @@ where
self.state self.state
} }
pub async fn process_node(&mut self, semaphore: Arc<Semaphore>) -> Result<GeneticState, Error> { pub async fn process_node(&mut self, gemla_context: T::Context) -> Result<GeneticState, Error> {
let context = GeneticNodeContext { let context = GeneticNodeContext {
generation: self.generation, generation: self.generation,
max_generations: self.max_generations, max_generations: self.max_generations,
id: self.id, id: self.id,
semaphore: Some(semaphore), gemla_context,
}; };
match (self.state, &mut self.node) { match (self.state, &mut self.node) {
(GeneticState::Initialize, _) => { (GeneticState::Initialize, _) => {
self.node = Some(*T::initialize(context.clone())?); self.node = Some(*T::initialize(context.clone()).await?);
self.state = GeneticState::Simulate; self.state = GeneticState::Simulate;
} }
(GeneticState::Simulate, Some(n)) => { (GeneticState::Simulate, Some(n)) => {
@ -144,7 +146,7 @@ where
}; };
} }
(GeneticState::Mutate, Some(n)) => { (GeneticState::Mutate, Some(n)) => {
n.mutate(context.clone()) n.mutate(context.clone()).await
.with_context(|| format!("Error mutating node: {:?}", self))?; .with_context(|| format!("Error mutating node: {:?}", self))?;
self.generation += 1; self.generation += 1;
@ -172,20 +174,22 @@ mod tests {
#[async_trait] #[async_trait]
impl GeneticNode for TestState { impl GeneticNode for TestState {
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { type Context = ();
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
self.score += 1.0; self.score += 1.0;
Ok(()) Ok(())
} }
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> { async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 })) Ok(Box::new(TestState { score: 0.0 }))
} }
fn merge(_l: &TestState, _r: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> { async fn merge(_l: &TestState, _r: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
Err(Error::Other(anyhow!("Unable to merge"))) Err(Error::Other(anyhow!("Unable to merge")))
} }
} }
@ -281,14 +285,13 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_process_node() -> Result<(), Error> { async fn test_process_node() -> Result<(), Error> {
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2); let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
let semaphore = Arc::new(Semaphore::new(1));
assert_eq!(genetic_node.state(), GeneticState::Initialize); assert_eq!(genetic_node.state(), GeneticState::Initialize);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate); assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Mutate); assert_eq!(genetic_node.process_node(()).await?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate); assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish); assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish); assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
Ok(()) Ok(())
} }

View file

@ -4,15 +4,15 @@
pub mod genetic_node; pub mod genetic_node;
use crate::{error::Error, tree::Tree}; use crate::{error::Error, tree::Tree};
use async_recursion::async_recursion;
use file_linked::{constants::data_format::DataFormat, FileLinked}; use file_linked::{constants::data_format::DataFormat, FileLinked};
use futures::future; use futures::{executor::block_on, future};
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState}; use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
use log::{info, trace, warn}; use log::{info, trace, warn};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio::sync::Semaphore;
use std::{ use std::{
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, time::Instant
}; };
use uuid::Uuid; use uuid::Uuid;
@ -58,7 +58,6 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
pub struct GemlaConfig { pub struct GemlaConfig {
pub generations_per_height: u64, pub generations_per_height: u64,
pub overwrite: bool, pub overwrite: bool,
pub shared_semaphore_concurrency_limit: usize,
} }
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`]. /// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
@ -69,16 +68,17 @@ pub struct GemlaConfig {
/// [`GeneticNode`]: genetic_node::GeneticNode /// [`GeneticNode`]: genetic_node::GeneticNode
pub struct Gemla<T> pub struct Gemla<T>
where where
T: Serialize + Clone, T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{ {
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>, pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>,
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>, threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
semaphore: Arc<Semaphore>,
} }
impl<T: 'static> Gemla<T> impl<T: 'static> Gemla<T>
where where
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send, T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{ {
pub fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result<Self, Error> { pub fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result<Self, Error> {
match File::open(path) { match File::open(path) {
@ -86,18 +86,16 @@ where
// based on the configuration provided // based on the configuration provided
Ok(_) => Ok(Gemla { Ok(_) => Ok(Gemla {
data: if config.overwrite { data: if config.overwrite {
FileLinked::new((None, config), path, data_format)? FileLinked::new((None, config, T::Context::default()), path, data_format)?
} else { } else {
FileLinked::from_file(path, data_format)? FileLinked::from_file(path, data_format)?
}, },
threads: HashMap::new(), threads: HashMap::new(),
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
}), }),
// If the file doesn't exist we must create it // If the file doesn't exist we must create it
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla { Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
data: FileLinked::new((None, config), path, data_format)?, data: FileLinked::new((None, config, T::Context::default()), path, data_format)?,
threads: HashMap::new(), threads: HashMap::new(),
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
}), }),
Err(error) => Err(Error::IO(error)), Err(error) => Err(Error::IO(error)),
} }
@ -117,7 +115,7 @@ where
{ {
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed // Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
// in the tree and which nodes have not. // in the tree and which nodes have not.
self.data.mutate(|(d, c)| { self.data.mutate(|(d, c, _)| {
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps); let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
mem::swap(d, &mut tree); mem::swap(d, &mut tree);
})?; })?;
@ -151,11 +149,11 @@ where
{ {
trace!("Adding node to process list {}", node.id()); trace!("Adding node to process list {}", node.id());
let semaphore = self.semaphore.clone(); let gemla_context = self.data.readonly().2.clone();
self.threads self.threads
.insert(node.id(), tokio::spawn(async move { .insert(node.id(), tokio::spawn(async move {
Gemla::process_node(node, semaphore).await Gemla::process_node(node, gemla_context).await
})); }));
} else { } else {
trace!("No node found to process, joining threads"); trace!("No node found to process, joining threads");
@ -180,7 +178,7 @@ where
// We need to retrieve the processed nodes from the resulting list and replace them in the original list // We need to retrieve the processed nodes from the resulting list and replace them in the original list
reduced_results.and_then(|r| { reduced_results.and_then(|r| {
self.data.mutate(|(d, _)| { self.data.mutate(|(d, _, context)| {
if let Some(t) = d { if let Some(t) = d {
let failed_nodes = Gemla::replace_nodes(t, r); let failed_nodes = Gemla::replace_nodes(t, r);
// We receive a list of nodes that were unable to be found in the original tree // We receive a list of nodes that were unable to be found in the original tree
@ -192,7 +190,7 @@ where
} }
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes // Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
Gemla::merge_completed_nodes(t) block_on(Gemla::merge_completed_nodes(t, context.clone()))
} else { } else {
warn!("Unable to replce nodes {:?} in empty tree", r); warn!("Unable to replce nodes {:?} in empty tree", r);
Ok(()) Ok(())
@ -204,7 +202,8 @@ where
Ok(()) Ok(())
} }
fn merge_completed_nodes(tree: &mut SimulationTree<T>) -> Result<(), Error> { #[async_recursion]
async fn merge_completed_nodes(tree: &mut SimulationTree<T>, gemla_context: T::Context) -> Result<(), Error> {
if tree.val.state() == GeneticState::Initialize { if tree.val.state() == GeneticState::Initialize {
match (&mut tree.left, &mut tree.right) { match (&mut tree.left, &mut tree.right) {
// If the current node has been initialized, and has children nodes that are completed, then we need // If the current node has been initialized, and has children nodes that are completed, then we need
@ -215,7 +214,7 @@ where
{ {
info!("Merging nodes {} and {}", l.val.id(), r.val.id()); info!("Merging nodes {} and {}", l.val.id(), r.val.id());
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) { if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id())?; let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id(), gemla_context.clone()).await?;
tree.val = GeneticNodeWrapper::from( tree.val = GeneticNodeWrapper::from(
*merged_node, *merged_node,
tree.val.max_generations(), tree.val.max_generations(),
@ -224,8 +223,8 @@ where
} }
} }
(Some(l), Some(r)) => { (Some(l), Some(r)) => {
Gemla::merge_completed_nodes(l)?; Gemla::merge_completed_nodes(l, gemla_context.clone()).await?;
Gemla::merge_completed_nodes(r)?; Gemla::merge_completed_nodes(r, gemla_context.clone()).await?;
} }
// If there is only one child node that's completed then we want to copy it to the parent node // If there is only one child node that's completed then we want to copy it to the parent node
(Some(l), None) if l.val.state() == GeneticState::Finish => { (Some(l), None) if l.val.state() == GeneticState::Finish => {
@ -239,7 +238,7 @@ where
); );
} }
} }
(Some(l), None) => Gemla::merge_completed_nodes(l)?, (Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?,
(None, Some(r)) if r.val.state() == GeneticState::Finish => { (None, Some(r)) if r.val.state() == GeneticState::Finish => {
trace!("Copying node {}", r.val.id()); trace!("Copying node {}", r.val.id());
@ -251,7 +250,7 @@ where
); );
} }
} }
(None, Some(r)) => Gemla::merge_completed_nodes(r)?, (None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?,
(_, _) => (), (_, _) => (),
} }
} }
@ -329,15 +328,15 @@ where
tree.val.state() == GeneticState::Finish tree.val.state() == GeneticState::Finish
} }
async fn process_node(mut node: GeneticNodeWrapper<T>, semaphore: Arc<Semaphore>) -> Result<GeneticNodeWrapper<T>, Error> { async fn process_node(mut node: GeneticNodeWrapper<T>, gemla_context: T::Context) -> Result<GeneticNodeWrapper<T>, Error> {
let node_state_time = Instant::now(); let node_state_time = Instant::now();
let node_state = node.state(); let node_state = node.state();
node.process_node(semaphore.clone()).await?; node.process_node(gemla_context.clone()).await?;
if node.state() == GeneticState::Simulate if node.state() == GeneticState::Simulate
{ {
node.process_node(semaphore.clone()).await?; node.process_node(gemla_context.clone()).await?;
} }
trace!( trace!(
@ -397,20 +396,22 @@ mod tests {
#[async_trait] #[async_trait]
impl genetic_node::GeneticNode for TestState { impl genetic_node::GeneticNode for TestState {
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { type Context = ();
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
self.score += 1.0; self.score += 1.0;
Ok(()) Ok(())
} }
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> { async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 })) Ok(Box::new(TestState { score: 0.0 }))
} }
fn merge(left: &TestState, right: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> { async fn merge(left: &TestState, right: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.score > right.score { Ok(Box::new(if left.score > right.score {
left.clone() left.clone()
} else { } else {
@ -433,7 +434,6 @@ mod tests {
let mut config = GemlaConfig { let mut config = GemlaConfig {
generations_per_height: 1, generations_per_height: 1,
overwrite: true, overwrite: true,
shared_semaphore_concurrency_limit: 1,
}; };
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?; let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
@ -483,7 +483,6 @@ mod tests {
let config = GemlaConfig { let config = GemlaConfig {
generations_per_height: 10, generations_per_height: 10,
overwrite: true, overwrite: true,
shared_semaphore_concurrency_limit: 1,
}; };
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?; let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;