dootcamp #1
7 changed files with 212 additions and 148 deletions
|
@ -31,3 +31,4 @@ num_cpus = "1.16.0"
|
||||||
easy-parallel = "3.3.1"
|
easy-parallel = "3.3.1"
|
||||||
fann = "0.1.8"
|
fann = "0.1.8"
|
||||||
async-trait = "0.1.78"
|
async-trait = "0.1.78"
|
||||||
|
async-recursion = "1.1.0"
|
||||||
|
|
|
@ -47,7 +47,6 @@ fn main() -> Result<()> {
|
||||||
GemlaConfig {
|
GemlaConfig {
|
||||||
generations_per_height: 5,
|
generations_per_height: 5,
|
||||||
overwrite: false,
|
overwrite: false,
|
||||||
shared_semaphore_concurrency_limit: 50,
|
|
||||||
},
|
},
|
||||||
DataFormat::Json,
|
DataFormat::Json,
|
||||||
))?;
|
))?;
|
||||||
|
|
48
gemla/src/bin/fighter_nn/fighter_context.rs
Normal file
48
gemla/src/bin/fighter_nn/fighter_context.rs
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
|
const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 20;
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct FighterContext {
|
||||||
|
pub shared_semaphore: Arc<Semaphore>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FighterContext {
|
||||||
|
fn default() -> Self {
|
||||||
|
FighterContext {
|
||||||
|
shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Custom serialization to just output the concurrency limit.
|
||||||
|
impl Serialize for FighterContext {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
// Assuming the semaphore's available permits represent the concurrency limit.
|
||||||
|
// This part is tricky since Semaphore does not expose its initial permits.
|
||||||
|
// You might need to store the concurrency limit as a separate field if this assumption doesn't hold.
|
||||||
|
let concurrency_limit = SHARED_SEMAPHORE_CONCURRENCY_LIMIT;
|
||||||
|
serializer.serialize_u64(concurrency_limit as u64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom deserialization to reconstruct the FighterContext from a concurrency limit.
|
||||||
|
impl<'de> Deserialize<'de> for FighterContext {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let concurrency_limit = u64::deserialize(deserializer)?;
|
||||||
|
Ok(FighterContext {
|
||||||
|
shared_semaphore: Arc::new(Semaphore::new(concurrency_limit as usize)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
extern crate fann;
|
extern crate fann;
|
||||||
|
|
||||||
pub mod neural_network_utility;
|
pub mod neural_network_utility;
|
||||||
|
pub mod fighter_context;
|
||||||
|
|
||||||
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, path::{Path, PathBuf}, sync::Arc};
|
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, ops::Range, path::{Path, PathBuf}, sync::Arc};
|
||||||
use fann::{ActivationFunc, Fann};
|
use fann::{ActivationFunc, Fann};
|
||||||
|
@ -63,8 +64,10 @@ pub struct FighterNN {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl GeneticNode for FighterNN {
|
impl GeneticNode for FighterNN {
|
||||||
|
type Context = fighter_context::FighterContext;
|
||||||
|
|
||||||
// Check for the highest number of the folder name and increment it by 1
|
// Check for the highest number of the folder name and increment it by 1
|
||||||
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error> {
|
async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
|
||||||
let base_path = PathBuf::from(BASE_DIR);
|
let base_path = PathBuf::from(BASE_DIR);
|
||||||
|
|
||||||
let folder = base_path.join(format!("fighter_nn_{:06}", context.id));
|
let folder = base_path.join(format!("fighter_nn_{:06}", context.id));
|
||||||
|
@ -127,100 +130,37 @@ impl GeneticNode for FighterNN {
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> {
|
async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||||
debug!("Context: {:?}", context);
|
debug!("Context: {:?}", context);
|
||||||
let mut tasks = Vec::new();
|
let mut tasks = Vec::new();
|
||||||
|
|
||||||
// For each nn in the current generation:
|
// For each nn in the current generation:
|
||||||
for i in 0..self.population_size {
|
for i in 0..self.population_size {
|
||||||
let self_clone = self.clone();
|
let self_clone = self.clone();
|
||||||
let semaphore_clone = Arc::clone(context.semaphore.as_ref().unwrap());
|
let semaphore_clone = context.gemla_context.shared_semaphore.clone();
|
||||||
|
|
||||||
let task = async move {
|
let task = async move {
|
||||||
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(format!("{:06}_fighter_nn_{}.net", self_clone.id, i));
|
let nn = self_clone.folder.join(format!("{}", self_clone.generation)).join(self_clone.get_individual_id(i as u64));
|
||||||
let mut simulations = Vec::new();
|
let mut simulations = Vec::new();
|
||||||
|
|
||||||
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
|
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
|
||||||
for _ in 0..SIMULATION_ROUNDS {
|
for _ in 0..SIMULATION_ROUNDS {
|
||||||
let random_nn_index = thread_rng().gen_range(0..self_clone.population_size);
|
let random_nn_index = thread_rng().gen_range(0..self_clone.population_size);
|
||||||
let id = self_clone.id.clone();
|
|
||||||
let folder = self_clone.folder.clone();
|
let folder = self_clone.folder.clone();
|
||||||
let generation = self_clone.generation;
|
let generation = self_clone.generation;
|
||||||
let semaphore_clone = Arc::clone(&semaphore_clone);
|
let semaphore_clone = Arc::clone(&semaphore_clone);
|
||||||
|
|
||||||
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index));
|
let random_nn = folder.join(format!("{}", generation)).join(self_clone.get_individual_id(random_nn_index as u64));
|
||||||
let nn_clone = nn.clone(); // Clone the path to use in the async block
|
let nn_clone = nn.clone(); // Clone the path to use in the async block
|
||||||
|
|
||||||
let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap());
|
|
||||||
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
|
|
||||||
let disable_unreal_rendering_arg = "-nullrhi".to_string();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
let future = async move {
|
let future = async move {
|
||||||
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
|
let permit = semaphore_clone.acquire_owned().await.with_context(|| "Failed to acquire semaphore permit")?;
|
||||||
|
|
||||||
// Construct the score file path
|
let score = run_1v1_simulation(&nn_clone, &random_nn).await?;
|
||||||
let nn_id = format!("{:06}_fighter_nn_{}", id, i);
|
|
||||||
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
|
|
||||||
let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id);
|
|
||||||
let score_file = folder.join(format!("{}", generation)).join(&score_file_name);
|
|
||||||
|
|
||||||
// Check if score file already exists before running the simulation
|
|
||||||
if score_file.exists() {
|
|
||||||
let round_score = read_score_from_file(&score_file, &nn_id).await
|
|
||||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
|
|
||||||
|
|
||||||
debug!("{} scored {}", nn_id, round_score);
|
|
||||||
|
|
||||||
return Ok::<f32, Error>(round_score);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the opposite round score has been determined
|
|
||||||
let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id));
|
|
||||||
if opposite_score_file.exists() {
|
|
||||||
let round_score = read_score_from_file(&opposite_score_file, &nn_id).await
|
|
||||||
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
|
|
||||||
|
|
||||||
debug!("{} scored {}", nn_id, round_score);
|
|
||||||
|
|
||||||
return Ok::<f32, Error>(1.0 - round_score);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run simulation until score file is generated
|
|
||||||
while !score_file.exists() {
|
|
||||||
let _output = if thread_rng().gen_range(0..100) < 1 {
|
|
||||||
Command::new(GAME_EXECUTABLE_PATH)
|
|
||||||
.arg(&config1_arg)
|
|
||||||
.arg(&config2_arg)
|
|
||||||
.output()
|
|
||||||
.await
|
|
||||||
.expect("Failed to execute game")
|
|
||||||
} else {
|
|
||||||
Command::new(GAME_EXECUTABLE_PATH)
|
|
||||||
.arg(&config1_arg)
|
|
||||||
.arg(&config2_arg)
|
|
||||||
.arg(&disable_unreal_rendering_arg)
|
|
||||||
.output()
|
|
||||||
.await
|
|
||||||
.expect("Failed to execute game")
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
drop(permit);
|
drop(permit);
|
||||||
|
|
||||||
// Read the score from the file
|
Ok(score)
|
||||||
if score_file.exists() {
|
|
||||||
let round_score = read_score_from_file(&score_file, &nn_id).await
|
|
||||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
|
|
||||||
|
|
||||||
debug!("{} scored {}", nn_id, round_score);
|
|
||||||
|
|
||||||
Ok(round_score)
|
|
||||||
} else {
|
|
||||||
warn!("Score file not found: {:?}", score_file_name);
|
|
||||||
Ok(0.0)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
simulations.push(future);
|
simulations.push(future);
|
||||||
|
@ -259,7 +199,7 @@ impl GeneticNode for FighterNN {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||||
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
|
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
|
||||||
|
|
||||||
// Create the new generation folder
|
// Create the new generation folder
|
||||||
|
@ -322,7 +262,7 @@ impl GeneticNode for FighterNN {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid) -> Result<Box<FighterNN>, Error> {
|
async fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid, _: Self::Context) -> Result<Box<FighterNN>, Error> {
|
||||||
let base_path = PathBuf::from(BASE_DIR);
|
let base_path = PathBuf::from(BASE_DIR);
|
||||||
let folder = base_path.join(format!("fighter_nn_{:06}", id));
|
let folder = base_path.join(format!("fighter_nn_{:06}", id));
|
||||||
|
|
||||||
|
@ -365,6 +305,78 @@ impl GeneticNode for FighterNN {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FighterNN {
|
||||||
|
pub fn get_individual_id(&self, nn_id: u64) -> String {
|
||||||
|
format!("{:06}_fighter_nn_{}", self.id, nn_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_1v1_simulation(nn_path_1: &PathBuf, nn_path_2: &PathBuf) -> Result<f32, Error> {
|
||||||
|
// Construct the score file path
|
||||||
|
let base_folder = nn_path_1.parent().unwrap();
|
||||||
|
let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap();
|
||||||
|
let nn_2_id = nn_path_2.file_stem().unwrap().to_str().unwrap();
|
||||||
|
let score_file = base_folder.join(format!("{}_vs_{}.txt", nn_1_id, nn_2_id));
|
||||||
|
|
||||||
|
// Check if score file already exists before running the simulation
|
||||||
|
if score_file.exists() {
|
||||||
|
let round_score = read_score_from_file(&score_file, &nn_1_id).await
|
||||||
|
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
|
||||||
|
|
||||||
|
trace!("{} scored {}", nn_1_id, round_score);
|
||||||
|
|
||||||
|
return Ok::<f32, Error>(round_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the opposite round score has been determined
|
||||||
|
let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id));
|
||||||
|
if opposite_score_file.exists() {
|
||||||
|
let round_score = read_score_from_file(&opposite_score_file, &nn_1_id).await
|
||||||
|
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
|
||||||
|
|
||||||
|
trace!("{} scored {}", nn_1_id, round_score);
|
||||||
|
|
||||||
|
return Ok::<f32, Error>(1.0 - round_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run simulation until score file is generated
|
||||||
|
let config1_arg = format!("-NN1Config=\"{}\"", nn_path_1.to_str().unwrap());
|
||||||
|
let config2_arg = format!("-NN2Config=\"{}\"", nn_path_2.to_str().unwrap());
|
||||||
|
let disable_unreal_rendering_arg = "-nullrhi".to_string();
|
||||||
|
|
||||||
|
while !score_file.exists() {
|
||||||
|
let _output = if thread_rng().gen_range(0..100) < 1 {
|
||||||
|
Command::new(GAME_EXECUTABLE_PATH)
|
||||||
|
.arg(&config1_arg)
|
||||||
|
.arg(&config2_arg)
|
||||||
|
.output()
|
||||||
|
.await
|
||||||
|
.expect("Failed to execute game")
|
||||||
|
} else {
|
||||||
|
Command::new(GAME_EXECUTABLE_PATH)
|
||||||
|
.arg(&config1_arg)
|
||||||
|
.arg(&config2_arg)
|
||||||
|
.arg(&disable_unreal_rendering_arg)
|
||||||
|
.output()
|
||||||
|
.await
|
||||||
|
.expect("Failed to execute game")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the score from the file
|
||||||
|
if score_file.exists() {
|
||||||
|
let round_score = read_score_from_file(&score_file, &nn_1_id).await
|
||||||
|
.with_context(|| format!("Failed to read score from file: {:?}", score_file))?;
|
||||||
|
|
||||||
|
trace!("{} scored {}", nn_1_id, round_score);
|
||||||
|
|
||||||
|
Ok(round_score)
|
||||||
|
} else {
|
||||||
|
warn!("Score file not found: {:?}", score_file);
|
||||||
|
Ok(0.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::Error> {
|
async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::Error> {
|
||||||
let mut attempts = 0;
|
let mut attempts = 0;
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,9 @@ pub struct TestState {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl GeneticNode for TestState {
|
impl GeneticNode for TestState {
|
||||||
fn initialize(_context: GeneticNodeContext) -> Result<Box<Self>, Error> {
|
type Context = ();
|
||||||
|
|
||||||
|
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
|
||||||
let mut population: Vec<i64> = vec![];
|
let mut population: Vec<i64> = vec![];
|
||||||
|
|
||||||
for _ in 0..POPULATION_SIZE {
|
for _ in 0..POPULATION_SIZE {
|
||||||
|
@ -24,7 +26,7 @@ impl GeneticNode for TestState {
|
||||||
Ok(Box::new(TestState { population }))
|
Ok(Box::new(TestState { population }))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
|
|
||||||
self.population = self
|
self.population = self
|
||||||
|
@ -36,7 +38,7 @@ impl GeneticNode for TestState {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
|
|
||||||
let mut v = self.population.clone();
|
let mut v = self.population.clone();
|
||||||
|
@ -74,7 +76,7 @@ impl GeneticNode for TestState {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn merge(left: &TestState, right: &TestState, id: &Uuid) -> Result<Box<TestState>, Error> {
|
async fn merge(left: &TestState, right: &TestState, id: &Uuid, gemla_context: Self::Context) -> Result<Box<TestState>, Error> {
|
||||||
let mut v = left.population.clone();
|
let mut v = left.population.clone();
|
||||||
v.append(&mut right.population.clone());
|
v.append(&mut right.population.clone());
|
||||||
|
|
||||||
|
@ -89,8 +91,8 @@ impl GeneticNode for TestState {
|
||||||
id: id.clone(),
|
id: id.clone(),
|
||||||
generation: 0,
|
generation: 0,
|
||||||
max_generations: 0,
|
max_generations: 0,
|
||||||
semaphore: None,
|
gemla_context: gemla_context
|
||||||
})?;
|
}).await?;
|
||||||
|
|
||||||
Ok(Box::new(result))
|
Ok(Box::new(result))
|
||||||
}
|
}
|
||||||
|
@ -101,16 +103,16 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use gemla::core::genetic_node::GeneticNode;
|
use gemla::core::genetic_node::GeneticNode;
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_initialize() {
|
async fn test_initialize() {
|
||||||
let state = TestState::initialize(
|
let state = TestState::initialize(
|
||||||
GeneticNodeContext {
|
GeneticNodeContext {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
generation: 0,
|
generation: 0,
|
||||||
max_generations: 0,
|
max_generations: 0,
|
||||||
semaphore: None,
|
gemla_context: (),
|
||||||
}
|
}
|
||||||
).unwrap();
|
).await.unwrap();
|
||||||
|
|
||||||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||||
}
|
}
|
||||||
|
@ -128,7 +130,7 @@ mod tests {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
generation: 0,
|
generation: 0,
|
||||||
max_generations: 0,
|
max_generations: 0,
|
||||||
semaphore: None,
|
gemla_context: (),
|
||||||
}
|
}
|
||||||
).await.unwrap();
|
).await.unwrap();
|
||||||
assert!(original_population
|
assert!(original_population
|
||||||
|
@ -141,7 +143,7 @@ mod tests {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
generation: 0,
|
generation: 0,
|
||||||
max_generations: 0,
|
max_generations: 0,
|
||||||
semaphore: None,
|
gemla_context: (),
|
||||||
}
|
}
|
||||||
).await.unwrap();
|
).await.unwrap();
|
||||||
state.simulate(
|
state.simulate(
|
||||||
|
@ -149,7 +151,7 @@ mod tests {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
generation: 0,
|
generation: 0,
|
||||||
max_generations: 0,
|
max_generations: 0,
|
||||||
semaphore: None,
|
gemla_context: (),
|
||||||
}
|
}
|
||||||
).await.unwrap();
|
).await.unwrap();
|
||||||
assert!(original_population
|
assert!(original_population
|
||||||
|
@ -158,8 +160,8 @@ mod tests {
|
||||||
.all(|(&a, &b)| b >= a - 3 && b <= a + 6))
|
.all(|(&a, &b)| b >= a - 3 && b <= a + 6))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_mutate() {
|
async fn test_mutate() {
|
||||||
let mut state = TestState {
|
let mut state = TestState {
|
||||||
population: vec![4, 3, 3],
|
population: vec![4, 3, 3],
|
||||||
};
|
};
|
||||||
|
@ -169,15 +171,15 @@ mod tests {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
generation: 0,
|
generation: 0,
|
||||||
max_generations: 0,
|
max_generations: 0,
|
||||||
semaphore: None,
|
gemla_context: (),
|
||||||
}
|
}
|
||||||
).unwrap();
|
).await.unwrap();
|
||||||
|
|
||||||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_merge() {
|
async fn test_merge() {
|
||||||
let state1 = TestState {
|
let state1 = TestState {
|
||||||
population: vec![1, 2, 4, 5],
|
population: vec![1, 2, 4, 5],
|
||||||
};
|
};
|
||||||
|
@ -186,7 +188,7 @@ mod tests {
|
||||||
population: vec![0, 1, 3, 7],
|
population: vec![0, 1, 3, 7],
|
||||||
};
|
};
|
||||||
|
|
||||||
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4()).unwrap();
|
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()).await.unwrap();
|
||||||
|
|
||||||
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
|
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
|
||||||
assert!(merged_state.population.iter().any(|&x| x == 7));
|
assert!(merged_state.population.iter().any(|&x| x == 7));
|
||||||
|
|
|
@ -5,9 +5,8 @@
|
||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
use tokio::sync::Semaphore;
|
use std::fmt::Debug;
|
||||||
use std::{fmt::Debug, sync::Arc};
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
@ -27,11 +26,11 @@ pub enum GeneticState {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct GeneticNodeContext {
|
pub struct GeneticNodeContext<S> {
|
||||||
pub generation: u64,
|
pub generation: u64,
|
||||||
pub max_generations: u64,
|
pub max_generations: u64,
|
||||||
pub id: Uuid,
|
pub id: Uuid,
|
||||||
pub semaphore: Option<Arc<Semaphore>>,
|
pub gemla_context: S
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
|
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
|
||||||
|
@ -39,21 +38,23 @@ pub struct GeneticNodeContext {
|
||||||
/// [`Bracket`]: crate::bracket::Bracket
|
/// [`Bracket`]: crate::bracket::Bracket
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait GeneticNode : Send {
|
pub trait GeneticNode : Send {
|
||||||
|
type Context;
|
||||||
|
|
||||||
/// Initializes a new instance of a [`GeneticState`].
|
/// Initializes a new instance of a [`GeneticState`].
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
/// TODO
|
/// TODO
|
||||||
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error>;
|
async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error>;
|
||||||
|
|
||||||
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error>;
|
async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
|
||||||
|
|
||||||
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
|
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
/// TODO
|
/// TODO
|
||||||
fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>;
|
async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
|
||||||
|
|
||||||
fn merge(left: &Self, right: &Self, id: &Uuid) -> Result<Box<Self>, Error>;
|
async fn merge(left: &Self, right: &Self, id: &Uuid, context: Self::Context) -> Result<Box<Self>, Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
|
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
|
||||||
|
@ -82,6 +83,7 @@ impl<T> Default for GeneticNodeWrapper<T> {
|
||||||
impl<T> GeneticNodeWrapper<T>
|
impl<T> GeneticNodeWrapper<T>
|
||||||
where
|
where
|
||||||
T: GeneticNode + Debug + Send,
|
T: GeneticNode + Debug + Send,
|
||||||
|
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
|
||||||
{
|
{
|
||||||
pub fn new(max_generations: u64) -> Self {
|
pub fn new(max_generations: u64) -> Self {
|
||||||
GeneticNodeWrapper::<T> {
|
GeneticNodeWrapper::<T> {
|
||||||
|
@ -120,17 +122,17 @@ where
|
||||||
self.state
|
self.state
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn process_node(&mut self, semaphore: Arc<Semaphore>) -> Result<GeneticState, Error> {
|
pub async fn process_node(&mut self, gemla_context: T::Context) -> Result<GeneticState, Error> {
|
||||||
let context = GeneticNodeContext {
|
let context = GeneticNodeContext {
|
||||||
generation: self.generation,
|
generation: self.generation,
|
||||||
max_generations: self.max_generations,
|
max_generations: self.max_generations,
|
||||||
id: self.id,
|
id: self.id,
|
||||||
semaphore: Some(semaphore),
|
gemla_context,
|
||||||
};
|
};
|
||||||
|
|
||||||
match (self.state, &mut self.node) {
|
match (self.state, &mut self.node) {
|
||||||
(GeneticState::Initialize, _) => {
|
(GeneticState::Initialize, _) => {
|
||||||
self.node = Some(*T::initialize(context.clone())?);
|
self.node = Some(*T::initialize(context.clone()).await?);
|
||||||
self.state = GeneticState::Simulate;
|
self.state = GeneticState::Simulate;
|
||||||
}
|
}
|
||||||
(GeneticState::Simulate, Some(n)) => {
|
(GeneticState::Simulate, Some(n)) => {
|
||||||
|
@ -144,7 +146,7 @@ where
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
(GeneticState::Mutate, Some(n)) => {
|
(GeneticState::Mutate, Some(n)) => {
|
||||||
n.mutate(context.clone())
|
n.mutate(context.clone()).await
|
||||||
.with_context(|| format!("Error mutating node: {:?}", self))?;
|
.with_context(|| format!("Error mutating node: {:?}", self))?;
|
||||||
|
|
||||||
self.generation += 1;
|
self.generation += 1;
|
||||||
|
@ -172,20 +174,22 @@ mod tests {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl GeneticNode for TestState {
|
impl GeneticNode for TestState {
|
||||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
type Context = ();
|
||||||
|
|
||||||
|
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||||
self.score += 1.0;
|
self.score += 1.0;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> {
|
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
|
||||||
Ok(Box::new(TestState { score: 0.0 }))
|
Ok(Box::new(TestState { score: 0.0 }))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn merge(_l: &TestState, _r: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> {
|
async fn merge(_l: &TestState, _r: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
|
||||||
Err(Error::Other(anyhow!("Unable to merge")))
|
Err(Error::Other(anyhow!("Unable to merge")))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -281,14 +285,13 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_process_node() -> Result<(), Error> {
|
async fn test_process_node() -> Result<(), Error> {
|
||||||
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
|
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
|
||||||
let semaphore = Arc::new(Semaphore::new(1));
|
|
||||||
|
|
||||||
assert_eq!(genetic_node.state(), GeneticState::Initialize);
|
assert_eq!(genetic_node.state(), GeneticState::Initialize);
|
||||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
|
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
|
||||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Mutate);
|
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Mutate);
|
||||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Simulate);
|
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
|
||||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
|
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
|
||||||
assert_eq!(genetic_node.process_node(semaphore.clone()).await?, GeneticState::Finish);
|
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,15 +4,15 @@
|
||||||
pub mod genetic_node;
|
pub mod genetic_node;
|
||||||
|
|
||||||
use crate::{error::Error, tree::Tree};
|
use crate::{error::Error, tree::Tree};
|
||||||
|
use async_recursion::async_recursion;
|
||||||
use file_linked::{constants::data_format::DataFormat, FileLinked};
|
use file_linked::{constants::data_format::DataFormat, FileLinked};
|
||||||
use futures::future;
|
use futures::{executor::block_on, future};
|
||||||
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
|
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
|
||||||
use log::{info, trace, warn};
|
use log::{info, trace, warn};
|
||||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
use tokio::sync::Semaphore;
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, sync::Arc, time::Instant
|
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, time::Instant
|
||||||
};
|
};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
@ -58,7 +58,6 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
|
||||||
pub struct GemlaConfig {
|
pub struct GemlaConfig {
|
||||||
pub generations_per_height: u64,
|
pub generations_per_height: u64,
|
||||||
pub overwrite: bool,
|
pub overwrite: bool,
|
||||||
pub shared_semaphore_concurrency_limit: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
|
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
|
||||||
|
@ -69,16 +68,17 @@ pub struct GemlaConfig {
|
||||||
/// [`GeneticNode`]: genetic_node::GeneticNode
|
/// [`GeneticNode`]: genetic_node::GeneticNode
|
||||||
pub struct Gemla<T>
|
pub struct Gemla<T>
|
||||||
where
|
where
|
||||||
T: Serialize + Clone,
|
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
|
||||||
|
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
|
||||||
{
|
{
|
||||||
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
|
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>,
|
||||||
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
|
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
|
||||||
semaphore: Arc<Semaphore>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: 'static> Gemla<T>
|
impl<T: 'static> Gemla<T>
|
||||||
where
|
where
|
||||||
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
|
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
|
||||||
|
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
|
||||||
{
|
{
|
||||||
pub fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result<Self, Error> {
|
pub fn new(path: &Path, config: GemlaConfig, data_format: DataFormat) -> Result<Self, Error> {
|
||||||
match File::open(path) {
|
match File::open(path) {
|
||||||
|
@ -86,18 +86,16 @@ where
|
||||||
// based on the configuration provided
|
// based on the configuration provided
|
||||||
Ok(_) => Ok(Gemla {
|
Ok(_) => Ok(Gemla {
|
||||||
data: if config.overwrite {
|
data: if config.overwrite {
|
||||||
FileLinked::new((None, config), path, data_format)?
|
FileLinked::new((None, config, T::Context::default()), path, data_format)?
|
||||||
} else {
|
} else {
|
||||||
FileLinked::from_file(path, data_format)?
|
FileLinked::from_file(path, data_format)?
|
||||||
},
|
},
|
||||||
threads: HashMap::new(),
|
threads: HashMap::new(),
|
||||||
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
|
|
||||||
}),
|
}),
|
||||||
// If the file doesn't exist we must create it
|
// If the file doesn't exist we must create it
|
||||||
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
|
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
|
||||||
data: FileLinked::new((None, config), path, data_format)?,
|
data: FileLinked::new((None, config, T::Context::default()), path, data_format)?,
|
||||||
threads: HashMap::new(),
|
threads: HashMap::new(),
|
||||||
semaphore: Arc::new(Semaphore::new(config.shared_semaphore_concurrency_limit)),
|
|
||||||
}),
|
}),
|
||||||
Err(error) => Err(Error::IO(error)),
|
Err(error) => Err(Error::IO(error)),
|
||||||
}
|
}
|
||||||
|
@ -117,7 +115,7 @@ where
|
||||||
{
|
{
|
||||||
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
|
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
|
||||||
// in the tree and which nodes have not.
|
// in the tree and which nodes have not.
|
||||||
self.data.mutate(|(d, c)| {
|
self.data.mutate(|(d, c, _)| {
|
||||||
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
|
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
|
||||||
mem::swap(d, &mut tree);
|
mem::swap(d, &mut tree);
|
||||||
})?;
|
})?;
|
||||||
|
@ -151,11 +149,11 @@ where
|
||||||
{
|
{
|
||||||
trace!("Adding node to process list {}", node.id());
|
trace!("Adding node to process list {}", node.id());
|
||||||
|
|
||||||
let semaphore = self.semaphore.clone();
|
let gemla_context = self.data.readonly().2.clone();
|
||||||
|
|
||||||
self.threads
|
self.threads
|
||||||
.insert(node.id(), tokio::spawn(async move {
|
.insert(node.id(), tokio::spawn(async move {
|
||||||
Gemla::process_node(node, semaphore).await
|
Gemla::process_node(node, gemla_context).await
|
||||||
}));
|
}));
|
||||||
} else {
|
} else {
|
||||||
trace!("No node found to process, joining threads");
|
trace!("No node found to process, joining threads");
|
||||||
|
@ -180,7 +178,7 @@ where
|
||||||
|
|
||||||
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
|
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
|
||||||
reduced_results.and_then(|r| {
|
reduced_results.and_then(|r| {
|
||||||
self.data.mutate(|(d, _)| {
|
self.data.mutate(|(d, _, context)| {
|
||||||
if let Some(t) = d {
|
if let Some(t) = d {
|
||||||
let failed_nodes = Gemla::replace_nodes(t, r);
|
let failed_nodes = Gemla::replace_nodes(t, r);
|
||||||
// We receive a list of nodes that were unable to be found in the original tree
|
// We receive a list of nodes that were unable to be found in the original tree
|
||||||
|
@ -192,7 +190,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
|
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
|
||||||
Gemla::merge_completed_nodes(t)
|
block_on(Gemla::merge_completed_nodes(t, context.clone()))
|
||||||
} else {
|
} else {
|
||||||
warn!("Unable to replce nodes {:?} in empty tree", r);
|
warn!("Unable to replce nodes {:?} in empty tree", r);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -204,7 +202,8 @@ where
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn merge_completed_nodes(tree: &mut SimulationTree<T>) -> Result<(), Error> {
|
#[async_recursion]
|
||||||
|
async fn merge_completed_nodes(tree: &mut SimulationTree<T>, gemla_context: T::Context) -> Result<(), Error> {
|
||||||
if tree.val.state() == GeneticState::Initialize {
|
if tree.val.state() == GeneticState::Initialize {
|
||||||
match (&mut tree.left, &mut tree.right) {
|
match (&mut tree.left, &mut tree.right) {
|
||||||
// If the current node has been initialized, and has children nodes that are completed, then we need
|
// If the current node has been initialized, and has children nodes that are completed, then we need
|
||||||
|
@ -215,7 +214,7 @@ where
|
||||||
{
|
{
|
||||||
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
|
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
|
||||||
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
|
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
|
||||||
let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id())?;
|
let merged_node = GeneticNode::merge(left_node, right_node, &tree.val.id(), gemla_context.clone()).await?;
|
||||||
tree.val = GeneticNodeWrapper::from(
|
tree.val = GeneticNodeWrapper::from(
|
||||||
*merged_node,
|
*merged_node,
|
||||||
tree.val.max_generations(),
|
tree.val.max_generations(),
|
||||||
|
@ -224,8 +223,8 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(Some(l), Some(r)) => {
|
(Some(l), Some(r)) => {
|
||||||
Gemla::merge_completed_nodes(l)?;
|
Gemla::merge_completed_nodes(l, gemla_context.clone()).await?;
|
||||||
Gemla::merge_completed_nodes(r)?;
|
Gemla::merge_completed_nodes(r, gemla_context.clone()).await?;
|
||||||
}
|
}
|
||||||
// If there is only one child node that's completed then we want to copy it to the parent node
|
// If there is only one child node that's completed then we want to copy it to the parent node
|
||||||
(Some(l), None) if l.val.state() == GeneticState::Finish => {
|
(Some(l), None) if l.val.state() == GeneticState::Finish => {
|
||||||
|
@ -239,7 +238,7 @@ where
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(Some(l), None) => Gemla::merge_completed_nodes(l)?,
|
(Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?,
|
||||||
(None, Some(r)) if r.val.state() == GeneticState::Finish => {
|
(None, Some(r)) if r.val.state() == GeneticState::Finish => {
|
||||||
trace!("Copying node {}", r.val.id());
|
trace!("Copying node {}", r.val.id());
|
||||||
|
|
||||||
|
@ -251,7 +250,7 @@ where
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(None, Some(r)) => Gemla::merge_completed_nodes(r)?,
|
(None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?,
|
||||||
(_, _) => (),
|
(_, _) => (),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -329,15 +328,15 @@ where
|
||||||
tree.val.state() == GeneticState::Finish
|
tree.val.state() == GeneticState::Finish
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn process_node(mut node: GeneticNodeWrapper<T>, semaphore: Arc<Semaphore>) -> Result<GeneticNodeWrapper<T>, Error> {
|
async fn process_node(mut node: GeneticNodeWrapper<T>, gemla_context: T::Context) -> Result<GeneticNodeWrapper<T>, Error> {
|
||||||
let node_state_time = Instant::now();
|
let node_state_time = Instant::now();
|
||||||
let node_state = node.state();
|
let node_state = node.state();
|
||||||
|
|
||||||
node.process_node(semaphore.clone()).await?;
|
node.process_node(gemla_context.clone()).await?;
|
||||||
|
|
||||||
if node.state() == GeneticState::Simulate
|
if node.state() == GeneticState::Simulate
|
||||||
{
|
{
|
||||||
node.process_node(semaphore.clone()).await?;
|
node.process_node(gemla_context.clone()).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
trace!(
|
trace!(
|
||||||
|
@ -397,20 +396,22 @@ mod tests {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl genetic_node::GeneticNode for TestState {
|
impl genetic_node::GeneticNode for TestState {
|
||||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
type Context = ();
|
||||||
|
|
||||||
|
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||||
self.score += 1.0;
|
self.score += 1.0;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> {
|
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
|
||||||
Ok(Box::new(TestState { score: 0.0 }))
|
Ok(Box::new(TestState { score: 0.0 }))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn merge(left: &TestState, right: &TestState, _id: &Uuid) -> Result<Box<TestState>, Error> {
|
async fn merge(left: &TestState, right: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
|
||||||
Ok(Box::new(if left.score > right.score {
|
Ok(Box::new(if left.score > right.score {
|
||||||
left.clone()
|
left.clone()
|
||||||
} else {
|
} else {
|
||||||
|
@ -433,7 +434,6 @@ mod tests {
|
||||||
let mut config = GemlaConfig {
|
let mut config = GemlaConfig {
|
||||||
generations_per_height: 1,
|
generations_per_height: 1,
|
||||||
overwrite: true,
|
overwrite: true,
|
||||||
shared_semaphore_concurrency_limit: 1,
|
|
||||||
};
|
};
|
||||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||||
|
|
||||||
|
@ -483,7 +483,6 @@ mod tests {
|
||||||
let config = GemlaConfig {
|
let config = GemlaConfig {
|
||||||
generations_per_height: 10,
|
generations_per_height: 10,
|
||||||
overwrite: true,
|
overwrite: true,
|
||||||
shared_semaphore_concurrency_limit: 1,
|
|
||||||
};
|
};
|
||||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue