diff --git a/gemla/src/bin/fighter_nn/mod.rs b/gemla/src/bin/fighter_nn/mod.rs index 2033667..248dbb0 100644 --- a/gemla/src/bin/fighter_nn/mod.rs +++ b/gemla/src/bin/fighter_nn/mod.rs @@ -12,7 +12,7 @@ use std::collections::HashMap; const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations"; const POPULATION: usize = 100; -const NEURAL_NETWORK_SHAPE: &[u32; 3] = &[10, 10, 10]; +const NEURAL_NETWORK_SHAPE: &[u32; 5] = &[14, 20, 20, 12, 8]; const SIMULATION_ROUNDS: usize = 10; const SURVIVAL_RATE: f32 = 0.5; @@ -90,7 +90,7 @@ impl GeneticNode for FighterNN { let random_fann = Fann::from_file(&random_nn) .with_context(|| format!("Failed to load random nn"))?; - let inputs: Vec = (0..10).map(|_| thread_rng().gen_range(-1.0..1.0)).collect(); + let inputs: Vec = (0..NEURAL_NETWORK_SHAPE[0]).map(|_| thread_rng().gen_range(-1.0..1.0)).collect(); let outputs = fann.run(&inputs) .with_context(|| format!("Failed to run nn"))?; let random_outputs = random_fann.run(&inputs) @@ -179,9 +179,10 @@ impl GeneticNode for FighterNN { for c in &mut connections { if thread_rng().gen_range(0..100) < 20 { c.weight += thread_rng().gen_range(-0.1..0.1); - } else if thread_rng().gen_range(0..100) < 5 { - c.weight += thread_rng().gen_range(-0.3..0.3); } + // else if thread_rng().gen_range(0..100) < 5 { + // c.weight += thread_rng().gen_range(-0.3..0.3); + // } } fann.set_connections(&connections);