Adjusting NN size

This commit is contained in:
vandomej 2024-03-19 22:09:08 -07:00
parent ca3989421d
commit c9b746e59d

View file

@ -12,7 +12,7 @@ use std::collections::HashMap;
const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations";
const POPULATION: usize = 100;
const NEURAL_NETWORK_SHAPE: &[u32; 3] = &[10, 10, 10];
const NEURAL_NETWORK_SHAPE: &[u32; 5] = &[14, 20, 20, 12, 8];
const SIMULATION_ROUNDS: usize = 10;
const SURVIVAL_RATE: f32 = 0.5;
@ -90,7 +90,7 @@ impl GeneticNode for FighterNN {
let random_fann = Fann::from_file(&random_nn)
.with_context(|| format!("Failed to load random nn"))?;
let inputs: Vec<f32> = (0..10).map(|_| thread_rng().gen_range(-1.0..1.0)).collect();
let inputs: Vec<f32> = (0..NEURAL_NETWORK_SHAPE[0]).map(|_| thread_rng().gen_range(-1.0..1.0)).collect();
let outputs = fann.run(&inputs)
.with_context(|| format!("Failed to run nn"))?;
let random_outputs = random_fann.run(&inputs)
@ -179,9 +179,10 @@ impl GeneticNode for FighterNN {
for c in &mut connections {
if thread_rng().gen_range(0..100) < 20 {
c.weight += thread_rng().gen_range(-0.1..0.1);
} else if thread_rng().gen_range(0..100) < 5 {
c.weight += thread_rng().gen_range(-0.3..0.3);
}
// else if thread_rng().gen_range(0..100) < 5 {
// c.weight += thread_rng().gen_range(-0.3..0.3);
// }
}
fann.set_connections(&connections);