Adding user defined, shared context
This commit is contained in:
parent
d473970325
commit
95699bd47e
7 changed files with 212 additions and 148 deletions
|
@ -31,3 +31,4 @@ num_cpus = "1.16.0"
|
|||
easy-parallel = "3.3.1"
|
||||
fann = "0.1.8"
|
||||
async-trait = "0.1.78"
|
||||
async-recursion = "1.1.0"
|
||||
|
|
|
@ -47,7 +47,6 @@ fn main() -> Result<()> {
|
|||
GemlaConfig {
|
||||
generations_per_height: 5,
|
||||
overwrite: false,
|
||||
shared_semaphore_concurrency_limit: 50,
|
||||
},
|
||||
DataFormat::Json,
|
||||
))?;
|
||||
|
|
48
gemla/src/bin/fighter_nn/fighter_context.rs
Normal file
48
gemla/src/bin/fighter_nn/fighter_context.rs
Normal file
|
@ -0,0 +1,48 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 20;
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FighterContext {
|
||||
pub shared_semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl Default for FighterContext {
|
||||
fn default() -> Self {
|
||||
FighterContext {
|
||||
shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Custom serialization to just output the concurrency limit.
|
||||
impl Serialize for FighterContext {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
// Assuming the semaphore's available permits represent the concurrency limit.
|
||||
// This part is tricky since Semaphore does not expose its initial permits.
|
||||
// You might need to store the concurrency limit as a separate field if this assumption doesn't hold.
|
||||
let concurrency_limit = SHARED_SEMAPHORE_CONCURRENCY_LIMIT;
|
||||
serializer.serialize_u64(concurrency_limit as u64)
|
||||
}
|
||||
}
|
||||
|
||||
// Custom deserialization to reconstruct the FighterContext from a concurrency limit.
|
||||
impl<'de> Deserialize<'de> for FighterContext {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let concurrency_limit = u64::deserialize(deserializer)?;
|
||||
Ok(FighterContext {
|
||||
shared_semaphore: Arc::new(Semaphore::new(concurrency_limit as usize)),
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
extern crate fann;
|
||||
|
||||
pub mod neural_network_utility;
|
||||
pub mod fighter_context;
|
||||
|
||||
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, path::{Path, PathBuf}, sync::Arc};
|
||||
use fann::{ActivationFunc, Fann};
|
||||
|
@ -63,8 +64,10 @@ pub struct FighterNN {
|
|||
|
||||
#[async_trait]
|
||||
impl GeneticNode for FighterNN {
|
||||
type Context = fighter_context::FighterContext;
|
||||
|
||||
// Check for the highest number of the folder name and increment it by 1
|
||||
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error> {
|
||||
async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
|
||||
let base_path = PathBuf::from(BASE_DIR);
|
||||
|
||||
let folder = base_path.join(format!("fighter_nn_{:06}", context.id));
|
||||
|
@ -127,100 +130,37 @@ impl GeneticNode for FighterNN {
|
|||
}))
|
||||
}
|
||||
|
||||
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
debug!("Context: {:?}", context);
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
// For each nn in the current generation:
|
||||
for i in 0..self.population_size {
|
||||
let self_clone = self.clone();
|
||||
let semaphore_clone = Arc::clone(context.semaphore.as_ref().unwrap());
|
||||
let semaphore_clone = context.gemla_context.shared_semaphore.clone();
|
||||
|
||||
let task = async move {
|
||||
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(format!("{:06}_fighter_nn_{}.net", self_clone.id, i));
|
||||
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(self_clone.get_individual_id(i as u64));
|
||||
let mut simulations = Vec::new();
|
||||
|
||||
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
|
||||
for _ in 0..SIMULATION_ROUNDS {
|
||||
let random_nn_index = thread_rng().gen_range(0..self_clone.population_size);
|
||||
let id = self_clone.id.clone();
|
||||
let folder = self_clone.folder.clone();
|
||||
let generation = self_clone.generation;
|
||||
let semaphore_clone = Arc::clone(&semaphore_clone);
|
||||
|
||||
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index));
|
||||
let random_nn = folder.join(format!("{}", generation)).join(self_clone.get_individual_id(random_nn_index as u64));
|
||||
let nn_clone = nn.clone(); // Clone the path to use in the async block
|
||||
|
||||
let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap());
|
||||
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
|
||||
let disable_unreal_rendering_arg = "-nullrhi".to_string();
|
||||
|
||||
|
||||
|
||||
let future = async move {
|
||||
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
|
||||
|
||||
// Construct the score file path
|
||||
let nn_id = format!("{:06}_fighter_nn_{}", id, i);
|
||||
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
|
||||
let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id);
|
||||
let score_file = folder.join(format!("{}", generation)).join(&score_file_name);
|
||||
|
||||
// Check if score file already exists before running the simulation
|
||||
if score_file.exists() {
|
||||
let round_score = read_score_from_file(&score_file, &nn_id).await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
|
||||
|
||||
debug!("{} scored {}", nn_id, round_score);
|
||||
|
||||
return Ok::<f32, Error>(round_score);
|
||||
}
|
||||
|
||||
// Check if the opposite round score has been determined
|
||||
let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id));
|
||||
if opposite_score_file.exists() {
|
||||
let round_score = read_score_from_file(&opposite_score_file, &nn_id).await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
|
||||
|
||||
debug!("{} scored {}", nn_id, round_score);
|
||||
|
||||
return Ok::<f32, Error>(1.0 - round_score);
|
||||
}
|
||||
|
||||
// Run simulation until score file is generated
|
||||
while !score_file.exists() {
|
||||
let _output = if thread_rng().gen_range(0..100) < 1 {
|
||||
Command::new(GAME_EXECUTABLE_PATH)
|
||||
.arg(&config1_arg)
|
||||
.arg(&config2_arg)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute game")
|
||||
} else {
|
||||
Command::new(GAME_EXECUTABLE_PATH)
|
||||
.arg(&config1_arg)
|
||||
.arg(&config2_arg)
|
||||
.arg(&disable_unreal_rendering_arg)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute game")
|
||||
};
|
||||
}
|
||||
let score = run_1v1_simulation(&nn_clone, &random_nn).await?;
|
||||
|
||||
drop(permit);
|
||||
|
||||
// Read the score from the file
|
||||
if score_file.exists() {
|
||||
let round_score = read_score_from_file(&score_file, &nn_id).await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
|
||||
|
||||
debug!("{} scored {}", nn_id, round_score);
|
||||
|
||||
Ok(round_score)
|
||||
} else {
|
||||
warn!("Score file not found: {:?}", score_file_name);
|
||||
Ok(0.0)
|
||||
}
|
||||
Ok(score)
|
||||
};
|
||||
|
||||
simulations.push(future);
|
||||
|
@ -259,7 +199,7 @@ impl GeneticNode for FighterNN {
|
|||
}
|
||||
|
||||
|
||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
|
||||
|
||||
// Create the new generation folder
|
||||
|
@ -322,7 +262,7 @@ impl GeneticNode for FighterNN {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid) -> Result<Box<FighterNN>, Error> {
|
||||
async fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid, _: Self::Context) -> Result<Box<FighterNN>, Error> {
|
||||
let base_path = PathBuf::from(BASE_DIR);
|
||||
let folder = base_path.join(format!("fighter_nn_{:06}", id));
|
||||
|
||||
|
@ -365,6 +305,78 @@ impl GeneticNode for FighterNN {
|
|||
}
|
||||
}
|
||||
|
||||
impl FighterNN {
|
||||
pub fn get_individual_id(&self, nn_id: u64) -> String {
|
||||
format!("{:06}_fighter_nn_{}", self.id, nn_id)
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<f32, Error> {
|
||||
// Construct the score file path
|
||||
let base_folder = nn_path_1.parent().unwrap();
|
||||
let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap();
|
||||
let nn_2_id = nn_path_2.file_stem().unwrap().to_str().unwrap();
|
||||
let score_file = base_folder.join(format!("{}_vs_{}.txt", nn_1_id, nn_2_id));
|
||||
|
||||
// Check if score file already exists before running the simulation
|
||||
if score_file.exists() {
|
||||
let round_score = read_score_from_file(&score_file, &nn_1_id).await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
|
||||
|
||||
trace!("{} scored {}", nn_1_id, round_score);
|
||||
|
||||
return Ok::<f32, Error>(round_score);
|
||||
}
|
||||
|
||||
// Check if the opposite round score has been determined
|
||||
let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id));
|
||||
if opposite_score_file.exists() {
|
||||
let round_score = read_score_from_file(&opposite_score_file, &nn_1_id).await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
|
||||
|
||||
trace!("{} scored {}", nn_1_id, round_score);
|
||||
|
||||
return Ok::<f32, Error>(1.0 - round_score);
|
||||
}
|
||||
|
||||
// Run simulation until score file is generated
|
||||
let config1_arg = format!("-NN1Config=\"{}\"", nn_path_1.to_str().unwrap());
|
||||
let config2_arg = format!("-NN2Config=\"{}\"", nn_path_2.to_str().unwrap());
|
||||
let disable_unreal_rendering_arg = "-nullrhi".to_string();
|
||||
|
||||
while !score_file.exists() {
|
||||
let _output = if thread_rng().gen_range(0..100) < 1 {
|
||||
Command::new(GAME_EXECUTABLE_PATH)
|
||||
.arg(&config1_arg)
|
||||
.arg(&config2_arg)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute game")
|
||||
} else {
|
||||
Command::new(GAME_EXECUTABLE_PATH)
|
||||
.arg(&config1_arg)
|
||||
.arg(&config2_arg)
|
||||
.arg(&disable_unreal_rendering_arg)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute game")
|
||||
};
|
||||
}
|
||||
|
||||
// Read the score from the file
|
||||
if score_file.exists() {
|
||||
let round_score = read_score_from_file(&score_file, &nn_1_id).await
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
|
||||
|
||||
trace!("{} scored {}", nn_1_id, round_score);
|
||||
|
||||
Ok(round_score)
|
||||
} else {
|
||||
warn!("Score file not found: {:?}", score_file);
|
||||
Ok(0.0)
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::Error> {
|
||||
let mut attempts = 0;
|
||||
|
||||
|
|
|
@ -14,7 +14,9 @@ pub struct TestState {
|
|||
|
||||
#[async_trait]
|
||||
impl GeneticNode for TestState {
|
||||
fn initialize(_context: GeneticNodeContext) -> Result<Box<Self>, Error> {
|
||||
type Context = ();
|
||||
|
||||
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
|
||||
let mut population: Vec<i64> = vec![];
|
||||
|
||||
for _ in 0..POPULATION_SIZE {
|
||||
|
@ -24,7 +26,7 @@ impl GeneticNode for TestState {
|
|||
Ok(Box::new(TestState { population }))
|
||||
}
|
||||
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
self.population = self
|
||||
|
@ -36,7 +38,7 @@ impl GeneticNode for TestState {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let mut v = self.population.clone();
|
||||
|
@ -74,7 +76,7 @@ impl GeneticNode for TestState {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn merge(left: &TestState, right: &TestState, id: &Uuid) -> Result<Box<TestState>, Error> {
|
||||
async fn merge(left: &TestState, right: &TestState, id: &Uuid, gemla_context: Self::Context) -> Result<Box<TestState>, Error> {
|
||||
let mut v = left.population.clone();
|
||||
v.append(&mut right.population.clone());
|
||||
|
||||
|
@ -89,8 +91,8 @@ impl GeneticNode for TestState {
|
|||
id: id.clone(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
})?;
|
||||
gemla_context: gemla_context
|
||||
}).await?;
|
||||
|
||||
Ok(Box::new(result))
|
||||
}
|
||||
|
@ -101,16 +103,16 @@ mod tests {
|
|||
use super::*;
|
||||
use gemla::core::genetic_node::GeneticNode;
|
||||
|
||||
#[test]
|
||||
fn test_initialize() {
|
||||
#[tokio::test]
|
||||
async fn test_initialize() {
|
||||
let state = TestState::initialize(
|
||||
GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
gemla_context: (),
|
||||
}
|
||||
).unwrap();
|
||||
).await.unwrap();
|
||||
|
||||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||
}
|
||||
|
@ -128,7 +130,7 @@ mod tests {
|
|||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
gemla_context: (),
|
||||
}
|
||||
).await.unwrap();
|
||||
assert!(original_population
|
||||
|
@ -141,7 +143,7 @@ mod tests {
|
|||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
gemla_context: (),
|
||||
}
|
||||
).await.unwrap();
|
||||
state.simulate(
|
||||
|
@ -149,7 +151,7 @@ mod tests {
|
|||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
gemla_context: (),
|
||||
}
|
||||
).await.unwrap();
|
||||
assert!(original_population
|
||||
|
@ -158,8 +160,8 @@ mod tests {
|
|||
.all(|(&a, &b)| b >= a - 3 && b <= a + 6))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mutate() {
|
||||
#[tokio::test]
|
||||
async fn test_mutate() {
|
||||
let mut state = TestState {
|
||||
population: vec![4, 3, 3],
|
||||
};
|
||||
|
@ -169,15 +171,15 @@ mod tests {
|
|||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
semaphore: None,
|
||||
gemla_context: (),
|
||||
}
|
||||
).unwrap();
|
||||
).await.unwrap();
|
||||
|
||||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge() {
|
||||
#[tokio::test]
|
||||
async fn test_merge() {
|
||||
let state1 = TestState {
|
||||
population: vec![1, 2, 4, 5],
|
||||
};
|
||||
|
@ -186,7 +188,7 @@ mod tests {
|
|||
population: vec![0, 1, 3, 7],
|
||||
};
|
||||
|
||||
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4()).unwrap();
|
||||
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()).await.unwrap();
|
||||
|
||||
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
|
||||
assert!(merged_state.population.iter().any(|&x| x == 7));
|
||||
|
|
|
@ -5,9 +5,8 @@
|
|||
use crate::error::Error;
|
||||
|
||||
use anyhow::Context;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Semaphore;
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use uuid::Uuid;
|
||||
use async_trait::async_trait;
|
||||
|
||||
|
@ -27,33 +26,35 @@ pub enum GeneticState {
|
|||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GeneticNodeContext {
|
||||
pub struct GeneticNodeContext<S> {
|
||||
pub generation: u64,
|
||||
pub max_generations: u64,
|
||||
pub id: Uuid,
|
||||
pub semaphore: Option<Arc<Semaphore>>,
|
||||
pub gemla_context: S
|
||||
}
|
||||
|
||||
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
|
||||
///
|
||||
/// [`Bracket`]: crate::bracket::Bracket
|
||||
#[async_trait]
|
||||
pub trait GeneticNode: Send {
|
||||
pub trait GeneticNode : Send {
|
||||
type Context;
|
||||
|
||||
/// Initializes a new instance of a [`GeneticState`].
|
||||
///
|
||||
/// # Examples
|
||||
/// TODO
|
||||
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error>;
|
||||
async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error>;
|
||||
|
||||
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error>;
|
||||
async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
|
||||
|
||||
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
|
||||
///
|
||||
/// # Examples
|
||||
/// TODO
|
||||
fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>;
|
||||
async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
|
||||
|
||||
fn merge(left: &Self, right: &Self, id: &Uuid) -> Result<Box<Self>, Error>;
|
||||
async fn merge(left: &Self, right: &Self, id: &Uuid, context: Self::Context) -> Result<Box<Self>, Error>;
|
||||
}
|
||||
|
||||
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
|
||||
|
@ -82,6 +83,7 @@ impl<T> Default for GeneticNodeWrapper<T> {
|
|||
impl<T> GeneticNodeWrapper<T>
|
||||
where
|
||||
T: GeneticNode + Debug + Send,
|
||||
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
|
||||
{
|
||||
pub fn new(max_generations: u64) -> Self {
|
||||
GeneticNodeWrapper::<T> {
|
||||
|
@ -120,17 +122,17 @@ where
|
|||
self.state
|
||||
}
|
||||
|
||||
pub async fn process_node(&mut self, semaphore: Arc<Semaphore>) -> Result<GeneticState, Error> {
|
||||
pub async fn process_node(&mut self, gemla_context: T::Context) -> Result<GeneticState, Error> {
|
||||
let context = GeneticNodeContext {
|
||||
generation: self.generation,
|
||||
max_generations: self.max_generations,
|
||||
id: self.id,
|
||||
semaphore: Some(semaphore),
|
||||
gemla_context,
|
||||
};
|
||||
|
||||
match (self.state, &mut self.node) {
|
||||
(GeneticState::Initialize, _) => {
|
||||
self.node = Some(*T::initialize(context.clone())?);
|
||||
self.node = Some(*T::initialize(context.clone()).await?);
|
||||
self.state = GeneticState::Simulate;
|
||||
}
|
||||
(GeneticState::Simulate, Some(n)) => {
|
||||
|
@ -144,7 +146,7 @@ where
|
|||
};
|
||||
}
|
||||
(GeneticState::Mutate, Some(n)) => {
|
||||
n.mutate(context.clone())
|
||||
n.mutate(context.clone()).await
|
||||
.with_context(|| format!("Error mutating node: {:?}", self))?;
|
||||
|
||||
self.generation += 1;
|
||||
|
@ -172,20 +174,22 @@ mod tests {
|
|||
|
||||
#[async_trait]
|
||||
impl GeneticNode for TestState {
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
type Context = ();
|
||||
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
self.score += 1.0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> {
|
||||
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState { score: 0.0 }))
|
||||
}
|
||||
|
||||
fn merge(_l: &TestState, _r: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> {
|
||||
async fn merge(_l: &TestState, _r: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
|
||||
Err(Error::Other(anyhow!("Unable to merge")))
|
||||
}
|
||||
}
|
||||
|
@ -281,14 +285,13 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_process_node() -> Result<(), Error> {
|
||||
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
|
||||
let semaphore = Arc::new(Semaphore::new(1));
|
||||
|
||||
assert_eq!(genetic_node.state(), GeneticState::Initialize);
|
||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Mutate);
|
||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Mutate);
|
||||
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -4,15 +4,15 @@
|
|||
pub mod genetic_node;
|
||||
|
||||
use crate::{error::Error, tree::Tree};
|
||||
use async_recursion::async_recursion;
|
||||
use file_linked::{constants::data_format::DataFormat, FileLinked};
|
||||
use futures::future;
|
||||
use futures::{executor::block_on, future};
|
||||
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
|
||||
use log::{info, trace, warn};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::sync::Semaphore;
|
||||
use std::{
|
||||
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant
|
||||
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, time::Instant
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -58,7 +58,6 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
|
|||
pub struct GemlaConfig {
|
||||
pub generations_per_height: u64,
|
||||
pub overwrite: bool,
|
||||
pub shared_semaphore_concurrency_limit: usize,
|
||||
}
|
||||
|
||||
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
|
||||
|
@ -69,16 +68,17 @@ pub struct GemlaConfig {
|
|||
/// [`GeneticNode`]: genetic_node::GeneticNode
|
||||
pub struct Gemla<T>
|
||||
where
|
||||
T: Serialize + Clone,
|
||||
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
|
||||
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
|
||||
{
|
||||
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
|
||||
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>,
|
||||
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
|
||||
semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl<T: 'static> Gemla<T>
|
||||
where
|
||||
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
|
||||
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
|
||||
{
|
||||
pub fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result<Self, Error> {
|
||||
match File::open(path) {
|
||||
|
@ -86,18 +86,16 @@ where
|
|||
// based on the configuration provided
|
||||
Ok(_) => Ok(Gemla {
|
||||
data: if config.overwrite {
|
||||
FileLinked::new((None, config), path, data_format)?
|
||||
FileLinked::new((None, config, T::Context::default()), path, data_format)?
|
||||
} else {
|
||||
FileLinked::from_file(path, data_format)?
|
||||
},
|
||||
threads: HashMap::new(),
|
||||
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
|
||||
}),
|
||||
// If the file doesn't exist we must create it
|
||||
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
|
||||
data: FileLinked::new((None, config), path, data_format)?,
|
||||
data: FileLinked::new((None, config, T::Context::default()), path, data_format)?,
|
||||
threads: HashMap::new(),
|
||||
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
|
||||
}),
|
||||
Err(error) => Err(Error::IO(error)),
|
||||
}
|
||||
|
@ -117,7 +115,7 @@ where
|
|||
{
|
||||
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
|
||||
// in the tree and which nodes have not.
|
||||
self.data.mutate(|(d, c)| {
|
||||
self.data.mutate(|(d, c, _)| {
|
||||
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
|
||||
mem::swap(d, &mut tree);
|
||||
})?;
|
||||
|
@ -151,11 +149,11 @@ where
|
|||
{
|
||||
trace!("Adding node to process list {}", node.id());
|
||||
|
||||
let semaphore = self.semaphore.clone();
|
||||
let gemla_context = self.data.readonly().2.clone();
|
||||
|
||||
self.threads
|
||||
.insert(node.id(), tokio::spawn(async move {
|
||||
Gemla::process_node(node, semaphore).await
|
||||
Gemla::process_node(node, gemla_context).await
|
||||
}));
|
||||
} else {
|
||||
trace!("No node found to process, joining threads");
|
||||
|
@ -180,7 +178,7 @@ where
|
|||
|
||||
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
|
||||
reduced_results.and_then(|r| {
|
||||
self.data.mutate(|(d, _)| {
|
||||
self.data.mutate(|(d, _, context)| {
|
||||
if let Some(t) = d {
|
||||
let failed_nodes = Gemla::replace_nodes(t, r);
|
||||
// We receive a list of nodes that were unable to be found in the original tree
|
||||
|
@ -192,7 +190,7 @@ where
|
|||
}
|
||||
|
||||
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
|
||||
Gemla::merge_completed_nodes(t)
|
||||
block_on(Gemla::merge_completed_nodes(t, context.clone()))
|
||||
} else {
|
||||
warn!("Unable to replce nodes {:?} in empty tree", r);
|
||||
Ok(())
|
||||
|
@ -204,7 +202,8 @@ where
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn merge_completed_nodes(tree: &mut SimulationTree<T>) -> Result<(), Error> {
|
||||
#[async_recursion]
|
||||
async fn merge_completed_nodes(tree: &mut SimulationTree<T>, gemla_context: T::Context) -> Result<(), Error> {
|
||||
if tree.val.state() == GeneticState::Initialize {
|
||||
match (&mut tree.left, &mut tree.right) {
|
||||
// If the current node has been initialized, and has children nodes that are completed, then we need
|
||||
|
@ -215,7 +214,7 @@ where
|
|||
{
|
||||
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
|
||||
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
|
||||
let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id())?;
|
||||
let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id(), gemla_context.clone()).await?;
|
||||
tree.val = GeneticNodeWrapper::from(
|
||||
*merged_node,
|
||||
tree.val.max_generations(),
|
||||
|
@ -224,8 +223,8 @@ where
|
|||
}
|
||||
}
|
||||
(Some(l), Some(r)) => {
|
||||
Gemla::merge_completed_nodes(l)?;
|
||||
Gemla::merge_completed_nodes(r)?;
|
||||
Gemla::merge_completed_nodes(l, gemla_context.clone()).await?;
|
||||
Gemla::merge_completed_nodes(r, gemla_context.clone()).await?;
|
||||
}
|
||||
// If there is only one child node that's completed then we want to copy it to the parent node
|
||||
(Some(l), None) if l.val.state() == GeneticState::Finish => {
|
||||
|
@ -239,7 +238,7 @@ where
|
|||
);
|
||||
}
|
||||
}
|
||||
(Some(l), None) => Gemla::merge_completed_nodes(l)?,
|
||||
(Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?,
|
||||
(None, Some(r)) if r.val.state() == GeneticState::Finish => {
|
||||
trace!("Copying node {}", r.val.id());
|
||||
|
||||
|
@ -251,7 +250,7 @@ where
|
|||
);
|
||||
}
|
||||
}
|
||||
(None, Some(r)) => Gemla::merge_completed_nodes(r)?,
|
||||
(None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?,
|
||||
(_, _) => (),
|
||||
}
|
||||
}
|
||||
|
@ -329,15 +328,15 @@ where
|
|||
tree.val.state() == GeneticState::Finish
|
||||
}
|
||||
|
||||
async fn process_node(mut node: GeneticNodeWrapper<T>, semaphore: Arc<Semaphore>) -> Result<GeneticNodeWrapper<T>, Error> {
|
||||
async fn process_node(mut node: GeneticNodeWrapper<T>, gemla_context: T::Context) -> Result<GeneticNodeWrapper<T>, Error> {
|
||||
let node_state_time = Instant::now();
|
||||
let node_state = node.state();
|
||||
|
||||
node.process_node(semaphore.clone()).await?;
|
||||
node.process_node(gemla_context.clone()).await?;
|
||||
|
||||
if node.state() == GeneticState::Simulate
|
||||
{
|
||||
node.process_node(semaphore.clone()).await?;
|
||||
node.process_node(gemla_context.clone()).await?;
|
||||
}
|
||||
|
||||
trace!(
|
||||
|
@ -397,20 +396,22 @@ mod tests {
|
|||
|
||||
#[async_trait]
|
||||
impl genetic_node::GeneticNode for TestState {
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
type Context = ();
|
||||
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
self.score += 1.0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> {
|
||||
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState { score: 0.0 }))
|
||||
}
|
||||
|
||||
fn merge(left: &TestState, right: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> {
|
||||
async fn merge(left: &TestState, right: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(if left.score > right.score {
|
||||
left.clone()
|
||||
} else {
|
||||
|
@ -433,7 +434,6 @@ mod tests {
|
|||
let mut config = GemlaConfig {
|
||||
generations_per_height: 1,
|
||||
overwrite: true,
|
||||
shared_semaphore_concurrency_limit: 1,
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
|
||||
|
@ -483,7 +483,6 @@ mod tests {
|
|||
let config = GemlaConfig {
|
||||
generations_per_height: 10,
|
||||
overwrite: true,
|
||||
shared_semaphore_concurrency_limit: 1,
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue