diff --git a/analyze_data.py b/analyze_data.py new file mode 100644 index 0000000..0ac2c16 --- /dev/null +++ b/analyze_data.py @@ -0,0 +1,163 @@ +# Re-importing necessary libraries +import json +import matplotlib.pyplot as plt +from collections import defaultdict +import numpy as np + +# Simplified JSON data for demonstration +with open('gemla/round2.json', 'r') as file: + simplified_json_data = json.load(file) + +target_node_id = '0c1e64dc-6ddf-4dbb-bf6e-e8218b925194' + +# Function to traverse the tree to find a node id +def traverse_left_nodes(node): + if node is None: + return [] + + left_node = node.get("left") + if left_node is None: + return [node] + + return [node] + traverse_left_nodes(left_node) + +# Function to traverse the tree to find a node id +def traverse_right_nodes(node): + if node is None: + return [] + + right_node = node.get("right") + left_node = node.get("left") + + if right_node is None and left_node is None: + return [] + elif right_node and left_node: + return [right_node] + traverse_right_nodes(left_node) + + return [] + + +# Getting the left graph +left_nodes = traverse_left_nodes(simplified_json_data[0]) +left_nodes.reverse() +# print(node) +# Print properties available on the first node +node = left_nodes[0] +# print(node["val"].keys()) + +scores = [] +for node in left_nodes: + # print(node) + # print(f'Node ID: {node["val"]["id"]}') + # print(f'Node scores length: {len(node["val"]["node"]["scores"])}') + if node["val"]["node"]: + node_scores = node["val"]["node"]["scores"] + if node_scores: + for score in node_scores: + scores.append(score) + +# print(scores) + +scores_values = [list(score_set.values()) for score_set in scores] + +# Set up the figure for plotting on the same graph +fig, ax = plt.subplots(figsize=(10, 6)) + +# Generate a boxplot for each set of scores on the same graph +boxplots = ax.boxplot(scores_values, vert=False, patch_artist=True, labels=[f'Set {i+1}' for i in range(len(scores_values))]) + +# Set figure name to node id +# fig.canvas.set_window_title('Main node line') + +# Labeling +ax.set_xlabel(f'Scores - Main Line') +ax.set_ylabel('Score Sets') +ax.yaxis.grid(True) # Add horizontal grid lines for clarity + +# Set y-axis labels to be visible +ax.set_yticklabels([f'Set {i+1}' for i in range(len(scores_values))]) + +# Getting most recent right graph +right_nodes = traverse_right_nodes(simplified_json_data[0]) +target_node_id = None +target_node = None +if target_node_id: + for node in right_nodes: + if node["val"]["id"] == target_node_id: + target_node = node + break +else: + target_node = right_nodes[1] +scores = target_node["val"]["node"]["scores"] + +scores_values = [list(score_set.values()) for score_set in scores] + +# Set up the figure for plotting on the same graph +fig, ax = plt.subplots(figsize=(10, 6)) + +# Generate a boxplot for each set of scores on the same graph +boxplots = ax.boxplot(scores_values, vert=False, patch_artist=True, labels=[f'Set {i+1}' for i in range(len(scores_values))]) + + +# Labeling +ax.set_xlabel(f'Scores: {target_node['val']['id']}') +ax.set_ylabel('Score Sets') +ax.yaxis.grid(True) # Add horizontal grid lines for clarity + +# Set y-axis labels to be visible +ax.set_yticklabels([f'Set {i+1}' for i in range(len(scores_values))]) + +# Find the highest scoring sets combining all scores and generations +scores = [] +for node in left_nodes: + if node["val"]["node"]: + node_scores = node["val"]["node"]["scores"] + translated_node_scores = [] + if node_scores: + for i in range(len(node_scores)): + for (individual, score) in node_scores[i].items(): + translated_node_scores.append((node["val"]["id"], i, score)) + + scores.append(translated_node_scores) + +# Add scores from the right nodes +for node in right_nodes: + if node["val"]["node"]: + node_scores = node["val"]["node"]["scores"] + translated_node_scores = [] + if node_scores: + for i in range(len(node_scores)): + for (individual, score) in node_scores[i].items(): + translated_node_scores.append((node["val"]["id"], i, score)) + +# Organize scores by individual and then by generation +individual_generation_scores = defaultdict(lambda: defaultdict(list)) +for sublist in scores: + for id, generation, score in sublist: + individual_generation_scores[id][generation].append(score) + +# Calculate Q3 for each individual's generation +individual_generation_q3 = {} +for id, generations in individual_generation_scores.items(): + for gen, scores in generations.items(): + individual_generation_q3[(id, gen)] = np.percentile(scores, 75) + +# Sort by Q3 value, highest first, and select the top 20 +top_20_individual_generations = sorted(individual_generation_q3, key=individual_generation_q3.get, reverse=True)[:40] + +# Prepare scores for the top 20 for plotting +top_20_scores = [individual_generation_scores[id][gen] for id, gen in top_20_individual_generations] + +# Adjust labels for clarity, indicating both the individual ID and generation +labels = [f'{id[:8]}... Gen {gen}' for id, gen in top_20_individual_generations] + +# Generate box and whisker plots for the top 20 individual generations +fig, ax = plt.subplots(figsize=(12, 10)) +ax.boxplot(top_20_scores, vert=False, patch_artist=True, labels=labels) +ax.set_xlabel('Scores') +ax.set_ylabel('Individual Generation') +ax.set_title('Top 20 Individual Generations by Q3 Value') + +# Display the plot +plt.show() + diff --git a/gemla/src/bin/bin.rs b/gemla/src/bin/bin.rs index e41e417..96248cb 100644 --- a/gemla/src/bin/bin.rs +++ b/gemla/src/bin/bin.rs @@ -47,10 +47,7 @@ fn main() -> Result<()> { let mut gemla = log_error( Gemla::::new( &PathBuf::from(args.file), - GemlaConfig { - generations_per_height: 5, - overwrite: false, - }, + GemlaConfig { overwrite: false }, DataFormat::Json, ) .await, diff --git a/gemla/src/bin/fighter_nn/mod.rs b/gemla/src/bin/fighter_nn/mod.rs index b52deec..820d90b 100644 --- a/gemla/src/bin/fighter_nn/mod.rs +++ b/gemla/src/bin/fighter_nn/mod.rs @@ -3,7 +3,7 @@ extern crate fann; pub mod fighter_context; pub mod neural_network_utility; -use anyhow::Context; +use anyhow::{anyhow, Context}; use async_trait::async_trait; use fann::{ActivationFunc, Fann}; use futures::future::join_all; @@ -22,7 +22,7 @@ use std::{ ops::Range, path::{Path, PathBuf}, }; -use tokio::process::Command; +use tokio::{process::Command, sync::mpsc::channel}; use uuid::Uuid; use self::neural_network_utility::{crossbreed, major_mutation}; @@ -65,12 +65,15 @@ pub struct FighterNN { // A map of each nn identifier in a generation and their physics score pub scores: Vec>, // A map of the id of the nn in the current generation and their neural network shape - pub nn_shapes: HashMap>, + pub nn_shapes: Vec>>, pub crossbreed_segments: usize, pub weight_initialization_range: Range, pub minor_mutation_rate: f32, pub major_mutation_rate: f32, pub mutation_weight_range: Range, + // Shows how individuals are mapped from one generation to the next + pub id_mapping: Vec>, + pub lerp_amount: f32, } #[async_trait] @@ -146,115 +149,158 @@ impl GeneticNode for FighterNN { population_size: POPULATION, generation: 0, scores: vec![HashMap::new()], - nn_shapes, + nn_shapes: vec![nn_shapes], // we need crossbreed segments to be even crossbreed_segments, weight_initialization_range, minor_mutation_rate: thread_rng().gen_range(0.0..1.0), major_mutation_rate: thread_rng().gen_range(0.0..1.0), mutation_weight_range: -mutation_weight_amplitude..mutation_weight_amplitude, + id_mapping: vec![HashMap::new()], + lerp_amount: 0.0, })) } - async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error> { + async fn simulate( + &mut self, + context: GeneticNodeContext, + ) -> Result { debug!("Context: {:?}", context); + let mut matches = Vec::new(); + let mut allotted_simulations = Vec::new(); + for i in 0..self.population_size { + allotted_simulations.push((i, SIMULATION_ROUNDS)); + } + + while !allotted_simulations.is_empty() { + let primary_id = { + let id = thread_rng().gen_range(0..allotted_simulations.len()); + let (i, _) = allotted_simulations[id]; + // Decrement the number of simulations left for this nn + allotted_simulations[id].1 -= 1; + // Remove the nn from the list if it has no more simulations left + if allotted_simulations[id].1 == 0 { + allotted_simulations.remove(id); + } + i + }; + + let secondary_id = loop { + let id = thread_rng().gen_range(0..allotted_simulations.len()); + let (i, _) = allotted_simulations[id]; + + if i != primary_id { + // Decrement the number of simulations left for this nn + allotted_simulations[id].1 -= 1; + // Remove the nn from the list if it has no more simulations left + if allotted_simulations[id].1 == 0 { + allotted_simulations.remove(id); + } + break i; + } + }; + + matches.push((primary_id, secondary_id)); + } + + trace!("Matches: {:?}", matches); + + // Create a channel to send the scores back to the main thread + let (tx, mut rx) = channel::<(usize, f32)>(self.population_size * SIMULATION_ROUNDS * 20); let mut tasks = Vec::new(); - // For each nn in the current generation: - for i in 0..self.population_size { + for (primary_id, secondary_id) in matches.iter() { let self_clone = self.clone(); let semaphore_clone = context.gemla_context.shared_semaphore.clone(); let display_simulation_semaphore = context.gemla_context.visible_simulations.clone(); + let tx = tx.clone(); let task = async move { - let nn = self_clone + let folder = self_clone.folder.clone(); + let generation = self_clone.generation; + + let primary_nn = self_clone .folder .join(format!("{}", self_clone.generation)) - .join(self_clone.get_individual_id(i as u64)) + .join(self_clone.get_individual_id(*primary_id as u64)) + .with_extension("net"); + let secondary_nn = folder + .join(format!("{}", generation)) + .join(self_clone.get_individual_id(*secondary_id as u64)) .with_extension("net"); - let mut simulations = Vec::new(); - // Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently - for _ in 0..SIMULATION_ROUNDS { - let random_nn_index = thread_rng().gen_range(0..self_clone.population_size); - let folder = self_clone.folder.clone(); - let generation = self_clone.generation; - let semaphore_clone = semaphore_clone.clone(); - let display_simulation_semaphore = display_simulation_semaphore.clone(); + let permit = semaphore_clone + .acquire_owned() + .await + .with_context(|| "Failed to acquire semaphore permit")?; - let random_nn = folder - .join(format!("{}", generation)) - .join(self_clone.get_individual_id(random_nn_index as u64)) - .with_extension("net"); - let nn_clone = nn.clone(); // Clone the path to use in the async block + let display_simulation = match display_simulation_semaphore.try_acquire_owned() { + Ok(s) => Some(s), + Err(_) => None, + }; - let future = async move { - let permit = semaphore_clone - .acquire_owned() - .await - .with_context(|| "Failed to acquire semaphore permit")?; - - let display_simulation = - match display_simulation_semaphore.try_acquire_owned() { - Ok(s) => Some(s), - Err(_) => None, - }; - - let (score, _) = if let Some(display_simulation) = display_simulation { - let result = run_1v1_simulation(&nn_clone, &random_nn, true).await?; - drop(display_simulation); - result - } else { - run_1v1_simulation(&nn_clone, &random_nn, false).await? - }; - - drop(permit); - - Ok(score) + let (primary_score, secondary_score) = + if let Some(display_simulation) = display_simulation { + let result = run_1v1_simulation(&primary_nn, &secondary_nn, true).await?; + drop(display_simulation); + result + } else { + run_1v1_simulation(&primary_nn, &secondary_nn, false).await? }; - simulations.push(future); - } + drop(permit); - // Wait for all simulation rounds to complete - let results: Result, Error> = - join_all(simulations).await.into_iter().collect(); + debug!( + "{} vs {} -> {} vs {}", + primary_id, secondary_id, primary_score, secondary_score + ); - let score = match results { - Ok(scores) => scores.into_iter().sum::() / SIMULATION_ROUNDS as f32, - Err(e) => return Err(e), // Return the error if results collection failed - }; - debug!("NN {:06}_fighter_nn_{} scored {}", self_clone.id, i, score); - Ok((i, score)) + // Send score using a channel + tx.send((*primary_id, primary_score)) + .await + .with_context(|| "Failed to send score")?; + tx.send((*secondary_id, secondary_score)) + .await + .with_context(|| "Failed to send score")?; + + Ok(()) }; tasks.push(task); } - let results = join_all(tasks).await; + let results: Vec> = join_all(tasks).await; - for result in results { - match result { - Ok((index, score)) => { - // Update the original `self` object with the score. - self.scores[self.generation as usize].insert(index as u64, score); - } - Err(e) => { - // Handle task panic or execution error - return Err(Error::Other(anyhow::anyhow!(format!( - "Task failed: {:?}", - e - )))); - } + // resolve results for any errors + for result in results.into_iter() { + result.with_context(|| "Failed to run simulation")?; + } + + // Receive the scores from the channel + let mut scores = HashMap::new(); + while let Some((id, score)) = rx.recv().await { + // If score exists, add the new score to the existing score + if let Some(existing_score) = scores.get_mut(&(id as u64)) { + *existing_score += score; + } else { + scores.insert(id as u64, score); } } - Ok(()) + // Average scores for each individual + for (_, score) in scores.iter_mut() { + *score /= SIMULATION_ROUNDS as f32; + } + + self.scores.push(scores); + + Ok(should_continue(&self.scores)?) } async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize; let mut nn_sizes = Vec::new(); + let mut id_mapping = HashMap::new(); // Create the new generation folder let new_gen_folder = self.folder.join(format!("{}", self.generation + 1)); @@ -280,8 +326,14 @@ impl GeneticNode for FighterNN { .join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id)); let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i)); debug!("Copying nn from {:?} to {:?}", nn_id, i); + id_mapping.insert(**nn_id, i as u64); fs::copy(&nn, &new_nn)?; - nn_sizes.push(self.nn_shapes.get(nn_id).unwrap().clone()); + nn_sizes.push( + self.nn_shapes[self.generation as usize] + .get(nn_id) + .unwrap() + .clone(), + ); } let weights: HashMap = scores_to_keep.iter().map(|(k, v)| (**k, **v)).collect(); @@ -352,9 +404,7 @@ impl GeneticNode for FighterNN { let new_nn = new_gen_folder .join(self_clone.get_individual_id((i + survivor_count) as u64)) .with_extension("net"); - new_fann - .save(&new_nn) - .with_context(|| "Failed to save nn")?; + new_fann.save(new_nn).with_context(|| "Failed to save nn")?; Ok::, Error>(new_fann.get_layer_sizes()) }); @@ -369,8 +419,17 @@ impl GeneticNode for FighterNN { nn_sizes.push(new_size); } + // Use the index of nn_sizes to generate the id for the nn_sizes HashMap + let nn_sizes_map = nn_sizes + .into_iter() + .enumerate() + .map(|(i, v)| (i as u64, v)) + .collect::>(); + self.generation += 1; self.scores.push(HashMap::new()); + self.nn_shapes.push(nn_sizes_map); + self.id_mapping.push(id_mapping); Ok(()) } @@ -502,10 +561,11 @@ impl GeneticNode for FighterNN { format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path) })?; - nn_shapes.insert( - (start_idx + i) as u64, - source.nn_shapes.get(nn_id).unwrap().clone(), - ); + let nn_shape = source.nn_shapes[source.generation as usize] + .get(nn_id) + .unwrap(); + + nn_shapes.insert((start_idx + i) as u64, nn_shape.clone()); } Ok(()) @@ -577,11 +637,13 @@ impl GeneticNode for FighterNN { population_size: nn_shapes.len(), scores: vec![HashMap::new()], crossbreed_segments, - nn_shapes, + nn_shapes: vec![nn_shapes], weight_initialization_range, minor_mutation_rate, major_mutation_rate, mutation_weight_range, + id_mapping: vec![HashMap::new()], + lerp_amount, })) } } @@ -592,6 +654,54 @@ impl FighterNN { } } +fn should_continue(scores: &[HashMap]) -> Result { + if scores.len() < 5 { + return Ok(true); + } + + let mut highest_q3_value = f32::MIN; + let mut generation_with_highest_q3 = 0; + + let mut highest_median = f32::MIN; + let mut generation_with_highest_median = 0; + + for (generation_index, generation) in scores.iter().enumerate() { + let mut scores: Vec = generation.values().copied().collect(); + scores.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + let q3_index = (scores.len() as f32 * 0.75).ceil() as usize - 1; + let q3_value = scores + .get(q3_index) + .ok_or(anyhow!("Failed to get Q3 value"))?; + + if *q3_value > highest_q3_value { + highest_q3_value = *q3_value; + generation_with_highest_q3 = generation_index; + } + + let median_index = (scores.len() as f32 * 0.5).ceil() as usize - 1; + let median_value = scores + .get(median_index) + .ok_or(anyhow!("Failed to get median value"))?; + + if *median_value > highest_median { + highest_median = *median_value; + generation_with_highest_median = generation_index; + } + } + + let highest_generation_index = scores.len() - 1; + let result = highest_generation_index - generation_with_highest_q3 < 5 + && highest_generation_index - generation_with_highest_median < 5; + + debug!( + "Highest Q3 value: {} at generation {}, Highest Median value: {} at generation {}, Continuing? {}", + highest_q3_value, generation_with_highest_q3, highest_median, generation_with_highest_median, result + ); + + Ok(result) +} + fn weighted_random_selection(weights: &HashMap) -> T { let mut rng = thread_rng(); @@ -828,4 +938,655 @@ pub mod test { assert_eq!(ids.len(), 0); } + + #[test] + fn test_should_continue() { + let scores = vec![ + // Generation 0 + [ + (37, -7.1222725), + (12, -3.6037624), + (27, -5.202844), + (21, -6.3283415), + (4, -6.0053186), + (8, -4.040202), + (13, -4.0050435), + (17, -5.8206105), + (40, -7.5448103), + (42, -8.027704), + (15, -5.1600137), + (10, -7.9063845), + (1, -6.9830275), + (7, -3.3323112), + (16, -6.1065326), + (23, -6.417853), + (25, -6.410652), + (14, -6.5887403), + (3, -6.3966584), + (19, 0.1242948), + (28, -4.806827), + (18, -6.3310747), + (30, -5.8972425), + (31, -6.398958), + (22, -7.042196), + (29, -5.7098813), + (9, -8.931531), + (33, -5.9806275), + (6, -6.5489874), + (26, -5.892653), + (34, -6.4281516), + (35, -5.5369387), + (38, -5.495344), + (43, 0.9552175), + (44, -6.2549844), + (45, -8.42142), + (24, -7.121878), + (47, -5.373896), + (48, -6.445716), + (39, -6.053849), + (11, -5.8320975), + (49, -10.014197), + (46, -7.0919595), + (20, -6.033137), + (5, -6.3501267), + (32, -4.203919), + (2, -5.743471), + (36, -8.493466), + (41, -7.60419), + (0, -7.388545), + ], + // Generation 1 + [ + (18, -6.048934), + (39, -1.1448132), + (48, -7.921489), + (38, -6.0117235), + (27, -6.30289), + (9, -6.5567093), + (29, -5.905172), + (25, -4.2305975), + (40, -5.1198816), + (24, -7.232001), + (46, -6.5581756), + (20, -6.7987585), + (8, -9.346154), + (2, -7.6944494), + (3, -6.487195), + (16, -8.379641), + (32, -7.292016), + (33, -7.91467), + (41, -7.4449363), + (21, -6.0500197), + (19, -5.357873), + (10, -6.9984064), + (7, -5.6824636), + (13, -8.154273), + (45, -7.8713655), + (47, -5.279138), + (49, -1.915852), + (6, -2.682654), + (30, -5.566201), + (1, -1.829716), + (11, -7.7527223), + (12, -10.379072), + (15, -4.866212), + (35, -8.091223), + (36, -8.137203), + (42, -7.2846284), + (44, -4.7636213), + (28, -6.518874), + (34, 1.9858776), + (43, -10.140268), + (0, -3.5068736), + (17, -2.3913155), + (26, -6.1766686), + (22, -9.119884), + (14, -7.470778), + (5, -5.925585), + (23, -6.004782), + (31, -2.696432), + (4, -2.4887466), + (37, -5.5321026), + ], + // Generation 2 + [ + (25, -8.760574), + (0, -2.5970187), + (9, -4.270929), + (11, -0.27550858), + (20, -6.7012835), + (30, 2.3309054), + (4, -7.0107384), + (31, -7.5239167), + (41, -2.337672), + (6, -3.4384027), + (16, -7.9485044), + (37, -7.3155503), + (38, -7.4812994), + (3, -3.958924), + (42, -7.738173), + (43, -6.500585), + (22, -6.318394), + (17, -5.7882595), + (45, -8.782414), + (49, -8.84129), + (23, -10.222613), + (26, -6.06804), + (32, -6.4851217), + (33, -7.3542376), + (34, -2.8723297), + (27, -7.1350646), + (8, -2.7956052), + (18, -5.0000043), + (10, -1.5138103), + (2, 0.10560961), + (7, -1.4954948), + (35, -7.7015786), + (36, -8.602789), + (47, -8.117584), + (28, -9.151132), + (39, -8.035833), + (13, -6.2601876), + (15, -9.050044), + (19, -5.465233), + (44, -8.494604), + (5, -6.9012084), + (12, -9.458872), + (21, -5.980685), + (14, -7.7407913), + (46, -0.701484), + (24, -9.477325), + (29, -6.6444407), + (1, -3.4681067), + (40, -5.4685316), + (48, 0.22965483), + ], + // Generation 3 + [ + (11, -5.7744265), + (12, 0.10171394), + (18, -8.503949), + (3, -1.9760166), + (17, -7.895561), + (20, -8.515409), + (45, -1.9184738), + (6, -5.6488137), + (46, -6.1171823), + (49, -7.006673), + (29, -3.6479561), + (37, -4.025724), + (42, -4.1281996), + (9, -2.7060657), + (33, 0.18799233), + (15, -7.8216696), + (23, -11.02603), + (22, -10.132984), + (7, -6.432255), + (38, -7.2159233), + (10, -2.195277), + (2, -6.7676725), + (27, -1.8040345), + (34, -11.214028), + (40, -6.1334066), + (35, -9.410227), + (44, -0.14929143), + (47, -7.3865366), + (41, -9.200221), + (26, -6.1885824), + (13, -5.5693216), + (31, -8.184256), + (39, -8.06583), + (24, -11.773471), + (25, -15.231514), + (14, -5.4468412), + (30, -5.494699), + (21, -10.619481), + (28, -7.322004), + (16, -7.4136076), + (8, -3.2260292), + (32, -8.187313), + (19, -5.9347467), + (43, -0.112977505), + (5, -1.9279568), + (48, -3.8396995), + (0, -9.317253), + (4, -1.8099403), + (1, -5.4981036), + (36, -3.5487309), + ], + // Generation 4 + [ + (28, -6.2057357), + (40, -6.9324327), + (46, -0.5130272), + (23, -7.9489794), + (47, -7.3411865), + (20, -8.930363), + (26, -3.238875), + (41, -7.376683), + (48, -0.83026105), + (27, -10.048681), + (36, -5.1788163), + (30, -8.002236), + (9, -7.4656434), + (4, -3.8850121), + (16, -3.1768656), + (11, 1.0195583), + (44, -8.7163315), + (45, -6.7038856), + (33, -6.974304), + (22, -10.026589), + (13, -4.342838), + (12, -6.69588), + (31, -2.2994905), + (14, -7.9772606), + (32, -10.55702), + (38, -5.668454), + (34, -10.026564), + (37, -8.128912), + (42, -10.7178335), + (17, -5.18195), + (49, -9.900299), + (21, -12.4000635), + (8, -1.8514707), + (29, -3.365313), + (39, -5.588918), + (43, -8.482417), + (1, -4.390686), + (35, -5.604909), + (24, -7.1810236), + (25, -5.9158974), + (19, -4.5733366), + (0, -5.68081), + (3, -2.8414884), + (6, -1.5809858), + (7, -9.295659), + (5, -3.7936096), + (10, -4.088697), + (2, -2.3494315), + (15, -7.3323736), + (18, -7.7137175), + ], + // Generation 5 + [ + (1, -2.7719336), + (37, -6.097855), + (39, -4.1296787), + (2, -5.4538774), + (34, -11.808794), + (40, -9.822159), + (3, -7.884645), + (42, -14.777964), + (32, -2.6564443), + (16, -5.2442584), + (9, -6.2919874), + (48, -2.4359574), + (25, -11.707236), + (33, -5.5483084), + (35, -0.3632618), + (7, -4.3673687), + (27, -8.139543), + (12, -9.019396), + (17, -0.029791832), + (24, -8.63045), + (18, -11.925819), + (20, -9.040375), + (44, -10.296264), + (47, -15.95397), + (23, -12.38116), + (21, 0.18342426), + (38, -7.695002), + (6, -8.710346), + (28, -2.8542902), + (5, -2.077858), + (10, -3.638583), + (8, -7.360152), + (15, -7.1610765), + (29, -4.8372035), + (45, -11.499393), + (13, -3.8436065), + (22, -5.472387), + (11, -4.259357), + (26, -4.847328), + (4, -2.0376666), + (36, -7.5392637), + (41, -5.3857164), + (19, -8.576212), + (14, -8.267895), + (30, -4.0456495), + (31, -3.806975), + (43, -7.9901657), + (46, -7.181662), + (0, -7.502816), + (49, -7.3067017), + ], + // Generation 6 + [ + (17, -9.793276), + (27, -2.8843281), + (38, -8.737534), + (8, -1.5083166), + (16, -8.267393), + (42, -8.055011), + (47, -2.0843022), + (14, -3.9945045), + (30, -10.208374), + (26, -3.2439823), + (49, -2.5527742), + (25, -10.359426), + (9, -4.4744225), + (19, -7.2775927), + (3, -7.282045), + (36, -8.503307), + (40, -12.083569), + (22, -3.7249084), + (18, -7.5065627), + (41, -3.3326488), + (44, -2.76882), + (45, -12.154654), + (24, -2.8332536), + (5, -5.2674284), + (4, -4.105483), + (10, -6.930478), + (20, -3.7845988), + (2, -4.4593267), + (28, -0.3003047), + (29, -6.5971193), + (32, -5.0542274), + (33, -9.068264), + (43, -7.124672), + (46, -8.358111), + (23, -5.551978), + (11, -7.7810373), + (35, -7.4763336), + (34, -10.868844), + (39, -10.51066), + (7, -4.376377), + (48, -9.093265), + (6, -0.20033613), + (1, -6.125786), + (12, -8.243349), + (0, -7.1646323), + (13, -3.7055316), + (15, -6.295897), + (21, -5.929867), + (31, -7.2123885), + (37, -2.482071), + ], + // Generation 7 + [ + (30, -12.467585), + (14, -5.1706576), + (40, -9.03964), + (18, -5.7730474), + (41, -9.061858), + (20, -2.8577142), + (24, -3.3558655), + (42, -7.902747), + (43, -6.1566644), + (21, -5.4271364), + (23, -7.1462164), + (44, -7.9898252), + (11, -2.493559), + (31, -4.6718645), + (48, -12.774545), + (8, -7.252562), + (35, -1.6866531), + (49, -4.437603), + (45, -7.164916), + (7, -4.613396), + (32, -8.156101), + (39, -10.887325), + (0, -0.18116185), + (47, -4.998584), + (10, -8.914183), + (13, -0.8690014), + (27, -0.3714923), + (28, -12.002966), + (9, -6.2789965), + (26, -0.46416503), + (2, -9.865377), + (29, -8.443848), + (46, -6.3264246), + (3, -7.807205), + (4, -6.8240366), + (5, -6.843891), + (12, -5.6381693), + (15, -4.6679296), + (36, -6.8010025), + (16, -8.222928), + (25, -10.326822), + (34, -6.0182467), + (37, -8.713378), + (38, -7.549215), + (17, -7.247555), + (22, -13.296148), + (33, -8.542955), + (19, -7.254419), + (1, -2.8472056), + (6, -5.898753), + ], + // Generation 8 + [ + (7, -3.6624274), + (4, -2.9281456), + (39, -5.9176188), + (13, -8.0644045), + (16, -2.0319564), + (49, -10.309226), + (3, -0.21671781), + (37, -8.295551), + (44, -16.496105), + (46, -6.2466326), + (47, -3.5928986), + (19, -9.298591), + (1, -7.937351), + (15, -8.218504), + (6, -6.945601), + (25, -8.446054), + (12, -5.8477135), + (14, -3.9165816), + (17, -2.4864268), + (20, -7.97737), + (22, -5.347026), + (0, -6.0739775), + (32, -6.7568192), + (36, -4.730008), + (28, -9.923819), + (38, -8.677519), + (42, -4.668519), + (48, 0.14014988), + (5, -8.3167), + (8, -2.5030074), + (21, -1.8195568), + (27, -6.111103), + (45, -12.708131), + (35, -8.089076), + (11, -6.0151362), + (34, -13.688166), + (33, -11.375975), + (2, -4.1082373), + (24, -4.0867376), + (10, -4.2828474), + (41, -9.174506), + (43, -1.1505331), + (29, -3.7704785), + (18, -4.9493446), + (30, -3.727829), + (31, -6.490308), + (9, -6.0947385), + (40, -9.492185), + (26, -13.629112), + (23, -9.773454), + ], + // Generation 9 + [ + (12, -1.754871), + (41, 2.712658), + (24, -4.0929146), + (18, -4.9418926), + (44, -9.325021), + (8, -6.4423165), + (1, -0.0946085), + (5, -3.0156248), + (14, -5.29519), + (34, -10.763539), + (11, -7.304751), + (20, -6.8397574), + (22, -5.6720686), + (23, -7.829904), + (7, -3.8627372), + (6, -3.1108487), + (16, -8.803584), + (36, -13.916307), + (21, -10.142917), + (37, -12.171498), + (45, -13.004938), + (19, -3.7237267), + (47, -6.0189786), + (17, -4.612711), + (15, -5.3010545), + (30, -5.671092), + (46, -13.300519), + (25, -8.2948), + (3, -10.556543), + (42, -7.041272), + (48, -9.797744), + (9, -5.6163936), + (26, -6.665021), + (27, -7.074666), + (4, -1.5992731), + (2, -6.4931273), + (29, -3.9785416), + (31, -12.222026), + (10, -2.3970482), + (40, -6.204074), + (49, -7.025599), + (28, -8.562909), + (13, -6.2592154), + (32, -10.465271), + (33, -7.7043953), + (35, -6.4584246), + (38, -2.9016697), + (39, -1.5256255), + (43, -10.858711), + (0, -4.720929), + ], + //Generation 10 + [ + (2, -5.1676617), + (3, -4.521774), + (29, -7.3104324), + (23, -6.550776), + (26, -10.467587), + (18, 1.6576093), + (33, -2.564094), + (20, -3.2697926), + (35, -13.577334), + (37, -6.0147185), + (17, -4.07909), + (0, -9.630419), + (38, -7.011383), + (12, -10.686635), + (43, -8.94728), + (48, -9.350017), + (30, -7.3335466), + (13, -7.7690034), + (4, -2.3488472), + (14, -7.2594194), + (21, -9.08367), + (34, -7.7497597), + (8, -6.2317214), + (27, -8.440135), + (22, -4.4437346), + (32, -2.194015), + (28, -6.6919556), + (40, -8.840385), + (42, -9.781796), + (15, -7.3304253), + (49, -8.720987), + (19, -9.044103), + (6, -5.715863), + (41, -8.395639), + (36, -3.995482), + (25, -9.1373005), + (5, -7.5690002), + (1, -6.0397635), + (16, -8.231512), + (10, -6.5344634), + (44, -7.749376), + (7, -9.302668), + (31, -10.868391), + (39, -2.7578635), + (47, -6.964238), + (24, -4.033315), + (11, -8.211409), + (45, -10.472969), + (9, -7.1529093), + (46, -9.653514), + ], + ]; + + // Transform scores into a vector of hashmaps instead + let scores: Vec> = scores + .iter() + .map(|gen_scores| gen_scores.iter().cloned().collect()) + .collect(); + + assert!( + should_continue(scores[..0].as_ref()) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..1].as_ref()) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..2].as_ref()) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..3].as_ref()) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..4].as_ref()) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..5].as_ref()) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..6].as_ref()) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..7].as_ref()) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..8].as_ref()) + .expect("Failed to determine if the simulation should continue") + == false + ); + assert!( + should_continue(scores[..9].as_ref()) + .expect("Failed to determine if the simulation should continue") + == false + ); + assert!( + should_continue(scores[..10].as_ref()) + .expect("Failed to determine if the simulation should continue") + == false + ); + } } diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index af5b316..5bde119 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -13,6 +13,7 @@ const POPULATION_REDUCTION_SIZE: u64 = 3; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct TestState { pub population: Vec, + pub max_generations: u64, } #[async_trait] @@ -26,10 +27,16 @@ impl GeneticNode for TestState { population.push(thread_rng().gen_range(0..100)) } - Ok(Box::new(TestState { population })) + Ok(Box::new(TestState { + population, + max_generations: 10, + })) } - async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + async fn simulate( + &mut self, + context: GeneticNodeContext, + ) -> Result { let mut rng = thread_rng(); self.population = self @@ -38,7 +45,11 @@ impl GeneticNode for TestState { .map(|p| p.saturating_add(rng.gen_range(-1..2))) .collect(); - Ok(()) + if context.generation >= self.max_generations { + Ok(false) + } else { + Ok(true) + } } async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { @@ -93,13 +104,15 @@ impl GeneticNode for TestState { v = v[..(POPULATION_REDUCTION_SIZE as usize)].to_vec(); - let mut result = TestState { population: v }; + let mut result = TestState { + population: v, + max_generations: 10, + }; result .mutate(GeneticNodeContext { id: *id, generation: 0, - max_generations: 0, gemla_context, }) .await?; @@ -118,7 +131,6 @@ mod tests { let state = TestState::initialize(GeneticNodeContext { id: Uuid::new_v4(), generation: 0, - max_generations: 0, gemla_context: (), }) .await @@ -131,6 +143,7 @@ mod tests { async fn test_simulate() { let mut state = TestState { population: vec![1, 1, 2, 3], + max_generations: 1, }; let original_population = state.population.clone(); @@ -139,7 +152,6 @@ mod tests { .simulate(GeneticNodeContext { id: Uuid::new_v4(), generation: 0, - max_generations: 0, gemla_context: (), }) .await @@ -153,7 +165,6 @@ mod tests { .simulate(GeneticNodeContext { id: Uuid::new_v4(), generation: 0, - max_generations: 0, gemla_context: (), }) .await @@ -162,7 +173,6 @@ mod tests { .simulate(GeneticNodeContext { id: Uuid::new_v4(), generation: 0, - max_generations: 0, gemla_context: (), }) .await @@ -177,13 +187,13 @@ mod tests { async fn test_mutate() { let mut state = TestState { population: vec![4, 3, 3], + max_generations: 1, }; state .mutate(GeneticNodeContext { id: Uuid::new_v4(), generation: 0, - max_generations: 0, gemla_context: (), }) .await @@ -196,10 +206,12 @@ mod tests { async fn test_merge() { let state1 = TestState { population: vec![1, 2, 4, 5], + max_generations: 1, }; let state2 = TestState { population: vec![0, 1, 3, 7], + max_generations: 1, }; let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()) diff --git a/gemla/src/core/genetic_node.rs b/gemla/src/core/genetic_node.rs index 020d2c6..b85b775 100644 --- a/gemla/src/core/genetic_node.rs +++ b/gemla/src/core/genetic_node.rs @@ -28,7 +28,6 @@ pub enum GeneticState { #[derive(Clone, Debug)] pub struct GeneticNodeContext { pub generation: u64, - pub max_generations: u64, pub id: Uuid, pub gemla_context: S, } @@ -46,7 +45,8 @@ pub trait GeneticNode: Send { /// TODO async fn initialize(context: GeneticNodeContext) -> Result, Error>; - async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; + async fn simulate(&mut self, context: GeneticNodeContext) + -> Result; /// Mutates members in a population and/or crossbreeds them to produce new offspring. /// @@ -72,7 +72,6 @@ where node: Option, state: GeneticState, generation: u64, - max_generations: u64, id: Uuid, } @@ -85,7 +84,6 @@ where node: None, state: GeneticState::Initialize, generation: 1, - max_generations: 1, id: Uuid::new_v4(), } } @@ -96,19 +94,17 @@ where T: GeneticNode + Debug + Send + Clone, T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default, { - pub fn new(max_generations: u64) -> Self { + pub fn new() -> Self { GeneticNodeWrapper:: { - max_generations, ..Default::default() } } - pub fn from(data: T, max_generations: u64, id: Uuid) -> Self { + pub fn from(data: T, id: Uuid) -> Self { GeneticNodeWrapper { node: Some(data), state: GeneticState::Simulate, generation: 1, - max_generations, id, } } @@ -125,10 +121,6 @@ where self.id } - pub fn max_generations(&self) -> u64 { - self.max_generations - } - pub fn generation(&self) -> u64 { self.generation } @@ -140,7 +132,6 @@ where pub async fn process_node(&mut self, gemla_context: T::Context) -> Result { let context = GeneticNodeContext { generation: self.generation, - max_generations: self.max_generations, id: self.id, gemla_context, }; @@ -151,14 +142,15 @@ where self.state = GeneticState::Simulate; } (GeneticState::Simulate, Some(n)) => { - n.simulate(context.clone()) + let next_generation = n + .simulate(context.clone()) .await .with_context(|| format!("Error simulating node: {:?}", self))?; - self.state = if self.generation >= self.max_generations { - GeneticState::Finish - } else { + self.state = if next_generation { GeneticState::Mutate + } else { + GeneticState::Finish }; } (GeneticState::Mutate, Some(n)) => { @@ -187,6 +179,7 @@ mod tests { #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] struct TestState { pub score: f64, + pub max_generations: u64, } #[async_trait] @@ -195,10 +188,14 @@ mod tests { async fn simulate( &mut self, - _context: GeneticNodeContext, - ) -> Result<(), Error> { + context: GeneticNodeContext, + ) -> Result { self.score += 1.0; - Ok(()) + if context.generation >= self.max_generations { + Ok(false) + } else { + Ok(true) + } } async fn mutate( @@ -211,7 +208,10 @@ mod tests { async fn initialize( _context: GeneticNodeContext, ) -> Result, Error> { - Ok(Box::new(TestState { score: 0.0 })) + Ok(Box::new(TestState { + score: 0.0, + max_generations: 2, + })) } async fn merge( @@ -226,13 +226,12 @@ mod tests { #[test] fn test_new() -> Result<(), Error> { - let genetic_node = GeneticNodeWrapper::::new(10); + let genetic_node = GeneticNodeWrapper::::new(); let other_genetic_node = GeneticNodeWrapper:: { node: None, state: GeneticState::Initialize, generation: 1, - max_generations: 10, id: genetic_node.id(), }; @@ -243,15 +242,17 @@ mod tests { #[test] fn test_from() -> Result<(), Error> { - let val = TestState { score: 0.0 }; + let val = TestState { + score: 0.0, + max_generations: 10, + }; let uuid = Uuid::new_v4(); - let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid); + let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let other_genetic_node = GeneticNodeWrapper:: { node: Some(val), state: GeneticState::Simulate, generation: 1, - max_generations: 10, id: genetic_node.id(), }; @@ -262,9 +263,12 @@ mod tests { #[test] fn test_as_ref() -> Result<(), Error> { - let val = TestState { score: 3.0 }; + let val = TestState { + score: 3.0, + max_generations: 10, + }; let uuid = Uuid::new_v4(); - let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid); + let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let ref_value = genetic_node.as_ref().unwrap(); @@ -275,9 +279,12 @@ mod tests { #[test] fn test_id() -> Result<(), Error> { - let val = TestState { score: 3.0 }; + let val = TestState { + score: 3.0, + max_generations: 10, + }; let uuid = Uuid::new_v4(); - let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid); + let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let id_value = genetic_node.id(); @@ -286,24 +293,14 @@ mod tests { Ok(()) } - #[test] - fn test_max_generations() -> Result<(), Error> { - let val = TestState { score: 3.0 }; - let uuid = Uuid::new_v4(); - let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid); - - let max_generations = genetic_node.max_generations(); - - assert_eq!(max_generations, 10); - - Ok(()) - } - #[test] fn test_state() -> Result<(), Error> { - let val = TestState { score: 3.0 }; + let val = TestState { + score: 3.0, + max_generations: 10, + }; let uuid = Uuid::new_v4(); - let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid); + let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let state = genetic_node.state(); @@ -314,7 +311,7 @@ mod tests { #[tokio::test] async fn test_process_node() -> Result<(), Error> { - let mut genetic_node = GeneticNodeWrapper::::new(2); + let mut genetic_node = GeneticNodeWrapper::::new(); assert_eq!(genetic_node.state(), GeneticState::Initialize); assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate); diff --git a/gemla/src/core/mod.rs b/gemla/src/core/mod.rs index 766a7d0..d3de3f3 100644 --- a/gemla/src/core/mod.rs +++ b/gemla/src/core/mod.rs @@ -57,7 +57,6 @@ type SimulationTree = Box>>; /// ``` #[derive(Serialize, Deserialize, Copy, Clone)] pub struct GemlaConfig { - pub generations_per_height: u64, pub overwrite: bool, } @@ -126,9 +125,9 @@ where // Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed // in the tree and which nodes have not. self.data - .mutate(|(d, c, _)| { + .mutate(|(d, _, _)| { let mut tree: Option> = - Gemla::increase_height(d.take(), c, steps); + Gemla::increase_height(d.take(), steps); mem::swap(d, &mut tree); }) .await?; @@ -268,11 +267,7 @@ where gemla_context.clone(), ) .await?; - tree.val = GeneticNodeWrapper::from( - *merged_node, - tree.val.max_generations(), - tree.val.id(), - ); + tree.val = GeneticNodeWrapper::from(*merged_node, tree.val.id()); } } (Some(l), Some(r)) => { @@ -284,11 +279,7 @@ where trace!("Copying node {}", l.val.id()); if let Some(left_node) = l.val.as_ref() { - GeneticNodeWrapper::from( - left_node.clone(), - tree.val.max_generations(), - tree.val.id(), - ); + GeneticNodeWrapper::from(left_node.clone(), tree.val.id()); } } (Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?, @@ -296,11 +287,7 @@ where trace!("Copying node {}", r.val.id()); if let Some(right_node) = r.val.as_ref() { - tree.val = GeneticNodeWrapper::from( - right_node.clone(), - tree.val.max_generations(), - tree.val.id(), - ); + tree.val = GeneticNodeWrapper::from(right_node.clone(), tree.val.id()); } } (None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?, @@ -353,11 +340,7 @@ where } } - fn increase_height( - tree: Option>, - config: &GemlaConfig, - amount: u64, - ) -> Option> { + fn increase_height(tree: Option>, amount: u64) -> Option> { if amount == 0 { tree } else { @@ -365,13 +348,11 @@ where tree.as_ref().map(|t| t.height() as u64).unwrap_or(0) + amount - 1; Some(Box::new(Tree::new( - GeneticNodeWrapper::new(config.generations_per_height), - Gemla::increase_height(tree, config, amount - 1), + GeneticNodeWrapper::new(), + Gemla::increase_height(tree, amount - 1), // The right branch height has to equal the left branches total height if left_branch_height > 0 { - Some(Box::new(btree!(GeneticNodeWrapper::new( - left_branch_height * config.generations_per_height - )))) + Some(Box::new(btree!(GeneticNodeWrapper::new()))) } else { None }, @@ -446,6 +427,7 @@ mod tests { #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] struct TestState { pub score: f64, + pub max_generations: u64, } #[async_trait] @@ -454,10 +436,10 @@ mod tests { async fn simulate( &mut self, - _context: GeneticNodeContext, - ) -> Result<(), Error> { + context: GeneticNodeContext, + ) -> Result { self.score += 1.0; - Ok(()) + Ok(context.generation < self.max_generations) } async fn mutate( @@ -470,7 +452,10 @@ mod tests { async fn initialize( _context: GeneticNodeContext, ) -> Result, Error> { - Ok(Box::new(TestState { score: 0.0 })) + Ok(Box::new(TestState { + score: 0.0, + max_generations: 10, + })) } async fn merge( @@ -498,10 +483,7 @@ mod tests { assert!(!path.exists()); // Testing initial creation - let mut config = GemlaConfig { - generations_per_height: 1, - overwrite: true, - }; + let mut config = GemlaConfig { overwrite: true }; let mut gemla = Gemla::::new(&p, config, DataFormat::Json).await?; // Now we can use `.await` within the spawned blocking task. @@ -559,10 +541,7 @@ mod tests { CleanUp::new(&path).run(move |p| { rt.block_on(async { // Testing initial creation - let config = GemlaConfig { - generations_per_height: 10, - overwrite: true, - }; + let config = GemlaConfig { overwrite: true }; let mut gemla = Gemla::::new(&p, config, DataFormat::Json).await?; // Now we can use `.await` within the spawned blocking task. diff --git a/analysis.py b/visualize_simulation_tree.py similarity index 100% rename from analysis.py rename to visualize_simulation_tree.py