Update logic for increasing height
This commit is contained in:
parent
69b026593e
commit
ca3989421d
3 changed files with 68 additions and 62 deletions
|
@ -56,15 +56,15 @@ fn main() -> anyhow::Result<()> {
|
||||||
let mut gemla = log_error(Gemla::<FighterNN>::new(
|
let mut gemla = log_error(Gemla::<FighterNN>::new(
|
||||||
&PathBuf::from(args.file),
|
&PathBuf::from(args.file),
|
||||||
GemlaConfig {
|
GemlaConfig {
|
||||||
generations_per_node: 3,
|
generations_per_height: 3,
|
||||||
overwrite: true,
|
overwrite: false,
|
||||||
},
|
},
|
||||||
DataFormat::Json,
|
DataFormat::Json,
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
log_error(gemla.simulate(3).await)?;
|
loop {
|
||||||
|
log_error(gemla.simulate(5).await)?;
|
||||||
Ok(())
|
}
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -41,23 +41,27 @@ impl GeneticNode for FighterNN {
|
||||||
fn initialize(context: &GeneticNodeContext) -> Result<Box<Self>, Error> {
|
fn initialize(context: &GeneticNodeContext) -> Result<Box<Self>, Error> {
|
||||||
let base_path = PathBuf::from(BASE_DIR);
|
let base_path = PathBuf::from(BASE_DIR);
|
||||||
|
|
||||||
let mut folder = base_path.join(format!("fighter_nn_{:06}", context.id));
|
let folder = base_path.join(format!("fighter_nn_{:06}", context.id));
|
||||||
fs::create_dir(&folder)?;
|
// Ensures directory is created if it doesn't exist and does nothing if it exists
|
||||||
|
fs::create_dir_all(&folder)
|
||||||
|
.with_context(|| format!("Failed to create or access the folder: {:?}", folder))?;
|
||||||
|
|
||||||
//Create a new directory for the first generation
|
//Create a new directory for the first generation, using create_dir_all to avoid errors if it already exists
|
||||||
let gen_folder = folder.join("0");
|
let gen_folder = folder.join("0");
|
||||||
fs::create_dir(&gen_folder)?;
|
fs::create_dir_all(&gen_folder)
|
||||||
|
.with_context(|| format!("Failed to create or access the generation folder: {:?}", gen_folder))?;
|
||||||
|
|
||||||
// Create the first generation in this folder
|
// Create the first generation in this folder
|
||||||
for i in 0..POPULATION {
|
for i in 0..POPULATION {
|
||||||
// Filenames are stored in the format of "xxxxxx_fighter_nn_0.net", "xxxxxx_fighter_nn_1.net", etc. Where xxxxxx is the folder name
|
// Filenames are stored in the format of "xxxxxx_fighter_nn_0.net", "xxxxxx_fighter_nn_1.net", etc. Where xxxxxx is the folder name
|
||||||
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i));
|
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i));
|
||||||
let mut fann = Fann::new(NEURAL_NETWORK_SHAPE)
|
let mut fann = Fann::new(NEURAL_NETWORK_SHAPE)
|
||||||
.with_context(|| format!("Failed to create nn"))?;
|
.with_context(|| "Failed to create nn")?;
|
||||||
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
|
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
|
||||||
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
|
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
|
||||||
|
// This will overwrite any existing file with the same name
|
||||||
fann.save(&nn)
|
fann.save(&nn)
|
||||||
.with_context(|| format!("Failed to save nn"))?;
|
.with_context(|| format!("Failed to save nn at {:?}", nn))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Box::new(FighterNN {
|
Ok(Box::new(FighterNN {
|
||||||
|
@ -114,7 +118,7 @@ impl GeneticNode for FighterNN {
|
||||||
|
|
||||||
// Create the new generation folder
|
// Create the new generation folder
|
||||||
let new_gen_folder = self.folder.join(format!("{}", self.generation + 1));
|
let new_gen_folder = self.folder.join(format!("{}", self.generation + 1));
|
||||||
fs::create_dir(&new_gen_folder)?;
|
fs::create_dir_all(&new_gen_folder).with_context(|| format!("Failed to create or access new generation folder: {:?}", new_gen_folder))?;
|
||||||
|
|
||||||
// Remove the 5 nn's with the lowest scores
|
// Remove the 5 nn's with the lowest scores
|
||||||
let mut sorted_scores: Vec<_> = self.scores[self.generation as usize].iter().collect();
|
let mut sorted_scores: Vec<_> = self.scores[self.generation as usize].iter().collect();
|
||||||
|
@ -195,44 +199,36 @@ impl GeneticNode for FighterNN {
|
||||||
|
|
||||||
fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid) -> Result<Box<FighterNN>, Error> {
|
fn merge(left: &FighterNN, right: &FighterNN, id: &Uuid) -> Result<Box<FighterNN>, Error> {
|
||||||
let base_path = PathBuf::from(BASE_DIR);
|
let base_path = PathBuf::from(BASE_DIR);
|
||||||
|
|
||||||
// Find next highest
|
|
||||||
let folder = base_path.join(format!("fighter_nn_{:06}", id));
|
let folder = base_path.join(format!("fighter_nn_{:06}", id));
|
||||||
fs::create_dir(&folder)?;
|
|
||||||
|
|
||||||
//Create a new directory for the first generation
|
// Ensure the folder exists, including the generation subfolder.
|
||||||
let gen_folder = folder.join("0");
|
fs::create_dir_all(&folder.join("0"))
|
||||||
fs::create_dir(&gen_folder)?;
|
.with_context(|| format!("Failed to create directory {:?}", folder.join("0")))?;
|
||||||
|
|
||||||
// Take the 5 nn's with the highest scores from the left nn's and save them to the new fighter folder
|
// Function to copy NNs from a source FighterNN to the new folder.
|
||||||
let mut sorted_scores: Vec<_> = left.scores[left.generation as usize].iter().collect();
|
let copy_nns = |source: &FighterNN, folder: &PathBuf, id: &Uuid, start_idx: usize| -> Result<(), Error> {
|
||||||
|
let mut sorted_scores: Vec<_> = source.scores[source.generation as usize].iter().collect();
|
||||||
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
|
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
|
||||||
let mut remaining = sorted_scores[(left.population_size / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
|
let remaining = sorted_scores[(source.population_size / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
|
||||||
for i in 0..(left.population_size / 2) {
|
|
||||||
let nn = left.folder.join(format!("{}", left.generation)).join(format!("{:06}_fighter_nn_{}.net", left.id, remaining.pop().unwrap()));
|
|
||||||
let new_nn = folder.join(format!("0")).join(format!("{:06}_fighter_nn_{}.net", id, i));
|
|
||||||
trace!("From: {:?}, To: {:?}", &nn, &new_nn);
|
|
||||||
fs::copy(&nn, &new_nn)
|
|
||||||
.with_context(|| format!("Failed to copy left nn"))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Take the 5 nn's with the highest scores from the right nn's and save them to the new fighter folder
|
for (i, nn_id) in remaining.into_iter().enumerate() {
|
||||||
sorted_scores = right.scores[right.generation as usize].iter().collect();
|
let nn_path = source.folder.join(source.generation.to_string()).join(format!("{:06}_fighter_nn_{}.net", source.id, nn_id));
|
||||||
sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap());
|
let new_nn_path = folder.join("0").join(format!("{:06}_fighter_nn_{}.net", id, start_idx + i));
|
||||||
remaining = sorted_scores[(right.population_size / 2)..].iter().map(|(k, _)| *k).collect::<Vec<_>>();
|
fs::copy(&nn_path, &new_nn_path)
|
||||||
for i in (right.population_size / 2)..right.population_size {
|
.with_context(|| format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path))?;
|
||||||
let nn = right.folder.join(format!("{}", right.generation)).join(format!("{:06}_fighter_nn_{}.net", right.id, remaining.pop().unwrap()));
|
|
||||||
let new_nn = folder.join(format!("0")).join(format!("{:06}_fighter_nn_{}.net", id, i));
|
|
||||||
trace!("From: {:?}, To: {:?}", &nn, &new_nn);
|
|
||||||
fs::copy(&nn, &new_nn)
|
|
||||||
.with_context(|| format!("Failed to copy right nn"))?;
|
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
// Copy the top half of NNs from each parent to the new folder.
|
||||||
|
copy_nns(left, &folder, id, 0)?;
|
||||||
|
copy_nns(right, &folder, id, left.population_size as usize / 2)?;
|
||||||
|
|
||||||
Ok(Box::new(FighterNN {
|
Ok(Box::new(FighterNN {
|
||||||
id: *id,
|
id: *id,
|
||||||
folder,
|
folder,
|
||||||
generation: 0,
|
generation: 0,
|
||||||
population_size: POPULATION,
|
population_size: left.population_size, // Assuming left and right have the same population size.
|
||||||
scores: vec![HashMap::new()],
|
scores: vec![HashMap::new()],
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,7 +55,7 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
|
||||||
/// ```
|
/// ```
|
||||||
#[derive(Serialize, Deserialize, Copy, Clone)]
|
#[derive(Serialize, Deserialize, Copy, Clone)]
|
||||||
pub struct GemlaConfig {
|
pub struct GemlaConfig {
|
||||||
pub generations_per_node: u64,
|
pub generations_per_height: u64,
|
||||||
pub overwrite: bool,
|
pub overwrite: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,12 +103,22 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> {
|
pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> {
|
||||||
|
// Only increase height if the tree is uninitialized or completed
|
||||||
|
if self.tree_ref().is_none() ||
|
||||||
|
self
|
||||||
|
.tree_ref()
|
||||||
|
.map(|t| Gemla::is_completed(t))
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
|
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
|
||||||
// in the tree and which nodes have not.
|
// in the tree and which nodes have not.
|
||||||
self.data.mutate(|(d, c)| {
|
self.data.mutate(|(d, c)| {
|
||||||
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
|
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
|
||||||
mem::swap(d, &mut tree);
|
mem::swap(d, &mut tree);
|
||||||
})?;
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Height of simulation tree increased to {}",
|
"Height of simulation tree increased to {}",
|
||||||
|
@ -286,16 +296,16 @@ where
|
||||||
if amount == 0 {
|
if amount == 0 {
|
||||||
tree
|
tree
|
||||||
} else {
|
} else {
|
||||||
let left_branch_right =
|
let left_branch_height =
|
||||||
tree.as_ref().map(|t| t.height() as u64).unwrap_or(0) + amount - 1;
|
tree.as_ref().map(|t| t.height() as u64).unwrap_or(0) + amount - 1;
|
||||||
|
|
||||||
Some(Box::new(Tree::new(
|
Some(Box::new(Tree::new(
|
||||||
GeneticNodeWrapper::new(config.generations_per_node),
|
GeneticNodeWrapper::new(config.generations_per_height),
|
||||||
Gemla::increase_height(tree, config, amount - 1),
|
Gemla::increase_height(tree, config, amount - 1),
|
||||||
// The right branch height has to equal the left branches total height
|
// The right branch height has to equal the left branches total height
|
||||||
if left_branch_right > 0 {
|
if left_branch_height > 0 {
|
||||||
Some(Box::new(btree!(GeneticNodeWrapper::new(
|
Some(Box::new(btree!(GeneticNodeWrapper::new(
|
||||||
left_branch_right * config.generations_per_node
|
left_branch_height * config.generations_per_height
|
||||||
))))
|
))))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -399,7 +409,7 @@ mod tests {
|
||||||
|
|
||||||
// Testing initial creation
|
// Testing initial creation
|
||||||
let mut config = GemlaConfig {
|
let mut config = GemlaConfig {
|
||||||
generations_per_node: 1,
|
generations_per_height: 1,
|
||||||
overwrite: true
|
overwrite: true
|
||||||
};
|
};
|
||||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||||
|
@ -439,7 +449,7 @@ mod tests {
|
||||||
CleanUp::new(&path).run(|p| {
|
CleanUp::new(&path).run(|p| {
|
||||||
// Testing initial creation
|
// Testing initial creation
|
||||||
let config = GemlaConfig {
|
let config = GemlaConfig {
|
||||||
generations_per_node: 10,
|
generations_per_height: 10,
|
||||||
overwrite: true
|
overwrite: true
|
||||||
};
|
};
|
||||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||||
|
|
Loading…
Add table
Reference in a new issue