Compare commits

..

No commits in common. "32a6813cf472576c87d50bc0aed2c750ea05472e" and "5ab3c2382e019e18d1cd5da12c8b2ef304032e9f" have entirely different histories.

5 changed files with 102 additions and 379 deletions

View file

@ -19,7 +19,5 @@ num_cpus = "1.17.0"
rand = "0.9.2" rand = "0.9.2"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.143" serde_json = "1.0.143"
tempfile = "3.21.0"
tokio = { version = "1.47.1", features = ["full"] } tokio = { version = "1.47.1", features = ["full"] }
tokio-test = "0.4.4"
uuid = "1.18.1" uuid = "1.18.1"

View file

@ -3,7 +3,7 @@ extern crate fann;
pub mod fighter_context; pub mod fighter_context;
pub mod neural_network_utility; pub mod neural_network_utility;
use anyhow::{Context, anyhow}; use anyhow::{anyhow, Context};
use async_trait::async_trait; use async_trait::async_trait;
use fann::{ActivationFunc, Fann}; use fann::{ActivationFunc, Fann};
use futures::future::join_all; use futures::future::join_all;
@ -114,10 +114,8 @@ impl GeneticNode for FighterNN {
})?; })?;
let mut nn_shapes = HashMap::new(); let mut nn_shapes = HashMap::new();
let weight_initialization_amplitude = let weight_initialization_amplitude = rng().random_range(0.0..NEURAL_NETWORK_INITIAL_WEIGHT_MAX);
rng().random_range(0.0..NEURAL_NETWORK_INITIAL_WEIGHT_MAX); let weight_initialization_range = -weight_initialization_amplitude..weight_initialization_amplitude;
let weight_initialization_range =
-weight_initialization_amplitude..weight_initialization_amplitude;
// Create the first generation in this folder // Create the first generation in this folder
for i in 0..POPULATION { for i in 0..POPULATION {
@ -203,6 +201,7 @@ impl GeneticNode for FighterNN {
i i
}; };
let secondary_id = loop { let secondary_id = loop {
if allotted_simulations.is_empty() || allotted_simulations.len() == 1 { if allotted_simulations.is_empty() || allotted_simulations.len() == 1 {
// Select a random id // Select a random id
@ -244,8 +243,7 @@ impl GeneticNode for FighterNN {
let task = { let task = {
let self_clone = self.clone(); let self_clone = self.clone();
let semaphore_clone = context.gemla_context.shared_semaphore.clone(); let semaphore_clone = context.gemla_context.shared_semaphore.clone();
let display_simulation_semaphore = let display_simulation_semaphore = context.gemla_context.visible_simulations.clone();
context.gemla_context.visible_simulations.clone();
let folder = self_clone.folder.clone(); let folder = self_clone.folder.clone();
let generation = self_clone.r#generation; let generation = self_clone.r#generation;
@ -262,16 +260,12 @@ impl GeneticNode for FighterNN {
// Introducing a new scope for acquiring permits and running simulations // Introducing a new scope for acquiring permits and running simulations
let simulation_result = async move { let simulation_result = async move {
let permit = semaphore_clone let permit = semaphore_clone.acquire_owned().await
.acquire_owned()
.await
.with_context(|| "Failed to acquire semaphore permit")?; .with_context(|| "Failed to acquire semaphore permit")?;
let display_simulation = display_simulation_semaphore.try_acquire_owned().ok(); let display_simulation = display_simulation_semaphore.try_acquire_owned().ok();
let (primary_score, secondary_score) = if let Some(display_simulation) = let (primary_score, secondary_score) = if let Some(display_simulation) = display_simulation {
display_simulation
{
let result = run_1v1_simulation(&primary_nn, &secondary_nn, true).await?; let result = run_1v1_simulation(&primary_nn, &secondary_nn, true).await?;
drop(display_simulation); // Explicitly dropping resources no longer needed drop(display_simulation); // Explicitly dropping resources no longer needed
result result
@ -306,8 +300,7 @@ impl GeneticNode for FighterNN {
// resolve results for any errors // resolve results for any errors
let mut scores = HashMap::new(); let mut scores = HashMap::new();
for result in results.into_iter() { for result in results.into_iter() {
let (primary_id, primary_score, secondary_id, secondary_score) = let (primary_id, primary_score, secondary_id, secondary_score) = result.with_context(|| "Failed to run simulation")?;
result.with_context(|| "Failed to run simulation")?;
// If score exists, add the new score to the existing score // If score exists, add the new score to the existing score
if let Some((existing_score, count)) = scores.get_mut(&(primary_id as u64)) { if let Some((existing_score, count)) = scores.get_mut(&(primary_id as u64)) {
@ -487,9 +480,8 @@ impl GeneticNode for FighterNN {
.with_context(|| format!("Failed to create directory {:?}", folder.join("0")))?; .with_context(|| format!("Failed to create directory {:?}", folder.join("0")))?;
let get_highest_scores = |fighter: &FighterNN| -> Vec<(u64, f32)> { let get_highest_scores = |fighter: &FighterNN| -> Vec<(u64, f32)> {
let mut sorted_scores: Vec<_> = fighter.scores[fighter.r#generation as usize] let mut sorted_scores: Vec<_> =
.iter() fighter.scores[fighter.r#generation as usize].iter().collect();
.collect();
sorted_scores.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); sorted_scores.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
sorted_scores sorted_scores
.iter() .iter()
@ -545,10 +537,7 @@ impl GeneticNode for FighterNN {
run_1v1_simulation(&left_nn_path, &right_nn_path, false).await? run_1v1_simulation(&left_nn_path, &right_nn_path, false).await?
}; };
debug!( debug!("{} vs {} -> {} vs {}", left_nn_id, right_nn_id, left_score, right_score);
"{} vs {} -> {} vs {}",
left_nn_id, right_nn_id, left_score, right_score
);
drop(permit); drop(permit);
@ -745,11 +734,7 @@ fn should_continue(scores: &[HashMap<u64, f32>], lenience: u64) -> Result<bool,
debug!( debug!(
"Highest Q3 value: {} at generation {}, Highest Median value: {} at generation {}, Continuing? {}", "Highest Q3 value: {} at generation {}, Highest Median value: {} at generation {}, Continuing? {}",
highest_q3_value, highest_q3_value, generation_with_highest_q3 + 1, highest_median, generation_with_highest_median + 1, result
generation_with_highest_q3 + 1,
highest_median,
generation_with_highest_median + 1,
result
); );
Ok(result) Ok(result)
@ -841,7 +826,10 @@ async fn run_1v1_simulation(
trace!( trace!(
"Executing the following command {} {} {} {}", "Executing the following command {} {} {} {}",
GAME_EXECUTABLE_PATH, config1_arg, config2_arg, disable_unreal_rendering_arg GAME_EXECUTABLE_PATH,
config1_arg,
config2_arg,
disable_unreal_rendering_arg
); );
trace!("Running simulation for {} vs {}", nn_1_id, nn_2_id); trace!("Running simulation for {} vs {}", nn_1_id, nn_2_id);
@ -918,7 +906,8 @@ async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::
"NN ID not found in scores file", "NN ID not found in scores file",
)); ));
} }
Err(_) => { Err(_) =>
{
if attempts >= 2 { if attempts >= 2 {
// Attempt 5 times before giving up. // Attempt 5 times before giving up.
return Ok(-100.0); return Ok(-100.0);
@ -935,74 +924,6 @@ async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::
#[cfg(test)] #[cfg(test)]
pub mod test { pub mod test {
use super::*; use super::*;
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
use tempfile::tempdir;
use tokio_test::block_on;
use uuid::Uuid;
#[test]
fn test_get_individual_id_format() {
let nn = FighterNN {
id: Uuid::new_v4(),
folder: PathBuf::new(),
generation: 0,
population_size: 10,
scores: vec![],
nn_shapes: vec![],
crossbreed_segments: 2,
weight_initialization_range: -0.5..0.5,
minor_mutation_rate: 0.1,
major_mutation_rate: 0.05,
mutation_weight_range: -0.1..0.1,
id_mapping: vec![],
lerp_amount: 0.0,
generational_lenience: 3,
survival_rate: 0.5,
};
let id_str = nn.get_individual_id(42);
assert!(id_str.contains("fighter_nn_"));
assert!(id_str.ends_with("_42"));
assert!(id_str.starts_with(&format!("{:06}_", nn.id)));
}
#[test]
fn test_read_score_from_file_found() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("score.txt");
let nn_id = "test_nn";
let mut file = File::create(&file_path).unwrap();
writeln!(file, "{}: 123.45", nn_id).unwrap();
let score = block_on(read_score_from_file(&file_path, nn_id)).unwrap();
assert!((score - 123.45).abs() < 1e-5);
}
#[test]
fn test_read_score_from_file_not_found() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("score.txt");
let nn_id = "not_in_file";
let mut file = File::create(&file_path).unwrap();
writeln!(file, "other_nn: 1.0").unwrap();
let result = block_on(read_score_from_file(&file_path, nn_id));
assert!(result.is_err());
}
#[test]
fn test_run_1v1_simulation_reads_existing_score() {
let dir = tempdir().unwrap();
let nn_id1 = "nn1";
let nn_id2 = "nn2";
let file_path = dir.path().join(format!("{}_vs_{}.txt", nn_id1, nn_id2));
let mut file = File::create(&file_path).unwrap();
writeln!(file, "{}: 10.0", nn_id1).unwrap();
writeln!(file, "{}: 20.0", nn_id2).unwrap();
let nn_path_1 = dir.path().join(format!("{}.net", nn_id1));
let nn_path_2 = dir.path().join(format!("{}.net", nn_id2));
let result = block_on(run_1v1_simulation(&nn_path_1, &nn_path_2, false)).unwrap();
assert_eq!(result, (10.0, 20.0));
}
#[test] #[test]
fn test_weighted_random_selection() { fn test_weighted_random_selection() {

View file

@ -1,13 +1,12 @@
use std::{cmp::Ordering, cmp::min, collections::HashMap, ops::Range}; use std::{cmp::min, cmp::Ordering, collections::HashMap, ops::Range};
use anyhow::Context; use anyhow::Context;
use fann::{ActivationFunc, Fann}; use fann::{ActivationFunc, Fann};
use gemla::error::Error; use gemla::error::Error;
use rand::{ use rand::{
Rng,
distr::{Distribution, Uniform}, distr::{Distribution, Uniform},
rng,
seq::IteratorRandom, seq::IteratorRandom,
rng, Rng,
}; };
use super::{ use super::{
@ -209,37 +208,13 @@ pub fn consolidate_old_connections(
to_non_bias_network_id(connection.from_neuron, &primary_shape); to_non_bias_network_id(connection.from_neuron, &primary_shape);
let original_to_neuron = let original_to_neuron =
to_non_bias_network_id(connection.to_neuron, &primary_shape); to_non_bias_network_id(connection.to_neuron, &primary_shape);
trace!( trace!("Primary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", previous_new_id, new_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, previous_neuron_id, neuron_id);
"Primary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]",
previous_new_id,
new_id,
original_from_neuron,
original_to_neuron,
connection.weight,
found_in_primary,
connection.from_neuron,
connection.to_neuron,
previous_neuron_id,
neuron_id
);
} else { } else {
let original_from_neuron = let original_from_neuron =
to_non_bias_network_id(connection.from_neuron, &secondary_shape); to_non_bias_network_id(connection.from_neuron, &secondary_shape);
let original_to_neuron = let original_to_neuron =
to_non_bias_network_id(connection.to_neuron, &secondary_shape); to_non_bias_network_id(connection.to_neuron, &secondary_shape);
trace!( trace!("Secondary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", previous_new_id, new_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, previous_neuron_id, neuron_id);
"Secondary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]",
previous_new_id,
new_id,
original_from_neuron,
original_to_neuron,
connection.weight,
found_in_primary,
connection.from_neuron,
connection.to_neuron,
previous_neuron_id,
neuron_id
);
} }
let translated_from = to_bias_network_id(previous_new_id, &new_shape); let translated_from = to_bias_network_id(previous_new_id, &new_shape);
let translated_to = to_bias_network_id(new_id, &new_shape); let translated_to = to_bias_network_id(new_id, &new_shape);
@ -247,7 +222,10 @@ pub fn consolidate_old_connections(
} else { } else {
trace!( trace!(
"Connection not found for ({}, {}) -> ({}, {})", "Connection not found for ({}, {}) -> ({}, {})",
previous_new_id, new_id, previous_neuron_id, neuron_id previous_new_id,
new_id,
previous_neuron_id,
neuron_id
); );
} }
} }
@ -339,43 +317,23 @@ pub fn consolidate_old_connections(
to_non_bias_network_id(connection.from_neuron, &primary_shape); to_non_bias_network_id(connection.from_neuron, &primary_shape);
let original_to_neuron = let original_to_neuron =
to_non_bias_network_id(connection.to_neuron, &primary_shape); to_non_bias_network_id(connection.to_neuron, &primary_shape);
trace!( trace!("Primary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", bias_neuron, translated_neuron_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, bias_neuron, neuron_id);
"Primary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]",
bias_neuron,
translated_neuron_id,
original_from_neuron,
original_to_neuron,
connection.weight,
found_in_primary,
connection.from_neuron,
connection.to_neuron,
bias_neuron,
neuron_id
);
} else { } else {
let original_from_neuron = let original_from_neuron =
to_non_bias_network_id(connection.from_neuron, &secondary_shape); to_non_bias_network_id(connection.from_neuron, &secondary_shape);
let original_to_neuron = let original_to_neuron =
to_non_bias_network_id(connection.to_neuron, &secondary_shape); to_non_bias_network_id(connection.to_neuron, &secondary_shape);
trace!( trace!("Secondary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", bias_neuron, translated_neuron_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, bias_neuron, neuron_id);
"Secondary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]",
bias_neuron,
translated_neuron_id,
original_from_neuron,
original_to_neuron,
connection.weight,
found_in_primary,
connection.from_neuron,
connection.to_neuron,
bias_neuron,
neuron_id
);
} }
new_fann.set_weight(bias_neuron, translated_neuron_id, connection.weight); new_fann.set_weight(bias_neuron, translated_neuron_id, connection.weight);
} else { } else {
trace!( trace!(
"Connection not found for bias ({}, {}) -> ({}, {}) primary: {}", "Connection not found for bias ({}, {}) -> ({}, {}) primary: {}",
bias_neuron, neuron_id, bias_neuron, translated_neuron_id, is_primary bias_neuron,
neuron_id,
bias_neuron,
translated_neuron_id,
is_primary
); );
} }
} }
@ -409,8 +367,11 @@ pub fn crossbreed_neuron_arrays(
current_layer += 1; current_layer += 1;
} }
new_neurons.push((*neuron_id, is_primary, current_layer, 0)); new_neurons.push((*neuron_id, is_primary, current_layer, 0));
// The first segment is always from the primary network, so we can set primary_last_layer here if is_primary {
primary_last_layer = current_layer; primary_last_layer = current_layer;
} else {
secondary_last_layer = current_layer;
}
} else { } else {
break; break;
} }
@ -427,7 +388,8 @@ pub fn crossbreed_neuron_arrays(
if neuron_id >= &segment.0 && neuron_id <= &segment.1 { if neuron_id >= &segment.0 && neuron_id <= &segment.1 {
// We need to do something different depending on whether the neuron layer is, lower, higher or equal to the target layer // We need to do something different depending on whether the neuron layer is, lower, higher or equal to the target layer
match layer.cmp(&current_layer) { match layer.cmp(&current_layer)
{
Ordering::Equal => { Ordering::Equal => {
new_neurons.push((*neuron_id, is_primary, current_layer, 0)); new_neurons.push((*neuron_id, is_primary, current_layer, 0));
@ -449,9 +411,8 @@ pub fn crossbreed_neuron_arrays(
let highest_id = earlier_layer_neurons let highest_id = earlier_layer_neurons
.iter() .iter()
.max_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); .max_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0)));
if let Some(highest_id) = highest_id if let Some(highest_id) = highest_id {
&& highest_id.1 == is_primary if highest_id.1 == is_primary {
{
let neurons_to_add = target_neurons let neurons_to_add = target_neurons
.iter() .iter()
.filter(|(id, l)| { .filter(|(id, l)| {
@ -468,6 +429,7 @@ pub fn crossbreed_neuron_arrays(
} }
} }
} }
}
new_neurons.push((*neuron_id, is_primary, *layer, 0)); new_neurons.push((*neuron_id, is_primary, *layer, 0));
@ -486,9 +448,8 @@ pub fn crossbreed_neuron_arrays(
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let highest_id = let highest_id =
current_layer_neurons.iter().max_by_key(|(id, _, _, _)| id); current_layer_neurons.iter().max_by_key(|(id, _, _, _)| id);
if let Some(highest_id) = highest_id if let Some(highest_id) = highest_id {
&& highest_id.1 == is_primary if highest_id.1 == is_primary {
{
let neurons_to_add = target_neurons let neurons_to_add = target_neurons
.iter() .iter()
.filter(|(id, l)| id > &highest_id.0 && *l == layer - 1) .filter(|(id, l)| id > &highest_id.0 && *l == layer - 1)
@ -503,6 +464,7 @@ pub fn crossbreed_neuron_arrays(
} }
} }
} }
}
// If it's in a future layer, move to the next layer // If it's in a future layer, move to the next layer
current_layer += 1; current_layer += 1;
@ -555,9 +517,7 @@ pub fn crossbreed_neuron_arrays(
new_neurons.push((*neuron_id, is_primary, current_layer, 0)); new_neurons.push((*neuron_id, is_primary, current_layer, 0));
} }
break; break;
} } else if *neuron_id == &segments.last().unwrap().1 + 1 {
// If the neuron id is exactly one more than the last neuron id, we need to ensure that there's at least one neuron from the same individual in the previous layer
else if *neuron_id == &segments.last().unwrap().1 + 1 {
let target_layer = if is_primary { let target_layer = if is_primary {
primary_last_layer primary_last_layer
} else { } else {
@ -573,9 +533,8 @@ pub fn crossbreed_neuron_arrays(
let highest_id = earlier_layer_neurons let highest_id = earlier_layer_neurons
.iter() .iter()
.max_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); .max_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0)));
if let Some(highest_id) = highest_id if let Some(highest_id) = highest_id {
&& highest_id.1 == is_primary if highest_id.1 == is_primary {
{
let neurons_to_add = target_neurons let neurons_to_add = target_neurons
.iter() .iter()
.filter(|(id, _)| id > &highest_id.0 && id < neuron_id) .filter(|(id, _)| id > &highest_id.0 && id < neuron_id)
@ -584,6 +543,7 @@ pub fn crossbreed_neuron_arrays(
new_neurons.push((*neuron_id, is_primary, *l, 0)); new_neurons.push((*neuron_id, is_primary, *l, 0));
} }
} }
}
new_neurons.push((*neuron_id, is_primary, *layer, 0)); new_neurons.push((*neuron_id, is_primary, *layer, 0));
} else { } else {
@ -802,45 +762,6 @@ mod tests {
use super::*; use super::*;
#[test]
fn crossbreed_basic_test() -> Result<(), Box<dyn std::error::Error>> {
// Create a dummy FighterNN
let fighter_nn = FighterNN {
id: uuid::Uuid::new_v4(),
folder: std::path::PathBuf::from("/tmp"),
population_size: 2,
generation: 0,
scores: vec![],
nn_shapes: vec![],
crossbreed_segments: 1,
weight_initialization_range: -0.5..0.5,
minor_mutation_rate: 0.1,
major_mutation_rate: 0.1,
mutation_weight_range: -0.5..0.5,
id_mapping: vec![],
lerp_amount: 0.5,
generational_lenience: 1,
survival_rate: 0.5,
};
// Use very small networks to avoid FANN memory/resource issues
let primary = Fann::new(&[5, 3, 3])?;
let secondary = Fann::new(&[5, 3, 3])?;
// Run crossbreed
let result = crossbreed(&fighter_nn, &primary, &secondary, 3)?;
// Check that the result has the correct input and output size
let shape = result.get_layer_sizes();
assert_eq!(shape[0], 5);
assert_eq!(*shape.last().unwrap(), 3);
// All hidden layers should have at least 1 neuron
for (i, &layer_size) in shape.iter().enumerate().skip(1).take(shape.len() - 2) {
assert!(layer_size > 0, "Hidden layer {} has zero neurons", i);
}
Ok(())
}
#[test] #[test]
fn major_mutation_test() -> Result<(), Box<dyn std::error::Error>> { fn major_mutation_test() -> Result<(), Box<dyn std::error::Error>> {
// Assign // Assign
@ -1526,51 +1447,6 @@ mod tests {
assert_eq!(result_set, expected); assert_eq!(result_set, expected);
} }
#[test]
fn crossbreed_neuron_arrays_secondary_last_layer_final_segment() {
// Use 3 segments and larger neuron arrays to ensure the else-if branch is covered
let segments = vec![(0, 2), (3, 5), (6, 8)];
// secondary_neurons: (id, layer)
let primary_neurons = generate_neuron_datastructure(&vec![3, 3, 3, 3, 2]);
let secondary_neurons = generate_neuron_datastructure(&vec![3, 3, 3, 3, 3]);
// The last segment is (6, 8), so neuron 9 in secondary_neurons will trigger the else-if branch
let result = crossbreed_neuron_arrays(segments, primary_neurons, secondary_neurons);
// Assert: The result should contain a secondary neuron with id 9 and layer 3
let has_secondary_9_layer_3 = result
.iter()
.any(|&(id, is_primary, layer, _)| id == 9 && !is_primary && layer == 3);
assert!(
has_secondary_9_layer_3,
"Expected a secondary neuron with id 9 in layer 3, indicating secondary_last_layer was used"
);
}
#[test]
fn crossbreed_neuron_arrays_prune_layer_exceeds_max() {
// Use the real constant from the module
let max = NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX;
let layer = 1;
// Create more than max neurons in layer 1, alternating primary/secondary
let primary_neurons = generate_neuron_datastructure(&vec![1, (max + 3) as u32, 1]);
let secondary_neurons = generate_neuron_datastructure(&vec![1, (max + 3) as u32, 1]);
// Segments: one for input, one for the large layer, one for output
let segments = vec![
(0, 0),
(1, (max + 3) as u32),
((max + 4) as u32, (max + 4) as u32),
];
let result = crossbreed_neuron_arrays(segments, primary_neurons, secondary_neurons);
// Count neurons in layer 1
let count = result.iter().filter(|(_, _, l, _)| *l == layer).count();
assert!(
count <= max,
"Layer should be pruned to NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX neurons"
);
}
#[test] #[test]
fn generate_neuron_datastructure_test() { fn generate_neuron_datastructure_test() {
// Assign // Assign
@ -1944,92 +1820,6 @@ mod tests {
} }
} }
// Setup for bias connection fallback
// We'll target hidden layer 3 (layer index 3), neuron 20 (first in that layer, is_primary = false)
let bias_layer = 3;
let bias_neuron_primary = get_bias_neuron_for_layer(bias_layer, &primary_shape).unwrap();
// Add a bias connection in primary if not present
let mut primary_connections = primary_fann.get_connections();
for connection in primary_connections.iter_mut() {
if connection.from_neuron == bias_neuron_primary {
// Set to a unique weight to verify it's not used
connection.weight = 12345.0;
}
}
primary_fann.set_connections(&primary_connections);
// Secondary network needs to have only 1 hidden layer, so it won't have a bias connection for layer 3
let secondary_shape = vec![4, 8, 6];
let mut secondary_fann = Fann::new(&secondary_shape)?;
let mut secondary_connections = secondary_fann.get_connections();
for connection in secondary_connections.iter_mut() {
connection.weight = ((connection.from_neuron * 100) + connection.to_neuron) as f32;
connection.weight = connection.weight * -1.0;
}
secondary_fann.set_connections(&secondary_connections);
let new_neurons = vec![
// Input layer: Expect 4
(0, true, 0, 0),
(1, true, 0, 1),
(2, true, 0, 2),
(3, true, 0, 3),
// Hidden Layer 1: Expect 8
(4, false, 1, 4),
(5, false, 1, 5),
(6, false, 1, 6),
(7, true, 1, 7),
(8, true, 1, 8),
(9, true, 1, 9),
(10, true, 1, 10),
(11, true, 1, 11),
// Hidden Layer 2: Expect 6
(12, true, 2, 12),
(13, true, 2, 13),
(14, true, 2, 14),
(15, true, 2, 15),
(16, true, 2, 16),
(17, true, 2, 17),
// Output Layer: Expect 4
(18, true, 3, 18),
(19, true, 3, 19),
(20, false, 3, 20),
(21, true, 3, 21),
];
let new_shape = vec![4, 8, 6, 4];
let mut new_fann = Fann::new(&new_shape)?;
// Initialize weights to 0
let mut new_connections = new_fann.get_connections();
for connection in new_connections.iter_mut() {
connection.weight = 0.0;
}
new_fann.set_connections(&new_connections);
// Act
consolidate_old_connections(
&primary_fann,
&secondary_fann,
new_shape,
new_neurons,
&mut new_fann,
);
// Assert that the fallback bias connection was used
let new_connections = new_fann.get_connections();
for connection in new_connections.iter() {
println!("{:?}", connection);
}
let found = new_connections
.iter()
.any(|c| c.from_neuron == bias_neuron_primary && c.weight == 12345.0);
assert!(
found,
"Expected fallback bias connection from primary network to be used"
);
Ok(()) Ok(())
} }
} }

View file

@ -100,5 +100,19 @@ def visualize_fann_network(network_file):
plt.show() plt.show()
# Path to the FANN network file # Path to the FANN network file
fann_path = 'F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_4f2be613-ab26-4384-9a65-450e043984ea\\6\\4f2be613-ab26-4384-9a65-450e043984ea_fighter_nn_0.net'
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_fc294503-7b2a-40f8-be59-ccc486eb3f79\\0\\fc294503-7b2a-40f8-be59-ccc486eb3f79_fighter_nn_0.net"
# fann_path = 'F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_99c30a7f-40ab-4faf-b16a-b44703fdb6cd\\0\\99c30a7f-40ab-4faf-b16a-b44703fdb6cd_fighter_nn_0.net'
# Has a 4 layer network
# # Generation 1
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\1\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 5
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\5\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 10
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\10\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 20
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\20\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 32
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\32\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
fann_path = select_file() fann_path = select_file()
visualize_fann_network(fann_path) visualize_fann_network(fann_path)